faer/linalg/cholesky/llt_pivoting/
solve.rs1use crate::assert;
2use crate::internal_prelude::*;
3use crate::perm::permute_rows;
4
5pub fn solve_in_place_scratch<I: Index, T: ComplexField>(dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
6 _ = par;
7 temp_mat_scratch::<T>(dim, rhs_ncols)
8}
9
10#[math]
11#[track_caller]
12pub fn solve_in_place_with_conj<I: Index, T: ComplexField>(
13 L: MatRef<'_, T>,
14 perm: PermRef<'_, I>,
15 conj_lhs: Conj,
16 rhs: MatMut<'_, T>,
17 par: Par,
18 stack: &mut MemStack,
19) {
20 let n = L.nrows();
21 let k = rhs.ncols();
22 assert!(all(L.nrows() == n, L.ncols() == n, rhs.nrows() == n));
23
24 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
25 let mut tmp = tmp.as_mat_mut();
26 let mut rhs = rhs;
27
28 permute_rows(tmp.rb_mut(), rhs.rb(), perm);
29 linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(L, conj_lhs, tmp.rb_mut(), par);
30 linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(L.transpose(), conj_lhs.compose(Conj::Yes), tmp.rb_mut(), par);
31 permute_rows(rhs.rb_mut(), tmp.rb(), perm.inverse());
32}
33
34#[math]
35#[track_caller]
36pub fn solve_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
37 L: MatRef<'_, C>,
38 perm: PermRef<'_, I>,
39 rhs: MatMut<'_, T>,
40 par: Par,
41 stack: &mut MemStack,
42) {
43 solve_in_place_with_conj(L.canonical(), perm, Conj::get::<C>(), rhs, par, stack);
44}
45
46#[cfg(test)]
47mod tests {
48 use super::*;
49 use crate::assert;
50 use crate::stats::prelude::*;
51 use crate::utils::approx::*;
52 use dyn_stack::MemBuffer;
53 use linalg::cholesky::llt_pivoting;
54
55 #[test]
56 fn test_solve() {
57 let rng = &mut StdRng::seed_from_u64(0);
58 let n = 50;
59 let k = 3;
60
61 let A = CwiseMatDistribution {
62 nrows: n,
63 ncols: n,
64 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
65 }
66 .rand::<Mat<c64>>(rng);
67
68 let B = CwiseMatDistribution {
69 nrows: n,
70 ncols: k,
71 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
72 }
73 .rand::<Mat<c64>>(rng);
74
75 let A = &A * A.adjoint();
76 let mut L = A.to_owned();
77 let perm_fwd = &mut *vec![0usize; n];
78 let perm_bwd = &mut *vec![0usize; n];
79
80 let (_, perm) = llt_pivoting::factor::cholesky_in_place(
81 L.as_mut(),
82 perm_fwd,
83 perm_bwd,
84 Par::Seq,
85 MemStack::new(&mut { MemBuffer::new(llt_pivoting::factor::cholesky_in_place_scratch::<usize, c64>(n, Par::Seq, default())) }),
86 default(),
87 )
88 .unwrap();
89
90 let approx_eq = CwiseMat(ApproxEq::eps() * 8.0 * (n as f64));
91
92 {
93 let mut X = B.to_owned();
94 llt_pivoting::solve::solve_in_place(
95 L.as_ref(),
96 perm,
97 X.as_mut(),
98 Par::Seq,
99 MemStack::new(&mut MemBuffer::new(llt_pivoting::solve::solve_in_place_scratch::<usize, c64>(
100 n,
101 k,
102 Par::Seq,
103 ))),
104 );
105
106 assert!(&A * &X ~ B);
107 }
108
109 {
110 let mut X = B.to_owned();
111 llt_pivoting::solve::solve_in_place(
112 L.conjugate(),
113 perm,
114 X.as_mut(),
115 Par::Seq,
116 MemStack::new(&mut MemBuffer::new(llt_pivoting::solve::solve_in_place_scratch::<usize, c64>(
117 n,
118 k,
119 Par::Seq,
120 ))),
121 );
122
123 assert!(A.conjugate() * &X ~ B);
124 }
125 }
126}