faer/linalg/
triangular_inverse.rs

1use linalg::matmul::triangular::BlockStructure;
2
3use crate::internal_prelude::*;
4use crate::utils::thread::join_raw;
5
6#[math]
7fn invert_lower_triangular_impl_small<'N, T: ComplexField>(mut dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>) {
8	let N = dst.nrows();
9	match *N {
10		0 => {},
11		1 => {
12			let i0 = N.check(0);
13			*dst.rb_mut().at_mut(i0, i0) = recip(src[(i0, i0)])
14		},
15		2 => {
16			let i0 = N.check(0);
17			let i1 = N.check(1);
18			let dst00 = recip(src[(i0, i0)]);
19			let dst11 = recip(src[(i1, i1)]);
20			let dst10 = -dst11 * src[(i1, i0)] * dst00;
21
22			*dst.rb_mut().at_mut(i0, i0) = dst00;
23			*dst.rb_mut().at_mut(i1, i1) = dst11;
24			*dst.rb_mut().at_mut(i1, i0) = dst10;
25		},
26		_ => unreachable!(),
27	}
28}
29
30#[math]
31fn invert_unit_lower_triangular_impl_small<'N, T: ComplexField>(mut dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>) {
32	let N = dst.nrows();
33	match *N {
34		0 | 1 => {},
35		2 => {
36			let i0 = N.check(0);
37			let i1 = N.check(1);
38			*dst.rb_mut().at_mut(i1, i0) = -src[(i1, i0)];
39		},
40		_ => unreachable!(),
41	}
42}
43
44#[math]
45fn invert_lower_triangular_impl<'N, T: ComplexField>(dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>, par: Par) {
46	// m must be equal to n
47	let N = dst.ncols();
48
49	if *N <= 2 {
50		invert_lower_triangular_impl_small(dst, src);
51		return;
52	}
53
54	make_guard!(HEAD);
55	make_guard!(TAIL);
56	let mid = N.partition(N.checked_idx_inc(*N / 2), HEAD, TAIL);
57
58	let (mut dst_tl, _, mut dst_bl, mut dst_br) = { dst.split_with_mut(mid, mid) };
59	let (src_tl, _, src_bl, src_br) = { src.split_with(mid, mid) };
60
61	join_raw(
62		|par| invert_lower_triangular_impl(dst_tl.rb_mut(), src_tl, par),
63		|par| invert_lower_triangular_impl(dst_br.rb_mut(), src_br, par),
64		par,
65	);
66
67	linalg::matmul::triangular::matmul(
68		dst_bl.rb_mut(),
69		BlockStructure::Rectangular,
70		Accum::Replace,
71		src_bl,
72		BlockStructure::Rectangular,
73		dst_tl.rb(),
74		BlockStructure::TriangularLower,
75		-one::<T>(),
76		par,
77	);
78	linalg::triangular_solve::solve_lower_triangular_in_place(src_br, dst_bl, par);
79}
80
81#[math]
82fn invert_unit_lower_triangular_impl<'N, T: ComplexField>(dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>, par: Par) {
83	// m must be equal to n
84	let N = dst.ncols();
85
86	if *N <= 2 {
87		invert_unit_lower_triangular_impl_small(dst, src);
88		return;
89	}
90
91	make_guard!(HEAD);
92	make_guard!(TAIL);
93	let mid = N.partition(N.checked_idx_inc(*N / 2), HEAD, TAIL);
94
95	let (mut dst_tl, _, mut dst_bl, mut dst_br) = { dst.split_with_mut(mid, mid) };
96	let (src_tl, _, src_bl, src_br) = { src.split_with(mid, mid) };
97
98	join_raw(
99		|par| invert_unit_lower_triangular_impl(dst_tl.rb_mut(), src_tl, par),
100		|par| invert_unit_lower_triangular_impl(dst_br.rb_mut(), src_br, par),
101		par,
102	);
103
104	linalg::matmul::triangular::matmul(
105		dst_bl.rb_mut(),
106		BlockStructure::Rectangular,
107		Accum::Replace,
108		src_bl,
109		BlockStructure::Rectangular,
110		dst_tl.rb(),
111		BlockStructure::UnitTriangularLower,
112		-one::<T>(),
113		par,
114	);
115	linalg::triangular_solve::solve_unit_lower_triangular_in_place(src_br, dst_bl, par);
116}
117
118/// computes the inverse of the lower triangular matrix `src` (with implicit unit
119/// diagonal) and stores the strictly lower triangular part of the result to `dst`.
120///
121/// # panics
122///
123/// panics if `src` and `dst` have mismatching dimensions, or if they are not square.
124#[track_caller]
125pub fn invert_unit_lower_triangular<T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, par: Par) {
126	Assert!(all(dst.nrows() == src.nrows(), dst.ncols() == src.ncols(), dst.nrows() == dst.ncols()));
127
128	with_dim!(N, dst.nrows().unbound());
129
130	invert_unit_lower_triangular_impl(dst.as_shape_mut(N, N).as_dyn_stride_mut(), src.as_shape(N, N).as_dyn_stride(), par)
131}
132
133/// computes the inverse of the lower triangular matrix `src` and stores the
134/// lower triangular part of the result to `dst`.
135///
136/// # panics
137///
138/// panics if `src` and `dst` have mismatching dimensions, or if they are not square.
139#[track_caller]
140pub fn invert_lower_triangular<T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, par: Par) {
141	Assert!(all(dst.nrows() == src.nrows(), dst.ncols() == src.ncols(), dst.nrows() == dst.ncols()));
142
143	with_dim!(N, dst.nrows().unbound());
144
145	invert_lower_triangular_impl(dst.as_shape_mut(N, N).as_dyn_stride_mut(), src.as_shape(N, N).as_dyn_stride(), par)
146}
147
148/// computes the inverse of the upper triangular matrix `src` (with implicit unit
149/// diagonal) and stores the strictly upper triangular part of the result to `dst`.
150///
151/// # panics
152///
153/// panics if `src` and `dst` have mismatching dimensions, or if they are not square.
154#[track_caller]
155pub fn invert_unit_upper_triangular<T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, par: Par) {
156	invert_unit_lower_triangular(dst.reverse_rows_and_cols_mut(), src.reverse_rows_and_cols(), par)
157}
158
159/// computes the inverse of the upper triangular matrix `src` and stores the
160/// upper triangular part of the result to `dst`.
161///
162/// # panics
163///
164/// panics if `src` and `dst` have mismatching dimensions, or if they are not square.
165#[track_caller]
166pub fn invert_upper_triangular<T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, par: Par) {
167	invert_lower_triangular(dst.reverse_rows_and_cols_mut(), src.reverse_rows_and_cols(), par)
168}
169
170#[cfg(test)]
171mod tests {
172	use super::*;
173	use crate::{Mat, MatRef, assert};
174	use assert_approx_eq::assert_approx_eq;
175	use linalg::matmul::triangular;
176	use rand::SeedableRng;
177	use rand::rngs::StdRng;
178	use rand_distr::{Distribution, StandardNormal};
179
180	#[test]
181	fn test_invert_lower() {
182		let rng = &mut StdRng::seed_from_u64(0);
183		(0..32).for_each(|n| {
184			let mut a: Mat<f64> = crate::stats::CwiseMatDistribution {
185				nrows: n,
186				ncols: n,
187				dist: StandardNormal,
188			}
189			.sample(rng);
190			a += MatRef::from_repeated_ref(&2.0, n, n);
191			let mut inv = Mat::zeros(n, n);
192
193			invert_lower_triangular(inv.as_mut(), a.as_ref(), Par::rayon(0));
194
195			let mut prod = Mat::zeros(n, n);
196			triangular::matmul(
197				prod.as_mut(),
198				BlockStructure::Rectangular,
199				Accum::Replace,
200				a.as_ref(),
201				BlockStructure::TriangularLower,
202				inv.as_ref(),
203				BlockStructure::TriangularLower,
204				1.0,
205				Par::rayon(0),
206			);
207
208			for i in 0..n {
209				for j in 0..n {
210					let target = if i == j { 1.0 } else { 0.0 };
211					assert_approx_eq!(prod[(i, j)], target, 1e-4);
212				}
213			}
214		});
215	}
216
217	#[test]
218	fn test_invert_unit_lower() {
219		let rng = &mut StdRng::seed_from_u64(0);
220		(0..32).for_each(|n| {
221			let mut a: Mat<f64> = crate::stats::CwiseMatDistribution {
222				nrows: n,
223				ncols: n,
224				dist: StandardNormal,
225			}
226			.sample(rng);
227			a += MatRef::from_repeated_ref(&2.0, n, n);
228			let mut inv = Mat::zeros(n, n);
229
230			invert_unit_lower_triangular(inv.as_mut(), a.as_ref(), Par::rayon(0));
231
232			let mut prod = Mat::zeros(n, n);
233			triangular::matmul(
234				prod.as_mut(),
235				BlockStructure::Rectangular,
236				Accum::Replace,
237				a.as_ref(),
238				BlockStructure::UnitTriangularLower,
239				inv.as_ref(),
240				BlockStructure::UnitTriangularLower,
241				1.0,
242				Par::rayon(0),
243			);
244
245			for i in 0..n {
246				for j in 0..n {
247					let target = if i == j { 1.0 } else { 0.0 };
248					assert_approx_eq!(prod[(i, j)], target, 1e-4);
249				}
250			}
251		});
252	}
253
254	#[test]
255	fn test_invert_upper() {
256		let rng = &mut StdRng::seed_from_u64(0);
257		(0..32).for_each(|n| {
258			let mut a: Mat<f64> = crate::stats::CwiseMatDistribution {
259				nrows: n,
260				ncols: n,
261				dist: StandardNormal,
262			}
263			.sample(rng);
264			a += MatRef::from_repeated_ref(&2.0, n, n);
265			let mut inv = Mat::zeros(n, n);
266
267			invert_upper_triangular(inv.as_mut(), a.as_ref(), Par::rayon(0));
268
269			let mut prod = Mat::zeros(n, n);
270			triangular::matmul(
271				prod.as_mut(),
272				BlockStructure::Rectangular,
273				Accum::Replace,
274				a.as_ref(),
275				BlockStructure::TriangularUpper,
276				inv.as_ref(),
277				BlockStructure::TriangularUpper,
278				1.0,
279				Par::rayon(0),
280			);
281
282			for i in 0..n {
283				for j in 0..n {
284					let target = if i == j { 1.0 } else { 0.0 };
285					assert_approx_eq!(prod[(i, j)], target, 1e-4);
286				}
287			}
288		});
289	}
290
291	#[test]
292	fn test_invert_unit_upper() {
293		let rng = &mut StdRng::seed_from_u64(0);
294		(0..32).for_each(|n| {
295			let mut a: Mat<f64> = crate::stats::CwiseMatDistribution {
296				nrows: n,
297				ncols: n,
298				dist: StandardNormal,
299			}
300			.sample(rng);
301			a += MatRef::from_repeated_ref(&2.0, n, n);
302
303			let mut inv = Mat::zeros(n, n);
304
305			invert_unit_upper_triangular(inv.as_mut(), a.as_ref(), Par::rayon(0));
306
307			let mut prod = Mat::zeros(n, n);
308			triangular::matmul(
309				prod.as_mut(),
310				BlockStructure::Rectangular,
311				Accum::Replace,
312				a.as_ref(),
313				BlockStructure::UnitTriangularUpper,
314				inv.as_ref(),
315				BlockStructure::UnitTriangularUpper,
316				1.0,
317				Par::rayon(0),
318			);
319
320			for i in 0..n {
321				for j in 0..n {
322					let target = if i == j { 1.0 } else { 0.0 };
323					assert_approx_eq!(prod[(i, j)], target, 1e-4);
324				}
325			}
326		});
327	}
328}