1use crate::internal_prelude::*;
2use pulp::Simd;
3
4#[math]
5fn rank_update_step_simd<T: ComplexField>(
6 L: ColMut<'_, T, usize, ContiguousFwd>,
7 W: MatMut<'_, T, usize, usize, ContiguousFwd>,
8 p: ColRef<'_, T>,
9 beta: ColRef<'_, T>,
10 align_offset: usize,
11) {
12 struct Impl<'a, 'N, 'R, T: ComplexField> {
13 L: ColMut<'a, T, Dim<'N>, ContiguousFwd>,
14 W: MatMut<'a, T, Dim<'N>, Dim<'R>, ContiguousFwd>,
15 p: ColRef<'a, T, Dim<'R>>,
16 beta: ColRef<'a, T, Dim<'R>>,
17 align_offset: usize,
18 }
19
20 impl<'a, 'N, 'R, T: ComplexField> pulp::WithSimd for Impl<'a, 'N, 'R, T> {
21 type Output = ();
22
23 #[inline(always)]
24 fn with_simd<S: Simd>(self, simd: S) {
25 let Self { L, W, p, beta, align_offset } = self;
26
27 let mut L = L;
28 let mut W = W;
29 let N = W.nrows();
30 let R = W.ncols();
31
32 let simd = SimdCtx::<T, S>::new_align(T::simd_ctx(simd), N, align_offset);
33 let (head, body, tail) = simd.indices();
34
35 let mut iter = R.indices();
36 let (i0, i1, i2, i3) = (iter.next(), iter.next(), iter.next(), iter.next());
37
38 match (i0, i1, i2, i3) {
39 (Some(i0), None, None, None) => {
40 let p0 = simd.splat(&p[i0]);
41 let beta0 = simd.splat(&beta[i0]);
42
43 macro_rules! simd {
44 ($i: expr) => {{
45 let i = $i;
46 let mut l = simd.read(L.rb(), i);
47 let mut w0 = simd.read(W.rb().col(i0), i);
48
49 w0 = simd.mul_add(p0, l, w0);
50 l = simd.mul_add(beta0, w0, l);
51
52 simd.write(L.rb_mut(), i, l);
53 simd.write(W.rb_mut().col_mut(i0), i, w0);
54 }};
55 }
56
57 if let Some(i) = head {
58 simd!(i);
59 }
60 for i in body {
61 simd!(i);
62 }
63 if let Some(i) = tail {
64 simd!(i);
65 }
66 },
67 (Some(i0), Some(i1), None, None) => {
68 let (p0, p1) = (simd.splat(&p[i0]), simd.splat(&p[i1]));
69 let (beta0, beta1) = (simd.splat(&beta[i0]), simd.splat(&beta[i1]));
70
71 macro_rules! simd {
72 ($i: expr) => {{
73 let i = $i;
74 let mut l = simd.read(L.rb(), i);
75 let mut w0 = simd.read(W.rb().col(i0), i);
76 let mut w1 = simd.read(W.rb().col(i1), i);
77
78 w0 = simd.mul_add(p0, l, w0);
79 l = simd.mul_add(beta0, w0, l);
80 w1 = simd.mul_add(p1, l, w1);
81 l = simd.mul_add(beta1, w1, l);
82
83 simd.write(L.rb_mut(), i, l);
84 simd.write(W.rb_mut().col_mut(i0), i, w0);
85 simd.write(W.rb_mut().col_mut(i1), i, w1);
86 }};
87 }
88
89 if let Some(i) = head {
90 simd!(i);
91 }
92 for i in body {
93 simd!(i);
94 }
95 if let Some(i) = tail {
96 simd!(i);
97 }
98 },
99 (Some(i0), Some(i1), Some(i2), None) => {
100 let (p0, p1, p2) = (simd.splat(&p[i0]), simd.splat(&p[i1]), simd.splat(&p[i2]));
101 let (beta0, beta1, beta2) = (simd.splat(&beta[i0]), simd.splat(&beta[i1]), simd.splat(&beta[i2]));
102
103 macro_rules! simd {
104 ($i: expr) => {{
105 let i = $i;
106 let mut l = simd.read(L.rb(), i);
107 let mut w0 = simd.read(W.rb().col(i0), i);
108 let mut w1 = simd.read(W.rb().col(i1), i);
109 let mut w2 = simd.read(W.rb().col(i2), i);
110
111 w0 = simd.mul_add(p0, l, w0);
112 l = simd.mul_add(beta0, w0, l);
113 w1 = simd.mul_add(p1, l, w1);
114 l = simd.mul_add(beta1, w1, l);
115 w2 = simd.mul_add(p2, l, w2);
116 l = simd.mul_add(beta2, w2, l);
117
118 simd.write(L.rb_mut(), i, l);
119 simd.write(W.rb_mut().col_mut(i0), i, w0);
120 simd.write(W.rb_mut().col_mut(i1), i, w1);
121 simd.write(W.rb_mut().col_mut(i2), i, w2);
122 }};
123 }
124
125 if let Some(i) = head {
126 simd!(i);
127 }
128 for i in body {
129 simd!(i);
130 }
131 if let Some(i) = tail {
132 simd!(i);
133 }
134 },
135 (Some(i0), Some(i1), Some(i2), Some(i3)) => {
136 let (p0, p1, p2, p3) = (simd.splat(&p[i0]), simd.splat(&p[i1]), simd.splat(&p[i2]), simd.splat(&p[i3]));
137 let (beta0, beta1, beta2, beta3) = (simd.splat(&beta[i0]), simd.splat(&beta[i1]), simd.splat(&beta[i2]), simd.splat(&beta[i3]));
138
139 macro_rules! simd {
140 ($i: expr) => {{
141 let i = $i;
142 let mut l = simd.read(L.rb(), i);
143 let mut w0 = simd.read(W.rb().col(i0), i);
144 let mut w1 = simd.read(W.rb().col(i1), i);
145 let mut w2 = simd.read(W.rb().col(i2), i);
146 let mut w3 = simd.read(W.rb().col(i3), i);
147
148 w0 = simd.mul_add(p0, l, w0);
149 l = simd.mul_add(beta0, w0, l);
150 w1 = simd.mul_add(p1, l, w1);
151 l = simd.mul_add(beta1, w1, l);
152 w2 = simd.mul_add(p2, l, w2);
153 l = simd.mul_add(beta2, w2, l);
154 w3 = simd.mul_add(p3, l, w3);
155 l = simd.mul_add(beta3, w3, l);
156
157 simd.write(L.rb_mut(), i, l);
158 simd.write(W.rb_mut().col_mut(i0), i, w0);
159 simd.write(W.rb_mut().col_mut(i1), i, w1);
160 simd.write(W.rb_mut().col_mut(i2), i, w2);
161 simd.write(W.rb_mut().col_mut(i3), i, w3);
162 }};
163 }
164
165 if let Some(i) = head {
166 simd!(i);
167 }
168 for i in body {
169 simd!(i);
170 }
171 if let Some(i) = tail {
172 simd!(i);
173 }
174 },
175 _ => panic!(),
176 }
177 }
178 }
179
180 with_dim!(N, W.nrows());
181 with_dim!(R, W.ncols());
182
183 dispatch!(
184 Impl {
185 L: L.as_row_shape_mut(N),
186 W: W.as_shape_mut(N, R),
187 p: p.as_row_shape(R),
188 beta: beta.as_row_shape(R),
189 align_offset,
190 },
191 Impl,
192 T
193 )
194}
195
196#[math]
197fn rank_update_step_fallback<T: ComplexField>(L: ColMut<'_, T>, W: MatMut<'_, T>, p: ColRef<'_, T>, beta: ColRef<'_, T>) {
198 let mut L = L;
199 let mut W = W;
200 let n = W.nrows();
201 let r = W.ncols();
202
203 let mut iter = 0..r;
204 let (i0, i1, i2, i3) = (iter.next(), iter.next(), iter.next(), iter.next());
205
206 match (i0, i1, i2, i3) {
207 (Some(i0), None, None, None) => {
208 let p0 = &p[i0];
209 let beta0 = &beta[i0];
210
211 for i in 0..n {
212 let mut l = copy(L[i]);
213 let mut w0 = copy(W[(i, i0)]);
214
215 w0 = *p0 * l + w0;
216 l = *beta0 * w0 + l;
217
218 L[i] = l;
219 W[(i, i0)] = w0;
220 }
221 },
222 (Some(i0), Some(i1), None, None) => {
223 let (p0, p1) = (&p[i0], &p[i1]);
224 let (beta0, beta1) = (&beta[i0], &beta[i1]);
225
226 for i in 0..n {
227 let mut l = copy(L[i]);
228 let mut w0 = copy(W[(i, i0)]);
229 let mut w1 = copy(W[(i, i1)]);
230
231 w0 = *p0 * l + w0;
232 l = *beta0 * w0 + l;
233 w1 = *p1 * l + w1;
234 l = *beta1 * w1 + l;
235
236 L[i] = l;
237 W[(i, i0)] = w0;
238 W[(i, i1)] = w1;
239 }
240 },
241 (Some(i0), Some(i1), Some(i2), None) => {
242 let (p0, p1, p2) = (&p[i0], &p[i1], &p[i2]);
243 let (beta0, beta1, beta2) = (&beta[i0], &beta[i1], &beta[i2]);
244
245 for i in 0..n {
246 let mut l = copy(L[i]);
247 let mut w0 = copy(W[(i, i0)]);
248 let mut w1 = copy(W[(i, i1)]);
249 let mut w2 = copy(W[(i, i2)]);
250
251 w0 = *p0 * l + w0;
252 l = *beta0 * w0 + l;
253 w1 = *p1 * l + w1;
254 l = *beta1 * w1 + l;
255 w2 = *p2 * l + w2;
256 l = *beta2 * w2 + l;
257
258 L[i] = l;
259 W[(i, i0)] = w0;
260 W[(i, i1)] = w1;
261 W[(i, i2)] = w2;
262 }
263 },
264 (Some(i0), Some(i1), Some(i2), Some(i3)) => {
265 let (p0, p1, p2, p3) = (&p[i0], &p[i1], &p[i2], &p[i3]);
266 let (beta0, beta1, beta2, beta3) = (&beta[i0], &beta[i1], &beta[i2], &beta[i3]);
267
268 for i in 0..n {
269 let mut l = copy(L[i]);
270 let mut w0 = copy(W[(i, i0)]);
271 let mut w1 = copy(W[(i, i1)]);
272 let mut w2 = copy(W[(i, i2)]);
273 let mut w3 = copy(W[(i, i3)]);
274
275 w0 = *p0 * l + w0;
276 l = *beta0 * w0 + l;
277 w1 = *p1 * l + w1;
278 l = *beta1 * w1 + l;
279 w2 = *p2 * l + w2;
280 l = *beta2 * w2 + l;
281 w3 = *p3 * l + w3;
282 l = *beta3 * w3 + l;
283
284 L[i] = l;
285 W[(i, i0)] = w0;
286 W[(i, i1)] = w1;
287 W[(i, i2)] = w2;
288 W[(i, i3)] = w3;
289 }
290 },
291 _ => panic!(),
292 }
293}
294
295struct RankRUpdate<'a, T: ComplexField> {
296 ld: MatMut<'a, T>,
297 w: MatMut<'a, T>,
298 alpha: ColMut<'a, T>,
299 r: &'a mut dyn FnMut() -> usize,
300}
301
302impl<T: ComplexField> RankRUpdate<'_, T> {
303 #[math]
308 fn run(self) {
309 let Self { mut ld, mut w, mut alpha, r } = self;
310
311 let n = w.nrows();
312 let k = w.ncols();
313
314 for j in 0..n {
315 let mut L_col = ld.rb_mut().col_mut(j);
316
317 let r = Ord::min(r(), k);
318 let mut W = w.rb_mut().subcols_mut(0, r);
319 let mut alpha = alpha.rb_mut().subrows_mut(0, r);
320 let R = r;
321
322 const BLOCKSIZE: usize = 4;
323
324 let mut r = 0;
325 while r < R {
326 let bs = Ord::min(BLOCKSIZE, R - r);
327
328 stack_mat!(p, bs, 1, BLOCKSIZE, 1, T);
329 stack_mat!(beta, bs, 1, BLOCKSIZE, 1, T);
330
331 let mut p = p.rb_mut().col_mut(0);
332 let mut beta = beta.rb_mut().col_mut(0);
333
334 for k in 0..bs {
335 let p = p.rb_mut().at_mut(k);
336 let beta = beta.rb_mut().at_mut(k);
337 let alpha = alpha.rb_mut().at_mut(r + k);
338 let d = L_col.rb_mut().at_mut(j);
339
340 let w = W.rb().col(r + k);
341
342 *p = copy(w[j]);
343
344 let alpha_conj_p = *alpha * conj(*p);
345 let new_d = real(*d) + real(mul(alpha_conj_p, *p));
346 *beta = mul_real(alpha_conj_p, recip(new_d));
347
348 *alpha = from_real(real(*alpha) - new_d * abs2(*beta));
349 *d = from_real(new_d);
350 *p = -*p;
351 }
352
353 let mut L_col = L_col.rb_mut().get_mut(j + 1..);
354 let mut W_col = W.rb_mut().subcols_mut(r, bs).get_mut(j + 1.., ..);
355
356 if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
357 if let (Some(L_col), Some(W_col)) = (L_col.rb_mut().try_as_col_major_mut(), W_col.rb_mut().try_as_col_major_mut()) {
358 rank_update_step_simd(L_col, W_col, p.rb(), beta.rb(), simd_align(j + 1));
359 } else {
360 rank_update_step_fallback(L_col, W_col, p.rb(), beta.rb());
361 }
362 } else {
363 rank_update_step_fallback(L_col, W_col, p.rb(), beta.rb());
364 }
365 r += bs;
366 }
367 }
368 }
369}
370
371#[track_caller]
372pub fn rank_r_update_clobber<T: ComplexField>(cholesky_factors: MatMut<'_, T>, w: MatMut<'_, T>, alpha: DiagMut<'_, T>) {
373 let n = cholesky_factors.nrows();
374 let r = w.ncols();
375
376 if n == 0 {
377 return;
378 }
379
380 RankRUpdate {
381 ld: cholesky_factors,
382 w,
383 alpha: alpha.column_vector_mut(),
384 r: &mut || r,
385 }
386 .run();
387}
388
389#[cfg(test)]
390mod tests {
391 use dyn_stack::MemBuffer;
392
393 use super::*;
394 use crate::stats::prelude::*;
395 use crate::utils::approx::*;
396 use crate::{Col, Mat, assert, c64};
397
398 #[test]
399 fn test_rank_update() {
400 let rng = &mut StdRng::seed_from_u64(0);
401
402 let approx_eq = CwiseMat(ApproxEq {
403 abs_tol: 1e-12,
404 rel_tol: 1e-12,
405 });
406
407 for r in [0, 1, 2, 3, 4, 5, 6, 7, 8, 10] {
408 for n in [2, 4, 8, 15] {
409 let A = CwiseMatDistribution {
410 nrows: n,
411 ncols: n,
412 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
413 }
414 .rand::<Mat<c64>>(rng);
415 let mut W = CwiseMatDistribution {
416 nrows: n,
417 ncols: r,
418 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
419 }
420 .rand::<Mat<c64>>(rng);
421 let mut alpha = CwiseColDistribution {
422 nrows: r,
423 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
424 }
425 .rand::<Col<c64>>(rng)
426 .into_diagonal();
427
428 for j in 0..r {
429 alpha.column_vector_mut()[j].im = 0.0;
430 }
431
432 let A = &A * &A.adjoint();
433 let A_new = &A + &W * &alpha * &W.adjoint();
434
435 let A = A.as_ref();
436 let A_new = A_new.as_ref();
437
438 let mut L = A.cloned();
439 let mut L = L.as_mut();
440
441 linalg::cholesky::ldlt::factor::cholesky_in_place(
442 L.rb_mut(),
443 default(),
444 Par::Seq,
445 MemStack::new(&mut MemBuffer::new(linalg::cholesky::ldlt::factor::cholesky_in_place_scratch::<c64>(
446 n,
447 Par::Seq,
448 default(),
449 ))),
450 default(),
451 )
452 .unwrap();
453
454 linalg::cholesky::ldlt::update::rank_r_update_clobber(L.rb_mut(), W.as_mut(), alpha.as_mut());
455 let D = L.as_mut().diagonal().column_vector().as_mat().cloned();
456 let D = D.col(0).as_diagonal();
457
458 for j in 0..n {
459 for i in 0..j {
460 L[(i, j)] = c64::ZERO;
461 }
462 L[(j, j)] = c64::ONE;
463 }
464 let L = L.as_ref();
465
466 assert!(A_new ~ L * D * L.adjoint());
467 }
468 }
469 }
470}