1use super::temp_mat_scratch;
4use crate::col::ColRef;
5use crate::internal_prelude::*;
6use crate::mat::{MatMut, MatRef};
7use crate::row::RowRef;
8use crate::utils::bound::Dim;
9use crate::utils::simd::SimdCtx;
10use crate::{Conj, ContiguousFwd, Par, Shape};
11use core::mem::MaybeUninit;
12use dyn_stack::{MemBuffer, MemStack};
13use equator::assert;
14use faer_macros::math;
15use faer_traits::{ByRef, ComplexField, Conjugate};
16use pulp::Simd;
17use reborrow::*;
18
19const NANO_GEMM_THRESHOLD: usize = 16 * 16 * 16;
20
21pub(crate) mod internal;
22
23pub mod triangular;
26
27mod matmul_shared {
28 use super::*;
29
30 pub const NC: usize = 2048;
31 pub const KC: usize = 128;
32
33 pub struct SimdLaneCount<T: ComplexField> {
34 pub __marker: core::marker::PhantomData<fn() -> T>,
35 }
36 impl<T: ComplexField> pulp::WithSimd for SimdLaneCount<T> {
37 type Output = usize;
38
39 #[inline(always)]
40 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
41 let _ = simd;
42 core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>()
43 }
44 }
45
46 pub struct MicroKernelShape<T: ComplexField> {
47 pub __marker: core::marker::PhantomData<fn() -> T>,
48 }
49
50 impl<T: ComplexField> MicroKernelShape<T> {
51 pub const IS_1X1: bool = Self::MAX_MR_DIV_N == 2 && Self::MAX_NR == 1;
52 pub const IS_2X1: bool = Self::MAX_MR_DIV_N == 2 && Self::MAX_NR == 1;
53 pub const IS_2X2: bool = Self::MAX_MR_DIV_N == 2 && Self::MAX_NR == 2;
54 pub const MAX_MR_DIV_N: usize = Self::SHAPE.0;
55 pub const MAX_NR: usize = Self::SHAPE.1;
56 pub const SHAPE: (usize, usize) = {
57 if const { size_of::<T>() / size_of::<T::Unit>() <= 2 } {
58 (2, 2)
59 } else if const { size_of::<T>() / size_of::<T::Unit>() == 4 } {
60 (2, 1)
61 } else {
62 (1, 1)
63 }
64 };
65 }
66}
67
68mod matmul_vertical {
69 use super::*;
70 use matmul_shared::*;
71
72 struct Ukr<'a, const MR_DIV_N: usize, const NR: usize, T: ComplexField> {
73 dst: MatMut<'a, T, usize, usize, ContiguousFwd>,
74 a: MatRef<'a, T, usize, usize, ContiguousFwd>,
75 b: MatRef<'a, T, usize, usize>,
76 conj_lhs: Conj,
77 conj_rhs: Conj,
78 alpha: &'a T,
79 beta: Accum,
80 }
81
82 impl<const MR_DIV_N: usize, const NR: usize, T: ComplexField> pulp::WithSimd for Ukr<'_, MR_DIV_N, NR, T> {
83 type Output = ();
84
85 #[inline(always)]
86 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
87 let Self {
88 dst,
89 a,
90 b,
91 conj_lhs,
92 conj_rhs,
93 alpha,
94 beta,
95 } = self;
96
97 with_dim!(M, a.nrows());
98 with_dim!(N, b.ncols());
99 with_dim!(K, a.ncols());
100 let a = a.as_shape(M, K);
101 let b = b.as_shape(K, N);
102 let mut dst = dst.as_shape_mut(M, N);
103
104 let simd = SimdCtx::<T, S>::new_force_mask(T::simd_ctx(simd), M);
105 let (_, body, tail) = simd.indices();
106 let tail = tail.unwrap();
107
108 let mut local_acc = [[simd.zero(); MR_DIV_N]; NR];
109
110 if conj_lhs == conj_rhs {
111 for depth in K.indices() {
112 let mut a_uninit = [MaybeUninit::<T::SimdVec<S>>::uninit(); MR_DIV_N];
113
114 for (dst, src) in core::iter::zip(&mut a_uninit, body.clone()) {
115 *dst = MaybeUninit::new(simd.read(a.col(depth), src));
116 }
117 a_uninit[MR_DIV_N - 1] = MaybeUninit::new(simd.read(a.col(depth), tail));
118
119 let a: [T::SimdVec<S>; MR_DIV_N] =
120 unsafe { crate::hacks::transmute::<[MaybeUninit<T::SimdVec<S>>; MR_DIV_N], [T::SimdVec<S>; MR_DIV_N]>(a_uninit) };
121
122 for j in N.indices() {
123 let b = simd.splat(&b[(depth, j)]);
124
125 for i in 0..MR_DIV_N {
126 let local_acc = &mut local_acc[*j][i];
127 *local_acc = simd.mul_add(b, a[i], *local_acc);
128 }
129 }
130 }
131 } else {
132 for depth in K.indices() {
133 let mut a_uninit = [MaybeUninit::<T::SimdVec<S>>::uninit(); MR_DIV_N];
134
135 for (dst, src) in core::iter::zip(&mut a_uninit, body.clone()) {
136 *dst = MaybeUninit::new(simd.read(a.col(depth), src));
137 }
138 a_uninit[MR_DIV_N - 1] = MaybeUninit::new(simd.read(a.col(depth), tail));
139
140 let a: [T::SimdVec<S>; MR_DIV_N] =
141 unsafe { crate::hacks::transmute::<[MaybeUninit<T::SimdVec<S>>; MR_DIV_N], [T::SimdVec<S>; MR_DIV_N]>(a_uninit) };
142
143 for j in N.indices() {
144 let b = simd.splat(&b[(depth, j)]);
145
146 for i in 0..MR_DIV_N {
147 let local_acc = &mut local_acc[*j][i];
148 *local_acc = simd.conj_mul_add(b, a[i], *local_acc);
149 }
150 }
151 }
152 }
153
154 if conj_lhs.is_conj() {
155 for x in &mut local_acc {
156 for x in x {
157 *x = simd.conj(*x);
158 }
159 }
160 }
161
162 let alpha = simd.splat(alpha);
163
164 match beta {
165 Accum::Add => {
166 for (result, j) in core::iter::zip(&local_acc, N.indices()) {
167 for (result, i) in core::iter::zip(result, body.clone()) {
168 let mut val = simd.read(dst.rb().col(j), i);
169 val = simd.mul_add(alpha, *result, val);
170 simd.write(dst.rb_mut().col_mut(j), i, val);
171 }
172 let i = tail;
173 let result = &result[MR_DIV_N - 1];
174
175 let mut val = simd.read(dst.rb().col(j), i);
176 val = simd.mul_add(alpha, *result, val);
177 simd.write(dst.rb_mut().col_mut(j), i, val);
178 }
179 },
180 Accum::Replace => {
181 for (result, j) in core::iter::zip(&local_acc, N.indices()) {
182 for (result, i) in core::iter::zip(result, body.clone()) {
183 let val = simd.mul(alpha, *result);
184 simd.write(dst.rb_mut().col_mut(j), i, val);
185 }
186
187 let i = tail;
188 let result = &result[MR_DIV_N - 1];
189
190 let val = simd.mul(alpha, *result);
191 simd.write(dst.rb_mut().col_mut(j), i, val);
192 }
193 },
194 }
195 }
196 }
197
198 #[math]
199 pub fn matmul_simd<'M, 'N, 'K, T: ComplexField>(
200 dst: MatMut<'_, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
201 beta: Accum,
202 lhs: MatRef<'_, T, Dim<'M>, Dim<'K>, ContiguousFwd>,
203 conj_lhs: Conj,
204 rhs: MatRef<'_, T, Dim<'K>, Dim<'N>>,
205 conj_rhs: Conj,
206 alpha: &T,
207 par: Par,
208 ) {
209 let dst = dst.as_dyn_mut();
210 let lhs = lhs.as_dyn();
211 let rhs = rhs.as_dyn();
212
213 let (m, n) = dst.shape();
214 let k = lhs.ncols();
215
216 let arch = T::Arch::default();
217
218 let lane_count = arch.dispatch(SimdLaneCount::<T> {
219 __marker: core::marker::PhantomData,
220 });
221
222 let nr = MicroKernelShape::<T>::MAX_NR;
223 let mr_div_n = MicroKernelShape::<T>::MAX_MR_DIV_N;
224 let mr = mr_div_n * lane_count;
225
226 let mut col_outer = 0;
227 while col_outer < n {
228 let n_chunk = Ord::min(n - col_outer, NC);
229 let mut beta = beta;
230
231 let mut depth = 0;
232 while depth < k {
233 let k_chunk = Ord::min(k - depth, KC);
234
235 let job = |row: usize, col_inner: usize| {
236 let nrows = Ord::min(m - row, mr);
237 let ukr_i = nrows.div_ceil(lane_count);
238 let ncols = Ord::min(n_chunk - col_inner, nr);
239 let ukr_j = ncols;
240
241 let dst = unsafe { dst.rb().const_cast() }.submatrix_mut(row, col_outer + col_inner, nrows, ncols);
242 let a = lhs.submatrix(row, depth, nrows, k_chunk);
243 let b = rhs.submatrix(depth, col_outer + col_inner, k_chunk, ncols);
244
245 macro_rules! call {
246 ($M: expr, $N: expr) => {
247 arch.dispatch(Ukr::<'_, $M, $N, T> {
248 dst,
249 a,
250 b,
251 conj_lhs,
252 conj_rhs,
253 alpha,
254 beta,
255 })
256 };
257 }
258 if const { MicroKernelShape::<T>::IS_2X2 } {
259 match (ukr_i, ukr_j) {
260 (2, 2) => call!(2, 2),
261 (1, 2) => call!(1, 2),
262 (2, 1) => call!(2, 1),
263 (1, 1) => call!(1, 1),
264 _ => unreachable!(),
265 }
266 } else if const { MicroKernelShape::<T>::IS_2X1 } {
267 match (ukr_i, ukr_j) {
268 (2, 1) => call!(2, 1),
269 (1, 1) => call!(1, 1),
270 _ => unreachable!(),
271 }
272 } else if const { MicroKernelShape::<T>::IS_1X1 } {
273 call!(1, 1)
274 } else {
275 unreachable!()
276 }
277 };
278
279 let job_count = m.div_ceil(mr) * n_chunk.div_ceil(nr);
280 let d = n_chunk.div_ceil(nr);
281 match par {
282 Par::Seq => {
283 for job_idx in 0..job_count {
284 let col_inner = nr * (job_idx % d);
285 let row = mr * (job_idx / d);
286 job(row, col_inner);
287 }
288 },
289 #[cfg(feature = "rayon")]
290 Par::Rayon(nthreads) => {
291 let nthreads = nthreads.get();
292 use rayon::prelude::*;
293
294 let job_idx = core::sync::atomic::AtomicUsize::new(0);
295
296 (0..nthreads).into_par_iter().for_each(|_| {
297 loop {
298 let job_idx = job_idx.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
299 if job_idx < job_count {
300 let col_inner = nr * (job_idx % d);
301 let row = mr * (job_idx / d);
302 job(row, col_inner);
303 } else {
304 return;
305 }
306 }
307 });
308 },
309 }
310
311 beta = Accum::Add;
312 depth += k_chunk;
313 }
314 col_outer += n_chunk;
315 }
316 }
317}
318
319mod matmul_horizontal {
320 use super::*;
321 use matmul_shared::*;
322
323 struct Ukr<'a, const MR: usize, const NR: usize, T: ComplexField> {
324 dst: MatMut<'a, T, usize, usize>,
325 a: MatRef<'a, T, usize, usize, isize, ContiguousFwd>,
326 b: MatRef<'a, T, usize, usize, ContiguousFwd, isize>,
327 conj_lhs: Conj,
328 conj_rhs: Conj,
329 alpha: &'a T,
330 beta: Accum,
331 }
332
333 impl<const MR: usize, const NR: usize, T: ComplexField> pulp::WithSimd for Ukr<'_, MR, NR, T> {
334 type Output = ();
335
336 #[math]
337 #[inline(always)]
338 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
339 let Self {
340 dst,
341 a,
342 b,
343 conj_lhs,
344 conj_rhs,
345 alpha,
346 beta,
347 } = self;
348
349 with_dim!(M, a.nrows());
350 with_dim!(N, b.ncols());
351 with_dim!(K, a.ncols());
352 let a = a.as_shape(M, K);
353 let b = b.as_shape(K, N);
354 let mut dst = dst.as_shape_mut(M, N);
355
356 let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), K);
357 let (_, body, tail) = simd.indices();
358
359 let mut local_acc = [[simd.zero(); MR]; NR];
360 let mut is = [M.idx(0usize); MR];
361 let mut js = [N.idx(0usize); NR];
362
363 for (idx, i) in is.iter_mut().enumerate() {
364 *i = M.idx(idx);
365 }
366 for (idx, j) in js.iter_mut().enumerate() {
367 *j = N.idx(idx);
368 }
369
370 if conj_lhs == conj_rhs {
371 macro_rules! do_it {
372 ($depth: expr) => {{
373 let depth = $depth;
374 let a = is.map(
375 #[inline(always)]
376 |i| simd.read(a.row(i).transpose(), depth),
377 );
378 let b = js.map(
379 #[inline(always)]
380 |j| simd.read(b.col(j), depth),
381 );
382
383 for i in 0..MR {
384 for j in 0..NR {
385 local_acc[j][i] = simd.mul_add(b[j], a[i], local_acc[j][i]);
386 }
387 }
388 }};
389 }
390 for depth in body {
391 do_it!(depth);
392 }
393 if let Some(depth) = tail {
394 do_it!(depth);
395 }
396 } else {
397 macro_rules! do_it {
398 ($depth: expr) => {{
399 let depth = $depth;
400 let a = is.map(
401 #[inline(always)]
402 |i| simd.read(a.row(i).transpose(), depth),
403 );
404 let b = js.map(
405 #[inline(always)]
406 |j| simd.read(b.col(j), depth),
407 );
408
409 for i in 0..MR {
410 for j in 0..NR {
411 local_acc[j][i] = simd.conj_mul_add(b[j], a[i], local_acc[j][i]);
412 }
413 }
414 }};
415 }
416 for depth in body {
417 do_it!(depth);
418 }
419 if let Some(depth) = tail {
420 do_it!(depth);
421 }
422 }
423
424 if conj_lhs.is_conj() {
425 for x in &mut local_acc {
426 for x in x {
427 *x = simd.conj(*x);
428 }
429 }
430 }
431 let result = local_acc;
432 let result = result.map(
433 #[inline(always)]
434 |result| {
435 result.map(
436 #[inline(always)]
437 |result| simd.reduce_sum(result),
438 )
439 },
440 );
441
442 let alpha = copy(*alpha);
443 match beta {
444 Accum::Add => {
445 for (result, j) in core::iter::zip(&result, js) {
446 for (result, i) in core::iter::zip(result, is) {
447 dst[(i, j)] = alpha * *result + dst[(i, j)];
448 }
449 }
450 },
451 Accum::Replace => {
452 for (result, j) in core::iter::zip(&result, js) {
453 for (result, i) in core::iter::zip(result, is) {
454 dst[(i, j)] = alpha * *result;
455 }
456 }
457 },
458 }
459 }
460 }
461
462 #[math]
463 pub fn matmul_simd<'M, 'N, 'K, T: ComplexField>(
464 dst: MatMut<'_, T, Dim<'M>, Dim<'N>>,
465 beta: Accum,
466 lhs: MatRef<'_, T, Dim<'M>, Dim<'K>, isize, ContiguousFwd>,
467 conj_lhs: Conj,
468 rhs: MatRef<'_, T, Dim<'K>, Dim<'N>, ContiguousFwd, isize>,
469 conj_rhs: Conj,
470 alpha: &T,
471 par: Par,
472 ) {
473 let dst = dst.as_dyn_mut();
474 let lhs = lhs.as_dyn();
475 let rhs = rhs.as_dyn();
476
477 let (m, n) = dst.shape();
478 let k = lhs.ncols();
479
480 let nr = MicroKernelShape::<T>::MAX_NR;
481 let mr = MicroKernelShape::<T>::MAX_MR_DIV_N;
482
483 let arch = T::Arch::default();
484
485 let lane_count = arch.dispatch(SimdLaneCount::<T> {
486 __marker: core::marker::PhantomData,
487 });
488 let kc = KC * lane_count;
489
490 let mut col_outer = 0;
491 while col_outer < n {
492 let n_chunk = Ord::min(n - col_outer, NC);
493
494 let mut beta = beta;
495 let mut depth = 0;
496 while depth < k {
497 let k_chunk = Ord::min(k - depth, kc);
498
499 let job = |row: usize, col_inner: usize| {
500 let nrows = Ord::min(m - row, mr);
501 let ukr_i = nrows;
502 let ncols = Ord::min(n_chunk - col_inner, nr);
503 let ukr_j = ncols;
504
505 let dst = unsafe { dst.rb().const_cast() }.submatrix_mut(row, col_outer + col_inner, nrows, ncols);
506 let a = lhs.submatrix(row, depth, nrows, k_chunk);
507 let b = rhs.submatrix(depth, col_outer + col_inner, k_chunk, ncols);
508
509 macro_rules! call {
510 ($M: expr, $N: expr) => {
511 arch.dispatch(Ukr::<'_, $M, $N, T> {
512 dst,
513 a,
514 b,
515 conj_lhs,
516 conj_rhs,
517 alpha,
518 beta,
519 })
520 };
521 }
522 if const { MicroKernelShape::<T>::IS_2X2 } {
523 match (ukr_i, ukr_j) {
524 (2, 2) => call!(2, 2),
525 (1, 2) => call!(1, 2),
526 (2, 1) => call!(2, 1),
527 (1, 1) => call!(1, 1),
528 _ => unreachable!(),
529 }
530 } else if const { MicroKernelShape::<T>::IS_2X1 } {
531 match (ukr_i, ukr_j) {
532 (2, 1) => call!(2, 1),
533 (1, 1) => call!(1, 1),
534 _ => unreachable!(),
535 }
536 } else if const { MicroKernelShape::<T>::IS_1X1 } {
537 call!(1, 1)
538 } else {
539 unreachable!()
540 }
541 };
542
543 let job_count = m.div_ceil(mr) * n.div_ceil(nr);
544 let d = n.div_ceil(nr);
545 match par {
546 Par::Seq => {
547 for job_idx in 0..job_count {
548 let col_inner = nr * (job_idx % d);
549 let row = mr * (job_idx / d);
550 job(row, col_inner);
551 }
552 },
553 #[cfg(feature = "rayon")]
554 Par::Rayon(nthreads) => {
555 let nthreads = nthreads.get();
556 use rayon::prelude::*;
557
558 let job_idx = core::sync::atomic::AtomicUsize::new(0);
559
560 (0..nthreads).into_par_iter().for_each(|_| {
561 loop {
562 let job_idx = job_idx.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
563 if job_idx < job_count {
564 let col_inner = nr * (job_idx % d);
565 let row = mr * (job_idx / d);
566 job(row, col_inner);
567 } else {
568 return;
569 }
570 }
571 });
572 },
573 }
574
575 beta = Accum::Add;
576 depth += k_chunk;
577 }
578 col_outer += n_chunk;
579 }
580 }
581}
582
583pub mod dot {
585 use super::*;
586 use faer_traits::SimdArch;
587
588 pub fn inner_prod<K: Shape, T: ComplexField>(lhs: RowRef<T, K>, conj_lhs: Conj, rhs: ColRef<T, K>, conj_rhs: Conj) -> T {
590 #[math]
591 pub fn imp<'K, T: ComplexField>(lhs: RowRef<T, Dim<'K>>, conj_lhs: Conj, rhs: ColRef<T, Dim<'K>>, conj_rhs: Conj) -> T {
592 if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
593 if let (Some(lhs), Some(rhs)) = (lhs.try_as_row_major(), rhs.try_as_col_major()) {
594 inner_prod_slice::<T>(lhs.ncols(), lhs.transpose(), conj_lhs, rhs, conj_rhs)
595 } else {
596 inner_prod_schoolbook(lhs, conj_lhs, rhs, conj_rhs)
597 }
598 } else {
599 inner_prod_schoolbook(lhs, conj_lhs, rhs, conj_rhs)
600 }
601 }
602
603 with_dim!(K, lhs.ncols().unbound());
604
605 imp(lhs.as_col_shape(K), conj_lhs, rhs.as_row_shape(K), conj_rhs)
606 }
607
608 #[inline(always)]
609 #[math]
610 fn inner_prod_slice<'K, T: ComplexField>(
611 len: Dim<'K>,
612 lhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
613 conj_lhs: Conj,
614 rhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
615 conj_rhs: Conj,
616 ) -> T {
617 struct Impl<'a, 'K, T: ComplexField> {
618 len: Dim<'K>,
619 lhs: ColRef<'a, T, Dim<'K>, ContiguousFwd>,
620 conj_lhs: Conj,
621 rhs: ColRef<'a, T, Dim<'K>, ContiguousFwd>,
622 conj_rhs: Conj,
623 }
624 impl<'a, 'K, T: ComplexField> pulp::WithSimd for Impl<'_, '_, T> {
625 type Output = T;
626
627 #[inline(always)]
628 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
629 let Self {
630 len,
631 lhs,
632 conj_lhs,
633 rhs,
634 conj_rhs,
635 } = self;
636
637 let simd = SimdCtx::new(T::simd_ctx(simd), len);
638
639 let mut tmp = if conj_lhs == conj_rhs {
640 inner_prod_no_conj_simd::<T, S>(simd, lhs, rhs)
641 } else {
642 inner_prod_conj_lhs_simd::<T, S>(simd, lhs, rhs)
643 };
644
645 if conj_rhs == Conj::Yes {
646 tmp = conj(tmp);
647 }
648 tmp
649 }
650 }
651
652 dispatch!(
653 Impl {
654 len,
655 lhs,
656 rhs,
657 conj_lhs,
658 conj_rhs
659 },
660 Impl,
661 T
662 )
663 }
664
665 #[inline(always)]
666 pub(crate) fn inner_prod_no_conj_simd<'K, T: ComplexField, S: Simd>(
667 simd: SimdCtx<'K, T, S>,
668 lhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
669 rhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
670 ) -> T {
671 let mut acc0 = simd.zero();
672 let mut acc1 = simd.zero();
673 let mut acc2 = simd.zero();
674 let mut acc3 = simd.zero();
675
676 let (head, idx4, idx, tail) = simd.batch_indices::<4>();
677
678 if let Some(i0) = head {
679 let l0 = simd.read(lhs, i0);
680 let r0 = simd.read(rhs, i0);
681
682 acc0 = simd.mul_add(l0, r0, acc0);
683 }
684 for [i0, i1, i2, i3] in idx4 {
685 let l0 = simd.read(lhs, i0);
686 let l1 = simd.read(lhs, i1);
687 let l2 = simd.read(lhs, i2);
688 let l3 = simd.read(lhs, i3);
689
690 let r0 = simd.read(rhs, i0);
691 let r1 = simd.read(rhs, i1);
692 let r2 = simd.read(rhs, i2);
693 let r3 = simd.read(rhs, i3);
694
695 acc0 = simd.mul_add(l0, r0, acc0);
696 acc1 = simd.mul_add(l1, r1, acc1);
697 acc2 = simd.mul_add(l2, r2, acc2);
698 acc3 = simd.mul_add(l3, r3, acc3);
699 }
700 for i0 in idx {
701 let l0 = simd.read(lhs, i0);
702 let r0 = simd.read(rhs, i0);
703
704 acc0 = simd.mul_add(l0, r0, acc0);
705 }
706 if let Some(i0) = tail {
707 let l0 = simd.read(lhs, i0);
708 let r0 = simd.read(rhs, i0);
709
710 acc0 = simd.mul_add(l0, r0, acc0);
711 }
712 acc0 = simd.add(acc0, acc1);
713 acc2 = simd.add(acc2, acc3);
714 acc0 = simd.add(acc0, acc2);
715
716 simd.reduce_sum(acc0)
717 }
718
719 #[inline(always)]
720 pub(crate) fn inner_prod_conj_lhs_simd<'K, T: ComplexField, S: Simd>(
721 simd: SimdCtx<'K, T, S>,
722 lhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
723 rhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
724 ) -> T {
725 let mut acc0 = simd.zero();
726 let mut acc1 = simd.zero();
727 let mut acc2 = simd.zero();
728 let mut acc3 = simd.zero();
729
730 let (head, idx4, idx, tail) = simd.batch_indices::<4>();
731
732 if let Some(i0) = head {
733 let l0 = simd.read(lhs, i0);
734 let r0 = simd.read(rhs, i0);
735
736 acc0 = simd.conj_mul_add(l0, r0, acc0);
737 }
738 for [i0, i1, i2, i3] in idx4 {
739 let l0 = simd.read(lhs, i0);
740 let l1 = simd.read(lhs, i1);
741 let l2 = simd.read(lhs, i2);
742 let l3 = simd.read(lhs, i3);
743
744 let r0 = simd.read(rhs, i0);
745 let r1 = simd.read(rhs, i1);
746 let r2 = simd.read(rhs, i2);
747 let r3 = simd.read(rhs, i3);
748
749 acc0 = simd.conj_mul_add(l0, r0, acc0);
750 acc1 = simd.conj_mul_add(l1, r1, acc1);
751 acc2 = simd.conj_mul_add(l2, r2, acc2);
752 acc3 = simd.conj_mul_add(l3, r3, acc3);
753 }
754 for i0 in idx {
755 let l0 = simd.read(lhs, i0);
756 let r0 = simd.read(rhs, i0);
757
758 acc0 = simd.conj_mul_add(l0, r0, acc0);
759 }
760 if let Some(i0) = tail {
761 let l0 = simd.read(lhs, i0);
762 let r0 = simd.read(rhs, i0);
763
764 acc0 = simd.conj_mul_add(l0, r0, acc0);
765 }
766 acc0 = simd.add(acc0, acc1);
767 acc2 = simd.add(acc2, acc3);
768 acc0 = simd.add(acc0, acc2);
769
770 simd.reduce_sum(acc0)
771 }
772
773 #[math]
774 pub(crate) fn inner_prod_schoolbook<'K, T: ComplexField>(
775 lhs: RowRef<'_, T, Dim<'K>>,
776 conj_lhs: Conj,
777 rhs: ColRef<'_, T, Dim<'K>>,
778 conj_rhs: Conj,
779 ) -> T {
780 let mut acc = zero();
781
782 for k in lhs.ncols().indices() {
783 if try_const! { T::IS_REAL } {
784 acc = lhs[k] * rhs[k] + acc;
785 } else {
786 match (conj_lhs, conj_rhs) {
787 (Conj::No, Conj::No) => {
788 acc = lhs[k] * rhs[k] + acc;
789 },
790 (Conj::No, Conj::Yes) => {
791 acc = lhs[k] * conj(rhs[k]) + acc;
792 },
793 (Conj::Yes, Conj::No) => {
794 acc = conj(lhs[k]) * rhs[k] + acc;
795 },
796 (Conj::Yes, Conj::Yes) => {
797 acc = conj(lhs[k] * rhs[k]) + acc;
798 },
799 }
800 }
801 }
802
803 acc
804 }
805}
806
807mod matvec_rowmajor {
808 use super::*;
809 use crate::col::ColMut;
810 use faer_traits::SimdArch;
811
812 #[math]
813 pub fn matvec<'M, 'K, T: ComplexField>(
814 dst: ColMut<'_, T, Dim<'M>>,
815 beta: Accum,
816 lhs: MatRef<'_, T, Dim<'M>, Dim<'K>, isize, ContiguousFwd>,
817 conj_lhs: Conj,
818 rhs: ColRef<'_, T, Dim<'K>, ContiguousFwd>,
819 conj_rhs: Conj,
820 alpha: &T,
821 par: Par,
822 ) {
823 core::assert!(try_const! { T::SIMD_CAPABILITIES.is_simd() });
824 let size = *lhs.nrows() * *lhs.ncols();
825 let par = if size < 256 * 256usize { Par::Seq } else { par };
826
827 match par {
828 Par::Seq => {
829 pub struct Impl<'a, 'M, 'K, T: ComplexField> {
830 dst: ColMut<'a, T, Dim<'M>>,
831 beta: Accum,
832 lhs: MatRef<'a, T, Dim<'M>, Dim<'K>, isize, ContiguousFwd>,
833 conj_lhs: Conj,
834 rhs: ColRef<'a, T, Dim<'K>, ContiguousFwd>,
835 conj_rhs: Conj,
836 alpha: &'a T,
837 }
838
839 impl<'a, 'M, 'K, T: ComplexField> pulp::WithSimd for Impl<'a, 'M, 'K, T> {
840 type Output = ();
841
842 #[inline(always)]
843 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
844 let Self {
845 dst,
846 beta,
847 lhs,
848 conj_lhs,
849 rhs,
850 conj_rhs,
851 alpha,
852 } = self;
853 let simd = T::simd_ctx(simd);
854 let mut dst = dst;
855
856 let K = lhs.ncols();
857 let simd = SimdCtx::new(simd, K);
858 for i in lhs.nrows().indices() {
859 let dst = &mut dst[i];
860 let lhs = lhs.row(i);
861 let rhs = rhs;
862 let mut tmp = if conj_lhs == conj_rhs {
863 dot::inner_prod_no_conj_simd::<T, S>(simd, lhs.transpose(), rhs)
864 } else {
865 dot::inner_prod_conj_lhs_simd::<T, S>(simd, lhs.transpose(), rhs)
866 };
867
868 if conj_rhs == Conj::Yes {
869 tmp = conj(tmp);
870 }
871 tmp = *alpha * tmp;
872 if let Accum::Add = beta {
873 tmp = *dst + tmp;
874 }
875 *dst = tmp;
876 }
877 }
878 }
879
880 dispatch!(
881 Impl {
882 dst,
883 beta,
884 lhs,
885 conj_lhs,
886 rhs,
887 conj_rhs,
888 alpha,
889 },
890 Impl,
891 T
892 );
893 },
894 #[cfg(feature = "rayon")]
895 Par::Rayon(nthreads) => {
896 let nthreads = nthreads.get();
897
898 use rayon::prelude::*;
899 dst.par_partition_mut(nthreads)
900 .zip_eq(lhs.par_row_partition(nthreads))
901 .for_each(|(dst, lhs)| {
902 make_guard!(M);
903 let nrows = dst.nrows().bind(M);
904 let dst = dst.as_row_shape_mut(nrows);
905 let lhs = lhs.as_row_shape(nrows);
906
907 matvec(dst, beta, lhs, conj_lhs, rhs, conj_rhs, alpha, Par::Seq);
908 })
909 },
910 }
911 }
912}
913
914mod matvec_colmajor {
915 use super::*;
916 use crate::col::ColMut;
917 use crate::linalg::temp_mat_uninit;
918 use crate::mat::AsMatMut;
919 use crate::utils::bound::IdxInc;
920 use crate::{unzip, zip};
921 use faer_traits::SimdArch;
922
923 #[math]
924 pub fn matvec<'M, 'K, T: ComplexField>(
925 dst: ColMut<'_, T, Dim<'M>, ContiguousFwd>,
926 beta: Accum,
927 lhs: MatRef<'_, T, Dim<'M>, Dim<'K>, ContiguousFwd, isize>,
928 conj_lhs: Conj,
929 rhs: ColRef<'_, T, Dim<'K>>,
930 conj_rhs: Conj,
931 alpha: &T,
932 par: Par,
933 ) {
934 core::assert!(try_const! { T::SIMD_CAPABILITIES.is_simd() });
935 let size = *lhs.nrows() * *lhs.ncols();
936 let par = if size < 256 * 256usize { Par::Seq } else { par };
937
938 match par {
939 Par::Seq => {
940 pub struct Impl<'a, 'M, 'K, T: ComplexField> {
941 dst: ColMut<'a, T, Dim<'M>, ContiguousFwd>,
942 beta: Accum,
943 lhs: MatRef<'a, T, Dim<'M>, Dim<'K>, ContiguousFwd, isize>,
944 conj_lhs: Conj,
945 rhs: ColRef<'a, T, Dim<'K>>,
946 conj_rhs: Conj,
947 alpha: &'a T,
948 }
949
950 impl<'a, 'M, 'K, T: ComplexField> pulp::WithSimd for Impl<'a, 'M, 'K, T> {
951 type Output = ();
952
953 #[inline(always)]
954 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
955 let Self {
956 dst,
957 beta,
958 lhs,
959 conj_lhs,
960 rhs,
961 conj_rhs,
962 alpha,
963 } = self;
964
965 let simd = T::simd_ctx(simd);
966
967 let M = lhs.nrows();
968 let simd = SimdCtx::<T, S>::new(simd, M);
969 let (head, body, tail) = simd.indices();
970
971 let mut dst = dst;
972 match beta {
973 Accum::Add => {},
974 Accum::Replace => {
975 let mut dst = dst.rb_mut();
976 if let Some(i) = head {
977 simd.write(dst.rb_mut(), i, simd.zero());
978 }
979 for i in body.clone() {
980 simd.write(dst.rb_mut(), i, simd.zero());
981 }
982 if let Some(i) = tail {
983 simd.write(dst.rb_mut(), i, simd.zero());
984 }
985 },
986 }
987
988 for j in lhs.ncols().indices() {
989 let mut dst = dst.rb_mut();
990 let lhs = lhs.col(j);
991 let rhs = &rhs[j];
992 let rhs = if conj_rhs == Conj::Yes { conj(*rhs) } else { copy(*rhs) };
993 let rhs = rhs * *alpha;
994
995 let vrhs = simd.splat(&rhs);
996 if conj_lhs == Conj::Yes {
997 if let Some(i) = head {
998 let y = simd.read(dst.rb(), i);
999 let x = simd.read(lhs, i);
1000 simd.write(dst.rb_mut(), i, simd.conj_mul_add(x, vrhs, y));
1001 }
1002 for i in body.clone() {
1003 let y = simd.read(dst.rb(), i);
1004 let x = simd.read(lhs, i);
1005 simd.write(dst.rb_mut(), i, simd.conj_mul_add(x, vrhs, y));
1006 }
1007 if let Some(i) = tail {
1008 let y = simd.read(dst.rb(), i);
1009 let x = simd.read(lhs, i);
1010 simd.write(dst.rb_mut(), i, simd.conj_mul_add(x, vrhs, y));
1011 }
1012 } else {
1013 if let Some(i) = head {
1014 let y = simd.read(dst.rb(), i);
1015 let x = simd.read(lhs, i);
1016 simd.write(dst.rb_mut(), i, simd.mul_add(x, vrhs, y));
1017 }
1018 for i in body.clone() {
1019 let y = simd.read(dst.rb(), i);
1020 let x = simd.read(lhs, i);
1021 simd.write(dst.rb_mut(), i, simd.mul_add(x, vrhs, y));
1022 }
1023 if let Some(i) = tail {
1024 let y = simd.read(dst.rb(), i);
1025 let x = simd.read(lhs, i);
1026 simd.write(dst.rb_mut(), i, simd.mul_add(x, vrhs, y));
1027 }
1028 }
1029 }
1030 }
1031 }
1032
1033 dispatch!(
1034 Impl {
1035 dst,
1036 lhs,
1037 conj_lhs,
1038 rhs,
1039 conj_rhs,
1040 beta,
1041 alpha,
1042 },
1043 Impl,
1044 T
1045 )
1046 },
1047 #[cfg(feature = "rayon")]
1048 Par::Rayon(nthreads) => {
1049 use rayon::prelude::*;
1050 let nthreads = nthreads.get();
1051 let mut mem = MemBuffer::new(temp_mat_scratch::<T>(dst.nrows().unbound(), nthreads));
1052 let stack = MemStack::new(&mut mem);
1053
1054 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(dst.nrows(), nthreads, stack) };
1055 let mut tmp = tmp.as_mat_mut().try_as_col_major_mut().unwrap();
1056
1057 let mut dst = dst;
1058 make_guard!(Z);
1059 let Z = 0usize.bind(Z);
1060 let z = IdxInc::new_checked(0, lhs.ncols());
1061
1062 tmp.rb_mut()
1063 .par_col_iter_mut()
1064 .zip_eq(lhs.par_col_partition(nthreads))
1065 .zip_eq(rhs.par_partition(nthreads))
1066 .for_each(|((dst, lhs), rhs)| {
1067 make_guard!(K);
1068 let K = lhs.ncols().bind(K);
1069 let lhs = lhs.as_col_shape(K);
1070 let rhs = rhs.as_row_shape(K);
1071
1072 matvec(dst, Accum::Replace, lhs, conj_lhs, rhs, conj_rhs, alpha, Par::Seq);
1073 });
1074
1075 matvec(
1076 dst.rb_mut(),
1077 beta,
1078 lhs.subcols(z, Z),
1079 conj_lhs,
1080 rhs.subrows(z, Z),
1081 conj_rhs,
1082 &zero(),
1083 Par::Seq,
1084 );
1085 for j in 0..nthreads {
1086 zip!(dst.rb_mut(), tmp.rb().col(j)).for_each(|unzip!(dst, src)| *dst = *dst + *src)
1087 }
1088 },
1089 }
1090 }
1091}
1092
1093mod rank_update {
1094 use super::*;
1095 use crate::assert;
1096
1097 #[math]
1098 fn rank_update_imp<'M, 'N, T: ComplexField>(
1099 dst: MatMut<'_, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
1100 beta: Accum,
1101 lhs: ColRef<'_, T, Dim<'M>, ContiguousFwd>,
1102 conj_lhs: Conj,
1103 rhs: RowRef<'_, T, Dim<'N>>,
1104 conj_rhs: Conj,
1105 alpha: &T,
1106 ) {
1107 assert!(T::SIMD_CAPABILITIES.is_simd());
1108
1109 struct Impl<'a, 'M, 'N, T: ComplexField> {
1110 dst: MatMut<'a, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
1111 beta: Accum,
1112 lhs: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
1113 conj_lhs: Conj,
1114 rhs: RowRef<'a, T, Dim<'N>>,
1115 conj_rhs: Conj,
1116 alpha: &'a T,
1117 }
1118
1119 impl<T: ComplexField> pulp::WithSimd for Impl<'_, '_, '_, T> {
1120 type Output = ();
1121
1122 #[inline(always)]
1123 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
1124 let Self {
1125 mut dst,
1126 beta,
1127 lhs,
1128 conj_lhs,
1129 rhs,
1130 conj_rhs,
1131 alpha,
1132 } = self;
1133
1134 let (m, n) = dst.shape();
1135 let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), m);
1136
1137 let (head, body, tail) = simd.indices();
1138
1139 for j in n.indices() {
1140 let mut dst = dst.rb_mut().col_mut(j);
1141
1142 let rhs = *alpha * conj_rhs.apply_rt(&rhs[j]);
1143 let rhs = simd.splat(&rhs);
1144
1145 if conj_lhs.is_conj() {
1146 match beta {
1147 Accum::Add => {
1148 if let Some(i) = head {
1149 let mut acc = simd.read(dst.rb(), i);
1150 acc = simd.conj_mul_add(simd.read(lhs, i), rhs, acc);
1151 simd.write(dst.rb_mut(), i, acc);
1152 }
1153 for i in body.clone() {
1154 let mut acc = simd.read(dst.rb(), i);
1155 acc = simd.conj_mul_add(simd.read(lhs, i), rhs, acc);
1156 simd.write(dst.rb_mut(), i, acc);
1157 }
1158 if let Some(i) = tail {
1159 let mut acc = simd.read(dst.rb(), i);
1160 acc = simd.conj_mul_add(simd.read(lhs, i), rhs, acc);
1161 simd.write(dst.rb_mut(), i, acc);
1162 }
1163 },
1164 Accum::Replace => {
1165 if let Some(i) = head {
1166 let acc = simd.conj_mul(simd.read(lhs, i), rhs);
1167 simd.write(dst.rb_mut(), i, acc);
1168 }
1169 for i in body.clone() {
1170 let acc = simd.conj_mul(simd.read(lhs, i), rhs);
1171 simd.write(dst.rb_mut(), i, acc);
1172 }
1173 if let Some(i) = tail {
1174 let acc = simd.conj_mul(simd.read(lhs, i), rhs);
1175 simd.write(dst.rb_mut(), i, acc);
1176 }
1177 },
1178 }
1179 } else {
1180 match beta {
1181 Accum::Add => {
1182 if let Some(i) = head {
1183 let mut acc = simd.read(dst.rb(), i);
1184 acc = simd.mul_add(simd.read(lhs, i), rhs, acc);
1185 simd.write(dst.rb_mut(), i, acc);
1186 }
1187 for i in body.clone() {
1188 let mut acc = simd.read(dst.rb(), i);
1189 acc = simd.mul_add(simd.read(lhs, i), rhs, acc);
1190 simd.write(dst.rb_mut(), i, acc);
1191 }
1192 if let Some(i) = tail {
1193 let mut acc = simd.read(dst.rb(), i);
1194 acc = simd.mul_add(simd.read(lhs, i), rhs, acc);
1195 simd.write(dst.rb_mut(), i, acc);
1196 }
1197 },
1198 Accum::Replace => {
1199 if let Some(i) = head {
1200 let acc = simd.mul(simd.read(lhs, i), rhs);
1201 simd.write(dst.rb_mut(), i, acc);
1202 }
1203 for i in body.clone() {
1204 let acc = simd.mul(simd.read(lhs, i), rhs);
1205 simd.write(dst.rb_mut(), i, acc);
1206 }
1207 if let Some(i) = tail {
1208 let acc = simd.mul(simd.read(lhs, i), rhs);
1209 simd.write(dst.rb_mut(), i, acc);
1210 }
1211 },
1212 }
1213 }
1214 }
1215 }
1216 }
1217
1218 dispatch!(
1219 Impl {
1220 dst,
1221 lhs,
1222 conj_lhs,
1223 rhs,
1224 conj_rhs,
1225 beta,
1226 alpha,
1227 },
1228 Impl,
1229 T
1230 )
1231 }
1232
1233 #[math]
1234 pub fn rank_update<'M, 'N, T: ComplexField>(
1235 dst: MatMut<'_, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
1236 beta: Accum,
1237 lhs: ColRef<'_, T, Dim<'M>, ContiguousFwd>,
1238 conj_lhs: Conj,
1239 rhs: RowRef<'_, T, Dim<'N>>,
1240 conj_rhs: Conj,
1241 alpha: &T,
1242 par: Par,
1243 ) {
1244 match par {
1245 Par::Seq => {
1246 rank_update_imp(dst, beta, lhs, conj_lhs, rhs, conj_rhs, alpha);
1247 },
1248 #[cfg(feature = "rayon")]
1249 Par::Rayon(nthreads) => {
1250 let nthreads = nthreads.get();
1251 use rayon::prelude::*;
1252 dst.par_col_partition_mut(nthreads)
1253 .zip(rhs.par_partition(nthreads))
1254 .for_each(|(dst, rhs)| {
1255 with_dim!(N, dst.ncols());
1256 rank_update_imp(dst.as_col_shape_mut(N), beta, lhs, conj_lhs, rhs.as_col_shape(N), conj_rhs, alpha);
1257 });
1258 },
1259 }
1260 }
1261}
1262
1263#[math]
1264fn matmul_imp<'M, 'N, 'K, T: ComplexField>(
1265 dst: MatMut<'_, T, Dim<'M>, Dim<'N>>,
1266 beta: Accum,
1267 lhs: MatRef<'_, T, Dim<'M>, Dim<'K>>,
1268 conj_lhs: Conj,
1269 rhs: MatRef<'_, T, Dim<'K>, Dim<'N>>,
1270 conj_rhs: Conj,
1271 alpha: &T,
1272 par: Par,
1273) {
1274 let mut dst = dst;
1275
1276 let M = dst.nrows();
1277 let N = dst.ncols();
1278 let K = lhs.ncols();
1279 if *M == 0 || *N == 0 {
1280 return;
1281 }
1282 if *K == 0 {
1283 if beta == Accum::Replace {
1284 dst.fill(zero());
1285 }
1286 return;
1287 }
1288
1289 let mut lhs = lhs;
1290 let mut rhs = rhs;
1291
1292 if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
1293 if dst.row_stride() < 0 {
1294 dst = dst.reverse_rows_mut();
1295 lhs = lhs.reverse_rows();
1296 }
1297 if dst.col_stride() < 0 {
1298 dst = dst.reverse_cols_mut();
1299 rhs = rhs.reverse_cols();
1300 }
1301 if lhs.col_stride() < 0 {
1302 lhs = lhs.reverse_cols();
1303 rhs = rhs.reverse_rows();
1304 }
1305
1306 if dst.ncols().unbound() == 1 {
1307 let first = dst.ncols().check(0);
1308 if let (Some(dst), Some(lhs)) = (dst.rb_mut().try_as_col_major_mut(), lhs.try_as_col_major()) {
1309 matvec_colmajor::matvec(dst.col_mut(first), beta, lhs, conj_lhs, rhs.col(first), conj_rhs, alpha, par);
1310 return;
1311 }
1312
1313 if let (Some(rhs), Some(lhs)) = (rhs.try_as_col_major(), lhs.try_as_row_major()) {
1314 matvec_rowmajor::matvec(dst.col_mut(first), beta, lhs, conj_lhs, rhs.col(first), conj_rhs, alpha, par);
1315 return;
1316 }
1317 }
1318 if dst.nrows().unbound() == 1 {
1319 let mut dst = dst.rb_mut().transpose_mut();
1320 let (rhs, lhs) = (lhs.transpose(), rhs.transpose());
1321 let (conj_rhs, conj_lhs) = (conj_lhs, conj_rhs);
1322
1323 let first = dst.ncols().check(0);
1324 if let (Some(dst), Some(lhs)) = (dst.rb_mut().try_as_col_major_mut(), lhs.try_as_col_major()) {
1325 matvec_colmajor::matvec(dst.col_mut(first), beta, lhs, conj_lhs, rhs.col(first), conj_rhs, alpha, par);
1326 return;
1327 }
1328
1329 if let (Some(rhs), Some(lhs)) = (rhs.try_as_col_major(), lhs.try_as_row_major()) {
1330 matvec_rowmajor::matvec(dst.col_mut(first), beta, lhs, conj_lhs, rhs.col(first), conj_rhs, alpha, par);
1331 return;
1332 }
1333 }
1334 if *K == 1 {
1335 let z = K.idx(0);
1336
1337 if let (Some(dst), Some(lhs)) = (dst.rb_mut().try_as_col_major_mut(), lhs.try_as_col_major()) {
1338 rank_update::rank_update(dst, beta, lhs.col(z), conj_lhs, rhs.row(z), conj_rhs, alpha, par);
1339 return;
1340 }
1341
1342 if let (Some(dst), Some(rhs)) = (dst.rb_mut().try_as_row_major_mut(), rhs.try_as_row_major()) {
1343 let dst = dst.transpose_mut();
1344 let rhs = rhs.row(z).transpose();
1345 let lhs = lhs.col(z).transpose();
1346 rank_update::rank_update(dst, beta, rhs, conj_rhs, lhs, conj_lhs, alpha, par);
1347 return;
1348 }
1349 }
1350 macro_rules! gemm_call {
1351 ($kind: ident, $ty: ty, $nanogemm: ident) => {
1352 unsafe {
1353 let dst = core::mem::transmute_copy::<MatMut<'_, T, Dim<'M>, Dim<'N>>, MatMut<'_, $ty, Dim<'M>, Dim<'N>>>(&dst);
1354 let lhs = core::mem::transmute_copy::<MatRef<'_, T, Dim<'M>, Dim<'K>>, MatRef<'_, $ty, Dim<'M>, Dim<'K>>>(&lhs);
1355 let rhs = core::mem::transmute_copy::<MatRef<'_, T, Dim<'K>, Dim<'N>>, MatRef<'_, $ty, Dim<'K>, Dim<'N>>>(&rhs);
1356 let alpha = *core::mem::transmute_copy::<&T, &$ty>(&alpha);
1357
1358 if (*M).saturating_mul(*N).saturating_mul(*K) <= NANO_GEMM_THRESHOLD {
1359 nano_gemm::planless::$nanogemm(
1360 *M,
1361 *N,
1362 *K,
1363 dst.as_ptr_mut(),
1364 dst.row_stride(),
1365 dst.col_stride(),
1366 lhs.as_ptr(),
1367 lhs.row_stride(),
1368 lhs.col_stride(),
1369 rhs.as_ptr(),
1370 rhs.row_stride(),
1371 rhs.col_stride(),
1372 match beta {
1373 Accum::Replace => core::mem::zeroed(),
1374 Accum::Add => 1.0.into(),
1375 },
1376 alpha,
1377 conj_lhs == Conj::Yes,
1378 conj_rhs == Conj::Yes,
1379 );
1380 return;
1381 } else {
1382 #[cfg(all(target_arch = "x86_64", feature = "std"))]
1383 {
1384 use private_gemm_x86::*;
1385
1386 let feat = if std::arch::is_x86_feature_detected!("avx512f") {
1387 Some(InstrSet::Avx512)
1388 } else if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") {
1389 Some(InstrSet::Avx256)
1390 } else {
1391 None
1392 };
1393
1394 if let Some(feat) = feat {
1395 gemm(
1396 DType::$kind,
1397 IType::U64,
1398 feat,
1399 *M,
1400 *N,
1401 *K,
1402 dst.as_ptr_mut() as *mut (),
1403 dst.row_stride(),
1404 dst.col_stride(),
1405 core::ptr::null(),
1406 core::ptr::null(),
1407 DstKind::Full,
1408 match beta {
1409 $crate::Accum::Replace => Accum::Replace,
1410 $crate::Accum::Add => Accum::Add,
1411 },
1412 lhs.as_ptr() as *const (),
1413 lhs.row_stride(),
1414 lhs.col_stride(),
1415 conj_lhs == Conj::Yes,
1416 core::ptr::null(),
1417 0,
1418 rhs.as_ptr() as *const (),
1419 rhs.row_stride(),
1420 rhs.col_stride(),
1421 conj_rhs == Conj::Yes,
1422 &raw const alpha as *const (),
1423 par.degree(),
1424 );
1425 return;
1426 }
1427 }
1428
1429 {
1430 gemm::gemm(
1431 M.unbound(),
1432 N.unbound(),
1433 K.unbound(),
1434 dst.as_ptr_mut(),
1435 dst.col_stride(),
1436 dst.row_stride(),
1437 beta != Accum::Replace,
1438 lhs.as_ptr(),
1439 lhs.col_stride(),
1440 lhs.row_stride(),
1441 rhs.as_ptr(),
1442 rhs.col_stride(),
1443 rhs.row_stride(),
1444 match beta {
1445 Accum::Replace => core::mem::zeroed(),
1446 Accum::Add => 1.0.into(),
1447 },
1448 alpha,
1449 false,
1450 conj_lhs == Conj::Yes,
1451 conj_rhs == Conj::Yes,
1452 match par {
1453 Par::Seq => gemm::Parallelism::None,
1454 #[cfg(feature = "rayon")]
1455 Par::Rayon(nthreads) => gemm::Parallelism::Rayon(nthreads.get()),
1456 },
1457 );
1458
1459 return;
1460 }
1461 }
1462 };
1463 };
1464 }
1465
1466 if try_const! { T::IS_NATIVE_F64 } {
1467 gemm_call!(F64, f64, execute_f64);
1468 }
1469 if try_const! { T::IS_NATIVE_C64 } {
1470 gemm_call!(C64, num_complex::Complex<f64>, execute_c64);
1471 }
1472 if try_const! { T::IS_NATIVE_F32 } {
1473 gemm_call!(F32, f32, execute_f32);
1474 }
1475 if try_const! { T::IS_NATIVE_C32 } {
1476 gemm_call!(C32, num_complex::Complex<f32>, execute_c32);
1477 }
1478
1479 if const { !(T::IS_NATIVE_F64 || T::IS_NATIVE_F32 || T::IS_NATIVE_C64 || T::IS_NATIVE_C32) } {
1480 if let (Some(dst), Some(lhs)) = (dst.rb_mut().try_as_col_major_mut(), lhs.try_as_col_major()) {
1481 matmul_vertical::matmul_simd(dst, beta, lhs, conj_lhs, rhs, conj_rhs, alpha, par);
1482 return;
1483 }
1484 if let (Some(dst), Some(rhs)) = (dst.rb_mut().try_as_row_major_mut(), rhs.try_as_row_major()) {
1485 matmul_vertical::matmul_simd(
1486 dst.transpose_mut(),
1487 beta,
1488 rhs.transpose(),
1489 conj_rhs,
1490 lhs.transpose(),
1491 conj_lhs,
1492 alpha,
1493 par,
1494 );
1495 return;
1496 }
1497 if let (Some(lhs), Some(rhs)) = (lhs.try_as_row_major(), rhs.try_as_col_major()) {
1498 matmul_horizontal::matmul_simd(dst, beta, lhs, conj_lhs, rhs, conj_rhs, alpha, par);
1499 return;
1500 }
1501 }
1502 }
1503
1504 match par {
1505 Par::Seq => {
1506 for j in dst.ncols().indices() {
1507 for i in dst.nrows().indices() {
1508 let dst = &mut dst[(i, j)];
1509
1510 let mut acc = dot::inner_prod_schoolbook(lhs.row(i), conj_lhs, rhs.col(j), conj_rhs);
1511 acc = *alpha * acc;
1512 if let Accum::Add = beta {
1513 acc = *dst + acc;
1514 }
1515 *dst = acc;
1516 }
1517 }
1518 },
1519 #[cfg(feature = "rayon")]
1520 Par::Rayon(nthreads) => {
1521 use rayon::prelude::*;
1522 let nthreads = nthreads.get();
1523
1524 let m = *dst.nrows();
1525 let n = *dst.ncols();
1526 let task_count = m * n;
1527 let task_per_thread = task_count.msrv_div_ceil(nthreads);
1528
1529 let dst = dst.rb();
1530 (0..nthreads).into_par_iter().for_each(|tid| {
1531 let task_idx = tid * task_per_thread;
1532 if task_idx >= task_count {
1533 return;
1534 }
1535 let ntasks = Ord::min(task_per_thread, task_count - task_idx);
1536
1537 for ij in 0..ntasks {
1538 let ij = task_idx + ij;
1539 let i = dst.nrows().check(ij % m);
1540 let j = dst.ncols().check(ij / m);
1541
1542 let mut dst = unsafe { dst.const_cast() };
1543 let dst = &mut dst[(i, j)];
1544
1545 let mut acc = dot::inner_prod_schoolbook(lhs.row(i), conj_lhs, rhs.col(j), conj_rhs);
1546 acc = *alpha * acc;
1547
1548 if let Accum::Add = beta {
1549 acc = *dst + acc;
1550 }
1551 *dst = acc;
1552 }
1553 });
1554 },
1555 }
1556}
1557
1558#[track_caller]
1559fn precondition<M: Shape, N: Shape, K: Shape>(dst_nrows: M, dst_ncols: N, lhs_nrows: M, lhs_ncols: K, rhs_nrows: K, rhs_ncols: N) {
1560 assert!(all(dst_nrows == lhs_nrows, dst_ncols == rhs_ncols, lhs_ncols == rhs_nrows,));
1561}
1562
1563#[track_caller]
1604#[inline]
1605pub fn matmul<T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>, M: Shape, N: Shape, K: Shape>(
1606 dst: impl AsMatMut<T = T, Rows = M, Cols = N>,
1607 beta: Accum,
1608 lhs: impl AsMatRef<T = LhsT, Rows = M, Cols = K>,
1609 rhs: impl AsMatRef<T = RhsT, Rows = K, Cols = N>,
1610 alpha: T,
1611 par: Par,
1612) {
1613 let mut dst = dst;
1614 let dst = dst.as_mat_mut();
1615 let lhs = lhs.as_mat_ref();
1616 let rhs = rhs.as_mat_ref();
1617
1618 precondition(dst.nrows(), dst.ncols(), lhs.nrows(), lhs.ncols(), rhs.nrows(), rhs.ncols());
1619
1620 make_guard!(M);
1621 make_guard!(N);
1622 make_guard!(K);
1623 let M = dst.nrows().bind(M);
1624 let N = dst.ncols().bind(N);
1625 let K = lhs.ncols().bind(K);
1626
1627 matmul_imp(
1628 dst.as_dyn_stride_mut().as_shape_mut(M, N),
1629 beta,
1630 lhs.as_dyn_stride().canonical().as_shape(M, K),
1631 try_const! { Conj::get::<LhsT>() },
1632 rhs.as_dyn_stride().canonical().as_shape(K, N),
1633 try_const! { Conj::get::<RhsT>() },
1634 &alpha,
1635 par,
1636 );
1637}
1638
1639#[track_caller]
1690#[inline]
1691pub fn matmul_with_conj<T: ComplexField, M: Shape, N: Shape, K: Shape>(
1692 dst: impl AsMatMut<T = T, Rows = M, Cols = N>,
1693 beta: Accum,
1694 lhs: impl AsMatRef<T = T, Rows = M, Cols = K>,
1695 conj_lhs: Conj,
1696 rhs: impl AsMatRef<T = T, Rows = K, Cols = N>,
1697 conj_rhs: Conj,
1698 alpha: T,
1699 par: Par,
1700) {
1701 let mut dst = dst;
1702 let dst = dst.as_mat_mut();
1703 let lhs = lhs.as_mat_ref();
1704 let rhs = rhs.as_mat_ref();
1705
1706 precondition(dst.nrows(), dst.ncols(), lhs.nrows(), lhs.ncols(), rhs.nrows(), rhs.ncols());
1707
1708 make_guard!(M);
1709 make_guard!(N);
1710 make_guard!(K);
1711 let M = dst.nrows().bind(M);
1712 let N = dst.ncols().bind(N);
1713 let K = lhs.ncols().bind(K);
1714
1715 matmul_imp(
1716 dst.as_dyn_stride_mut().as_shape_mut(M, N),
1717 beta,
1718 lhs.as_dyn_stride().canonical().as_shape(M, K),
1719 conj_lhs,
1720 rhs.as_dyn_stride().canonical().as_shape(K, N),
1721 conj_rhs,
1722 &alpha,
1723 par,
1724 );
1725}
1726
1727#[cfg(test)]
1728mod tests {
1729 use crate::c32;
1730 use std::num::NonZeroUsize;
1731
1732 use super::triangular::{BlockStructure, DiagonalKind};
1733 use super::*;
1734 use crate::assert;
1735 use crate::mat::{Mat, MatMut, MatRef};
1736 use crate::stats::prelude::*;
1737
1738 #[test]
1739 #[ignore = "takes too long"]
1740 fn test_matmul() {
1741 let rng = &mut StdRng::seed_from_u64(0);
1742
1743 if option_env!("CI") == Some("true") {
1744 return;
1746 }
1747
1748 let betas = [Accum::Replace, Accum::Add];
1749
1750 #[cfg(not(miri))]
1751 let bools = [false, true];
1752 #[cfg(not(miri))]
1753 let alphas = [c32::ONE, c32::ZERO, c32::new(21.04, -12.13)];
1754 #[cfg(not(miri))]
1755 let par = [Par::Seq, Par::Rayon(NonZeroUsize::new(4).unwrap())];
1756 #[cfg(not(miri))]
1757 let conjs = [Conj::Yes, Conj::No];
1758
1759 #[cfg(miri)]
1760 let bools = [true];
1761 #[cfg(miri)]
1762 let alphas = [c32::new(0.3218, -1.217489)];
1763 #[cfg(miri)]
1764 let par = [Par::Seq];
1765 #[cfg(miri)]
1766 let conjs = [Conj::Yes];
1767
1768 let big0 = 127;
1769 let big1 = 128;
1770 let big2 = 129;
1771
1772 let mid0 = 15;
1773 let mid1 = 16;
1774 let mid2 = 17;
1775 for (m, n, k) in [
1776 (big0, big1, 5),
1777 (big1, big0, 5),
1778 (big0, big2, 5),
1779 (big2, big0, 5),
1780 (mid0, mid0, 5),
1781 (mid1, mid1, 5),
1782 (mid2, mid2, 5),
1783 (mid0, mid1, 5),
1784 (mid1, mid0, 5),
1785 (mid0, mid2, 5),
1786 (mid2, mid0, 5),
1787 (mid0, 1, 1),
1788 (1, mid0, 1),
1789 (1, 1, mid0),
1790 (1, mid0, mid0),
1791 (mid0, 1, mid0),
1792 (mid0, mid0, 1),
1793 (1, 1, 1),
1794 ] {
1795 let distribution = ComplexDistribution::new(StandardNormal, StandardNormal);
1796 let a = CwiseMatDistribution {
1797 nrows: m,
1798 ncols: k,
1799 dist: distribution,
1800 }
1801 .rand::<Mat<c32>>(rng);
1802 let b = CwiseMatDistribution {
1803 nrows: k,
1804 ncols: n,
1805 dist: distribution,
1806 }
1807 .rand::<Mat<c32>>(rng);
1808 let mut acc_init = CwiseMatDistribution {
1809 nrows: m,
1810 ncols: n,
1811 dist: distribution,
1812 }
1813 .rand::<Mat<c32>>(rng);
1814
1815 let a = a.as_ref();
1816 let b = b.as_ref();
1817
1818 for reverse_acc_cols in bools {
1819 for reverse_acc_rows in bools {
1820 for reverse_b_cols in bools {
1821 for reverse_b_rows in bools {
1822 for reverse_a_cols in bools {
1823 for reverse_a_rows in bools {
1824 for a_colmajor in bools {
1825 for b_colmajor in bools {
1826 for acc_colmajor in bools {
1827 let a = if a_colmajor { a } else { a.transpose() };
1828 let mut a = if a_colmajor { a } else { a.transpose() };
1829
1830 let b = if b_colmajor { b } else { b.transpose() };
1831 let mut b = if b_colmajor { b } else { b.transpose() };
1832
1833 if reverse_a_rows {
1834 a = a.reverse_rows();
1835 }
1836 if reverse_a_cols {
1837 a = a.reverse_cols();
1838 }
1839 if reverse_b_rows {
1840 b = b.reverse_rows();
1841 }
1842 if reverse_b_cols {
1843 b = b.reverse_cols();
1844 }
1845 for conj_a in conjs {
1846 for conj_b in conjs {
1847 for par in par {
1848 for beta in betas {
1849 for alpha in alphas {
1850 test_matmul_impl(
1851 reverse_acc_cols,
1852 reverse_acc_rows,
1853 acc_colmajor,
1854 m,
1855 n,
1856 conj_a,
1857 conj_b,
1858 par,
1859 beta,
1860 alpha,
1861 acc_init.as_mut(),
1862 a,
1863 b,
1864 );
1865 }
1866 }
1867 }
1868 }
1869 }
1870 }
1871 }
1872 }
1873 }
1874 }
1875 }
1876 }
1877 }
1878 }
1879 }
1880 }
1881
1882 #[math]
1883 fn matmul_with_conj_fallback<T: Copy + ComplexField>(
1884 acc: MatMut<'_, T>,
1885 a: MatRef<'_, T>,
1886 conj_a: Conj,
1887 b: MatRef<'_, T>,
1888 conj_b: Conj,
1889 beta: Accum,
1890 alpha: T,
1891 ) {
1892 let m = acc.nrows();
1893 let n = acc.ncols();
1894 let k = a.ncols();
1895
1896 let job = |idx: usize| {
1897 let i = idx % m;
1898 let j = idx / m;
1899 let acc = acc.rb().submatrix(i, j, 1, 1);
1900 let mut acc = unsafe { acc.const_cast() };
1901
1902 let mut local_acc = zero::<T>();
1903 for depth in 0..k {
1904 let a = &a[(i, depth)];
1905 let b = &b[(depth, j)];
1906 local_acc = local_acc
1907 + match conj_a {
1908 Conj::Yes => conj(*a),
1909 Conj::No => copy(*a),
1910 } * match conj_b {
1911 Conj::Yes => conj(*b),
1912 Conj::No => copy(*b),
1913 }
1914 }
1915 match beta {
1916 Accum::Add => acc[(0, 0)] = acc[(0, 0)] + local_acc * alpha,
1917 Accum::Replace => acc[(0, 0)] = local_acc * alpha,
1918 }
1919 };
1920
1921 for i in 0..m * n {
1922 job(i);
1923 }
1924 }
1925
1926 #[math]
1927 fn test_matmul_impl(
1928 reverse_acc_cols: bool,
1929 reverse_acc_rows: bool,
1930 acc_colmajor: bool,
1931 m: usize,
1932 n: usize,
1933 conj_a: Conj,
1934 conj_b: Conj,
1935 par: Par,
1936 beta: Accum,
1937 alpha: c32,
1938 acc_init: MatMut<c32>,
1939 a: MatRef<c32>,
1940 b: MatRef<c32>,
1941 ) {
1942 let acc = if acc_colmajor { acc_init } else { acc_init.transpose_mut() };
1943
1944 let mut acc = if acc_colmajor { acc } else { acc.transpose_mut() };
1945 if reverse_acc_rows {
1946 acc = acc.reverse_rows_mut();
1947 }
1948 if reverse_acc_cols {
1949 acc = acc.reverse_cols_mut();
1950 }
1951
1952 let mut target = acc.rb().to_owned();
1953 matmul_with_conj_fallback(target.as_mut(), a, conj_a, b, conj_b, beta, alpha);
1954 let target = target.rb();
1955
1956 {
1957 let mut acc = acc.cloned();
1958 let a = a.cloned();
1959
1960 {
1961 with_dim!(M, a.nrows());
1962 with_dim!(N, b.ncols());
1963 with_dim!(K, a.ncols());
1964 let mut acc = acc.rb_mut().as_shape_mut(M, N);
1965 let a = a.as_shape(M, K);
1966 let b = b.as_shape(K, N);
1967
1968 matmul_vertical::matmul_simd(
1969 acc.rb_mut().try_as_col_major_mut().unwrap(),
1970 beta,
1971 a.try_as_col_major().unwrap(),
1972 conj_a,
1973 b,
1974 conj_b,
1975 &alpha,
1976 par,
1977 );
1978 }
1979 for j in 0..n {
1980 for i in 0..m {
1981 let acc = acc[(i, j)];
1982 let target = target[(i, j)];
1983 assert!(abs(acc.re - target.re) < 1e-3);
1984 assert!(abs(acc.im - target.im) < 1e-3);
1985 }
1986 }
1987 }
1988 {
1989 let mut acc = acc.cloned();
1990 let a = a.transpose().cloned();
1991 let a = a.transpose();
1992
1993 let b = b.cloned();
1994
1995 {
1996 with_dim!(M, a.nrows());
1997 with_dim!(N, b.ncols());
1998 with_dim!(K, a.ncols());
1999 let mut acc = acc.rb_mut().as_shape_mut(M, N);
2000 let a = a.as_shape(M, K);
2001 let b = b.as_shape(K, N);
2002
2003 matmul_horizontal::matmul_simd(
2004 acc.rb_mut(),
2005 beta,
2006 a.try_as_row_major().unwrap(),
2007 conj_a,
2008 b.try_as_col_major().unwrap(),
2009 conj_b,
2010 &alpha,
2011 par,
2012 );
2013 }
2014 for j in 0..n {
2015 for i in 0..m {
2016 let acc = acc[(i, j)];
2017 let target = target[(i, j)];
2018 assert!(abs(acc.re - target.re) < 1e-3);
2019 assert!(abs(acc.im - target.im) < 1e-3);
2020 }
2021 }
2022 }
2023
2024 matmul_with_conj(acc.rb_mut(), beta, a, conj_a, b, conj_b, alpha, par);
2025 for j in 0..n {
2026 for i in 0..m {
2027 let acc = acc[(i, j)];
2028 let target = target[(i, j)];
2029 assert!(abs(acc.re - target.re) < 1e-3);
2030 assert!(abs(acc.im - target.im) < 1e-3);
2031 }
2032 }
2033 }
2034
2035 fn generate_structured_matrix(is_dst: bool, nrows: usize, ncols: usize, structure: BlockStructure) -> Mat<f64> {
2036 let rng = &mut StdRng::seed_from_u64(0);
2037 let mut mat = CwiseMatDistribution {
2038 nrows,
2039 ncols,
2040 dist: StandardNormal,
2041 }
2042 .rand::<Mat<f64>>(rng);
2043
2044 if !is_dst {
2045 let kind = structure.diag_kind();
2046 if structure.is_lower() {
2047 for j in 0..ncols {
2048 for i in 0..j {
2049 mat[(i, j)] = 0.0;
2050 }
2051 }
2052 } else if structure.is_upper() {
2053 for j in 0..ncols {
2054 for i in j + 1..nrows {
2055 mat[(i, j)] = 0.0;
2056 }
2057 }
2058 }
2059
2060 match kind {
2061 triangular::DiagonalKind::Zero => {
2062 for i in 0..nrows {
2063 mat[(i, i)] = 0.0;
2064 }
2065 },
2066 triangular::DiagonalKind::Unit => {
2067 for i in 0..nrows {
2068 mat[(i, i)] = 1.0;
2069 }
2070 },
2071 triangular::DiagonalKind::Generic => (),
2072 }
2073 }
2074 mat
2075 }
2076
2077 fn run_test_problem(m: usize, n: usize, k: usize, dst_structure: BlockStructure, lhs_structure: BlockStructure, rhs_structure: BlockStructure) {
2078 let mut dst = generate_structured_matrix(true, m, n, dst_structure);
2079 let mut dst_target = dst.as_ref().to_owned();
2080 let dst_orig = dst.as_ref().to_owned();
2081 let lhs = generate_structured_matrix(false, m, k, lhs_structure);
2082 let rhs = generate_structured_matrix(false, k, n, rhs_structure);
2083
2084 for par in [Par::Seq, Par::rayon(8)] {
2085 triangular::matmul_with_conj(
2086 dst.as_mut(),
2087 dst_structure,
2088 Accum::Replace,
2089 lhs.as_ref(),
2090 lhs_structure,
2091 Conj::No,
2092 rhs.as_ref(),
2093 rhs_structure,
2094 Conj::No,
2095 2.5,
2096 par,
2097 );
2098
2099 matmul_with_conj(
2100 dst_target.as_mut(),
2101 Accum::Replace,
2102 lhs.as_ref(),
2103 Conj::No,
2104 rhs.as_ref(),
2105 Conj::No,
2106 2.5,
2107 par,
2108 );
2109
2110 if dst_structure.is_dense() {
2111 for j in 0..n {
2112 for i in 0..m {
2113 assert!((dst[(i, j)] - dst_target[(i, j)]).abs() < 1e-10);
2114 }
2115 }
2116 } else if dst_structure.is_lower() {
2117 for j in 0..n {
2118 if matches!(dst_structure.diag_kind(), DiagonalKind::Generic) {
2119 for i in 0..j {
2120 assert!((dst[(i, j)] - dst_orig[(i, j)]).abs() < 1e-10);
2121 }
2122 for i in j..n {
2123 assert!((dst[(i, j)] - dst_target[(i, j)]).abs() < 1e-10);
2124 }
2125 } else {
2126 for i in 0..=j {
2127 assert!((dst[(i, j)] - dst_orig[(i, j)]).abs() < 1e-10);
2128 }
2129 for i in j + 1..n {
2130 assert!((dst[(i, j)] - dst_target[(i, j)]).abs() < 1e-10);
2131 }
2132 }
2133 }
2134 } else {
2135 for j in 0..n {
2136 if matches!(dst_structure.diag_kind(), DiagonalKind::Generic) {
2137 for i in 0..=j {
2138 assert!((dst[(i, j)] - dst_target[(i, j)]).abs() < 1e-10);
2139 }
2140 for i in j + 1..n {
2141 assert!((dst[(i, j)] - dst_orig[(i, j)]).abs() < 1e-10);
2142 }
2143 } else {
2144 for i in 0..j {
2145 assert!((dst[(i, j)] - dst_target[(i, j)]).abs() < 1e-10);
2146 }
2147 for i in j..n {
2148 assert!((dst[(i, j)] - dst_orig[(i, j)]).abs() < 1e-10);
2149 }
2150 }
2151 }
2152 }
2153 }
2154 }
2155
2156 #[test]
2157 fn test_triangular() {
2158 use BlockStructure::*;
2159 let structures = [
2160 Rectangular,
2161 TriangularLower,
2162 TriangularUpper,
2163 StrictTriangularLower,
2164 StrictTriangularUpper,
2165 UnitTriangularLower,
2166 UnitTriangularUpper,
2167 ];
2168
2169 for dst in structures {
2170 for lhs in structures {
2171 for rhs in structures {
2172 #[cfg(not(miri))]
2173 let big = 100;
2174
2175 #[cfg(miri)]
2176 let big = 31;
2177 for _ in 0..3 {
2178 let m = rand::random::<usize>() % big;
2179 let mut n = rand::random::<usize>() % big;
2180 let mut k = rand::random::<usize>() % big;
2181
2182 match (!dst.is_dense(), !lhs.is_dense(), !rhs.is_dense()) {
2183 (true, true, _) | (true, _, true) | (_, true, true) => {
2184 n = m;
2185 k = m;
2186 },
2187 _ => (),
2188 }
2189
2190 if !dst.is_dense() {
2191 n = m;
2192 }
2193
2194 if !lhs.is_dense() {
2195 k = m;
2196 }
2197
2198 if !rhs.is_dense() {
2199 k = n;
2200 }
2201
2202 run_test_problem(m, n, k, dst, lhs, rhs);
2203 }
2204 }
2205 }
2206 }
2207 }
2208}