faer/linalg/cholesky/bunch_kaufman/
reconstruct.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::matmul::triangular::BlockStructure;
4
5pub fn reconstruct_scratch<I: Index, T: ComplexField>(dim: usize, par: Par) -> StackReq {
6	_ = par;
7	temp_mat_scratch::<T>(dim, dim)
8}
9
10#[track_caller]
11#[math]
12pub fn reconstruct<I: Index, T: ComplexField>(
13	out: MatMut<'_, T>,
14	L: MatRef<'_, T>,
15	diagonal: DiagRef<'_, T>,
16	subdiagonal: DiagRef<'_, T>,
17	perm: PermRef<'_, I>,
18	par: Par,
19	stack: &mut MemStack,
20) {
21	let n = L.nrows();
22	assert!(all(
23		out.nrows() == n,
24		out.ncols() == n,
25		L.nrows() == n,
26		L.ncols() == n,
27		diagonal.dim() == n,
28		subdiagonal.dim() == n,
29		perm.len() == n,
30	));
31
32	let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(n, n, stack) };
33	let mut tmp = tmp.as_mat_mut();
34	let mut out = out;
35	let s = subdiagonal;
36
37	out.fill(zero());
38	out.rb_mut().diagonal_mut().fill(one());
39	out.copy_from_strict_triangular_lower(L);
40
41	let mut j = 0;
42	while j < n {
43		if s[j] == zero() {
44			let d = real(L[(j, j)]);
45
46			for i in 0..n {
47				out[(i, j)] = mul_real(out[(i, j)], d);
48			}
49
50			j += 1;
51		} else {
52			let akp1k = copy(s[j]);
53			let ak = real(L[(j, j)]);
54			let akp1 = real(L[(j + 1, j + 1)]);
55
56			for i in 0..n {
57				let xk = copy(out[(i, j)]);
58				let xkp1 = copy(out[(i, j + 1)]);
59
60				out[(i, j)] = mul_real(xk, ak) + (xkp1 * akp1k);
61				out[(i, j + 1)] = mul_real(xkp1, akp1) + (xk * conj(akp1k));
62			}
63
64			j += 2;
65		}
66	}
67
68	linalg::matmul::triangular::matmul(
69		tmp.rb_mut(),
70		BlockStructure::TriangularLower,
71		Accum::Replace,
72		L,
73		BlockStructure::UnitTriangularLower,
74		out.rb().adjoint(),
75		BlockStructure::Rectangular,
76		one(),
77		par,
78	);
79
80	let perm_inv = perm.arrays().1;
81	for j in 0..n {
82		let pj = perm_inv[j].zx();
83		for i in j..n {
84			let pi = perm_inv[i].zx();
85
86			out[(i, j)] = if pi >= pj { copy(tmp[(pi, pj)]) } else { conj(tmp[(pj, pi)]) };
87		}
88	}
89
90	for j in 0..n {
91		out[(j, j)] = from_real(real(out[(j, j)]));
92	}
93}
94
95#[cfg(test)]
96mod tests {
97	use super::*;
98	use crate::assert;
99	use crate::stats::prelude::*;
100	use crate::utils::approx::*;
101	use dyn_stack::MemBuffer;
102	use linalg::cholesky::lblt::*;
103
104	#[test]
105	fn test_reconstruct() {
106		let rng = &mut StdRng::seed_from_u64(0);
107		let n = 50;
108
109		let A = CwiseMatDistribution {
110			nrows: n,
111			ncols: n,
112			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
113		}
114		.rand::<Mat<c64>>(rng);
115
116		let A = &A + A.adjoint();
117		let mut LB = A.to_owned();
118		let mut subdiag = Diag::zeros(n);
119		let perm_fwd = &mut *vec![0usize; n];
120		let perm_bwd = &mut *vec![0usize; n];
121
122		let (_, perm) = factor::cholesky_in_place(
123			LB.as_mut(),
124			subdiag.as_mut(),
125			perm_fwd,
126			perm_bwd,
127			Par::Seq,
128			MemStack::new(&mut { MemBuffer::new(factor::cholesky_in_place_scratch::<usize, c64>(n, Par::Seq, default())) }),
129			default(),
130		);
131
132		let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
133
134		let mut A_rec = Mat::zeros(n, n);
135		reconstruct::reconstruct(
136			A_rec.as_mut(),
137			LB.as_ref(),
138			LB.diagonal(),
139			subdiag.as_ref(),
140			perm,
141			Par::Seq,
142			MemStack::new(&mut MemBuffer::new(reconstruct::reconstruct_scratch::<usize, c64>(n, Par::Seq))),
143		);
144
145		for j in 0..n {
146			for i in 0..j {
147				A_rec[(i, j)] = A_rec[(j, i)].conj();
148			}
149		}
150
151		assert!(A_rec ~ A);
152	}
153}