faer/linalg/svd/
bidiag.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::householder::*;
4use linalg::matmul::{dot, matmul};
5
6/// computes the size and alignment of the workspace required to compute a matrix's
7/// bidiagonalization
8pub fn bidiag_in_place_scratch<T: ComplexField>(nrows: usize, ncols: usize, par: Par, params: Spec<BidiagParams, T>) -> StackReq {
9	_ = par;
10	_ = params;
11	StackReq::all_of(&[temp_mat_scratch::<T>(nrows, 1), temp_mat_scratch::<T>(ncols, 1)])
12}
13
14/// bidiagonalization tuning parameters.
15#[derive(Debug, Copy, Clone)]
16pub struct BidiagParams {
17	/// threshold at which parallelism should be disabled
18	pub par_threshold: usize,
19	#[doc(hidden)]
20	pub non_exhaustive: NonExhaustive,
21}
22
23impl<T: ComplexField> Auto<T> for BidiagParams {
24	fn auto() -> Self {
25		Self {
26			par_threshold: 192 * 256,
27			non_exhaustive: NonExhaustive(()),
28		}
29	}
30}
31
32/// computes a matrix $A$'s bidiagonalization such that $A = U B V^H$
33///
34/// $B$ is a bidiagonal matrix stored in $A$'s diagonal and superdiagonal
35///
36/// $U$ is a sequence of householder reflections stored in the unit lower triangular half of $A$,
37/// with the householder coefficients being stored in `H_left`
38///
39/// $V$ is a sequence of householder reflections stored in the unit upper triangular half of $A$
40/// (excluding the diagonal), with the householder coefficients being stored in `H_right`
41#[math]
42pub fn bidiag_in_place<T: ComplexField>(
43	A: MatMut<'_, T>,
44	H_left: MatMut<'_, T>,
45	H_right: MatMut<'_, T>,
46	par: Par,
47	stack: &mut MemStack,
48	params: Spec<BidiagParams, T>,
49) {
50	let params = params.config;
51	let m = A.nrows();
52	let n = A.ncols();
53	let size = Ord::min(m, n);
54	let bl = H_left.nrows();
55	let br = H_right.nrows();
56
57	assert!(H_left.ncols() == size);
58	assert!(H_right.ncols() == size.saturating_sub(1));
59
60	let (mut y, stack) = unsafe { temp_mat_uninit(n, 1, stack) };
61	let (mut z, _) = unsafe { temp_mat_uninit(m, 1, stack) };
62
63	let mut y = y.as_mat_mut().col_mut(0).transpose_mut();
64	let mut z = z.as_mat_mut().col_mut(0);
65
66	let mut A = A;
67	let mut Hl = H_left;
68	let mut Hr = H_right;
69	let mut par = par;
70
71	{
72		let mut Hl = Hl.rb_mut().row_mut(0);
73		let mut Hr = Hr.rb_mut().row_mut(0);
74
75		for k in 0..size {
76			let mut A = A.rb_mut();
77
78			let (_, A01, A10, A11) = A.rb_mut().split_at_mut(k, k);
79
80			let (_, A02) = A01.split_first_col().unwrap();
81			let (A10, A20) = A10.split_first_row_mut().unwrap();
82			let (mut A11, A12, A21, mut A22) = A11.split_at_mut(1, 1);
83
84			let mut A12 = A12.row_mut(0);
85			let mut A21 = A21.col_mut(0);
86
87			let a11 = &mut A11[(0, 0)];
88
89			let (y1, mut y2) = y.rb_mut().split_at_col_mut(k).1.split_at_col_mut(1);
90			let (z1, mut z2) = z.rb_mut().split_at_row_mut(k).1.split_at_row_mut(1);
91
92			let y1 = copy(y1[0]);
93			let z1 = copy(z1[0]);
94
95			if k > 0 {
96				let k1 = k - 1;
97
98				let up0 = copy(A10[k1]);
99				let up = A20.rb().col(k1);
100				let vp = A02.rb().row(k1);
101
102				*a11 = *a11 - up0 * y1 - z1;
103				z!(A21.rb_mut(), up.rb(), z2.rb()).for_each(|uz!(a, u, z)| *a = *a - *u * y1 - *z);
104				z!(A12.rb_mut(), y2.rb(), vp.rb()).for_each(|uz!(a, y, v)| *a = *a - up0 * *y - z1 * *v);
105			}
106
107			let HouseholderInfo { tau: tl, .. } = make_householder_in_place(a11, A21.rb_mut());
108			let tl_inv = recip(tl);
109			Hl[k] = from_real(tl);
110
111			if (m - k - 1) * (n - k - 1) < params.par_threshold {
112				par = Par::Seq;
113			}
114
115			if k > 0 {
116				let k1 = k - 1;
117
118				let up = A20.rb().col(k1);
119				let vp = A02.row(k1);
120
121				match par {
122					Par::Seq => bidiag_fused_op(A22.rb_mut(), A21.rb(), up.rb(), z2.rb(), y2.rb_mut(), vp.rb(), simd_align(k + 1)),
123					#[cfg(feature = "rayon")]
124					Par::Rayon(nthreads) => {
125						use rayon::prelude::*;
126						let nthreads = nthreads.get();
127
128						A22.rb_mut()
129							.par_col_partition_mut(nthreads)
130							.zip_eq(y2.rb_mut().par_partition_mut(nthreads))
131							.zip_eq(vp.par_partition(nthreads))
132							.for_each(|((A22, y2), vp)| {
133								bidiag_fused_op(A22, A21.rb(), up.rb(), z2.rb(), y2, vp.rb(), simd_align(k + 1));
134							});
135					},
136				}
137			} else {
138				matmul(y2.rb_mut(), Accum::Replace, A21.rb().adjoint(), A22.rb(), one(), par);
139			}
140
141			z!(y2.rb_mut(), A12.rb_mut()).for_each(|uz!(y, a)| {
142				*y = mul_real(*y + *a, tl_inv);
143				*a = *a - *y;
144			});
145			let norm = A12.rb().norm_l2();
146			let norm_inv = recip(norm);
147			if norm != zero() {
148				z!(A12.rb_mut()).for_each(|uz!(a)| *a = mul_real(a, norm_inv));
149			}
150			matmul(z2.rb_mut(), Accum::Replace, A22.rb(), A12.rb().adjoint(), one(), par);
151
152			if k + 1 == size {
153				break;
154			}
155
156			let (mut A12_a, mut A12_b) = A12.rb_mut().split_at_col_mut(1);
157			let A22_a = A22.rb().col(0);
158			let (y2_a, y2_b) = y2.rb().split_at_col(1);
159			let y2_a = &y2_a[0];
160
161			let a12_a = &mut A12_a[0];
162
163			let HouseholderInfo {
164				tau: tr,
165				head_with_beta_inv: m,
166				..
167			} = make_householder_in_place(a12_a, A12_b.rb_mut().transpose_mut());
168			let tr_inv = recip(tr);
169			Hr[k] = from_real(tr);
170			let beta = copy(*a12_a);
171			*a12_a = mul_real(*a12_a, norm);
172
173			let b = *y2_a + dot::inner_prod(y2_b, Conj::No, A12_b.rb().transpose(), Conj::Yes);
174
175			if m != infinity() {
176				z!(z2.rb_mut(), A21.rb(), A22_a.rb()).for_each(|uz!(z, u, a)| {
177					let w = *z - *a * conj(beta);
178					let w = w * conj(m);
179					let w = w - *u * b;
180					*z = mul_real(w, tr_inv);
181				});
182			} else {
183				z!(z2.rb_mut(), A21.rb(), A22_a.rb()).for_each(|uz!(z, u, a)| {
184					let w = *a - *u * b;
185					*z = mul_real(w, tr_inv);
186				});
187			}
188		}
189	}
190
191	let mut j = 0;
192	while j < size {
193		let bl = Ord::min(bl, size - j);
194
195		let mut Hl = Hl.rb_mut().get_mut(..bl, j..j + bl);
196		for k in 0..bl {
197			Hl[(k, k)] = copy(Hl[(0, k)]);
198		}
199
200		upgrade_householder_factor(Hl.rb_mut(), A.rb().get(j.., j..j + bl), bl, 1, par);
201
202		j += bl;
203	}
204
205	if size > 0 {
206		let size = size - 1;
207		let A = A.rb().get(..size, 1..);
208
209		let mut Hr = Hr.rb_mut().get_mut(.., ..size);
210
211		let mut j = 0;
212		while j < size {
213			let br = Ord::min(br, size - j);
214
215			let mut Hr = Hr.rb_mut().get_mut(..br, j..j + br);
216
217			for k in 0..br {
218				Hr[(k, k)] = copy(Hr[(0, k)]);
219			}
220
221			upgrade_householder_factor(Hr.rb_mut(), A.transpose().get(j.., j..j + br), br, 1, par);
222			j += br;
223		}
224	}
225}
226
227#[math]
228fn bidiag_fused_op<T: ComplexField>(
229	A22: MatMut<'_, T>,
230	u: ColRef<'_, T>,
231	up: ColRef<'_, T>,
232	z: ColRef<'_, T>,
233	y: RowMut<'_, T>,
234	vp: RowRef<'_, T>,
235	align: usize,
236) {
237	let mut A22 = A22;
238
239	if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
240		if let (Some(A22), Some(u), Some(up), Some(z)) = (
241			A22.rb_mut().try_as_col_major_mut(),
242			u.try_as_col_major(),
243			up.try_as_col_major(),
244			z.try_as_col_major(),
245		) {
246			bidiag_fused_op_simd(A22, u, up, z, y, vp, align);
247		} else {
248			bidiag_fused_op_fallback(A22, u, up, z, y, vp);
249		}
250	} else {
251		bidiag_fused_op_fallback(A22, u, up, z, y, vp);
252	}
253}
254
255#[math]
256fn bidiag_fused_op_fallback<T: ComplexField>(
257	A22: MatMut<'_, T>,
258	u: ColRef<'_, T>,
259	up: ColRef<'_, T>,
260	z: ColRef<'_, T>,
261	y: RowMut<'_, T>,
262	vp: RowRef<'_, T>,
263) {
264	let mut A22 = A22;
265	let mut y = y;
266
267	matmul(A22.rb_mut(), Accum::Add, up, y.rb(), -one::<T>(), Par::Seq);
268	matmul(A22.rb_mut(), Accum::Add, z, vp, -one::<T>(), Par::Seq);
269	matmul(y.rb_mut(), Accum::Replace, u.adjoint(), A22.rb(), one(), Par::Seq);
270}
271
272#[math]
273fn bidiag_fused_op_simd<'M, 'N, T: ComplexField>(
274	A22: MatMut<'_, T, usize, usize, ContiguousFwd>,
275	u: ColRef<'_, T, usize, ContiguousFwd>,
276	up: ColRef<'_, T, usize, ContiguousFwd>,
277	z: ColRef<'_, T, usize, ContiguousFwd>,
278
279	y: RowMut<'_, T, usize>,
280	vp: RowRef<'_, T, usize>,
281
282	align: usize,
283) {
284	struct Impl<'a, 'M, 'N, T: ComplexField> {
285		A22: MatMut<'a, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
286		u: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
287		up: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
288		z: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
289
290		y: RowMut<'a, T, Dim<'N>>,
291		vp: RowRef<'a, T, Dim<'N>>,
292
293		align: usize,
294	}
295
296	impl<'a, 'M, 'N, T: ComplexField> pulp::WithSimd for Impl<'a, 'M, 'N, T> {
297		type Output = ();
298
299		#[inline(always)]
300		fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
301			let Self {
302				mut A22,
303				u,
304				up,
305				z,
306				mut y,
307				vp,
308				align,
309			} = self;
310
311			let m = A22.nrows();
312			let n = A22.ncols();
313			let simd = SimdCtx::<T, S>::new_align(T::simd_ctx(simd), m, align);
314			let (head, body4, body1, tail) = simd.batch_indices::<4>();
315
316			for j in n.indices() {
317				let mut a = A22.rb_mut().col_mut(j);
318
319				let mut acc0 = simd.zero();
320				let mut acc1 = simd.zero();
321				let mut acc2 = simd.zero();
322				let mut acc3 = simd.zero();
323
324				let yj = simd.splat(&-y[j]);
325				let vj = simd.splat(&-vp[j]);
326
327				if let Some(i0) = head {
328					let mut a0 = simd.read(a.rb(), i0);
329					a0 = simd.mul_add(simd.read(up, i0), yj, a0);
330					a0 = simd.mul_add(simd.read(z, i0), vj, a0);
331					simd.write(a.rb_mut(), i0, a0);
332
333					acc0 = simd.conj_mul_add(simd.read(u, i0), a0, acc0);
334				}
335
336				for [i0, i1, i2, i3] in body4.clone() {
337					{
338						let mut a0 = simd.read(a.rb(), i0);
339						a0 = simd.mul_add(simd.read(up, i0), yj, a0);
340						a0 = simd.mul_add(simd.read(z, i0), vj, a0);
341						simd.write(a.rb_mut(), i0, a0);
342
343						acc0 = simd.conj_mul_add(simd.read(u, i0), a0, acc0);
344					}
345					{
346						let mut a1 = simd.read(a.rb(), i1);
347						a1 = simd.mul_add(simd.read(up, i1), yj, a1);
348						a1 = simd.mul_add(simd.read(z, i1), vj, a1);
349						simd.write(a.rb_mut(), i1, a1);
350
351						acc1 = simd.conj_mul_add(simd.read(u, i1), a1, acc1);
352					}
353					{
354						let mut a2 = simd.read(a.rb(), i2);
355						a2 = simd.mul_add(simd.read(up, i2), yj, a2);
356						a2 = simd.mul_add(simd.read(z, i2), vj, a2);
357						simd.write(a.rb_mut(), i2, a2);
358
359						acc2 = simd.conj_mul_add(simd.read(u, i2), a2, acc2);
360					}
361					{
362						let mut a3 = simd.read(a.rb(), i3);
363						a3 = simd.mul_add(simd.read(up, i3), yj, a3);
364						a3 = simd.mul_add(simd.read(z, i3), vj, a3);
365						simd.write(a.rb_mut(), i3, a3);
366
367						acc3 = simd.conj_mul_add(simd.read(u, i3), a3, acc3);
368					}
369				}
370
371				for i0 in body1.clone() {
372					let mut a0 = simd.read(a.rb(), i0);
373					a0 = simd.mul_add(simd.read(up, i0), yj, a0);
374					a0 = simd.mul_add(simd.read(z, i0), vj, a0);
375					simd.write(a.rb_mut(), i0, a0);
376
377					acc0 = simd.conj_mul_add(simd.read(u, i0), a0, acc0);
378				}
379				if let Some(i0) = tail {
380					let mut a0 = simd.read(a.rb(), i0);
381					a0 = simd.mul_add(simd.read(up, i0), yj, a0);
382					a0 = simd.mul_add(simd.read(z, i0), vj, a0);
383					simd.write(a.rb_mut(), i0, a0);
384
385					acc0 = simd.conj_mul_add(simd.read(u, i0), a0, acc0);
386				}
387
388				acc0 = simd.add(acc0, acc1);
389				acc2 = simd.add(acc2, acc3);
390				acc0 = simd.add(acc0, acc2);
391
392				y[j] = simd.reduce_sum(acc0);
393			}
394		}
395	}
396
397	with_dim!(M, A22.nrows());
398	with_dim!(N, A22.ncols());
399
400	dispatch!(
401		Impl {
402			A22: A22.as_shape_mut(M, N),
403			u: u.as_row_shape(M),
404			up: up.as_row_shape(M),
405			z: z.as_row_shape(M),
406			y: y.as_col_shape_mut(N),
407			vp: vp.as_col_shape(N),
408			align,
409		},
410		Impl,
411		T
412	)
413}
414
415#[cfg(test)]
416mod tests {
417	use std::mem::MaybeUninit;
418
419	use dyn_stack::MemBuffer;
420
421	use super::*;
422	use crate::stats::prelude::*;
423	use crate::utils::approx::*;
424	use crate::{Mat, assert, c64};
425
426	#[test]
427	fn test_bidiag_real() {
428		let rng = &mut StdRng::seed_from_u64(0);
429
430		for (m, n) in [(8, 4), (8, 8)] {
431			let size = Ord::min(m, n);
432
433			let A = CwiseMatDistribution {
434				nrows: m,
435				ncols: n,
436				dist: StandardNormal,
437			}
438			.rand::<Mat<f64>>(rng);
439
440			let bl = 4;
441			let br = 3;
442			let mut Hl = Mat::zeros(bl, size);
443			let mut Hr = Mat::zeros(br, size - 1);
444
445			let mut UV = A.clone();
446			bidiag_in_place(
447				UV.rb_mut(),
448				Hl.rb_mut(),
449				Hr.rb_mut(),
450				Par::Seq,
451				MemStack::new(&mut [MaybeUninit::uninit(); 1024]),
452				default(),
453			);
454
455			let mut A = A.clone();
456			let mut A = A.as_mut();
457
458			apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
459				UV.rb().get(.., ..size),
460				Hl.rb(),
461				Conj::Yes,
462				A.rb_mut(),
463				Par::Seq,
464				MemStack::new(&mut MemBuffer::new(
465					apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<f64>(n - 1, 1, m),
466				)),
467			);
468
469			let V = UV.rb().get(..size - 1, 1..size);
470			let A1 = A.rb_mut().get_mut(.., 1..size);
471			let Hr = Hr.as_ref();
472
473			apply_block_householder_sequence_on_the_right_in_place_with_conj(
474				V.transpose(),
475				Hr.as_ref(),
476				Conj::Yes,
477				A1,
478				Par::Seq,
479				MemStack::new(&mut MemBuffer::new(
480					apply_block_householder_sequence_on_the_right_in_place_scratch::<f64>(n - 1, 1, m),
481				)),
482			);
483
484			let approx_eq = CwiseMat(ApproxEq::<f64>::eps());
485			for j in 0..n {
486				for i in 0..m {
487					if i > j || j > i + 1 {
488						UV[(i, j)] = 0.0;
489					}
490				}
491			}
492
493			assert!(UV ~ A);
494		}
495	}
496
497	#[test]
498	fn test_bidiag_cplx() {
499		let rng = &mut StdRng::seed_from_u64(0);
500
501		for (m, n) in [(8, 4), (8, 8)] {
502			let size = Ord::min(m, n);
503			let A = CwiseMatDistribution {
504				nrows: m,
505				ncols: n,
506				dist: ComplexDistribution::new(StandardNormal, StandardNormal),
507			}
508			.rand::<Mat<c64>>(rng);
509
510			let bl = 4;
511			let br = 3;
512			let mut Hl = Mat::zeros(bl, size);
513			let mut Hr = Mat::zeros(br, size - 1);
514
515			let mut UV = A.clone();
516			let mut UV = UV.as_mut();
517			bidiag_in_place(
518				UV.rb_mut(),
519				Hl.rb_mut(),
520				Hr.rb_mut(),
521				Par::Seq,
522				MemStack::new(&mut [MaybeUninit::uninit(); 1024]),
523				default(),
524			);
525
526			let mut A = A.clone();
527			let mut A = A.as_mut();
528
529			apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
530				UV.rb().subcols(0, size),
531				Hl.rb(),
532				Conj::Yes,
533				A.rb_mut(),
534				Par::Seq,
535				MemStack::new(&mut MemBuffer::new(
536					apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<c64>(n - 1, 1, m),
537				)),
538			);
539
540			let V = UV.rb().get(..size - 1, 1..size);
541			let A1 = A.rb_mut().get_mut(.., 1..size);
542			let Hr = Hr.rb();
543
544			apply_block_householder_sequence_on_the_right_in_place_with_conj(
545				V.transpose(),
546				Hr,
547				Conj::Yes,
548				A1,
549				Par::Seq,
550				MemStack::new(&mut MemBuffer::new(
551					apply_block_householder_sequence_on_the_right_in_place_scratch::<c64>(n - 1, 1, m),
552				)),
553			);
554
555			let approx_eq = CwiseMat(ApproxEq::eps());
556			for j in 0..n {
557				for i in 0..m {
558					if i > j || j > i + 1 {
559						UV[(i, j)] = c64::ZERO;
560					}
561				}
562			}
563
564			assert!(UV ~ A);
565		}
566	}
567}