faer/linalg/lu/full_pivoting/
reconstruct.rs1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn reconstruct_scratch<I: Index, T: ComplexField>(nrows: usize, ncols: usize, par: Par) -> StackReq {
6 _ = par;
7 temp_mat_scratch::<T>(nrows, ncols)
8}
9
10#[track_caller]
11pub fn reconstruct<I: Index, T: ComplexField>(
12 out: MatMut<'_, T>,
13 L: MatRef<'_, T>,
14 U: MatRef<'_, T>,
15 row_perm: PermRef<'_, I>,
16 col_perm: PermRef<'_, I>,
17 par: Par,
18 stack: &mut MemStack,
19) {
20 let m = L.nrows();
21 let n = U.ncols();
22 let size = Ord::min(m, n);
23 assert!(all(out.nrows() == m, out.ncols() == n, row_perm.len() == m, col_perm.len() == n,));
24
25 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(m, n, stack) };
26 let mut tmp = tmp.as_mat_mut();
27 let mut out = out;
28
29 linalg::matmul::triangular::matmul(
30 tmp.rb_mut().get_mut(..size, ..size),
31 BlockStructure::Rectangular,
32 Accum::Replace,
33 L.get(..size, ..size),
34 BlockStructure::UnitTriangularLower,
35 U.get(..size, ..size),
36 BlockStructure::TriangularUpper,
37 one(),
38 par,
39 );
40
41 if m > n {
42 linalg::matmul::triangular::matmul(
43 tmp.rb_mut().get_mut(size.., ..size),
44 BlockStructure::Rectangular,
45 Accum::Replace,
46 L.get(size.., ..size),
47 BlockStructure::Rectangular,
48 U.get(..size, ..size),
49 BlockStructure::TriangularUpper,
50 one(),
51 par,
52 );
53 }
54 if m < n {
55 linalg::matmul::triangular::matmul(
56 tmp.rb_mut().get_mut(..size, size..),
57 BlockStructure::Rectangular,
58 Accum::Replace,
59 L.get(..size, ..size),
60 BlockStructure::UnitTriangularLower,
61 U.get(..size, size..),
62 BlockStructure::Rectangular,
63 one(),
64 par,
65 );
66 }
67
68 with_dim!(M, m);
69 with_dim!(N, n);
70
71 let row_perm = row_perm.as_shape(M).bound_arrays().1;
72 let col_perm = col_perm.as_shape(N).bound_arrays().1;
73
74 let tmp = tmp.rb().as_shape(M, N);
75 let mut out = out.rb_mut().as_shape_mut(M, N);
76
77 for j in N.indices() {
78 for i in M.indices() {
79 out[(i, j)] = tmp[(row_perm[i].zx(), col_perm[j].zx())].clone();
80 }
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use crate::assert;
88 use crate::stats::prelude::*;
89 use crate::utils::approx::*;
90 use dyn_stack::MemBuffer;
91 use linalg::lu::full_pivoting::*;
92
93 #[test]
94 fn test_reconstruct() {
95 let rng = &mut StdRng::seed_from_u64(0);
96 for (m, n) in [(100, 50), (50, 100)] {
97 let A = CwiseMatDistribution {
98 nrows: m,
99 ncols: n,
100 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
101 }
102 .rand::<Mat<c64>>(rng);
103
104 let mut LU = A.to_owned();
105 let row_perm_fwd = &mut *vec![0usize; m];
106 let row_perm_bwd = &mut *vec![0usize; m];
107 let col_perm_fwd = &mut *vec![0usize; n];
108 let col_perm_bwd = &mut *vec![0usize; n];
109
110 let (_, row_perm, col_perm) = factor::lu_in_place(
111 LU.as_mut(),
112 row_perm_fwd,
113 row_perm_bwd,
114 col_perm_fwd,
115 col_perm_bwd,
116 Par::Seq,
117 MemStack::new(&mut { MemBuffer::new(factor::lu_in_place_scratch::<usize, c64>(m, n, Par::Seq, default())) }),
118 default(),
119 );
120
121 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
122
123 let mut A_rec = Mat::zeros(m, n);
124 reconstruct::reconstruct(
125 A_rec.as_mut(),
126 LU.as_ref(),
127 LU.as_ref(),
128 row_perm,
129 col_perm,
130 Par::Seq,
131 MemStack::new(&mut MemBuffer::new(reconstruct::reconstruct_scratch::<usize, c64>(m, n, Par::Seq))),
132 );
133
134 assert!(A_rec ~ A);
135 }
136 }
137}