faer/linalg/cholesky/ldlt/
inverse.rs1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn inverse_scratch<T: ComplexField>(dim: usize, par: Par) -> StackReq {
6 _ = par;
7 temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11#[math]
12pub fn inverse<T: ComplexField>(out: MatMut<'_, T>, L: MatRef<'_, T>, D: DiagRef<'_, T>, par: Par, stack: &mut MemStack) {
13 let mut out = out;
17 let n = out.nrows();
18
19 assert!(all(out.nrows() == n, out.ncols() == n, L.nrows() == n, L.ncols() == n, D.dim() == n,));
20
21 let (mut L_inv, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
22 let mut L_inv = L_inv.as_mat_mut();
23
24 linalg::triangular_inverse::invert_unit_lower_triangular(L_inv.rb_mut(), L, par);
25
26 {
27 with_dim!(N, n);
28 let mut L_inv = L_inv.rb_mut().as_shape_mut(N, N);
29 let D = D.as_shape(N);
30
31 for j in N.indices() {
32 let d = recip(real(D[j]));
33 L_inv[(j, j)] = from_real(d);
34 }
35
36 for j in N.indices() {
37 for i in j.next().to(N.end()) {
38 let d = real(L_inv[(i, i)]);
39 L_inv[(j, i)] = mul_real(conj(L_inv[(i, j)]), d);
40 }
41 }
42 }
43
44 let L_inv = L_inv.rb();
45
46 linalg::matmul::triangular::matmul(
47 out.rb_mut(),
48 BlockStructure::TriangularLower,
49 Accum::Replace,
50 L_inv,
51 BlockStructure::TriangularUpper,
52 L_inv,
53 BlockStructure::UnitTriangularLower,
54 one(),
55 par,
56 );
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62 use crate::assert;
63 use crate::stats::prelude::*;
64 use crate::utils::approx::*;
65 use dyn_stack::MemBuffer;
66 use linalg::cholesky::ldlt::*;
67
68 #[test]
69 fn test_inverse() {
70 let rng = &mut StdRng::seed_from_u64(0);
71 let n = 50;
72
73 let A = CwiseMatDistribution {
74 nrows: n,
75 ncols: n,
76 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
77 }
78 .rand::<Mat<c64>>(rng);
79
80 let A = &A * A.adjoint();
81 let mut L = A.to_owned();
82
83 factor::cholesky_in_place(
84 L.as_mut(),
85 Default::default(),
86 Par::Seq,
87 MemStack::new(&mut { MemBuffer::new(factor::cholesky_in_place_scratch::<c64>(n, Par::Seq, default())) }),
88 default(),
89 )
90 .unwrap();
91
92 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
93
94 let mut A_inv = Mat::zeros(n, n);
95 inverse::inverse(
96 A_inv.as_mut(),
97 L.as_ref(),
98 L.diagonal(),
99 Par::Seq,
100 MemStack::new(&mut MemBuffer::new(inverse::inverse_scratch::<c64>(n, Par::Seq))),
101 );
102
103 for j in 0..n {
104 for i in 0..j {
105 A_inv[(i, j)] = A_inv[(j, i)].conj();
106 }
107 }
108
109 assert!(A_inv * A ~ Mat::identity(n, n));
110 }
111}