1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::householder::{self, HouseholderInfo};
4use linalg::matmul::triangular::BlockStructure;
5use linalg::matmul::{self, dot, matmul};
6use linalg::triangular_solve;
7
8#[derive(Copy, Clone, Debug)]
10pub struct HessenbergParams {
11 pub par_threshold: usize,
13 pub blocking_threshold: usize,
15
16 #[doc(hidden)]
17 pub non_exhaustive: NonExhaustive,
18}
19
20impl<T: ComplexField> Auto<T> for HessenbergParams {
21 fn auto() -> Self {
22 Self {
23 par_threshold: 192 * 256,
24 blocking_threshold: 256 * 256,
25 non_exhaustive: NonExhaustive(()),
26 }
27 }
28}
29
30pub fn hessenberg_in_place_scratch<T: ComplexField>(dim: usize, blocksize: usize, par: Par, params: Spec<HessenbergParams, T>) -> StackReq {
33 let params = params.config;
34 let _ = par;
35 let n = dim;
36 if n * n < params.blocking_threshold {
37 StackReq::any_of(&[StackReq::all_of(&[
38 temp_mat_scratch::<T>(n, 1).array(3),
39 temp_mat_scratch::<T>(n, par.degree()),
40 ])])
41 } else {
42 StackReq::all_of(&[
43 temp_mat_scratch::<T>(n, blocksize),
44 temp_mat_scratch::<T>(blocksize, 1),
45 StackReq::any_of(&[
46 StackReq::all_of(&[temp_mat_scratch::<T>(n, 1), temp_mat_scratch::<T>(n, par.degree())]),
47 temp_mat_scratch::<T>(n, blocksize),
48 ]),
49 ])
50 }
51}
52
53#[math]
54fn hessenberg_fused_op_simd<T: ComplexField>(
55 A: MatMut<'_, T, usize, usize, ContiguousFwd>,
56
57 l_out: RowMut<'_, T, usize>,
58 r_out: ColMut<'_, T, usize, ContiguousFwd>,
59 l_in: RowRef<'_, T, usize, ContiguousFwd>,
60 r_in: ColRef<'_, T, usize>,
61
62 l0: ColRef<'_, T, usize, ContiguousFwd>,
63 l1: ColRef<'_, T, usize, ContiguousFwd>,
64 r0: RowRef<'_, T, usize>,
65 r1: RowRef<'_, T, usize>,
66 align: usize,
67) {
68 struct Impl<'a, 'M, 'N, T: ComplexField> {
69 A: MatMut<'a, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
70
71 l_out: RowMut<'a, T, Dim<'N>>,
72 r_out: ColMut<'a, T, Dim<'M>, ContiguousFwd>,
73 l_in: RowRef<'a, T, Dim<'M>, ContiguousFwd>,
74 r_in: ColRef<'a, T, Dim<'N>>,
75
76 l0: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
77 l1: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
78 r0: RowRef<'a, T, Dim<'N>>,
79 r1: RowRef<'a, T, Dim<'N>>,
80 align: usize,
81 }
82
83 impl<'a, 'M, 'N, T: ComplexField> pulp::WithSimd for Impl<'a, 'M, 'N, T> {
84 type Output = ();
85
86 #[inline(always)]
87 fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
88 let Self {
89 mut A,
90 mut l_out,
91 mut r_out,
92 l_in,
93 r_in,
94 l0,
95 l1,
96 r0,
97 r1,
98 align,
99 } = self;
100
101 let (m, n) = A.shape();
102
103 let simd = SimdCtx::<T, S>::new_align(T::simd_ctx(simd), m, align);
104
105 {
106 let (head, body, tail) = simd.indices();
107 if let Some(i) = head {
108 simd.write(r_out.rb_mut(), i, simd.zero());
109 }
110 for i in body {
111 simd.write(r_out.rb_mut(), i, simd.zero());
112 }
113 if let Some(i) = tail {
114 simd.write(r_out.rb_mut(), i, simd.zero());
115 }
116 }
117
118 let (head, body4, body1, tail) = simd.batch_indices::<4>();
119
120 let l_in = l_in.transpose();
121
122 for j in n.indices() {
123 let mut A = A.rb_mut().col_mut(j);
124 let r_in = simd.splat(r_in.at(j));
125 let r0 = simd.splat(&(-r0[j]));
126 let r1 = simd.splat(&(-r1[j]));
127
128 let mut acc0 = simd.zero();
129 let mut acc1 = simd.zero();
130 let mut acc2 = simd.zero();
131 let mut acc3 = simd.zero();
132
133 if let Some(i0) = head {
134 let mut a0 = simd.read(A.rb(), i0);
135 a0 = simd.mul_add(simd.read(l0, i0), r0, a0);
136 a0 = simd.conj_mul_add(r1, simd.read(l1, i0), a0);
137 simd.write(A.rb_mut(), i0, a0);
138 acc0 = simd.conj_mul_add(simd.read(l_in, i0), a0, acc0);
139 let tmp = simd.read(r_out.rb(), i0);
140 simd.write(r_out.rb_mut(), i0, simd.mul_add(a0, r_in, tmp));
141 }
142 for [i0, i1, i2, i3] in body4.clone() {
143 {
144 let mut a0 = simd.read(A.rb(), i0);
145 a0 = simd.mul_add(simd.read(l0, i0), r0, a0);
146 a0 = simd.conj_mul_add(r1, simd.read(l1, i0), a0);
147 simd.write(A.rb_mut(), i0, a0);
148 acc0 = simd.conj_mul_add(simd.read(l_in, i0), a0, acc0);
149 let tmp = simd.read(r_out.rb(), i0);
150 simd.write(r_out.rb_mut(), i0, simd.mul_add(a0, r_in, tmp));
151 }
152 {
153 let mut a1 = simd.read(A.rb(), i1);
154 a1 = simd.mul_add(simd.read(l0, i1), r0, a1);
155 a1 = simd.conj_mul_add(r1, simd.read(l1, i1), a1);
156 simd.write(A.rb_mut(), i1, a1);
157 acc1 = simd.conj_mul_add(simd.read(l_in, i1), a1, acc1);
158 let tmp = simd.read(r_out.rb(), i1);
159 simd.write(r_out.rb_mut(), i1, simd.mul_add(a1, r_in, tmp));
160 }
161 {
162 let mut a2 = simd.read(A.rb(), i2);
163 a2 = simd.mul_add(simd.read(l0, i2), r0, a2);
164 a2 = simd.conj_mul_add(r1, simd.read(l1, i2), a2);
165 simd.write(A.rb_mut(), i2, a2);
166 acc2 = simd.conj_mul_add(simd.read(l_in, i2), a2, acc2);
167 let tmp = simd.read(r_out.rb(), i2);
168 simd.write(r_out.rb_mut(), i2, simd.mul_add(a2, r_in, tmp));
169 }
170 {
171 let mut a3 = simd.read(A.rb(), i3);
172 a3 = simd.mul_add(simd.read(l0, i3), r0, a3);
173 a3 = simd.conj_mul_add(r1, simd.read(l1, i3), a3);
174 simd.write(A.rb_mut(), i3, a3);
175 acc3 = simd.conj_mul_add(simd.read(l_in, i3), a3, acc3);
176 let tmp = simd.read(r_out.rb(), i3);
177 simd.write(r_out.rb_mut(), i3, simd.mul_add(a3, r_in, tmp));
178 }
179 }
180 for i0 in body1.clone() {
181 let mut a0 = simd.read(A.rb(), i0);
182 a0 = simd.mul_add(simd.read(l0, i0), r0, a0);
183 a0 = simd.conj_mul_add(r1, simd.read(l1, i0), a0);
184 simd.write(A.rb_mut(), i0, a0);
185 acc0 = simd.conj_mul_add(simd.read(l_in, i0), a0, acc0);
186 let tmp = simd.read(r_out.rb(), i0);
187 simd.write(r_out.rb_mut(), i0, simd.mul_add(a0, r_in, tmp));
188 }
189 if let Some(i0) = tail {
190 let mut a0 = simd.read(A.rb(), i0);
191 a0 = simd.mul_add(simd.read(l0, i0), r0, a0);
192 a0 = simd.conj_mul_add(r1, simd.read(l1, i0), a0);
193 simd.write(A.rb_mut(), i0, a0);
194 acc0 = simd.conj_mul_add(simd.read(l_in, i0), a0, acc0);
195 let tmp = simd.read(r_out.rb(), i0);
196 simd.write(r_out.rb_mut(), i0, simd.mul_add(a0, r_in, tmp));
197 }
198
199 acc0 = simd.add(acc0, acc1);
200 acc2 = simd.add(acc2, acc3);
201 acc0 = simd.add(acc0, acc2);
202
203 let l_out = l_out.rb_mut().at_mut(j);
204 *l_out = simd.reduce_sum(acc0);
205 }
206 }
207 }
208
209 with_dim!(M, A.nrows());
210 with_dim!(N, A.ncols());
211
212 dispatch!(
213 Impl {
214 A: A.as_shape_mut(M, N),
215 l_out: l_out.as_col_shape_mut(N),
216 r_out: r_out.as_row_shape_mut(M),
217 l_in: l_in.as_col_shape(M),
218 r_in: r_in.as_row_shape(N),
219 l0: l0.as_row_shape(M),
220 l1: l1.as_row_shape(M),
221 r0: r0.as_col_shape(N),
222 r1: r1.as_col_shape(N),
223 align,
224 },
225 Impl,
226 T
227 )
228}
229
230#[math]
231fn hessenberg_fused_op_fallback<T: ComplexField>(
232 A: MatMut<'_, T>,
233
234 l_out: RowMut<'_, T>,
235 r_out: ColMut<'_, T>,
236 l_in: RowRef<'_, T>,
237 r_in: ColRef<'_, T>,
238
239 l0: ColRef<'_, T>,
240 l1: ColRef<'_, T>,
241 r0: RowRef<'_, T>,
242 r1: RowRef<'_, T>,
243) {
244 let mut A = A;
245
246 matmul(A.rb_mut(), Accum::Add, l0.as_mat(), r0.as_mat(), -one::<T>(), Par::Seq);
247 matmul(A.rb_mut(), Accum::Add, l1.as_mat(), r1.as_mat().conjugate(), -one::<T>(), Par::Seq);
248
249 matmul(r_out.as_mat_mut(), Accum::Replace, A.rb(), r_in.as_mat(), one(), Par::Seq);
250 matmul(l_out.as_mat_mut(), Accum::Replace, l_in.as_mat().conjugate(), A.rb(), one(), Par::Seq);
251}
252
253fn hessenberg_fused_op<T: ComplexField>(
254 A: MatMut<'_, T>,
255
256 l_out: RowMut<'_, T>,
257 r_out: ColMut<'_, T>,
258 l_in: RowRef<'_, T>,
259 r_in: ColRef<'_, T>,
260
261 l0: ColRef<'_, T>,
262 l1: ColRef<'_, T>,
263 r0: RowRef<'_, T>,
264 r1: RowRef<'_, T>,
265 align: usize,
266) {
267 let mut A = A;
268 let mut r_out = r_out;
269
270 if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
271 if let (Some(A), Some(r_out), Some(l_in), Some(l0), Some(l1)) = (
272 A.rb_mut().try_as_col_major_mut(),
273 r_out.rb_mut().try_as_col_major_mut(),
274 l_in.try_as_row_major(),
275 l0.try_as_col_major(),
276 l1.try_as_col_major(),
277 ) {
278 hessenberg_fused_op_simd(A, l_out, r_out, l_in, r_in, l0, l1, r0, r1, align);
279 } else {
280 hessenberg_fused_op_fallback(A, l_out, r_out, l_in, r_in, l0, l1, r0, r1);
281 }
282 } else {
283 hessenberg_fused_op_fallback(A, l_out, r_out, l_in, r_in, l0, l1, r0, r1);
284 }
285}
286
287#[math]
288fn hessenberg_rearranged_unblocked<T: ComplexField>(A: MatMut<'_, T>, H: MatMut<'_, T>, par: Par, stack: &mut MemStack, params: HessenbergParams) {
289 assert!(all(A.nrows() == A.ncols(), H.ncols() == A.ncols().saturating_sub(1)));
290
291 let n = A.nrows();
292 let b = H.nrows();
293
294 if n == 0 {
295 return;
296 }
297
298 let mut A = A;
299 let mut H = H;
300 let mut par = par;
301
302 {
303 let mut H = H.rb_mut().row_mut(0);
304 let (mut y, stack) = unsafe { temp_mat_uninit(n, 1, stack) };
305 let (mut z, stack) = unsafe { temp_mat_uninit(n, 1, stack) };
306 let (mut v, stack) = unsafe { temp_mat_uninit(n, 1, stack) };
307 let (mut w, _) = unsafe { temp_mat_uninit(n, par.degree(), stack) };
308
309 let mut y = y.as_mat_mut().col_mut(0).transpose_mut();
310 let mut z = z.as_mat_mut().col_mut(0);
311 let mut v = v.as_mat_mut().col_mut(0).transpose_mut();
312 let mut w = w.as_mat_mut();
313
314 for k in 0..n {
315 let (_, A01, A10, A11) = A.rb_mut().split_at_mut(k, k);
316
317 let (_, mut A02) = A01.split_first_col_mut().unwrap();
318 let (_, A20) = A10.split_first_row_mut().unwrap();
319 let (mut A11, A12, A21, mut A22) = A11.split_at_mut(1, 1);
320
321 let mut A12 = A12.row_mut(0);
322 let mut A21 = A21.col_mut(0);
323
324 let A11 = &mut A11[(0, 0)];
325
326 let (y1, mut y2) = y.rb_mut().split_at_col_mut(k).1.split_at_col_mut(1);
327 let y1 = copy(y1[0]);
328
329 let (z1, mut z2) = z.rb_mut().split_at_row_mut(k).1.split_at_row_mut(1);
330 let z1 = copy(z1[0]);
331
332 let (_, mut v2) = v.rb_mut().split_at_col_mut(k).1.split_at_col_mut(1);
333 let (mut w0, w12) = w.rb_mut().split_at_row_mut(k);
334 let (_, mut w2) = w12.split_at_row_mut(1);
335
336 if k > 0 {
337 let p = k - 1;
338 let u2 = A20.rb().col(p);
339
340 *A11 = *A11 - y1 - z1;
341 z!(&mut A12, &y2, u2.rb().transpose()).for_each(|uz!(a, y, u)| *a = *a - *y - z1 * conj(*u));
342 z!(&mut A21, &u2, &z2).for_each(|uz!(a, u, z)| *a = *a - *u * y1 - *z);
343 }
344
345 {
346 let n = n - k - 1;
347 if n * n < params.par_threshold {
348 par = Par::Seq;
349 }
350 }
351
352 if k + 1 == n {
353 break;
354 }
355
356 let beta;
357 let tau_inv;
358 {
359 let (mut A11, mut A21) = A21.rb_mut().split_at_row_mut(1);
360 let A11 = &mut A11[0];
361
362 let HouseholderInfo { tau, .. } = householder::make_householder_in_place(A11, A21.rb_mut());
363 tau_inv = recip(tau);
364 beta = copy(*A11);
365 *A11 = one();
366
367 H[k] = from_real(tau);
368 }
369
370 let x2 = A21.rb();
371
372 if k > 0 {
373 let p = k - 1;
374 let u2 = A20.rb().col(p);
375 hessenberg_fused_op(
376 A22.rb_mut(),
377 v2.rb_mut(),
378 w2.rb_mut().col_mut(0),
379 x2.transpose(),
380 x2,
381 u2,
382 z2.rb(),
383 y2.rb(),
384 u2.transpose(),
385 simd_align(k + 1),
386 );
387 y2.copy_from(v2.rb());
388 z2.copy_from(w2.rb().col(0));
389 } else {
390 matmul(z2.rb_mut().as_mat_mut(), Accum::Replace, A22.rb(), x2.as_mat(), one(), par);
391 matmul(y2.rb_mut().as_mat_mut(), Accum::Replace, x2.adjoint().as_mat(), A22.rb(), one(), par);
392 }
393
394 let u2 = x2;
395
396 let b = mul_real(
397 mul_pow2(dot::inner_prod(u2.rb().transpose(), Conj::Yes, z2.rb(), Conj::No), from_f64(0.5)),
398 tau_inv,
399 );
400 z!(&mut y2, u2.transpose()).for_each(|uz!(y, u)| *y = mul_real(*y - b * conj(*u), tau_inv));
401 z!(&mut z2, u2).for_each(|uz!(z, u)| *z = mul_real(*z - b * *u, tau_inv));
402
403 let dot = mul_real(dot::inner_prod(A12.rb(), Conj::No, u2.rb(), Conj::No), tau_inv);
404 z!(&mut A12, u2.transpose()).for_each(|uz!(a, u)| *a = *a - dot * conj(u));
405
406 matmul(w0.rb_mut().col_mut(0).as_mat_mut(), Accum::Replace, A02.rb(), u2.as_mat(), one(), par);
407 matmul(
408 A02.rb_mut(),
409 Accum::Add,
410 w0.rb().col(0).as_mat(),
411 u2.adjoint().as_mat(),
412 -from_real::<T>(&tau_inv),
413 par,
414 );
415
416 A21[0] = beta;
417 }
418 }
419
420 if n > 0 {
421 let n = n - 1;
422 let A = A.rb().submatrix(1, 0, n, n);
423 let mut H = H.rb_mut().subcols_mut(0, n);
424
425 let mut j = 0;
426 while j < n {
427 let b = Ord::min(b, n - j);
428
429 let mut H = H.rb_mut().submatrix_mut(0, j, b, b);
430
431 for k in 0..b {
432 H[(k, k)] = copy(H[(0, k)]);
433 }
434
435 householder::upgrade_householder_factor(H.rb_mut(), A.submatrix(j, j, n - j, b), b, 1, par);
436 j += b;
437 }
438 }
439}
440
441#[math]
442fn hessenberg_gqvdg_unblocked<T: ComplexField>(
443 A: MatMut<'_, T>,
444 Z: MatMut<'_, T>,
445 H: MatMut<'_, T>,
446 beta: ColMut<'_, T>,
447 par: Par,
448 stack: &mut MemStack,
449 params: HessenbergParams,
450) {
451 let n = A.nrows();
452 let b = H.nrows();
453 let mut A = A;
454 let mut H = H;
455 let mut Z = Z;
456 _ = params;
457
458 let (mut x, _) = unsafe { temp_mat_uninit(n, 1, stack) };
459 let mut x = x.as_mat_mut().col_mut(0);
460 let mut beta = beta;
461
462 for k in 0..b {
463 let mut x0 = x.rb_mut().subrows_mut(0, k);
464 let (T00, T01, _, T11) = H.rb_mut().split_at_mut(k, k);
465 let (mut T01, _) = T01.split_first_col_mut().unwrap();
466 let (mut T11, _, _, _) = T11.split_at_mut(1, 1);
467
468 let T11 = &mut T11[(0, 0)];
469
470 let (U0, A12) = A.rb_mut().split_at_col_mut(k);
471 let (mut A1, A2) = A12.split_first_col_mut().unwrap();
472
473 let (Z0, Z12) = Z.rb_mut().split_at_col_mut(k);
474 let (mut Z1, _) = Z12.split_first_col_mut().unwrap();
475
476 let U0 = U0.rb();
477 let Z0 = Z0.rb();
478 let T00 = T00.rb();
479
480 let (U00, U10) = U0.split_at_row(k);
481 let (U10, U20) = U10.split_first_row().unwrap();
482
483 x0.copy_from(U10.adjoint());
484 triangular_solve::solve_upper_triangular_in_place(T00, x0.rb_mut().as_mat_mut(), par);
485 matmul::matmul(A1.rb_mut().as_mat_mut(), Accum::Add, Z0, x0.rb().as_mat(), -one::<T>(), par);
486
487 let (mut A01, A11) = A1.rb_mut().split_at_row_mut(k);
488 let (mut A11, mut A21) = A11.split_at_row_mut(1);
489 let A11 = &mut A11[0];
490
491 {
492 matmul::triangular::matmul(
493 x0.rb_mut().as_mat_mut(),
494 BlockStructure::Rectangular,
495 Accum::Replace,
496 U00.adjoint(),
497 BlockStructure::StrictTriangularUpper,
498 A01.rb().as_mat(),
499 BlockStructure::Rectangular,
500 one(),
501 par,
502 );
503 z!(x0.rb_mut(), U10.transpose()).for_each(|uz!(x, u)| *x = *x + *A11 * conj(*u));
504 matmul::matmul(x0.rb_mut().as_mat_mut(), Accum::Add, U20.adjoint(), A21.rb().as_mat(), one(), par);
505 }
506 {
507 triangular_solve::solve_lower_triangular_in_place(T00.adjoint(), x0.rb_mut().as_mat_mut(), par);
508 }
509 {
510 matmul::triangular::matmul(
511 A01.rb_mut().as_mat_mut(),
512 BlockStructure::Rectangular,
513 Accum::Add,
514 U00,
515 BlockStructure::StrictTriangularLower,
516 x0.rb().as_mat(),
517 BlockStructure::Rectangular,
518 -one::<T>(),
519 par,
520 );
521 *A11 = *A11 - dot::inner_prod(U10, Conj::No, x0.rb(), Conj::No);
522 matmul::matmul(A21.rb_mut().as_mat_mut(), Accum::Add, U20, x0.rb().as_mat(), -one::<T>(), par);
523 }
524
525 if k + 1 < n {
526 let (mut A11, mut A21) = A21.rb_mut().split_at_row_mut(1);
527 let A11 = &mut A11[0];
528
529 let HouseholderInfo { tau, .. } = householder::make_householder_in_place(A11, A21.rb_mut());
530
531 beta[k] = copy(A11);
532 *A11 = one();
533 *T11 = from_real(tau);
534 } else {
535 *T11 = infinity();
536 }
537
538 matmul::matmul(Z1.rb_mut().as_mat_mut(), Accum::Replace, A2.rb(), A21.rb().as_mat(), one(), par);
539
540 matmul::matmul(T01.rb_mut().as_mat_mut(), Accum::Replace, U20.adjoint(), A21.rb().as_mat(), one(), par);
541 }
542}
543
544#[track_caller]
551pub fn hessenberg_in_place<T: ComplexField>(
552 A: MatMut<'_, T>,
553 householder: MatMut<'_, T>,
554 par: Par,
555 stack: &mut MemStack,
556 params: Spec<HessenbergParams, T>,
557) {
558 let params = params.config;
559 assert!(all(A.nrows() == A.ncols(), householder.ncols() == A.ncols().saturating_sub(1)));
560
561 let n = A.nrows().unbound();
562
563 if n * n < params.blocking_threshold {
564 hessenberg_rearranged_unblocked(A, householder, par, stack, params);
565 } else {
566 hessenberg_gqvdg_blocked(A, householder, par, stack, params);
567 }
568}
569
570#[math]
571fn hessenberg_gqvdg_blocked<T: ComplexField>(A: MatMut<'_, T>, H: MatMut<'_, T>, par: Par, stack: &mut MemStack, params: HessenbergParams) {
572 let n = A.nrows();
573 let b = H.nrows();
574 let mut A = A;
575 let mut H = H;
576 let (mut Z, stack) = unsafe { temp_mat_uninit(n, b, stack) };
577 let mut Z = Z.as_mat_mut();
578
579 let mut j = 0;
580 while j < n {
581 let bs = Ord::min(b, n - j);
582 let bs_u = Ord::min(bs, n - j - 1);
583
584 let (mut beta, stack) = unsafe { temp_mat_uninit(bs, 1, stack) };
585 let mut beta = beta.as_mat_mut().col_mut(0);
586
587 {
588 let mut T11 = H.rb_mut().submatrix_mut(0, j, bs_u, bs_u);
589 {
590 let A11 = A.rb_mut().submatrix_mut(j, j, n - j, n - j);
591 let Z1 = Z.rb_mut().submatrix_mut(j, 0, n - j, bs);
592
593 hessenberg_gqvdg_unblocked(A11, Z1, T11.rb_mut(), beta.rb_mut(), par, stack, params);
594 }
595
596 let (mut X, _) = unsafe { temp_mat_uninit(n, bs_u, stack) };
597 let mut X = X.as_mat_mut();
598
599 let (mut X0, X12) = X.rb_mut().split_at_row_mut(j);
600 let (_, mut X2) = X12.split_at_row_mut(bs_u);
601
602 let (_, Z12) = Z.rb_mut().subcols_mut(0, bs_u).split_at_row_mut(j);
603 let (mut Z1, mut Z2) = Z12.split_at_row_mut(bs_u);
604
605 let (_, A01, _, A11) = A.rb_mut().split_at_mut(j, j);
606 let (mut A01, mut A02) = A01.split_at_col_mut(bs_u);
607 let (A11, mut A12, A21, mut A22) = A11.split_at_mut(bs_u, bs_u);
608
609 let U1 = A11.rb();
610 let U2 = A21.rb();
611
612 let T1 = T11.rb();
613
614 matmul::triangular::matmul(
615 X0.rb_mut(),
616 BlockStructure::Rectangular,
617 Accum::Replace,
618 A01.rb(),
619 BlockStructure::Rectangular,
620 U1,
621 BlockStructure::StrictTriangularLower,
622 one(),
623 par,
624 );
625 matmul::matmul(X0.rb_mut(), Accum::Add, A02.rb(), U2, one(), par);
626
627 triangular_solve::solve_lower_triangular_in_place(T1.transpose(), X0.rb_mut().transpose_mut(), par);
628
629 matmul::triangular::matmul(
630 A01.rb_mut(),
631 BlockStructure::Rectangular,
632 Accum::Add,
633 X0.rb(),
634 BlockStructure::Rectangular,
635 U1.adjoint(),
636 BlockStructure::StrictTriangularUpper,
637 -one::<T>(),
638 par,
639 );
640 matmul::matmul(A02.rb_mut(), Accum::Add, X0.rb(), U2.adjoint(), -one::<T>(), par);
641
642 triangular_solve::solve_lower_triangular_in_place(T1.transpose(), Z1.rb_mut().transpose_mut(), par);
643 triangular_solve::solve_lower_triangular_in_place(T1.transpose(), Z2.rb_mut().transpose_mut(), par);
644
645 matmul::matmul(A12.rb_mut(), Accum::Add, Z1.rb(), U2.adjoint(), -one::<T>(), par);
646 matmul::matmul(A22.rb_mut(), Accum::Add, Z2.rb(), U2.adjoint(), -one::<T>(), par);
647
648 let mut X = X2.rb_mut().transpose_mut();
649
650 matmul::triangular::matmul(
651 X.rb_mut(),
652 BlockStructure::Rectangular,
653 Accum::Replace,
654 U1.adjoint(),
655 BlockStructure::StrictTriangularUpper,
656 A12.rb(),
657 BlockStructure::Rectangular,
658 one(),
659 par,
660 );
661 matmul::matmul(X.rb_mut(), Accum::Add, U2.adjoint(), A22.rb(), one(), par);
662
663 triangular_solve::solve_lower_triangular_in_place(T1.adjoint(), X.rb_mut(), par);
664
665 matmul::triangular::matmul(
666 A12.rb_mut(),
667 BlockStructure::Rectangular,
668 Accum::Add,
669 U1,
670 BlockStructure::StrictTriangularLower,
671 X.rb(),
672 BlockStructure::Rectangular,
673 -one::<T>(),
674 par,
675 );
676 matmul::matmul(A22.rb_mut(), Accum::Add, U2, X.rb(), -one::<T>(), par);
677 }
678
679 let n = n - j;
680 let mut A = A.rb_mut().submatrix_mut(j, j, n, bs);
681 for k in 0..bs {
682 if k + 1 < n {
683 A[(k + 1, k)] = copy(beta[k]);
684 }
685 }
686
687 j += bs;
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 use dyn_stack::MemBuffer;
694 use std::mem::MaybeUninit;
695
696 use super::*;
697 use crate::stats::prelude::*;
698 use crate::utils::approx::*;
699 use crate::{Mat, assert, c64};
700
701 #[test]
702 fn test_hessenberg_real() {
703 let rng = &mut StdRng::seed_from_u64(0);
704
705 for n in [3, 4, 8, 16] {
706 let A = CwiseMatDistribution {
707 nrows: n,
708 ncols: n,
709 dist: StandardNormal,
710 }
711 .rand::<Mat<f64>>(rng);
712
713 let b = 3;
714 let mut H = Mat::zeros(b, n - 1);
715
716 let mut V = A.clone();
717 let mut V = V.as_mut();
718 hessenberg_rearranged_unblocked(
719 V.rb_mut(),
720 H.as_mut(),
721 Par::Seq,
722 MemStack::new(&mut [MaybeUninit::uninit(); 1024]),
723 auto!(f64),
724 );
725
726 let mut A = A.clone();
727 let mut A = A.as_mut();
728
729 for iter in 0..2 {
730 let mut A = if iter == 0 { A.rb_mut() } else { A.rb_mut().transpose_mut() };
731
732 let n = n - 1;
733
734 let V = V.rb().submatrix(1, 0, n, n);
735 let mut A = A.rb_mut().subrows_mut(1, n);
736 let H = H.as_ref();
737
738 householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
739 V,
740 H.as_ref(),
741 if iter == 0 { Conj::Yes } else { Conj::No },
742 A.rb_mut(),
743 Par::Seq,
744 MemStack::new(&mut MemBuffer::new(
745 householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<f64>(n, b, n + 1),
746 )),
747 );
748 }
749
750 let approx_eq = CwiseMat(ApproxEq::<f64>::eps());
751 for j in 0..n {
752 for i in 0..n {
753 if i > j + 1 {
754 V[(i, j)] = 0.0;
755 }
756 }
757 }
758
759 assert!(V ~ A);
760 }
761 }
762
763 #[test]
764 fn test_hessenberg_cplx() {
765 let rng = &mut StdRng::seed_from_u64(0);
766
767 for n in [1, 2, 3, 4, 8, 16] {
768 for par in [Par::Seq, Par::rayon(4)] {
769 let A = CwiseMatDistribution {
770 nrows: n,
771 ncols: n,
772 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
773 }
774 .rand::<Mat<c64>>(rng);
775
776 let b = 3;
777 let mut H = Mat::zeros(b, n - 1);
778
779 let mut V = A.clone();
780 let mut V = V.as_mut();
781 hessenberg_rearranged_unblocked(
782 V.rb_mut(),
783 H.as_mut(),
784 par,
785 MemStack::new(&mut [MaybeUninit::uninit(); 8 * 1024]),
786 HessenbergParams {
787 par_threshold: 0,
788 ..auto!(c64)
789 }
790 .into(),
791 );
792
793 let mut A = A.clone();
794 let mut A = A.as_mut();
795
796 for iter in 0..2 {
797 let mut A = if iter == 0 { A.rb_mut() } else { A.rb_mut().transpose_mut() };
798
799 let n = n - 1;
800
801 let V = V.rb().submatrix(1, 0, n, n);
802 let mut A = A.rb_mut().subrows_mut(1, n);
803 let H = H.as_ref();
804
805 householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
806 V,
807 H.as_ref(),
808 if iter == 0 { Conj::Yes } else { Conj::No },
809 A.rb_mut(),
810 Par::Seq,
811 MemStack::new(&mut MemBuffer::new(
812 householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<c64>(n, b, n + 1),
813 )),
814 );
815 }
816
817 let approx_eq = CwiseMat(ApproxEq::eps());
818 for j in 0..n {
819 for i in 0..n {
820 if i > j + 1 {
821 V[(i, j)] = c64::ZERO;
822 }
823 }
824 }
825
826 assert!(V ~ A);
827 }
828 }
829 }
830
831 #[test]
832 fn test_hessenberg_cplx_gqvdg() {
833 let rng = &mut StdRng::seed_from_u64(0);
834
835 for n in [2, 3, 4, 8, 16, 21] {
836 for par in [Par::Seq, Par::rayon(4)] {
837 let b = 4;
838
839 let A = CwiseMatDistribution {
840 nrows: n,
841 ncols: n,
842 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
843 }
844 .rand::<Mat<c64, _, _>>(rng);
845
846 let mut H = Mat::zeros(b, n - 1);
847
848 let mut V = A.clone();
849 let mut V = V.as_mut();
850 hessenberg_gqvdg_blocked(
851 V.rb_mut(),
852 H.as_mut(),
853 par,
854 MemStack::new(&mut [MaybeUninit::uninit(); 16 * 1024]),
855 HessenbergParams {
856 par_threshold: 0,
857 ..auto!(c64)
858 }
859 .into(),
860 );
861
862 let mut A = A.clone();
863 let mut A = A.as_mut();
864
865 for iter in 0..2 {
866 let mut A = if iter == 0 { A.rb_mut() } else { A.rb_mut().transpose_mut() };
867
868 let n = n - 1;
869
870 let V = V.rb().submatrix(1, 0, n, n);
871 let mut A = A.rb_mut().subrows_mut(1, n);
872 let H = H.as_ref().subcols(0, n);
873
874 householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
875 V,
876 H.as_ref(),
877 if iter == 0 { Conj::Yes } else { Conj::No },
878 A.rb_mut(),
879 Par::Seq,
880 MemStack::new(&mut MemBuffer::new(
881 householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<c64>(n, b, n + 1),
882 )),
883 );
884 }
885
886 let approx_eq = CwiseMat(ApproxEq::eps());
887 for j in 0..n {
888 for i in 0..n {
889 if i > j + 1 {
890 V[(i, j)] = c64::ZERO;
891 }
892 }
893 }
894
895 assert!(V ~ A);
896 }
897 }
898 }
899}