faer/linalg/qr/col_pivoting/
factor.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use crate::perm::swap_cols_idx;
4use linalg::householder::{self, HouseholderInfo};
5use pulp::Simd;
6
7pub use super::super::no_pivoting::factor::recommended_blocksize;
8
9// B11 += A10 * dot
10// B01 += l * dot
11// dot  = -tau_inv * (B01 + B10^H * B11)
12// B01 += dot
13// norm-= abs2(B01)
14#[math]
15fn update_mat_and_dot_simd<T: ComplexField>(
16	norm: RowMut<'_, T>,
17	dot: RowMut<'_, T>,
18	B01: RowMut<'_, T>,
19	B11: MatMut<'_, T, usize, usize, ContiguousFwd>,
20	A10: ColRef<'_, T, usize, ContiguousFwd>,
21	B10: ColRef<'_, T, usize, ContiguousFwd>,
22	l: T,
23	tau_inv: T::Real,
24	align: usize,
25) {
26	struct Impl<'a, 'M, 'N, T: ComplexField> {
27		norm: RowMut<'a, T, Dim<'N>>,
28		dot: RowMut<'a, T, Dim<'N>>,
29		B01: RowMut<'a, T, Dim<'N>>,
30		B11: MatMut<'a, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
31		A10: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
32		B10: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
33		l: T,
34		tau_inv: T::Real,
35		align: usize,
36	}
37	impl<'M, 'N, T: ComplexField> pulp::WithSimd for Impl<'_, 'M, 'N, T> {
38		type Output = ();
39
40		#[inline(always)]
41		fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
42			let Self {
43				mut norm,
44				mut dot,
45				B01: mut u,
46				mut B11,
47				A10,
48				B10,
49				l,
50				tau_inv,
51				align,
52			} = self;
53
54			let m = B11.nrows();
55			let n = B11.ncols();
56
57			let simd = SimdCtx::<'_, T, S>::new_align(T::simd_ctx(simd), m, align);
58
59			let (head, body4, body1, tail) = simd.batch_indices::<4>();
60
61			let mut j = n.indices();
62
63			loop {
64				match (j.next(), j.next(), j.next(), j.next()) {
65					(Some(j0), Some(j1), Some(j2), Some(j3)) => {
66						let b0 = copy(dot[j0]);
67						let b1 = copy(dot[j1]);
68						let b2 = copy(dot[j2]);
69						let b3 = copy(dot[j3]);
70
71						let rhs0 = simd.splat(&b0);
72						let rhs1 = simd.splat(&b1);
73						let rhs2 = simd.splat(&b2);
74						let rhs3 = simd.splat(&b3);
75
76						let mut acc0 = simd.zero();
77						let mut acc1 = simd.zero();
78						let mut acc2 = simd.zero();
79						let mut acc3 = simd.zero();
80
81						macro_rules! do_it {
82							($i: expr) => {{
83								let i = $i;
84
85								let lhs0 = simd.read(A10, i);
86								let lhs1 = simd.read(B10, i);
87
88								let mut dst0 = simd.read(B11.rb().col(j0), i);
89								dst0 = simd.mul_add(lhs0, rhs0, dst0);
90								acc0 = simd.conj_mul_add(lhs1, dst0, acc0);
91								simd.write(B11.rb_mut().col_mut(j0), i, dst0);
92
93								let mut dst1 = simd.read(B11.rb().col(j1), i);
94								dst1 = simd.mul_add(lhs0, rhs1, dst1);
95								acc1 = simd.conj_mul_add(lhs1, dst1, acc1);
96								simd.write(B11.rb_mut().col_mut(j1), i, dst1);
97
98								let mut dst2 = simd.read(B11.rb().col(j2), i);
99								dst2 = simd.mul_add(lhs0, rhs2, dst2);
100								acc2 = simd.conj_mul_add(lhs1, dst2, acc2);
101								simd.write(B11.rb_mut().col_mut(j2), i, dst2);
102
103								let mut dst3 = simd.read(B11.rb().col(j3), i);
104								dst3 = simd.mul_add(lhs0, rhs3, dst3);
105								acc3 = simd.conj_mul_add(lhs1, dst3, acc3);
106								simd.write(B11.rb_mut().col_mut(j3), i, dst3);
107							}};
108						}
109
110						if let Some(i) = head {
111							do_it!(i);
112						}
113
114						for [i0, i1, i2, i3] in body4.clone() {
115							do_it!(i0);
116							do_it!(i1);
117							do_it!(i2);
118							do_it!(i3);
119						}
120						for i in body1.clone() {
121							do_it!(i);
122						}
123						if let Some(i) = tail {
124							do_it!(i);
125						}
126
127						let tmp = u[j0] + l * b0;
128						let d0 = mul_real(tmp + simd.reduce_sum(acc0), -tau_inv);
129						u[j0] = tmp + d0;
130						dot[j0] = d0;
131						norm[j0] = from_real(sqrt(abs2(norm[j0]) - abs2(u[j0])));
132
133						let tmp = u[j1] + l * b1;
134						let d1 = mul_real(tmp + simd.reduce_sum(acc1), -tau_inv);
135						u[j1] = tmp + d1;
136						dot[j1] = d1;
137						norm[j1] = from_real(sqrt(abs2(norm[j1]) - abs2(u[j1])));
138
139						let tmp = u[j2] + l * b2;
140						let d2 = mul_real(tmp + simd.reduce_sum(acc2), -tau_inv);
141						u[j2] = tmp + d2;
142						dot[j2] = d2;
143						norm[j2] = from_real(sqrt(abs2(norm[j2]) - abs2(u[j2])));
144
145						let tmp = u[j3] + l * b3;
146						let d3 = mul_real(tmp + simd.reduce_sum(acc3), -tau_inv);
147						u[j3] = tmp + d3;
148						dot[j3] = d3;
149						norm[j3] = from_real(sqrt(abs2(norm[j3]) - abs2(u[j3])));
150					},
151					(j0, j1, j2, j3) => {
152						for j0 in [j0, j1, j2, j3].into_iter().flatten() {
153							let b0 = copy(dot[j0]);
154							let rhs0 = simd.splat(&b0);
155
156							let mut acc0 = simd.zero();
157
158							macro_rules! do_it {
159								($i: expr) => {{
160									let i = $i;
161
162									let lhs0 = simd.read(A10, i);
163									let lhs1 = simd.read(B10, i);
164
165									let mut dst0 = simd.read(B11.rb().col(j0), i);
166									dst0 = simd.mul_add(lhs0, rhs0, dst0);
167									acc0 = simd.conj_mul_add(lhs1, dst0, acc0);
168									simd.write(B11.rb_mut().col_mut(j0), i, dst0);
169								}};
170							}
171
172							if let Some(i) = head {
173								do_it!(i);
174							}
175							for [i0, i1, i2, i3] in body4.clone() {
176								do_it!(i0);
177								do_it!(i1);
178								do_it!(i2);
179								do_it!(i3);
180							}
181
182							for i in body1.clone() {
183								do_it!(i);
184							}
185							if let Some(i) = tail {
186								do_it!(i);
187							}
188
189							let tmp = u[j0] + l * b0;
190							let d0 = mul_real(tmp + simd.reduce_sum(acc0), -tau_inv);
191							u[j0] = tmp + d0;
192							dot[j0] = d0;
193							norm[j0] = from_real(sqrt(abs2(norm[j0]) - abs2(u[j0])));
194						}
195						break;
196					},
197				}
198			}
199		}
200	}
201
202	with_dim!(M, B11.nrows());
203	with_dim!(N, B11.ncols());
204	dispatch!(
205		Impl {
206			norm: norm.as_col_shape_mut(N),
207			dot: dot.as_col_shape_mut(N),
208			B01: B01.as_col_shape_mut(N),
209			B11: B11.as_shape_mut(M, N),
210			A10: A10.as_row_shape(M),
211			B10: B10.as_row_shape(M),
212			l,
213			tau_inv,
214			align
215		},
216		Impl,
217		T
218	)
219}
220
221#[math]
222
223/// $QR$ factorization with column pivoting tuning parameters
224#[derive(Copy, Clone, Debug)]
225pub struct ColPivQrParams {
226	/// threshold at which blocking algorithms should be disabled
227	pub blocking_threshold: usize,
228	/// threshold at which the parallelism should be disabled
229	pub par_threshold: usize,
230
231	#[doc(hidden)]
232	pub non_exhaustive: NonExhaustive,
233}
234
235impl<T: ComplexField> Auto<T> for ColPivQrParams {
236	#[inline]
237	fn auto() -> Self {
238		Self {
239			blocking_threshold: 48 * 48,
240			par_threshold: 192 * 256,
241			non_exhaustive: NonExhaustive(()),
242		}
243	}
244}
245
246#[track_caller]
247#[math]
248fn qr_in_place_unblocked<'out, I: Index, T: ComplexField>(
249	A: MatMut<'_, T>,
250	H: RowMut<'_, T>,
251	col_perm: &'out mut [I],
252	col_perm_inv: &'out mut [I],
253	par: Par,
254	stack: &mut MemStack,
255	params: Spec<ColPivQrParams, T>,
256) -> (ColPivQrInfo, PermRef<'out, I>) {
257	let m = A.nrows();
258	let n = A.ncols();
259	let size = H.ncols();
260
261	let params = params.config;
262	let mut A = A;
263	let mut H = H;
264	let mut par = par;
265
266	assert!(size == Ord::min(m, n));
267	for j in 0..n {
268		col_perm[j] = I::truncate(j);
269	}
270
271	let mut n_trans = 0;
272
273	'main: {
274		if size == 0 {
275			break 'main;
276		}
277
278		let (mut dot, stack) = temp_mat_zeroed::<T, _, _>(n, 1, stack);
279		let (mut norm, stack) = temp_mat_zeroed::<T, _, _>(n, 1, stack);
280		let _ = stack;
281
282		let mut dot = dot.as_mat_mut().col_mut(0).transpose_mut();
283		let mut norm = norm.as_mat_mut().col_mut(0).transpose_mut();
284
285		let mut best = zero();
286
287		let threshold = sqrt(eps::<T::Real>());
288
289		for j in 0..n {
290			let val = A.rb().col(j).norm_l2();
291			norm[j] = from_real(val);
292
293			if val > best {
294				best = val;
295			}
296		}
297
298		let scale_fwd = copy(best);
299		let scale_bwd = recip(best);
300
301		zip!(A.rb_mut()).for_each(|unzip!(a)| *a = mul_real(*a, scale_bwd));
302
303		for j in 0..n {
304			norm[j] = from_real(real(norm[j]) * scale_bwd);
305		}
306		best = best * scale_bwd;
307		let mut best_threshold = best * threshold;
308
309		'unscale: {
310			for k in 0..size {
311				let mut new_best = zero::<T::Real>();
312				let mut best_col = k;
313				for j in k..n {
314					let val = real(norm[j]);
315					if val > new_best {
316						new_best = val;
317						best_col = j;
318					}
319				}
320
321				let delayed_update = T::SIMD_CAPABILITIES.is_simd() && A.row_stride() == 1 && k > 0 && new_best >= best_threshold;
322
323				if k > 0 && !delayed_update {
324					let (_, _, A10, mut A11) = A.rb_mut().split_at_mut(k, k);
325					let dot = dot.rb().get(k..);
326					let A10 = A10.rb().col(k - 1);
327
328					linalg::matmul::matmul(A11.rb_mut(), Accum::Add, A10, dot, one(), par);
329
330					best = zero();
331					for j in k..n {
332						let val = A11.rb().col(j - k).norm_l2();
333
334						norm[j] = from_real(val);
335
336						if val > best {
337							best = val;
338							best_col = j;
339						}
340					}
341					best_threshold = best * threshold;
342				}
343
344				if best_col != k {
345					n_trans += 1;
346					col_perm.as_mut().swap(best_col, k);
347					swap_cols_idx(A.rb_mut(), best_col, k);
348					swap_cols_idx(dot.rb_mut().as_mat_mut(), best_col, k);
349					swap_cols_idx(norm.rb_mut().as_mat_mut(), best_col, k);
350				}
351
352				let (_, _, A10, mut A11) = A.rb_mut().split_at_mut(k, k);
353				let A10 = A10.rb();
354				let dot0 = dot.rb_mut().get_mut(k..);
355
356				let (mut B00, B01, B10, mut B11) = A11.rb_mut().split_at_mut(1, 1);
357				let B00 = &mut B00[(0, 0)];
358				let mut B01 = B01.row_mut(0);
359				let mut B10 = B10.col_mut(0);
360
361				let l = if delayed_update {
362					let A10 = A10.col(k - 1);
363					copy(A10[0])
364				} else {
365					zero()
366				};
367				let r = copy(dot0[0]);
368
369				let mut dot = dot.rb_mut().get_mut(k + 1..);
370				let mut norm = norm.rb_mut().get_mut(k + 1..);
371
372				if delayed_update {
373					let A10 = A10.col(k - 1).get(1..);
374
375					*B00 = *B00 + l * r;
376					zip!(B10.rb_mut(), A10).for_each(|unzip!(x, y)| {
377						*x = *x + r * *y;
378					});
379				}
380
381				let HouseholderInfo { tau, .. } = householder::make_householder_in_place(B00, B10.rb_mut());
382				let tau_inv = recip(tau);
383				H[k] = from_real(tau);
384
385				if k + 1 == size {
386					if delayed_update {
387						zip!(B01.rb_mut(), dot.rb()).for_each(|unzip!(x, y)| {
388							*x = *x + l * *y;
389						});
390					}
391					break 'unscale;
392				}
393
394				if (m - k - 1) * (n - k - 1) < params.par_threshold {
395					par = Par::Seq;
396				}
397
398				if delayed_update {
399					let A10 = A10.col(k - 1).get(1..);
400
401					match par {
402						Par::Seq => {
403							update_mat_and_dot_simd(
404								norm.rb_mut(),
405								dot.rb_mut(),
406								B01.rb_mut(),
407								B11.rb_mut().try_as_col_major_mut().unwrap(),
408								A10.try_as_col_major().unwrap(),
409								B10.rb().try_as_col_major().unwrap(),
410								l,
411								tau_inv,
412								simd_align(k + 1),
413							);
414						},
415						#[cfg(feature = "rayon")]
416						Par::Rayon(nthreads) => {
417							let nthreads = nthreads.get();
418							use rayon::prelude::*;
419							norm.par_partition_mut(nthreads)
420								.zip(dot.par_partition_mut(nthreads))
421								.zip(B01.par_partition_mut(nthreads))
422								.zip(B11.par_col_partition_mut(nthreads))
423								.for_each(|(((norm, dot), B01), B11)| {
424									update_mat_and_dot_simd(
425										norm,
426										dot,
427										B01,
428										B11.try_as_col_major_mut().unwrap(),
429										A10.try_as_col_major().unwrap(),
430										B10.rb().try_as_col_major().unwrap(),
431										copy(l),
432										copy(tau_inv),
433										simd_align(k + 1),
434									);
435								});
436						},
437					}
438				} else {
439					dot.copy_from(B01.rb());
440					linalg::matmul::matmul(dot.rb_mut(), Accum::Add, B10.rb().adjoint(), B11.rb(), one(), par);
441
442					zip!(B01.rb_mut(), dot.rb_mut(), norm.rb_mut()).for_each(|unzip!(a, dot, norm)| {
443						*dot = mul_real(-*dot, tau_inv);
444						*a = *a + *dot;
445						*norm = from_real(sqrt(abs2(*norm) - abs2(*a)));
446					});
447				}
448			}
449		}
450		zip!(A.rb_mut()).for_each_triangular_upper(linalg::zip::Diag::Include, |unzip!(a)| *a = mul_real(*a, scale_fwd));
451	}
452
453	for j in 0..n {
454		col_perm_inv[col_perm[j].zx()] = I::truncate(j);
455	}
456
457	(
458		ColPivQrInfo {
459			transposition_count: n_trans,
460		},
461		unsafe { PermRef::new_unchecked(col_perm, col_perm_inv, n) },
462	)
463}
464
465/// computes the size and alignment of required workspace for performing a qr decomposition
466/// with column pivoting
467pub fn qr_in_place_scratch<I: Index, T: ComplexField>(
468	nrows: usize,
469	ncols: usize,
470	blocksize: usize,
471	par: Par,
472	params: Spec<ColPivQrParams, T>,
473) -> StackReq {
474	let _ = nrows;
475	let _ = ncols;
476	let _ = par;
477	let _ = blocksize;
478	let _ = &params;
479	linalg::temp_mat_scratch::<T>(ncols, 2)
480}
481
482/// information about the resulting $QR$ factorization.
483#[derive(Copy, Clone, Debug)]
484pub struct ColPivQrInfo {
485	/// number of transpositions that were performed, can be used to compute the determinant of
486	/// $P$.
487	pub transposition_count: usize,
488}
489
490#[track_caller]
491#[math]
492pub fn qr_in_place<'out, I: Index, T: ComplexField>(
493	A: MatMut<'_, T>,
494	Q_coeff: MatMut<'_, T>,
495	col_perm: &'out mut [I],
496	col_perm_inv: &'out mut [I],
497	par: Par,
498	stack: &mut MemStack,
499	params: Spec<ColPivQrParams, T>,
500) -> (ColPivQrInfo, PermRef<'out, I>) {
501	let mut A = A;
502	let mut H = Q_coeff;
503	let size = H.ncols();
504	let blocksize = H.nrows();
505
506	let ret = qr_in_place_unblocked(A.rb_mut(), H.rb_mut().row_mut(0), col_perm, col_perm_inv, par, stack, params);
507
508	let mut j = 0;
509	while j < size {
510		let blocksize = Ord::min(blocksize, size - j);
511
512		let mut H = H.rb_mut().subcols_mut(j, blocksize).subrows_mut(0, blocksize);
513
514		for j in 0..blocksize {
515			H[(j, j)] = copy(H[(0, j)]);
516		}
517
518		let A = A.rb().get(j.., j..j + blocksize);
519
520		householder::upgrade_householder_factor(H.rb_mut(), A, blocksize, 1, par);
521		j += blocksize;
522	}
523	ret
524}
525
526#[cfg(test)]
527mod tests {
528	use super::*;
529	use crate::stats::prelude::*;
530	use crate::utils::approx::*;
531	use crate::{Mat, assert, c64};
532	use dyn_stack::MemBuffer;
533
534	#[test]
535	fn test_unblocked_qr() {
536		let rng = &mut StdRng::seed_from_u64(0);
537
538		for par in [Par::Seq, Par::rayon(8)] {
539			for n in [2, 3, 4, 8, 16, 24, 32, 128, 255] {
540				let bs = 15;
541
542				let approx_eq = CwiseMat(ApproxEq {
543					abs_tol: 1e-10,
544					rel_tol: 1e-10,
545				});
546
547				let A = CwiseMatDistribution {
548					nrows: n,
549					ncols: n,
550					dist: ComplexDistribution::new(StandardNormal, StandardNormal),
551				}
552				.rand::<Mat<c64>>(rng);
553				let A = A.as_ref();
554				let mut QR = A.cloned();
555				let mut H = Mat::zeros(bs, n);
556
557				let col_perm = &mut *vec![0usize; n];
558				let col_perm_inv = &mut *vec![0usize; n];
559
560				let q = qr_in_place(
561					QR.as_mut(),
562					H.as_mut(),
563					col_perm,
564					col_perm_inv,
565					par,
566					MemStack::new(&mut MemBuffer::new(qr_in_place_scratch::<usize, c64>(n, n, bs, par, default()))),
567					default(),
568				)
569				.1;
570
571				let mut Q = Mat::<c64, _, _>::zeros(n, n);
572				let mut R = QR.as_ref().cloned();
573
574				for j in 0..n {
575					Q[(j, j)] = c64::ONE;
576				}
577
578				householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
579					QR.as_ref(),
580					H.as_ref(),
581					Conj::No,
582					Q.as_mut(),
583					Par::Seq,
584					MemStack::new(&mut MemBuffer::new(
585						householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<c64>(n, bs, n),
586					)),
587				);
588
589				for j in 0..n {
590					for i in j + 1..n {
591						R[(i, j)] = c64::ZERO;
592					}
593				}
594
595				assert!(Q * R * q ~ A);
596			}
597
598			let n = 20;
599			for m in [2, 3, 4, 8, 16, 24, 32, 128, 255] {
600				let bs = 15;
601				let size = Ord::min(m, n);
602
603				let approx_eq = CwiseMat(ApproxEq {
604					abs_tol: 1e-10,
605					rel_tol: 1e-10,
606				});
607
608				let A = CwiseMatDistribution {
609					nrows: m,
610					ncols: n,
611					dist: ComplexDistribution::new(StandardNormal, StandardNormal),
612				}
613				.rand::<Mat<c64>>(rng);
614				let A = A.as_ref();
615				let mut QR = A.cloned();
616				let mut H = Mat::zeros(bs, size);
617
618				let col_perm = &mut *vec![0usize; n];
619				let col_perm_inv = &mut *vec![0usize; n];
620
621				let q = qr_in_place(
622					QR.as_mut(),
623					H.as_mut(),
624					col_perm,
625					col_perm_inv,
626					par,
627					MemStack::new(&mut MemBuffer::new(qr_in_place_scratch::<usize, c64>(m, n, bs, par, default()))),
628					default(),
629				)
630				.1;
631
632				let mut Q = Mat::<c64, _, _>::zeros(m, m);
633				let mut R = QR.as_ref().cloned();
634
635				for j in 0..m {
636					Q[(j, j)] = c64::ONE;
637				}
638
639				householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
640					QR.as_ref().subcols(0, size),
641					H.as_ref(),
642					Conj::No,
643					Q.as_mut(),
644					Par::Seq,
645					MemStack::new(&mut MemBuffer::new(
646						householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<c64>(m, bs, m),
647					)),
648				);
649
650				for j in 0..n {
651					for i in j + 1..m {
652						R[(i, j)] = c64::ZERO;
653					}
654				}
655
656				assert!(Q * R * q ~ A);
657			}
658		}
659	}
660}