faer/linalg/cholesky/ldlt/
solve.rs

1use crate::assert;
2use crate::internal_prelude::*;
3
4pub fn solve_in_place_scratch<T: ComplexField>(dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
5	_ = (dim, rhs_ncols, par);
6	StackReq::EMPTY
7}
8
9#[math]
10#[track_caller]
11pub fn solve_in_place_with_conj<T: ComplexField>(
12	L: MatRef<'_, T>,
13	D: DiagRef<'_, T>,
14	conj_lhs: Conj,
15	rhs: MatMut<'_, T>,
16	par: Par,
17	stack: &mut MemStack,
18) {
19	let n = L.nrows();
20	_ = stack;
21	assert!(all(L.nrows() == n, L.ncols() == n, D.dim() == n, rhs.nrows() == n,));
22
23	let mut rhs = rhs;
24	linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj(L, conj_lhs, rhs.rb_mut(), par);
25
26	{
27		with_dim!(N, rhs.nrows());
28		with_dim!(K, rhs.ncols());
29
30		let D = D.as_shape(N);
31		let mut rhs = rhs.rb_mut().as_shape_mut(N, K);
32
33		for j in K.indices() {
34			for i in N.indices() {
35				let d = recip(real(D[i]));
36				rhs[(i, j)] = mul_real(rhs[(i, j)], d);
37			}
38		}
39	}
40
41	linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj(L.transpose(), conj_lhs.compose(Conj::Yes), rhs.rb_mut(), par);
42}
43
44#[math]
45#[track_caller]
46pub fn solve_in_place<T: ComplexField, C: Conjugate<Canonical = T>>(
47	L: MatRef<'_, C>,
48	D: DiagRef<'_, C>,
49	rhs: MatMut<'_, T>,
50	par: Par,
51	stack: &mut MemStack,
52) {
53	solve_in_place_with_conj(L.canonical(), D.canonical(), Conj::get::<C>(), rhs, par, stack);
54}
55
56#[cfg(test)]
57mod tests {
58	use super::*;
59	use crate::assert;
60	use crate::stats::prelude::*;
61	use crate::utils::approx::*;
62	use dyn_stack::MemBuffer;
63	use linalg::cholesky::ldlt;
64
65	#[test]
66	fn test_solve() {
67		let rng = &mut StdRng::seed_from_u64(0);
68		let n = 50;
69		let k = 3;
70
71		let A = CwiseMatDistribution {
72			nrows: n,
73			ncols: n,
74			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
75		}
76		.rand::<Mat<c64>>(rng);
77
78		let B = CwiseMatDistribution {
79			nrows: n,
80			ncols: k,
81			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
82		}
83		.rand::<Mat<c64>>(rng);
84
85		let A = &A * A.adjoint();
86		let mut L = A.to_owned();
87
88		ldlt::factor::cholesky_in_place(
89			L.as_mut(),
90			Default::default(),
91			Par::Seq,
92			MemStack::new(&mut MemBuffer::new(ldlt::factor::cholesky_in_place_scratch::<c64>(
93				n,
94				Par::Seq,
95				default(),
96			))),
97			default(),
98		)
99		.unwrap();
100
101		let approx_eq = CwiseMat(ApproxEq::eps() * 8.0 * (n as f64));
102
103		{
104			let mut X = B.to_owned();
105			ldlt::solve::solve_in_place(
106				L.as_ref(),
107				L.diagonal(),
108				X.as_mut(),
109				Par::Seq,
110				MemStack::new(&mut MemBuffer::new(ldlt::solve::solve_in_place_scratch::<c64>(n, k, Par::Seq))),
111			);
112
113			assert!(&A * &X ~ B);
114		}
115
116		{
117			let mut X = B.to_owned();
118			ldlt::solve::solve_in_place(
119				L.conjugate(),
120				L.conjugate().diagonal(),
121				X.as_mut(),
122				Par::Seq,
123				MemStack::new(&mut MemBuffer::new(ldlt::solve::solve_in_place_scratch::<c64>(n, k, Par::Seq))),
124			);
125
126			assert!(A.conjugate() * &X ~ B);
127		}
128	}
129}