faer/linalg/evd/
hessenberg.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::householder::{self, HouseholderInfo};
4use linalg::matmul::triangular::BlockStructure;
5use linalg::matmul::{self, dot, matmul};
6use linalg::triangular_solve;
7
8/// hessenberg factorization tuning parameters
9#[derive(Copy, Clone, Debug)]
10pub struct HessenbergParams {
11	/// threshold at which parallelism should be disabled
12	pub par_threshold: usize,
13	/// threshold at which parallelism should be disabled
14	pub blocking_threshold: usize,
15
16	#[doc(hidden)]
17	pub non_exhaustive: NonExhaustive,
18}
19
20impl<T: ComplexField> Auto<T> for HessenbergParams {
21	fn auto() -> Self {
22		Self {
23			par_threshold: 192 * 256,
24			blocking_threshold: 256 * 256,
25			non_exhaustive: NonExhaustive(()),
26		}
27	}
28}
29
30/// computes the size and alignment of the workspace required to compute a matrix's hessenberg
31/// decomposition
32pub fn hessenberg_in_place_scratch<T: ComplexField>(dim: usize, blocksize: usize, par: Par, params: Spec<HessenbergParams, T>) -> StackReq {
33	let params = params.config;
34	let _ = par;
35	let n = dim;
36	if n * n < params.blocking_threshold {
37		StackReq::any_of(&[StackReq::all_of(&[
38			temp_mat_scratch::<T>(n, 1).array(3),
39			temp_mat_scratch::<T>(n, par.degree()),
40		])])
41	} else {
42		StackReq::all_of(&[
43			temp_mat_scratch::<T>(n, blocksize),
44			temp_mat_scratch::<T>(blocksize, 1),
45			StackReq::any_of(&[
46				StackReq::all_of(&[temp_mat_scratch::<T>(n, 1), temp_mat_scratch::<T>(n, par.degree())]),
47				temp_mat_scratch::<T>(n, blocksize),
48			]),
49		])
50	}
51}
52
53#[math]
54fn hessenberg_fused_op_simd<T: ComplexField>(
55	A: MatMut<'_, T, usize, usize, ContiguousFwd>,
56
57	l_out: RowMut<'_, T, usize>,
58	r_out: ColMut<'_, T, usize, ContiguousFwd>,
59	l_in: RowRef<'_, T, usize, ContiguousFwd>,
60	r_in: ColRef<'_, T, usize>,
61
62	l0: ColRef<'_, T, usize, ContiguousFwd>,
63	l1: ColRef<'_, T, usize, ContiguousFwd>,
64	r0: RowRef<'_, T, usize>,
65	r1: RowRef<'_, T, usize>,
66	align: usize,
67) {
68	struct Impl<'a, 'M, 'N, T: ComplexField> {
69		A: MatMut<'a, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
70
71		l_out: RowMut<'a, T, Dim<'N>>,
72		r_out: ColMut<'a, T, Dim<'M>, ContiguousFwd>,
73		l_in: RowRef<'a, T, Dim<'M>, ContiguousFwd>,
74		r_in: ColRef<'a, T, Dim<'N>>,
75
76		l0: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
77		l1: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
78		r0: RowRef<'a, T, Dim<'N>>,
79		r1: RowRef<'a, T, Dim<'N>>,
80		align: usize,
81	}
82
83	impl<'a, 'M, 'N, T: ComplexField> pulp::WithSimd for Impl<'a, 'M, 'N, T> {
84		type Output = ();
85
86		#[inline(always)]
87		fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
88			let Self {
89				mut A,
90				mut l_out,
91				mut r_out,
92				l_in,
93				r_in,
94				l0,
95				l1,
96				r0,
97				r1,
98				align,
99			} = self;
100
101			let (m, n) = A.shape();
102
103			let simd = SimdCtx::<T, S>::new_align(T::simd_ctx(simd), m, align);
104
105			{
106				let (head, body, tail) = simd.indices();
107				if let Some(i) = head {
108					simd.write(r_out.rb_mut(), i, simd.zero());
109				}
110				for i in body {
111					simd.write(r_out.rb_mut(), i, simd.zero());
112				}
113				if let Some(i) = tail {
114					simd.write(r_out.rb_mut(), i, simd.zero());
115				}
116			}
117
118			let (head, body4, body1, tail) = simd.batch_indices::<4>();
119
120			let l_in = l_in.transpose();
121
122			for j in n.indices() {
123				let mut A = A.rb_mut().col_mut(j);
124				let r_in = simd.splat(r_in.at(j));
125				let r0 = simd.splat(&(-r0[j]));
126				let r1 = simd.splat(&(-r1[j]));
127
128				let mut acc0 = simd.zero();
129				let mut acc1 = simd.zero();
130				let mut acc2 = simd.zero();
131				let mut acc3 = simd.zero();
132
133				if let Some(i0) = head {
134					let mut a0 = simd.read(A.rb(), i0);
135					a0 = simd.mul_add(simd.read(l0, i0), r0, a0);
136					a0 = simd.conj_mul_add(r1, simd.read(l1, i0), a0);
137					simd.write(A.rb_mut(), i0, a0);
138					acc0 = simd.conj_mul_add(simd.read(l_in, i0), a0, acc0);
139					let tmp = simd.read(r_out.rb(), i0);
140					simd.write(r_out.rb_mut(), i0, simd.mul_add(a0, r_in, tmp));
141				}
142				for [i0, i1, i2, i3] in body4.clone() {
143					{
144						let mut a0 = simd.read(A.rb(), i0);
145						a0 = simd.mul_add(simd.read(l0, i0), r0, a0);
146						a0 = simd.conj_mul_add(r1, simd.read(l1, i0), a0);
147						simd.write(A.rb_mut(), i0, a0);
148						acc0 = simd.conj_mul_add(simd.read(l_in, i0), a0, acc0);
149						let tmp = simd.read(r_out.rb(), i0);
150						simd.write(r_out.rb_mut(), i0, simd.mul_add(a0, r_in, tmp));
151					}
152					{
153						let mut a1 = simd.read(A.rb(), i1);
154						a1 = simd.mul_add(simd.read(l0, i1), r0, a1);
155						a1 = simd.conj_mul_add(r1, simd.read(l1, i1), a1);
156						simd.write(A.rb_mut(), i1, a1);
157						acc1 = simd.conj_mul_add(simd.read(l_in, i1), a1, acc1);
158						let tmp = simd.read(r_out.rb(), i1);
159						simd.write(r_out.rb_mut(), i1, simd.mul_add(a1, r_in, tmp));
160					}
161					{
162						let mut a2 = simd.read(A.rb(), i2);
163						a2 = simd.mul_add(simd.read(l0, i2), r0, a2);
164						a2 = simd.conj_mul_add(r1, simd.read(l1, i2), a2);
165						simd.write(A.rb_mut(), i2, a2);
166						acc2 = simd.conj_mul_add(simd.read(l_in, i2), a2, acc2);
167						let tmp = simd.read(r_out.rb(), i2);
168						simd.write(r_out.rb_mut(), i2, simd.mul_add(a2, r_in, tmp));
169					}
170					{
171						let mut a3 = simd.read(A.rb(), i3);
172						a3 = simd.mul_add(simd.read(l0, i3), r0, a3);
173						a3 = simd.conj_mul_add(r1, simd.read(l1, i3), a3);
174						simd.write(A.rb_mut(), i3, a3);
175						acc3 = simd.conj_mul_add(simd.read(l_in, i3), a3, acc3);
176						let tmp = simd.read(r_out.rb(), i3);
177						simd.write(r_out.rb_mut(), i3, simd.mul_add(a3, r_in, tmp));
178					}
179				}
180				for i0 in body1.clone() {
181					let mut a0 = simd.read(A.rb(), i0);
182					a0 = simd.mul_add(simd.read(l0, i0), r0, a0);
183					a0 = simd.conj_mul_add(r1, simd.read(l1, i0), a0);
184					simd.write(A.rb_mut(), i0, a0);
185					acc0 = simd.conj_mul_add(simd.read(l_in, i0), a0, acc0);
186					let tmp = simd.read(r_out.rb(), i0);
187					simd.write(r_out.rb_mut(), i0, simd.mul_add(a0, r_in, tmp));
188				}
189				if let Some(i0) = tail {
190					let mut a0 = simd.read(A.rb(), i0);
191					a0 = simd.mul_add(simd.read(l0, i0), r0, a0);
192					a0 = simd.conj_mul_add(r1, simd.read(l1, i0), a0);
193					simd.write(A.rb_mut(), i0, a0);
194					acc0 = simd.conj_mul_add(simd.read(l_in, i0), a0, acc0);
195					let tmp = simd.read(r_out.rb(), i0);
196					simd.write(r_out.rb_mut(), i0, simd.mul_add(a0, r_in, tmp));
197				}
198
199				acc0 = simd.add(acc0, acc1);
200				acc2 = simd.add(acc2, acc3);
201				acc0 = simd.add(acc0, acc2);
202
203				let l_out = l_out.rb_mut().at_mut(j);
204				*l_out = simd.reduce_sum(acc0);
205			}
206		}
207	}
208
209	with_dim!(M, A.nrows());
210	with_dim!(N, A.ncols());
211
212	dispatch!(
213		Impl {
214			A: A.as_shape_mut(M, N),
215			l_out: l_out.as_col_shape_mut(N),
216			r_out: r_out.as_row_shape_mut(M),
217			l_in: l_in.as_col_shape(M),
218			r_in: r_in.as_row_shape(N),
219			l0: l0.as_row_shape(M),
220			l1: l1.as_row_shape(M),
221			r0: r0.as_col_shape(N),
222			r1: r1.as_col_shape(N),
223			align,
224		},
225		Impl,
226		T
227	)
228}
229
230#[math]
231fn hessenberg_fused_op_fallback<T: ComplexField>(
232	A: MatMut<'_, T>,
233
234	l_out: RowMut<'_, T>,
235	r_out: ColMut<'_, T>,
236	l_in: RowRef<'_, T>,
237	r_in: ColRef<'_, T>,
238
239	l0: ColRef<'_, T>,
240	l1: ColRef<'_, T>,
241	r0: RowRef<'_, T>,
242	r1: RowRef<'_, T>,
243) {
244	let mut A = A;
245
246	matmul(A.rb_mut(), Accum::Add, l0.as_mat(), r0.as_mat(), -one::<T>(), Par::Seq);
247	matmul(A.rb_mut(), Accum::Add, l1.as_mat(), r1.as_mat().conjugate(), -one::<T>(), Par::Seq);
248
249	matmul(r_out.as_mat_mut(), Accum::Replace, A.rb(), r_in.as_mat(), one(), Par::Seq);
250	matmul(l_out.as_mat_mut(), Accum::Replace, l_in.as_mat().conjugate(), A.rb(), one(), Par::Seq);
251}
252
253fn hessenberg_fused_op<T: ComplexField>(
254	A: MatMut<'_, T>,
255
256	l_out: RowMut<'_, T>,
257	r_out: ColMut<'_, T>,
258	l_in: RowRef<'_, T>,
259	r_in: ColRef<'_, T>,
260
261	l0: ColRef<'_, T>,
262	l1: ColRef<'_, T>,
263	r0: RowRef<'_, T>,
264	r1: RowRef<'_, T>,
265	align: usize,
266) {
267	let mut A = A;
268	let mut r_out = r_out;
269
270	if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
271		if let (Some(A), Some(r_out), Some(l_in), Some(l0), Some(l1)) = (
272			A.rb_mut().try_as_col_major_mut(),
273			r_out.rb_mut().try_as_col_major_mut(),
274			l_in.try_as_row_major(),
275			l0.try_as_col_major(),
276			l1.try_as_col_major(),
277		) {
278			hessenberg_fused_op_simd(A, l_out, r_out, l_in, r_in, l0, l1, r0, r1, align);
279		} else {
280			hessenberg_fused_op_fallback(A, l_out, r_out, l_in, r_in, l0, l1, r0, r1);
281		}
282	} else {
283		hessenberg_fused_op_fallback(A, l_out, r_out, l_in, r_in, l0, l1, r0, r1);
284	}
285}
286
287#[math]
288fn hessenberg_rearranged_unblocked<T: ComplexField>(A: MatMut<'_, T>, H: MatMut<'_, T>, par: Par, stack: &mut MemStack, params: HessenbergParams) {
289	assert!(all(A.nrows() == A.ncols(), H.ncols() == A.ncols().saturating_sub(1)));
290
291	let n = A.nrows();
292	let b = H.nrows();
293
294	if n == 0 {
295		return;
296	}
297
298	let mut A = A;
299	let mut H = H;
300	let mut par = par;
301
302	{
303		let mut H = H.rb_mut().row_mut(0);
304		let (mut y, stack) = unsafe { temp_mat_uninit(n, 1, stack) };
305		let (mut z, stack) = unsafe { temp_mat_uninit(n, 1, stack) };
306		let (mut v, stack) = unsafe { temp_mat_uninit(n, 1, stack) };
307		let (mut w, _) = unsafe { temp_mat_uninit(n, par.degree(), stack) };
308
309		let mut y = y.as_mat_mut().col_mut(0).transpose_mut();
310		let mut z = z.as_mat_mut().col_mut(0);
311		let mut v = v.as_mat_mut().col_mut(0).transpose_mut();
312		let mut w = w.as_mat_mut();
313
314		for k in 0..n {
315			let (_, A01, A10, A11) = A.rb_mut().split_at_mut(k, k);
316
317			let (_, mut A02) = A01.split_first_col_mut().unwrap();
318			let (_, A20) = A10.split_first_row_mut().unwrap();
319			let (mut A11, A12, A21, mut A22) = A11.split_at_mut(1, 1);
320
321			let mut A12 = A12.row_mut(0);
322			let mut A21 = A21.col_mut(0);
323
324			let A11 = &mut A11[(0, 0)];
325
326			let (y1, mut y2) = y.rb_mut().split_at_col_mut(k).1.split_at_col_mut(1);
327			let y1 = copy(y1[0]);
328
329			let (z1, mut z2) = z.rb_mut().split_at_row_mut(k).1.split_at_row_mut(1);
330			let z1 = copy(z1[0]);
331
332			let (_, mut v2) = v.rb_mut().split_at_col_mut(k).1.split_at_col_mut(1);
333			let (mut w0, w12) = w.rb_mut().split_at_row_mut(k);
334			let (_, mut w2) = w12.split_at_row_mut(1);
335
336			if k > 0 {
337				let p = k - 1;
338				let u2 = A20.rb().col(p);
339
340				*A11 = *A11 - y1 - z1;
341				z!(&mut A12, &y2, u2.rb().transpose()).for_each(|uz!(a, y, u)| *a = *a - *y - z1 * conj(*u));
342				z!(&mut A21, &u2, &z2).for_each(|uz!(a, u, z)| *a = *a - *u * y1 - *z);
343			}
344
345			{
346				let n = n - k - 1;
347				if n * n < params.par_threshold {
348					par = Par::Seq;
349				}
350			}
351
352			if k + 1 == n {
353				break;
354			}
355
356			let beta;
357			let tau_inv;
358			{
359				let (mut A11, mut A21) = A21.rb_mut().split_at_row_mut(1);
360				let A11 = &mut A11[0];
361
362				let HouseholderInfo { tau, .. } = householder::make_householder_in_place(A11, A21.rb_mut());
363				tau_inv = recip(tau);
364				beta = copy(*A11);
365				*A11 = one();
366
367				H[k] = from_real(tau);
368			}
369
370			let x2 = A21.rb();
371
372			if k > 0 {
373				let p = k - 1;
374				let u2 = A20.rb().col(p);
375				hessenberg_fused_op(
376					A22.rb_mut(),
377					v2.rb_mut(),
378					w2.rb_mut().col_mut(0),
379					x2.transpose(),
380					x2,
381					u2,
382					z2.rb(),
383					y2.rb(),
384					u2.transpose(),
385					simd_align(k + 1),
386				);
387				y2.copy_from(v2.rb());
388				z2.copy_from(w2.rb().col(0));
389			} else {
390				matmul(z2.rb_mut().as_mat_mut(), Accum::Replace, A22.rb(), x2.as_mat(), one(), par);
391				matmul(y2.rb_mut().as_mat_mut(), Accum::Replace, x2.adjoint().as_mat(), A22.rb(), one(), par);
392			}
393
394			let u2 = x2;
395
396			let b = mul_real(
397				mul_pow2(dot::inner_prod(u2.rb().transpose(), Conj::Yes, z2.rb(), Conj::No), from_f64(0.5)),
398				tau_inv,
399			);
400			z!(&mut y2, u2.transpose()).for_each(|uz!(y, u)| *y = mul_real(*y - b * conj(*u), tau_inv));
401			z!(&mut z2, u2).for_each(|uz!(z, u)| *z = mul_real(*z - b * *u, tau_inv));
402
403			let dot = mul_real(dot::inner_prod(A12.rb(), Conj::No, u2.rb(), Conj::No), tau_inv);
404			z!(&mut A12, u2.transpose()).for_each(|uz!(a, u)| *a = *a - dot * conj(u));
405
406			matmul(w0.rb_mut().col_mut(0).as_mat_mut(), Accum::Replace, A02.rb(), u2.as_mat(), one(), par);
407			matmul(
408				A02.rb_mut(),
409				Accum::Add,
410				w0.rb().col(0).as_mat(),
411				u2.adjoint().as_mat(),
412				-from_real::<T>(&tau_inv),
413				par,
414			);
415
416			A21[0] = beta;
417		}
418	}
419
420	if n > 0 {
421		let n = n - 1;
422		let A = A.rb().submatrix(1, 0, n, n);
423		let mut H = H.rb_mut().subcols_mut(0, n);
424
425		let mut j = 0;
426		while j < n {
427			let b = Ord::min(b, n - j);
428
429			let mut H = H.rb_mut().submatrix_mut(0, j, b, b);
430
431			for k in 0..b {
432				H[(k, k)] = copy(H[(0, k)]);
433			}
434
435			householder::upgrade_householder_factor(H.rb_mut(), A.submatrix(j, j, n - j, b), b, 1, par);
436			j += b;
437		}
438	}
439}
440
441#[math]
442fn hessenberg_gqvdg_unblocked<T: ComplexField>(
443	A: MatMut<'_, T>,
444	Z: MatMut<'_, T>,
445	H: MatMut<'_, T>,
446	beta: ColMut<'_, T>,
447	par: Par,
448	stack: &mut MemStack,
449	params: HessenbergParams,
450) {
451	let n = A.nrows();
452	let b = H.nrows();
453	let mut A = A;
454	let mut H = H;
455	let mut Z = Z;
456	_ = params;
457
458	let (mut x, _) = unsafe { temp_mat_uninit(n, 1, stack) };
459	let mut x = x.as_mat_mut().col_mut(0);
460	let mut beta = beta;
461
462	for k in 0..b {
463		let mut x0 = x.rb_mut().subrows_mut(0, k);
464		let (T00, T01, _, T11) = H.rb_mut().split_at_mut(k, k);
465		let (mut T01, _) = T01.split_first_col_mut().unwrap();
466		let (mut T11, _, _, _) = T11.split_at_mut(1, 1);
467
468		let T11 = &mut T11[(0, 0)];
469
470		let (U0, A12) = A.rb_mut().split_at_col_mut(k);
471		let (mut A1, A2) = A12.split_first_col_mut().unwrap();
472
473		let (Z0, Z12) = Z.rb_mut().split_at_col_mut(k);
474		let (mut Z1, _) = Z12.split_first_col_mut().unwrap();
475
476		let U0 = U0.rb();
477		let Z0 = Z0.rb();
478		let T00 = T00.rb();
479
480		let (U00, U10) = U0.split_at_row(k);
481		let (U10, U20) = U10.split_first_row().unwrap();
482
483		x0.copy_from(U10.adjoint());
484		triangular_solve::solve_upper_triangular_in_place(T00, x0.rb_mut().as_mat_mut(), par);
485		matmul::matmul(A1.rb_mut().as_mat_mut(), Accum::Add, Z0, x0.rb().as_mat(), -one::<T>(), par);
486
487		let (mut A01, A11) = A1.rb_mut().split_at_row_mut(k);
488		let (mut A11, mut A21) = A11.split_at_row_mut(1);
489		let A11 = &mut A11[0];
490
491		{
492			matmul::triangular::matmul(
493				x0.rb_mut().as_mat_mut(),
494				BlockStructure::Rectangular,
495				Accum::Replace,
496				U00.adjoint(),
497				BlockStructure::StrictTriangularUpper,
498				A01.rb().as_mat(),
499				BlockStructure::Rectangular,
500				one(),
501				par,
502			);
503			z!(x0.rb_mut(), U10.transpose()).for_each(|uz!(x, u)| *x = *x + *A11 * conj(*u));
504			matmul::matmul(x0.rb_mut().as_mat_mut(), Accum::Add, U20.adjoint(), A21.rb().as_mat(), one(), par);
505		}
506		{
507			triangular_solve::solve_lower_triangular_in_place(T00.adjoint(), x0.rb_mut().as_mat_mut(), par);
508		}
509		{
510			matmul::triangular::matmul(
511				A01.rb_mut().as_mat_mut(),
512				BlockStructure::Rectangular,
513				Accum::Add,
514				U00,
515				BlockStructure::StrictTriangularLower,
516				x0.rb().as_mat(),
517				BlockStructure::Rectangular,
518				-one::<T>(),
519				par,
520			);
521			*A11 = *A11 - dot::inner_prod(U10, Conj::No, x0.rb(), Conj::No);
522			matmul::matmul(A21.rb_mut().as_mat_mut(), Accum::Add, U20, x0.rb().as_mat(), -one::<T>(), par);
523		}
524
525		if k + 1 < n {
526			let (mut A11, mut A21) = A21.rb_mut().split_at_row_mut(1);
527			let A11 = &mut A11[0];
528
529			let HouseholderInfo { tau, .. } = householder::make_householder_in_place(A11, A21.rb_mut());
530
531			beta[k] = copy(A11);
532			*A11 = one();
533			*T11 = from_real(tau);
534		} else {
535			*T11 = infinity();
536		}
537
538		matmul::matmul(Z1.rb_mut().as_mat_mut(), Accum::Replace, A2.rb(), A21.rb().as_mat(), one(), par);
539
540		matmul::matmul(T01.rb_mut().as_mat_mut(), Accum::Replace, U20.adjoint(), A21.rb().as_mat(), one(), par);
541	}
542}
543
544/// computes a matrix $A$'s hessenberg decomposition such that $A = Q H Q^H$
545///
546/// $H$ is a hessenberg matrix stored in the upper triangular half of $A$ (plus the subdiagonal)
547///
548/// $Q$ is a sequence of householder reflections stored in the unit lower triangular half of $A$
549/// (excluding the diagonal), with the householder coefficients being stored in `householder`
550#[track_caller]
551pub fn hessenberg_in_place<T: ComplexField>(
552	A: MatMut<'_, T>,
553	householder: MatMut<'_, T>,
554	par: Par,
555	stack: &mut MemStack,
556	params: Spec<HessenbergParams, T>,
557) {
558	let params = params.config;
559	assert!(all(A.nrows() == A.ncols(), householder.ncols() == A.ncols().saturating_sub(1)));
560
561	let n = A.nrows().unbound();
562
563	if n * n < params.blocking_threshold {
564		hessenberg_rearranged_unblocked(A, householder, par, stack, params);
565	} else {
566		hessenberg_gqvdg_blocked(A, householder, par, stack, params);
567	}
568}
569
570#[math]
571fn hessenberg_gqvdg_blocked<T: ComplexField>(A: MatMut<'_, T>, H: MatMut<'_, T>, par: Par, stack: &mut MemStack, params: HessenbergParams) {
572	let n = A.nrows();
573	let b = H.nrows();
574	let mut A = A;
575	let mut H = H;
576	let (mut Z, stack) = unsafe { temp_mat_uninit(n, b, stack) };
577	let mut Z = Z.as_mat_mut();
578
579	let mut j = 0;
580	while j < n {
581		let bs = Ord::min(b, n - j);
582		let bs_u = Ord::min(bs, n - j - 1);
583
584		let (mut beta, stack) = unsafe { temp_mat_uninit(bs, 1, stack) };
585		let mut beta = beta.as_mat_mut().col_mut(0);
586
587		{
588			let mut T11 = H.rb_mut().submatrix_mut(0, j, bs_u, bs_u);
589			{
590				let A11 = A.rb_mut().submatrix_mut(j, j, n - j, n - j);
591				let Z1 = Z.rb_mut().submatrix_mut(j, 0, n - j, bs);
592
593				hessenberg_gqvdg_unblocked(A11, Z1, T11.rb_mut(), beta.rb_mut(), par, stack, params);
594			}
595
596			let (mut X, _) = unsafe { temp_mat_uninit(n, bs_u, stack) };
597			let mut X = X.as_mat_mut();
598
599			let (mut X0, X12) = X.rb_mut().split_at_row_mut(j);
600			let (_, mut X2) = X12.split_at_row_mut(bs_u);
601
602			let (_, Z12) = Z.rb_mut().subcols_mut(0, bs_u).split_at_row_mut(j);
603			let (mut Z1, mut Z2) = Z12.split_at_row_mut(bs_u);
604
605			let (_, A01, _, A11) = A.rb_mut().split_at_mut(j, j);
606			let (mut A01, mut A02) = A01.split_at_col_mut(bs_u);
607			let (A11, mut A12, A21, mut A22) = A11.split_at_mut(bs_u, bs_u);
608
609			let U1 = A11.rb();
610			let U2 = A21.rb();
611
612			let T1 = T11.rb();
613
614			matmul::triangular::matmul(
615				X0.rb_mut(),
616				BlockStructure::Rectangular,
617				Accum::Replace,
618				A01.rb(),
619				BlockStructure::Rectangular,
620				U1,
621				BlockStructure::StrictTriangularLower,
622				one(),
623				par,
624			);
625			matmul::matmul(X0.rb_mut(), Accum::Add, A02.rb(), U2, one(), par);
626
627			triangular_solve::solve_lower_triangular_in_place(T1.transpose(), X0.rb_mut().transpose_mut(), par);
628
629			matmul::triangular::matmul(
630				A01.rb_mut(),
631				BlockStructure::Rectangular,
632				Accum::Add,
633				X0.rb(),
634				BlockStructure::Rectangular,
635				U1.adjoint(),
636				BlockStructure::StrictTriangularUpper,
637				-one::<T>(),
638				par,
639			);
640			matmul::matmul(A02.rb_mut(), Accum::Add, X0.rb(), U2.adjoint(), -one::<T>(), par);
641
642			triangular_solve::solve_lower_triangular_in_place(T1.transpose(), Z1.rb_mut().transpose_mut(), par);
643			triangular_solve::solve_lower_triangular_in_place(T1.transpose(), Z2.rb_mut().transpose_mut(), par);
644
645			matmul::matmul(A12.rb_mut(), Accum::Add, Z1.rb(), U2.adjoint(), -one::<T>(), par);
646			matmul::matmul(A22.rb_mut(), Accum::Add, Z2.rb(), U2.adjoint(), -one::<T>(), par);
647
648			let mut X = X2.rb_mut().transpose_mut();
649
650			matmul::triangular::matmul(
651				X.rb_mut(),
652				BlockStructure::Rectangular,
653				Accum::Replace,
654				U1.adjoint(),
655				BlockStructure::StrictTriangularUpper,
656				A12.rb(),
657				BlockStructure::Rectangular,
658				one(),
659				par,
660			);
661			matmul::matmul(X.rb_mut(), Accum::Add, U2.adjoint(), A22.rb(), one(), par);
662
663			triangular_solve::solve_lower_triangular_in_place(T1.adjoint(), X.rb_mut(), par);
664
665			matmul::triangular::matmul(
666				A12.rb_mut(),
667				BlockStructure::Rectangular,
668				Accum::Add,
669				U1,
670				BlockStructure::StrictTriangularLower,
671				X.rb(),
672				BlockStructure::Rectangular,
673				-one::<T>(),
674				par,
675			);
676			matmul::matmul(A22.rb_mut(), Accum::Add, U2, X.rb(), -one::<T>(), par);
677		}
678
679		let n = n - j;
680		let mut A = A.rb_mut().submatrix_mut(j, j, n, bs);
681		for k in 0..bs {
682			if k + 1 < n {
683				A[(k + 1, k)] = copy(beta[k]);
684			}
685		}
686
687		j += bs;
688	}
689}
690
691#[cfg(test)]
692mod tests {
693	use dyn_stack::MemBuffer;
694	use std::mem::MaybeUninit;
695
696	use super::*;
697	use crate::stats::prelude::*;
698	use crate::utils::approx::*;
699	use crate::{Mat, assert, c64};
700
701	#[test]
702	fn test_hessenberg_real() {
703		let rng = &mut StdRng::seed_from_u64(0);
704
705		for n in [3, 4, 8, 16] {
706			let A = CwiseMatDistribution {
707				nrows: n,
708				ncols: n,
709				dist: StandardNormal,
710			}
711			.rand::<Mat<f64>>(rng);
712
713			let b = 3;
714			let mut H = Mat::zeros(b, n - 1);
715
716			let mut V = A.clone();
717			let mut V = V.as_mut();
718			hessenberg_rearranged_unblocked(
719				V.rb_mut(),
720				H.as_mut(),
721				Par::Seq,
722				MemStack::new(&mut [MaybeUninit::uninit(); 1024]),
723				auto!(f64),
724			);
725
726			let mut A = A.clone();
727			let mut A = A.as_mut();
728
729			for iter in 0..2 {
730				let mut A = if iter == 0 { A.rb_mut() } else { A.rb_mut().transpose_mut() };
731
732				let n = n - 1;
733
734				let V = V.rb().submatrix(1, 0, n, n);
735				let mut A = A.rb_mut().subrows_mut(1, n);
736				let H = H.as_ref();
737
738				householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
739					V,
740					H.as_ref(),
741					if iter == 0 { Conj::Yes } else { Conj::No },
742					A.rb_mut(),
743					Par::Seq,
744					MemStack::new(&mut MemBuffer::new(
745						householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<f64>(n, b, n + 1),
746					)),
747				);
748			}
749
750			let approx_eq = CwiseMat(ApproxEq::<f64>::eps());
751			for j in 0..n {
752				for i in 0..n {
753					if i > j + 1 {
754						V[(i, j)] = 0.0;
755					}
756				}
757			}
758
759			assert!(V ~ A);
760		}
761	}
762
763	#[test]
764	fn test_hessenberg_cplx() {
765		let rng = &mut StdRng::seed_from_u64(0);
766
767		for n in [1, 2, 3, 4, 8, 16] {
768			for par in [Par::Seq, Par::rayon(4)] {
769				let A = CwiseMatDistribution {
770					nrows: n,
771					ncols: n,
772					dist: ComplexDistribution::new(StandardNormal, StandardNormal),
773				}
774				.rand::<Mat<c64>>(rng);
775
776				let b = 3;
777				let mut H = Mat::zeros(b, n - 1);
778
779				let mut V = A.clone();
780				let mut V = V.as_mut();
781				hessenberg_rearranged_unblocked(
782					V.rb_mut(),
783					H.as_mut(),
784					par,
785					MemStack::new(&mut [MaybeUninit::uninit(); 8 * 1024]),
786					HessenbergParams {
787						par_threshold: 0,
788						..auto!(c64)
789					}
790					.into(),
791				);
792
793				let mut A = A.clone();
794				let mut A = A.as_mut();
795
796				for iter in 0..2 {
797					let mut A = if iter == 0 { A.rb_mut() } else { A.rb_mut().transpose_mut() };
798
799					let n = n - 1;
800
801					let V = V.rb().submatrix(1, 0, n, n);
802					let mut A = A.rb_mut().subrows_mut(1, n);
803					let H = H.as_ref();
804
805					householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
806						V,
807						H.as_ref(),
808						if iter == 0 { Conj::Yes } else { Conj::No },
809						A.rb_mut(),
810						Par::Seq,
811						MemStack::new(&mut MemBuffer::new(
812							householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<c64>(n, b, n + 1),
813						)),
814					);
815				}
816
817				let approx_eq = CwiseMat(ApproxEq::eps());
818				for j in 0..n {
819					for i in 0..n {
820						if i > j + 1 {
821							V[(i, j)] = c64::ZERO;
822						}
823					}
824				}
825
826				assert!(V ~ A);
827			}
828		}
829	}
830
831	#[test]
832	fn test_hessenberg_cplx_gqvdg() {
833		let rng = &mut StdRng::seed_from_u64(0);
834
835		for n in [2, 3, 4, 8, 16, 21] {
836			for par in [Par::Seq, Par::rayon(4)] {
837				let b = 4;
838
839				let A = CwiseMatDistribution {
840					nrows: n,
841					ncols: n,
842					dist: ComplexDistribution::new(StandardNormal, StandardNormal),
843				}
844				.rand::<Mat<c64, _, _>>(rng);
845
846				let mut H = Mat::zeros(b, n - 1);
847
848				let mut V = A.clone();
849				let mut V = V.as_mut();
850				hessenberg_gqvdg_blocked(
851					V.rb_mut(),
852					H.as_mut(),
853					par,
854					MemStack::new(&mut [MaybeUninit::uninit(); 16 * 1024]),
855					HessenbergParams {
856						par_threshold: 0,
857						..auto!(c64)
858					}
859					.into(),
860				);
861
862				let mut A = A.clone();
863				let mut A = A.as_mut();
864
865				for iter in 0..2 {
866					let mut A = if iter == 0 { A.rb_mut() } else { A.rb_mut().transpose_mut() };
867
868					let n = n - 1;
869
870					let V = V.rb().submatrix(1, 0, n, n);
871					let mut A = A.rb_mut().subrows_mut(1, n);
872					let H = H.as_ref().subcols(0, n);
873
874					householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
875						V,
876						H.as_ref(),
877						if iter == 0 { Conj::Yes } else { Conj::No },
878						A.rb_mut(),
879						Par::Seq,
880						MemStack::new(&mut MemBuffer::new(
881							householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<c64>(n, b, n + 1),
882						)),
883					);
884				}
885
886				let approx_eq = CwiseMat(ApproxEq::eps());
887				for j in 0..n {
888					for i in 0..n {
889						if i > j + 1 {
890							V[(i, j)] = c64::ZERO;
891						}
892					}
893				}
894
895				assert!(V ~ A);
896			}
897		}
898	}
899}