1use crate::assert;
32use crate::internal_prelude::*;
33use crate::linalg::matmul::triangular::{self, BlockStructure};
34use crate::linalg::matmul::{dot, matmul, matmul_with_conj};
35use crate::linalg::triangular_solve;
36use crate::utils::simd::SimdCtx;
37use crate::utils::thread::join_raw;
38
39#[derive(Clone, Debug)]
41pub struct HouseholderInfo<T: ComplexField> {
42 pub tau: T::Real,
44
45 pub head_with_beta_inv: T,
47
48 pub norm: T::Real,
50}
51
52#[math]
59fn make_householder_imp<T: ComplexField>(head: &mut T, out: ColMut<'_, T>, input: Option<ColRef<'_, T>>) -> HouseholderInfo<T> {
60 let tail = input.unwrap_or(out.rb());
61 let tail_norm = tail.norm_l2();
62
63 let mut head_norm = abs(*head);
64 if head_norm < min_positive() {
65 *head = zero();
66 head_norm = zero();
67 }
68
69 if tail_norm < min_positive() {
70 return HouseholderInfo {
71 tau: infinity::<T::Real>(),
72 head_with_beta_inv: infinity(),
73 norm: head_norm,
74 };
75 }
76
77 let one_half = from_f64::<T::Real>(0.5);
78
79 let norm = hypot(head_norm, tail_norm);
80
81 let sign = if head_norm != zero() { mul_real(*head, recip(head_norm)) } else { one() };
82
83 let signed_norm = sign * from_real(norm);
84 let head_with_beta = *head + signed_norm;
85 let head_with_beta_inv = recip(head_with_beta);
86
87 match input {
88 None => zip!(out).for_each(|unzip!(e)| {
89 *e = *e * head_with_beta_inv;
90 }),
91 Some(input) => zip!(out, input).for_each(|unzip!(o, e)| {
92 *o = *e * head_with_beta_inv;
93 }),
94 }
95
96 *head = -signed_norm;
97
98 let tau = one_half * (one::<T::Real>() + abs2(tail_norm * abs(head_with_beta_inv)));
99 HouseholderInfo {
100 tau,
101 head_with_beta_inv,
102 norm,
103 }
104}
105
106#[inline]
113pub fn make_householder_in_place<T: ComplexField>(head: &mut T, tail: ColMut<'_, T>) -> HouseholderInfo<T> {
114 make_householder_imp(head, tail, None)
115}
116
117#[inline]
118pub(crate) fn make_householder_out_of_place<T: ComplexField>(head: &mut T, out: ColMut<'_, T>, tail: ColRef<'_, T>) -> HouseholderInfo<T> {
119 make_householder_imp(head, out, Some(tail))
120}
121
122#[doc(hidden)]
123#[math]
124pub fn upgrade_householder_factor<T: ComplexField>(
125 householder_factor: MatMut<'_, T>,
126 essentials: MatRef<'_, T>,
127 blocksize: usize,
128 prev_blocksize: usize,
129 par: Par,
130) {
131 assert!(all(
132 householder_factor.nrows() == householder_factor.ncols(),
133 essentials.ncols() == householder_factor.ncols(),
134 ));
135
136 if blocksize == prev_blocksize || householder_factor.nrows().unbound() <= prev_blocksize {
137 return;
138 }
139
140 let n = essentials.ncols();
141 let mut householder_factor = householder_factor;
142 let essentials = essentials;
143
144 assert!(householder_factor.nrows() == householder_factor.ncols());
145
146 let block_count = householder_factor.nrows().msrv_div_ceil(blocksize);
147 if block_count > 1 {
148 assert!(all(blocksize > prev_blocksize, blocksize % prev_blocksize == 0,));
149 let mid = block_count / 2;
150
151 let (tau_tl, _, _, tau_br) = householder_factor.split_at_mut(mid, mid);
152 let (basis_left, basis_right) = essentials.split_at_col(mid);
153 let basis_right = basis_right.split_at_row(mid).1;
154 join_raw(
155 |parallelism| upgrade_householder_factor(tau_tl, basis_left, blocksize, prev_blocksize, parallelism),
156 |parallelism| upgrade_householder_factor(tau_br, basis_right, blocksize, prev_blocksize, parallelism),
157 par,
158 );
159 return;
160 }
161
162 if prev_blocksize < 8 {
163 let (basis_top, basis_bot) = essentials.split_at_row(n);
166 let acc_structure = BlockStructure::UnitTriangularUpper;
167
168 triangular::matmul(
169 householder_factor.rb_mut(),
170 acc_structure,
171 Accum::Replace,
172 basis_top.adjoint(),
173 BlockStructure::UnitTriangularUpper,
174 basis_top,
175 BlockStructure::UnitTriangularLower,
176 one(),
177 par,
178 );
179 triangular::matmul(
180 householder_factor.rb_mut(),
181 acc_structure,
182 Accum::Add,
183 basis_bot.adjoint(),
184 BlockStructure::Rectangular,
185 basis_bot,
186 BlockStructure::Rectangular,
187 one(),
188 par,
189 );
190 } else {
191 let prev_block_count = householder_factor.nrows().msrv_div_ceil(prev_blocksize);
192
193 let mid = (prev_block_count / 2) * prev_blocksize;
194
195 let (tau_tl, mut tau_tr, _, tau_br) = householder_factor.split_at_mut(mid, mid);
196 let (basis_left, basis_right) = essentials.split_at_col(mid);
197 let basis_right = basis_right.split_at_row(mid).1;
198
199 join_raw(
200 |parallelism| {
201 join_raw(
202 |parallelism| upgrade_householder_factor(tau_tl, basis_left, blocksize, prev_blocksize, parallelism),
203 |parallelism| upgrade_householder_factor(tau_br, basis_right, blocksize, prev_blocksize, parallelism),
204 parallelism,
205 );
206 },
207 |parallelism| {
208 let basis_left = basis_left.split_at_row(mid).1;
209 let row_mid = basis_right.ncols();
210
211 let (basis_left_top, basis_left_bot) = basis_left.split_at_row(row_mid);
212 let (basis_right_top, basis_right_bot) = basis_right.split_at_row(row_mid);
213
214 triangular::matmul(
215 tau_tr.rb_mut(),
216 BlockStructure::Rectangular,
217 Accum::Replace,
218 basis_left_top.adjoint(),
219 BlockStructure::Rectangular,
220 basis_right_top,
221 BlockStructure::UnitTriangularLower,
222 one(),
223 parallelism,
224 );
225 matmul(tau_tr.rb_mut(), Accum::Add, basis_left_bot.adjoint(), basis_right_bot, one(), parallelism);
226 },
227 par,
228 );
229 }
230}
231
232pub fn apply_block_householder_on_the_left_in_place_scratch<T: ComplexField>(
235 householder_basis_nrows: usize,
236 blocksize: usize,
237 rhs_ncols: usize,
238) -> StackReq {
239 let _ = householder_basis_nrows;
240 temp_mat_scratch::<T>(blocksize, rhs_ncols)
241}
242
243pub fn apply_block_householder_transpose_on_the_left_in_place_scratch<T: ComplexField>(
246 householder_basis_nrows: usize,
247 blocksize: usize,
248 rhs_ncols: usize,
249) -> StackReq {
250 let _ = householder_basis_nrows;
251 temp_mat_scratch::<T>(blocksize, rhs_ncols)
252}
253
254pub fn apply_block_householder_on_the_right_in_place_scratch<T: ComplexField>(
257 householder_basis_nrows: usize,
258 blocksize: usize,
259 lhs_nrows: usize,
260) -> StackReq {
261 let _ = householder_basis_nrows;
262 temp_mat_scratch::<T>(blocksize, lhs_nrows)
263}
264
265pub fn apply_block_householder_transpose_on_the_right_in_place_scratch<T: ComplexField>(
268 householder_basis_nrows: usize,
269 blocksize: usize,
270 lhs_nrows: usize,
271) -> StackReq {
272 let _ = householder_basis_nrows;
273 temp_mat_scratch::<T>(blocksize, lhs_nrows)
274}
275
276pub fn apply_block_householder_sequence_transpose_on_the_left_in_place_scratch<T: ComplexField>(
279 householder_basis_nrows: usize,
280 blocksize: usize,
281 rhs_ncols: usize,
282) -> StackReq {
283 let _ = householder_basis_nrows;
284 temp_mat_scratch::<T>(blocksize, rhs_ncols)
285}
286
287pub fn apply_block_householder_sequence_on_the_left_in_place_scratch<T: ComplexField>(
290 householder_basis_nrows: usize,
291 blocksize: usize,
292 rhs_ncols: usize,
293) -> StackReq {
294 let _ = householder_basis_nrows;
295 temp_mat_scratch::<T>(blocksize, rhs_ncols)
296}
297
298pub fn apply_block_householder_sequence_transpose_on_the_right_in_place_scratch<T: ComplexField>(
301 householder_basis_nrows: usize,
302 blocksize: usize,
303 lhs_nrows: usize,
304) -> StackReq {
305 let _ = householder_basis_nrows;
306 temp_mat_scratch::<T>(blocksize, lhs_nrows)
307}
308
309pub fn apply_block_householder_sequence_on_the_right_in_place_scratch<T: ComplexField>(
312 householder_basis_nrows: usize,
313 blocksize: usize,
314 lhs_nrows: usize,
315) -> StackReq {
316 let _ = householder_basis_nrows;
317 temp_mat_scratch::<T>(blocksize, lhs_nrows)
318}
319
320#[track_caller]
321#[math]
322fn apply_block_householder_on_the_left_in_place_generic<'M, 'N, 'K, T: ComplexField>(
323 householder_basis: MatRef<'_, T, Dim<'M>, Dim<'N>>,
324 householder_factor: MatRef<'_, T, Dim<'N>, Dim<'N>>,
325 conj_lhs: Conj,
326 matrix: MatMut<'_, T, Dim<'M>, Dim<'K>>,
327 forward: bool,
328 par: Par,
329 stack: &mut MemStack,
330) {
331 assert!(all(
332 householder_factor.nrows() == householder_factor.ncols(),
333 householder_basis.ncols() == householder_factor.nrows(),
334 matrix.nrows() == householder_basis.nrows(),
335 ));
336
337 let mut matrix = matrix;
338
339 let M = householder_basis.nrows();
340 let N = householder_basis.ncols();
341
342 make_guard!(TAIL);
343 let midpoint = M.head_partition(N, TAIL);
344
345 if let (Some(householder_basis), Some(matrix), 1, true) = (
346 householder_basis.try_as_col_major(),
347 matrix.rb_mut().try_as_col_major_mut(),
348 N.unbound(),
349 T::SIMD_CAPABILITIES.is_simd(),
350 ) {
351 struct ApplyOnLeft<'a, 'TAIL, 'K, T: ComplexField, const CONJ: bool> {
352 tau_inv: &'a T,
353 essential: ColRef<'a, T, Dim<'TAIL>, ContiguousFwd>,
354 rhs0: RowMut<'a, T, Dim<'K>>,
355 rhs: MatMut<'a, T, Dim<'TAIL>, Dim<'K>, ContiguousFwd>,
356 }
357
358 impl<'TAIL, 'K, T: ComplexField, const CONJ: bool> pulp::WithSimd for ApplyOnLeft<'_, 'TAIL, 'K, T, CONJ> {
359 type Output = ();
360
361 #[inline(always)]
362 fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
363 let Self {
364 tau_inv,
365 essential,
366 mut rhs,
367 mut rhs0,
368 } = self;
369
370 if rhs.nrows().unbound() == 0 {
371 return;
372 }
373
374 let N = rhs.nrows();
375 let K = rhs.ncols();
376 let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), N);
377
378 let (head, indices, tail) = simd.indices();
379
380 for idx in K.indices() {
381 let col0 = rhs0.rb_mut().at_mut(idx);
382 let mut col = rhs.rb_mut().col_mut(idx);
383 let essential = essential;
384
385 let dot = if try_const! { CONJ } {
386 *col0 + dot::inner_prod_no_conj_simd(simd, essential.rb(), col.rb())
387 } else {
388 *col0 + dot::inner_prod_conj_lhs_simd(simd, essential.rb(), col.rb())
389 };
390
391 let k = -dot * tau_inv;
392 *col0 = *col0 + k;
393
394 let k = simd.splat(&k);
395 macro_rules! simd {
396 ($i: expr) => {{
397 let i = $i;
398 let mut a = simd.read(col.rb(), i);
399 let b = simd.read(essential.rb(), i);
400
401 if try_const! { CONJ } {
402 a = simd.conj_mul_add(b, k, a);
403 } else {
404 a = simd.mul_add(b, k, a);
405 }
406
407 simd.write(col.rb_mut(), i, a);
408 }};
409 }
410
411 if let Some(i) = head {
412 simd!(i);
413 }
414 for i in indices.clone() {
415 simd!(i);
416 }
417 if let Some(i) = tail {
418 simd!(i);
419 }
420 }
421 }
422 }
423
424 let N0 = N.check(0);
425
426 let essential = householder_basis.col(N0).split_rows_with(midpoint).1;
427 let (rhs0, rhs) = matrix.split_rows_with_mut(midpoint);
428 let rhs0 = rhs0.row_mut(N0);
429
430 let tau_inv: T = from_real(recip(real(householder_factor[(N0, N0)])));
431
432 if try_const! { T::IS_REAL } {
433 type Apply<'a, 'TAIL, 'K, T> = ApplyOnLeft<'a, 'TAIL, 'K, T, false>;
434
435 dispatch!(
436 Apply {
437 tau_inv: &tau_inv,
438 essential,
439 rhs,
440 rhs0,
441 },
442 Apply,
443 T
444 );
445 } else if matches!(conj_lhs, Conj::No) {
446 type Apply<'a, 'TAIL, 'K, T> = ApplyOnLeft<'a, 'TAIL, 'K, T, false>;
447
448 dispatch!(
449 Apply {
450 tau_inv: &tau_inv,
451 essential,
452 rhs,
453 rhs0,
454 },
455 Apply,
456 T
457 );
458 } else {
459 type Apply<'a, 'TAIL, 'K, T> = ApplyOnLeft<'a, 'TAIL, 'K, T, true>;
460
461 dispatch!(
462 Apply {
463 tau_inv: &tau_inv,
464 essential,
465 rhs,
466 rhs0,
467 },
468 Apply,
469 T
470 );
471 }
472 } else {
473 let (essentials_top, essentials_bot) = householder_basis.split_rows_with(midpoint);
474 let M = matrix.nrows();
475 let K = matrix.ncols();
476
477 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(N, K, stack) };
479 let mut tmp = tmp.as_mat_mut();
480
481 let mut n_tasks = Ord::min(Ord::min(crate::utils::thread::parallelism_degree(par), K.unbound()), 4);
482 if (M.unbound() * K.unbound()).saturating_mul(4 * M.unbound()) < gemm::get_threading_threshold() {
483 n_tasks = 1;
484 }
485
486 let inner_parallelism = match par {
487 Par::Seq => Par::Seq,
488 #[cfg(feature = "rayon")]
489 Par::Rayon(par) => {
490 let par = par.get();
491
492 if par >= 2 * n_tasks { Par::rayon(par / n_tasks) } else { Par::Seq }
493 },
494 };
495
496 let func = |(mut tmp, mut matrix): (MatMut<'_, T, Dim<'N>>, MatMut<'_, T, Dim<'M>>)| {
497 let (mut top, mut bot) = matrix.rb_mut().split_rows_with_mut(midpoint);
498
499 triangular::matmul_with_conj(
500 tmp.rb_mut(),
501 BlockStructure::Rectangular,
502 Accum::Replace,
503 essentials_top.transpose(),
504 BlockStructure::UnitTriangularUpper,
505 Conj::Yes.compose(conj_lhs),
506 top.rb(),
507 BlockStructure::Rectangular,
508 Conj::No,
509 one(),
510 inner_parallelism,
511 );
512
513 matmul_with_conj(
514 tmp.rb_mut(),
515 Accum::Add,
516 essentials_bot.transpose(),
517 Conj::Yes.compose(conj_lhs),
518 bot.rb(),
519 Conj::No,
520 one(),
521 inner_parallelism,
522 );
523
524 if forward {
526 triangular_solve::solve_lower_triangular_in_place_with_conj(
527 householder_factor.transpose(),
528 Conj::Yes.compose(conj_lhs),
529 tmp.rb_mut(),
530 inner_parallelism,
531 );
532 } else {
533 triangular_solve::solve_upper_triangular_in_place_with_conj(
534 householder_factor,
535 Conj::No.compose(conj_lhs),
536 tmp.rb_mut(),
537 inner_parallelism,
538 );
539 }
540
541 triangular::matmul_with_conj(
543 top.rb_mut(),
544 BlockStructure::Rectangular,
545 Accum::Add,
546 essentials_top,
547 BlockStructure::UnitTriangularLower,
548 Conj::No.compose(conj_lhs),
549 tmp.rb(),
550 BlockStructure::Rectangular,
551 Conj::No,
552 -one::<T>(),
553 inner_parallelism,
554 );
555 matmul_with_conj(
556 bot.rb_mut(),
557 Accum::Add,
558 essentials_bot,
559 Conj::No.compose(conj_lhs),
560 tmp.rb(),
561 Conj::No,
562 -one::<T>(),
563 inner_parallelism,
564 );
565 };
566
567 if n_tasks <= 1 {
568 func((tmp.as_dyn_cols_mut(), matrix.as_dyn_cols_mut()));
569 return;
570 } else {
571 #[cfg(feature = "rayon")]
572 {
573 use rayon::prelude::*;
574 tmp.rb_mut()
575 .par_col_partition_mut(n_tasks)
576 .zip_eq(matrix.rb_mut().par_col_partition_mut(n_tasks))
577 .for_each(func);
578 }
579 }
580 }
581}
582
583#[track_caller]
586pub fn apply_block_householder_on_the_right_in_place_with_conj<T: ComplexField>(
587 householder_basis: MatRef<'_, T>,
588 householder_factor: MatRef<'_, T>,
589 conj_rhs: Conj,
590 matrix: MatMut<'_, T>,
591 par: Par,
592 stack: &mut MemStack,
593) {
594 apply_block_householder_transpose_on_the_left_in_place_with_conj(
595 householder_basis,
596 householder_factor,
597 conj_rhs,
598 matrix.transpose_mut(),
599 par,
600 stack,
601 )
602}
603
604#[track_caller]
607pub fn apply_block_householder_transpose_on_the_right_in_place_with_conj<T: ComplexField>(
608 householder_basis: MatRef<'_, T>,
609 householder_factor: MatRef<'_, T>,
610 conj_rhs: Conj,
611 matrix: MatMut<'_, T>,
612 par: Par,
613 stack: &mut MemStack,
614) {
615 apply_block_householder_on_the_left_in_place_with_conj(householder_basis, householder_factor, conj_rhs, matrix.transpose_mut(), par, stack)
616}
617
618#[track_caller]
621pub fn apply_block_householder_on_the_left_in_place_with_conj<T: ComplexField>(
622 householder_basis: MatRef<'_, T>,
623 householder_factor: MatRef<'_, T>,
624 conj_lhs: Conj,
625 matrix: MatMut<'_, T>,
626 par: Par,
627 stack: &mut MemStack,
628) {
629 make_guard!(M);
630 make_guard!(N);
631 make_guard!(K);
632 let M = householder_basis.nrows().bind(M);
633 let N = householder_basis.ncols().bind(N);
634 let K = matrix.ncols().bind(K);
635
636 apply_block_householder_on_the_left_in_place_generic(
637 householder_basis.as_shape(M, N).as_dyn_stride(),
638 householder_factor.as_shape(N, N).as_dyn_stride(),
639 conj_lhs,
640 matrix.as_shape_mut(M, K).as_dyn_stride_mut(),
641 false,
642 par,
643 stack,
644 )
645}
646
647#[track_caller]
650pub fn apply_block_householder_transpose_on_the_left_in_place_with_conj<T: ComplexField>(
651 householder_basis: MatRef<'_, T>,
652 householder_factor: MatRef<'_, T>,
653 conj_lhs: Conj,
654 matrix: MatMut<'_, T>,
655 par: Par,
656 stack: &mut MemStack,
657) {
658 with_dim!(M, householder_basis.nrows());
659 with_dim!(N, householder_basis.ncols());
660 with_dim!(K, matrix.ncols());
661
662 apply_block_householder_on_the_left_in_place_generic(
663 householder_basis.as_shape(M, N).as_dyn_stride(),
664 householder_factor.as_shape(N, N).as_dyn_stride(),
665 conj_lhs.compose(Conj::Yes),
666 matrix.as_shape_mut(M, K).as_dyn_stride_mut(),
667 true,
668 par,
669 stack,
670 )
671}
672
673#[track_caller]
677pub fn apply_block_householder_sequence_on_the_left_in_place_with_conj<T: ComplexField>(
678 householder_basis: MatRef<'_, T>,
679 householder_factor: MatRef<'_, T>,
680 conj_lhs: Conj,
681 matrix: MatMut<'_, T>,
682 par: Par,
683 stack: &mut MemStack,
684) {
685 let mut matrix = matrix;
686 let mut stack = stack;
687 let m = householder_basis.nrows();
688 let n = householder_basis.ncols();
689
690 assert!(all(householder_factor.nrows() > 0, householder_factor.ncols() == Ord::min(m, n),));
691
692 let size = householder_factor.ncols();
693
694 let mut j = size;
695
696 let mut blocksize = size % householder_factor.nrows();
697 if blocksize == 0 {
698 blocksize = householder_factor.nrows();
699 }
700
701 while j > 0 {
702 let j_prev = j - blocksize;
703 blocksize = householder_factor.nrows();
704
705 let essentials = householder_basis.get(j_prev.., j_prev..j);
706 let householder = householder_factor.get(.., j_prev..j).subrows(0, j - j_prev);
707 let matrix = matrix.rb_mut().get_mut(j_prev.., ..);
708
709 apply_block_householder_on_the_left_in_place_with_conj(essentials, householder, conj_lhs, matrix, par, stack.rb_mut());
710
711 j = j_prev;
712 }
713}
714
715#[track_caller]
719pub fn apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj<T: ComplexField>(
720 householder_basis: MatRef<'_, T>,
721 householder_factor: MatRef<'_, T>,
722 conj_lhs: Conj,
723 matrix: MatMut<'_, T>,
724 par: Par,
725 stack: &mut MemStack,
726) {
727 let mut matrix = matrix;
728 let mut stack = stack;
729
730 let blocksize = householder_factor.nrows();
731
732 let m = householder_basis.nrows();
733 let n = householder_basis.ncols();
734
735 assert!(all(householder_factor.nrows() > 0, householder_factor.ncols() == Ord::min(m, n),));
736
737 let size = householder_factor.ncols();
738
739 let mut j = 0;
740 while j < size {
741 let blocksize = Ord::min(blocksize, size - j);
742
743 let essentials = householder_basis.get(j.., j..j + blocksize);
744 let householder = householder_factor.get(.., j..j + blocksize).subrows(0, blocksize);
745
746 let matrix = matrix.rb_mut().get_mut(j.., ..);
747
748 apply_block_householder_transpose_on_the_left_in_place_with_conj(essentials, householder, conj_lhs, matrix, par, stack.rb_mut());
749
750 j += blocksize;
751 }
752}
753
754#[track_caller]
757pub fn apply_block_householder_sequence_on_the_right_in_place_with_conj<T: ComplexField>(
758 householder_basis: MatRef<'_, T>,
759 householder_factor: MatRef<'_, T>,
760 conj_rhs: Conj,
761 matrix: MatMut<'_, T>,
762 par: Par,
763 stack: &mut MemStack,
764) {
765 apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
766 householder_basis,
767 householder_factor,
768 conj_rhs,
769 matrix.transpose_mut(),
770 par,
771 stack,
772 )
773}
774
775#[track_caller]
779pub fn apply_block_householder_sequence_transpose_on_the_right_in_place_with_conj<T: ComplexField>(
780 householder_basis: MatRef<'_, T>,
781 householder_factor: MatRef<'_, T>,
782 conj_rhs: Conj,
783 matrix: MatMut<'_, T>,
784 par: Par,
785 stack: &mut MemStack,
786) {
787 apply_block_householder_sequence_on_the_left_in_place_with_conj(
788 householder_basis,
789 householder_factor,
790 conj_rhs,
791 matrix.transpose_mut(),
792 par,
793 stack,
794 )
795}