faer/linalg/lu/partial_pivoting/
reconstruct.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn reconstruct_scratch<I: Index, T: ComplexField>(nrows: usize, ncols: usize, par: Par) -> StackReq {
6	_ = par;
7	temp_mat_scratch::<T>(nrows, ncols)
8}
9
10#[track_caller]
11pub fn reconstruct<I: Index, T: ComplexField>(
12	out: MatMut<'_, T>,
13	L: MatRef<'_, T>,
14	U: MatRef<'_, T>,
15	row_perm: PermRef<'_, I>,
16	par: Par,
17	stack: &mut MemStack,
18) {
19	let m = L.nrows();
20	let n = U.ncols();
21	let size = Ord::min(m, n);
22	assert!(all(out.nrows() == m, out.ncols() == n, row_perm.len() == m));
23
24	let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(m, n, stack) };
25	let mut tmp = tmp.as_mat_mut();
26	let mut out = out;
27
28	linalg::matmul::triangular::matmul(
29		tmp.rb_mut().get_mut(..size, ..size),
30		BlockStructure::Rectangular,
31		Accum::Replace,
32		L.get(..size, ..size),
33		BlockStructure::UnitTriangularLower,
34		U.get(..size, ..size),
35		BlockStructure::TriangularUpper,
36		one(),
37		par,
38	);
39
40	if m > n {
41		linalg::matmul::triangular::matmul(
42			tmp.rb_mut().get_mut(size.., ..size),
43			BlockStructure::Rectangular,
44			Accum::Replace,
45			L.get(size.., ..size),
46			BlockStructure::Rectangular,
47			U.get(..size, ..size),
48			BlockStructure::TriangularUpper,
49			one(),
50			par,
51		);
52	}
53	if m < n {
54		linalg::matmul::triangular::matmul(
55			tmp.rb_mut().get_mut(..size, size..),
56			BlockStructure::Rectangular,
57			Accum::Replace,
58			L.get(..size, ..size),
59			BlockStructure::UnitTriangularLower,
60			U.get(..size, size..),
61			BlockStructure::Rectangular,
62			one(),
63			par,
64		);
65	}
66
67	crate::perm::permute_rows(out.rb_mut(), tmp.rb(), row_perm.inverse());
68}
69
70#[cfg(test)]
71mod tests {
72	use super::*;
73	use crate::assert;
74	use crate::stats::prelude::*;
75	use crate::utils::approx::*;
76	use dyn_stack::MemBuffer;
77	use linalg::lu::partial_pivoting::*;
78
79	#[test]
80	fn test_reconstruct() {
81		let rng = &mut StdRng::seed_from_u64(0);
82		for (m, n) in [(100, 50), (50, 100)] {
83			let A = CwiseMatDistribution {
84				nrows: m,
85				ncols: n,
86				dist: ComplexDistribution::new(StandardNormal, StandardNormal),
87			}
88			.rand::<Mat<c64>>(rng);
89
90			let mut LU = A.to_owned();
91			let perm_fwd = &mut *vec![0usize; m];
92			let perm_bwd = &mut *vec![0usize; m];
93
94			let (_, perm) = factor::lu_in_place(
95				LU.as_mut(),
96				perm_fwd,
97				perm_bwd,
98				Par::Seq,
99				MemStack::new(&mut { MemBuffer::new(factor::lu_in_place_scratch::<usize, c64>(m, n, Par::Seq, default())) }),
100				default(),
101			);
102
103			let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
104
105			let mut A_rec = Mat::zeros(m, n);
106			reconstruct::reconstruct(
107				A_rec.as_mut(),
108				LU.as_ref(),
109				LU.as_ref(),
110				perm,
111				Par::Seq,
112				MemStack::new(&mut MemBuffer::new(reconstruct::reconstruct_scratch::<usize, c64>(m, n, Par::Seq))),
113			);
114
115			assert!(A_rec ~ A);
116		}
117	}
118}