faer/linalg/cholesky/llt_pivoting/
reconstruct.rs1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn reconstruct_scratch<I: Index, T: ComplexField>(dim: usize, par: Par) -> StackReq {
6 _ = (dim, par);
7 temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11#[math]
12pub fn reconstruct<I: Index, T: ComplexField>(out: MatMut<'_, T>, L: MatRef<'_, T>, perm: PermRef<'_, I>, par: Par, stack: &mut MemStack) {
13 let mut out = out;
14 let n = out.nrows();
15
16 assert!(all(out.nrows() == n, out.ncols() == n, L.nrows() == n, L.ncols() == n,));
17
18 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
19 let mut tmp = tmp.as_mat_mut();
20
21 linalg::matmul::triangular::matmul(
22 tmp.rb_mut(),
23 BlockStructure::TriangularLower,
24 Accum::Replace,
25 L,
26 BlockStructure::TriangularLower,
27 L.adjoint(),
28 BlockStructure::TriangularUpper,
29 one(),
30 par,
31 );
32
33 let p = perm.arrays().1;
34
35 for j in 0..n {
36 let jj = p[j].zx();
37 for i in j..n {
38 let ii = p[i].zx();
39
40 if ii >= jj {
41 out[(i, j)] = copy(tmp[(ii, jj)]);
42 } else {
43 out[(i, j)] = conj(tmp[(jj, ii)]);
44 }
45 }
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52 use crate::assert;
53 use crate::stats::prelude::*;
54 use crate::utils::approx::*;
55 use dyn_stack::MemBuffer;
56 use linalg::cholesky::llt_pivoting::*;
57
58 #[test]
59 fn test_reconstruct() {
60 let rng = &mut StdRng::seed_from_u64(0);
61 let n = 50;
62
63 let A = CwiseMatDistribution {
64 nrows: n,
65 ncols: n,
66 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
67 }
68 .rand::<Mat<c64>>(rng);
69
70 let A = &A * A.adjoint();
71 let mut L = A.to_owned();
72 let perm_fwd = &mut *vec![0usize; n];
73 let perm_bwd = &mut *vec![0usize; n];
74
75 let (_, perm) = factor::cholesky_in_place(
76 L.as_mut(),
77 perm_fwd,
78 perm_bwd,
79 Par::Seq,
80 MemStack::new(&mut { MemBuffer::new(factor::cholesky_in_place_scratch::<usize, c64>(n, Par::Seq, default())) }),
81 default(),
82 )
83 .unwrap();
84
85 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
86
87 let mut A_rec = Mat::zeros(n, n);
88 reconstruct::reconstruct(
89 A_rec.as_mut(),
90 L.as_ref(),
91 perm,
92 Par::Seq,
93 MemStack::new(&mut MemBuffer::new(reconstruct::reconstruct_scratch::<usize, c64>(n, Par::Seq))),
94 );
95
96 for j in 0..n {
97 for i in 0..j {
98 A_rec[(i, j)] = A_rec[(j, i)].conj();
99 }
100 }
101
102 assert!(A_rec ~ A);
103 }
104}