faer/linalg/lu/full_pivoting/
inverse.rs1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn inverse_scratch<I: Index, T: ComplexField>(dim: usize, par: Par) -> StackReq {
6 _ = par;
7 temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11pub fn inverse<I: Index, T: ComplexField>(
12 out: MatMut<'_, T>,
13 L: MatRef<'_, T>,
14 U: MatRef<'_, T>,
15 row_perm: PermRef<'_, I>,
16 col_perm: PermRef<'_, I>,
17 par: Par,
18 stack: &mut MemStack,
19) {
20 let n = L.ncols();
24 assert!(all(
25 L.nrows() == n,
26 L.ncols() == n,
27 U.nrows() == n,
28 U.ncols() == n,
29 out.nrows() == n,
30 out.ncols() == n,
31 row_perm.len() == n,
32 ));
33
34 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
35 let mut tmp = tmp.as_mat_mut();
36 let mut out = out;
37
38 linalg::triangular_inverse::invert_unit_lower_triangular(out.rb_mut(), L, par);
39 linalg::triangular_inverse::invert_upper_triangular(out.rb_mut(), U, par);
40
41 linalg::matmul::triangular::matmul(
42 tmp.rb_mut(),
43 BlockStructure::Rectangular,
44 Accum::Replace,
45 out.rb(),
46 BlockStructure::TriangularUpper,
47 out.rb(),
48 BlockStructure::UnitTriangularLower,
49 one(),
50 par,
51 );
52 with_dim!(N, n);
53
54 let (row_perm, col_perm) = (col_perm.as_shape(N).bound_arrays().1, row_perm.as_shape(N).bound_arrays().1);
55
56 let tmp = tmp.rb().as_shape(N, N);
57 let mut out = out.rb_mut().as_shape_mut(N, N);
58
59 for j in N.indices() {
60 for i in N.indices() {
61 out[(i, j)] = tmp[(row_perm[i].zx(), col_perm[j].zx())].clone();
62 }
63 }
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69 use crate::assert;
70 use crate::stats::prelude::*;
71 use crate::utils::approx::*;
72 use dyn_stack::MemBuffer;
73 use linalg::lu::full_pivoting::*;
74
75 #[test]
76 fn test_inverse() {
77 let rng = &mut StdRng::seed_from_u64(0);
78 let n = 50;
79 let A = CwiseMatDistribution {
80 nrows: n,
81 ncols: n,
82 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
83 }
84 .rand::<Mat<c64>>(rng);
85
86 let mut LU = A.to_owned();
87 let row_perm_fwd = &mut *vec![0usize; n];
88 let row_perm_bwd = &mut *vec![0usize; n];
89 let col_perm_fwd = &mut *vec![0usize; n];
90 let col_perm_bwd = &mut *vec![0usize; n];
91
92 let (_, row_perm, col_perm) = factor::lu_in_place(
93 LU.as_mut(),
94 row_perm_fwd,
95 row_perm_bwd,
96 col_perm_fwd,
97 col_perm_bwd,
98 Par::Seq,
99 MemStack::new(&mut { MemBuffer::new(factor::lu_in_place_scratch::<usize, c64>(n, n, Par::Seq, default())) }),
100 default(),
101 );
102
103 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
104
105 let mut A_inv = Mat::zeros(n, n);
106 inverse::inverse(
107 A_inv.as_mut(),
108 LU.as_ref(),
109 LU.as_ref(),
110 row_perm,
111 col_perm,
112 Par::Seq,
113 MemStack::new(&mut MemBuffer::new(inverse::inverse_scratch::<usize, c64>(n, Par::Seq))),
114 );
115
116 assert!(&A_inv * &A ~ Mat::identity(n, n));
117 }
118}