faer/linalg/lu/full_pivoting/
inverse.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn inverse_scratch<I: Index, T: ComplexField>(dim: usize, par: Par) -> StackReq {
6	_ = par;
7	temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11pub fn inverse<I: Index, T: ComplexField>(
12	out: MatMut<'_, T>,
13	L: MatRef<'_, T>,
14	U: MatRef<'_, T>,
15	row_perm: PermRef<'_, I>,
16	col_perm: PermRef<'_, I>,
17	par: Par,
18	stack: &mut MemStack,
19) {
20	// A = P^-1 L U Q
21	// A^-1 = Q^-1 U^-1 L^-1 P
22
23	let n = L.ncols();
24	assert!(all(
25		L.nrows() == n,
26		L.ncols() == n,
27		U.nrows() == n,
28		U.ncols() == n,
29		out.nrows() == n,
30		out.ncols() == n,
31		row_perm.len() == n,
32	));
33
34	let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
35	let mut tmp = tmp.as_mat_mut();
36	let mut out = out;
37
38	linalg::triangular_inverse::invert_unit_lower_triangular(out.rb_mut(), L, par);
39	linalg::triangular_inverse::invert_upper_triangular(out.rb_mut(), U, par);
40
41	linalg::matmul::triangular::matmul(
42		tmp.rb_mut(),
43		BlockStructure::Rectangular,
44		Accum::Replace,
45		out.rb(),
46		BlockStructure::TriangularUpper,
47		out.rb(),
48		BlockStructure::UnitTriangularLower,
49		one(),
50		par,
51	);
52	with_dim!(N, n);
53
54	let (row_perm, col_perm) = (col_perm.as_shape(N).bound_arrays().1, row_perm.as_shape(N).bound_arrays().1);
55
56	let tmp = tmp.rb().as_shape(N, N);
57	let mut out = out.rb_mut().as_shape_mut(N, N);
58
59	for j in N.indices() {
60		for i in N.indices() {
61			out[(i, j)] = tmp[(row_perm[i].zx(), col_perm[j].zx())].clone();
62		}
63	}
64}
65
66#[cfg(test)]
67mod tests {
68	use super::*;
69	use crate::assert;
70	use crate::stats::prelude::*;
71	use crate::utils::approx::*;
72	use dyn_stack::MemBuffer;
73	use linalg::lu::full_pivoting::*;
74
75	#[test]
76	fn test_inverse() {
77		let rng = &mut StdRng::seed_from_u64(0);
78		let n = 50;
79		let A = CwiseMatDistribution {
80			nrows: n,
81			ncols: n,
82			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
83		}
84		.rand::<Mat<c64>>(rng);
85
86		let mut LU = A.to_owned();
87		let row_perm_fwd = &mut *vec![0usize; n];
88		let row_perm_bwd = &mut *vec![0usize; n];
89		let col_perm_fwd = &mut *vec![0usize; n];
90		let col_perm_bwd = &mut *vec![0usize; n];
91
92		let (_, row_perm, col_perm) = factor::lu_in_place(
93			LU.as_mut(),
94			row_perm_fwd,
95			row_perm_bwd,
96			col_perm_fwd,
97			col_perm_bwd,
98			Par::Seq,
99			MemStack::new(&mut { MemBuffer::new(factor::lu_in_place_scratch::<usize, c64>(n, n, Par::Seq, default())) }),
100			default(),
101		);
102
103		let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
104
105		let mut A_inv = Mat::zeros(n, n);
106		inverse::inverse(
107			A_inv.as_mut(),
108			LU.as_ref(),
109			LU.as_ref(),
110			row_perm,
111			col_perm,
112			Par::Seq,
113			MemStack::new(&mut MemBuffer::new(inverse::inverse_scratch::<usize, c64>(n, Par::Seq))),
114		);
115
116		assert!(&A_inv * &A ~ Mat::identity(n, n));
117	}
118}