faer/linalg/cholesky/ldlt/
reconstruct.rs1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn reconstruct_scratch<T: ComplexField>(dim: usize, par: Par) -> StackReq {
6 _ = par;
7 temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11#[math]
12pub fn reconstruct<T: ComplexField>(out: MatMut<'_, T>, L: MatRef<'_, T>, D: DiagRef<'_, T>, par: Par, stack: &mut MemStack) {
13 let mut out = out;
14 _ = stack;
15
16 let n = out.nrows();
17 assert!(all(out.nrows() == n, out.ncols() == n, L.nrows() == n, L.ncols() == n, D.dim() == n,));
18
19 let (mut LxD, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
20 let mut LxD = LxD.as_mat_mut();
21 {
22 with_dim!(N, n);
23 let mut LxD = LxD.rb_mut().as_shape_mut(N, N);
24 let L = L.as_shape(N, N);
25 let D = D.as_shape(N);
26
27 for j in N.indices() {
28 let d = copy(D[j]);
29
30 LxD[(j, j)] = copy(d);
31 for i in j.next().to(N.end()) {
32 LxD[(i, j)] = L[(i, j)] * d;
33 }
34 }
35 }
36
37 let LxD = LxD.rb();
38
39 linalg::matmul::triangular::matmul(
40 out.rb_mut(),
41 BlockStructure::TriangularLower,
42 Accum::Replace,
43 LxD,
44 BlockStructure::TriangularLower,
45 L.adjoint(),
46 BlockStructure::UnitTriangularUpper,
47 one(),
48 par,
49 );
50}
51
52#[cfg(test)]
53mod tests {
54 use super::*;
55 use crate::assert;
56 use crate::stats::prelude::*;
57 use crate::utils::approx::*;
58 use dyn_stack::MemBuffer;
59 use linalg::cholesky::ldlt::*;
60
61 #[test]
62 fn test_reconstruct() {
63 let rng = &mut StdRng::seed_from_u64(0);
64 let n = 50;
65
66 let A = CwiseMatDistribution {
67 nrows: n,
68 ncols: n,
69 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
70 }
71 .rand::<Mat<c64>>(rng);
72
73 let A = &A * A.adjoint();
74 let mut L = A.to_owned();
75
76 factor::cholesky_in_place(
77 L.as_mut(),
78 Default::default(),
79 Par::Seq,
80 MemStack::new(&mut MemBuffer::new(factor::cholesky_in_place_scratch::<c64>(n, Par::Seq, default()))),
81 default(),
82 )
83 .unwrap();
84
85 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
86
87 let mut A_rec = Mat::zeros(n, n);
88 reconstruct::reconstruct(
89 A_rec.as_mut(),
90 L.as_ref(),
91 L.diagonal(),
92 Par::Seq,
93 MemStack::new(&mut MemBuffer::new(reconstruct::reconstruct_scratch::<c64>(n, Par::Seq))),
94 );
95
96 for j in 0..n {
97 for i in 0..j {
98 A_rec[(i, j)] = A_rec[(j, i)].conj();
99 }
100 }
101
102 assert!(A_rec ~ A);
103 }
104}