1use crate::assert;
2use crate::internal_prelude::*;
3use crate::linalg::matmul::internal::*;
4use linalg::matmul::triangular::BlockStructure;
5use pulp::Simd;
6
7#[inline(always)]
8#[math]
9fn simd_cholesky_row_batch<'N, T: ComplexField, S: Simd>(
10 simd: T::SimdCtx<S>,
11 A: MatMut<'_, T, Dim<'N>, Dim<'N>, ContiguousFwd>,
12 D: RowMut<'_, T, Dim<'N>>,
13
14 start: IdxInc<'N>,
15
16 is_llt: bool,
17 regularize: bool,
18 eps: T::Real,
19 delta: T::Real,
20 signs: Option<&Array<'N, i8>>,
21) -> Result<usize, usize> {
22 let mut A = A;
23 let mut D = D;
24
25 let n = A.ncols();
26
27 with_dim!(TAIL, *n - *start);
28
29 let simd = SimdCtx::<T, S>::new_force_mask(simd, TAIL);
30 let (idx_head, indices, idx_tail) = simd.indices();
31 assert!(idx_head.is_none());
32 let Some(idx_tail) = idx_tail else { panic!() };
33
34 let mut count = 0usize;
35
36 for j in n.indices() {
37 with_dim!(LEFT, *j);
38
39 let (A_0, Aj) = A.rb_mut().split_at_col_mut(j.into());
40 let A_0 = A_0.as_col_shape(LEFT);
41 let A10 = A_0.subrows(start, TAIL);
42
43 let mut Aj = Aj.col_mut(0).subrows_mut(start, TAIL);
44
45 {
46 let D = D.rb().subcols(IdxInc::ZERO, LEFT);
47 let mut Aj = Aj.rb_mut();
48 let mut iter = indices.clone();
49 let i0 = iter.next();
50 let i1 = iter.next();
51 let i2 = iter.next();
52
53 match (i0, i1, i2) {
54 (None, None, None) => {
55 let mut Aij = simd.read(Aj.rb(), idx_tail);
56
57 for k in LEFT.indices() {
58 let Ak = A10.col(k);
59
60 let D = real(D[k]);
61 let D = if is_llt { one() } else { D };
62
63 let Ajk = simd.splat(&mul_real(conj(A_0[(j, k)]), -D));
64
65 let Aik = simd.read(Ak, idx_tail);
66 Aij = simd.mul_add(Ajk, Aik, Aij);
67 }
68 simd.write(Aj.rb_mut(), idx_tail, Aij);
69 },
70 (Some(i0), None, None) => {
71 let mut A0j = simd.read(Aj.rb(), i0);
72 let mut Aij = simd.read(Aj.rb(), idx_tail);
73
74 for k in LEFT.indices() {
75 let Ak = A10.col(k);
76
77 let D = real(D[k]);
78 let D = if is_llt { one() } else { D };
79
80 let Ajk = simd.splat(&mul_real(conj(A_0[(j, k)]), -D));
81
82 let A0k = simd.read(Ak, i0);
83 let Aik = simd.read(Ak, idx_tail);
84 A0j = simd.mul_add(Ajk, A0k, A0j);
85 Aij = simd.mul_add(Ajk, Aik, Aij);
86 }
87 simd.write(Aj.rb_mut(), i0, A0j);
88 simd.write(Aj.rb_mut(), idx_tail, Aij);
89 },
90 (Some(i0), Some(i1), None) => {
91 let mut A0j = simd.read(Aj.rb(), i0);
92 let mut A1j = simd.read(Aj.rb(), i1);
93 let mut Aij = simd.read(Aj.rb(), idx_tail);
94
95 for k in LEFT.indices() {
96 let Ak = A10.col(k);
97
98 let D = real(D[k]);
99 let D = if is_llt { one() } else { D };
100
101 let Ajk = simd.splat(&mul_real(conj(A_0[(j, k)]), -D));
102
103 let A0k = simd.read(Ak, i0);
104 let A1k = simd.read(Ak, i1);
105 let Aik = simd.read(Ak, idx_tail);
106 A0j = simd.mul_add(Ajk, A0k, A0j);
107 A1j = simd.mul_add(Ajk, A1k, A1j);
108 Aij = simd.mul_add(Ajk, Aik, Aij);
109 }
110 simd.write(Aj.rb_mut(), i0, A0j);
111 simd.write(Aj.rb_mut(), i1, A1j);
112 simd.write(Aj.rb_mut(), idx_tail, Aij);
113 },
114 (Some(i0), Some(i1), Some(i2)) => {
115 let mut A0j = simd.read(Aj.rb(), i0);
116 let mut A1j = simd.read(Aj.rb(), i1);
117 let mut A2j = simd.read(Aj.rb(), i2);
118 let mut Aij = simd.read(Aj.rb(), idx_tail);
119
120 for k in LEFT.indices() {
121 let Ak = A10.col(k);
122
123 let D = real(D[k]);
124 let D = if is_llt { one() } else { D };
125
126 let Ajk = simd.splat(&mul_real(conj(A_0[(j, k)]), -D));
127
128 let A0k = simd.read(Ak, i0);
129 let A1k = simd.read(Ak, i1);
130 let A2k = simd.read(Ak, i2);
131 let Aik = simd.read(Ak, idx_tail);
132 A0j = simd.mul_add(Ajk, A0k, A0j);
133 A1j = simd.mul_add(Ajk, A1k, A1j);
134 A2j = simd.mul_add(Ajk, A2k, A2j);
135 Aij = simd.mul_add(Ajk, Aik, Aij);
136 }
137 simd.write(Aj.rb_mut(), i0, A0j);
138 simd.write(Aj.rb_mut(), i1, A1j);
139 simd.write(Aj.rb_mut(), i2, A2j);
140 simd.write(Aj.rb_mut(), idx_tail, Aij);
141 },
142 _ => {
143 unreachable!();
144 },
145 }
146 }
147
148 let D = D.rb_mut().at_mut(j);
149
150 if *j >= *start {
151 let j_row = TAIL.idx(*j - *start);
152
153 let mut diag = real(Aj[j_row]);
154
155 if regularize {
156 let sign = if is_llt { 1 } else { if let Some(signs) = signs { signs[j] } else { 0 } };
157
158 let small_or_negative = diag <= eps;
159 let minus_small_or_positive = diag >= -eps;
160
161 if sign == 1 && small_or_negative {
162 diag = copy(delta);
163 count += 1;
164 } else if sign == -1i8 && minus_small_or_positive {
165 diag = neg(delta);
166 } else {
167 if small_or_negative && minus_small_or_positive {
168 if diag < zero() {
169 diag = neg(delta);
170 } else {
171 diag = copy(delta);
172 }
173 }
174 }
175 }
176
177 let j = j;
178 let diag = if is_llt {
179 if !(diag > zero()) {
180 *D = from_real(diag);
181 return Err(*j);
182 }
183 sqrt(diag)
184 } else {
185 copy(diag)
186 };
187
188 *D = from_real(diag);
189
190 if diag == zero() || !is_finite(diag) {
191 return Err(*j);
192 }
193 }
194
195 let diag = real(*D);
196
197 {
198 let mut Aj = Aj.rb_mut();
199 let inv = simd.splat_real(&recip(diag));
200
201 for i in indices.clone() {
202 let mut Aij = simd.read(Aj.rb(), i);
203 Aij = simd.mul_real(Aij, inv);
204 simd.write(Aj.rb_mut(), i, Aij);
205 }
206 {
207 let mut Aij = simd.read(Aj.rb(), idx_tail);
208 Aij = simd.mul_real(Aij, inv);
209 simd.write(Aj.rb_mut(), idx_tail, Aij);
210 }
211 }
212 }
213
214 Ok(count)
215}
216
217#[inline(always)]
218#[math]
219fn simd_cholesky_matrix<T: ComplexField, S: Simd>(
220 simd: T::SimdCtx<S>,
221 A: MatMut<'_, T, usize, usize, ContiguousFwd>,
222 D: RowMut<'_, T, usize>,
223
224 is_llt: bool,
225 regularize: bool,
226 eps: T::Real,
227 delta: T::Real,
228 signs: Option<&[i8]>,
229) -> Result<usize, usize> {
230 let N = A.ncols();
231
232 let blocksize = 4 * (core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>());
233
234 let mut A = A;
235 let mut D = D;
236
237 let mut count = 0;
238
239 let mut j = 0;
240 while j < N {
241 let blocksize = Ord::min(blocksize, N - j);
242 let j_next = j + blocksize;
243
244 with_dim!(HEAD, j_next);
245 let A = A.rb_mut().submatrix_mut(0, 0, HEAD, HEAD);
246 let D = D.rb_mut().subcols_mut(0, HEAD);
247
248 let signs = signs.map(|signs| Array::from_ref(&signs[..*HEAD], HEAD));
249
250 count += simd_cholesky_row_batch(simd, A, D, HEAD.idx_inc(j), is_llt, regularize, eps.clone(), delta.clone(), signs)?;
251 j += blocksize;
252 }
253
254 Ok(count)
255}
256
257fn simd_cholesky<T: ComplexField>(
258 A: MatMut<'_, T>,
259 D: RowMut<'_, T>,
260 is_llt: bool,
261 regularize: bool,
262 eps: T::Real,
263 delta: T::Real,
264 signs: Option<&[i8]>,
265) -> Result<usize, usize> {
266 struct Impl<'a, T: ComplexField> {
267 A: MatMut<'a, T, usize, usize, ContiguousFwd>,
268 D: RowMut<'a, T>,
269 is_llt: bool,
270 regularize: bool,
271 eps: T::Real,
272 delta: T::Real,
273 signs: Option<&'a [i8]>,
274 }
275
276 impl<'a, T: ComplexField> pulp::WithSimd for Impl<'a, T> {
277 type Output = Result<usize, usize>;
278
279 #[inline(always)]
280 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
281 let Self {
282 A,
283 D,
284 is_llt,
285 regularize,
286 eps,
287 delta,
288 signs,
289 } = self;
290 let simd = T::simd_ctx(simd);
291 if A.nrows() > 0 {
292 simd_cholesky_matrix(simd, A, D, is_llt, regularize, eps, delta, signs)
293 } else {
294 Ok(0)
295 }
296 }
297 }
298
299 let mut A = A;
300 if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
301 if let Some(A) = A.rb_mut().try_as_col_major_mut() {
302 dispatch!(
303 Impl {
304 A,
305 D,
306 is_llt,
307 regularize,
308 eps,
309 delta,
310 signs,
311 },
312 Impl,
313 T
314 )
315 } else {
316 cholesky_fallback(A, D, is_llt, regularize, eps.clone(), delta.clone(), signs)
317 }
318 } else {
319 cholesky_fallback(A, D, is_llt, regularize, eps.clone(), delta.clone(), signs)
320 }
321}
322
323#[math]
324fn cholesky_fallback<T: ComplexField>(
325 A: MatMut<'_, T>,
326 D: RowMut<'_, T>,
327 is_llt: bool,
328 regularize: bool,
329 eps: T::Real,
330 delta: T::Real,
331 signs: Option<&[i8]>,
332) -> Result<usize, usize> {
333 let n = A.nrows();
334 let mut count = 0;
335 let mut A = A;
336 let mut D = D;
337
338 for j in 0..n {
339 for i in j..n {
340 let mut sum = zero();
341 for k in 0..j {
342 let D = real(D[k]);
343 let D = if is_llt { one() } else { D };
344
345 sum = sum + mul_real(conj(A[(j, k)]) * A[(i, k)], D);
346 }
347 A[(i, j)] = A[(i, j)] - sum;
348 }
349
350 let D = D.rb_mut().at_mut(j);
351 let mut diag = real(A[(j, j)]);
352
353 if regularize {
354 let sign = if is_llt { 1 } else { if let Some(signs) = signs { signs[j] } else { 0 } };
355
356 let small_or_negative = diag <= eps;
357 let minus_small_or_positive = diag >= -eps;
358
359 if sign == 1 && small_or_negative {
360 diag = copy(delta);
361 count += 1;
362 } else if sign == -1i8 && minus_small_or_positive {
363 diag = neg(delta);
364 } else {
365 if small_or_negative && minus_small_or_positive {
366 if diag < zero() {
367 diag = neg(delta);
368 } else {
369 diag = copy(delta);
370 }
371 }
372 }
373 }
374
375 let diag = if is_llt {
376 if !(diag > zero()) {
377 *D = from_real(diag);
378 return Err(j);
379 }
380 sqrt(diag)
381 } else {
382 copy(diag)
383 };
384 *D = from_real(diag);
385
386 if diag == zero() || !is_finite(diag) {
387 return Err(j);
388 }
389
390 let inv = recip(diag);
391
392 for i in j..n {
393 A[(i, j)] = mul_real(A[(i, j)], inv);
394 }
395 }
396
397 Ok(count)
398}
399
400#[math]
401pub(crate) fn cholesky_recursion<T: ComplexField>(
402 A: MatMut<'_, T>,
403 D: RowMut<'_, T>,
404
405 recursion_threshold: usize,
406 blocksize: usize,
407 is_llt: bool,
408 regularize: bool,
409 eps: &T::Real,
410 delta: &T::Real,
411 signs: Option<&[i8]>,
412 par: Par,
413) -> Result<usize, usize> {
414 let n = A.ncols();
415 if n <= recursion_threshold {
416 simd_cholesky(A, D, is_llt, regularize, eps.clone(), delta.clone(), signs)
417 } else {
418 let mut count = 0;
419 let blocksize = Ord::min(n.next_power_of_two() / 2, blocksize);
420 let mut A = A;
421 let mut D = D;
422
423 let mut j = 0;
424 while j < n {
425 let blocksize = Ord::min(blocksize, n - j);
426
427 let (mut A00, A01, mut A10, mut A11) = A.rb_mut().get_mut(j.., j..).split_at_mut(blocksize, blocksize);
428
429 let mut D0 = D.rb_mut().subcols_mut(j, blocksize);
430
431 let mut L10xD0 = A01.transpose_mut();
432
433 let signs = signs.map(|signs| &signs[j..][..blocksize]);
434
435 match cholesky_recursion(
436 A00.rb_mut(),
437 D0.rb_mut(),
438 recursion_threshold,
439 blocksize,
440 is_llt,
441 regularize,
442 eps,
443 delta,
444 signs,
445 par,
446 ) {
447 Ok(local_count) => count += local_count,
448 Err(fail_idx) => return Err(j + fail_idx),
449 }
450 let A00 = A00.rb();
451
452 if is_llt {
453 linalg::triangular_solve::solve_lower_triangular_in_place(A00.conjugate(), A10.rb_mut().transpose_mut(), par)
454 } else {
455 linalg::triangular_solve::solve_unit_lower_triangular_in_place(A00.conjugate(), A10.rb_mut().transpose_mut(), par)
456 }
457 let mut A10 = A10.rb_mut();
458
459 if is_llt {
460 linalg::matmul::triangular::matmul(
461 A11.rb_mut(),
462 BlockStructure::TriangularLower,
463 Accum::Add,
464 A10.rb(),
465 BlockStructure::Rectangular,
466 A10.rb().adjoint(),
467 BlockStructure::Rectangular,
468 -one::<T>(),
469 par,
470 );
471 } else {
472 if has_spicy_matmul::<T>() {
473 for k in 0..blocksize {
474 let d = real(D0[k]);
475 let d = recip(d);
476
477 for i in j + blocksize..n {
478 let i = i - (j + blocksize);
479 A10[(i, k)] = mul_real(A10[(i, k)], d);
480 }
481 }
482 spicy_matmul::<usize, T>(
483 A11.rb_mut(),
484 BlockStructure::TriangularLower,
485 None,
486 None,
487 Accum::Add,
488 A10.rb(),
489 Conj::No,
490 A10.rb().transpose(),
491 Conj::Yes,
492 Some(D0.rb().transpose().as_diagonal()),
493 -one::<T>(),
494 par,
495 MemStack::new(&mut []),
496 );
497 } else {
498 for k in 0..blocksize {
499 let d = real(D0[k]);
500 let d = recip(d);
501
502 for i in j + blocksize..n {
503 let i = i - (j + blocksize);
504 let a = copy(A10[(i, k)]);
505 A10[(i, k)] = mul_real(A10[(i, k)], d);
506 L10xD0[(i, k)] = a;
507 }
508 }
509 linalg::matmul::triangular::matmul(
510 A11.rb_mut(),
511 BlockStructure::TriangularLower,
512 Accum::Add,
513 A10,
514 BlockStructure::Rectangular,
515 L10xD0.adjoint(),
516 BlockStructure::Rectangular,
517 -one::<T>(),
518 par,
519 );
520 }
521 };
522
523 j += blocksize;
524 }
525
526 Ok(count)
527 }
528}
529
530#[derive(Copy, Clone, Debug)]
534pub struct LdltRegularization<'a, T> {
535 pub dynamic_regularization_signs: Option<&'a [i8]>,
537 pub dynamic_regularization_delta: T,
539 pub dynamic_regularization_epsilon: T,
541}
542
543#[derive(Copy, Clone, Debug)]
545pub struct LdltInfo {
546 pub dynamic_regularization_count: usize,
548}
549
550#[derive(Copy, Clone, Debug)]
552pub enum LdltError {
553 ZeroPivot { index: usize },
554}
555
556impl core::fmt::Display for LdltError {
557 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
558 core::fmt::Debug::fmt(self, f)
559 }
560}
561impl core::error::Error for LdltError {}
562
563impl<T: RealField> Default for LdltRegularization<'_, T> {
564 fn default() -> Self {
565 Self {
566 dynamic_regularization_signs: None,
567 dynamic_regularization_delta: zero(),
568 dynamic_regularization_epsilon: zero(),
569 }
570 }
571}
572
573#[derive(Copy, Clone, Debug)]
574pub struct LdltParams {
575 pub recursion_threshold: usize,
576 pub blocksize: usize,
577 #[doc(hidden)]
578 pub non_exhaustive: NonExhaustive,
579}
580
581impl<T: ComplexField> Auto<T> for LdltParams {
582 #[inline]
583 fn auto() -> Self {
584 Self {
585 recursion_threshold: 64,
586 blocksize: 128,
587 non_exhaustive: NonExhaustive(()),
588 }
589 }
590}
591
592#[inline]
593pub fn cholesky_in_place_scratch<T: ComplexField>(dim: usize, par: Par, params: Spec<LdltParams, T>) -> StackReq {
594 _ = par;
595 _ = params;
596 temp_mat_scratch::<T>(dim, 1)
597}
598
599#[math]
600pub fn cholesky_in_place<T: ComplexField>(
601 A: MatMut<'_, T>,
602 regularization: LdltRegularization<'_, T::Real>,
603 par: Par,
604 stack: &mut MemStack,
605 params: Spec<LdltParams, T>,
606) -> Result<LdltInfo, LdltError> {
607 let params = params.config;
608
609 let n = A.nrows();
610 let mut D = unsafe { temp_mat_uninit(n, 1, stack).0 };
611 let D = D.as_mat_mut();
612 let mut D = D.col_mut(0).transpose_mut();
613 let mut A = A;
614
615 let ret = match cholesky_recursion(
616 A.rb_mut(),
617 D.rb_mut(),
618 params.recursion_threshold,
619 params.blocksize,
620 false,
621 regularization.dynamic_regularization_delta > zero() && regularization.dynamic_regularization_epsilon > zero(),
622 ®ularization.dynamic_regularization_epsilon,
623 ®ularization.dynamic_regularization_delta,
624 regularization.dynamic_regularization_signs.map(|signs| signs),
625 par,
626 ) {
627 Ok(count) => Ok(LdltInfo {
628 dynamic_regularization_count: count,
629 }),
630 Err(index) => Err(LdltError::ZeroPivot { index }),
631 };
632 let init = if let Err(LdltError::ZeroPivot { index }) = ret { index + 1 } else { n };
633
634 for i in 0..init {
635 A[(i, i)] = copy(D[i]);
636 }
637
638 ret
639}
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644 use crate::stats::prelude::*;
645 use crate::utils::approx::*;
646 use crate::{Mat, Row, assert, c64};
647
648 #[test]
649 fn test_simd_cholesky() {
650 let rng = &mut StdRng::seed_from_u64(0);
651
652 type T = c64;
653
654 for n in 0..=64 {
655 for f in [cholesky_fallback::<T>, simd_cholesky::<T>] {
656 for llt in [true, false] {
657 let approx_eq = CwiseMat(ApproxEq {
658 abs_tol: 1e-12,
659 rel_tol: 1e-12,
660 });
661
662 let A = CwiseMatDistribution {
663 nrows: n,
664 ncols: n,
665 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
666 }
667 .rand::<Mat<c64>>(rng);
668
669 let A = &A * &A.adjoint();
670 let A = A.as_ref().as_shape(n, n);
671
672 let mut L = A.cloned();
673 let mut L = L.as_mut();
674 let mut D = Row::zeros(n);
675 let mut D = D.as_mut();
676
677 f(L.rb_mut(), D.rb_mut(), llt, false, 0.0, 0.0, None).unwrap();
678
679 for j in 0..n {
680 for i in 0..j {
681 L[(i, j)] = c64::ZERO;
682 }
683 }
684 let L = L.rb().as_dyn_stride();
685
686 if llt {
687 assert!(L * L.adjoint() ~ A);
688 } else {
689 assert!(L * D.as_diagonal() * L.adjoint() ~ A);
690 };
691 }
692 }
693 }
694 }
695
696 #[test]
697 fn test_cholesky() {
698 let rng = &mut StdRng::seed_from_u64(0);
699
700 for n in [2, 4, 8, 31, 127, 240] {
701 for llt in [false, true] {
702 let approx_eq = CwiseMat(ApproxEq {
703 abs_tol: 1e-12,
704 rel_tol: 1e-12,
705 });
706
707 let A = CwiseMatDistribution {
708 nrows: n,
709 ncols: n,
710 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
711 }
712 .rand::<Mat<c64>>(rng);
713
714 let A = &A * &A.adjoint();
715 let A = A.as_ref();
716
717 let mut L = A.cloned();
718 let mut L = L.as_mut();
719 let mut D = Row::zeros(n);
720 let mut D = D.as_mut();
721
722 cholesky_recursion(L.rb_mut(), D.rb_mut(), 32, 32, llt, false, &0.0, &0.0, None, Par::Seq).unwrap();
723
724 for j in 0..n {
725 for i in 0..j {
726 L[(i, j)] = c64::ZERO;
727 }
728 }
729 let L = L.rb().as_dyn_stride();
730
731 if llt {
732 assert!(L * L.adjoint() ~ A);
733 } else {
734 assert!(L * D.as_diagonal() * L.adjoint() ~ A);
735 };
736 }
737 }
738 }
739}