1use linalg::matmul::triangular::BlockStructure;
2
3use crate::internal_prelude::*;
4use crate::utils::thread::join_raw;
5
6#[math]
7fn invert_lower_triangular_impl_small<'N, T: ComplexField>(mut dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>) {
8 let N = dst.nrows();
9 match *N {
10 0 => {},
11 1 => {
12 let i0 = N.check(0);
13 *dst.rb_mut().at_mut(i0, i0) = recip(src[(i0, i0)])
14 },
15 2 => {
16 let i0 = N.check(0);
17 let i1 = N.check(1);
18 let dst00 = recip(src[(i0, i0)]);
19 let dst11 = recip(src[(i1, i1)]);
20 let dst10 = -dst11 * src[(i1, i0)] * dst00;
21
22 *dst.rb_mut().at_mut(i0, i0) = dst00;
23 *dst.rb_mut().at_mut(i1, i1) = dst11;
24 *dst.rb_mut().at_mut(i1, i0) = dst10;
25 },
26 _ => unreachable!(),
27 }
28}
29
30#[math]
31fn invert_unit_lower_triangular_impl_small<'N, T: ComplexField>(mut dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>) {
32 let N = dst.nrows();
33 match *N {
34 0 | 1 => {},
35 2 => {
36 let i0 = N.check(0);
37 let i1 = N.check(1);
38 *dst.rb_mut().at_mut(i1, i0) = -src[(i1, i0)];
39 },
40 _ => unreachable!(),
41 }
42}
43
44#[math]
45fn invert_lower_triangular_impl<'N, T: ComplexField>(dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>, par: Par) {
46 let N = dst.ncols();
48
49 if *N <= 2 {
50 invert_lower_triangular_impl_small(dst, src);
51 return;
52 }
53
54 make_guard!(HEAD);
55 make_guard!(TAIL);
56 let mid = N.partition(N.checked_idx_inc(*N / 2), HEAD, TAIL);
57
58 let (mut dst_tl, _, mut dst_bl, mut dst_br) = { dst.split_with_mut(mid, mid) };
59 let (src_tl, _, src_bl, src_br) = { src.split_with(mid, mid) };
60
61 join_raw(
62 |par| invert_lower_triangular_impl(dst_tl.rb_mut(), src_tl, par),
63 |par| invert_lower_triangular_impl(dst_br.rb_mut(), src_br, par),
64 par,
65 );
66
67 linalg::matmul::triangular::matmul(
68 dst_bl.rb_mut(),
69 BlockStructure::Rectangular,
70 Accum::Replace,
71 src_bl,
72 BlockStructure::Rectangular,
73 dst_tl.rb(),
74 BlockStructure::TriangularLower,
75 -one::<T>(),
76 par,
77 );
78 linalg::triangular_solve::solve_lower_triangular_in_place(src_br, dst_bl, par);
79}
80
81#[math]
82fn invert_unit_lower_triangular_impl<'N, T: ComplexField>(dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>, par: Par) {
83 let N = dst.ncols();
85
86 if *N <= 2 {
87 invert_unit_lower_triangular_impl_small(dst, src);
88 return;
89 }
90
91 make_guard!(HEAD);
92 make_guard!(TAIL);
93 let mid = N.partition(N.checked_idx_inc(*N / 2), HEAD, TAIL);
94
95 let (mut dst_tl, _, mut dst_bl, mut dst_br) = { dst.split_with_mut(mid, mid) };
96 let (src_tl, _, src_bl, src_br) = { src.split_with(mid, mid) };
97
98 join_raw(
99 |par| invert_unit_lower_triangular_impl(dst_tl.rb_mut(), src_tl, par),
100 |par| invert_unit_lower_triangular_impl(dst_br.rb_mut(), src_br, par),
101 par,
102 );
103
104 linalg::matmul::triangular::matmul(
105 dst_bl.rb_mut(),
106 BlockStructure::Rectangular,
107 Accum::Replace,
108 src_bl,
109 BlockStructure::Rectangular,
110 dst_tl.rb(),
111 BlockStructure::UnitTriangularLower,
112 -one::<T>(),
113 par,
114 );
115 linalg::triangular_solve::solve_unit_lower_triangular_in_place(src_br, dst_bl, par);
116}
117
118#[track_caller]
125pub fn invert_unit_lower_triangular<T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, par: Par) {
126 Assert!(all(dst.nrows() == src.nrows(), dst.ncols() == src.ncols(), dst.nrows() == dst.ncols()));
127
128 with_dim!(N, dst.nrows().unbound());
129
130 invert_unit_lower_triangular_impl(dst.as_shape_mut(N, N).as_dyn_stride_mut(), src.as_shape(N, N).as_dyn_stride(), par)
131}
132
133#[track_caller]
140pub fn invert_lower_triangular<T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, par: Par) {
141 Assert!(all(dst.nrows() == src.nrows(), dst.ncols() == src.ncols(), dst.nrows() == dst.ncols()));
142
143 with_dim!(N, dst.nrows().unbound());
144
145 invert_lower_triangular_impl(dst.as_shape_mut(N, N).as_dyn_stride_mut(), src.as_shape(N, N).as_dyn_stride(), par)
146}
147
148#[track_caller]
155pub fn invert_unit_upper_triangular<T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, par: Par) {
156 invert_unit_lower_triangular(dst.reverse_rows_and_cols_mut(), src.reverse_rows_and_cols(), par)
157}
158
159#[track_caller]
166pub fn invert_upper_triangular<T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, par: Par) {
167 invert_lower_triangular(dst.reverse_rows_and_cols_mut(), src.reverse_rows_and_cols(), par)
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::{Mat, MatRef, assert};
174 use assert_approx_eq::assert_approx_eq;
175 use linalg::matmul::triangular;
176 use rand::SeedableRng;
177 use rand::rngs::StdRng;
178 use rand_distr::{Distribution, StandardNormal};
179
180 #[test]
181 fn test_invert_lower() {
182 let rng = &mut StdRng::seed_from_u64(0);
183 (0..32).for_each(|n| {
184 let mut a: Mat<f64> = crate::stats::CwiseMatDistribution {
185 nrows: n,
186 ncols: n,
187 dist: StandardNormal,
188 }
189 .sample(rng);
190 a += MatRef::from_repeated_ref(&2.0, n, n);
191 let mut inv = Mat::zeros(n, n);
192
193 invert_lower_triangular(inv.as_mut(), a.as_ref(), Par::rayon(0));
194
195 let mut prod = Mat::zeros(n, n);
196 triangular::matmul(
197 prod.as_mut(),
198 BlockStructure::Rectangular,
199 Accum::Replace,
200 a.as_ref(),
201 BlockStructure::TriangularLower,
202 inv.as_ref(),
203 BlockStructure::TriangularLower,
204 1.0,
205 Par::rayon(0),
206 );
207
208 for i in 0..n {
209 for j in 0..n {
210 let target = if i == j { 1.0 } else { 0.0 };
211 assert_approx_eq!(prod[(i, j)], target, 1e-4);
212 }
213 }
214 });
215 }
216
217 #[test]
218 fn test_invert_unit_lower() {
219 let rng = &mut StdRng::seed_from_u64(0);
220 (0..32).for_each(|n| {
221 let mut a: Mat<f64> = crate::stats::CwiseMatDistribution {
222 nrows: n,
223 ncols: n,
224 dist: StandardNormal,
225 }
226 .sample(rng);
227 a += MatRef::from_repeated_ref(&2.0, n, n);
228 let mut inv = Mat::zeros(n, n);
229
230 invert_unit_lower_triangular(inv.as_mut(), a.as_ref(), Par::rayon(0));
231
232 let mut prod = Mat::zeros(n, n);
233 triangular::matmul(
234 prod.as_mut(),
235 BlockStructure::Rectangular,
236 Accum::Replace,
237 a.as_ref(),
238 BlockStructure::UnitTriangularLower,
239 inv.as_ref(),
240 BlockStructure::UnitTriangularLower,
241 1.0,
242 Par::rayon(0),
243 );
244
245 for i in 0..n {
246 for j in 0..n {
247 let target = if i == j { 1.0 } else { 0.0 };
248 assert_approx_eq!(prod[(i, j)], target, 1e-4);
249 }
250 }
251 });
252 }
253
254 #[test]
255 fn test_invert_upper() {
256 let rng = &mut StdRng::seed_from_u64(0);
257 (0..32).for_each(|n| {
258 let mut a: Mat<f64> = crate::stats::CwiseMatDistribution {
259 nrows: n,
260 ncols: n,
261 dist: StandardNormal,
262 }
263 .sample(rng);
264 a += MatRef::from_repeated_ref(&2.0, n, n);
265 let mut inv = Mat::zeros(n, n);
266
267 invert_upper_triangular(inv.as_mut(), a.as_ref(), Par::rayon(0));
268
269 let mut prod = Mat::zeros(n, n);
270 triangular::matmul(
271 prod.as_mut(),
272 BlockStructure::Rectangular,
273 Accum::Replace,
274 a.as_ref(),
275 BlockStructure::TriangularUpper,
276 inv.as_ref(),
277 BlockStructure::TriangularUpper,
278 1.0,
279 Par::rayon(0),
280 );
281
282 for i in 0..n {
283 for j in 0..n {
284 let target = if i == j { 1.0 } else { 0.0 };
285 assert_approx_eq!(prod[(i, j)], target, 1e-4);
286 }
287 }
288 });
289 }
290
291 #[test]
292 fn test_invert_unit_upper() {
293 let rng = &mut StdRng::seed_from_u64(0);
294 (0..32).for_each(|n| {
295 let mut a: Mat<f64> = crate::stats::CwiseMatDistribution {
296 nrows: n,
297 ncols: n,
298 dist: StandardNormal,
299 }
300 .sample(rng);
301 a += MatRef::from_repeated_ref(&2.0, n, n);
302
303 let mut inv = Mat::zeros(n, n);
304
305 invert_unit_upper_triangular(inv.as_mut(), a.as_ref(), Par::rayon(0));
306
307 let mut prod = Mat::zeros(n, n);
308 triangular::matmul(
309 prod.as_mut(),
310 BlockStructure::Rectangular,
311 Accum::Replace,
312 a.as_ref(),
313 BlockStructure::UnitTriangularUpper,
314 inv.as_ref(),
315 BlockStructure::UnitTriangularUpper,
316 1.0,
317 Par::rayon(0),
318 );
319
320 for i in 0..n {
321 for j in 0..n {
322 let target = if i == j { 1.0 } else { 0.0 };
323 assert_approx_eq!(prod[(i, j)], target, 1e-4);
324 }
325 }
326 });
327 }
328}