faer/linalg/cholesky/llt/
inverse.rs1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn inverse_scratch<T: ComplexField>(dim: usize, par: Par) -> StackReq {
6 _ = par;
7 temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11pub fn inverse<T: ComplexField>(out: MatMut<'_, T>, L: MatRef<'_, T>, par: Par, stack: &mut MemStack) {
12 let mut out = out;
16 let n = out.nrows();
17
18 assert!(all(out.nrows() == n, out.ncols() == n, L.nrows() == n, L.ncols() == n,));
19
20 let (mut L_inv, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
21 let mut L_inv = L_inv.as_mat_mut();
22
23 linalg::triangular_inverse::invert_lower_triangular(L_inv.rb_mut(), L, par);
24 let L_inv = L_inv.rb();
25
26 linalg::matmul::triangular::matmul(
27 out.rb_mut(),
28 BlockStructure::TriangularLower,
29 Accum::Replace,
30 L_inv.adjoint(),
31 BlockStructure::TriangularUpper,
32 L_inv,
33 BlockStructure::TriangularLower,
34 one(),
35 par,
36 );
37}
38
39#[cfg(test)]
40mod tests {
41 use super::*;
42 use crate::assert;
43 use crate::stats::prelude::*;
44 use crate::utils::approx::*;
45 use dyn_stack::MemBuffer;
46 use linalg::cholesky::llt::*;
47
48 #[test]
49 fn test_inverse() {
50 let rng = &mut StdRng::seed_from_u64(0);
51 let n = 50;
52
53 let A = CwiseMatDistribution {
54 nrows: n,
55 ncols: n,
56 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
57 }
58 .rand::<Mat<c64>>(rng);
59
60 let A = &A * A.adjoint();
61 let mut L = A.to_owned();
62
63 factor::cholesky_in_place(
64 L.as_mut(),
65 Default::default(),
66 Par::Seq,
67 MemStack::new(&mut { MemBuffer::new(factor::cholesky_in_place_scratch::<c64>(n, Par::Seq, default())) }),
68 default(),
69 )
70 .unwrap();
71
72 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
73
74 let mut A_inv = Mat::zeros(n, n);
75 inverse::inverse(
76 A_inv.as_mut(),
77 L.as_ref(),
78 Par::Seq,
79 MemStack::new(&mut MemBuffer::new(inverse::inverse_scratch::<c64>(n, Par::Seq))),
80 );
81
82 for j in 0..n {
83 for i in 0..j {
84 A_inv[(i, j)] = A_inv[(j, i)].conj();
85 }
86 }
87
88 assert!(A_inv * A ~ Mat::identity(n, n));
89 }
90}