faer/linalg/cholesky/bunch_kaufman/
solve.rs1use crate::assert;
2use crate::internal_prelude::*;
3use crate::perm::permute_rows;
4use linalg::triangular_solve::{solve_unit_lower_triangular_in_place_with_conj, solve_unit_upper_triangular_in_place_with_conj};
5
6#[track_caller]
9pub fn solve_in_place_scratch<I: Index, T: ComplexField>(dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
10 let _ = par;
11 temp_mat_scratch::<T>(dim, rhs_ncols)
12}
13
14#[track_caller]
28#[math]
29pub fn solve_in_place_with_conj<I: Index, T: ComplexField>(
30 L: MatRef<'_, T>,
31 diagonal: DiagRef<'_, T>,
32 subdiagonal: DiagRef<'_, T>,
33 conj_A: Conj,
34 perm: PermRef<'_, I>,
35 rhs: MatMut<'_, T>,
36 par: Par,
37 stack: &mut MemStack,
38) {
39 let n = L.nrows();
40 let k = rhs.ncols();
41
42 assert!(all(
43 L.nrows() == n,
44 L.ncols() == n,
45 rhs.nrows() == n,
46 diagonal.dim() == n,
47 subdiagonal.dim() == n,
48 perm.len() == n
49 ));
50
51 let a = L;
52 let par = par;
53 let not_conj = conj_A.compose(Conj::Yes);
54
55 let mut rhs = rhs;
56 let mut x = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack).0 };
57 let mut x = x.as_mat_mut();
58
59 permute_rows(x.rb_mut(), rhs.rb(), perm);
60 solve_unit_lower_triangular_in_place_with_conj(a, conj_A, x.rb_mut(), par);
61
62 let mut i = 0;
63 while i < n {
64 let i0 = i;
65 let i1 = i + 1;
66
67 if subdiagonal[i] == zero() {
68 let d_inv = recip(real(diagonal[i]));
69 for j in 0..k {
70 x[(i, j)] = mul_real(x[(i, j)], d_inv);
71 }
72 i += 1;
73 } else {
74 let mut akp1k = copy(subdiagonal[i0]);
75 if matches!(conj_A, Conj::Yes) {
76 akp1k = conj(akp1k);
77 }
78 akp1k = recip(akp1k);
79 let (ak, akp1) = (mul_real(conj(akp1k), real(diagonal[i0])), mul_real(akp1k, real(diagonal[i1])));
80
81 let denom = real(recip(ak * akp1 - one()));
82
83 for j in 0..k {
84 let (xk, xkp1) = (
85 x[(i0, j)] * conj(akp1k),
87 x[(i1, j)] * akp1k,
88 );
89
90 let (xk, xkp1) = (mul_real((akp1 * xk - xkp1), denom), mul_real((ak * xkp1 - xk), denom));
91
92 x[(i, j)] = xk;
93 x[(i + 1, j)] = xkp1;
94 }
95
96 i += 2;
97 }
98 }
99
100 solve_unit_upper_triangular_in_place_with_conj(a.transpose(), not_conj, x.rb_mut(), par);
101 permute_rows(rhs.rb_mut(), x.rb(), perm.inverse());
102}
103
104#[track_caller]
105#[math]
106pub fn solve_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
107 L: MatRef<'_, C>,
108 diagonal: DiagRef<'_, C>,
109 subdiagonal: DiagRef<'_, C>,
110 perm: PermRef<'_, I>,
111 rhs: MatMut<'_, T>,
112 par: Par,
113 stack: &mut MemStack,
114) {
115 solve_in_place_with_conj(
116 L.canonical(),
117 diagonal.canonical(),
118 subdiagonal.canonical(),
119 Conj::get::<C>(),
120 perm,
121 rhs,
122 par,
123 stack,
124 );
125}