faer/linalg/matmul/
mod.rs

1//! matrix multiplication
2
3use super::temp_mat_scratch;
4use crate::col::ColRef;
5use crate::internal_prelude::*;
6use crate::mat::{MatMut, MatRef};
7use crate::row::RowRef;
8use crate::utils::bound::Dim;
9use crate::utils::simd::SimdCtx;
10use crate::{Conj, ContiguousFwd, Par, Shape};
11use core::mem::MaybeUninit;
12use dyn_stack::{MemBuffer, MemStack};
13use equator::assert;
14use faer_macros::math;
15use faer_traits::{ByRef, ComplexField, Conjugate};
16use pulp::Simd;
17use reborrow::*;
18
19const NANO_GEMM_THRESHOLD: usize = 16 * 16 * 16;
20
21pub(crate) mod internal;
22
23/// triangular matrix multiplication module, where some of the operands are treated as triangular
24/// matrices
25pub mod triangular;
26
27mod matmul_shared {
28	use super::*;
29
30	pub const NC: usize = 2048;
31	pub const KC: usize = 128;
32
33	pub struct SimdLaneCount<T: ComplexField> {
34		pub __marker: core::marker::PhantomData<fn() -> T>,
35	}
36	impl<T: ComplexField> pulp::WithSimd for SimdLaneCount<T> {
37		type Output = usize;
38
39		#[inline(always)]
40		fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
41			let _ = simd;
42			core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>()
43		}
44	}
45
46	pub struct MicroKernelShape<T: ComplexField> {
47		pub __marker: core::marker::PhantomData<fn() -> T>,
48	}
49
50	impl<T: ComplexField> MicroKernelShape<T> {
51		pub const IS_1X1: bool = Self::MAX_MR_DIV_N == 2 && Self::MAX_NR == 1;
52		pub const IS_2X1: bool = Self::MAX_MR_DIV_N == 2 && Self::MAX_NR == 1;
53		pub const IS_2X2: bool = Self::MAX_MR_DIV_N == 2 && Self::MAX_NR == 2;
54		pub const MAX_MR_DIV_N: usize = Self::SHAPE.0;
55		pub const MAX_NR: usize = Self::SHAPE.1;
56		pub const SHAPE: (usize, usize) = {
57			if const { size_of::<T>() / size_of::<T::Unit>() <= 2 } {
58				(2, 2)
59			} else if const { size_of::<T>() / size_of::<T::Unit>() == 4 } {
60				(2, 1)
61			} else {
62				(1, 1)
63			}
64		};
65	}
66}
67
68mod matmul_vertical {
69	use super::*;
70	use matmul_shared::*;
71
72	struct Ukr<'a, const MR_DIV_N: usize, const NR: usize, T: ComplexField> {
73		dst: MatMut<'a, T, usize, usize, ContiguousFwd>,
74		a: MatRef<'a, T, usize, usize, ContiguousFwd>,
75		b: MatRef<'a, T, usize, usize>,
76		conj_lhs: Conj,
77		conj_rhs: Conj,
78		alpha: &'a T,
79		beta: Accum,
80	}
81
82	impl<const MR_DIV_N: usize, const NR: usize, T: ComplexField> pulp::WithSimd for Ukr<'_, MR_DIV_N, NR, T> {
83		type Output = ();
84
85		#[inline(always)]
86		fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
87			let Self {
88				dst,
89				a,
90				b,
91				conj_lhs,
92				conj_rhs,
93				alpha,
94				beta,
95			} = self;
96
97			with_dim!(M, a.nrows());
98			with_dim!(N, b.ncols());
99			with_dim!(K, a.ncols());
100			let a = a.as_shape(M, K);
101			let b = b.as_shape(K, N);
102			let mut dst = dst.as_shape_mut(M, N);
103
104			let simd = SimdCtx::<T, S>::new_force_mask(T::simd_ctx(simd), M);
105			let (_, body, tail) = simd.indices();
106			let tail = tail.unwrap();
107
108			let mut local_acc = [[simd.zero(); MR_DIV_N]; NR];
109
110			if conj_lhs == conj_rhs {
111				for depth in K.indices() {
112					let mut a_uninit = [MaybeUninit::<T::SimdVec<S>>::uninit(); MR_DIV_N];
113
114					for (dst, src) in core::iter::zip(&mut a_uninit, body.clone()) {
115						*dst = MaybeUninit::new(simd.read(a.col(depth), src));
116					}
117					a_uninit[MR_DIV_N - 1] = MaybeUninit::new(simd.read(a.col(depth), tail));
118
119					let a: [T::SimdVec<S>; MR_DIV_N] =
120						unsafe { crate::hacks::transmute::<[MaybeUninit<T::SimdVec<S>>; MR_DIV_N], [T::SimdVec<S>; MR_DIV_N]>(a_uninit) };
121
122					for j in N.indices() {
123						let b = simd.splat(&b[(depth, j)]);
124
125						for i in 0..MR_DIV_N {
126							let local_acc = &mut local_acc[*j][i];
127							*local_acc = simd.mul_add(b, a[i], *local_acc);
128						}
129					}
130				}
131			} else {
132				for depth in K.indices() {
133					let mut a_uninit = [MaybeUninit::<T::SimdVec<S>>::uninit(); MR_DIV_N];
134
135					for (dst, src) in core::iter::zip(&mut a_uninit, body.clone()) {
136						*dst = MaybeUninit::new(simd.read(a.col(depth), src));
137					}
138					a_uninit[MR_DIV_N - 1] = MaybeUninit::new(simd.read(a.col(depth), tail));
139
140					let a: [T::SimdVec<S>; MR_DIV_N] =
141						unsafe { crate::hacks::transmute::<[MaybeUninit<T::SimdVec<S>>; MR_DIV_N], [T::SimdVec<S>; MR_DIV_N]>(a_uninit) };
142
143					for j in N.indices() {
144						let b = simd.splat(&b[(depth, j)]);
145
146						for i in 0..MR_DIV_N {
147							let local_acc = &mut local_acc[*j][i];
148							*local_acc = simd.conj_mul_add(b, a[i], *local_acc);
149						}
150					}
151				}
152			}
153
154			if conj_lhs.is_conj() {
155				for x in &mut local_acc {
156					for x in x {
157						*x = simd.conj(*x);
158					}
159				}
160			}
161
162			let alpha = simd.splat(alpha);
163
164			match beta {
165				Accum::Add => {
166					for (result, j) in core::iter::zip(&local_acc, N.indices()) {
167						for (result, i) in core::iter::zip(result, body.clone()) {
168							let mut val = simd.read(dst.rb().col(j), i);
169							val = simd.mul_add(alpha, *result, val);
170							simd.write(dst.rb_mut().col_mut(j), i, val);
171						}
172						let i = tail;
173						let result = &result[MR_DIV_N - 1];
174
175						let mut val = simd.read(dst.rb().col(j), i);
176						val = simd.mul_add(alpha, *result, val);
177						simd.write(dst.rb_mut().col_mut(j), i, val);
178					}
179				},
180				Accum::Replace => {
181					for (result, j) in core::iter::zip(&local_acc, N.indices()) {
182						for (result, i) in core::iter::zip(result, body.clone()) {
183							let val = simd.mul(alpha, *result);
184							simd.write(dst.rb_mut().col_mut(j), i, val);
185						}
186
187						let i = tail;
188						let result = &result[MR_DIV_N - 1];
189
190						let val = simd.mul(alpha, *result);
191						simd.write(dst.rb_mut().col_mut(j), i, val);
192					}
193				},
194			}
195		}
196	}
197
198	#[math]
199	pub fn matmul_simd<'M, 'N, 'K, T: ComplexField>(
200		dst: MatMut<'_, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
201		beta: Accum,
202		lhs: MatRef<'_, T, Dim<'M>, Dim<'K>, ContiguousFwd>,
203		conj_lhs: Conj,
204		rhs: MatRef<'_, T, Dim<'K>, Dim<'N>>,
205		conj_rhs: Conj,
206		alpha: &T,
207		par: Par,
208	) {
209		let dst = dst.as_dyn_mut();
210		let lhs = lhs.as_dyn();
211		let rhs = rhs.as_dyn();
212
213		let (m, n) = dst.shape();
214		let k = lhs.ncols();
215
216		let arch = T::Arch::default();
217
218		let lane_count = arch.dispatch(SimdLaneCount::<T> {
219			__marker: core::marker::PhantomData,
220		});
221
222		let nr = MicroKernelShape::<T>::MAX_NR;
223		let mr_div_n = MicroKernelShape::<T>::MAX_MR_DIV_N;
224		let mr = mr_div_n * lane_count;
225
226		let mut col_outer = 0;
227		while col_outer < n {
228			let n_chunk = Ord::min(n - col_outer, NC);
229			let mut beta = beta;
230
231			let mut depth = 0;
232			while depth < k {
233				let k_chunk = Ord::min(k - depth, KC);
234
235				let job = |row: usize, col_inner: usize| {
236					let nrows = Ord::min(m - row, mr);
237					let ukr_i = nrows.div_ceil(lane_count);
238					let ncols = Ord::min(n_chunk - col_inner, nr);
239					let ukr_j = ncols;
240
241					let dst = unsafe { dst.rb().const_cast() }.submatrix_mut(row, col_outer + col_inner, nrows, ncols);
242					let a = lhs.submatrix(row, depth, nrows, k_chunk);
243					let b = rhs.submatrix(depth, col_outer + col_inner, k_chunk, ncols);
244
245					macro_rules! call {
246						($M: expr, $N: expr) => {
247							arch.dispatch(Ukr::<'_, $M, $N, T> {
248								dst,
249								a,
250								b,
251								conj_lhs,
252								conj_rhs,
253								alpha,
254								beta,
255							})
256						};
257					}
258					if const { MicroKernelShape::<T>::IS_2X2 } {
259						match (ukr_i, ukr_j) {
260							(2, 2) => call!(2, 2),
261							(1, 2) => call!(1, 2),
262							(2, 1) => call!(2, 1),
263							(1, 1) => call!(1, 1),
264							_ => unreachable!(),
265						}
266					} else if const { MicroKernelShape::<T>::IS_2X1 } {
267						match (ukr_i, ukr_j) {
268							(2, 1) => call!(2, 1),
269							(1, 1) => call!(1, 1),
270							_ => unreachable!(),
271						}
272					} else if const { MicroKernelShape::<T>::IS_1X1 } {
273						call!(1, 1)
274					} else {
275						unreachable!()
276					}
277				};
278
279				let job_count = m.div_ceil(mr) * n_chunk.div_ceil(nr);
280				let d = n_chunk.div_ceil(nr);
281				match par {
282					Par::Seq => {
283						for job_idx in 0..job_count {
284							let col_inner = nr * (job_idx % d);
285							let row = mr * (job_idx / d);
286							job(row, col_inner);
287						}
288					},
289					#[cfg(feature = "rayon")]
290					Par::Rayon(nthreads) => {
291						let nthreads = nthreads.get();
292						use rayon::prelude::*;
293
294						let job_idx = core::sync::atomic::AtomicUsize::new(0);
295
296						(0..nthreads).into_par_iter().for_each(|_| {
297							loop {
298								let job_idx = job_idx.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
299								if job_idx < job_count {
300									let col_inner = nr * (job_idx % d);
301									let row = mr * (job_idx / d);
302									job(row, col_inner);
303								} else {
304									return;
305								}
306							}
307						});
308					},
309				}
310
311				beta = Accum::Add;
312				depth += k_chunk;
313			}
314			col_outer += n_chunk;
315		}
316	}
317}
318
319mod matmul_horizontal {
320	use super::*;
321	use matmul_shared::*;
322
323	struct Ukr<'a, const MR: usize, const NR: usize, T: ComplexField> {
324		dst: MatMut<'a, T, usize, usize>,
325		a: MatRef<'a, T, usize, usize, isize, ContiguousFwd>,
326		b: MatRef<'a, T, usize, usize, ContiguousFwd, isize>,
327		conj_lhs: Conj,
328		conj_rhs: Conj,
329		alpha: &'a T,
330		beta: Accum,
331	}
332
333	impl<const MR: usize, const NR: usize, T: ComplexField> pulp::WithSimd for Ukr<'_, MR, NR, T> {
334		type Output = ();
335
336		#[math]
337		#[inline(always)]
338		fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
339			let Self {
340				dst,
341				a,
342				b,
343				conj_lhs,
344				conj_rhs,
345				alpha,
346				beta,
347			} = self;
348
349			with_dim!(M, a.nrows());
350			with_dim!(N, b.ncols());
351			with_dim!(K, a.ncols());
352			let a = a.as_shape(M, K);
353			let b = b.as_shape(K, N);
354			let mut dst = dst.as_shape_mut(M, N);
355
356			let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), K);
357			let (_, body, tail) = simd.indices();
358
359			let mut local_acc = [[simd.zero(); MR]; NR];
360			let mut is = [M.idx(0usize); MR];
361			let mut js = [N.idx(0usize); NR];
362
363			for (idx, i) in is.iter_mut().enumerate() {
364				*i = M.idx(idx);
365			}
366			for (idx, j) in js.iter_mut().enumerate() {
367				*j = N.idx(idx);
368			}
369
370			if conj_lhs == conj_rhs {
371				macro_rules! do_it {
372					($depth: expr) => {{
373						let depth = $depth;
374						let a = is.map(
375							#[inline(always)]
376							|i| simd.read(a.row(i).transpose(), depth),
377						);
378						let b = js.map(
379							#[inline(always)]
380							|j| simd.read(b.col(j), depth),
381						);
382
383						for i in 0..MR {
384							for j in 0..NR {
385								local_acc[j][i] = simd.mul_add(b[j], a[i], local_acc[j][i]);
386							}
387						}
388					}};
389				}
390				for depth in body {
391					do_it!(depth);
392				}
393				if let Some(depth) = tail {
394					do_it!(depth);
395				}
396			} else {
397				macro_rules! do_it {
398					($depth: expr) => {{
399						let depth = $depth;
400						let a = is.map(
401							#[inline(always)]
402							|i| simd.read(a.row(i).transpose(), depth),
403						);
404						let b = js.map(
405							#[inline(always)]
406							|j| simd.read(b.col(j), depth),
407						);
408
409						for i in 0..MR {
410							for j in 0..NR {
411								local_acc[j][i] = simd.conj_mul_add(b[j], a[i], local_acc[j][i]);
412							}
413						}
414					}};
415				}
416				for depth in body {
417					do_it!(depth);
418				}
419				if let Some(depth) = tail {
420					do_it!(depth);
421				}
422			}
423
424			if conj_lhs.is_conj() {
425				for x in &mut local_acc {
426					for x in x {
427						*x = simd.conj(*x);
428					}
429				}
430			}
431			let result = local_acc;
432			let result = result.map(
433				#[inline(always)]
434				|result| {
435					result.map(
436						#[inline(always)]
437						|result| simd.reduce_sum(result),
438					)
439				},
440			);
441
442			let alpha = copy(*alpha);
443			match beta {
444				Accum::Add => {
445					for (result, j) in core::iter::zip(&result, js) {
446						for (result, i) in core::iter::zip(result, is) {
447							dst[(i, j)] = alpha * *result + dst[(i, j)];
448						}
449					}
450				},
451				Accum::Replace => {
452					for (result, j) in core::iter::zip(&result, js) {
453						for (result, i) in core::iter::zip(result, is) {
454							dst[(i, j)] = alpha * *result;
455						}
456					}
457				},
458			}
459		}
460	}
461
462	#[math]
463	pub fn matmul_simd<'M, 'N, 'K, T: ComplexField>(
464		dst: MatMut<'_, T, Dim<'M>, Dim<'N>>,
465		beta: Accum,
466		lhs: MatRef<'_, T, Dim<'M>, Dim<'K>, isize, ContiguousFwd>,
467		conj_lhs: Conj,
468		rhs: MatRef<'_, T, Dim<'K>, Dim<'N>, ContiguousFwd, isize>,
469		conj_rhs: Conj,
470		alpha: &T,
471		par: Par,
472	) {
473		let dst = dst.as_dyn_mut();
474		let lhs = lhs.as_dyn();
475		let rhs = rhs.as_dyn();
476
477		let (m, n) = dst.shape();
478		let k = lhs.ncols();
479
480		let nr = MicroKernelShape::<T>::MAX_NR;
481		let mr = MicroKernelShape::<T>::MAX_MR_DIV_N;
482
483		let arch = T::Arch::default();
484
485		let lane_count = arch.dispatch(SimdLaneCount::<T> {
486			__marker: core::marker::PhantomData,
487		});
488		let kc = KC * lane_count;
489
490		let mut col_outer = 0;
491		while col_outer < n {
492			let n_chunk = Ord::min(n - col_outer, NC);
493
494			let mut beta = beta;
495			let mut depth = 0;
496			while depth < k {
497				let k_chunk = Ord::min(k - depth, kc);
498
499				let job = |row: usize, col_inner: usize| {
500					let nrows = Ord::min(m - row, mr);
501					let ukr_i = nrows;
502					let ncols = Ord::min(n_chunk - col_inner, nr);
503					let ukr_j = ncols;
504
505					let dst = unsafe { dst.rb().const_cast() }.submatrix_mut(row, col_outer + col_inner, nrows, ncols);
506					let a = lhs.submatrix(row, depth, nrows, k_chunk);
507					let b = rhs.submatrix(depth, col_outer + col_inner, k_chunk, ncols);
508
509					macro_rules! call {
510						($M: expr, $N: expr) => {
511							arch.dispatch(Ukr::<'_, $M, $N, T> {
512								dst,
513								a,
514								b,
515								conj_lhs,
516								conj_rhs,
517								alpha,
518								beta,
519							})
520						};
521					}
522					if const { MicroKernelShape::<T>::IS_2X2 } {
523						match (ukr_i, ukr_j) {
524							(2, 2) => call!(2, 2),
525							(1, 2) => call!(1, 2),
526							(2, 1) => call!(2, 1),
527							(1, 1) => call!(1, 1),
528							_ => unreachable!(),
529						}
530					} else if const { MicroKernelShape::<T>::IS_2X1 } {
531						match (ukr_i, ukr_j) {
532							(2, 1) => call!(2, 1),
533							(1, 1) => call!(1, 1),
534							_ => unreachable!(),
535						}
536					} else if const { MicroKernelShape::<T>::IS_1X1 } {
537						call!(1, 1)
538					} else {
539						unreachable!()
540					}
541				};
542
543				let job_count = m.div_ceil(mr) * n.div_ceil(nr);
544				let d = n.div_ceil(nr);
545				match par {
546					Par::Seq => {
547						for job_idx in 0..job_count {
548							let col_inner = nr * (job_idx % d);
549							let row = mr * (job_idx / d);
550							job(row, col_inner);
551						}
552					},
553					#[cfg(feature = "rayon")]
554					Par::Rayon(nthreads) => {
555						let nthreads = nthreads.get();
556						use rayon::prelude::*;
557
558						let job_idx = core::sync::atomic::AtomicUsize::new(0);
559
560						(0..nthreads).into_par_iter().for_each(|_| {
561							loop {
562								let job_idx = job_idx.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
563								if job_idx < job_count {
564									let col_inner = nr * (job_idx % d);
565									let row = mr * (job_idx / d);
566									job(row, col_inner);
567								} else {
568									return;
569								}
570							}
571						});
572					},
573				}
574
575				beta = Accum::Add;
576				depth += k_chunk;
577			}
578			col_outer += n_chunk;
579		}
580	}
581}
582
583/// dot product
584pub mod dot {
585	use super::*;
586	use faer_traits::SimdArch;
587
588	/// returns `lhs * rhs`, implicitly conjugating the operands if needed
589	pub fn inner_prod<K: Shape, T: ComplexField>(lhs: RowRef<T, K>, conj_lhs: Conj, rhs: ColRef<T, K>, conj_rhs: Conj) -> T {
590		#[math]
591		pub fn imp<'K, T: ComplexField>(lhs: RowRef<T, Dim<'K>>, conj_lhs: Conj, rhs: ColRef<T, Dim<'K>>, conj_rhs: Conj) -> T {
592			if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
593				if let (Some(lhs), Some(rhs)) = (lhs.try_as_row_major(), rhs.try_as_col_major()) {
594					inner_prod_slice::<T>(lhs.ncols(), lhs.transpose(), conj_lhs, rhs, conj_rhs)
595				} else {
596					inner_prod_schoolbook(lhs, conj_lhs, rhs, conj_rhs)
597				}
598			} else {
599				inner_prod_schoolbook(lhs, conj_lhs, rhs, conj_rhs)
600			}
601		}
602
603		with_dim!(K, lhs.ncols().unbound());
604
605		imp(lhs.as_col_shape(K), conj_lhs, rhs.as_row_shape(K), conj_rhs)
606	}
607
608	#[inline(always)]
609	#[math]
610	fn inner_prod_slice<'K, T: ComplexField>(
611		len: Dim<'K>,
612		lhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
613		conj_lhs: Conj,
614		rhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
615		conj_rhs: Conj,
616	) -> T {
617		struct Impl<'a, 'K, T: ComplexField> {
618			len: Dim<'K>,
619			lhs: ColRef<'a, T, Dim<'K>, ContiguousFwd>,
620			conj_lhs: Conj,
621			rhs: ColRef<'a, T, Dim<'K>, ContiguousFwd>,
622			conj_rhs: Conj,
623		}
624		impl<'a, 'K, T: ComplexField> pulp::WithSimd for Impl<'_, '_, T> {
625			type Output = T;
626
627			#[inline(always)]
628			fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
629				let Self {
630					len,
631					lhs,
632					conj_lhs,
633					rhs,
634					conj_rhs,
635				} = self;
636
637				let simd = SimdCtx::new(T::simd_ctx(simd), len);
638
639				let mut tmp = if conj_lhs == conj_rhs {
640					inner_prod_no_conj_simd::<T, S>(simd, lhs, rhs)
641				} else {
642					inner_prod_conj_lhs_simd::<T, S>(simd, lhs, rhs)
643				};
644
645				if conj_rhs == Conj::Yes {
646					tmp = conj(tmp);
647				}
648				tmp
649			}
650		}
651
652		dispatch!(
653			Impl {
654				len,
655				lhs,
656				rhs,
657				conj_lhs,
658				conj_rhs
659			},
660			Impl,
661			T
662		)
663	}
664
665	#[inline(always)]
666	pub(crate) fn inner_prod_no_conj_simd<'K, T: ComplexField, S: Simd>(
667		simd: SimdCtx<'K, T, S>,
668		lhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
669		rhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
670	) -> T {
671		let mut acc0 = simd.zero();
672		let mut acc1 = simd.zero();
673		let mut acc2 = simd.zero();
674		let mut acc3 = simd.zero();
675
676		let (head, idx4, idx, tail) = simd.batch_indices::<4>();
677
678		if let Some(i0) = head {
679			let l0 = simd.read(lhs, i0);
680			let r0 = simd.read(rhs, i0);
681
682			acc0 = simd.mul_add(l0, r0, acc0);
683		}
684		for [i0, i1, i2, i3] in idx4 {
685			let l0 = simd.read(lhs, i0);
686			let l1 = simd.read(lhs, i1);
687			let l2 = simd.read(lhs, i2);
688			let l3 = simd.read(lhs, i3);
689
690			let r0 = simd.read(rhs, i0);
691			let r1 = simd.read(rhs, i1);
692			let r2 = simd.read(rhs, i2);
693			let r3 = simd.read(rhs, i3);
694
695			acc0 = simd.mul_add(l0, r0, acc0);
696			acc1 = simd.mul_add(l1, r1, acc1);
697			acc2 = simd.mul_add(l2, r2, acc2);
698			acc3 = simd.mul_add(l3, r3, acc3);
699		}
700		for i0 in idx {
701			let l0 = simd.read(lhs, i0);
702			let r0 = simd.read(rhs, i0);
703
704			acc0 = simd.mul_add(l0, r0, acc0);
705		}
706		if let Some(i0) = tail {
707			let l0 = simd.read(lhs, i0);
708			let r0 = simd.read(rhs, i0);
709
710			acc0 = simd.mul_add(l0, r0, acc0);
711		}
712		acc0 = simd.add(acc0, acc1);
713		acc2 = simd.add(acc2, acc3);
714		acc0 = simd.add(acc0, acc2);
715
716		simd.reduce_sum(acc0)
717	}
718
719	#[inline(always)]
720	pub(crate) fn inner_prod_conj_lhs_simd<'K, T: ComplexField, S: Simd>(
721		simd: SimdCtx<'K, T, S>,
722		lhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
723		rhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
724	) -> T {
725		let mut acc0 = simd.zero();
726		let mut acc1 = simd.zero();
727		let mut acc2 = simd.zero();
728		let mut acc3 = simd.zero();
729
730		let (head, idx4, idx, tail) = simd.batch_indices::<4>();
731
732		if let Some(i0) = head {
733			let l0 = simd.read(lhs, i0);
734			let r0 = simd.read(rhs, i0);
735
736			acc0 = simd.conj_mul_add(l0, r0, acc0);
737		}
738		for [i0, i1, i2, i3] in idx4 {
739			let l0 = simd.read(lhs, i0);
740			let l1 = simd.read(lhs, i1);
741			let l2 = simd.read(lhs, i2);
742			let l3 = simd.read(lhs, i3);
743
744			let r0 = simd.read(rhs, i0);
745			let r1 = simd.read(rhs, i1);
746			let r2 = simd.read(rhs, i2);
747			let r3 = simd.read(rhs, i3);
748
749			acc0 = simd.conj_mul_add(l0, r0, acc0);
750			acc1 = simd.conj_mul_add(l1, r1, acc1);
751			acc2 = simd.conj_mul_add(l2, r2, acc2);
752			acc3 = simd.conj_mul_add(l3, r3, acc3);
753		}
754		for i0 in idx {
755			let l0 = simd.read(lhs, i0);
756			let r0 = simd.read(rhs, i0);
757
758			acc0 = simd.conj_mul_add(l0, r0, acc0);
759		}
760		if let Some(i0) = tail {
761			let l0 = simd.read(lhs, i0);
762			let r0 = simd.read(rhs, i0);
763
764			acc0 = simd.conj_mul_add(l0, r0, acc0);
765		}
766		acc0 = simd.add(acc0, acc1);
767		acc2 = simd.add(acc2, acc3);
768		acc0 = simd.add(acc0, acc2);
769
770		simd.reduce_sum(acc0)
771	}
772
773	#[math]
774	pub(crate) fn inner_prod_schoolbook<'K, T: ComplexField>(
775		lhs: RowRef<'_, T, Dim<'K>>,
776		conj_lhs: Conj,
777		rhs: ColRef<'_, T, Dim<'K>>,
778		conj_rhs: Conj,
779	) -> T {
780		let mut acc = zero();
781
782		for k in lhs.ncols().indices() {
783			if try_const! { T::IS_REAL } {
784				acc = lhs[k] * rhs[k] + acc;
785			} else {
786				match (conj_lhs, conj_rhs) {
787					(Conj::No, Conj::No) => {
788						acc = lhs[k] * rhs[k] + acc;
789					},
790					(Conj::No, Conj::Yes) => {
791						acc = lhs[k] * conj(rhs[k]) + acc;
792					},
793					(Conj::Yes, Conj::No) => {
794						acc = conj(lhs[k]) * rhs[k] + acc;
795					},
796					(Conj::Yes, Conj::Yes) => {
797						acc = conj(lhs[k] * rhs[k]) + acc;
798					},
799				}
800			}
801		}
802
803		acc
804	}
805}
806
807mod matvec_rowmajor {
808	use super::*;
809	use crate::col::ColMut;
810	use faer_traits::SimdArch;
811
812	#[math]
813	pub fn matvec<'M, 'K, T: ComplexField>(
814		dst: ColMut<'_, T, Dim<'M>>,
815		beta: Accum,
816		lhs: MatRef<'_, T, Dim<'M>, Dim<'K>, isize, ContiguousFwd>,
817		conj_lhs: Conj,
818		rhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
819		conj_rhs: Conj,
820		alpha: &T,
821		par: Par,
822	) {
823		core::assert!(try_const! { T::SIMD_CAPABILITIES.is_simd() });
824		let size = *lhs.nrows() * *lhs.ncols();
825		let par = if size < 256 * 256usize { Par::Seq } else { par };
826
827		match par {
828			Par::Seq => {
829				pub struct Impl<'a, 'M, 'K, T: ComplexField> {
830					dst: ColMut<'a, T, Dim<'M>>,
831					beta: Accum,
832					lhs: MatRef<'a, T, Dim<'M>, Dim<'K>, isize, ContiguousFwd>,
833					conj_lhs: Conj,
834					rhs: ColRef<'a, T, Dim<'K>, ContiguousFwd>,
835					conj_rhs: Conj,
836					alpha: &'a T,
837				}
838
839				impl<'a, 'M, 'K, T: ComplexField> pulp::WithSimd for Impl<'a, 'M, 'K, T> {
840					type Output = ();
841
842					#[inline(always)]
843					fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
844						let Self {
845							dst,
846							beta,
847							lhs,
848							conj_lhs,
849							rhs,
850							conj_rhs,
851							alpha,
852						} = self;
853						let simd = T::simd_ctx(simd);
854						let mut dst = dst;
855
856						let K = lhs.ncols();
857						let simd = SimdCtx::new(simd, K);
858						for i in lhs.nrows().indices() {
859							let dst = &mut dst[i];
860							let lhs = lhs.row(i);
861							let rhs = rhs;
862							let mut tmp = if conj_lhs == conj_rhs {
863								dot::inner_prod_no_conj_simd::<T, S>(simd, lhs.transpose(), rhs)
864							} else {
865								dot::inner_prod_conj_lhs_simd::<T, S>(simd, lhs.transpose(), rhs)
866							};
867
868							if conj_rhs == Conj::Yes {
869								tmp = conj(tmp);
870							}
871							tmp = *alpha * tmp;
872							if let Accum::Add = beta {
873								tmp = *dst + tmp;
874							}
875							*dst = tmp;
876						}
877					}
878				}
879
880				dispatch!(
881					Impl {
882						dst,
883						beta,
884						lhs,
885						conj_lhs,
886						rhs,
887						conj_rhs,
888						alpha,
889					},
890					Impl,
891					T
892				);
893			},
894			#[cfg(feature = "rayon")]
895			Par::Rayon(nthreads) => {
896				let nthreads = nthreads.get();
897
898				use rayon::prelude::*;
899				dst.par_partition_mut(nthreads)
900					.zip_eq(lhs.par_row_partition(nthreads))
901					.for_each(|(dst, lhs)| {
902						make_guard!(M);
903						let nrows = dst.nrows().bind(M);
904						let dst = dst.as_row_shape_mut(nrows);
905						let lhs = lhs.as_row_shape(nrows);
906
907						matvec(dst, beta, lhs, conj_lhs, rhs, conj_rhs, alpha, Par::Seq);
908					})
909			},
910		}
911	}
912}
913
914mod matvec_colmajor {
915	use super::*;
916	use crate::col::ColMut;
917	use crate::linalg::temp_mat_uninit;
918	use crate::mat::AsMatMut;
919	use crate::utils::bound::IdxInc;
920	use crate::{unzip, zip};
921	use faer_traits::SimdArch;
922
923	#[math]
924	pub fn matvec<'M, 'K, T: ComplexField>(
925		dst: ColMut<'_, T, Dim<'M>, ContiguousFwd>,
926		beta: Accum,
927		lhs: MatRef<'_, T, Dim<'M>, Dim<'K>, ContiguousFwd, isize>,
928		conj_lhs: Conj,
929		rhs: ColRef<'_, T, Dim<'K>>,
930		conj_rhs: Conj,
931		alpha: &T,
932		par: Par,
933	) {
934		core::assert!(try_const! { T::SIMD_CAPABILITIES.is_simd() });
935		let size = *lhs.nrows() * *lhs.ncols();
936		let par = if size < 256 * 256usize { Par::Seq } else { par };
937
938		match par {
939			Par::Seq => {
940				pub struct Impl<'a, 'M, 'K, T: ComplexField> {
941					dst: ColMut<'a, T, Dim<'M>, ContiguousFwd>,
942					beta: Accum,
943					lhs: MatRef<'a, T, Dim<'M>, Dim<'K>, ContiguousFwd, isize>,
944					conj_lhs: Conj,
945					rhs: ColRef<'a, T, Dim<'K>>,
946					conj_rhs: Conj,
947					alpha: &'a T,
948				}
949
950				impl<'a, 'M, 'K, T: ComplexField> pulp::WithSimd for Impl<'a, 'M, 'K, T> {
951					type Output = ();
952
953					#[inline(always)]
954					fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
955						let Self {
956							dst,
957							beta,
958							lhs,
959							conj_lhs,
960							rhs,
961							conj_rhs,
962							alpha,
963						} = self;
964
965						let simd = T::simd_ctx(simd);
966
967						let M = lhs.nrows();
968						let simd = SimdCtx::<T, S>::new(simd, M);
969						let (head, body, tail) = simd.indices();
970
971						let mut dst = dst;
972						match beta {
973							Accum::Add => {},
974							Accum::Replace => {
975								let mut dst = dst.rb_mut();
976								if let Some(i) = head {
977									simd.write(dst.rb_mut(), i, simd.zero());
978								}
979								for i in body.clone() {
980									simd.write(dst.rb_mut(), i, simd.zero());
981								}
982								if let Some(i) = tail {
983									simd.write(dst.rb_mut(), i, simd.zero());
984								}
985							},
986						}
987
988						for j in lhs.ncols().indices() {
989							let mut dst = dst.rb_mut();
990							let lhs = lhs.col(j);
991							let rhs = &rhs[j];
992							let rhs = if conj_rhs == Conj::Yes { conj(*rhs) } else { copy(*rhs) };
993							let rhs = rhs * *alpha;
994
995							let vrhs = simd.splat(&rhs);
996							if conj_lhs == Conj::Yes {
997								if let Some(i) = head {
998									let y = simd.read(dst.rb(), i);
999									let x = simd.read(lhs, i);
1000									simd.write(dst.rb_mut(), i, simd.conj_mul_add(x, vrhs, y));
1001								}
1002								for i in body.clone() {
1003									let y = simd.read(dst.rb(), i);
1004									let x = simd.read(lhs, i);
1005									simd.write(dst.rb_mut(), i, simd.conj_mul_add(x, vrhs, y));
1006								}
1007								if let Some(i) = tail {
1008									let y = simd.read(dst.rb(), i);
1009									let x = simd.read(lhs, i);
1010									simd.write(dst.rb_mut(), i, simd.conj_mul_add(x, vrhs, y));
1011								}
1012							} else {
1013								if let Some(i) = head {
1014									let y = simd.read(dst.rb(), i);
1015									let x = simd.read(lhs, i);
1016									simd.write(dst.rb_mut(), i, simd.mul_add(x, vrhs, y));
1017								}
1018								for i in body.clone() {
1019									let y = simd.read(dst.rb(), i);
1020									let x = simd.read(lhs, i);
1021									simd.write(dst.rb_mut(), i, simd.mul_add(x, vrhs, y));
1022								}
1023								if let Some(i) = tail {
1024									let y = simd.read(dst.rb(), i);
1025									let x = simd.read(lhs, i);
1026									simd.write(dst.rb_mut(), i, simd.mul_add(x, vrhs, y));
1027								}
1028							}
1029						}
1030					}
1031				}
1032
1033				dispatch!(
1034					Impl {
1035						dst,
1036						lhs,
1037						conj_lhs,
1038						rhs,
1039						conj_rhs,
1040						beta,
1041						alpha,
1042					},
1043					Impl,
1044					T
1045				)
1046			},
1047			#[cfg(feature = "rayon")]
1048			Par::Rayon(nthreads) => {
1049				use rayon::prelude::*;
1050				let nthreads = nthreads.get();
1051				let mut mem = MemBuffer::new(temp_mat_scratch::<T>(dst.nrows().unbound(), nthreads));
1052				let stack = MemStack::new(&mut mem);
1053
1054				let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(dst.nrows(), nthreads, stack) };
1055				let mut tmp = tmp.as_mat_mut().try_as_col_major_mut().unwrap();
1056
1057				let mut dst = dst;
1058				make_guard!(Z);
1059				let Z = 0usize.bind(Z);
1060				let z = IdxInc::new_checked(0, lhs.ncols());
1061
1062				tmp.rb_mut()
1063					.par_col_iter_mut()
1064					.zip_eq(lhs.par_col_partition(nthreads))
1065					.zip_eq(rhs.par_partition(nthreads))
1066					.for_each(|((dst, lhs), rhs)| {
1067						make_guard!(K);
1068						let K = lhs.ncols().bind(K);
1069						let lhs = lhs.as_col_shape(K);
1070						let rhs = rhs.as_row_shape(K);
1071
1072						matvec(dst, Accum::Replace, lhs, conj_lhs, rhs, conj_rhs, alpha, Par::Seq);
1073					});
1074
1075				matvec(
1076					dst.rb_mut(),
1077					beta,
1078					lhs.subcols(z, Z),
1079					conj_lhs,
1080					rhs.subrows(z, Z),
1081					conj_rhs,
1082					&zero(),
1083					Par::Seq,
1084				);
1085				for j in 0..nthreads {
1086					zip!(dst.rb_mut(), tmp.rb().col(j)).for_each(|unzip!(dst, src)| *dst = *dst + *src)
1087				}
1088			},
1089		}
1090	}
1091}
1092
1093mod rank_update {
1094	use super::*;
1095	use crate::assert;
1096
1097	#[math]
1098	fn rank_update_imp<'M, 'N, T: ComplexField>(
1099		dst: MatMut<'_, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
1100		beta: Accum,
1101		lhs: ColRef<'_, T, Dim<'M>, ContiguousFwd>,
1102		conj_lhs: Conj,
1103		rhs: RowRef<'_, T, Dim<'N>>,
1104		conj_rhs: Conj,
1105		alpha: &T,
1106	) {
1107		assert!(T::SIMD_CAPABILITIES.is_simd());
1108
1109		struct Impl<'a, 'M, 'N, T: ComplexField> {
1110			dst: MatMut<'a, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
1111			beta: Accum,
1112			lhs: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
1113			conj_lhs: Conj,
1114			rhs: RowRef<'a, T, Dim<'N>>,
1115			conj_rhs: Conj,
1116			alpha: &'a T,
1117		}
1118
1119		impl<T: ComplexField> pulp::WithSimd for Impl<'_, '_, '_, T> {
1120			type Output = ();
1121
1122			#[inline(always)]
1123			fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
1124				let Self {
1125					mut dst,
1126					beta,
1127					lhs,
1128					conj_lhs,
1129					rhs,
1130					conj_rhs,
1131					alpha,
1132				} = self;
1133
1134				let (m, n) = dst.shape();
1135				let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), m);
1136
1137				let (head, body, tail) = simd.indices();
1138
1139				for j in n.indices() {
1140					let mut dst = dst.rb_mut().col_mut(j);
1141
1142					let rhs = *alpha * conj_rhs.apply_rt(&rhs[j]);
1143					let rhs = simd.splat(&rhs);
1144
1145					if conj_lhs.is_conj() {
1146						match beta {
1147							Accum::Add => {
1148								if let Some(i) = head {
1149									let mut acc = simd.read(dst.rb(), i);
1150									acc = simd.conj_mul_add(simd.read(lhs, i), rhs, acc);
1151									simd.write(dst.rb_mut(), i, acc);
1152								}
1153								for i in body.clone() {
1154									let mut acc = simd.read(dst.rb(), i);
1155									acc = simd.conj_mul_add(simd.read(lhs, i), rhs, acc);
1156									simd.write(dst.rb_mut(), i, acc);
1157								}
1158								if let Some(i) = tail {
1159									let mut acc = simd.read(dst.rb(), i);
1160									acc = simd.conj_mul_add(simd.read(lhs, i), rhs, acc);
1161									simd.write(dst.rb_mut(), i, acc);
1162								}
1163							},
1164							Accum::Replace => {
1165								if let Some(i) = head {
1166									let acc = simd.conj_mul(simd.read(lhs, i), rhs);
1167									simd.write(dst.rb_mut(), i, acc);
1168								}
1169								for i in body.clone() {
1170									let acc = simd.conj_mul(simd.read(lhs, i), rhs);
1171									simd.write(dst.rb_mut(), i, acc);
1172								}
1173								if let Some(i) = tail {
1174									let acc = simd.conj_mul(simd.read(lhs, i), rhs);
1175									simd.write(dst.rb_mut(), i, acc);
1176								}
1177							},
1178						}
1179					} else {
1180						match beta {
1181							Accum::Add => {
1182								if let Some(i) = head {
1183									let mut acc = simd.read(dst.rb(), i);
1184									acc = simd.mul_add(simd.read(lhs, i), rhs, acc);
1185									simd.write(dst.rb_mut(), i, acc);
1186								}
1187								for i in body.clone() {
1188									let mut acc = simd.read(dst.rb(), i);
1189									acc = simd.mul_add(simd.read(lhs, i), rhs, acc);
1190									simd.write(dst.rb_mut(), i, acc);
1191								}
1192								if let Some(i) = tail {
1193									let mut acc = simd.read(dst.rb(), i);
1194									acc = simd.mul_add(simd.read(lhs, i), rhs, acc);
1195									simd.write(dst.rb_mut(), i, acc);
1196								}
1197							},
1198							Accum::Replace => {
1199								if let Some(i) = head {
1200									let acc = simd.mul(simd.read(lhs, i), rhs);
1201									simd.write(dst.rb_mut(), i, acc);
1202								}
1203								for i in body.clone() {
1204									let acc = simd.mul(simd.read(lhs, i), rhs);
1205									simd.write(dst.rb_mut(), i, acc);
1206								}
1207								if let Some(i) = tail {
1208									let acc = simd.mul(simd.read(lhs, i), rhs);
1209									simd.write(dst.rb_mut(), i, acc);
1210								}
1211							},
1212						}
1213					}
1214				}
1215			}
1216		}
1217
1218		dispatch!(
1219			Impl {
1220				dst,
1221				lhs,
1222				conj_lhs,
1223				rhs,
1224				conj_rhs,
1225				beta,
1226				alpha,
1227			},
1228			Impl,
1229			T
1230		)
1231	}
1232
1233	#[math]
1234	pub fn rank_update<'M, 'N, T: ComplexField>(
1235		dst: MatMut<'_, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
1236		beta: Accum,
1237		lhs: ColRef<'_, T, Dim<'M>, ContiguousFwd>,
1238		conj_lhs: Conj,
1239		rhs: RowRef<'_, T, Dim<'N>>,
1240		conj_rhs: Conj,
1241		alpha: &T,
1242		par: Par,
1243	) {
1244		match par {
1245			Par::Seq => {
1246				rank_update_imp(dst, beta, lhs, conj_lhs, rhs, conj_rhs, alpha);
1247			},
1248			#[cfg(feature = "rayon")]
1249			Par::Rayon(nthreads) => {
1250				let nthreads = nthreads.get();
1251				use rayon::prelude::*;
1252				dst.par_col_partition_mut(nthreads)
1253					.zip(rhs.par_partition(nthreads))
1254					.for_each(|(dst, rhs)| {
1255						with_dim!(N, dst.ncols());
1256						rank_update_imp(dst.as_col_shape_mut(N), beta, lhs, conj_lhs, rhs.as_col_shape(N), conj_rhs, alpha);
1257					});
1258			},
1259		}
1260	}
1261}
1262
1263#[math]
1264fn matmul_imp<'M, 'N, 'K, T: ComplexField>(
1265	dst: MatMut<'_, T, Dim<'M>, Dim<'N>>,
1266	beta: Accum,
1267	lhs: MatRef<'_, T, Dim<'M>, Dim<'K>>,
1268	conj_lhs: Conj,
1269	rhs: MatRef<'_, T, Dim<'K>, Dim<'N>>,
1270	conj_rhs: Conj,
1271	alpha: &T,
1272	par: Par,
1273) {
1274	let mut dst = dst;
1275
1276	let M = dst.nrows();
1277	let N = dst.ncols();
1278	let K = lhs.ncols();
1279	if *M == 0 || *N == 0 {
1280		return;
1281	}
1282	if *K == 0 {
1283		if beta == Accum::Replace {
1284			dst.fill(zero());
1285		}
1286		return;
1287	}
1288
1289	let mut lhs = lhs;
1290	let mut rhs = rhs;
1291
1292	if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
1293		if dst.row_stride() < 0 {
1294			dst = dst.reverse_rows_mut();
1295			lhs = lhs.reverse_rows();
1296		}
1297		if dst.col_stride() < 0 {
1298			dst = dst.reverse_cols_mut();
1299			rhs = rhs.reverse_cols();
1300		}
1301		if lhs.col_stride() < 0 {
1302			lhs = lhs.reverse_cols();
1303			rhs = rhs.reverse_rows();
1304		}
1305
1306		if dst.ncols().unbound() == 1 {
1307			let first = dst.ncols().check(0);
1308			if let (Some(dst), Some(lhs)) = (dst.rb_mut().try_as_col_major_mut(), lhs.try_as_col_major()) {
1309				matvec_colmajor::matvec(dst.col_mut(first), beta, lhs, conj_lhs, rhs.col(first), conj_rhs, alpha, par);
1310				return;
1311			}
1312
1313			if let (Some(rhs), Some(lhs)) = (rhs.try_as_col_major(), lhs.try_as_row_major()) {
1314				matvec_rowmajor::matvec(dst.col_mut(first), beta, lhs, conj_lhs, rhs.col(first), conj_rhs, alpha, par);
1315				return;
1316			}
1317		}
1318		if dst.nrows().unbound() == 1 {
1319			let mut dst = dst.rb_mut().transpose_mut();
1320			let (rhs, lhs) = (lhs.transpose(), rhs.transpose());
1321			let (conj_rhs, conj_lhs) = (conj_lhs, conj_rhs);
1322
1323			let first = dst.ncols().check(0);
1324			if let (Some(dst), Some(lhs)) = (dst.rb_mut().try_as_col_major_mut(), lhs.try_as_col_major()) {
1325				matvec_colmajor::matvec(dst.col_mut(first), beta, lhs, conj_lhs, rhs.col(first), conj_rhs, alpha, par);
1326				return;
1327			}
1328
1329			if let (Some(rhs), Some(lhs)) = (rhs.try_as_col_major(), lhs.try_as_row_major()) {
1330				matvec_rowmajor::matvec(dst.col_mut(first), beta, lhs, conj_lhs, rhs.col(first), conj_rhs, alpha, par);
1331				return;
1332			}
1333		}
1334		if *K == 1 {
1335			let z = K.idx(0);
1336
1337			if let (Some(dst), Some(lhs)) = (dst.rb_mut().try_as_col_major_mut(), lhs.try_as_col_major()) {
1338				rank_update::rank_update(dst, beta, lhs.col(z), conj_lhs, rhs.row(z), conj_rhs, alpha, par);
1339				return;
1340			}
1341
1342			if let (Some(dst), Some(rhs)) = (dst.rb_mut().try_as_row_major_mut(), rhs.try_as_row_major()) {
1343				let dst = dst.transpose_mut();
1344				let rhs = rhs.row(z).transpose();
1345				let lhs = lhs.col(z).transpose();
1346				rank_update::rank_update(dst, beta, rhs, conj_rhs, lhs, conj_lhs, alpha, par);
1347				return;
1348			}
1349		}
1350		macro_rules! gemm_call {
1351			($kind: ident, $ty: ty, $nanogemm: ident) => {
1352				unsafe {
1353					let dst = core::mem::transmute_copy::<MatMut<'_, T, Dim<'M>, Dim<'N>>, MatMut<'_, $ty, Dim<'M>, Dim<'N>>>(&dst);
1354					let lhs = core::mem::transmute_copy::<MatRef<'_, T, Dim<'M>, Dim<'K>>, MatRef<'_, $ty, Dim<'M>, Dim<'K>>>(&lhs);
1355					let rhs = core::mem::transmute_copy::<MatRef<'_, T, Dim<'K>, Dim<'N>>, MatRef<'_, $ty, Dim<'K>, Dim<'N>>>(&rhs);
1356					let alpha = *core::mem::transmute_copy::<&T, &$ty>(&alpha);
1357
1358					if (*M).saturating_mul(*N).saturating_mul(*K) <= NANO_GEMM_THRESHOLD {
1359						nano_gemm::planless::$nanogemm(
1360							*M,
1361							*N,
1362							*K,
1363							dst.as_ptr_mut(),
1364							dst.row_stride(),
1365							dst.col_stride(),
1366							lhs.as_ptr(),
1367							lhs.row_stride(),
1368							lhs.col_stride(),
1369							rhs.as_ptr(),
1370							rhs.row_stride(),
1371							rhs.col_stride(),
1372							match beta {
1373								Accum::Replace => core::mem::zeroed(),
1374								Accum::Add => 1.0.into(),
1375							},
1376							alpha,
1377							conj_lhs == Conj::Yes,
1378							conj_rhs == Conj::Yes,
1379						);
1380						return;
1381					} else {
1382						#[cfg(all(target_arch = "x86_64", feature = "std"))]
1383						{
1384							use private_gemm_x86::*;
1385
1386							let feat = if std::arch::is_x86_feature_detected!("avx512f") {
1387								Some(InstrSet::Avx512)
1388							} else if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") {
1389								Some(InstrSet::Avx256)
1390							} else {
1391								None
1392							};
1393
1394							if let Some(feat) = feat {
1395								gemm(
1396									DType::$kind,
1397									IType::U64,
1398									feat,
1399									*M,
1400									*N,
1401									*K,
1402									dst.as_ptr_mut() as *mut (),
1403									dst.row_stride(),
1404									dst.col_stride(),
1405									core::ptr::null(),
1406									core::ptr::null(),
1407									DstKind::Full,
1408									match beta {
1409										$crate::Accum::Replace => Accum::Replace,
1410										$crate::Accum::Add => Accum::Add,
1411									},
1412									lhs.as_ptr() as *const (),
1413									lhs.row_stride(),
1414									lhs.col_stride(),
1415									conj_lhs == Conj::Yes,
1416									core::ptr::null(),
1417									0,
1418									rhs.as_ptr() as *const (),
1419									rhs.row_stride(),
1420									rhs.col_stride(),
1421									conj_rhs == Conj::Yes,
1422									&raw const alpha as *const (),
1423									par.degree(),
1424								);
1425								return;
1426							}
1427						}
1428
1429						{
1430							gemm::gemm(
1431								M.unbound(),
1432								N.unbound(),
1433								K.unbound(),
1434								dst.as_ptr_mut(),
1435								dst.col_stride(),
1436								dst.row_stride(),
1437								beta != Accum::Replace,
1438								lhs.as_ptr(),
1439								lhs.col_stride(),
1440								lhs.row_stride(),
1441								rhs.as_ptr(),
1442								rhs.col_stride(),
1443								rhs.row_stride(),
1444								match beta {
1445									Accum::Replace => core::mem::zeroed(),
1446									Accum::Add => 1.0.into(),
1447								},
1448								alpha,
1449								false,
1450								conj_lhs == Conj::Yes,
1451								conj_rhs == Conj::Yes,
1452								match par {
1453									Par::Seq => gemm::Parallelism::None,
1454									#[cfg(feature = "rayon")]
1455									Par::Rayon(nthreads) => gemm::Parallelism::Rayon(nthreads.get()),
1456								},
1457							);
1458
1459							return;
1460						}
1461					}
1462				};
1463			};
1464		}
1465
1466		if try_const! { T::IS_NATIVE_F64 } {
1467			gemm_call!(F64, f64, execute_f64);
1468		}
1469		if try_const! { T::IS_NATIVE_C64 } {
1470			gemm_call!(C64, num_complex::Complex<f64>, execute_c64);
1471		}
1472		if try_const! { T::IS_NATIVE_F32 } {
1473			gemm_call!(F32, f32, execute_f32);
1474		}
1475		if try_const! { T::IS_NATIVE_C32 } {
1476			gemm_call!(C32, num_complex::Complex<f32>, execute_c32);
1477		}
1478
1479		if const { !(T::IS_NATIVE_F64 || T::IS_NATIVE_F32 || T::IS_NATIVE_C64 || T::IS_NATIVE_C32) } {
1480			if let (Some(dst), Some(lhs)) = (dst.rb_mut().try_as_col_major_mut(), lhs.try_as_col_major()) {
1481				matmul_vertical::matmul_simd(dst, beta, lhs, conj_lhs, rhs, conj_rhs, alpha, par);
1482				return;
1483			}
1484			if let (Some(dst), Some(rhs)) = (dst.rb_mut().try_as_row_major_mut(), rhs.try_as_row_major()) {
1485				matmul_vertical::matmul_simd(
1486					dst.transpose_mut(),
1487					beta,
1488					rhs.transpose(),
1489					conj_rhs,
1490					lhs.transpose(),
1491					conj_lhs,
1492					alpha,
1493					par,
1494				);
1495				return;
1496			}
1497			if let (Some(lhs), Some(rhs)) = (lhs.try_as_row_major(), rhs.try_as_col_major()) {
1498				matmul_horizontal::matmul_simd(dst, beta, lhs, conj_lhs, rhs, conj_rhs, alpha, par);
1499				return;
1500			}
1501		}
1502	}
1503
1504	match par {
1505		Par::Seq => {
1506			for j in dst.ncols().indices() {
1507				for i in dst.nrows().indices() {
1508					let dst = &mut dst[(i, j)];
1509
1510					let mut acc = dot::inner_prod_schoolbook(lhs.row(i), conj_lhs, rhs.col(j), conj_rhs);
1511					acc = *alpha * acc;
1512					if let Accum::Add = beta {
1513						acc = *dst + acc;
1514					}
1515					*dst = acc;
1516				}
1517			}
1518		},
1519		#[cfg(feature = "rayon")]
1520		Par::Rayon(nthreads) => {
1521			use rayon::prelude::*;
1522			let nthreads = nthreads.get();
1523
1524			let m = *dst.nrows();
1525			let n = *dst.ncols();
1526			let task_count = m * n;
1527			let task_per_thread = task_count.msrv_div_ceil(nthreads);
1528
1529			let dst = dst.rb();
1530			(0..nthreads).into_par_iter().for_each(|tid| {
1531				let task_idx = tid * task_per_thread;
1532				if task_idx >= task_count {
1533					return;
1534				}
1535				let ntasks = Ord::min(task_per_thread, task_count - task_idx);
1536
1537				for ij in 0..ntasks {
1538					let ij = task_idx + ij;
1539					let i = dst.nrows().check(ij % m);
1540					let j = dst.ncols().check(ij / m);
1541
1542					let mut dst = unsafe { dst.const_cast() };
1543					let dst = &mut dst[(i, j)];
1544
1545					let mut acc = dot::inner_prod_schoolbook(lhs.row(i), conj_lhs, rhs.col(j), conj_rhs);
1546					acc = *alpha * acc;
1547
1548					if let Accum::Add = beta {
1549						acc = *dst + acc;
1550					}
1551					*dst = acc;
1552				}
1553			});
1554		},
1555	}
1556}
1557
1558#[track_caller]
1559fn precondition<M: Shape, N: Shape, K: Shape>(dst_nrows: M, dst_ncols: N, lhs_nrows: M, lhs_ncols: K, rhs_nrows: K, rhs_ncols: N) {
1560	assert!(all(dst_nrows == lhs_nrows, dst_ncols == rhs_ncols, lhs_ncols == rhs_nrows,));
1561}
1562
1563/// computes the matrix product `[beta * acc] + alpha * lhs * rhs` and stores the result in `acc`
1564///
1565/// performs the operation:
1566/// - `acc = alpha * lhs * rhs` if `beta` is `Accum::Replace` (in this case, the preexisting
1567/// values in `acc` are not read)
1568/// - `acc = acc + alpha * lhs * rhs` if `beta` is `Accum::Add`
1569///
1570/// # panics
1571///
1572/// panics if the matrix dimensions are not compatible for matrix multiplication.
1573/// i.e.  
1574///  - `acc.nrows() == lhs.nrows()`
1575///  - `acc.ncols() == rhs.ncols()`
1576///  - `lhs.ncols() == rhs.nrows()`
1577///
1578/// # Example
1579///
1580/// ```
1581/// use faer::linalg::matmul::matmul;
1582/// use faer::{Accum, Conj, Mat, Par, mat, unzip, zip};
1583///
1584/// let lhs = mat![[0.0, 2.0], [1.0, 3.0]];
1585/// let rhs = mat![[4.0, 6.0], [5.0, 7.0]];
1586///
1587/// let mut acc = Mat::<f64>::zeros(2, 2);
1588/// let target = mat![
1589/// 	[
1590/// 		2.5 * (lhs[(0, 0)] * rhs[(0, 0)] + lhs[(0, 1)] * rhs[(1, 0)]),
1591/// 		2.5 * (lhs[(0, 0)] * rhs[(0, 1)] + lhs[(0, 1)] * rhs[(1, 1)]),
1592/// 	],
1593/// 	[
1594/// 		2.5 * (lhs[(1, 0)] * rhs[(0, 0)] + lhs[(1, 1)] * rhs[(1, 0)]),
1595/// 		2.5 * (lhs[(1, 0)] * rhs[(0, 1)] + lhs[(1, 1)] * rhs[(1, 1)]),
1596/// 	],
1597/// ];
1598///
1599/// matmul(&mut acc, Accum::Replace, &lhs, &rhs, 2.5, Par::Seq);
1600///
1601/// zip!(&acc, &target).for_each(|unzip!(acc, target)| assert!((acc - target).abs() < 1e-10));
1602/// ```
1603#[track_caller]
1604#[inline]
1605pub fn matmul<T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>, M: Shape, N: Shape, K: Shape>(
1606	dst: impl AsMatMut<T = T, Rows = M, Cols = N>,
1607	beta: Accum,
1608	lhs: impl AsMatRef<T = LhsT, Rows = M, Cols = K>,
1609	rhs: impl AsMatRef<T = RhsT, Rows = K, Cols = N>,
1610	alpha: T,
1611	par: Par,
1612) {
1613	let mut dst = dst;
1614	let dst = dst.as_mat_mut();
1615	let lhs = lhs.as_mat_ref();
1616	let rhs = rhs.as_mat_ref();
1617
1618	precondition(dst.nrows(), dst.ncols(), lhs.nrows(), lhs.ncols(), rhs.nrows(), rhs.ncols());
1619
1620	make_guard!(M);
1621	make_guard!(N);
1622	make_guard!(K);
1623	let M = dst.nrows().bind(M);
1624	let N = dst.ncols().bind(N);
1625	let K = lhs.ncols().bind(K);
1626
1627	matmul_imp(
1628		dst.as_dyn_stride_mut().as_shape_mut(M, N),
1629		beta,
1630		lhs.as_dyn_stride().canonical().as_shape(M, K),
1631		try_const! { Conj::get::<LhsT>() },
1632		rhs.as_dyn_stride().canonical().as_shape(K, N),
1633		try_const! { Conj::get::<RhsT>() },
1634		&alpha,
1635		par,
1636	);
1637}
1638
1639/// computes the matrix product `[beta * acc] + alpha * lhs * rhs` (implicitly conjugating the
1640/// operands if needed) and stores the result in `acc`
1641///
1642/// performs the operation:
1643/// - `acc = alpha * lhs * rhs` if `beta` is `Accum::Replace` (in this case, the preexisting
1644/// values in `acc` are not read)
1645/// - `acc = acc + alpha * lhs * rhs` if `beta` is `Accum::Add`
1646///
1647/// # panics
1648///
1649/// panics if the matrix dimensions are not compatible for matrix multiplication.
1650/// i.e.  
1651///  - `acc.nrows() == lhs.nrows()`
1652///  - `acc.ncols() == rhs.ncols()`
1653///  - `lhs.ncols() == rhs.nrows()`
1654///
1655/// # example
1656///
1657/// ```
1658/// use faer::linalg::matmul::matmul_with_conj;
1659/// use faer::{Accum, Conj, Mat, Par, mat, unzip, zip};
1660///
1661/// let lhs = mat![[0.0, 2.0], [1.0, 3.0]];
1662/// let rhs = mat![[4.0, 6.0], [5.0, 7.0]];
1663///
1664/// let mut acc = Mat::<f64>::zeros(2, 2);
1665/// let target = mat![
1666/// 	[
1667/// 		2.5 * (lhs[(0, 0)] * rhs[(0, 0)] + lhs[(0, 1)] * rhs[(1, 0)]),
1668/// 		2.5 * (lhs[(0, 0)] * rhs[(0, 1)] + lhs[(0, 1)] * rhs[(1, 1)]),
1669/// 	],
1670/// 	[
1671/// 		2.5 * (lhs[(1, 0)] * rhs[(0, 0)] + lhs[(1, 1)] * rhs[(1, 0)]),
1672/// 		2.5 * (lhs[(1, 0)] * rhs[(0, 1)] + lhs[(1, 1)] * rhs[(1, 1)]),
1673/// 	],
1674/// ];
1675///
1676/// matmul_with_conj(
1677/// 	&mut acc,
1678/// 	Accum::Replace,
1679/// 	&lhs,
1680/// 	Conj::No,
1681/// 	&rhs,
1682/// 	Conj::No,
1683/// 	2.5,
1684/// 	Par::Seq,
1685/// );
1686///
1687/// zip!(&acc, &target).for_each(|unzip!(acc, target)| assert!((acc - target).abs() < 1e-10));
1688/// ```
1689#[track_caller]
1690#[inline]
1691pub fn matmul_with_conj<T: ComplexField, M: Shape, N: Shape, K: Shape>(
1692	dst: impl AsMatMut<T = T, Rows = M, Cols = N>,
1693	beta: Accum,
1694	lhs: impl AsMatRef<T = T, Rows = M, Cols = K>,
1695	conj_lhs: Conj,
1696	rhs: impl AsMatRef<T = T, Rows = K, Cols = N>,
1697	conj_rhs: Conj,
1698	alpha: T,
1699	par: Par,
1700) {
1701	let mut dst = dst;
1702	let dst = dst.as_mat_mut();
1703	let lhs = lhs.as_mat_ref();
1704	let rhs = rhs.as_mat_ref();
1705
1706	precondition(dst.nrows(), dst.ncols(), lhs.nrows(), lhs.ncols(), rhs.nrows(), rhs.ncols());
1707
1708	make_guard!(M);
1709	make_guard!(N);
1710	make_guard!(K);
1711	let M = dst.nrows().bind(M);
1712	let N = dst.ncols().bind(N);
1713	let K = lhs.ncols().bind(K);
1714
1715	matmul_imp(
1716		dst.as_dyn_stride_mut().as_shape_mut(M, N),
1717		beta,
1718		lhs.as_dyn_stride().canonical().as_shape(M, K),
1719		conj_lhs,
1720		rhs.as_dyn_stride().canonical().as_shape(K, N),
1721		conj_rhs,
1722		&alpha,
1723		par,
1724	);
1725}
1726
1727#[cfg(test)]
1728mod tests {
1729	use crate::c32;
1730	use std::num::NonZeroUsize;
1731
1732	use super::triangular::{BlockStructure, DiagonalKind};
1733	use super::*;
1734	use crate::assert;
1735	use crate::mat::{Mat, MatMut, MatRef};
1736	use crate::stats::prelude::*;
1737
1738	#[test]
1739	#[ignore = "takes too long"]
1740	fn test_matmul() {
1741		let rng = &mut StdRng::seed_from_u64(0);
1742
1743		if option_env!("CI") == Some("true") {
1744			// too big for CI
1745			return;
1746		}
1747
1748		let betas = [Accum::Replace, Accum::Add];
1749
1750		#[cfg(not(miri))]
1751		let bools = [false, true];
1752		#[cfg(not(miri))]
1753		let alphas = [c32::ONE, c32::ZERO, c32::new(21.04, -12.13)];
1754		#[cfg(not(miri))]
1755		let par = [Par::Seq, Par::Rayon(NonZeroUsize::new(4).unwrap())];
1756		#[cfg(not(miri))]
1757		let conjs = [Conj::Yes, Conj::No];
1758
1759		#[cfg(miri)]
1760		let bools = [true];
1761		#[cfg(miri)]
1762		let alphas = [c32::new(0.3218, -1.217489)];
1763		#[cfg(miri)]
1764		let par = [Par::Seq];
1765		#[cfg(miri)]
1766		let conjs = [Conj::Yes];
1767
1768		let big0 = 127;
1769		let big1 = 128;
1770		let big2 = 129;
1771
1772		let mid0 = 15;
1773		let mid1 = 16;
1774		let mid2 = 17;
1775		for (m, n, k) in [
1776			(big0, big1, 5),
1777			(big1, big0, 5),
1778			(big0, big2, 5),
1779			(big2, big0, 5),
1780			(mid0, mid0, 5),
1781			(mid1, mid1, 5),
1782			(mid2, mid2, 5),
1783			(mid0, mid1, 5),
1784			(mid1, mid0, 5),
1785			(mid0, mid2, 5),
1786			(mid2, mid0, 5),
1787			(mid0, 1, 1),
1788			(1, mid0, 1),
1789			(1, 1, mid0),
1790			(1, mid0, mid0),
1791			(mid0, 1, mid0),
1792			(mid0, mid0, 1),
1793			(1, 1, 1),
1794		] {
1795			let distribution = ComplexDistribution::new(StandardNormal, StandardNormal);
1796			let a = CwiseMatDistribution {
1797				nrows: m,
1798				ncols: k,
1799				dist: distribution,
1800			}
1801			.rand::<Mat<c32>>(rng);
1802			let b = CwiseMatDistribution {
1803				nrows: k,
1804				ncols: n,
1805				dist: distribution,
1806			}
1807			.rand::<Mat<c32>>(rng);
1808			let mut acc_init = CwiseMatDistribution {
1809				nrows: m,
1810				ncols: n,
1811				dist: distribution,
1812			}
1813			.rand::<Mat<c32>>(rng);
1814
1815			let a = a.as_ref();
1816			let b = b.as_ref();
1817
1818			for reverse_acc_cols in bools {
1819				for reverse_acc_rows in bools {
1820					for reverse_b_cols in bools {
1821						for reverse_b_rows in bools {
1822							for reverse_a_cols in bools {
1823								for reverse_a_rows in bools {
1824									for a_colmajor in bools {
1825										for b_colmajor in bools {
1826											for acc_colmajor in bools {
1827												let a = if a_colmajor { a } else { a.transpose() };
1828												let mut a = if a_colmajor { a } else { a.transpose() };
1829
1830												let b = if b_colmajor { b } else { b.transpose() };
1831												let mut b = if b_colmajor { b } else { b.transpose() };
1832
1833												if reverse_a_rows {
1834													a = a.reverse_rows();
1835												}
1836												if reverse_a_cols {
1837													a = a.reverse_cols();
1838												}
1839												if reverse_b_rows {
1840													b = b.reverse_rows();
1841												}
1842												if reverse_b_cols {
1843													b = b.reverse_cols();
1844												}
1845												for conj_a in conjs {
1846													for conj_b in conjs {
1847														for par in par {
1848															for beta in betas {
1849																for alpha in alphas {
1850																	test_matmul_impl(
1851																		reverse_acc_cols,
1852																		reverse_acc_rows,
1853																		acc_colmajor,
1854																		m,
1855																		n,
1856																		conj_a,
1857																		conj_b,
1858																		par,
1859																		beta,
1860																		alpha,
1861																		acc_init.as_mut(),
1862																		a,
1863																		b,
1864																	);
1865																}
1866															}
1867														}
1868													}
1869												}
1870											}
1871										}
1872									}
1873								}
1874							}
1875						}
1876					}
1877				}
1878			}
1879		}
1880	}
1881
1882	#[math]
1883	fn matmul_with_conj_fallback<T: Copy + ComplexField>(
1884		acc: MatMut<'_, T>,
1885		a: MatRef<'_, T>,
1886		conj_a: Conj,
1887		b: MatRef<'_, T>,
1888		conj_b: Conj,
1889		beta: Accum,
1890		alpha: T,
1891	) {
1892		let m = acc.nrows();
1893		let n = acc.ncols();
1894		let k = a.ncols();
1895
1896		let job = |idx: usize| {
1897			let i = idx % m;
1898			let j = idx / m;
1899			let acc = acc.rb().submatrix(i, j, 1, 1);
1900			let mut acc = unsafe { acc.const_cast() };
1901
1902			let mut local_acc = zero::<T>();
1903			for depth in 0..k {
1904				let a = &a[(i, depth)];
1905				let b = &b[(depth, j)];
1906				local_acc = local_acc
1907					+ match conj_a {
1908						Conj::Yes => conj(*a),
1909						Conj::No => copy(*a),
1910					} * match conj_b {
1911						Conj::Yes => conj(*b),
1912						Conj::No => copy(*b),
1913					}
1914			}
1915			match beta {
1916				Accum::Add => acc[(0, 0)] = acc[(0, 0)] + local_acc * alpha,
1917				Accum::Replace => acc[(0, 0)] = local_acc * alpha,
1918			}
1919		};
1920
1921		for i in 0..m * n {
1922			job(i);
1923		}
1924	}
1925
1926	#[math]
1927	fn test_matmul_impl(
1928		reverse_acc_cols: bool,
1929		reverse_acc_rows: bool,
1930		acc_colmajor: bool,
1931		m: usize,
1932		n: usize,
1933		conj_a: Conj,
1934		conj_b: Conj,
1935		par: Par,
1936		beta: Accum,
1937		alpha: c32,
1938		acc_init: MatMut<c32>,
1939		a: MatRef<c32>,
1940		b: MatRef<c32>,
1941	) {
1942		let acc = if acc_colmajor { acc_init } else { acc_init.transpose_mut() };
1943
1944		let mut acc = if acc_colmajor { acc } else { acc.transpose_mut() };
1945		if reverse_acc_rows {
1946			acc = acc.reverse_rows_mut();
1947		}
1948		if reverse_acc_cols {
1949			acc = acc.reverse_cols_mut();
1950		}
1951
1952		let mut target = acc.rb().to_owned();
1953		matmul_with_conj_fallback(target.as_mut(), a, conj_a, b, conj_b, beta, alpha);
1954		let target = target.rb();
1955
1956		{
1957			let mut acc = acc.cloned();
1958			let a = a.cloned();
1959
1960			{
1961				with_dim!(M, a.nrows());
1962				with_dim!(N, b.ncols());
1963				with_dim!(K, a.ncols());
1964				let mut acc = acc.rb_mut().as_shape_mut(M, N);
1965				let a = a.as_shape(M, K);
1966				let b = b.as_shape(K, N);
1967
1968				matmul_vertical::matmul_simd(
1969					acc.rb_mut().try_as_col_major_mut().unwrap(),
1970					beta,
1971					a.try_as_col_major().unwrap(),
1972					conj_a,
1973					b,
1974					conj_b,
1975					&alpha,
1976					par,
1977				);
1978			}
1979			for j in 0..n {
1980				for i in 0..m {
1981					let acc = acc[(i, j)];
1982					let target = target[(i, j)];
1983					assert!(abs(acc.re - target.re) < 1e-3);
1984					assert!(abs(acc.im - target.im) < 1e-3);
1985				}
1986			}
1987		}
1988		{
1989			let mut acc = acc.cloned();
1990			let a = a.transpose().cloned();
1991			let a = a.transpose();
1992
1993			let b = b.cloned();
1994
1995			{
1996				with_dim!(M, a.nrows());
1997				with_dim!(N, b.ncols());
1998				with_dim!(K, a.ncols());
1999				let mut acc = acc.rb_mut().as_shape_mut(M, N);
2000				let a = a.as_shape(M, K);
2001				let b = b.as_shape(K, N);
2002
2003				matmul_horizontal::matmul_simd(
2004					acc.rb_mut(),
2005					beta,
2006					a.try_as_row_major().unwrap(),
2007					conj_a,
2008					b.try_as_col_major().unwrap(),
2009					conj_b,
2010					&alpha,
2011					par,
2012				);
2013			}
2014			for j in 0..n {
2015				for i in 0..m {
2016					let acc = acc[(i, j)];
2017					let target = target[(i, j)];
2018					assert!(abs(acc.re - target.re) < 1e-3);
2019					assert!(abs(acc.im - target.im) < 1e-3);
2020				}
2021			}
2022		}
2023
2024		matmul_with_conj(acc.rb_mut(), beta, a, conj_a, b, conj_b, alpha, par);
2025		for j in 0..n {
2026			for i in 0..m {
2027				let acc = acc[(i, j)];
2028				let target = target[(i, j)];
2029				assert!(abs(acc.re - target.re) < 1e-3);
2030				assert!(abs(acc.im - target.im) < 1e-3);
2031			}
2032		}
2033	}
2034
2035	fn generate_structured_matrix(is_dst: bool, nrows: usize, ncols: usize, structure: BlockStructure) -> Mat<f64> {
2036		let rng = &mut StdRng::seed_from_u64(0);
2037		let mut mat = CwiseMatDistribution {
2038			nrows,
2039			ncols,
2040			dist: StandardNormal,
2041		}
2042		.rand::<Mat<f64>>(rng);
2043
2044		if !is_dst {
2045			let kind = structure.diag_kind();
2046			if structure.is_lower() {
2047				for j in 0..ncols {
2048					for i in 0..j {
2049						mat[(i, j)] = 0.0;
2050					}
2051				}
2052			} else if structure.is_upper() {
2053				for j in 0..ncols {
2054					for i in j + 1..nrows {
2055						mat[(i, j)] = 0.0;
2056					}
2057				}
2058			}
2059
2060			match kind {
2061				triangular::DiagonalKind::Zero => {
2062					for i in 0..nrows {
2063						mat[(i, i)] = 0.0;
2064					}
2065				},
2066				triangular::DiagonalKind::Unit => {
2067					for i in 0..nrows {
2068						mat[(i, i)] = 1.0;
2069					}
2070				},
2071				triangular::DiagonalKind::Generic => (),
2072			}
2073		}
2074		mat
2075	}
2076
2077	fn run_test_problem(m: usize, n: usize, k: usize, dst_structure: BlockStructure, lhs_structure: BlockStructure, rhs_structure: BlockStructure) {
2078		let mut dst = generate_structured_matrix(true, m, n, dst_structure);
2079		let mut dst_target = dst.as_ref().to_owned();
2080		let dst_orig = dst.as_ref().to_owned();
2081		let lhs = generate_structured_matrix(false, m, k, lhs_structure);
2082		let rhs = generate_structured_matrix(false, k, n, rhs_structure);
2083
2084		for par in [Par::Seq, Par::rayon(8)] {
2085			triangular::matmul_with_conj(
2086				dst.as_mut(),
2087				dst_structure,
2088				Accum::Replace,
2089				lhs.as_ref(),
2090				lhs_structure,
2091				Conj::No,
2092				rhs.as_ref(),
2093				rhs_structure,
2094				Conj::No,
2095				2.5,
2096				par,
2097			);
2098
2099			matmul_with_conj(
2100				dst_target.as_mut(),
2101				Accum::Replace,
2102				lhs.as_ref(),
2103				Conj::No,
2104				rhs.as_ref(),
2105				Conj::No,
2106				2.5,
2107				par,
2108			);
2109
2110			if dst_structure.is_dense() {
2111				for j in 0..n {
2112					for i in 0..m {
2113						assert!((dst[(i, j)] - dst_target[(i, j)]).abs() < 1e-10);
2114					}
2115				}
2116			} else if dst_structure.is_lower() {
2117				for j in 0..n {
2118					if matches!(dst_structure.diag_kind(), DiagonalKind::Generic) {
2119						for i in 0..j {
2120							assert!((dst[(i, j)] - dst_orig[(i, j)]).abs() < 1e-10);
2121						}
2122						for i in j..n {
2123							assert!((dst[(i, j)] - dst_target[(i, j)]).abs() < 1e-10);
2124						}
2125					} else {
2126						for i in 0..=j {
2127							assert!((dst[(i, j)] - dst_orig[(i, j)]).abs() < 1e-10);
2128						}
2129						for i in j + 1..n {
2130							assert!((dst[(i, j)] - dst_target[(i, j)]).abs() < 1e-10);
2131						}
2132					}
2133				}
2134			} else {
2135				for j in 0..n {
2136					if matches!(dst_structure.diag_kind(), DiagonalKind::Generic) {
2137						for i in 0..=j {
2138							assert!((dst[(i, j)] - dst_target[(i, j)]).abs() < 1e-10);
2139						}
2140						for i in j + 1..n {
2141							assert!((dst[(i, j)] - dst_orig[(i, j)]).abs() < 1e-10);
2142						}
2143					} else {
2144						for i in 0..j {
2145							assert!((dst[(i, j)] - dst_target[(i, j)]).abs() < 1e-10);
2146						}
2147						for i in j..n {
2148							assert!((dst[(i, j)] - dst_orig[(i, j)]).abs() < 1e-10);
2149						}
2150					}
2151				}
2152			}
2153		}
2154	}
2155
2156	#[test]
2157	fn test_triangular() {
2158		use BlockStructure::*;
2159		let structures = [
2160			Rectangular,
2161			TriangularLower,
2162			TriangularUpper,
2163			StrictTriangularLower,
2164			StrictTriangularUpper,
2165			UnitTriangularLower,
2166			UnitTriangularUpper,
2167		];
2168
2169		for dst in structures {
2170			for lhs in structures {
2171				for rhs in structures {
2172					#[cfg(not(miri))]
2173					let big = 100;
2174
2175					#[cfg(miri)]
2176					let big = 31;
2177					for _ in 0..3 {
2178						let m = rand::random::<usize>() % big;
2179						let mut n = rand::random::<usize>() % big;
2180						let mut k = rand::random::<usize>() % big;
2181
2182						match (!dst.is_dense(), !lhs.is_dense(), !rhs.is_dense()) {
2183							(true, true, _) | (true, _, true) | (_, true, true) => {
2184								n = m;
2185								k = m;
2186							},
2187							_ => (),
2188						}
2189
2190						if !dst.is_dense() {
2191							n = m;
2192						}
2193
2194						if !lhs.is_dense() {
2195							k = m;
2196						}
2197
2198						if !rhs.is_dense() {
2199							k = n;
2200						}
2201
2202						run_test_problem(m, n, k, dst, lhs, rhs);
2203					}
2204				}
2205			}
2206		}
2207	}
2208}