faer/linalg/cholesky/ldlt/
update.rs

1use crate::internal_prelude::*;
2use pulp::Simd;
3
4#[math]
5fn rank_update_step_simd<T: ComplexField>(
6	L: ColMut<'_, T, usize, ContiguousFwd>,
7	W: MatMut<'_, T, usize, usize, ContiguousFwd>,
8	p: ColRef<'_, T>,
9	beta: ColRef<'_, T>,
10	align_offset: usize,
11) {
12	struct Impl<'a, 'N, 'R, T: ComplexField> {
13		L: ColMut<'a, T, Dim<'N>, ContiguousFwd>,
14		W: MatMut<'a, T, Dim<'N>, Dim<'R>, ContiguousFwd>,
15		p: ColRef<'a, T, Dim<'R>>,
16		beta: ColRef<'a, T, Dim<'R>>,
17		align_offset: usize,
18	}
19
20	impl<'a, 'N, 'R, T: ComplexField> pulp::WithSimd for Impl<'a, 'N, 'R, T> {
21		type Output = ();
22
23		#[inline(always)]
24		fn with_simd<S: Simd>(self, simd: S) {
25			let Self { L, W, p, beta, align_offset } = self;
26
27			let mut L = L;
28			let mut W = W;
29			let N = W.nrows();
30			let R = W.ncols();
31
32			let simd = SimdCtx::<T, S>::new_align(T::simd_ctx(simd), N, align_offset);
33			let (head, body, tail) = simd.indices();
34
35			let mut iter = R.indices();
36			let (i0, i1, i2, i3) = (iter.next(), iter.next(), iter.next(), iter.next());
37
38			match (i0, i1, i2, i3) {
39				(Some(i0), None, None, None) => {
40					let p0 = simd.splat(&p[i0]);
41					let beta0 = simd.splat(&beta[i0]);
42
43					macro_rules! simd {
44						($i: expr) => {{
45							let i = $i;
46							let mut l = simd.read(L.rb(), i);
47							let mut w0 = simd.read(W.rb().col(i0), i);
48
49							w0 = simd.mul_add(p0, l, w0);
50							l = simd.mul_add(beta0, w0, l);
51
52							simd.write(L.rb_mut(), i, l);
53							simd.write(W.rb_mut().col_mut(i0), i, w0);
54						}};
55					}
56
57					if let Some(i) = head {
58						simd!(i);
59					}
60					for i in body {
61						simd!(i);
62					}
63					if let Some(i) = tail {
64						simd!(i);
65					}
66				},
67				(Some(i0), Some(i1), None, None) => {
68					let (p0, p1) = (simd.splat(&p[i0]), simd.splat(&p[i1]));
69					let (beta0, beta1) = (simd.splat(&beta[i0]), simd.splat(&beta[i1]));
70
71					macro_rules! simd {
72						($i: expr) => {{
73							let i = $i;
74							let mut l = simd.read(L.rb(), i);
75							let mut w0 = simd.read(W.rb().col(i0), i);
76							let mut w1 = simd.read(W.rb().col(i1), i);
77
78							w0 = simd.mul_add(p0, l, w0);
79							l = simd.mul_add(beta0, w0, l);
80							w1 = simd.mul_add(p1, l, w1);
81							l = simd.mul_add(beta1, w1, l);
82
83							simd.write(L.rb_mut(), i, l);
84							simd.write(W.rb_mut().col_mut(i0), i, w0);
85							simd.write(W.rb_mut().col_mut(i1), i, w1);
86						}};
87					}
88
89					if let Some(i) = head {
90						simd!(i);
91					}
92					for i in body {
93						simd!(i);
94					}
95					if let Some(i) = tail {
96						simd!(i);
97					}
98				},
99				(Some(i0), Some(i1), Some(i2), None) => {
100					let (p0, p1, p2) = (simd.splat(&p[i0]), simd.splat(&p[i1]), simd.splat(&p[i2]));
101					let (beta0, beta1, beta2) = (simd.splat(&beta[i0]), simd.splat(&beta[i1]), simd.splat(&beta[i2]));
102
103					macro_rules! simd {
104						($i: expr) => {{
105							let i = $i;
106							let mut l = simd.read(L.rb(), i);
107							let mut w0 = simd.read(W.rb().col(i0), i);
108							let mut w1 = simd.read(W.rb().col(i1), i);
109							let mut w2 = simd.read(W.rb().col(i2), i);
110
111							w0 = simd.mul_add(p0, l, w0);
112							l = simd.mul_add(beta0, w0, l);
113							w1 = simd.mul_add(p1, l, w1);
114							l = simd.mul_add(beta1, w1, l);
115							w2 = simd.mul_add(p2, l, w2);
116							l = simd.mul_add(beta2, w2, l);
117
118							simd.write(L.rb_mut(), i, l);
119							simd.write(W.rb_mut().col_mut(i0), i, w0);
120							simd.write(W.rb_mut().col_mut(i1), i, w1);
121							simd.write(W.rb_mut().col_mut(i2), i, w2);
122						}};
123					}
124
125					if let Some(i) = head {
126						simd!(i);
127					}
128					for i in body {
129						simd!(i);
130					}
131					if let Some(i) = tail {
132						simd!(i);
133					}
134				},
135				(Some(i0), Some(i1), Some(i2), Some(i3)) => {
136					let (p0, p1, p2, p3) = (simd.splat(&p[i0]), simd.splat(&p[i1]), simd.splat(&p[i2]), simd.splat(&p[i3]));
137					let (beta0, beta1, beta2, beta3) = (simd.splat(&beta[i0]), simd.splat(&beta[i1]), simd.splat(&beta[i2]), simd.splat(&beta[i3]));
138
139					macro_rules! simd {
140						($i: expr) => {{
141							let i = $i;
142							let mut l = simd.read(L.rb(), i);
143							let mut w0 = simd.read(W.rb().col(i0), i);
144							let mut w1 = simd.read(W.rb().col(i1), i);
145							let mut w2 = simd.read(W.rb().col(i2), i);
146							let mut w3 = simd.read(W.rb().col(i3), i);
147
148							w0 = simd.mul_add(p0, l, w0);
149							l = simd.mul_add(beta0, w0, l);
150							w1 = simd.mul_add(p1, l, w1);
151							l = simd.mul_add(beta1, w1, l);
152							w2 = simd.mul_add(p2, l, w2);
153							l = simd.mul_add(beta2, w2, l);
154							w3 = simd.mul_add(p3, l, w3);
155							l = simd.mul_add(beta3, w3, l);
156
157							simd.write(L.rb_mut(), i, l);
158							simd.write(W.rb_mut().col_mut(i0), i, w0);
159							simd.write(W.rb_mut().col_mut(i1), i, w1);
160							simd.write(W.rb_mut().col_mut(i2), i, w2);
161							simd.write(W.rb_mut().col_mut(i3), i, w3);
162						}};
163					}
164
165					if let Some(i) = head {
166						simd!(i);
167					}
168					for i in body {
169						simd!(i);
170					}
171					if let Some(i) = tail {
172						simd!(i);
173					}
174				},
175				_ => panic!(),
176			}
177		}
178	}
179
180	with_dim!(N, W.nrows());
181	with_dim!(R, W.ncols());
182
183	dispatch!(
184		Impl {
185			L: L.as_row_shape_mut(N),
186			W: W.as_shape_mut(N, R),
187			p: p.as_row_shape(R),
188			beta: beta.as_row_shape(R),
189			align_offset,
190		},
191		Impl,
192		T
193	)
194}
195
196#[math]
197fn rank_update_step_fallback<T: ComplexField>(L: ColMut<'_, T>, W: MatMut<'_, T>, p: ColRef<'_, T>, beta: ColRef<'_, T>) {
198	let mut L = L;
199	let mut W = W;
200	let n = W.nrows();
201	let r = W.ncols();
202
203	let mut iter = 0..r;
204	let (i0, i1, i2, i3) = (iter.next(), iter.next(), iter.next(), iter.next());
205
206	match (i0, i1, i2, i3) {
207		(Some(i0), None, None, None) => {
208			let p0 = &p[i0];
209			let beta0 = &beta[i0];
210
211			for i in 0..n {
212				let mut l = copy(L[i]);
213				let mut w0 = copy(W[(i, i0)]);
214
215				w0 = *p0 * l + w0;
216				l = *beta0 * w0 + l;
217
218				L[i] = l;
219				W[(i, i0)] = w0;
220			}
221		},
222		(Some(i0), Some(i1), None, None) => {
223			let (p0, p1) = (&p[i0], &p[i1]);
224			let (beta0, beta1) = (&beta[i0], &beta[i1]);
225
226			for i in 0..n {
227				let mut l = copy(L[i]);
228				let mut w0 = copy(W[(i, i0)]);
229				let mut w1 = copy(W[(i, i1)]);
230
231				w0 = *p0 * l + w0;
232				l = *beta0 * w0 + l;
233				w1 = *p1 * l + w1;
234				l = *beta1 * w1 + l;
235
236				L[i] = l;
237				W[(i, i0)] = w0;
238				W[(i, i1)] = w1;
239			}
240		},
241		(Some(i0), Some(i1), Some(i2), None) => {
242			let (p0, p1, p2) = (&p[i0], &p[i1], &p[i2]);
243			let (beta0, beta1, beta2) = (&beta[i0], &beta[i1], &beta[i2]);
244
245			for i in 0..n {
246				let mut l = copy(L[i]);
247				let mut w0 = copy(W[(i, i0)]);
248				let mut w1 = copy(W[(i, i1)]);
249				let mut w2 = copy(W[(i, i2)]);
250
251				w0 = *p0 * l + w0;
252				l = *beta0 * w0 + l;
253				w1 = *p1 * l + w1;
254				l = *beta1 * w1 + l;
255				w2 = *p2 * l + w2;
256				l = *beta2 * w2 + l;
257
258				L[i] = l;
259				W[(i, i0)] = w0;
260				W[(i, i1)] = w1;
261				W[(i, i2)] = w2;
262			}
263		},
264		(Some(i0), Some(i1), Some(i2), Some(i3)) => {
265			let (p0, p1, p2, p3) = (&p[i0], &p[i1], &p[i2], &p[i3]);
266			let (beta0, beta1, beta2, beta3) = (&beta[i0], &beta[i1], &beta[i2], &beta[i3]);
267
268			for i in 0..n {
269				let mut l = copy(L[i]);
270				let mut w0 = copy(W[(i, i0)]);
271				let mut w1 = copy(W[(i, i1)]);
272				let mut w2 = copy(W[(i, i2)]);
273				let mut w3 = copy(W[(i, i3)]);
274
275				w0 = *p0 * l + w0;
276				l = *beta0 * w0 + l;
277				w1 = *p1 * l + w1;
278				l = *beta1 * w1 + l;
279				w2 = *p2 * l + w2;
280				l = *beta2 * w2 + l;
281				w3 = *p3 * l + w3;
282				l = *beta3 * w3 + l;
283
284				L[i] = l;
285				W[(i, i0)] = w0;
286				W[(i, i1)] = w1;
287				W[(i, i2)] = w2;
288				W[(i, i3)] = w3;
289			}
290		},
291		_ => panic!(),
292	}
293}
294
295struct RankRUpdate<'a, T: ComplexField> {
296	ld: MatMut<'a, T>,
297	w: MatMut<'a, T>,
298	alpha: ColMut<'a, T>,
299	r: &'a mut dyn FnMut() -> usize,
300}
301
302impl<T: ComplexField> RankRUpdate<'_, T> {
303	// On the Modification of LDLT Factorizations
304	// By R. Fletcher and M. J. D. Powell
305	// https://www.ams.org/journals/mcom/1974-28-128/S0025-5718-1974-0359297-1/S0025-5718-1974-0359297-1.pdf
306
307	#[math]
308	fn run(self) {
309		let Self { mut ld, mut w, mut alpha, r } = self;
310
311		let n = w.nrows();
312		let k = w.ncols();
313
314		for j in 0..n {
315			let mut L_col = ld.rb_mut().col_mut(j);
316
317			let r = Ord::min(r(), k);
318			let mut W = w.rb_mut().subcols_mut(0, r);
319			let mut alpha = alpha.rb_mut().subrows_mut(0, r);
320			let R = r;
321
322			const BLOCKSIZE: usize = 4;
323
324			let mut r = 0;
325			while r < R {
326				let bs = Ord::min(BLOCKSIZE, R - r);
327
328				stack_mat!(p, bs, 1, BLOCKSIZE, 1, T);
329				stack_mat!(beta, bs, 1, BLOCKSIZE, 1, T);
330
331				let mut p = p.rb_mut().col_mut(0);
332				let mut beta = beta.rb_mut().col_mut(0);
333
334				for k in 0..bs {
335					let p = p.rb_mut().at_mut(k);
336					let beta = beta.rb_mut().at_mut(k);
337					let alpha = alpha.rb_mut().at_mut(r + k);
338					let d = L_col.rb_mut().at_mut(j);
339
340					let w = W.rb().col(r + k);
341
342					*p = copy(w[j]);
343
344					let alpha_conj_p = *alpha * conj(*p);
345					let new_d = real(*d) + real(mul(alpha_conj_p, *p));
346					*beta = mul_real(alpha_conj_p, recip(new_d));
347
348					*alpha = from_real(real(*alpha) - new_d * abs2(*beta));
349					*d = from_real(new_d);
350					*p = -*p;
351				}
352
353				let mut L_col = L_col.rb_mut().get_mut(j + 1..);
354				let mut W_col = W.rb_mut().subcols_mut(r, bs).get_mut(j + 1.., ..);
355
356				if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
357					if let (Some(L_col), Some(W_col)) = (L_col.rb_mut().try_as_col_major_mut(), W_col.rb_mut().try_as_col_major_mut()) {
358						rank_update_step_simd(L_col, W_col, p.rb(), beta.rb(), simd_align(j + 1));
359					} else {
360						rank_update_step_fallback(L_col, W_col, p.rb(), beta.rb());
361					}
362				} else {
363					rank_update_step_fallback(L_col, W_col, p.rb(), beta.rb());
364				}
365				r += bs;
366			}
367		}
368	}
369}
370
371#[track_caller]
372pub fn rank_r_update_clobber<T: ComplexField>(cholesky_factors: MatMut<'_, T>, w: MatMut<'_, T>, alpha: DiagMut<'_, T>) {
373	let n = cholesky_factors.nrows();
374	let r = w.ncols();
375
376	if n == 0 {
377		return;
378	}
379
380	RankRUpdate {
381		ld: cholesky_factors,
382		w,
383		alpha: alpha.column_vector_mut(),
384		r: &mut || r,
385	}
386	.run();
387}
388
389#[cfg(test)]
390mod tests {
391	use dyn_stack::MemBuffer;
392
393	use super::*;
394	use crate::stats::prelude::*;
395	use crate::utils::approx::*;
396	use crate::{Col, Mat, assert, c64};
397
398	#[test]
399	fn test_rank_update() {
400		let rng = &mut StdRng::seed_from_u64(0);
401
402		let approx_eq = CwiseMat(ApproxEq {
403			abs_tol: 1e-12,
404			rel_tol: 1e-12,
405		});
406
407		for r in [0, 1, 2, 3, 4, 5, 6, 7, 8, 10] {
408			for n in [2, 4, 8, 15] {
409				let A = CwiseMatDistribution {
410					nrows: n,
411					ncols: n,
412					dist: ComplexDistribution::new(StandardNormal, StandardNormal),
413				}
414				.rand::<Mat<c64>>(rng);
415				let mut W = CwiseMatDistribution {
416					nrows: n,
417					ncols: r,
418					dist: ComplexDistribution::new(StandardNormal, StandardNormal),
419				}
420				.rand::<Mat<c64>>(rng);
421				let mut alpha = CwiseColDistribution {
422					nrows: r,
423					dist: ComplexDistribution::new(StandardNormal, StandardNormal),
424				}
425				.rand::<Col<c64>>(rng)
426				.into_diagonal();
427
428				for j in 0..r {
429					alpha.column_vector_mut()[j].im = 0.0;
430				}
431
432				let A = &A * &A.adjoint();
433				let A_new = &A + &W * &alpha * &W.adjoint();
434
435				let A = A.as_ref();
436				let A_new = A_new.as_ref();
437
438				let mut L = A.cloned();
439				let mut L = L.as_mut();
440
441				linalg::cholesky::ldlt::factor::cholesky_in_place(
442					L.rb_mut(),
443					default(),
444					Par::Seq,
445					MemStack::new(&mut MemBuffer::new(linalg::cholesky::ldlt::factor::cholesky_in_place_scratch::<c64>(
446						n,
447						Par::Seq,
448						default(),
449					))),
450					default(),
451				)
452				.unwrap();
453
454				linalg::cholesky::ldlt::update::rank_r_update_clobber(L.rb_mut(), W.as_mut(), alpha.as_mut());
455				let D = L.as_mut().diagonal().column_vector().as_mat().cloned();
456				let D = D.col(0).as_diagonal();
457
458				for j in 0..n {
459					for i in 0..j {
460						L[(i, j)] = c64::ZERO;
461					}
462					L[(j, j)] = c64::ONE;
463				}
464				let L = L.as_ref();
465
466				assert!(A_new ~ L * D * L.adjoint());
467			}
468		}
469	}
470}