faer/linalg/lu/full_pivoting/
reconstruct.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn reconstruct_scratch<I: Index, T: ComplexField>(nrows: usize, ncols: usize, par: Par) -> StackReq {
6	_ = par;
7	temp_mat_scratch::<T>(nrows, ncols)
8}
9
10#[track_caller]
11pub fn reconstruct<I: Index, T: ComplexField>(
12	out: MatMut<'_, T>,
13	L: MatRef<'_, T>,
14	U: MatRef<'_, T>,
15	row_perm: PermRef<'_, I>,
16	col_perm: PermRef<'_, I>,
17	par: Par,
18	stack: &mut MemStack,
19) {
20	let m = L.nrows();
21	let n = U.ncols();
22	let size = Ord::min(m, n);
23	assert!(all(out.nrows() == m, out.ncols() == n, row_perm.len() == m, col_perm.len() == n,));
24
25	let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(m, n, stack) };
26	let mut tmp = tmp.as_mat_mut();
27	let mut out = out;
28
29	linalg::matmul::triangular::matmul(
30		tmp.rb_mut().get_mut(..size, ..size),
31		BlockStructure::Rectangular,
32		Accum::Replace,
33		L.get(..size, ..size),
34		BlockStructure::UnitTriangularLower,
35		U.get(..size, ..size),
36		BlockStructure::TriangularUpper,
37		one(),
38		par,
39	);
40
41	if m > n {
42		linalg::matmul::triangular::matmul(
43			tmp.rb_mut().get_mut(size.., ..size),
44			BlockStructure::Rectangular,
45			Accum::Replace,
46			L.get(size.., ..size),
47			BlockStructure::Rectangular,
48			U.get(..size, ..size),
49			BlockStructure::TriangularUpper,
50			one(),
51			par,
52		);
53	}
54	if m < n {
55		linalg::matmul::triangular::matmul(
56			tmp.rb_mut().get_mut(..size, size..),
57			BlockStructure::Rectangular,
58			Accum::Replace,
59			L.get(..size, ..size),
60			BlockStructure::UnitTriangularLower,
61			U.get(..size, size..),
62			BlockStructure::Rectangular,
63			one(),
64			par,
65		);
66	}
67
68	with_dim!(M, m);
69	with_dim!(N, n);
70
71	let row_perm = row_perm.as_shape(M).bound_arrays().1;
72	let col_perm = col_perm.as_shape(N).bound_arrays().1;
73
74	let tmp = tmp.rb().as_shape(M, N);
75	let mut out = out.rb_mut().as_shape_mut(M, N);
76
77	for j in N.indices() {
78		for i in M.indices() {
79			out[(i, j)] = tmp[(row_perm[i].zx(), col_perm[j].zx())].clone();
80		}
81	}
82}
83
84#[cfg(test)]
85mod tests {
86	use super::*;
87	use crate::assert;
88	use crate::stats::prelude::*;
89	use crate::utils::approx::*;
90	use dyn_stack::MemBuffer;
91	use linalg::lu::full_pivoting::*;
92
93	#[test]
94	fn test_reconstruct() {
95		let rng = &mut StdRng::seed_from_u64(0);
96		for (m, n) in [(100, 50), (50, 100)] {
97			let A = CwiseMatDistribution {
98				nrows: m,
99				ncols: n,
100				dist: ComplexDistribution::new(StandardNormal, StandardNormal),
101			}
102			.rand::<Mat<c64>>(rng);
103
104			let mut LU = A.to_owned();
105			let row_perm_fwd = &mut *vec![0usize; m];
106			let row_perm_bwd = &mut *vec![0usize; m];
107			let col_perm_fwd = &mut *vec![0usize; n];
108			let col_perm_bwd = &mut *vec![0usize; n];
109
110			let (_, row_perm, col_perm) = factor::lu_in_place(
111				LU.as_mut(),
112				row_perm_fwd,
113				row_perm_bwd,
114				col_perm_fwd,
115				col_perm_bwd,
116				Par::Seq,
117				MemStack::new(&mut { MemBuffer::new(factor::lu_in_place_scratch::<usize, c64>(m, n, Par::Seq, default())) }),
118				default(),
119			);
120
121			let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
122
123			let mut A_rec = Mat::zeros(m, n);
124			reconstruct::reconstruct(
125				A_rec.as_mut(),
126				LU.as_ref(),
127				LU.as_ref(),
128				row_perm,
129				col_perm,
130				Par::Seq,
131				MemStack::new(&mut MemBuffer::new(reconstruct::reconstruct_scratch::<usize, c64>(m, n, Par::Seq))),
132			);
133
134			assert!(A_rec ~ A);
135		}
136	}
137}