faer/linalg/cholesky/ldlt/
reconstruct.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn reconstruct_scratch<T: ComplexField>(dim: usize, par: Par) -> StackReq {
6	_ = par;
7	temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11#[math]
12pub fn reconstruct<T: ComplexField>(out: MatMut<'_, T>, L: MatRef<'_, T>, D: DiagRef<'_, T>, par: Par, stack: &mut MemStack) {
13	let mut out = out;
14	_ = stack;
15
16	let n = out.nrows();
17	assert!(all(out.nrows() == n, out.ncols() == n, L.nrows() == n, L.ncols() == n, D.dim() == n,));
18
19	let (mut LxD, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
20	let mut LxD = LxD.as_mat_mut();
21	{
22		with_dim!(N, n);
23		let mut LxD = LxD.rb_mut().as_shape_mut(N, N);
24		let L = L.as_shape(N, N);
25		let D = D.as_shape(N);
26
27		for j in N.indices() {
28			let d = copy(D[j]);
29
30			LxD[(j, j)] = copy(d);
31			for i in j.next().to(N.end()) {
32				LxD[(i, j)] = L[(i, j)] * d;
33			}
34		}
35	}
36
37	let LxD = LxD.rb();
38
39	linalg::matmul::triangular::matmul(
40		out.rb_mut(),
41		BlockStructure::TriangularLower,
42		Accum::Replace,
43		LxD,
44		BlockStructure::TriangularLower,
45		L.adjoint(),
46		BlockStructure::UnitTriangularUpper,
47		one(),
48		par,
49	);
50}
51
52#[cfg(test)]
53mod tests {
54	use super::*;
55	use crate::assert;
56	use crate::stats::prelude::*;
57	use crate::utils::approx::*;
58	use dyn_stack::MemBuffer;
59	use linalg::cholesky::ldlt::*;
60
61	#[test]
62	fn test_reconstruct() {
63		let rng = &mut StdRng::seed_from_u64(0);
64		let n = 50;
65
66		let A = CwiseMatDistribution {
67			nrows: n,
68			ncols: n,
69			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
70		}
71		.rand::<Mat<c64>>(rng);
72
73		let A = &A * A.adjoint();
74		let mut L = A.to_owned();
75
76		factor::cholesky_in_place(
77			L.as_mut(),
78			Default::default(),
79			Par::Seq,
80			MemStack::new(&mut MemBuffer::new(factor::cholesky_in_place_scratch::<c64>(n, Par::Seq, default()))),
81			default(),
82		)
83		.unwrap();
84
85		let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
86
87		let mut A_rec = Mat::zeros(n, n);
88		reconstruct::reconstruct(
89			A_rec.as_mut(),
90			L.as_ref(),
91			L.diagonal(),
92			Par::Seq,
93			MemStack::new(&mut MemBuffer::new(reconstruct::reconstruct_scratch::<c64>(n, Par::Seq))),
94		);
95
96		for j in 0..n {
97			for i in 0..j {
98				A_rec[(i, j)] = A_rec[(j, i)].conj();
99			}
100		}
101
102		assert!(A_rec ~ A);
103	}
104}