faer/linalg/cholesky/ldlt/
solve.rs1use crate::assert;
2use crate::internal_prelude::*;
3
4pub fn solve_in_place_scratch<T: ComplexField>(dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
5 _ = (dim, rhs_ncols, par);
6 StackReq::EMPTY
7}
8
9#[math]
10#[track_caller]
11pub fn solve_in_place_with_conj<T: ComplexField>(
12 L: MatRef<'_, T>,
13 D: DiagRef<'_, T>,
14 conj_lhs: Conj,
15 rhs: MatMut<'_, T>,
16 par: Par,
17 stack: &mut MemStack,
18) {
19 let n = L.nrows();
20 _ = stack;
21 assert!(all(L.nrows() == n, L.ncols() == n, D.dim() == n, rhs.nrows() == n,));
22
23 let mut rhs = rhs;
24 linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj(L, conj_lhs, rhs.rb_mut(), par);
25
26 {
27 with_dim!(N, rhs.nrows());
28 with_dim!(K, rhs.ncols());
29
30 let D = D.as_shape(N);
31 let mut rhs = rhs.rb_mut().as_shape_mut(N, K);
32
33 for j in K.indices() {
34 for i in N.indices() {
35 let d = recip(real(D[i]));
36 rhs[(i, j)] = mul_real(rhs[(i, j)], d);
37 }
38 }
39 }
40
41 linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj(L.transpose(), conj_lhs.compose(Conj::Yes), rhs.rb_mut(), par);
42}
43
44#[math]
45#[track_caller]
46pub fn solve_in_place<T: ComplexField, C: Conjugate<Canonical = T>>(
47 L: MatRef<'_, C>,
48 D: DiagRef<'_, C>,
49 rhs: MatMut<'_, T>,
50 par: Par,
51 stack: &mut MemStack,
52) {
53 solve_in_place_with_conj(L.canonical(), D.canonical(), Conj::get::<C>(), rhs, par, stack);
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59 use crate::assert;
60 use crate::stats::prelude::*;
61 use crate::utils::approx::*;
62 use dyn_stack::MemBuffer;
63 use linalg::cholesky::ldlt;
64
65 #[test]
66 fn test_solve() {
67 let rng = &mut StdRng::seed_from_u64(0);
68 let n = 50;
69 let k = 3;
70
71 let A = CwiseMatDistribution {
72 nrows: n,
73 ncols: n,
74 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
75 }
76 .rand::<Mat<c64>>(rng);
77
78 let B = CwiseMatDistribution {
79 nrows: n,
80 ncols: k,
81 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
82 }
83 .rand::<Mat<c64>>(rng);
84
85 let A = &A * A.adjoint();
86 let mut L = A.to_owned();
87
88 ldlt::factor::cholesky_in_place(
89 L.as_mut(),
90 Default::default(),
91 Par::Seq,
92 MemStack::new(&mut MemBuffer::new(ldlt::factor::cholesky_in_place_scratch::<c64>(
93 n,
94 Par::Seq,
95 default(),
96 ))),
97 default(),
98 )
99 .unwrap();
100
101 let approx_eq = CwiseMat(ApproxEq::eps() * 8.0 * (n as f64));
102
103 {
104 let mut X = B.to_owned();
105 ldlt::solve::solve_in_place(
106 L.as_ref(),
107 L.diagonal(),
108 X.as_mut(),
109 Par::Seq,
110 MemStack::new(&mut MemBuffer::new(ldlt::solve::solve_in_place_scratch::<c64>(n, k, Par::Seq))),
111 );
112
113 assert!(&A * &X ~ B);
114 }
115
116 {
117 let mut X = B.to_owned();
118 ldlt::solve::solve_in_place(
119 L.conjugate(),
120 L.conjugate().diagonal(),
121 X.as_mut(),
122 Par::Seq,
123 MemStack::new(&mut MemBuffer::new(ldlt::solve::solve_in_place_scratch::<c64>(n, k, Par::Seq))),
124 );
125
126 assert!(A.conjugate() * &X ~ B);
127 }
128 }
129}