faer/linalg/lu/partial_pivoting/
reconstruct.rs1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn reconstruct_scratch<I: Index, T: ComplexField>(nrows: usize, ncols: usize, par: Par) -> StackReq {
6 _ = par;
7 temp_mat_scratch::<T>(nrows, ncols)
8}
9
10#[track_caller]
11pub fn reconstruct<I: Index, T: ComplexField>(
12 out: MatMut<'_, T>,
13 L: MatRef<'_, T>,
14 U: MatRef<'_, T>,
15 row_perm: PermRef<'_, I>,
16 par: Par,
17 stack: &mut MemStack,
18) {
19 let m = L.nrows();
20 let n = U.ncols();
21 let size = Ord::min(m, n);
22 assert!(all(out.nrows() == m, out.ncols() == n, row_perm.len() == m));
23
24 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(m, n, stack) };
25 let mut tmp = tmp.as_mat_mut();
26 let mut out = out;
27
28 linalg::matmul::triangular::matmul(
29 tmp.rb_mut().get_mut(..size, ..size),
30 BlockStructure::Rectangular,
31 Accum::Replace,
32 L.get(..size, ..size),
33 BlockStructure::UnitTriangularLower,
34 U.get(..size, ..size),
35 BlockStructure::TriangularUpper,
36 one(),
37 par,
38 );
39
40 if m > n {
41 linalg::matmul::triangular::matmul(
42 tmp.rb_mut().get_mut(size.., ..size),
43 BlockStructure::Rectangular,
44 Accum::Replace,
45 L.get(size.., ..size),
46 BlockStructure::Rectangular,
47 U.get(..size, ..size),
48 BlockStructure::TriangularUpper,
49 one(),
50 par,
51 );
52 }
53 if m < n {
54 linalg::matmul::triangular::matmul(
55 tmp.rb_mut().get_mut(..size, size..),
56 BlockStructure::Rectangular,
57 Accum::Replace,
58 L.get(..size, ..size),
59 BlockStructure::UnitTriangularLower,
60 U.get(..size, size..),
61 BlockStructure::Rectangular,
62 one(),
63 par,
64 );
65 }
66
67 crate::perm::permute_rows(out.rb_mut(), tmp.rb(), row_perm.inverse());
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use crate::assert;
74 use crate::stats::prelude::*;
75 use crate::utils::approx::*;
76 use dyn_stack::MemBuffer;
77 use linalg::lu::partial_pivoting::*;
78
79 #[test]
80 fn test_reconstruct() {
81 let rng = &mut StdRng::seed_from_u64(0);
82 for (m, n) in [(100, 50), (50, 100)] {
83 let A = CwiseMatDistribution {
84 nrows: m,
85 ncols: n,
86 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
87 }
88 .rand::<Mat<c64>>(rng);
89
90 let mut LU = A.to_owned();
91 let perm_fwd = &mut *vec![0usize; m];
92 let perm_bwd = &mut *vec![0usize; m];
93
94 let (_, perm) = factor::lu_in_place(
95 LU.as_mut(),
96 perm_fwd,
97 perm_bwd,
98 Par::Seq,
99 MemStack::new(&mut { MemBuffer::new(factor::lu_in_place_scratch::<usize, c64>(m, n, Par::Seq, default())) }),
100 default(),
101 );
102
103 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
104
105 let mut A_rec = Mat::zeros(m, n);
106 reconstruct::reconstruct(
107 A_rec.as_mut(),
108 LU.as_ref(),
109 LU.as_ref(),
110 perm,
111 Par::Seq,
112 MemStack::new(&mut MemBuffer::new(reconstruct::reconstruct_scratch::<usize, c64>(m, n, Par::Seq))),
113 );
114
115 assert!(A_rec ~ A);
116 }
117 }
118}