faer/linalg/lu/partial_pivoting/
inverse.rs1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn inverse_scratch<I: Index, T: ComplexField>(dim: usize, par: Par) -> StackReq {
6 _ = par;
7 temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11pub fn inverse<I: Index, T: ComplexField>(
12 out: MatMut<'_, T>,
13 L: MatRef<'_, T>,
14 U: MatRef<'_, T>,
15 row_perm: PermRef<'_, I>,
16 par: Par,
17 stack: &mut MemStack,
18) {
19 let n = L.ncols();
23 assert!(all(
24 L.nrows() == n,
25 L.ncols() == n,
26 U.nrows() == n,
27 U.ncols() == n,
28 out.nrows() == n,
29 out.ncols() == n,
30 row_perm.len() == n,
31 ));
32
33 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
34 let mut tmp = tmp.as_mat_mut();
35 let mut out = out;
36
37 linalg::triangular_inverse::invert_unit_lower_triangular(out.rb_mut(), L, par);
38 linalg::triangular_inverse::invert_upper_triangular(out.rb_mut(), U, par);
39
40 linalg::matmul::triangular::matmul(
41 tmp.rb_mut(),
42 BlockStructure::Rectangular,
43 Accum::Replace,
44 out.rb(),
45 BlockStructure::TriangularUpper,
46 out.rb(),
47 BlockStructure::UnitTriangularLower,
48 one(),
49 par,
50 );
51 crate::perm::permute_cols(out.rb_mut(), tmp.rb(), row_perm.inverse());
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57 use crate::assert;
58 use crate::stats::prelude::*;
59 use crate::utils::approx::*;
60 use dyn_stack::MemBuffer;
61 use linalg::lu::partial_pivoting::*;
62
63 #[test]
64 fn test_inverse() {
65 let rng = &mut StdRng::seed_from_u64(0);
66 let n = 50;
67 let A = CwiseMatDistribution {
68 nrows: n,
69 ncols: n,
70 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
71 }
72 .rand::<Mat<c64>>(rng);
73
74 let mut LU = A.to_owned();
75 let perm_fwd = &mut *vec![0usize; n];
76 let perm_bwd = &mut *vec![0usize; n];
77
78 let (_, perm) = factor::lu_in_place(
79 LU.as_mut(),
80 perm_fwd,
81 perm_bwd,
82 Par::Seq,
83 MemStack::new(&mut { MemBuffer::new(factor::lu_in_place_scratch::<usize, c64>(n, n, Par::Seq, default())) }),
84 default(),
85 );
86
87 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
88
89 let mut A_inv = Mat::zeros(n, n);
90 inverse::inverse(
91 A_inv.as_mut(),
92 LU.as_ref(),
93 LU.as_ref(),
94 perm,
95 Par::Seq,
96 MemStack::new(&mut MemBuffer::new(inverse::inverse_scratch::<usize, c64>(n, Par::Seq))),
97 );
98
99 assert!(&A_inv * &A ~ Mat::identity(n, n));
100 }
101}