faer/linalg/lu/partial_pivoting/
inverse.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn inverse_scratch<I: Index, T: ComplexField>(dim: usize, par: Par) -> StackReq {
6	_ = par;
7	temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11pub fn inverse<I: Index, T: ComplexField>(
12	out: MatMut<'_, T>,
13	L: MatRef<'_, T>,
14	U: MatRef<'_, T>,
15	row_perm: PermRef<'_, I>,
16	par: Par,
17	stack: &mut MemStack,
18) {
19	// A = P^-1 L U
20	// A^-1 = U^-1 L^-1 P
21
22	let n = L.ncols();
23	assert!(all(
24		L.nrows() == n,
25		L.ncols() == n,
26		U.nrows() == n,
27		U.ncols() == n,
28		out.nrows() == n,
29		out.ncols() == n,
30		row_perm.len() == n,
31	));
32
33	let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
34	let mut tmp = tmp.as_mat_mut();
35	let mut out = out;
36
37	linalg::triangular_inverse::invert_unit_lower_triangular(out.rb_mut(), L, par);
38	linalg::triangular_inverse::invert_upper_triangular(out.rb_mut(), U, par);
39
40	linalg::matmul::triangular::matmul(
41		tmp.rb_mut(),
42		BlockStructure::Rectangular,
43		Accum::Replace,
44		out.rb(),
45		BlockStructure::TriangularUpper,
46		out.rb(),
47		BlockStructure::UnitTriangularLower,
48		one(),
49		par,
50	);
51	crate::perm::permute_cols(out.rb_mut(), tmp.rb(), row_perm.inverse());
52}
53
54#[cfg(test)]
55mod tests {
56	use super::*;
57	use crate::assert;
58	use crate::stats::prelude::*;
59	use crate::utils::approx::*;
60	use dyn_stack::MemBuffer;
61	use linalg::lu::partial_pivoting::*;
62
63	#[test]
64	fn test_inverse() {
65		let rng = &mut StdRng::seed_from_u64(0);
66		let n = 50;
67		let A = CwiseMatDistribution {
68			nrows: n,
69			ncols: n,
70			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
71		}
72		.rand::<Mat<c64>>(rng);
73
74		let mut LU = A.to_owned();
75		let perm_fwd = &mut *vec![0usize; n];
76		let perm_bwd = &mut *vec![0usize; n];
77
78		let (_, perm) = factor::lu_in_place(
79			LU.as_mut(),
80			perm_fwd,
81			perm_bwd,
82			Par::Seq,
83			MemStack::new(&mut { MemBuffer::new(factor::lu_in_place_scratch::<usize, c64>(n, n, Par::Seq, default())) }),
84			default(),
85		);
86
87		let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
88
89		let mut A_inv = Mat::zeros(n, n);
90		inverse::inverse(
91			A_inv.as_mut(),
92			LU.as_ref(),
93			LU.as_ref(),
94			perm,
95			Par::Seq,
96			MemStack::new(&mut MemBuffer::new(inverse::inverse_scratch::<usize, c64>(n, Par::Seq))),
97		);
98
99		assert!(&A_inv * &A ~ Mat::identity(n, n));
100	}
101}