faer/linalg/cholesky/ldlt/
inverse.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn inverse_scratch<T: ComplexField>(dim: usize, par: Par) -> StackReq {
6	_ = par;
7	temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11#[math]
12pub fn inverse<T: ComplexField>(out: MatMut<'_, T>, L: MatRef<'_, T>, D: DiagRef<'_, T>, par: Par, stack: &mut MemStack) {
13	// A = L D L.T
14	// A^-1 = L^-T D^-1 L^-1
15
16	let mut out = out;
17	let n = out.nrows();
18
19	assert!(all(out.nrows() == n, out.ncols() == n, L.nrows() == n, L.ncols() == n, D.dim() == n,));
20
21	let (mut L_inv, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
22	let mut L_inv = L_inv.as_mat_mut();
23
24	linalg::triangular_inverse::invert_unit_lower_triangular(L_inv.rb_mut(), L, par);
25
26	{
27		with_dim!(N, n);
28		let mut L_inv = L_inv.rb_mut().as_shape_mut(N, N);
29		let D = D.as_shape(N);
30
31		for j in N.indices() {
32			let d = recip(real(D[j]));
33			L_inv[(j, j)] = from_real(d);
34		}
35
36		for j in N.indices() {
37			for i in j.next().to(N.end()) {
38				let d = real(L_inv[(i, i)]);
39				L_inv[(j, i)] = mul_real(conj(L_inv[(i, j)]), d);
40			}
41		}
42	}
43
44	let L_inv = L_inv.rb();
45
46	linalg::matmul::triangular::matmul(
47		out.rb_mut(),
48		BlockStructure::TriangularLower,
49		Accum::Replace,
50		L_inv,
51		BlockStructure::TriangularUpper,
52		L_inv,
53		BlockStructure::UnitTriangularLower,
54		one(),
55		par,
56	);
57}
58
59#[cfg(test)]
60mod tests {
61	use super::*;
62	use crate::assert;
63	use crate::stats::prelude::*;
64	use crate::utils::approx::*;
65	use dyn_stack::MemBuffer;
66	use linalg::cholesky::ldlt::*;
67
68	#[test]
69	fn test_inverse() {
70		let rng = &mut StdRng::seed_from_u64(0);
71		let n = 50;
72
73		let A = CwiseMatDistribution {
74			nrows: n,
75			ncols: n,
76			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
77		}
78		.rand::<Mat<c64>>(rng);
79
80		let A = &A * A.adjoint();
81		let mut L = A.to_owned();
82
83		factor::cholesky_in_place(
84			L.as_mut(),
85			Default::default(),
86			Par::Seq,
87			MemStack::new(&mut { MemBuffer::new(factor::cholesky_in_place_scratch::<c64>(n, Par::Seq, default())) }),
88			default(),
89		)
90		.unwrap();
91
92		let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
93
94		let mut A_inv = Mat::zeros(n, n);
95		inverse::inverse(
96			A_inv.as_mut(),
97			L.as_ref(),
98			L.diagonal(),
99			Par::Seq,
100			MemStack::new(&mut MemBuffer::new(inverse::inverse_scratch::<c64>(n, Par::Seq))),
101		);
102
103		for j in 0..n {
104			for i in 0..j {
105				A_inv[(i, j)] = A_inv[(j, i)].conj();
106			}
107		}
108
109		assert!(A_inv * A ~ Mat::identity(n, n));
110	}
111}