faer/linalg/cholesky/llt_pivoting/
solve.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use crate::perm::permute_rows;
4
5pub fn solve_in_place_scratch<I: Index, T: ComplexField>(dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
6	_ = par;
7	temp_mat_scratch::<T>(dim, rhs_ncols)
8}
9
10#[math]
11#[track_caller]
12pub fn solve_in_place_with_conj<I: Index, T: ComplexField>(
13	L: MatRef<'_, T>,
14	perm: PermRef<'_, I>,
15	conj_lhs: Conj,
16	rhs: MatMut<'_, T>,
17	par: Par,
18	stack: &mut MemStack,
19) {
20	let n = L.nrows();
21	let k = rhs.ncols();
22	assert!(all(L.nrows() == n, L.ncols() == n, rhs.nrows() == n));
23
24	let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
25	let mut tmp = tmp.as_mat_mut();
26	let mut rhs = rhs;
27
28	permute_rows(tmp.rb_mut(), rhs.rb(), perm);
29	linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(L, conj_lhs, tmp.rb_mut(), par);
30	linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(L.transpose(), conj_lhs.compose(Conj::Yes), tmp.rb_mut(), par);
31	permute_rows(rhs.rb_mut(), tmp.rb(), perm.inverse());
32}
33
34#[math]
35#[track_caller]
36pub fn solve_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
37	L: MatRef<'_, C>,
38	perm: PermRef<'_, I>,
39	rhs: MatMut<'_, T>,
40	par: Par,
41	stack: &mut MemStack,
42) {
43	solve_in_place_with_conj(L.canonical(), perm, Conj::get::<C>(), rhs, par, stack);
44}
45
46#[cfg(test)]
47mod tests {
48	use super::*;
49	use crate::assert;
50	use crate::stats::prelude::*;
51	use crate::utils::approx::*;
52	use dyn_stack::MemBuffer;
53	use linalg::cholesky::llt_pivoting;
54
55	#[test]
56	fn test_solve() {
57		let rng = &mut StdRng::seed_from_u64(0);
58		let n = 50;
59		let k = 3;
60
61		let A = CwiseMatDistribution {
62			nrows: n,
63			ncols: n,
64			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
65		}
66		.rand::<Mat<c64>>(rng);
67
68		let B = CwiseMatDistribution {
69			nrows: n,
70			ncols: k,
71			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
72		}
73		.rand::<Mat<c64>>(rng);
74
75		let A = &A * A.adjoint();
76		let mut L = A.to_owned();
77		let perm_fwd = &mut *vec![0usize; n];
78		let perm_bwd = &mut *vec![0usize; n];
79
80		let (_, perm) = llt_pivoting::factor::cholesky_in_place(
81			L.as_mut(),
82			perm_fwd,
83			perm_bwd,
84			Par::Seq,
85			MemStack::new(&mut { MemBuffer::new(llt_pivoting::factor::cholesky_in_place_scratch::<usize, c64>(n, Par::Seq, default())) }),
86			default(),
87		)
88		.unwrap();
89
90		let approx_eq = CwiseMat(ApproxEq::eps() * 8.0 * (n as f64));
91
92		{
93			let mut X = B.to_owned();
94			llt_pivoting::solve::solve_in_place(
95				L.as_ref(),
96				perm,
97				X.as_mut(),
98				Par::Seq,
99				MemStack::new(&mut MemBuffer::new(llt_pivoting::solve::solve_in_place_scratch::<usize, c64>(
100					n,
101					k,
102					Par::Seq,
103				))),
104			);
105
106			assert!(&A * &X ~ B);
107		}
108
109		{
110			let mut X = B.to_owned();
111			llt_pivoting::solve::solve_in_place(
112				L.conjugate(),
113				perm,
114				X.as_mut(),
115				Par::Seq,
116				MemStack::new(&mut MemBuffer::new(llt_pivoting::solve::solve_in_place_scratch::<usize, c64>(
117					n,
118					k,
119					Par::Seq,
120				))),
121			);
122
123			assert!(A.conjugate() * &X ~ B);
124		}
125	}
126}