1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::householder::*;
4use linalg::matmul::{dot, matmul};
5
6pub fn bidiag_in_place_scratch<T: ComplexField>(nrows: usize, ncols: usize, par: Par, params: Spec<BidiagParams, T>) -> StackReq {
9 _ = par;
10 _ = params;
11 StackReq::all_of(&[temp_mat_scratch::<T>(nrows, 1), temp_mat_scratch::<T>(ncols, 1)])
12}
13
14#[derive(Debug, Copy, Clone)]
16pub struct BidiagParams {
17 pub par_threshold: usize,
19 #[doc(hidden)]
20 pub non_exhaustive: NonExhaustive,
21}
22
23impl<T: ComplexField> Auto<T> for BidiagParams {
24 fn auto() -> Self {
25 Self {
26 par_threshold: 192 * 256,
27 non_exhaustive: NonExhaustive(()),
28 }
29 }
30}
31
32#[math]
42pub fn bidiag_in_place<T: ComplexField>(
43 A: MatMut<'_, T>,
44 H_left: MatMut<'_, T>,
45 H_right: MatMut<'_, T>,
46 par: Par,
47 stack: &mut MemStack,
48 params: Spec<BidiagParams, T>,
49) {
50 let params = params.config;
51 let m = A.nrows();
52 let n = A.ncols();
53 let size = Ord::min(m, n);
54 let bl = H_left.nrows();
55 let br = H_right.nrows();
56
57 assert!(H_left.ncols() == size);
58 assert!(H_right.ncols() == size.saturating_sub(1));
59
60 let (mut y, stack) = unsafe { temp_mat_uninit(n, 1, stack) };
61 let (mut z, _) = unsafe { temp_mat_uninit(m, 1, stack) };
62
63 let mut y = y.as_mat_mut().col_mut(0).transpose_mut();
64 let mut z = z.as_mat_mut().col_mut(0);
65
66 let mut A = A;
67 let mut Hl = H_left;
68 let mut Hr = H_right;
69 let mut par = par;
70
71 {
72 let mut Hl = Hl.rb_mut().row_mut(0);
73 let mut Hr = Hr.rb_mut().row_mut(0);
74
75 for k in 0..size {
76 let mut A = A.rb_mut();
77
78 let (_, A01, A10, A11) = A.rb_mut().split_at_mut(k, k);
79
80 let (_, A02) = A01.split_first_col().unwrap();
81 let (A10, A20) = A10.split_first_row_mut().unwrap();
82 let (mut A11, A12, A21, mut A22) = A11.split_at_mut(1, 1);
83
84 let mut A12 = A12.row_mut(0);
85 let mut A21 = A21.col_mut(0);
86
87 let a11 = &mut A11[(0, 0)];
88
89 let (y1, mut y2) = y.rb_mut().split_at_col_mut(k).1.split_at_col_mut(1);
90 let (z1, mut z2) = z.rb_mut().split_at_row_mut(k).1.split_at_row_mut(1);
91
92 let y1 = copy(y1[0]);
93 let z1 = copy(z1[0]);
94
95 if k > 0 {
96 let k1 = k - 1;
97
98 let up0 = copy(A10[k1]);
99 let up = A20.rb().col(k1);
100 let vp = A02.rb().row(k1);
101
102 *a11 = *a11 - up0 * y1 - z1;
103 z!(A21.rb_mut(), up.rb(), z2.rb()).for_each(|uz!(a, u, z)| *a = *a - *u * y1 - *z);
104 z!(A12.rb_mut(), y2.rb(), vp.rb()).for_each(|uz!(a, y, v)| *a = *a - up0 * *y - z1 * *v);
105 }
106
107 let HouseholderInfo { tau: tl, .. } = make_householder_in_place(a11, A21.rb_mut());
108 let tl_inv = recip(tl);
109 Hl[k] = from_real(tl);
110
111 if (m - k - 1) * (n - k - 1) < params.par_threshold {
112 par = Par::Seq;
113 }
114
115 if k > 0 {
116 let k1 = k - 1;
117
118 let up = A20.rb().col(k1);
119 let vp = A02.row(k1);
120
121 match par {
122 Par::Seq => bidiag_fused_op(A22.rb_mut(), A21.rb(), up.rb(), z2.rb(), y2.rb_mut(), vp.rb(), simd_align(k + 1)),
123 #[cfg(feature = "rayon")]
124 Par::Rayon(nthreads) => {
125 use rayon::prelude::*;
126 let nthreads = nthreads.get();
127
128 A22.rb_mut()
129 .par_col_partition_mut(nthreads)
130 .zip_eq(y2.rb_mut().par_partition_mut(nthreads))
131 .zip_eq(vp.par_partition(nthreads))
132 .for_each(|((A22, y2), vp)| {
133 bidiag_fused_op(A22, A21.rb(), up.rb(), z2.rb(), y2, vp.rb(), simd_align(k + 1));
134 });
135 },
136 }
137 } else {
138 matmul(y2.rb_mut(), Accum::Replace, A21.rb().adjoint(), A22.rb(), one(), par);
139 }
140
141 z!(y2.rb_mut(), A12.rb_mut()).for_each(|uz!(y, a)| {
142 *y = mul_real(*y + *a, tl_inv);
143 *a = *a - *y;
144 });
145 let norm = A12.rb().norm_l2();
146 let norm_inv = recip(norm);
147 if norm != zero() {
148 z!(A12.rb_mut()).for_each(|uz!(a)| *a = mul_real(a, norm_inv));
149 }
150 matmul(z2.rb_mut(), Accum::Replace, A22.rb(), A12.rb().adjoint(), one(), par);
151
152 if k + 1 == size {
153 break;
154 }
155
156 let (mut A12_a, mut A12_b) = A12.rb_mut().split_at_col_mut(1);
157 let A22_a = A22.rb().col(0);
158 let (y2_a, y2_b) = y2.rb().split_at_col(1);
159 let y2_a = &y2_a[0];
160
161 let a12_a = &mut A12_a[0];
162
163 let HouseholderInfo {
164 tau: tr,
165 head_with_beta_inv: m,
166 ..
167 } = make_householder_in_place(a12_a, A12_b.rb_mut().transpose_mut());
168 let tr_inv = recip(tr);
169 Hr[k] = from_real(tr);
170 let beta = copy(*a12_a);
171 *a12_a = mul_real(*a12_a, norm);
172
173 let b = *y2_a + dot::inner_prod(y2_b, Conj::No, A12_b.rb().transpose(), Conj::Yes);
174
175 if m != infinity() {
176 z!(z2.rb_mut(), A21.rb(), A22_a.rb()).for_each(|uz!(z, u, a)| {
177 let w = *z - *a * conj(beta);
178 let w = w * conj(m);
179 let w = w - *u * b;
180 *z = mul_real(w, tr_inv);
181 });
182 } else {
183 z!(z2.rb_mut(), A21.rb(), A22_a.rb()).for_each(|uz!(z, u, a)| {
184 let w = *a - *u * b;
185 *z = mul_real(w, tr_inv);
186 });
187 }
188 }
189 }
190
191 let mut j = 0;
192 while j < size {
193 let bl = Ord::min(bl, size - j);
194
195 let mut Hl = Hl.rb_mut().get_mut(..bl, j..j + bl);
196 for k in 0..bl {
197 Hl[(k, k)] = copy(Hl[(0, k)]);
198 }
199
200 upgrade_householder_factor(Hl.rb_mut(), A.rb().get(j.., j..j + bl), bl, 1, par);
201
202 j += bl;
203 }
204
205 if size > 0 {
206 let size = size - 1;
207 let A = A.rb().get(..size, 1..);
208
209 let mut Hr = Hr.rb_mut().get_mut(.., ..size);
210
211 let mut j = 0;
212 while j < size {
213 let br = Ord::min(br, size - j);
214
215 let mut Hr = Hr.rb_mut().get_mut(..br, j..j + br);
216
217 for k in 0..br {
218 Hr[(k, k)] = copy(Hr[(0, k)]);
219 }
220
221 upgrade_householder_factor(Hr.rb_mut(), A.transpose().get(j.., j..j + br), br, 1, par);
222 j += br;
223 }
224 }
225}
226
227#[math]
228fn bidiag_fused_op<T: ComplexField>(
229 A22: MatMut<'_, T>,
230 u: ColRef<'_, T>,
231 up: ColRef<'_, T>,
232 z: ColRef<'_, T>,
233 y: RowMut<'_, T>,
234 vp: RowRef<'_, T>,
235 align: usize,
236) {
237 let mut A22 = A22;
238
239 if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
240 if let (Some(A22), Some(u), Some(up), Some(z)) = (
241 A22.rb_mut().try_as_col_major_mut(),
242 u.try_as_col_major(),
243 up.try_as_col_major(),
244 z.try_as_col_major(),
245 ) {
246 bidiag_fused_op_simd(A22, u, up, z, y, vp, align);
247 } else {
248 bidiag_fused_op_fallback(A22, u, up, z, y, vp);
249 }
250 } else {
251 bidiag_fused_op_fallback(A22, u, up, z, y, vp);
252 }
253}
254
255#[math]
256fn bidiag_fused_op_fallback<T: ComplexField>(
257 A22: MatMut<'_, T>,
258 u: ColRef<'_, T>,
259 up: ColRef<'_, T>,
260 z: ColRef<'_, T>,
261 y: RowMut<'_, T>,
262 vp: RowRef<'_, T>,
263) {
264 let mut A22 = A22;
265 let mut y = y;
266
267 matmul(A22.rb_mut(), Accum::Add, up, y.rb(), -one::<T>(), Par::Seq);
268 matmul(A22.rb_mut(), Accum::Add, z, vp, -one::<T>(), Par::Seq);
269 matmul(y.rb_mut(), Accum::Replace, u.adjoint(), A22.rb(), one(), Par::Seq);
270}
271
272#[math]
273fn bidiag_fused_op_simd<'M, 'N, T: ComplexField>(
274 A22: MatMut<'_, T, usize, usize, ContiguousFwd>,
275 u: ColRef<'_, T, usize, ContiguousFwd>,
276 up: ColRef<'_, T, usize, ContiguousFwd>,
277 z: ColRef<'_, T, usize, ContiguousFwd>,
278
279 y: RowMut<'_, T, usize>,
280 vp: RowRef<'_, T, usize>,
281
282 align: usize,
283) {
284 struct Impl<'a, 'M, 'N, T: ComplexField> {
285 A22: MatMut<'a, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
286 u: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
287 up: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
288 z: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
289
290 y: RowMut<'a, T, Dim<'N>>,
291 vp: RowRef<'a, T, Dim<'N>>,
292
293 align: usize,
294 }
295
296 impl<'a, 'M, 'N, T: ComplexField> pulp::WithSimd for Impl<'a, 'M, 'N, T> {
297 type Output = ();
298
299 #[inline(always)]
300 fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
301 let Self {
302 mut A22,
303 u,
304 up,
305 z,
306 mut y,
307 vp,
308 align,
309 } = self;
310
311 let m = A22.nrows();
312 let n = A22.ncols();
313 let simd = SimdCtx::<T, S>::new_align(T::simd_ctx(simd), m, align);
314 let (head, body4, body1, tail) = simd.batch_indices::<4>();
315
316 for j in n.indices() {
317 let mut a = A22.rb_mut().col_mut(j);
318
319 let mut acc0 = simd.zero();
320 let mut acc1 = simd.zero();
321 let mut acc2 = simd.zero();
322 let mut acc3 = simd.zero();
323
324 let yj = simd.splat(&-y[j]);
325 let vj = simd.splat(&-vp[j]);
326
327 if let Some(i0) = head {
328 let mut a0 = simd.read(a.rb(), i0);
329 a0 = simd.mul_add(simd.read(up, i0), yj, a0);
330 a0 = simd.mul_add(simd.read(z, i0), vj, a0);
331 simd.write(a.rb_mut(), i0, a0);
332
333 acc0 = simd.conj_mul_add(simd.read(u, i0), a0, acc0);
334 }
335
336 for [i0, i1, i2, i3] in body4.clone() {
337 {
338 let mut a0 = simd.read(a.rb(), i0);
339 a0 = simd.mul_add(simd.read(up, i0), yj, a0);
340 a0 = simd.mul_add(simd.read(z, i0), vj, a0);
341 simd.write(a.rb_mut(), i0, a0);
342
343 acc0 = simd.conj_mul_add(simd.read(u, i0), a0, acc0);
344 }
345 {
346 let mut a1 = simd.read(a.rb(), i1);
347 a1 = simd.mul_add(simd.read(up, i1), yj, a1);
348 a1 = simd.mul_add(simd.read(z, i1), vj, a1);
349 simd.write(a.rb_mut(), i1, a1);
350
351 acc1 = simd.conj_mul_add(simd.read(u, i1), a1, acc1);
352 }
353 {
354 let mut a2 = simd.read(a.rb(), i2);
355 a2 = simd.mul_add(simd.read(up, i2), yj, a2);
356 a2 = simd.mul_add(simd.read(z, i2), vj, a2);
357 simd.write(a.rb_mut(), i2, a2);
358
359 acc2 = simd.conj_mul_add(simd.read(u, i2), a2, acc2);
360 }
361 {
362 let mut a3 = simd.read(a.rb(), i3);
363 a3 = simd.mul_add(simd.read(up, i3), yj, a3);
364 a3 = simd.mul_add(simd.read(z, i3), vj, a3);
365 simd.write(a.rb_mut(), i3, a3);
366
367 acc3 = simd.conj_mul_add(simd.read(u, i3), a3, acc3);
368 }
369 }
370
371 for i0 in body1.clone() {
372 let mut a0 = simd.read(a.rb(), i0);
373 a0 = simd.mul_add(simd.read(up, i0), yj, a0);
374 a0 = simd.mul_add(simd.read(z, i0), vj, a0);
375 simd.write(a.rb_mut(), i0, a0);
376
377 acc0 = simd.conj_mul_add(simd.read(u, i0), a0, acc0);
378 }
379 if let Some(i0) = tail {
380 let mut a0 = simd.read(a.rb(), i0);
381 a0 = simd.mul_add(simd.read(up, i0), yj, a0);
382 a0 = simd.mul_add(simd.read(z, i0), vj, a0);
383 simd.write(a.rb_mut(), i0, a0);
384
385 acc0 = simd.conj_mul_add(simd.read(u, i0), a0, acc0);
386 }
387
388 acc0 = simd.add(acc0, acc1);
389 acc2 = simd.add(acc2, acc3);
390 acc0 = simd.add(acc0, acc2);
391
392 y[j] = simd.reduce_sum(acc0);
393 }
394 }
395 }
396
397 with_dim!(M, A22.nrows());
398 with_dim!(N, A22.ncols());
399
400 dispatch!(
401 Impl {
402 A22: A22.as_shape_mut(M, N),
403 u: u.as_row_shape(M),
404 up: up.as_row_shape(M),
405 z: z.as_row_shape(M),
406 y: y.as_col_shape_mut(N),
407 vp: vp.as_col_shape(N),
408 align,
409 },
410 Impl,
411 T
412 )
413}
414
415#[cfg(test)]
416mod tests {
417 use std::mem::MaybeUninit;
418
419 use dyn_stack::MemBuffer;
420
421 use super::*;
422 use crate::stats::prelude::*;
423 use crate::utils::approx::*;
424 use crate::{Mat, assert, c64};
425
426 #[test]
427 fn test_bidiag_real() {
428 let rng = &mut StdRng::seed_from_u64(0);
429
430 for (m, n) in [(8, 4), (8, 8)] {
431 let size = Ord::min(m, n);
432
433 let A = CwiseMatDistribution {
434 nrows: m,
435 ncols: n,
436 dist: StandardNormal,
437 }
438 .rand::<Mat<f64>>(rng);
439
440 let bl = 4;
441 let br = 3;
442 let mut Hl = Mat::zeros(bl, size);
443 let mut Hr = Mat::zeros(br, size - 1);
444
445 let mut UV = A.clone();
446 bidiag_in_place(
447 UV.rb_mut(),
448 Hl.rb_mut(),
449 Hr.rb_mut(),
450 Par::Seq,
451 MemStack::new(&mut [MaybeUninit::uninit(); 1024]),
452 default(),
453 );
454
455 let mut A = A.clone();
456 let mut A = A.as_mut();
457
458 apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
459 UV.rb().get(.., ..size),
460 Hl.rb(),
461 Conj::Yes,
462 A.rb_mut(),
463 Par::Seq,
464 MemStack::new(&mut MemBuffer::new(
465 apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<f64>(n - 1, 1, m),
466 )),
467 );
468
469 let V = UV.rb().get(..size - 1, 1..size);
470 let A1 = A.rb_mut().get_mut(.., 1..size);
471 let Hr = Hr.as_ref();
472
473 apply_block_householder_sequence_on_the_right_in_place_with_conj(
474 V.transpose(),
475 Hr.as_ref(),
476 Conj::Yes,
477 A1,
478 Par::Seq,
479 MemStack::new(&mut MemBuffer::new(
480 apply_block_householder_sequence_on_the_right_in_place_scratch::<f64>(n - 1, 1, m),
481 )),
482 );
483
484 let approx_eq = CwiseMat(ApproxEq::<f64>::eps());
485 for j in 0..n {
486 for i in 0..m {
487 if i > j || j > i + 1 {
488 UV[(i, j)] = 0.0;
489 }
490 }
491 }
492
493 assert!(UV ~ A);
494 }
495 }
496
497 #[test]
498 fn test_bidiag_cplx() {
499 let rng = &mut StdRng::seed_from_u64(0);
500
501 for (m, n) in [(8, 4), (8, 8)] {
502 let size = Ord::min(m, n);
503 let A = CwiseMatDistribution {
504 nrows: m,
505 ncols: n,
506 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
507 }
508 .rand::<Mat<c64>>(rng);
509
510 let bl = 4;
511 let br = 3;
512 let mut Hl = Mat::zeros(bl, size);
513 let mut Hr = Mat::zeros(br, size - 1);
514
515 let mut UV = A.clone();
516 let mut UV = UV.as_mut();
517 bidiag_in_place(
518 UV.rb_mut(),
519 Hl.rb_mut(),
520 Hr.rb_mut(),
521 Par::Seq,
522 MemStack::new(&mut [MaybeUninit::uninit(); 1024]),
523 default(),
524 );
525
526 let mut A = A.clone();
527 let mut A = A.as_mut();
528
529 apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
530 UV.rb().subcols(0, size),
531 Hl.rb(),
532 Conj::Yes,
533 A.rb_mut(),
534 Par::Seq,
535 MemStack::new(&mut MemBuffer::new(
536 apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<c64>(n - 1, 1, m),
537 )),
538 );
539
540 let V = UV.rb().get(..size - 1, 1..size);
541 let A1 = A.rb_mut().get_mut(.., 1..size);
542 let Hr = Hr.rb();
543
544 apply_block_householder_sequence_on_the_right_in_place_with_conj(
545 V.transpose(),
546 Hr,
547 Conj::Yes,
548 A1,
549 Par::Seq,
550 MemStack::new(&mut MemBuffer::new(
551 apply_block_householder_sequence_on_the_right_in_place_scratch::<c64>(n - 1, 1, m),
552 )),
553 );
554
555 let approx_eq = CwiseMat(ApproxEq::eps());
556 for j in 0..n {
557 for i in 0..m {
558 if i > j || j > i + 1 {
559 UV[(i, j)] = c64::ZERO;
560 }
561 }
562 }
563
564 assert!(UV ~ A);
565 }
566 }
567}