1use crate::internal_prelude::*;
2use pulp::Simd;
3
4use super::factor::LltError;
5
6#[math]
7fn rank_update_step_simd<T: ComplexField>(
8 L: ColMut<'_, T, usize, ContiguousFwd>,
9 W: MatMut<'_, T, usize, usize, ContiguousFwd>,
10 p: ColRef<'_, T>,
11 beta: ColRef<'_, T>,
12 gamma: ColRef<'_, T>,
13 align_offset: usize,
14) {
15 struct Impl<'a, 'N, 'R, T: ComplexField> {
16 L: ColMut<'a, T, Dim<'N>, ContiguousFwd>,
17 W: MatMut<'a, T, Dim<'N>, Dim<'R>, ContiguousFwd>,
18 p: ColRef<'a, T, Dim<'R>>,
19 beta: ColRef<'a, T, Dim<'R>>,
20 gamma: ColRef<'a, T, Dim<'R>>,
21 align_offset: usize,
22 }
23
24 impl<'a, 'N, 'R, T: ComplexField> pulp::WithSimd for Impl<'a, 'N, 'R, T> {
25 type Output = ();
26
27 #[inline(always)]
28 fn with_simd<S: Simd>(self, simd: S) {
29 let Self {
30 L,
31 W,
32 p,
33 beta,
34 gamma,
35 align_offset,
36 } = self;
37
38 let mut L = L;
39 let mut W = W;
40 let N = W.nrows();
41 let R = W.ncols();
42
43 let simd = SimdCtx::<T, S>::new_align(T::simd_ctx(simd), N, align_offset);
44 let (head, body, tail) = simd.indices();
45
46 let mut iter = R.indices();
47 let (i0, i1, i2, i3) = (iter.next(), iter.next(), iter.next(), iter.next());
48
49 match (i0, i1, i2, i3) {
50 (Some(i0), None, None, None) => {
51 let p0 = simd.splat(&p[i0]);
52 let beta0 = simd.splat(&beta[i0]);
53 let gamma0 = simd.splat_real(&real(gamma[i0]));
54
55 macro_rules! simd {
56 ($i: expr) => {{
57 let i = $i;
58 let mut l = simd.read(L.rb(), i);
59 let mut w0 = simd.read(W.rb().col(i0), i);
60
61 w0 = simd.mul_add(p0, l, w0);
62 l = simd.mul_add(beta0, w0, simd.mul_real(l, gamma0));
63
64 simd.write(L.rb_mut(), i, l);
65 simd.write(W.rb_mut().col_mut(i0), i, w0);
66 }};
67 }
68
69 if let Some(i) = head {
70 simd!(i);
71 }
72 for i in body {
73 simd!(i);
74 }
75 if let Some(i) = tail {
76 simd!(i);
77 }
78 },
79 (Some(i0), Some(i1), None, None) => {
80 let (p0, p1) = (simd.splat(&p[i0]), simd.splat(&p[i1]));
81 let (beta0, beta1) = (simd.splat(&beta[i0]), simd.splat(&beta[i1]));
82 let (gamma0, gamma1) = (simd.splat_real(&real(gamma[i0])), simd.splat_real(&real(gamma[i1])));
83
84 macro_rules! simd {
85 ($i: expr) => {{
86 let i = $i;
87 let mut l = simd.read(L.rb(), i);
88 let mut w0 = simd.read(W.rb().col(i0), i);
89 let mut w1 = simd.read(W.rb().col(i1), i);
90
91 w0 = simd.mul_add(p0, l, w0);
92 l = simd.mul_add(beta0, w0, simd.mul_real(l, gamma0));
93 w1 = simd.mul_add(p1, l, w1);
94 l = simd.mul_add(beta1, w1, simd.mul_real(l, gamma1));
95
96 simd.write(L.rb_mut(), i, l);
97 simd.write(W.rb_mut().col_mut(i0), i, w0);
98 simd.write(W.rb_mut().col_mut(i1), i, w1);
99 }};
100 }
101
102 if let Some(i) = head {
103 simd!(i);
104 }
105 for i in body {
106 simd!(i);
107 }
108 if let Some(i) = tail {
109 simd!(i);
110 }
111 },
112 (Some(i0), Some(i1), Some(i2), None) => {
113 let (p0, p1, p2) = (simd.splat(&p[i0]), simd.splat(&p[i1]), simd.splat(&p[i2]));
114 let (beta0, beta1, beta2) = (simd.splat(&beta[i0]), simd.splat(&beta[i1]), simd.splat(&beta[i2]));
115 let (gamma0, gamma1, gamma2) = (
116 simd.splat_real(&real(gamma[i0])),
117 simd.splat_real(&real(gamma[i1])),
118 simd.splat_real(&real(gamma[i2])),
119 );
120
121 macro_rules! simd {
122 ($i: expr) => {{
123 let i = $i;
124 let mut l = simd.read(L.rb(), i);
125 let mut w0 = simd.read(W.rb().col(i0), i);
126 let mut w1 = simd.read(W.rb().col(i1), i);
127 let mut w2 = simd.read(W.rb().col(i2), i);
128
129 w0 = simd.mul_add(p0, l, w0);
130 l = simd.mul_add(beta0, w0, simd.mul_real(l, gamma0));
131 w1 = simd.mul_add(p1, l, w1);
132 l = simd.mul_add(beta1, w1, simd.mul_real(l, gamma1));
133 w2 = simd.mul_add(p2, l, w2);
134 l = simd.mul_add(beta2, w2, simd.mul_real(l, gamma2));
135
136 simd.write(L.rb_mut(), i, l);
137 simd.write(W.rb_mut().col_mut(i0), i, w0);
138 simd.write(W.rb_mut().col_mut(i1), i, w1);
139 simd.write(W.rb_mut().col_mut(i2), i, w2);
140 }};
141 }
142
143 if let Some(i) = head {
144 simd!(i);
145 }
146 for i in body {
147 simd!(i);
148 }
149 if let Some(i) = tail {
150 simd!(i);
151 }
152 },
153 (Some(i0), Some(i1), Some(i2), Some(i3)) => {
154 let (p0, p1, p2, p3) = (simd.splat(&p[i0]), simd.splat(&p[i1]), simd.splat(&p[i2]), simd.splat(&p[i3]));
155 let (beta0, beta1, beta2, beta3) = (simd.splat(&beta[i0]), simd.splat(&beta[i1]), simd.splat(&beta[i2]), simd.splat(&beta[i3]));
156 let (gamma0, gamma1, gamma2, gamma3) = (
157 simd.splat_real(&real(gamma[i0])),
158 simd.splat_real(&real(gamma[i1])),
159 simd.splat_real(&real(gamma[i2])),
160 simd.splat_real(&real(gamma[i3])),
161 );
162
163 macro_rules! simd {
164 ($i: expr) => {{
165 let i = $i;
166 let mut l = simd.read(L.rb(), i);
167 let mut w0 = simd.read(W.rb().col(i0), i);
168 let mut w1 = simd.read(W.rb().col(i1), i);
169 let mut w2 = simd.read(W.rb().col(i2), i);
170 let mut w3 = simd.read(W.rb().col(i3), i);
171
172 w0 = simd.mul_add(p0, l, w0);
173 l = simd.mul_add(beta0, w0, simd.mul_real(l, gamma0));
174 w1 = simd.mul_add(p1, l, w1);
175 l = simd.mul_add(beta1, w1, simd.mul_real(l, gamma1));
176 w2 = simd.mul_add(p2, l, w2);
177 l = simd.mul_add(beta2, w2, simd.mul_real(l, gamma2));
178 w3 = simd.mul_add(p3, l, w3);
179 l = simd.mul_add(beta3, w3, simd.mul_real(l, gamma3));
180
181 simd.write(L.rb_mut(), i, l);
182 simd.write(W.rb_mut().col_mut(i0), i, w0);
183 simd.write(W.rb_mut().col_mut(i1), i, w1);
184 simd.write(W.rb_mut().col_mut(i2), i, w2);
185 simd.write(W.rb_mut().col_mut(i3), i, w3);
186 }};
187 }
188
189 if let Some(i) = head {
190 simd!(i);
191 }
192 for i in body {
193 simd!(i);
194 }
195 if let Some(i) = tail {
196 simd!(i);
197 }
198 },
199 _ => panic!(),
200 }
201 }
202 }
203
204 with_dim!(N, W.nrows());
205 with_dim!(R, W.ncols());
206
207 dispatch!(
208 Impl {
209 L: L.as_row_shape_mut(N),
210 W: W.as_shape_mut(N, R),
211 p: p.as_row_shape(R),
212 beta: beta.as_row_shape(R),
213 gamma: gamma.as_row_shape(R),
214 align_offset,
215 },
216 Impl,
217 T
218 )
219}
220
221#[math]
222fn rank_update_step_fallback<T: ComplexField>(L: ColMut<'_, T>, W: MatMut<'_, T>, p: ColRef<'_, T>, beta: ColRef<'_, T>, gamma: ColRef<'_, T>) {
223 let mut L = L;
224 let mut W = W;
225 let N = W.nrows();
226 let R = W.ncols();
227
228 let mut iter = 0..R;
229 let (i0, i1, i2, i3) = (iter.next(), iter.next(), iter.next(), iter.next());
230
231 match (i0, i1, i2, i3) {
232 (Some(i0), None, None, None) => {
233 let p0 = &p[i0];
234 let beta0 = &beta[i0];
235 let gamma0 = &gamma[i0];
236
237 for i in 0..N {
238 let mut l = copy(L[i]);
239 let mut w0 = copy(W[(i, i0)]);
240
241 w0 = *p0 * l + w0;
242 l = *beta0 * w0 + l * gamma0;
243
244 L[i] = l;
245 W[(i, i0)] = w0;
246 }
247 },
248 (Some(i0), Some(i1), None, None) => {
249 let (p0, p1) = (&p[i0], &p[i1]);
250 let (beta0, beta1) = (&beta[i0], &beta[i1]);
251 let (gamma0, gamma1) = (&gamma[i0], &gamma[i1]);
252
253 for i in 0..N {
254 let mut l = copy(L[i]);
255 let mut w0 = copy(W[(i, i0)]);
256 let mut w1 = copy(W[(i, i1)]);
257
258 w0 = *p0 * l + w0;
259 l = *beta0 * w0 + l * gamma0;
260 w1 = *p1 * l + w1;
261 l = *beta1 * w1 + l * gamma1;
262
263 L[i] = l;
264 W[(i, i0)] = w0;
265 W[(i, i1)] = w1;
266 }
267 },
268 (Some(i0), Some(i1), Some(i2), None) => {
269 let (p0, p1, p2) = (&p[i0], &p[i1], &p[i2]);
270 let (beta0, beta1, beta2) = (&beta[i0], &beta[i1], &beta[i2]);
271 let (gamma0, gamma1, gamma2) = (&gamma[i0], &gamma[i1], &gamma[i2]);
272
273 for i in 0..N {
274 let mut l = copy(L[i]);
275 let mut w0 = copy(W[(i, i0)]);
276 let mut w1 = copy(W[(i, i1)]);
277 let mut w2 = copy(W[(i, i2)]);
278
279 w0 = *p0 * l + w0;
280 l = *beta0 * w0 + l * gamma0;
281 w1 = *p1 * l + w1;
282 l = *beta1 * w1 + l * gamma1;
283 w2 = *p2 * l + w2;
284 l = *beta2 * w2 + l * gamma2;
285
286 L[i] = l;
287 W[(i, i0)] = w0;
288 W[(i, i1)] = w1;
289 W[(i, i2)] = w2;
290 }
291 },
292 (Some(i0), Some(i1), Some(i2), Some(i3)) => {
293 let (p0, p1, p2, p3) = (&p[i0], &p[i1], &p[i2], &p[i3]);
294 let (beta0, beta1, beta2, beta3) = (&beta[i0], &beta[i1], &beta[i2], &beta[i3]);
295 let (gamma0, gamma1, gamma2, gamma3) = (&gamma[i0], &gamma[i1], &gamma[i2], &gamma[i3]);
296
297 for i in 0..N {
298 let mut l = copy(L[i]);
299 let mut w0 = copy(W[(i, i0)]);
300 let mut w1 = copy(W[(i, i1)]);
301 let mut w2 = copy(W[(i, i2)]);
302 let mut w3 = copy(W[(i, i3)]);
303
304 w0 = *p0 * l + w0;
305 l = *beta0 * w0 + l * gamma0;
306 w1 = *p1 * l + w1;
307 l = *beta1 * w1 + l * gamma1;
308 w2 = *p2 * l + w2;
309 l = *beta2 * w2 + l * gamma2;
310 w3 = *p3 * l + w3;
311 l = *beta3 * w3 + l * gamma3;
312
313 L[i] = l;
314 W[(i, i0)] = w0;
315 W[(i, i1)] = w1;
316 W[(i, i2)] = w2;
317 W[(i, i3)] = w3;
318 }
319 },
320 _ => panic!(),
321 }
322}
323
324struct RankRUpdate<'a, T: ComplexField> {
325 ld: MatMut<'a, T>,
326 w: MatMut<'a, T>,
327 alpha: ColMut<'a, T>,
328 r: &'a mut dyn FnMut() -> usize,
329}
330
331impl<T: ComplexField> RankRUpdate<'_, T> {
332 #[math]
337 fn run(self) -> Result<(), LltError> {
338 let Self { mut ld, mut w, mut alpha, r } = self;
339
340 let n = w.nrows();
341 let k = w.ncols();
342
343 for j in 0..n {
344 let mut L_col = ld.rb_mut().col_mut(j);
345
346 let r = Ord::min(r(), k);
347 let mut W = w.rb_mut().subcols_mut(0, r);
348 let mut alpha = alpha.rb_mut().subrows_mut(0, r);
349 let R = r;
350
351 const BLOCKSIZE: usize = 4;
352
353 let mut r = 0;
354 while r < R {
355 let bs = Ord::min(BLOCKSIZE, R - r);
356
357 stack_mat!(p, bs, 1, BLOCKSIZE, 1, T);
358 stack_mat!(beta, bs, 1, BLOCKSIZE, 1, T);
359 stack_mat!(gamma, bs, 1, BLOCKSIZE, 1, T);
360
361 let mut p = p.rb_mut().col_mut(0);
362 let mut beta = beta.rb_mut().col_mut(0);
363 let mut gamma = gamma.rb_mut().col_mut(0);
364
365 for k in 0..bs {
366 let p = &mut p[k];
367 let beta = &mut beta[k];
368 let gamma = &mut gamma[k];
369
370 let alpha = &mut alpha[r + k];
371 let d = &mut L_col[j];
372
373 let w = W.rb().col(r + k);
374
375 *p = copy(w[j]);
376
377 let alpha_conj_p = *alpha * conj(*p);
378 let new_d = abs2(real(*d)) + real(mul(alpha_conj_p, *p));
379
380 if new_d <= zero() {
381 return Err(LltError::NonPositivePivot { index: j });
382 }
383
384 let new_d = sqrt(new_d);
385 let d_inv = recip(real(*d));
386 let new_d_inv = recip(new_d);
387
388 *gamma = from_real(new_d * d_inv);
389 *beta = mul_real(alpha_conj_p, new_d_inv);
390 *p = mul_real(-*p, d_inv);
391
392 *alpha = from_real(real(*alpha) - abs2(*beta));
393 *d = from_real(new_d);
394 }
395
396 let mut L_col = L_col.rb_mut().get_mut(j + 1..);
397 let mut W_col = W.rb_mut().subcols_mut(r, bs).get_mut(j + 1.., ..);
398
399 if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
400 if let (Some(L_col), Some(W_col)) = (L_col.rb_mut().try_as_col_major_mut(), W_col.rb_mut().try_as_col_major_mut()) {
401 rank_update_step_simd(L_col, W_col, p.rb(), beta.rb(), gamma.rb(), simd_align(j + 1));
402 } else {
403 rank_update_step_fallback(L_col, W_col, p.rb(), beta.rb(), gamma.rb());
404 }
405 } else {
406 rank_update_step_fallback(L_col, W_col, p.rb(), beta.rb(), gamma.rb());
407 }
408 r += bs;
409 }
410 }
411 Ok(())
412 }
413}
414
415#[track_caller]
416pub fn rank_r_update_clobber<T: ComplexField>(cholesky_factors: MatMut<'_, T>, w: MatMut<'_, T>, alpha: DiagMut<'_, T>) -> Result<(), LltError> {
417 let N = cholesky_factors.nrows();
418 let R = w.ncols();
419
420 if N == 0 {
421 return Ok(());
422 }
423
424 RankRUpdate {
425 ld: cholesky_factors,
426 w,
427 alpha: alpha.column_vector_mut(),
428 r: &mut || R,
429 }
430 .run()
431}
432
433#[cfg(test)]
434mod tests {
435 use dyn_stack::MemBuffer;
436
437 use super::*;
438 use crate::stats::prelude::*;
439 use crate::utils::approx::*;
440 use crate::{Col, Mat, assert, c64};
441
442 #[test]
443 fn test_rank_update() {
444 let rng = &mut StdRng::seed_from_u64(0);
445
446 let approx_eq = CwiseMat(ApproxEq {
447 abs_tol: 1e-12,
448 rel_tol: 1e-12,
449 });
450
451 for r in [0, 1, 2, 3, 4, 5, 6, 7, 8, 10] {
452 for n in [2, 4, 8, 15] {
453 let A = CwiseMatDistribution {
454 nrows: n,
455 ncols: n,
456 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
457 }
458 .rand::<Mat<c64>>(rng);
459 let mut W = CwiseMatDistribution {
460 nrows: n,
461 ncols: r,
462 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
463 }
464 .rand::<Mat<c64>>(rng);
465 let mut alpha = CwiseColDistribution {
466 nrows: r,
467 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
468 }
469 .rand::<Col<c64>>(rng)
470 .into_diagonal();
471
472 for j in 0..r {
473 alpha.column_vector_mut()[j].re = abs(&alpha.column_vector_mut()[j]);
474 alpha.column_vector_mut()[j].im = 0.0;
475 }
476
477 let A = &A * &A.adjoint();
478 let A_new = &A + &W * &alpha * &W.adjoint();
479
480 let A = A.as_ref();
481 let A_new = A_new.as_ref();
482
483 let mut L = A.cloned();
484 let mut L = L.as_mut();
485
486 linalg::cholesky::llt::factor::cholesky_in_place(
487 L.rb_mut(),
488 default(),
489 Par::Seq,
490 MemStack::new(&mut MemBuffer::new(linalg::cholesky::llt::factor::cholesky_in_place_scratch::<c64>(
491 n,
492 Par::Seq,
493 default(),
494 ))),
495 default(),
496 )
497 .unwrap();
498
499 linalg::cholesky::llt::update::rank_r_update_clobber(L.rb_mut(), W.as_mut(), alpha.as_mut()).unwrap();
500
501 for j in 0..n {
502 for i in 0..j {
503 L[(i, j)] = c64::ZERO;
504 }
505 }
506 let L = L.as_ref();
507
508 assert!(A_new ~ L * L.adjoint());
509 }
510 }
511 }
512}