faer/linalg/cholesky/llt/
solve.rs1use crate::assert;
2use crate::internal_prelude::*;
3
4pub fn solve_in_place_scratch<T: ComplexField>(dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
5 _ = (dim, rhs_ncols, par);
6 StackReq::EMPTY
7}
8
9#[math]
10#[track_caller]
11pub fn solve_in_place_with_conj<T: ComplexField>(L: MatRef<'_, T>, conj_lhs: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
12 let n = L.nrows();
13 assert!(all(L.nrows() == n, L.ncols() == n, rhs.nrows() == n));
14
15 _ = stack;
16 let mut rhs = rhs;
17 linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(L, conj_lhs, rhs.rb_mut(), par);
18
19 linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(L.transpose(), conj_lhs.compose(Conj::Yes), rhs.rb_mut(), par);
20}
21
22#[math]
23#[track_caller]
24pub fn solve_in_place<T: ComplexField, C: Conjugate<Canonical = T>>(L: MatRef<'_, C>, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
25 solve_in_place_with_conj(L.canonical(), Conj::get::<C>(), rhs, par, stack);
26}
27
28#[cfg(test)]
29mod tests {
30 use super::*;
31 use crate::assert;
32 use crate::stats::prelude::*;
33 use crate::utils::approx::*;
34 use dyn_stack::MemBuffer;
35 use linalg::cholesky::llt;
36
37 #[test]
38 fn test_solve() {
39 let rng = &mut StdRng::seed_from_u64(0);
40 let n = 50;
41 let k = 3;
42
43 let A = CwiseMatDistribution {
44 nrows: n,
45 ncols: n,
46 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
47 }
48 .rand::<Mat<c64>>(rng);
49
50 let B = CwiseMatDistribution {
51 nrows: n,
52 ncols: k,
53 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
54 }
55 .rand::<Mat<c64>>(rng);
56
57 let A = &A * A.adjoint();
58 let mut L = A.to_owned();
59
60 llt::factor::cholesky_in_place(
61 L.as_mut(),
62 Default::default(),
63 Par::Seq,
64 MemStack::new(&mut { MemBuffer::new(llt::factor::cholesky_in_place_scratch::<c64>(n, Par::Seq, default())) }),
65 default(),
66 )
67 .unwrap();
68
69 let approx_eq = CwiseMat(ApproxEq::eps() * 8.0 * (n as f64));
70
71 {
72 let mut X = B.to_owned();
73 llt::solve::solve_in_place(
74 L.as_ref(),
75 X.as_mut(),
76 Par::Seq,
77 MemStack::new(&mut MemBuffer::new(llt::solve::solve_in_place_scratch::<c64>(n, k, Par::Seq))),
78 );
79
80 assert!(&A * &X ~ B);
81 }
82
83 {
84 let mut X = B.to_owned();
85 llt::solve::solve_in_place(
86 L.conjugate(),
87 X.as_mut(),
88 Par::Seq,
89 MemStack::new(&mut MemBuffer::new(llt::solve::solve_in_place_scratch::<c64>(n, k, Par::Seq))),
90 );
91
92 assert!(A.conjugate() * &X ~ B);
93 }
94 }
95}