faer/linalg/cholesky/llt/
solve.rs

1use crate::assert;
2use crate::internal_prelude::*;
3
4pub fn solve_in_place_scratch<T: ComplexField>(dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
5	_ = (dim, rhs_ncols, par);
6	StackReq::EMPTY
7}
8
9#[math]
10#[track_caller]
11pub fn solve_in_place_with_conj<T: ComplexField>(L: MatRef<'_, T>, conj_lhs: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
12	let n = L.nrows();
13	assert!(all(L.nrows() == n, L.ncols() == n, rhs.nrows() == n));
14
15	_ = stack;
16	let mut rhs = rhs;
17	linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(L, conj_lhs, rhs.rb_mut(), par);
18
19	linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(L.transpose(), conj_lhs.compose(Conj::Yes), rhs.rb_mut(), par);
20}
21
22#[math]
23#[track_caller]
24pub fn solve_in_place<T: ComplexField, C: Conjugate<Canonical = T>>(L: MatRef<'_, C>, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
25	solve_in_place_with_conj(L.canonical(), Conj::get::<C>(), rhs, par, stack);
26}
27
28#[cfg(test)]
29mod tests {
30	use super::*;
31	use crate::assert;
32	use crate::stats::prelude::*;
33	use crate::utils::approx::*;
34	use dyn_stack::MemBuffer;
35	use linalg::cholesky::llt;
36
37	#[test]
38	fn test_solve() {
39		let rng = &mut StdRng::seed_from_u64(0);
40		let n = 50;
41		let k = 3;
42
43		let A = CwiseMatDistribution {
44			nrows: n,
45			ncols: n,
46			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
47		}
48		.rand::<Mat<c64>>(rng);
49
50		let B = CwiseMatDistribution {
51			nrows: n,
52			ncols: k,
53			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
54		}
55		.rand::<Mat<c64>>(rng);
56
57		let A = &A * A.adjoint();
58		let mut L = A.to_owned();
59
60		llt::factor::cholesky_in_place(
61			L.as_mut(),
62			Default::default(),
63			Par::Seq,
64			MemStack::new(&mut { MemBuffer::new(llt::factor::cholesky_in_place_scratch::<c64>(n, Par::Seq, default())) }),
65			default(),
66		)
67		.unwrap();
68
69		let approx_eq = CwiseMat(ApproxEq::eps() * 8.0 * (n as f64));
70
71		{
72			let mut X = B.to_owned();
73			llt::solve::solve_in_place(
74				L.as_ref(),
75				X.as_mut(),
76				Par::Seq,
77				MemStack::new(&mut MemBuffer::new(llt::solve::solve_in_place_scratch::<c64>(n, k, Par::Seq))),
78			);
79
80			assert!(&A * &X ~ B);
81		}
82
83		{
84			let mut X = B.to_owned();
85			llt::solve::solve_in_place(
86				L.conjugate(),
87				X.as_mut(),
88				Par::Seq,
89				MemStack::new(&mut MemBuffer::new(llt::solve::solve_in_place_scratch::<c64>(n, k, Par::Seq))),
90			);
91
92			assert!(A.conjugate() * &X ~ B);
93		}
94	}
95}