faer/linalg/cholesky/llt_pivoting/
inverse.rs1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn inverse_scratch<I: Index, T: ComplexField>(dim: usize, par: Par) -> StackReq {
6 _ = par;
7 temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11#[math]
12pub fn inverse<I: Index, T: ComplexField>(out: MatMut<'_, T>, L: MatRef<'_, T>, perm: PermRef<'_, I>, par: Par, stack: &mut MemStack) {
13 let mut out = out;
17 let n = out.nrows();
18
19 assert!(all(out.nrows() == n, out.ncols() == n, L.nrows() == n, L.ncols() == n,));
20
21 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
22 let mut tmp = tmp.as_mat_mut();
23
24 linalg::triangular_inverse::invert_lower_triangular(out.rb_mut(), L, par);
25 let L_inv = out.rb();
26
27 linalg::matmul::triangular::matmul(
28 tmp.rb_mut(),
29 BlockStructure::TriangularLower,
30 Accum::Replace,
31 L_inv.adjoint(),
32 BlockStructure::TriangularUpper,
33 L_inv,
34 BlockStructure::TriangularLower,
35 one(),
36 par,
37 );
38
39 let p = perm.arrays().1;
40
41 for j in 0..n {
42 let jj = p[j].zx();
43 for i in j..n {
44 let ii = p[i].zx();
45
46 if ii >= jj {
47 out[(i, j)] = copy(tmp[(ii, jj)]);
48 } else {
49 out[(i, j)] = conj(tmp[(jj, ii)]);
50 }
51 }
52 }
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58 use crate::assert;
59 use crate::stats::prelude::*;
60 use crate::utils::approx::*;
61 use dyn_stack::MemBuffer;
62 use linalg::cholesky::llt_pivoting::*;
63
64 #[test]
65 fn test_inverse() {
66 let rng = &mut StdRng::seed_from_u64(0);
67 let n = 50;
68
69 let A = CwiseMatDistribution {
70 nrows: n,
71 ncols: n,
72 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
73 }
74 .rand::<Mat<c64>>(rng);
75
76 let A = &A * A.adjoint();
77 let mut L = A.to_owned();
78 let perm_fwd = &mut *vec![0usize; n];
79 let perm_bwd = &mut *vec![0usize; n];
80
81 let (_, perm) = factor::cholesky_in_place(
82 L.as_mut(),
83 perm_fwd,
84 perm_bwd,
85 Par::Seq,
86 MemStack::new(&mut { MemBuffer::new(factor::cholesky_in_place_scratch::<usize, c64>(n, Par::Seq, default())) }),
87 default(),
88 )
89 .unwrap();
90
91 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
92
93 let mut A_inv = Mat::zeros(n, n);
94 inverse::inverse(
95 A_inv.as_mut(),
96 L.as_ref(),
97 perm,
98 Par::Seq,
99 MemStack::new(&mut MemBuffer::new(inverse::inverse_scratch::<usize, c64>(n, Par::Seq))),
100 );
101
102 for j in 0..n {
103 for i in 0..j {
104 A_inv[(i, j)] = A_inv[(j, i)].conj();
105 }
106 }
107
108 assert!(A_inv * A ~ Mat::identity(n, n));
109 }
110}