1use crate::internal_prelude::*;
2use crate::{assert, perm};
3use linalg::matmul::triangular::BlockStructure;
4
5#[derive(Copy, Clone, Debug, PartialEq, Eq)]
7#[non_exhaustive]
8pub enum PivotingStrategy {
9 #[deprecated]
11 Diagonal,
12
13 Partial,
15 PartialDiag,
18 Rook,
20 RookDiag,
23
24 Full,
26}
27
28#[derive(Copy, Clone, Debug)]
30pub struct LbltParams {
31 pub pivoting: PivotingStrategy,
33 pub blocksize: usize,
35
36 pub par_threshold: usize,
38
39 #[doc(hidden)]
40 pub non_exhaustive: NonExhaustive,
41}
42
43#[math]
44fn swap_self_adjoint<T: ComplexField>(A: MatMut<'_, T>, i: usize, j: usize) {
45 assert_ne!(i, j);
46
47 let mut A = A;
48 let (i, j) = (Ord::min(i, j), Ord::max(i, j));
49
50 perm::swap_cols_idx(A.rb_mut().get_mut(j + 1.., ..), i, j);
51 perm::swap_rows_idx(A.rb_mut().get_mut(.., ..i), i, j);
52
53 let tmp = real(A[(i, i)]);
54 A[(i, i)] = from_real(real(A[(j, j)]));
55 A[(j, j)] = from_real(tmp);
56
57 A[(j, i)] = conj(A[(j, i)]);
58
59 let (Ai, Aj) = A.split_at_row_mut(j);
60 let Ai = Ai.get_mut(i + 1..j, i);
61 let Aj = Aj.get_mut(0, i + 1..j).transpose_mut();
62 zip!(Ai, Aj).for_each(|unzip!(x, y)| {
63 let tmp = conj(*x);
64 *x = conj(*y);
65 *y = tmp;
66 });
67}
68
69#[math]
70fn rank_1_update_and_argmax_fallback<'M, 'N, T: ComplexField>(
71 A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
72 L: ColRef<'_, T, Dim<'N>>,
73 d: T::Real,
74 start: IdxInc<'N>,
75 end: IdxInc<'N>,
76) -> (usize, usize, T::Real) {
77 let mut A = A;
78 let n = A.nrows();
79
80 let mut max_j = n.idx(0);
81 let mut max_i = n.idx(0);
82 let mut max_offdiag = zero();
83
84 for j in start.to(end) {
85 for i in j.next().to(n.end()) {
86 A[(i, j)] = A[(i, j)] - mul_real(L[i] * conj(L[j]), d);
87 let val = abs2(A[(i, j)]);
88 if val > max_offdiag {
89 max_offdiag = val;
90 max_i = i;
91 max_j = j;
92 }
93 }
94 }
95
96 (*max_i, *max_j, max_offdiag)
97}
98
99#[math]
100fn rank_2_update_and_argmax_fallback<'N, T: ComplexField>(
101 A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
102 L0: ColRef<'_, T, Dim<'N>>,
103 L1: ColRef<'_, T, Dim<'N>>,
104 d: T::Real,
105 d00: T::Real,
106 d11: T::Real,
107 d10: T,
108 start: IdxInc<'N>,
109 end: IdxInc<'N>,
110) -> (usize, usize, T::Real) {
111 let mut A = A;
112 let n = A.nrows();
113
114 let mut max_j = n.idx(0);
115 let mut max_i = n.idx(0);
116 let mut max_offdiag = zero();
117
118 for j in start.to(end) {
119 let x0 = copy(L0[j]);
120 let x1 = copy(L1[j]);
121
122 let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
123 let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
124
125 for i in j.next().to(n.end()) {
126 A[(i, j)] = A[(i, j)] - L0[i] * conj(w0) - L1[i] * conj(w1);
127
128 let val = abs2(A[(i, j)]);
129 if val > max_offdiag {
130 max_offdiag = val;
131 max_i = i;
132 max_j = j;
133 }
134 }
135 }
136 (*max_i, *max_j, max_offdiag)
137}
138
139#[math]
140fn rank_1_update_and_argmax_seq<'M, 'N, T: ComplexField>(
141 A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
142 L: ColRef<'_, T, Dim<'N>>,
143 d: T::Real,
144 start: IdxInc<'N>,
145 end: IdxInc<'N>,
146) -> (usize, usize, T::Real) {
147 rank_1_update_and_argmax_fallback(A, L, d, start, end)
148}
149
150#[math]
151fn rank_2_update_and_argmax_seq<'N, T: ComplexField>(
152 A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
153 L0: ColRef<'_, T, Dim<'N>>,
154 L1: ColRef<'_, T, Dim<'N>>,
155 d: T::Real,
156 d00: T::Real,
157 d11: T::Real,
158 d10: T,
159 start: IdxInc<'N>,
160 end: IdxInc<'N>,
161) -> (usize, usize, T::Real) {
162 rank_2_update_and_argmax_fallback(A, L0, L1, d, d00, d11, d10, start, end)
163}
164
165#[math]
166fn rank_1_update_and_argmax<T: ComplexField>(A: MatMut<'_, T>, L: ColRef<'_, T>, d: T::Real, par: Par) -> (usize, usize, T::Real) {
167 with_dim!(N, A.nrows());
168
169 match par {
170 Par::Seq => rank_1_update_and_argmax_seq(A.as_shape_mut(N, N), L.as_row_shape(N), d, IdxInc::ZERO, N.end()),
171 #[cfg(feature = "rayon")]
172 Par::Rayon(nthreads) => {
173 use rayon::prelude::*;
174 let nthreads = nthreads.get();
175 let n = *N;
176
177 assert!((n as u64) < (1u64 << 50));
179
180 let idx_to_col_start = |idx: usize| {
181 let idx_as_percent = idx as f64 / nthreads as f64;
182 let col_start_percent = 1.0f64 - libm::sqrt(1.0f64 - idx_as_percent);
183 (col_start_percent * n as f64) as usize
184 };
185
186 let mut r = alloc::vec![(0usize, 0usize, zero::<T::Real>()); nthreads];
187
188 r.par_iter_mut().enumerate().for_each(|(idx, out)| {
189 let A = unsafe { A.rb().const_cast() };
190 let start = N.idx_inc(idx_to_col_start(idx));
191 let end = N.idx_inc(idx_to_col_start(idx + 1));
192
193 *out = rank_1_update_and_argmax_seq(A.as_shape_mut(N, N), L.as_row_shape(N), copy(d), start, end);
194 });
195
196 r.into_iter()
197 .max_by(|(_, _, a), (_, _, b)| {
198 if a == b {
199 core::cmp::Ordering::Equal
200 } else if a > b {
201 core::cmp::Ordering::Greater
202 } else {
203 core::cmp::Ordering::Less
204 }
205 })
206 .unwrap()
207 },
208 }
209}
210
211#[math]
212fn rank_2_update_and_argmax<'N, T: ComplexField>(
213 A: MatMut<'_, T>,
214 L0: ColRef<'_, T>,
215 L1: ColRef<'_, T>,
216 d: T::Real,
217 d00: T::Real,
218 d11: T::Real,
219 d10: T,
220 par: Par,
221) -> (usize, usize, T::Real) {
222 with_dim!(N, A.nrows());
223
224 match par {
225 Par::Seq => rank_2_update_and_argmax_seq(
226 A.as_shape_mut(N, N),
227 L0.as_row_shape(N),
228 L1.as_row_shape(N),
229 d,
230 d00,
231 d11,
232 d10,
233 IdxInc::ZERO,
234 N.end(),
235 ),
236 #[cfg(feature = "rayon")]
237 Par::Rayon(nthreads) => {
238 use rayon::prelude::*;
239 let nthreads = nthreads.get();
240 let n = *N;
241
242 assert!((n as u64) < (1u64 << 50));
244
245 let idx_to_col_start = |idx: usize| {
246 let idx_as_percent = idx as f64 / nthreads as f64;
247 let col_start_percent = 1.0f64 - libm::sqrt(1.0f64 - idx_as_percent);
248 (col_start_percent * n as f64) as usize
249 };
250
251 let mut r = alloc::vec![(0usize, 0usize, zero::<T::Real>()); nthreads];
252
253 r.par_iter_mut().enumerate().for_each(|(idx, out)| {
254 let A = unsafe { A.rb().const_cast() };
255 let start = N.idx_inc(idx_to_col_start(idx));
256 let end = N.idx_inc(idx_to_col_start(idx + 1));
257
258 *out = rank_2_update_and_argmax_seq(
259 A.as_shape_mut(N, N),
260 L0.as_row_shape(N),
261 L1.as_row_shape(N),
262 copy(d),
263 copy(d00),
264 copy(d11),
265 copy(d10),
266 start,
267 end,
268 );
269 });
270
271 r.into_iter()
272 .max_by(|(_, _, a), (_, _, b)| {
273 if a == b {
274 core::cmp::Ordering::Equal
275 } else if a < b {
276 core::cmp::Ordering::Less
277 } else {
278 core::cmp::Ordering::Greater
279 }
280 })
281 .unwrap()
282 },
283 }
284}
285
286#[math]
287fn lblt_full_piv<T: ComplexField>(A: MatMut<'_, T>, subdiag: DiagMut<'_, T>, pivots: &mut [usize], par: Par, params: LbltParams) {
288 let alpha = (one::<T::Real>() + sqrt(from_f64::<T::Real>(17.0))) * from_f64::<T::Real>(0.125);
289 let alpha = alpha * alpha;
290
291 let mut A = A;
292 let mut subdiag = subdiag.column_vector_mut();
293 let mut par = par;
294 let n = A.nrows();
295
296 let scale_fwd = A.norm_max();
297 let scale_bwd = recip(scale_fwd);
298 zip!(A.rb_mut()).for_each(|unzip!(x)| *x = mul_real(*x, scale_bwd));
299
300 let mut max_i = 0;
301 let mut max_j = 0;
302 let mut max_offdiag = zero();
303
304 for j in 0..n {
305 for i in j + 1..n {
306 let val = abs2(A[(i, j)]);
307 if val > max_offdiag {
308 max_offdiag = val;
309 max_i = i;
310 max_j = j;
311 }
312 }
313 }
314
315 let mut k = 0;
316 while k < n {
317 if max_offdiag == zero() {
318 break;
319 }
320
321 let (mut Aprev, mut A) = A.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
322 let mut subdiag = subdiag.rb_mut().get_mut(k..);
323 let pivots = &mut pivots[k..];
324
325 let n = A.nrows();
326 let mut max_s = 0;
327 let mut max_diag = zero();
328
329 for s in 0..n {
330 let val = abs2(A[(s, s)]);
331 if val > max_diag {
332 max_diag = val;
333 max_s = s;
334 }
335 }
336
337 let npiv;
338 let i0;
339 let i1;
340
341 if max_diag >= alpha * max_offdiag {
342 npiv = 1;
343 i0 = max_s;
344 i1 = usize::MAX;
345 } else {
346 npiv = 2;
347 i0 = max_j;
348 i1 = max_i;
349 }
350
351 let rem = n - npiv;
352 if rem * rem < params.par_threshold {
353 par = Par::Seq;
354 }
355
356 if i0 != 0 {
358 swap_self_adjoint(A.rb_mut(), 0, i0);
359 perm::swap_rows_idx(Aprev.rb_mut(), 0, i0);
360 }
361 if npiv == 2 && i1 != 1 {
362 swap_self_adjoint(A.rb_mut(), 1, i1);
363 perm::swap_rows_idx(Aprev.rb_mut(), 1, i1);
364 }
365
366 if npiv == 1 {
367 let diag = real(A[(0, 0)]);
368 let diag_inv = recip(diag);
369 subdiag[0] = zero();
370
371 let (_, _, L, mut A) = A.rb_mut().split_at_mut(1, 1);
372 let n = A.nrows();
373 let mut L = L.col_mut(0);
374
375 zip!(L.rb_mut()).for_each(|unzip!(x)| *x = mul_real(*x, diag_inv));
376
377 for i in 0..n {
378 A[(i, i)] = from_real(real(A[(i, i)]) - diag * abs2(L[i]));
379 }
380
381 if n < params.par_threshold {}
382 if n != 0 {
383 (max_i, max_j, max_offdiag) = rank_1_update_and_argmax(A.rb_mut(), L.rb(), diag, par);
384 }
385 } else {
386 let a00 = real(A[(0, 0)]);
387 let a11 = real(A[(1, 1)]);
388 let a10 = copy(A[(1, 0)]);
389
390 subdiag[0] = copy(a10);
391 subdiag[1] = zero();
392 A[(1, 0)] = zero();
393
394 let d10 = abs(a10);
395 let d10_inv = recip(d10);
396 let d00 = a00 * d10_inv;
397 let d11 = a11 * d10_inv;
398
399 let t = recip(d00 * d11 - one());
401 let d10 = mul_real(a10, d10_inv);
402 let d = t * d10_inv;
403
404 let (_, _, L, mut A) = A.rb_mut().split_at_mut(2, 2);
407 let (mut L0, mut L1) = L.two_cols_mut(0, 1);
408 let n = A.nrows();
409
410 if n != 0 {
411 (max_i, max_j, max_offdiag) = rank_2_update_and_argmax(A.rb_mut(), L0.rb(), L1.rb(), copy(d), copy(d00), copy(d11), copy(d10), par);
412 }
413
414 for j in 0..n {
415 let x0 = copy(L0[j]);
416 let x1 = copy(L1[j]);
417
418 let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
419 let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
420
421 A[(j, j)] = from_real(real(A[(j, j)] - L0[j] * conj(w0) - L1[j] * conj(w1)));
422
423 L0[j] = w0;
424 L1[j] = w1;
425 }
426 }
427
428 if npiv == 2 {
429 pivots[0] = !(i0 + k);
430 pivots[1] = !(i1 + k);
431 } else {
432 pivots[0] = i0 + k;
433 }
434 k += npiv;
435 }
436
437 while k < n {
438 let (mut Aprev, mut A) = A.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
439 let mut subdiag = subdiag.rb_mut().get_mut(k..);
440 let pivots = &mut pivots[k..];
441
442 let n = A.nrows();
443 let mut max_s = 0;
444 let mut max_diag = zero();
445
446 for s in 0..n {
447 let val = abs2(A[(s, s)]);
448 if val > max_diag {
449 max_diag = val;
450 max_s = s;
451 }
452 }
453
454 if max_s != 0 {
455 let (mut A0, mut As) = A.rb_mut().two_cols_mut(0, max_s);
456 core::mem::swap(&mut A0[0], &mut As[max_s]);
457
458 perm::swap_rows_idx(Aprev.rb_mut(), 0, max_s);
459 }
460
461 subdiag[0] = zero();
462 pivots[0] = max_s + k;
463
464 k += 1;
465 }
466
467 zip!(A.rb_mut().diagonal_mut().column_vector_mut()).for_each(|unzip!(x)| *x = mul_real(*x, scale_fwd));
468 zip!(subdiag.rb_mut()).for_each(|unzip!(x)| *x = mul_real(*x, scale_fwd));
469}
470
471#[math]
472#[track_caller]
473fn l1_argmax<T: ComplexField>(col: ColRef<'_, T>) -> (Option<usize>, T::Real) {
474 let n = col.nrows();
475 if n == 0 {
476 return (None, zero());
477 }
478
479 let mut i = 0;
480 let mut best = zero();
481
482 for j in 0..n {
483 let val = abs1(col[j]);
484 if val > best {
485 best = val;
486 i = j;
487 }
488 }
489
490 (Some(i), best)
491}
492
493#[math]
494#[track_caller]
495fn offdiag_argmax<T: ComplexField>(A: MatRef<'_, T>, idx: usize) -> (Option<usize>, T::Real) {
496 let (mut col_argmax, col_max) = l1_argmax(A.rb().get(idx + 1.., idx));
497 col_argmax.as_mut().map(|col_argmax| *col_argmax += idx + 1);
498 let (row_argmax, row_max) = l1_argmax(A.rb().get(idx, ..idx).transpose());
499
500 if col_max > row_max {
501 (col_argmax, col_max)
502 } else {
503 (row_argmax, row_max)
504 }
505}
506
507#[math]
508fn update_and_offdiag_argmax<T: ComplexField>(
509 mut dst: ColMut<'_, T>,
510 Wl: MatRef<'_, T>,
511 Al: MatRef<'_, T>,
512 Ar: MatRef<'_, T>,
513 i0: usize,
514 par: Par,
515) -> (Option<usize>, T::Real) {
516 let n = Al.nrows();
517 for j in 0..i0 {
518 dst[j] = conj(Ar[(i0, j)]);
519 }
520 dst[i0] = zero();
521 for j in i0 + 1..n {
522 dst[j] = copy(Ar[(j, i0)]);
523 }
524
525 linalg::matmul::matmul(dst.rb_mut(), Accum::Add, Al.rb(), Wl.row(i0).adjoint(), -one::<T>(), par);
526 dst[i0] = zero();
527
528 let ret = l1_argmax(dst.rb());
529 dst[i0] = from_real(real(Ar[(i0, i0)]));
530 if n == 1 { (None, zero()) } else { ret }
531}
532
533#[math]
534fn lblt_blocked_step<T: ComplexField>(
535 alpha: T::Real,
536 W: MatMut<'_, T>,
537 A_left: MatMut<'_, T>,
538 A: MatMut<'_, T>,
539 subdiag: DiagMut<'_, T>,
540 pivots: &mut [usize],
541 rook: bool,
542 diagonal: bool,
543 par: Par,
544) -> usize {
545 let mut A = A;
546 let mut A_left = A_left;
547 let mut subdiag = subdiag;
548 let mut W = W;
549
550 let n = A.nrows();
551 let blocksize = W.ncols();
552
553 assert!(all(A.nrows() == n, A.ncols() == n, W.nrows() == n, subdiag.dim() == n, blocksize >= 2,));
554
555 let kmax = Ord::min(blocksize - 1, n);
556 let mut k = 0usize;
557 while k < kmax {
558 let mut A = A.rb_mut();
559 let mut W = W.rb_mut();
560 let mut subdiag = subdiag.rb_mut().column_vector_mut().get_mut(k..);
561 let mut A_left = A_left.rb_mut().get_mut(k.., ..);
562
563 let (mut Wl, mut Wr) = W.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
564 let (mut Al, mut Ar) = A.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
565 let mut Al = Al.rb_mut();
566 let mut Wr = Wr.rb_mut().get_mut(.., ..2);
567
568 let npiv;
569 let mut i0 = if diagonal {
570 l1_argmax(Ar.rb().diagonal().column_vector()).0.unwrap()
571 } else {
572 0
573 };
574 let mut i1 = usize::MAX;
575
576 let mut nothing_to_do = false;
577
578 let (mut Wr0, mut Wr1) = Wr.rb_mut().two_cols_mut(0, 1);
579
580 let (r, mut gamma_i) = update_and_offdiag_argmax(Wr0.rb_mut(), Wl.rb(), Al.rb(), Ar.rb(), i0, par);
581
582 if k + 1 == n || gamma_i == zero() {
583 nothing_to_do = true;
584 npiv = 1;
585 } else if abs(real(Ar[(i0, i0)])) >= alpha * gamma_i {
586 npiv = 1;
587 } else {
588 i1 = r.unwrap();
589 if rook {
590 loop {
591 let (s, gamma_r) = update_and_offdiag_argmax(Wr1.rb_mut(), Wl.rb(), Al.rb(), Ar.rb(), i1, par);
592
593 if abs1(Ar[(i1, i1)]) >= alpha * gamma_r {
594 npiv = 1;
595 i0 = i1;
596 i1 = usize::MAX;
597 Wr0.copy_from(&Wr1);
598 break;
599 } else if s == Some(i0) || gamma_i == gamma_r {
600 npiv = 2;
601 break;
602 } else {
603 i0 = i1;
604 i1 = s.unwrap();
605 gamma_i = gamma_r;
606 Wr0.copy_from(&Wr1);
607 }
608 }
609 } else {
610 let (_, gamma_r) = update_and_offdiag_argmax(Wr1.rb_mut(), Wl.rb(), Al.rb(), Ar.rb(), i1, par);
611
612 if abs(real(Ar[(i0, i0)])) >= (alpha * gamma_r) * (gamma_r / gamma_i) {
613 npiv = 1;
614 } else if abs(real(Ar[(i1, i1)])) >= alpha * gamma_r {
615 npiv = 1;
616 i0 = i1;
617 i1 = usize::MAX;
618 Wr0.copy_from(&Wr1);
619 } else {
620 npiv = 2;
621 }
622 }
623 }
624
625 if npiv == 2 && i0 > i1 {
626 perm::swap_cols_idx(Wr.rb_mut(), 0, 1);
627 (i0, i1) = (i1, i0);
628 }
629
630 let mut Wr = Wr.rb_mut().get_mut(.., ..npiv);
631
632 'next_iter: {
633 if i0 != 0 {
635 swap_self_adjoint(Ar.rb_mut(), 0, i0);
636 perm::swap_rows_idx(Al.rb_mut(), 0, i0);
637 perm::swap_rows_idx(A_left.rb_mut(), 0, i0);
638 perm::swap_rows_idx(Wl.rb_mut(), 0, i0);
639 perm::swap_rows_idx(Wr.rb_mut(), 0, i0);
640 }
641 if npiv == 2 && i1 != 1 {
642 swap_self_adjoint(Ar.rb_mut(), 1, i1);
643 perm::swap_rows_idx(Al.rb_mut(), 1, i1);
644 perm::swap_rows_idx(A_left.rb_mut(), 1, i1);
645 perm::swap_rows_idx(Wl.rb_mut(), 1, i1);
646 perm::swap_rows_idx(Wr.rb_mut(), 1, i1);
647 }
648
649 if nothing_to_do {
650 break 'next_iter;
651 }
652
653 if npiv == 1 {
654 let W0 = Wr.rb_mut().col_mut(0);
655
656 let diag = real(W0[0]);
657 let diag_inv = recip(diag);
658 subdiag[0] = zero();
659
660 let (_, _, L, mut A) = Ar.rb_mut().split_at_mut(1, 1);
661 let W0 = W0.rb().get(1..);
662 let n = A.nrows();
663
664 let mut L = L.col_mut(0);
665 zip!(W0, L.rb_mut()).for_each(|unzip!(w, a)| *a = mul_real(*w, diag_inv));
666
667 for j in 0..n {
668 A[(j, j)] = from_real(real(A[(j, j)]) - diag * abs2(L[j]));
669 }
670 } else {
671 let a00 = real(Wr[(0, 0)]);
672 let a11 = real(Wr[(1, 1)]);
673 let a10 = copy(Wr[(1, 0)]);
674
675 subdiag[0] = copy(a10);
676 subdiag[1] = zero();
677 Wr[(1, 0)] = zero();
678 Ar[(1, 0)] = zero();
679
680 let d10 = abs(a10);
681 let d10_inv = recip(d10);
682 let d00 = a00 * d10_inv;
683 let d11 = a11 * d10_inv;
684
685 let t = recip(d00 * d11 - one());
687 let d10 = mul_real(a10, d10_inv);
688 let d = t * d10_inv;
689
690 let (_, _, L, mut A) = Ar.rb_mut().split_at_mut(2, 2);
693 let (mut L0, mut L1) = L.two_cols_mut(0, 1);
694 let Wr = Wr.rb().get(2.., ..);
695 let W0 = Wr.col(0);
696 let W1 = Wr.col(1);
697
698 let n = A.nrows();
699 for j in 0..n {
700 let x0 = copy(W0[j]);
701 let x1 = copy(W1[j]);
702
703 let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
704 let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
705
706 A[(j, j)] = from_real(real(A[(j, j)] - W0[j] * conj(w0) - W1[j] * conj(w1)));
707
708 L0[j] = w0;
709 L1[j] = w1;
710 }
711 }
712 }
713
714 let offset = A_left.ncols();
715
716 if npiv == 2 {
717 pivots[k] = !(offset + i0 + k);
718 pivots[k + 1] = !(offset + i1 + k);
719 } else {
720 pivots[k] = offset + i0 + k;
721 }
722 k += npiv;
723 }
724
725 let W = W.rb().get(k.., ..k);
726 let (_, _, Al, mut Ar) = A.rb_mut().split_at_mut(k, k);
727 let Al = Al.rb();
728
729 linalg::matmul::triangular::matmul(
730 Ar.rb_mut(),
731 BlockStructure::StrictTriangularLower,
732 Accum::Add,
733 W,
734 BlockStructure::Rectangular,
735 Al.adjoint(),
736 BlockStructure::Rectangular,
737 -one::<T>(),
738 par,
739 );
740
741 for j in 0..n - k {
742 Ar[(j, j)] = from_real(real(Ar[(j, j)]));
743 }
744
745 k
746}
747
748#[math]
749fn lblt_blocked<T: ComplexField>(
750 A: MatMut<'_, T>,
751 subdiag: DiagMut<'_, T>,
752 pivots: &mut [usize],
753 blocksize: usize,
754 rook: bool,
755 diagonal: bool,
756 par: Par,
757 stack: &mut MemStack,
758) {
759 let alpha = (one::<T::Real>() + sqrt(from_f64::<T::Real>(17.0))) * from_f64::<T::Real>(0.125);
760
761 let mut A = A;
762 let mut subdiag = subdiag.column_vector_mut();
763 let n = A.nrows();
764
765 let mut k = 0;
766 while k < n {
767 let (_, _, A_left, A) = A.rb_mut().split_at_mut(k, k);
768 let (mut W, _) = unsafe { temp_mat_uninit::<T, _, _>(n - k, blocksize, stack) };
769 let W = W.as_mat_mut();
770
771 if blocksize < 2 || n - k <= blocksize {
772 lblt_unblocked(
773 copy(alpha),
774 A_left,
775 A,
776 subdiag.rb_mut().get_mut(k..).as_diagonal_mut(),
777 &mut pivots[k..],
778 rook,
779 diagonal,
780 par,
781 );
782
783 k = n;
784 } else {
785 let blocksize = lblt_blocked_step(
786 copy(alpha),
787 W,
788 A_left,
789 A,
790 subdiag.rb_mut().get_mut(k..).as_diagonal_mut(),
791 &mut pivots[k..],
792 rook,
793 diagonal,
794 par,
795 );
796
797 k += blocksize;
798 }
799 }
800}
801
802#[math]
803fn lblt_unblocked<T: ComplexField>(
804 alpha: T::Real,
805 A_left: MatMut<'_, T>,
806 A: MatMut<'_, T>,
807 subdiag: DiagMut<'_, T>,
808 pivots: &mut [usize],
809 rook: bool,
810 diagonal: bool,
811 par: Par,
812) {
813 let _ = par;
814 let mut A = A;
815 let mut A_left = A_left;
816 let mut subdiag = subdiag;
817
818 let n = A.nrows();
819 assert!(all(A.nrows() == n, A.ncols() == n, subdiag.dim() == n));
820
821 let mut k = 0usize;
822 while k < n {
823 let (_, _, mut L_prev, mut A) = A.rb_mut().split_at_mut(k, k);
824 let mut subdiag = subdiag.rb_mut().column_vector_mut().get_mut(k..);
825 let mut A_left = A_left.rb_mut().get_mut(k.., ..);
826
827 let npiv;
828
829 let mut i0 = if diagonal {
831 l1_argmax(A.rb().diagonal().column_vector()).0.unwrap()
832 } else {
833 0
834 };
835 let mut i1 = usize::MAX;
836
837 let (r, mut gamma_i) = offdiag_argmax(A.rb(), i0);
839
840 let mut nothing_to_do = false;
841
842 if k + 1 == n || gamma_i == zero() {
843 nothing_to_do = true;
844 npiv = 1;
845 } else if abs(real(A[(i0, i0)])) >= alpha * gamma_i {
846 npiv = 1;
847 } else {
848 i1 = r.unwrap();
849
850 if rook {
852 loop {
853 let (s, gamma_r) = offdiag_argmax(A.rb(), i1);
854
855 if abs1(A[(i1, i1)]) >= alpha * gamma_r {
856 npiv = 1;
857 i0 = i1;
858 i1 = usize::MAX;
859 break;
860 } else if gamma_i == gamma_r {
861 npiv = 2;
862 break;
863 } else {
864 i0 = i1;
865 i1 = s.unwrap();
866 gamma_i = gamma_r;
867 }
868 }
869 } else {
870 let (_, gamma_r) = offdiag_argmax(A.rb(), i1);
871 if abs(real(A[(i0, i0)])) >= (alpha * gamma_r) * (gamma_r / gamma_i) {
872 npiv = 1;
873 } else if abs(real(A[(i1, i1)])) >= alpha * gamma_r {
874 npiv = 1;
875 i0 = i1;
876 } else {
877 npiv = 2;
878 }
879 }
880 }
881
882 if npiv == 2 && i0 > i1 {
883 (i0, i1) = (i1, i0);
884 }
885
886 'next_iter: {
887 if i0 != 0 {
889 swap_self_adjoint(A.rb_mut(), 0, i0);
890 perm::swap_rows_idx(A_left.rb_mut(), 0, i0);
891 perm::swap_rows_idx(L_prev.rb_mut(), 0, i0);
892 }
893 if npiv == 2 && i1 != 1 {
894 swap_self_adjoint(A.rb_mut(), 1, i1);
895 perm::swap_rows_idx(A_left.rb_mut(), 1, i1);
896 perm::swap_rows_idx(L_prev.rb_mut(), 1, i1);
897 }
898
899 if nothing_to_do {
900 break 'next_iter;
901 }
902
903 if npiv == 1 {
905 let diag = real(A[(0, 0)]);
906 let diag_inv = recip(diag);
907 subdiag[0] = zero();
908
909 let (_, _, L, A) = A.rb_mut().split_at_mut(1, 1);
910 let L = L.col_mut(0);
911 rank1_update(A, L, diag_inv);
912 } else {
913 let a00 = real(A[(0, 0)]);
914 let a11 = real(A[(1, 1)]);
915 let a10 = copy(A[(1, 0)]);
916
917 subdiag[0] = copy(a10);
918 subdiag[1] = zero();
919 A[(1, 0)] = zero();
920
921 let d10 = abs(a10);
922 let d10_inv = recip(d10);
923 let d00 = a00 * d10_inv;
924 let d11 = a11 * d10_inv;
925
926 let t = recip(d00 * d11 - one());
928 let d10 = mul_real(a10, d10_inv);
929 let d = t * d10_inv;
930
931 let (_, _, L, A) = A.rb_mut().split_at_mut(2, 2);
934 let (L0, L1) = L.two_cols_mut(0, 1);
935 rank2_update(A, L0, L1, d, d00, d10, d11);
936 }
937 }
938
939 let offset = A_left.ncols();
940 if npiv == 2 {
941 pivots[k] = !(offset + i0 + k);
942 pivots[k + 1] = !(offset + i1 + k);
943 } else {
944 pivots[k] = offset + i0 + k;
945 }
946 k += npiv;
947 }
948}
949
950impl<T: ComplexField> Auto<T> for LbltParams {
951 fn auto() -> Self {
952 Self {
953 pivoting: PivotingStrategy::PartialDiag,
954 blocksize: 64,
955 par_threshold: 256 * 512,
956 non_exhaustive: NonExhaustive(()),
957 }
958 }
959}
960
961pub fn rank2_update<'a, T: ComplexField>(
962 mut A: MatMut<'a, T>,
963 mut L0: ColMut<'a, T>,
964 mut L1: ColMut<'a, T>,
965 d: T::Real,
966 d00: T::Real,
967 d10: T,
968 d11: T::Real,
969) {
970 if const { T::SIMD_CAPABILITIES.is_simd() } {
971 if let (Some(A), Some(L0), Some(L1)) = (
972 A.rb_mut().try_as_col_major_mut(),
973 L0.rb_mut().try_as_col_major_mut(),
974 L1.rb_mut().try_as_col_major_mut(),
975 ) {
976 rank2_update_simd(A, L0, L1, d, d00, d10, d11);
977 } else {
978 rank2_update_fallback(A, L0, L1, d, d00, d10, d11);
979 }
980 } else {
981 rank2_update_fallback(A, L0, L1, d, d00, d10, d11);
982 }
983}
984
985#[math]
986pub fn rank2_update_simd<'a, T: ComplexField>(
987 A: MatMut<'a, T, usize, usize, ContiguousFwd>,
988 L0: ColMut<'a, T, usize, ContiguousFwd>,
989 L1: ColMut<'a, T, usize, ContiguousFwd>,
990 d: T::Real,
991 d00: T::Real,
992 d10: T,
993 d11: T::Real,
994) {
995 struct Impl<'a, T: ComplexField> {
996 A: MatMut<'a, T, usize, usize, ContiguousFwd>,
997 L0: ColMut<'a, T, usize, ContiguousFwd>,
998 L1: ColMut<'a, T, usize, ContiguousFwd>,
999 d: T::Real,
1000 d00: T::Real,
1001 d10: T,
1002 d11: T::Real,
1003 }
1004
1005 impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
1006 type Output = ();
1007
1008 #[inline(always)]
1009 fn with_simd<S: pulp::Simd>(self, simd: S) {
1010 let Self {
1011 mut A,
1012 mut L0,
1013 mut L1,
1014 d,
1015 d00,
1016 d10,
1017 d11,
1018 } = self;
1019 let n = A.nrows();
1020 for j in 0..n {
1021 let x0 = copy(L0[j]);
1022 let x1 = copy(L1[j]);
1023 let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
1024 let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
1025
1026 with_dim!({
1027 let subrange_len = n - j;
1028 });
1029 {
1030 let mut A = A.rb_mut().get_mut(j.., j).as_row_shape_mut(subrange_len);
1031 let L0 = L0.rb().get(j..).as_row_shape(subrange_len);
1032 let L1 = L1.rb().get(j..).as_row_shape(subrange_len);
1033 let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), subrange_len);
1034 let (head, body, tail) = simd.indices();
1035
1036 let w0_conj = conj(w0);
1037 let w1_conj = conj(w1);
1038 let w0_conj_neg = -w0_conj;
1039 let w1_conj_neg = -w1_conj;
1040 let w0_splat = simd.splat(&w0_conj_neg);
1041 let w1_splat = simd.splat(&w1_conj_neg);
1042
1043 if let Some(i) = head {
1044 let mut acc = simd.read(A.rb(), i);
1045 let l0_val = simd.read(L0, i);
1046 let l1_val = simd.read(L1, i);
1047 acc = simd.mul_add(l0_val, w0_splat, acc);
1048 acc = simd.mul_add(l1_val, w1_splat, acc);
1049 simd.write(A.rb_mut(), i, acc);
1050 }
1051
1052 for i in body.clone() {
1053 let mut acc = simd.read(A.rb(), i);
1054 let l0_val = simd.read(L0, i);
1055 let l1_val = simd.read(L1, i);
1056 acc = simd.mul_add(l0_val, w0_splat, acc);
1057 acc = simd.mul_add(l1_val, w1_splat, acc);
1058 simd.write(A.rb_mut(), i, acc);
1059 }
1060
1061 if let Some(i) = tail {
1062 let mut acc = simd.read(A.rb(), i);
1063 let l0_val = simd.read(L0, i);
1064 let l1_val = simd.read(L1, i);
1065 acc = simd.mul_add(l0_val, w0_splat, acc);
1066 acc = simd.mul_add(l1_val, w1_splat, acc);
1067 simd.write(A.rb_mut(), i, acc);
1068 }
1069 }
1070 A[(j, j)] = from_real(real(A[(j, j)]));
1071
1072 L0[j] = w0;
1073 L1[j] = w1;
1074 }
1075 }
1076 }
1077 dispatch!(Impl { A, L0, L1, d, d00, d10, d11 }, Impl, T)
1078}
1079
1080#[math]
1081pub fn rank2_update_fallback<'a, T: ComplexField>(
1082 mut A: MatMut<'a, T>,
1083 mut L0: ColMut<'a, T>,
1084 mut L1: ColMut<'a, T>,
1085 d: T::Real,
1086 d00: T::Real,
1087 d10: T,
1088 d11: T::Real,
1089) {
1090 let n = A.nrows();
1091 for j in 0..n {
1092 let x0 = copy(L0[j]);
1093 let x1 = copy(L1[j]);
1094
1095 let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
1096 let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
1097
1098 for i in j..n {
1099 A[(i, j)] = A[(i, j)] - L0[i] * conj(w0) - L1[i] * conj(w1);
1100 }
1101 A[(j, j)] = from_real(real(A[(j, j)]));
1102
1103 L0[j] = w0;
1104 L1[j] = w1;
1105 }
1106}
1107
1108pub fn rank1_update<'a, T: ComplexField>(mut A: MatMut<'a, T>, mut L0: ColMut<'a, T>, d: T::Real) {
1109 if const { T::SIMD_CAPABILITIES.is_simd() } {
1110 if let (Some(A), Some(L0)) = (A.rb_mut().try_as_col_major_mut(), L0.rb_mut().try_as_col_major_mut()) {
1111 rank1_update_simd(A, L0, d);
1112 } else {
1113 rank1_update_fallback(A, L0, d);
1114 }
1115 } else {
1116 rank1_update_fallback(A, L0, d);
1117 }
1118}
1119
1120#[math]
1121pub fn rank1_update_simd<'a, T: ComplexField>(A: MatMut<'a, T, usize, usize, ContiguousFwd>, L0: ColMut<'a, T, usize, ContiguousFwd>, d: T::Real) {
1122 struct Impl<'a, T: ComplexField> {
1123 A: MatMut<'a, T, usize, usize, ContiguousFwd>,
1124 L0: ColMut<'a, T, usize, ContiguousFwd>,
1125 d: T::Real,
1126 }
1127
1128 impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
1129 type Output = ();
1130
1131 #[inline(always)]
1132 fn with_simd<S: pulp::Simd>(self, simd: S) {
1133 let Self { mut A, mut L0, d } = self;
1134
1135 let n = A.nrows();
1136 for j in 0..n {
1137 let x0 = copy(L0[j]);
1138 let w0 = mul_real(x0, d);
1139
1140 with_dim!({
1141 let subrange_len = n - j;
1142 });
1143 {
1144 let mut A = A.rb_mut().get_mut(j.., j).as_row_shape_mut(subrange_len);
1145 let L0 = L0.rb().get(j..).as_row_shape(subrange_len);
1146 let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), subrange_len);
1147 let (head, body, tail) = simd.indices();
1148
1149 let w0_conj = conj(w0);
1150 let w0_conj_neg = -w0_conj;
1151 let w0_splat = simd.splat(&w0_conj_neg);
1152
1153 if let Some(i) = head {
1154 let mut acc = simd.read(A.rb(), i);
1155 let l0_val = simd.read(L0, i);
1156 acc = simd.mul_add(l0_val, w0_splat, acc);
1157 simd.write(A.rb_mut(), i, acc);
1158 }
1159
1160 for i in body.clone() {
1161 let mut acc = simd.read(A.rb(), i);
1162 let l0_val = simd.read(L0, i);
1163 acc = simd.mul_add(l0_val, w0_splat, acc);
1164 simd.write(A.rb_mut(), i, acc);
1165 }
1166
1167 if let Some(i) = tail {
1168 let mut acc = simd.read(A.rb(), i);
1169 let l0_val = simd.read(L0, i);
1170 acc = simd.mul_add(l0_val, w0_splat, acc);
1171 simd.write(A.rb_mut(), i, acc);
1172 }
1173 }
1174 A[(j, j)] = from_real(real(A[(j, j)]));
1175
1176 L0[j] = w0;
1177 }
1178 }
1179 }
1180 dispatch!(Impl { A, L0, d }, Impl, T)
1181}
1182
1183#[math]
1184pub fn rank1_update_fallback<'a, T: ComplexField>(mut A: MatMut<'a, T>, mut L0: ColMut<'a, T>, d: T::Real) {
1185 let n = A.nrows();
1186 for j in 0..n {
1187 let x0 = copy(L0[j]);
1188 let w0 = mul_real(x0, d);
1189
1190 for i in j..n {
1191 A[(i, j)] = A[(i, j)] - L0[i] * conj(w0);
1192 }
1193 A[(j, j)] = from_real(real(A[(j, j)]));
1194 L0[j] = w0;
1195 }
1196}
1197pub fn cholesky_in_place_scratch<I: Index, T: ComplexField>(dim: usize, par: Par, params: Spec<LbltParams, T>) -> StackReq {
1200 let params = params.config;
1201 let _ = par;
1202 let mut bs = params.blocksize;
1203 if bs < 2 || dim <= bs {
1204 bs = 0;
1205 }
1206 StackReq::new::<usize>(dim).and(temp_mat_scratch::<T>(dim, bs))
1207}
1208
1209#[derive(Copy, Clone, Debug)]
1211pub struct LbltInfo {
1212 pub transposition_count: usize,
1214}
1215
1216#[track_caller]
1230#[math]
1231pub fn cholesky_in_place<'out, I: Index, T: ComplexField>(
1232 A: MatMut<'_, T>,
1233 subdiag: DiagMut<'_, T>,
1234 perm: &'out mut [I],
1235 perm_inv: &'out mut [I],
1236 par: Par,
1237 stack: &mut MemStack,
1238 params: Spec<LbltParams, T>,
1239) -> (LbltInfo, PermRef<'out, I>) {
1240 let params = params.config;
1241
1242 let truncate = <I::Signed as SignedIndex>::truncate;
1243
1244 let n = A.nrows();
1245 assert!(all(A.nrows() == A.ncols(), subdiag.dim() == n, perm.len() == n, perm_inv.len() == n));
1246
1247 #[cfg(feature = "perf-warn")]
1248 if A.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(CHOLESKY_WARN) {
1249 if A.col_stride().unsigned_abs() == 1 {
1250 log::warn!(target: "faer_perf", "$LBL^\top$ decomposition prefers column-major
1251 matrix. Found row-major matrix.");
1252 } else {
1253 log::warn!(target: "faer_perf", "$LBL^\top$ decomposition prefers column-major
1254 matrix. Found matrix with generic strides.");
1255 }
1256 }
1257
1258 let (mut pivots, stack) = stack.make_with::<usize>(n, |_| 0);
1259 let pivots = &mut *pivots;
1260
1261 let mut bs = params.blocksize;
1262 if bs < 2 || n <= bs {
1263 bs = 0;
1264 }
1265
1266 let (rook, diagonal) = match params.pivoting {
1267 PivotingStrategy::Partial => (false, false),
1268 PivotingStrategy::PartialDiag => (false, true),
1269 PivotingStrategy::Rook => (true, false),
1270 PivotingStrategy::RookDiag => (true, true),
1271 _ => (false, false),
1272 };
1273
1274 if params.pivoting == PivotingStrategy::Full {
1275 lblt_full_piv(A, subdiag, pivots, par, params);
1276 } else {
1277 lblt_blocked(A, subdiag, pivots, bs, rook, diagonal, par, stack);
1278 }
1279
1280 for (i, p) in perm.iter_mut().enumerate() {
1281 *p = I::from_signed(truncate(i));
1282 }
1283
1284 let mut transposition_count = 0usize;
1285 for i in 0..n {
1286 let mut p = pivots[i];
1287 if (p as isize) < 0 {
1288 p = !p;
1289 }
1290 if i != p {
1291 transposition_count += 1;
1292 }
1293 perm.swap(i, p);
1294 }
1295 for (i, &p) in perm.iter().enumerate() {
1296 perm_inv[p.to_signed().zx()] = I::from_signed(truncate(i));
1297 }
1298
1299 (LbltInfo { transposition_count }, unsafe { PermRef::new_unchecked(perm, perm_inv, n) })
1300}