faer/linalg/cholesky/llt/
inverse.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn inverse_scratch<T: ComplexField>(dim: usize, par: Par) -> StackReq {
6	_ = par;
7	temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11pub fn inverse<T: ComplexField>(out: MatMut<'_, T>, L: MatRef<'_, T>, par: Par, stack: &mut MemStack) {
12	// A = L L.T
13	// A^-1 = L^-T L^-1
14
15	let mut out = out;
16	let n = out.nrows();
17
18	assert!(all(out.nrows() == n, out.ncols() == n, L.nrows() == n, L.ncols() == n,));
19
20	let (mut L_inv, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
21	let mut L_inv = L_inv.as_mat_mut();
22
23	linalg::triangular_inverse::invert_lower_triangular(L_inv.rb_mut(), L, par);
24	let L_inv = L_inv.rb();
25
26	linalg::matmul::triangular::matmul(
27		out.rb_mut(),
28		BlockStructure::TriangularLower,
29		Accum::Replace,
30		L_inv.adjoint(),
31		BlockStructure::TriangularUpper,
32		L_inv,
33		BlockStructure::TriangularLower,
34		one(),
35		par,
36	);
37}
38
39#[cfg(test)]
40mod tests {
41	use super::*;
42	use crate::assert;
43	use crate::stats::prelude::*;
44	use crate::utils::approx::*;
45	use dyn_stack::MemBuffer;
46	use linalg::cholesky::llt::*;
47
48	#[test]
49	fn test_inverse() {
50		let rng = &mut StdRng::seed_from_u64(0);
51		let n = 50;
52
53		let A = CwiseMatDistribution {
54			nrows: n,
55			ncols: n,
56			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
57		}
58		.rand::<Mat<c64>>(rng);
59
60		let A = &A * A.adjoint();
61		let mut L = A.to_owned();
62
63		factor::cholesky_in_place(
64			L.as_mut(),
65			Default::default(),
66			Par::Seq,
67			MemStack::new(&mut { MemBuffer::new(factor::cholesky_in_place_scratch::<c64>(n, Par::Seq, default())) }),
68			default(),
69		)
70		.unwrap();
71
72		let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
73
74		let mut A_inv = Mat::zeros(n, n);
75		inverse::inverse(
76			A_inv.as_mut(),
77			L.as_ref(),
78			Par::Seq,
79			MemStack::new(&mut MemBuffer::new(inverse::inverse_scratch::<c64>(n, Par::Seq))),
80		);
81
82		for j in 0..n {
83			for i in 0..j {
84				A_inv[(i, j)] = A_inv[(j, i)].conj();
85			}
86		}
87
88		assert!(A_inv * A ~ Mat::identity(n, n));
89	}
90}