faer/linalg/cholesky/llt_pivoting/
reconstruct.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn reconstruct_scratch<I: Index, T: ComplexField>(dim: usize, par: Par) -> StackReq {
6	_ = (dim, par);
7	temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11#[math]
12pub fn reconstruct<I: Index, T: ComplexField>(out: MatMut<'_, T>, L: MatRef<'_, T>, perm: PermRef<'_, I>, par: Par, stack: &mut MemStack) {
13	let mut out = out;
14	let n = out.nrows();
15
16	assert!(all(out.nrows() == n, out.ncols() == n, L.nrows() == n, L.ncols() == n,));
17
18	let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
19	let mut tmp = tmp.as_mat_mut();
20
21	linalg::matmul::triangular::matmul(
22		tmp.rb_mut(),
23		BlockStructure::TriangularLower,
24		Accum::Replace,
25		L,
26		BlockStructure::TriangularLower,
27		L.adjoint(),
28		BlockStructure::TriangularUpper,
29		one(),
30		par,
31	);
32
33	let p = perm.arrays().1;
34
35	for j in 0..n {
36		let jj = p[j].zx();
37		for i in j..n {
38			let ii = p[i].zx();
39
40			if ii >= jj {
41				out[(i, j)] = copy(tmp[(ii, jj)]);
42			} else {
43				out[(i, j)] = conj(tmp[(jj, ii)]);
44			}
45		}
46	}
47}
48
49#[cfg(test)]
50mod tests {
51	use super::*;
52	use crate::assert;
53	use crate::stats::prelude::*;
54	use crate::utils::approx::*;
55	use dyn_stack::MemBuffer;
56	use linalg::cholesky::llt_pivoting::*;
57
58	#[test]
59	fn test_reconstruct() {
60		let rng = &mut StdRng::seed_from_u64(0);
61		let n = 50;
62
63		let A = CwiseMatDistribution {
64			nrows: n,
65			ncols: n,
66			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
67		}
68		.rand::<Mat<c64>>(rng);
69
70		let A = &A * A.adjoint();
71		let mut L = A.to_owned();
72		let perm_fwd = &mut *vec![0usize; n];
73		let perm_bwd = &mut *vec![0usize; n];
74
75		let (_, perm) = factor::cholesky_in_place(
76			L.as_mut(),
77			perm_fwd,
78			perm_bwd,
79			Par::Seq,
80			MemStack::new(&mut { MemBuffer::new(factor::cholesky_in_place_scratch::<usize, c64>(n, Par::Seq, default())) }),
81			default(),
82		)
83		.unwrap();
84
85		let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
86
87		let mut A_rec = Mat::zeros(n, n);
88		reconstruct::reconstruct(
89			A_rec.as_mut(),
90			L.as_ref(),
91			perm,
92			Par::Seq,
93			MemStack::new(&mut MemBuffer::new(reconstruct::reconstruct_scratch::<usize, c64>(n, Par::Seq))),
94		);
95
96		for j in 0..n {
97			for i in 0..j {
98				A_rec[(i, j)] = A_rec[(j, i)].conj();
99			}
100		}
101
102		assert!(A_rec ~ A);
103	}
104}