faer/linalg/cholesky/llt/
update.rs

1use crate::internal_prelude::*;
2use pulp::Simd;
3
4use super::factor::LltError;
5
6#[math]
7fn rank_update_step_simd<T: ComplexField>(
8	L: ColMut<'_, T, usize, ContiguousFwd>,
9	W: MatMut<'_, T, usize, usize, ContiguousFwd>,
10	p: ColRef<'_, T>,
11	beta: ColRef<'_, T>,
12	gamma: ColRef<'_, T>,
13	align_offset: usize,
14) {
15	struct Impl<'a, 'N, 'R, T: ComplexField> {
16		L: ColMut<'a, T, Dim<'N>, ContiguousFwd>,
17		W: MatMut<'a, T, Dim<'N>, Dim<'R>, ContiguousFwd>,
18		p: ColRef<'a, T, Dim<'R>>,
19		beta: ColRef<'a, T, Dim<'R>>,
20		gamma: ColRef<'a, T, Dim<'R>>,
21		align_offset: usize,
22	}
23
24	impl<'a, 'N, 'R, T: ComplexField> pulp::WithSimd for Impl<'a, 'N, 'R, T> {
25		type Output = ();
26
27		#[inline(always)]
28		fn with_simd<S: Simd>(self, simd: S) {
29			let Self {
30				L,
31				W,
32				p,
33				beta,
34				gamma,
35				align_offset,
36			} = self;
37
38			let mut L = L;
39			let mut W = W;
40			let N = W.nrows();
41			let R = W.ncols();
42
43			let simd = SimdCtx::<T, S>::new_align(T::simd_ctx(simd), N, align_offset);
44			let (head, body, tail) = simd.indices();
45
46			let mut iter = R.indices();
47			let (i0, i1, i2, i3) = (iter.next(), iter.next(), iter.next(), iter.next());
48
49			match (i0, i1, i2, i3) {
50				(Some(i0), None, None, None) => {
51					let p0 = simd.splat(&p[i0]);
52					let beta0 = simd.splat(&beta[i0]);
53					let gamma0 = simd.splat_real(&real(gamma[i0]));
54
55					macro_rules! simd {
56						($i: expr) => {{
57							let i = $i;
58							let mut l = simd.read(L.rb(), i);
59							let mut w0 = simd.read(W.rb().col(i0), i);
60
61							w0 = simd.mul_add(p0, l, w0);
62							l = simd.mul_add(beta0, w0, simd.mul_real(l, gamma0));
63
64							simd.write(L.rb_mut(), i, l);
65							simd.write(W.rb_mut().col_mut(i0), i, w0);
66						}};
67					}
68
69					if let Some(i) = head {
70						simd!(i);
71					}
72					for i in body {
73						simd!(i);
74					}
75					if let Some(i) = tail {
76						simd!(i);
77					}
78				},
79				(Some(i0), Some(i1), None, None) => {
80					let (p0, p1) = (simd.splat(&p[i0]), simd.splat(&p[i1]));
81					let (beta0, beta1) = (simd.splat(&beta[i0]), simd.splat(&beta[i1]));
82					let (gamma0, gamma1) = (simd.splat_real(&real(gamma[i0])), simd.splat_real(&real(gamma[i1])));
83
84					macro_rules! simd {
85						($i: expr) => {{
86							let i = $i;
87							let mut l = simd.read(L.rb(), i);
88							let mut w0 = simd.read(W.rb().col(i0), i);
89							let mut w1 = simd.read(W.rb().col(i1), i);
90
91							w0 = simd.mul_add(p0, l, w0);
92							l = simd.mul_add(beta0, w0, simd.mul_real(l, gamma0));
93							w1 = simd.mul_add(p1, l, w1);
94							l = simd.mul_add(beta1, w1, simd.mul_real(l, gamma1));
95
96							simd.write(L.rb_mut(), i, l);
97							simd.write(W.rb_mut().col_mut(i0), i, w0);
98							simd.write(W.rb_mut().col_mut(i1), i, w1);
99						}};
100					}
101
102					if let Some(i) = head {
103						simd!(i);
104					}
105					for i in body {
106						simd!(i);
107					}
108					if let Some(i) = tail {
109						simd!(i);
110					}
111				},
112				(Some(i0), Some(i1), Some(i2), None) => {
113					let (p0, p1, p2) = (simd.splat(&p[i0]), simd.splat(&p[i1]), simd.splat(&p[i2]));
114					let (beta0, beta1, beta2) = (simd.splat(&beta[i0]), simd.splat(&beta[i1]), simd.splat(&beta[i2]));
115					let (gamma0, gamma1, gamma2) = (
116						simd.splat_real(&real(gamma[i0])),
117						simd.splat_real(&real(gamma[i1])),
118						simd.splat_real(&real(gamma[i2])),
119					);
120
121					macro_rules! simd {
122						($i: expr) => {{
123							let i = $i;
124							let mut l = simd.read(L.rb(), i);
125							let mut w0 = simd.read(W.rb().col(i0), i);
126							let mut w1 = simd.read(W.rb().col(i1), i);
127							let mut w2 = simd.read(W.rb().col(i2), i);
128
129							w0 = simd.mul_add(p0, l, w0);
130							l = simd.mul_add(beta0, w0, simd.mul_real(l, gamma0));
131							w1 = simd.mul_add(p1, l, w1);
132							l = simd.mul_add(beta1, w1, simd.mul_real(l, gamma1));
133							w2 = simd.mul_add(p2, l, w2);
134							l = simd.mul_add(beta2, w2, simd.mul_real(l, gamma2));
135
136							simd.write(L.rb_mut(), i, l);
137							simd.write(W.rb_mut().col_mut(i0), i, w0);
138							simd.write(W.rb_mut().col_mut(i1), i, w1);
139							simd.write(W.rb_mut().col_mut(i2), i, w2);
140						}};
141					}
142
143					if let Some(i) = head {
144						simd!(i);
145					}
146					for i in body {
147						simd!(i);
148					}
149					if let Some(i) = tail {
150						simd!(i);
151					}
152				},
153				(Some(i0), Some(i1), Some(i2), Some(i3)) => {
154					let (p0, p1, p2, p3) = (simd.splat(&p[i0]), simd.splat(&p[i1]), simd.splat(&p[i2]), simd.splat(&p[i3]));
155					let (beta0, beta1, beta2, beta3) = (simd.splat(&beta[i0]), simd.splat(&beta[i1]), simd.splat(&beta[i2]), simd.splat(&beta[i3]));
156					let (gamma0, gamma1, gamma2, gamma3) = (
157						simd.splat_real(&real(gamma[i0])),
158						simd.splat_real(&real(gamma[i1])),
159						simd.splat_real(&real(gamma[i2])),
160						simd.splat_real(&real(gamma[i3])),
161					);
162
163					macro_rules! simd {
164						($i: expr) => {{
165							let i = $i;
166							let mut l = simd.read(L.rb(), i);
167							let mut w0 = simd.read(W.rb().col(i0), i);
168							let mut w1 = simd.read(W.rb().col(i1), i);
169							let mut w2 = simd.read(W.rb().col(i2), i);
170							let mut w3 = simd.read(W.rb().col(i3), i);
171
172							w0 = simd.mul_add(p0, l, w0);
173							l = simd.mul_add(beta0, w0, simd.mul_real(l, gamma0));
174							w1 = simd.mul_add(p1, l, w1);
175							l = simd.mul_add(beta1, w1, simd.mul_real(l, gamma1));
176							w2 = simd.mul_add(p2, l, w2);
177							l = simd.mul_add(beta2, w2, simd.mul_real(l, gamma2));
178							w3 = simd.mul_add(p3, l, w3);
179							l = simd.mul_add(beta3, w3, simd.mul_real(l, gamma3));
180
181							simd.write(L.rb_mut(), i, l);
182							simd.write(W.rb_mut().col_mut(i0), i, w0);
183							simd.write(W.rb_mut().col_mut(i1), i, w1);
184							simd.write(W.rb_mut().col_mut(i2), i, w2);
185							simd.write(W.rb_mut().col_mut(i3), i, w3);
186						}};
187					}
188
189					if let Some(i) = head {
190						simd!(i);
191					}
192					for i in body {
193						simd!(i);
194					}
195					if let Some(i) = tail {
196						simd!(i);
197					}
198				},
199				_ => panic!(),
200			}
201		}
202	}
203
204	with_dim!(N, W.nrows());
205	with_dim!(R, W.ncols());
206
207	dispatch!(
208		Impl {
209			L: L.as_row_shape_mut(N),
210			W: W.as_shape_mut(N, R),
211			p: p.as_row_shape(R),
212			beta: beta.as_row_shape(R),
213			gamma: gamma.as_row_shape(R),
214			align_offset,
215		},
216		Impl,
217		T
218	)
219}
220
221#[math]
222fn rank_update_step_fallback<T: ComplexField>(L: ColMut<'_, T>, W: MatMut<'_, T>, p: ColRef<'_, T>, beta: ColRef<'_, T>, gamma: ColRef<'_, T>) {
223	let mut L = L;
224	let mut W = W;
225	let N = W.nrows();
226	let R = W.ncols();
227
228	let mut iter = 0..R;
229	let (i0, i1, i2, i3) = (iter.next(), iter.next(), iter.next(), iter.next());
230
231	match (i0, i1, i2, i3) {
232		(Some(i0), None, None, None) => {
233			let p0 = &p[i0];
234			let beta0 = &beta[i0];
235			let gamma0 = &gamma[i0];
236
237			for i in 0..N {
238				let mut l = copy(L[i]);
239				let mut w0 = copy(W[(i, i0)]);
240
241				w0 = *p0 * l + w0;
242				l = *beta0 * w0 + l * gamma0;
243
244				L[i] = l;
245				W[(i, i0)] = w0;
246			}
247		},
248		(Some(i0), Some(i1), None, None) => {
249			let (p0, p1) = (&p[i0], &p[i1]);
250			let (beta0, beta1) = (&beta[i0], &beta[i1]);
251			let (gamma0, gamma1) = (&gamma[i0], &gamma[i1]);
252
253			for i in 0..N {
254				let mut l = copy(L[i]);
255				let mut w0 = copy(W[(i, i0)]);
256				let mut w1 = copy(W[(i, i1)]);
257
258				w0 = *p0 * l + w0;
259				l = *beta0 * w0 + l * gamma0;
260				w1 = *p1 * l + w1;
261				l = *beta1 * w1 + l * gamma1;
262
263				L[i] = l;
264				W[(i, i0)] = w0;
265				W[(i, i1)] = w1;
266			}
267		},
268		(Some(i0), Some(i1), Some(i2), None) => {
269			let (p0, p1, p2) = (&p[i0], &p[i1], &p[i2]);
270			let (beta0, beta1, beta2) = (&beta[i0], &beta[i1], &beta[i2]);
271			let (gamma0, gamma1, gamma2) = (&gamma[i0], &gamma[i1], &gamma[i2]);
272
273			for i in 0..N {
274				let mut l = copy(L[i]);
275				let mut w0 = copy(W[(i, i0)]);
276				let mut w1 = copy(W[(i, i1)]);
277				let mut w2 = copy(W[(i, i2)]);
278
279				w0 = *p0 * l + w0;
280				l = *beta0 * w0 + l * gamma0;
281				w1 = *p1 * l + w1;
282				l = *beta1 * w1 + l * gamma1;
283				w2 = *p2 * l + w2;
284				l = *beta2 * w2 + l * gamma2;
285
286				L[i] = l;
287				W[(i, i0)] = w0;
288				W[(i, i1)] = w1;
289				W[(i, i2)] = w2;
290			}
291		},
292		(Some(i0), Some(i1), Some(i2), Some(i3)) => {
293			let (p0, p1, p2, p3) = (&p[i0], &p[i1], &p[i2], &p[i3]);
294			let (beta0, beta1, beta2, beta3) = (&beta[i0], &beta[i1], &beta[i2], &beta[i3]);
295			let (gamma0, gamma1, gamma2, gamma3) = (&gamma[i0], &gamma[i1], &gamma[i2], &gamma[i3]);
296
297			for i in 0..N {
298				let mut l = copy(L[i]);
299				let mut w0 = copy(W[(i, i0)]);
300				let mut w1 = copy(W[(i, i1)]);
301				let mut w2 = copy(W[(i, i2)]);
302				let mut w3 = copy(W[(i, i3)]);
303
304				w0 = *p0 * l + w0;
305				l = *beta0 * w0 + l * gamma0;
306				w1 = *p1 * l + w1;
307				l = *beta1 * w1 + l * gamma1;
308				w2 = *p2 * l + w2;
309				l = *beta2 * w2 + l * gamma2;
310				w3 = *p3 * l + w3;
311				l = *beta3 * w3 + l * gamma3;
312
313				L[i] = l;
314				W[(i, i0)] = w0;
315				W[(i, i1)] = w1;
316				W[(i, i2)] = w2;
317				W[(i, i3)] = w3;
318			}
319		},
320		_ => panic!(),
321	}
322}
323
324struct RankRUpdate<'a, T: ComplexField> {
325	ld: MatMut<'a, T>,
326	w: MatMut<'a, T>,
327	alpha: ColMut<'a, T>,
328	r: &'a mut dyn FnMut() -> usize,
329}
330
331impl<T: ComplexField> RankRUpdate<'_, T> {
332	// On the Modification of LDLT Factorizations
333	// By R. Fletcher and M. J. D. Powell
334	// https://www.ams.org/journals/mcom/1974-28-128/S0025-5718-1974-0359297-1/S0025-5718-1974-0359297-1.pdf
335
336	#[math]
337	fn run(self) -> Result<(), LltError> {
338		let Self { mut ld, mut w, mut alpha, r } = self;
339
340		let n = w.nrows();
341		let k = w.ncols();
342
343		for j in 0..n {
344			let mut L_col = ld.rb_mut().col_mut(j);
345
346			let r = Ord::min(r(), k);
347			let mut W = w.rb_mut().subcols_mut(0, r);
348			let mut alpha = alpha.rb_mut().subrows_mut(0, r);
349			let R = r;
350
351			const BLOCKSIZE: usize = 4;
352
353			let mut r = 0;
354			while r < R {
355				let bs = Ord::min(BLOCKSIZE, R - r);
356
357				stack_mat!(p, bs, 1, BLOCKSIZE, 1, T);
358				stack_mat!(beta, bs, 1, BLOCKSIZE, 1, T);
359				stack_mat!(gamma, bs, 1, BLOCKSIZE, 1, T);
360
361				let mut p = p.rb_mut().col_mut(0);
362				let mut beta = beta.rb_mut().col_mut(0);
363				let mut gamma = gamma.rb_mut().col_mut(0);
364
365				for k in 0..bs {
366					let p = &mut p[k];
367					let beta = &mut beta[k];
368					let gamma = &mut gamma[k];
369
370					let alpha = &mut alpha[r + k];
371					let d = &mut L_col[j];
372
373					let w = W.rb().col(r + k);
374
375					*p = copy(w[j]);
376
377					let alpha_conj_p = *alpha * conj(*p);
378					let new_d = abs2(real(*d)) + real(mul(alpha_conj_p, *p));
379
380					if new_d <= zero() {
381						return Err(LltError::NonPositivePivot { index: j });
382					}
383
384					let new_d = sqrt(new_d);
385					let d_inv = recip(real(*d));
386					let new_d_inv = recip(new_d);
387
388					*gamma = from_real(new_d * d_inv);
389					*beta = mul_real(alpha_conj_p, new_d_inv);
390					*p = mul_real(-*p, d_inv);
391
392					*alpha = from_real(real(*alpha) - abs2(*beta));
393					*d = from_real(new_d);
394				}
395
396				let mut L_col = L_col.rb_mut().get_mut(j + 1..);
397				let mut W_col = W.rb_mut().subcols_mut(r, bs).get_mut(j + 1.., ..);
398
399				if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
400					if let (Some(L_col), Some(W_col)) = (L_col.rb_mut().try_as_col_major_mut(), W_col.rb_mut().try_as_col_major_mut()) {
401						rank_update_step_simd(L_col, W_col, p.rb(), beta.rb(), gamma.rb(), simd_align(j + 1));
402					} else {
403						rank_update_step_fallback(L_col, W_col, p.rb(), beta.rb(), gamma.rb());
404					}
405				} else {
406					rank_update_step_fallback(L_col, W_col, p.rb(), beta.rb(), gamma.rb());
407				}
408				r += bs;
409			}
410		}
411		Ok(())
412	}
413}
414
415#[track_caller]
416pub fn rank_r_update_clobber<T: ComplexField>(cholesky_factors: MatMut<'_, T>, w: MatMut<'_, T>, alpha: DiagMut<'_, T>) -> Result<(), LltError> {
417	let N = cholesky_factors.nrows();
418	let R = w.ncols();
419
420	if N == 0 {
421		return Ok(());
422	}
423
424	RankRUpdate {
425		ld: cholesky_factors,
426		w,
427		alpha: alpha.column_vector_mut(),
428		r: &mut || R,
429	}
430	.run()
431}
432
433#[cfg(test)]
434mod tests {
435	use dyn_stack::MemBuffer;
436
437	use super::*;
438	use crate::stats::prelude::*;
439	use crate::utils::approx::*;
440	use crate::{Col, Mat, assert, c64};
441
442	#[test]
443	fn test_rank_update() {
444		let rng = &mut StdRng::seed_from_u64(0);
445
446		let approx_eq = CwiseMat(ApproxEq {
447			abs_tol: 1e-12,
448			rel_tol: 1e-12,
449		});
450
451		for r in [0, 1, 2, 3, 4, 5, 6, 7, 8, 10] {
452			for n in [2, 4, 8, 15] {
453				let A = CwiseMatDistribution {
454					nrows: n,
455					ncols: n,
456					dist: ComplexDistribution::new(StandardNormal, StandardNormal),
457				}
458				.rand::<Mat<c64>>(rng);
459				let mut W = CwiseMatDistribution {
460					nrows: n,
461					ncols: r,
462					dist: ComplexDistribution::new(StandardNormal, StandardNormal),
463				}
464				.rand::<Mat<c64>>(rng);
465				let mut alpha = CwiseColDistribution {
466					nrows: r,
467					dist: ComplexDistribution::new(StandardNormal, StandardNormal),
468				}
469				.rand::<Col<c64>>(rng)
470				.into_diagonal();
471
472				for j in 0..r {
473					alpha.column_vector_mut()[j].re = abs(&alpha.column_vector_mut()[j]);
474					alpha.column_vector_mut()[j].im = 0.0;
475				}
476
477				let A = &A * &A.adjoint();
478				let A_new = &A + &W * &alpha * &W.adjoint();
479
480				let A = A.as_ref();
481				let A_new = A_new.as_ref();
482
483				let mut L = A.cloned();
484				let mut L = L.as_mut();
485
486				linalg::cholesky::llt::factor::cholesky_in_place(
487					L.rb_mut(),
488					default(),
489					Par::Seq,
490					MemStack::new(&mut MemBuffer::new(linalg::cholesky::llt::factor::cholesky_in_place_scratch::<c64>(
491						n,
492						Par::Seq,
493						default(),
494					))),
495					default(),
496				)
497				.unwrap();
498
499				linalg::cholesky::llt::update::rank_r_update_clobber(L.rb_mut(), W.as_mut(), alpha.as_mut()).unwrap();
500
501				for j in 0..n {
502					for i in 0..j {
503						L[(i, j)] = c64::ZERO;
504					}
505				}
506				let L = L.as_ref();
507
508				assert!(A_new ~ L * L.adjoint());
509			}
510		}
511	}
512}