1use crate::assert;
2use crate::internal_prelude::*;
3use crate::perm::swap_cols_idx;
4use linalg::householder::{self, HouseholderInfo};
5use pulp::Simd;
6
7pub use super::super::no_pivoting::factor::recommended_blocksize;
8
9#[math]
15fn update_mat_and_dot_simd<T: ComplexField>(
16 norm: RowMut<'_, T>,
17 dot: RowMut<'_, T>,
18 B01: RowMut<'_, T>,
19 B11: MatMut<'_, T, usize, usize, ContiguousFwd>,
20 A10: ColRef<'_, T, usize, ContiguousFwd>,
21 B10: ColRef<'_, T, usize, ContiguousFwd>,
22 l: T,
23 tau_inv: T::Real,
24 align: usize,
25) {
26 struct Impl<'a, 'M, 'N, T: ComplexField> {
27 norm: RowMut<'a, T, Dim<'N>>,
28 dot: RowMut<'a, T, Dim<'N>>,
29 B01: RowMut<'a, T, Dim<'N>>,
30 B11: MatMut<'a, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
31 A10: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
32 B10: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
33 l: T,
34 tau_inv: T::Real,
35 align: usize,
36 }
37 impl<'M, 'N, T: ComplexField> pulp::WithSimd for Impl<'_, 'M, 'N, T> {
38 type Output = ();
39
40 #[inline(always)]
41 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
42 let Self {
43 mut norm,
44 mut dot,
45 B01: mut u,
46 mut B11,
47 A10,
48 B10,
49 l,
50 tau_inv,
51 align,
52 } = self;
53
54 let m = B11.nrows();
55 let n = B11.ncols();
56
57 let simd = SimdCtx::<'_, T, S>::new_align(T::simd_ctx(simd), m, align);
58
59 let (head, body4, body1, tail) = simd.batch_indices::<4>();
60
61 let mut j = n.indices();
62
63 loop {
64 match (j.next(), j.next(), j.next(), j.next()) {
65 (Some(j0), Some(j1), Some(j2), Some(j3)) => {
66 let b0 = copy(dot[j0]);
67 let b1 = copy(dot[j1]);
68 let b2 = copy(dot[j2]);
69 let b3 = copy(dot[j3]);
70
71 let rhs0 = simd.splat(&b0);
72 let rhs1 = simd.splat(&b1);
73 let rhs2 = simd.splat(&b2);
74 let rhs3 = simd.splat(&b3);
75
76 let mut acc0 = simd.zero();
77 let mut acc1 = simd.zero();
78 let mut acc2 = simd.zero();
79 let mut acc3 = simd.zero();
80
81 macro_rules! do_it {
82 ($i: expr) => {{
83 let i = $i;
84
85 let lhs0 = simd.read(A10, i);
86 let lhs1 = simd.read(B10, i);
87
88 let mut dst0 = simd.read(B11.rb().col(j0), i);
89 dst0 = simd.mul_add(lhs0, rhs0, dst0);
90 acc0 = simd.conj_mul_add(lhs1, dst0, acc0);
91 simd.write(B11.rb_mut().col_mut(j0), i, dst0);
92
93 let mut dst1 = simd.read(B11.rb().col(j1), i);
94 dst1 = simd.mul_add(lhs0, rhs1, dst1);
95 acc1 = simd.conj_mul_add(lhs1, dst1, acc1);
96 simd.write(B11.rb_mut().col_mut(j1), i, dst1);
97
98 let mut dst2 = simd.read(B11.rb().col(j2), i);
99 dst2 = simd.mul_add(lhs0, rhs2, dst2);
100 acc2 = simd.conj_mul_add(lhs1, dst2, acc2);
101 simd.write(B11.rb_mut().col_mut(j2), i, dst2);
102
103 let mut dst3 = simd.read(B11.rb().col(j3), i);
104 dst3 = simd.mul_add(lhs0, rhs3, dst3);
105 acc3 = simd.conj_mul_add(lhs1, dst3, acc3);
106 simd.write(B11.rb_mut().col_mut(j3), i, dst3);
107 }};
108 }
109
110 if let Some(i) = head {
111 do_it!(i);
112 }
113
114 for [i0, i1, i2, i3] in body4.clone() {
115 do_it!(i0);
116 do_it!(i1);
117 do_it!(i2);
118 do_it!(i3);
119 }
120 for i in body1.clone() {
121 do_it!(i);
122 }
123 if let Some(i) = tail {
124 do_it!(i);
125 }
126
127 let tmp = u[j0] + l * b0;
128 let d0 = mul_real(tmp + simd.reduce_sum(acc0), -tau_inv);
129 u[j0] = tmp + d0;
130 dot[j0] = d0;
131 norm[j0] = from_real(sqrt(abs2(norm[j0]) - abs2(u[j0])));
132
133 let tmp = u[j1] + l * b1;
134 let d1 = mul_real(tmp + simd.reduce_sum(acc1), -tau_inv);
135 u[j1] = tmp + d1;
136 dot[j1] = d1;
137 norm[j1] = from_real(sqrt(abs2(norm[j1]) - abs2(u[j1])));
138
139 let tmp = u[j2] + l * b2;
140 let d2 = mul_real(tmp + simd.reduce_sum(acc2), -tau_inv);
141 u[j2] = tmp + d2;
142 dot[j2] = d2;
143 norm[j2] = from_real(sqrt(abs2(norm[j2]) - abs2(u[j2])));
144
145 let tmp = u[j3] + l * b3;
146 let d3 = mul_real(tmp + simd.reduce_sum(acc3), -tau_inv);
147 u[j3] = tmp + d3;
148 dot[j3] = d3;
149 norm[j3] = from_real(sqrt(abs2(norm[j3]) - abs2(u[j3])));
150 },
151 (j0, j1, j2, j3) => {
152 for j0 in [j0, j1, j2, j3].into_iter().flatten() {
153 let b0 = copy(dot[j0]);
154 let rhs0 = simd.splat(&b0);
155
156 let mut acc0 = simd.zero();
157
158 macro_rules! do_it {
159 ($i: expr) => {{
160 let i = $i;
161
162 let lhs0 = simd.read(A10, i);
163 let lhs1 = simd.read(B10, i);
164
165 let mut dst0 = simd.read(B11.rb().col(j0), i);
166 dst0 = simd.mul_add(lhs0, rhs0, dst0);
167 acc0 = simd.conj_mul_add(lhs1, dst0, acc0);
168 simd.write(B11.rb_mut().col_mut(j0), i, dst0);
169 }};
170 }
171
172 if let Some(i) = head {
173 do_it!(i);
174 }
175 for [i0, i1, i2, i3] in body4.clone() {
176 do_it!(i0);
177 do_it!(i1);
178 do_it!(i2);
179 do_it!(i3);
180 }
181
182 for i in body1.clone() {
183 do_it!(i);
184 }
185 if let Some(i) = tail {
186 do_it!(i);
187 }
188
189 let tmp = u[j0] + l * b0;
190 let d0 = mul_real(tmp + simd.reduce_sum(acc0), -tau_inv);
191 u[j0] = tmp + d0;
192 dot[j0] = d0;
193 norm[j0] = from_real(sqrt(abs2(norm[j0]) - abs2(u[j0])));
194 }
195 break;
196 },
197 }
198 }
199 }
200 }
201
202 with_dim!(M, B11.nrows());
203 with_dim!(N, B11.ncols());
204 dispatch!(
205 Impl {
206 norm: norm.as_col_shape_mut(N),
207 dot: dot.as_col_shape_mut(N),
208 B01: B01.as_col_shape_mut(N),
209 B11: B11.as_shape_mut(M, N),
210 A10: A10.as_row_shape(M),
211 B10: B10.as_row_shape(M),
212 l,
213 tau_inv,
214 align
215 },
216 Impl,
217 T
218 )
219}
220
221#[math]
222
223#[derive(Copy, Clone, Debug)]
225pub struct ColPivQrParams {
226 pub blocking_threshold: usize,
228 pub par_threshold: usize,
230
231 #[doc(hidden)]
232 pub non_exhaustive: NonExhaustive,
233}
234
235impl<T: ComplexField> Auto<T> for ColPivQrParams {
236 #[inline]
237 fn auto() -> Self {
238 Self {
239 blocking_threshold: 48 * 48,
240 par_threshold: 192 * 256,
241 non_exhaustive: NonExhaustive(()),
242 }
243 }
244}
245
246#[track_caller]
247#[math]
248fn qr_in_place_unblocked<'out, I: Index, T: ComplexField>(
249 A: MatMut<'_, T>,
250 H: RowMut<'_, T>,
251 col_perm: &'out mut [I],
252 col_perm_inv: &'out mut [I],
253 par: Par,
254 stack: &mut MemStack,
255 params: Spec<ColPivQrParams, T>,
256) -> (ColPivQrInfo, PermRef<'out, I>) {
257 let m = A.nrows();
258 let n = A.ncols();
259 let size = H.ncols();
260
261 let params = params.config;
262 let mut A = A;
263 let mut H = H;
264 let mut par = par;
265
266 assert!(size == Ord::min(m, n));
267 for j in 0..n {
268 col_perm[j] = I::truncate(j);
269 }
270
271 let mut n_trans = 0;
272
273 'main: {
274 if size == 0 {
275 break 'main;
276 }
277
278 let (mut dot, stack) = temp_mat_zeroed::<T, _, _>(n, 1, stack);
279 let (mut norm, stack) = temp_mat_zeroed::<T, _, _>(n, 1, stack);
280 let _ = stack;
281
282 let mut dot = dot.as_mat_mut().col_mut(0).transpose_mut();
283 let mut norm = norm.as_mat_mut().col_mut(0).transpose_mut();
284
285 let mut best = zero();
286
287 let threshold = sqrt(eps::<T::Real>());
288
289 for j in 0..n {
290 let val = A.rb().col(j).norm_l2();
291 norm[j] = from_real(val);
292
293 if val > best {
294 best = val;
295 }
296 }
297
298 let scale_fwd = copy(best);
299 let scale_bwd = recip(best);
300
301 zip!(A.rb_mut()).for_each(|unzip!(a)| *a = mul_real(*a, scale_bwd));
302
303 for j in 0..n {
304 norm[j] = from_real(real(norm[j]) * scale_bwd);
305 }
306 best = best * scale_bwd;
307 let mut best_threshold = best * threshold;
308
309 'unscale: {
310 for k in 0..size {
311 let mut new_best = zero::<T::Real>();
312 let mut best_col = k;
313 for j in k..n {
314 let val = real(norm[j]);
315 if val > new_best {
316 new_best = val;
317 best_col = j;
318 }
319 }
320
321 let delayed_update = T::SIMD_CAPABILITIES.is_simd() && A.row_stride() == 1 && k > 0 && new_best >= best_threshold;
322
323 if k > 0 && !delayed_update {
324 let (_, _, A10, mut A11) = A.rb_mut().split_at_mut(k, k);
325 let dot = dot.rb().get(k..);
326 let A10 = A10.rb().col(k - 1);
327
328 linalg::matmul::matmul(A11.rb_mut(), Accum::Add, A10, dot, one(), par);
329
330 best = zero();
331 for j in k..n {
332 let val = A11.rb().col(j - k).norm_l2();
333
334 norm[j] = from_real(val);
335
336 if val > best {
337 best = val;
338 best_col = j;
339 }
340 }
341 best_threshold = best * threshold;
342 }
343
344 if best_col != k {
345 n_trans += 1;
346 col_perm.as_mut().swap(best_col, k);
347 swap_cols_idx(A.rb_mut(), best_col, k);
348 swap_cols_idx(dot.rb_mut().as_mat_mut(), best_col, k);
349 swap_cols_idx(norm.rb_mut().as_mat_mut(), best_col, k);
350 }
351
352 let (_, _, A10, mut A11) = A.rb_mut().split_at_mut(k, k);
353 let A10 = A10.rb();
354 let dot0 = dot.rb_mut().get_mut(k..);
355
356 let (mut B00, B01, B10, mut B11) = A11.rb_mut().split_at_mut(1, 1);
357 let B00 = &mut B00[(0, 0)];
358 let mut B01 = B01.row_mut(0);
359 let mut B10 = B10.col_mut(0);
360
361 let l = if delayed_update {
362 let A10 = A10.col(k - 1);
363 copy(A10[0])
364 } else {
365 zero()
366 };
367 let r = copy(dot0[0]);
368
369 let mut dot = dot.rb_mut().get_mut(k + 1..);
370 let mut norm = norm.rb_mut().get_mut(k + 1..);
371
372 if delayed_update {
373 let A10 = A10.col(k - 1).get(1..);
374
375 *B00 = *B00 + l * r;
376 zip!(B10.rb_mut(), A10).for_each(|unzip!(x, y)| {
377 *x = *x + r * *y;
378 });
379 }
380
381 let HouseholderInfo { tau, .. } = householder::make_householder_in_place(B00, B10.rb_mut());
382 let tau_inv = recip(tau);
383 H[k] = from_real(tau);
384
385 if k + 1 == size {
386 if delayed_update {
387 zip!(B01.rb_mut(), dot.rb()).for_each(|unzip!(x, y)| {
388 *x = *x + l * *y;
389 });
390 }
391 break 'unscale;
392 }
393
394 if (m - k - 1) * (n - k - 1) < params.par_threshold {
395 par = Par::Seq;
396 }
397
398 if delayed_update {
399 let A10 = A10.col(k - 1).get(1..);
400
401 match par {
402 Par::Seq => {
403 update_mat_and_dot_simd(
404 norm.rb_mut(),
405 dot.rb_mut(),
406 B01.rb_mut(),
407 B11.rb_mut().try_as_col_major_mut().unwrap(),
408 A10.try_as_col_major().unwrap(),
409 B10.rb().try_as_col_major().unwrap(),
410 l,
411 tau_inv,
412 simd_align(k + 1),
413 );
414 },
415 #[cfg(feature = "rayon")]
416 Par::Rayon(nthreads) => {
417 let nthreads = nthreads.get();
418 use rayon::prelude::*;
419 norm.par_partition_mut(nthreads)
420 .zip(dot.par_partition_mut(nthreads))
421 .zip(B01.par_partition_mut(nthreads))
422 .zip(B11.par_col_partition_mut(nthreads))
423 .for_each(|(((norm, dot), B01), B11)| {
424 update_mat_and_dot_simd(
425 norm,
426 dot,
427 B01,
428 B11.try_as_col_major_mut().unwrap(),
429 A10.try_as_col_major().unwrap(),
430 B10.rb().try_as_col_major().unwrap(),
431 copy(l),
432 copy(tau_inv),
433 simd_align(k + 1),
434 );
435 });
436 },
437 }
438 } else {
439 dot.copy_from(B01.rb());
440 linalg::matmul::matmul(dot.rb_mut(), Accum::Add, B10.rb().adjoint(), B11.rb(), one(), par);
441
442 zip!(B01.rb_mut(), dot.rb_mut(), norm.rb_mut()).for_each(|unzip!(a, dot, norm)| {
443 *dot = mul_real(-*dot, tau_inv);
444 *a = *a + *dot;
445 *norm = from_real(sqrt(abs2(*norm) - abs2(*a)));
446 });
447 }
448 }
449 }
450 zip!(A.rb_mut()).for_each_triangular_upper(linalg::zip::Diag::Include, |unzip!(a)| *a = mul_real(*a, scale_fwd));
451 }
452
453 for j in 0..n {
454 col_perm_inv[col_perm[j].zx()] = I::truncate(j);
455 }
456
457 (
458 ColPivQrInfo {
459 transposition_count: n_trans,
460 },
461 unsafe { PermRef::new_unchecked(col_perm, col_perm_inv, n) },
462 )
463}
464
465pub fn qr_in_place_scratch<I: Index, T: ComplexField>(
468 nrows: usize,
469 ncols: usize,
470 blocksize: usize,
471 par: Par,
472 params: Spec<ColPivQrParams, T>,
473) -> StackReq {
474 let _ = nrows;
475 let _ = ncols;
476 let _ = par;
477 let _ = blocksize;
478 let _ = ¶ms;
479 linalg::temp_mat_scratch::<T>(ncols, 2)
480}
481
482#[derive(Copy, Clone, Debug)]
484pub struct ColPivQrInfo {
485 pub transposition_count: usize,
488}
489
490#[track_caller]
491#[math]
492pub fn qr_in_place<'out, I: Index, T: ComplexField>(
493 A: MatMut<'_, T>,
494 Q_coeff: MatMut<'_, T>,
495 col_perm: &'out mut [I],
496 col_perm_inv: &'out mut [I],
497 par: Par,
498 stack: &mut MemStack,
499 params: Spec<ColPivQrParams, T>,
500) -> (ColPivQrInfo, PermRef<'out, I>) {
501 let mut A = A;
502 let mut H = Q_coeff;
503 let size = H.ncols();
504 let blocksize = H.nrows();
505
506 let ret = qr_in_place_unblocked(A.rb_mut(), H.rb_mut().row_mut(0), col_perm, col_perm_inv, par, stack, params);
507
508 let mut j = 0;
509 while j < size {
510 let blocksize = Ord::min(blocksize, size - j);
511
512 let mut H = H.rb_mut().subcols_mut(j, blocksize).subrows_mut(0, blocksize);
513
514 for j in 0..blocksize {
515 H[(j, j)] = copy(H[(0, j)]);
516 }
517
518 let A = A.rb().get(j.., j..j + blocksize);
519
520 householder::upgrade_householder_factor(H.rb_mut(), A, blocksize, 1, par);
521 j += blocksize;
522 }
523 ret
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use crate::stats::prelude::*;
530 use crate::utils::approx::*;
531 use crate::{Mat, assert, c64};
532 use dyn_stack::MemBuffer;
533
534 #[test]
535 fn test_unblocked_qr() {
536 let rng = &mut StdRng::seed_from_u64(0);
537
538 for par in [Par::Seq, Par::rayon(8)] {
539 for n in [2, 3, 4, 8, 16, 24, 32, 128, 255] {
540 let bs = 15;
541
542 let approx_eq = CwiseMat(ApproxEq {
543 abs_tol: 1e-10,
544 rel_tol: 1e-10,
545 });
546
547 let A = CwiseMatDistribution {
548 nrows: n,
549 ncols: n,
550 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
551 }
552 .rand::<Mat<c64>>(rng);
553 let A = A.as_ref();
554 let mut QR = A.cloned();
555 let mut H = Mat::zeros(bs, n);
556
557 let col_perm = &mut *vec![0usize; n];
558 let col_perm_inv = &mut *vec![0usize; n];
559
560 let q = qr_in_place(
561 QR.as_mut(),
562 H.as_mut(),
563 col_perm,
564 col_perm_inv,
565 par,
566 MemStack::new(&mut MemBuffer::new(qr_in_place_scratch::<usize, c64>(n, n, bs, par, default()))),
567 default(),
568 )
569 .1;
570
571 let mut Q = Mat::<c64, _, _>::zeros(n, n);
572 let mut R = QR.as_ref().cloned();
573
574 for j in 0..n {
575 Q[(j, j)] = c64::ONE;
576 }
577
578 householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
579 QR.as_ref(),
580 H.as_ref(),
581 Conj::No,
582 Q.as_mut(),
583 Par::Seq,
584 MemStack::new(&mut MemBuffer::new(
585 householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<c64>(n, bs, n),
586 )),
587 );
588
589 for j in 0..n {
590 for i in j + 1..n {
591 R[(i, j)] = c64::ZERO;
592 }
593 }
594
595 assert!(Q * R * q ~ A);
596 }
597
598 let n = 20;
599 for m in [2, 3, 4, 8, 16, 24, 32, 128, 255] {
600 let bs = 15;
601 let size = Ord::min(m, n);
602
603 let approx_eq = CwiseMat(ApproxEq {
604 abs_tol: 1e-10,
605 rel_tol: 1e-10,
606 });
607
608 let A = CwiseMatDistribution {
609 nrows: m,
610 ncols: n,
611 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
612 }
613 .rand::<Mat<c64>>(rng);
614 let A = A.as_ref();
615 let mut QR = A.cloned();
616 let mut H = Mat::zeros(bs, size);
617
618 let col_perm = &mut *vec![0usize; n];
619 let col_perm_inv = &mut *vec![0usize; n];
620
621 let q = qr_in_place(
622 QR.as_mut(),
623 H.as_mut(),
624 col_perm,
625 col_perm_inv,
626 par,
627 MemStack::new(&mut MemBuffer::new(qr_in_place_scratch::<usize, c64>(m, n, bs, par, default()))),
628 default(),
629 )
630 .1;
631
632 let mut Q = Mat::<c64, _, _>::zeros(m, m);
633 let mut R = QR.as_ref().cloned();
634
635 for j in 0..m {
636 Q[(j, j)] = c64::ONE;
637 }
638
639 householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
640 QR.as_ref().subcols(0, size),
641 H.as_ref(),
642 Conj::No,
643 Q.as_mut(),
644 Par::Seq,
645 MemStack::new(&mut MemBuffer::new(
646 householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<c64>(m, bs, m),
647 )),
648 );
649
650 for j in 0..n {
651 for i in j + 1..m {
652 R[(i, j)] = c64::ZERO;
653 }
654 }
655
656 assert!(Q * R * q ~ A);
657 }
658 }
659 }
660}