faer/linalg/cholesky/llt_pivoting/
inverse.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn inverse_scratch<I: Index, T: ComplexField>(dim: usize, par: Par) -> StackReq {
6	_ = par;
7	temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11#[math]
12pub fn inverse<I: Index, T: ComplexField>(out: MatMut<'_, T>, L: MatRef<'_, T>, perm: PermRef<'_, I>, par: Par, stack: &mut MemStack) {
13	// A = L L.T
14	// A^-1 = L^-T L^-1
15
16	let mut out = out;
17	let n = out.nrows();
18
19	assert!(all(out.nrows() == n, out.ncols() == n, L.nrows() == n, L.ncols() == n,));
20
21	let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
22	let mut tmp = tmp.as_mat_mut();
23
24	linalg::triangular_inverse::invert_lower_triangular(out.rb_mut(), L, par);
25	let L_inv = out.rb();
26
27	linalg::matmul::triangular::matmul(
28		tmp.rb_mut(),
29		BlockStructure::TriangularLower,
30		Accum::Replace,
31		L_inv.adjoint(),
32		BlockStructure::TriangularUpper,
33		L_inv,
34		BlockStructure::TriangularLower,
35		one(),
36		par,
37	);
38
39	let p = perm.arrays().1;
40
41	for j in 0..n {
42		let jj = p[j].zx();
43		for i in j..n {
44			let ii = p[i].zx();
45
46			if ii >= jj {
47				out[(i, j)] = copy(tmp[(ii, jj)]);
48			} else {
49				out[(i, j)] = conj(tmp[(jj, ii)]);
50			}
51		}
52	}
53}
54
55#[cfg(test)]
56mod tests {
57	use super::*;
58	use crate::assert;
59	use crate::stats::prelude::*;
60	use crate::utils::approx::*;
61	use dyn_stack::MemBuffer;
62	use linalg::cholesky::llt_pivoting::*;
63
64	#[test]
65	fn test_inverse() {
66		let rng = &mut StdRng::seed_from_u64(0);
67		let n = 50;
68
69		let A = CwiseMatDistribution {
70			nrows: n,
71			ncols: n,
72			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
73		}
74		.rand::<Mat<c64>>(rng);
75
76		let A = &A * A.adjoint();
77		let mut L = A.to_owned();
78		let perm_fwd = &mut *vec![0usize; n];
79		let perm_bwd = &mut *vec![0usize; n];
80
81		let (_, perm) = factor::cholesky_in_place(
82			L.as_mut(),
83			perm_fwd,
84			perm_bwd,
85			Par::Seq,
86			MemStack::new(&mut { MemBuffer::new(factor::cholesky_in_place_scratch::<usize, c64>(n, Par::Seq, default())) }),
87			default(),
88		)
89		.unwrap();
90
91		let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
92
93		let mut A_inv = Mat::zeros(n, n);
94		inverse::inverse(
95			A_inv.as_mut(),
96			L.as_ref(),
97			perm,
98			Par::Seq,
99			MemStack::new(&mut MemBuffer::new(inverse::inverse_scratch::<usize, c64>(n, Par::Seq))),
100		);
101
102		for j in 0..n {
103			for i in 0..j {
104				A_inv[(i, j)] = A_inv[(j, i)].conj();
105			}
106		}
107
108		assert!(A_inv * A ~ Mat::identity(n, n));
109	}
110}