faer/linalg/cholesky/bunch_kaufman/
factor.rs

1use crate::internal_prelude::*;
2use crate::{assert, perm};
3use linalg::matmul::triangular::BlockStructure;
4
5/// pivoting strategy for choosing the pivots
6#[derive(Copy, Clone, Debug, PartialEq, Eq)]
7#[non_exhaustive]
8pub enum PivotingStrategy {
9	/// deprecated, corresponds to partial pivoting
10	#[deprecated]
11	Diagonal,
12
13	/// searches for the k-th pivot in the k-th column
14	Partial,
15	/// searches for the k-th pivot in the k-th column, as well as the tail of the diagonal of the
16	/// matrix
17	PartialDiag,
18	/// searches for pivots that are locally optimal
19	Rook,
20	/// searches for pivots that are locally optimal, as well as the tail of the diagonal of the
21	/// matrix
22	RookDiag,
23
24	/// searches for pivots that are globally optimal
25	Full,
26}
27
28/// tuning parameters for the decomposition
29#[derive(Copy, Clone, Debug)]
30pub struct LbltParams {
31	/// pivoting strategy
32	pub pivoting: PivotingStrategy,
33	/// block size of the algorithm
34	pub blocksize: usize,
35
36	/// threshold at which size parallelism should be disabled
37	pub par_threshold: usize,
38
39	#[doc(hidden)]
40	pub non_exhaustive: NonExhaustive,
41}
42
43#[math]
44fn swap_self_adjoint<T: ComplexField>(A: MatMut<'_, T>, i: usize, j: usize) {
45	assert_ne!(i, j);
46
47	let mut A = A;
48	let (i, j) = (Ord::min(i, j), Ord::max(i, j));
49
50	perm::swap_cols_idx(A.rb_mut().get_mut(j + 1.., ..), i, j);
51	perm::swap_rows_idx(A.rb_mut().get_mut(.., ..i), i, j);
52
53	let tmp = real(A[(i, i)]);
54	A[(i, i)] = from_real(real(A[(j, j)]));
55	A[(j, j)] = from_real(tmp);
56
57	A[(j, i)] = conj(A[(j, i)]);
58
59	let (Ai, Aj) = A.split_at_row_mut(j);
60	let Ai = Ai.get_mut(i + 1..j, i);
61	let Aj = Aj.get_mut(0, i + 1..j).transpose_mut();
62	zip!(Ai, Aj).for_each(|unzip!(x, y)| {
63		let tmp = conj(*x);
64		*x = conj(*y);
65		*y = tmp;
66	});
67}
68
69#[math]
70fn rank_1_update_and_argmax_fallback<'M, 'N, T: ComplexField>(
71	A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
72	L: ColRef<'_, T, Dim<'N>>,
73	d: T::Real,
74	start: IdxInc<'N>,
75	end: IdxInc<'N>,
76) -> (usize, usize, T::Real) {
77	let mut A = A;
78	let n = A.nrows();
79
80	let mut max_j = n.idx(0);
81	let mut max_i = n.idx(0);
82	let mut max_offdiag = zero();
83
84	for j in start.to(end) {
85		for i in j.next().to(n.end()) {
86			A[(i, j)] = A[(i, j)] - mul_real(L[i] * conj(L[j]), d);
87			let val = abs2(A[(i, j)]);
88			if val > max_offdiag {
89				max_offdiag = val;
90				max_i = i;
91				max_j = j;
92			}
93		}
94	}
95
96	(*max_i, *max_j, max_offdiag)
97}
98
99#[math]
100fn rank_2_update_and_argmax_fallback<'N, T: ComplexField>(
101	A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
102	L0: ColRef<'_, T, Dim<'N>>,
103	L1: ColRef<'_, T, Dim<'N>>,
104	d: T::Real,
105	d00: T::Real,
106	d11: T::Real,
107	d10: T,
108	start: IdxInc<'N>,
109	end: IdxInc<'N>,
110) -> (usize, usize, T::Real) {
111	let mut A = A;
112	let n = A.nrows();
113
114	let mut max_j = n.idx(0);
115	let mut max_i = n.idx(0);
116	let mut max_offdiag = zero();
117
118	for j in start.to(end) {
119		let x0 = copy(L0[j]);
120		let x1 = copy(L1[j]);
121
122		let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
123		let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
124
125		for i in j.next().to(n.end()) {
126			A[(i, j)] = A[(i, j)] - L0[i] * conj(w0) - L1[i] * conj(w1);
127
128			let val = abs2(A[(i, j)]);
129			if val > max_offdiag {
130				max_offdiag = val;
131				max_i = i;
132				max_j = j;
133			}
134		}
135	}
136	(*max_i, *max_j, max_offdiag)
137}
138
139#[math]
140fn rank_1_update_and_argmax_seq<'M, 'N, T: ComplexField>(
141	A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
142	L: ColRef<'_, T, Dim<'N>>,
143	d: T::Real,
144	start: IdxInc<'N>,
145	end: IdxInc<'N>,
146) -> (usize, usize, T::Real) {
147	rank_1_update_and_argmax_fallback(A, L, d, start, end)
148}
149
150#[math]
151fn rank_2_update_and_argmax_seq<'N, T: ComplexField>(
152	A: MatMut<'_, T, Dim<'N>, Dim<'N>>,
153	L0: ColRef<'_, T, Dim<'N>>,
154	L1: ColRef<'_, T, Dim<'N>>,
155	d: T::Real,
156	d00: T::Real,
157	d11: T::Real,
158	d10: T,
159	start: IdxInc<'N>,
160	end: IdxInc<'N>,
161) -> (usize, usize, T::Real) {
162	rank_2_update_and_argmax_fallback(A, L0, L1, d, d00, d11, d10, start, end)
163}
164
165#[math]
166fn rank_1_update_and_argmax<T: ComplexField>(A: MatMut<'_, T>, L: ColRef<'_, T>, d: T::Real, par: Par) -> (usize, usize, T::Real) {
167	with_dim!(N, A.nrows());
168
169	match par {
170		Par::Seq => rank_1_update_and_argmax_seq(A.as_shape_mut(N, N), L.as_row_shape(N), d, IdxInc::ZERO, N.end()),
171		#[cfg(feature = "rayon")]
172		Par::Rayon(nthreads) => {
173			use rayon::prelude::*;
174			let nthreads = nthreads.get();
175			let n = *N;
176
177			// to check that integers can be represented exactly as floats
178			assert!((n as u64) < (1u64 << 50));
179
180			let idx_to_col_start = |idx: usize| {
181				let idx_as_percent = idx as f64 / nthreads as f64;
182				let col_start_percent = 1.0f64 - libm::sqrt(1.0f64 - idx_as_percent);
183				(col_start_percent * n as f64) as usize
184			};
185
186			let mut r = alloc::vec![(0usize, 0usize, zero::<T::Real>()); nthreads];
187
188			r.par_iter_mut().enumerate().for_each(|(idx, out)| {
189				let A = unsafe { A.rb().const_cast() };
190				let start = N.idx_inc(idx_to_col_start(idx));
191				let end = N.idx_inc(idx_to_col_start(idx + 1));
192
193				*out = rank_1_update_and_argmax_seq(A.as_shape_mut(N, N), L.as_row_shape(N), copy(d), start, end);
194			});
195
196			r.into_iter()
197				.max_by(|(_, _, a), (_, _, b)| {
198					if a == b {
199						core::cmp::Ordering::Equal
200					} else if a > b {
201						core::cmp::Ordering::Greater
202					} else {
203						core::cmp::Ordering::Less
204					}
205				})
206				.unwrap()
207		},
208	}
209}
210
211#[math]
212fn rank_2_update_and_argmax<'N, T: ComplexField>(
213	A: MatMut<'_, T>,
214	L0: ColRef<'_, T>,
215	L1: ColRef<'_, T>,
216	d: T::Real,
217	d00: T::Real,
218	d11: T::Real,
219	d10: T,
220	par: Par,
221) -> (usize, usize, T::Real) {
222	with_dim!(N, A.nrows());
223
224	match par {
225		Par::Seq => rank_2_update_and_argmax_seq(
226			A.as_shape_mut(N, N),
227			L0.as_row_shape(N),
228			L1.as_row_shape(N),
229			d,
230			d00,
231			d11,
232			d10,
233			IdxInc::ZERO,
234			N.end(),
235		),
236		#[cfg(feature = "rayon")]
237		Par::Rayon(nthreads) => {
238			use rayon::prelude::*;
239			let nthreads = nthreads.get();
240			let n = *N;
241
242			// to check that integers can be represented exactly as floats
243			assert!((n as u64) < (1u64 << 50));
244
245			let idx_to_col_start = |idx: usize| {
246				let idx_as_percent = idx as f64 / nthreads as f64;
247				let col_start_percent = 1.0f64 - libm::sqrt(1.0f64 - idx_as_percent);
248				(col_start_percent * n as f64) as usize
249			};
250
251			let mut r = alloc::vec![(0usize, 0usize, zero::<T::Real>()); nthreads];
252
253			r.par_iter_mut().enumerate().for_each(|(idx, out)| {
254				let A = unsafe { A.rb().const_cast() };
255				let start = N.idx_inc(idx_to_col_start(idx));
256				let end = N.idx_inc(idx_to_col_start(idx + 1));
257
258				*out = rank_2_update_and_argmax_seq(
259					A.as_shape_mut(N, N),
260					L0.as_row_shape(N),
261					L1.as_row_shape(N),
262					copy(d),
263					copy(d00),
264					copy(d11),
265					copy(d10),
266					start,
267					end,
268				);
269			});
270
271			r.into_iter()
272				.max_by(|(_, _, a), (_, _, b)| {
273					if a == b {
274						core::cmp::Ordering::Equal
275					} else if a < b {
276						core::cmp::Ordering::Less
277					} else {
278						core::cmp::Ordering::Greater
279					}
280				})
281				.unwrap()
282		},
283	}
284}
285
286#[math]
287fn lblt_full_piv<T: ComplexField>(A: MatMut<'_, T>, subdiag: DiagMut<'_, T>, pivots: &mut [usize], par: Par, params: LbltParams) {
288	let alpha = (one::<T::Real>() + sqrt(from_f64::<T::Real>(17.0))) * from_f64::<T::Real>(0.125);
289	let alpha = alpha * alpha;
290
291	let mut A = A;
292	let mut subdiag = subdiag.column_vector_mut();
293	let mut par = par;
294	let n = A.nrows();
295
296	let scale_fwd = A.norm_max();
297	let scale_bwd = recip(scale_fwd);
298	zip!(A.rb_mut()).for_each(|unzip!(x)| *x = mul_real(*x, scale_bwd));
299
300	let mut max_i = 0;
301	let mut max_j = 0;
302	let mut max_offdiag = zero();
303
304	for j in 0..n {
305		for i in j + 1..n {
306			let val = abs2(A[(i, j)]);
307			if val > max_offdiag {
308				max_offdiag = val;
309				max_i = i;
310				max_j = j;
311			}
312		}
313	}
314
315	let mut k = 0;
316	while k < n {
317		if max_offdiag == zero() {
318			break;
319		}
320
321		let (mut Aprev, mut A) = A.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
322		let mut subdiag = subdiag.rb_mut().get_mut(k..);
323		let pivots = &mut pivots[k..];
324
325		let n = A.nrows();
326		let mut max_s = 0;
327		let mut max_diag = zero();
328
329		for s in 0..n {
330			let val = abs2(A[(s, s)]);
331			if val > max_diag {
332				max_diag = val;
333				max_s = s;
334			}
335		}
336
337		let npiv;
338		let i0;
339		let i1;
340
341		if max_diag >= alpha * max_offdiag {
342			npiv = 1;
343			i0 = max_s;
344			i1 = usize::MAX;
345		} else {
346			npiv = 2;
347			i0 = max_j;
348			i1 = max_i;
349		}
350
351		let rem = n - npiv;
352		if rem * rem < params.par_threshold {
353			par = Par::Seq;
354		}
355
356		// swap pivots to first (and second) column
357		if i0 != 0 {
358			swap_self_adjoint(A.rb_mut(), 0, i0);
359			perm::swap_rows_idx(Aprev.rb_mut(), 0, i0);
360		}
361		if npiv == 2 && i1 != 1 {
362			swap_self_adjoint(A.rb_mut(), 1, i1);
363			perm::swap_rows_idx(Aprev.rb_mut(), 1, i1);
364		}
365
366		if npiv == 1 {
367			let diag = real(A[(0, 0)]);
368			let diag_inv = recip(diag);
369			subdiag[0] = zero();
370
371			let (_, _, L, mut A) = A.rb_mut().split_at_mut(1, 1);
372			let n = A.nrows();
373			let mut L = L.col_mut(0);
374
375			zip!(L.rb_mut()).for_each(|unzip!(x)| *x = mul_real(*x, diag_inv));
376
377			for i in 0..n {
378				A[(i, i)] = from_real(real(A[(i, i)]) - diag * abs2(L[i]));
379			}
380
381			if n < params.par_threshold {}
382			if n != 0 {
383				(max_i, max_j, max_offdiag) = rank_1_update_and_argmax(A.rb_mut(), L.rb(), diag, par);
384			}
385		} else {
386			let a00 = real(A[(0, 0)]);
387			let a11 = real(A[(1, 1)]);
388			let a10 = copy(A[(1, 0)]);
389
390			subdiag[0] = copy(a10);
391			subdiag[1] = zero();
392			A[(1, 0)] = zero();
393
394			let d10 = abs(a10);
395			let d10_inv = recip(d10);
396			let d00 = a00 * d10_inv;
397			let d11 = a11 * d10_inv;
398
399			// t = (d00/|d10| * d11/|d10| - 1.0)
400			let t = recip(d00 * d11 - one());
401			let d10 = mul_real(a10, d10_inv);
402			let d = t * d10_inv;
403
404			//         [ a00  a01 ]
405			// L_new * [ a10  a11 ] = L
406			let (_, _, L, mut A) = A.rb_mut().split_at_mut(2, 2);
407			let (mut L0, mut L1) = L.two_cols_mut(0, 1);
408			let n = A.nrows();
409
410			if n != 0 {
411				(max_i, max_j, max_offdiag) = rank_2_update_and_argmax(A.rb_mut(), L0.rb(), L1.rb(), copy(d), copy(d00), copy(d11), copy(d10), par);
412			}
413
414			for j in 0..n {
415				let x0 = copy(L0[j]);
416				let x1 = copy(L1[j]);
417
418				let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
419				let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
420
421				A[(j, j)] = from_real(real(A[(j, j)] - L0[j] * conj(w0) - L1[j] * conj(w1)));
422
423				L0[j] = w0;
424				L1[j] = w1;
425			}
426		}
427
428		if npiv == 2 {
429			pivots[0] = !(i0 + k);
430			pivots[1] = !(i1 + k);
431		} else {
432			pivots[0] = i0 + k;
433		}
434		k += npiv;
435	}
436
437	while k < n {
438		let (mut Aprev, mut A) = A.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
439		let mut subdiag = subdiag.rb_mut().get_mut(k..);
440		let pivots = &mut pivots[k..];
441
442		let n = A.nrows();
443		let mut max_s = 0;
444		let mut max_diag = zero();
445
446		for s in 0..n {
447			let val = abs2(A[(s, s)]);
448			if val > max_diag {
449				max_diag = val;
450				max_s = s;
451			}
452		}
453
454		if max_s != 0 {
455			let (mut A0, mut As) = A.rb_mut().two_cols_mut(0, max_s);
456			core::mem::swap(&mut A0[0], &mut As[max_s]);
457
458			perm::swap_rows_idx(Aprev.rb_mut(), 0, max_s);
459		}
460
461		subdiag[0] = zero();
462		pivots[0] = max_s + k;
463
464		k += 1;
465	}
466
467	zip!(A.rb_mut().diagonal_mut().column_vector_mut()).for_each(|unzip!(x)| *x = mul_real(*x, scale_fwd));
468	zip!(subdiag.rb_mut()).for_each(|unzip!(x)| *x = mul_real(*x, scale_fwd));
469}
470
471#[math]
472#[track_caller]
473fn l1_argmax<T: ComplexField>(col: ColRef<'_, T>) -> (Option<usize>, T::Real) {
474	let n = col.nrows();
475	if n == 0 {
476		return (None, zero());
477	}
478
479	let mut i = 0;
480	let mut best = zero();
481
482	for j in 0..n {
483		let val = abs1(col[j]);
484		if val > best {
485			best = val;
486			i = j;
487		}
488	}
489
490	(Some(i), best)
491}
492
493#[math]
494#[track_caller]
495fn offdiag_argmax<T: ComplexField>(A: MatRef<'_, T>, idx: usize) -> (Option<usize>, T::Real) {
496	let (mut col_argmax, col_max) = l1_argmax(A.rb().get(idx + 1.., idx));
497	col_argmax.as_mut().map(|col_argmax| *col_argmax += idx + 1);
498	let (row_argmax, row_max) = l1_argmax(A.rb().get(idx, ..idx).transpose());
499
500	if col_max > row_max {
501		(col_argmax, col_max)
502	} else {
503		(row_argmax, row_max)
504	}
505}
506
507#[math]
508fn update_and_offdiag_argmax<T: ComplexField>(
509	mut dst: ColMut<'_, T>,
510	Wl: MatRef<'_, T>,
511	Al: MatRef<'_, T>,
512	Ar: MatRef<'_, T>,
513	i0: usize,
514	par: Par,
515) -> (Option<usize>, T::Real) {
516	let n = Al.nrows();
517	for j in 0..i0 {
518		dst[j] = conj(Ar[(i0, j)]);
519	}
520	dst[i0] = zero();
521	for j in i0 + 1..n {
522		dst[j] = copy(Ar[(j, i0)]);
523	}
524
525	linalg::matmul::matmul(dst.rb_mut(), Accum::Add, Al.rb(), Wl.row(i0).adjoint(), -one::<T>(), par);
526	dst[i0] = zero();
527
528	let ret = l1_argmax(dst.rb());
529	dst[i0] = from_real(real(Ar[(i0, i0)]));
530	if n == 1 { (None, zero()) } else { ret }
531}
532
533#[math]
534fn lblt_blocked_step<T: ComplexField>(
535	alpha: T::Real,
536	W: MatMut<'_, T>,
537	A_left: MatMut<'_, T>,
538	A: MatMut<'_, T>,
539	subdiag: DiagMut<'_, T>,
540	pivots: &mut [usize],
541	rook: bool,
542	diagonal: bool,
543	par: Par,
544) -> usize {
545	let mut A = A;
546	let mut A_left = A_left;
547	let mut subdiag = subdiag;
548	let mut W = W;
549
550	let n = A.nrows();
551	let blocksize = W.ncols();
552
553	assert!(all(A.nrows() == n, A.ncols() == n, W.nrows() == n, subdiag.dim() == n, blocksize >= 2,));
554
555	let kmax = Ord::min(blocksize - 1, n);
556	let mut k = 0usize;
557	while k < kmax {
558		let mut A = A.rb_mut();
559		let mut W = W.rb_mut();
560		let mut subdiag = subdiag.rb_mut().column_vector_mut().get_mut(k..);
561		let mut A_left = A_left.rb_mut().get_mut(k.., ..);
562
563		let (mut Wl, mut Wr) = W.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
564		let (mut Al, mut Ar) = A.rb_mut().get_mut(k.., ..).split_at_col_mut(k);
565		let mut Al = Al.rb_mut();
566		let mut Wr = Wr.rb_mut().get_mut(.., ..2);
567
568		let npiv;
569		let mut i0 = if diagonal {
570			l1_argmax(Ar.rb().diagonal().column_vector()).0.unwrap()
571		} else {
572			0
573		};
574		let mut i1 = usize::MAX;
575
576		let mut nothing_to_do = false;
577
578		let (mut Wr0, mut Wr1) = Wr.rb_mut().two_cols_mut(0, 1);
579
580		let (r, mut gamma_i) = update_and_offdiag_argmax(Wr0.rb_mut(), Wl.rb(), Al.rb(), Ar.rb(), i0, par);
581
582		if k + 1 == n || gamma_i == zero() {
583			nothing_to_do = true;
584			npiv = 1;
585		} else if abs(real(Ar[(i0, i0)])) >= alpha * gamma_i {
586			npiv = 1;
587		} else {
588			i1 = r.unwrap();
589			if rook {
590				loop {
591					let (s, gamma_r) = update_and_offdiag_argmax(Wr1.rb_mut(), Wl.rb(), Al.rb(), Ar.rb(), i1, par);
592
593					if abs1(Ar[(i1, i1)]) >= alpha * gamma_r {
594						npiv = 1;
595						i0 = i1;
596						i1 = usize::MAX;
597						Wr0.copy_from(&Wr1);
598						break;
599					} else if s == Some(i0) || gamma_i == gamma_r {
600						npiv = 2;
601						break;
602					} else {
603						i0 = i1;
604						i1 = s.unwrap();
605						gamma_i = gamma_r;
606						Wr0.copy_from(&Wr1);
607					}
608				}
609			} else {
610				let (_, gamma_r) = update_and_offdiag_argmax(Wr1.rb_mut(), Wl.rb(), Al.rb(), Ar.rb(), i1, par);
611
612				if abs(real(Ar[(i0, i0)])) >= (alpha * gamma_r) * (gamma_r / gamma_i) {
613					npiv = 1;
614				} else if abs(real(Ar[(i1, i1)])) >= alpha * gamma_r {
615					npiv = 1;
616					i0 = i1;
617					i1 = usize::MAX;
618					Wr0.copy_from(&Wr1);
619				} else {
620					npiv = 2;
621				}
622			}
623		}
624
625		if npiv == 2 && i0 > i1 {
626			perm::swap_cols_idx(Wr.rb_mut(), 0, 1);
627			(i0, i1) = (i1, i0);
628		}
629
630		let mut Wr = Wr.rb_mut().get_mut(.., ..npiv);
631
632		'next_iter: {
633			// swap pivots to first (and second) column
634			if i0 != 0 {
635				swap_self_adjoint(Ar.rb_mut(), 0, i0);
636				perm::swap_rows_idx(Al.rb_mut(), 0, i0);
637				perm::swap_rows_idx(A_left.rb_mut(), 0, i0);
638				perm::swap_rows_idx(Wl.rb_mut(), 0, i0);
639				perm::swap_rows_idx(Wr.rb_mut(), 0, i0);
640			}
641			if npiv == 2 && i1 != 1 {
642				swap_self_adjoint(Ar.rb_mut(), 1, i1);
643				perm::swap_rows_idx(Al.rb_mut(), 1, i1);
644				perm::swap_rows_idx(A_left.rb_mut(), 1, i1);
645				perm::swap_rows_idx(Wl.rb_mut(), 1, i1);
646				perm::swap_rows_idx(Wr.rb_mut(), 1, i1);
647			}
648
649			if nothing_to_do {
650				break 'next_iter;
651			}
652
653			if npiv == 1 {
654				let W0 = Wr.rb_mut().col_mut(0);
655
656				let diag = real(W0[0]);
657				let diag_inv = recip(diag);
658				subdiag[0] = zero();
659
660				let (_, _, L, mut A) = Ar.rb_mut().split_at_mut(1, 1);
661				let W0 = W0.rb().get(1..);
662				let n = A.nrows();
663
664				let mut L = L.col_mut(0);
665				zip!(W0, L.rb_mut()).for_each(|unzip!(w, a)| *a = mul_real(*w, diag_inv));
666
667				for j in 0..n {
668					A[(j, j)] = from_real(real(A[(j, j)]) - diag * abs2(L[j]));
669				}
670			} else {
671				let a00 = real(Wr[(0, 0)]);
672				let a11 = real(Wr[(1, 1)]);
673				let a10 = copy(Wr[(1, 0)]);
674
675				subdiag[0] = copy(a10);
676				subdiag[1] = zero();
677				Wr[(1, 0)] = zero();
678				Ar[(1, 0)] = zero();
679
680				let d10 = abs(a10);
681				let d10_inv = recip(d10);
682				let d00 = a00 * d10_inv;
683				let d11 = a11 * d10_inv;
684
685				// t = (d00/|d10| * d11/|d10| - 1.0)
686				let t = recip(d00 * d11 - one());
687				let d10 = mul_real(a10, d10_inv);
688				let d = t * d10_inv;
689
690				//         [ a00  a01 ]
691				// L_new * [ a10  a11 ] = L
692				let (_, _, L, mut A) = Ar.rb_mut().split_at_mut(2, 2);
693				let (mut L0, mut L1) = L.two_cols_mut(0, 1);
694				let Wr = Wr.rb().get(2.., ..);
695				let W0 = Wr.col(0);
696				let W1 = Wr.col(1);
697
698				let n = A.nrows();
699				for j in 0..n {
700					let x0 = copy(W0[j]);
701					let x1 = copy(W1[j]);
702
703					let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
704					let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
705
706					A[(j, j)] = from_real(real(A[(j, j)] - W0[j] * conj(w0) - W1[j] * conj(w1)));
707
708					L0[j] = w0;
709					L1[j] = w1;
710				}
711			}
712		}
713
714		let offset = A_left.ncols();
715
716		if npiv == 2 {
717			pivots[k] = !(offset + i0 + k);
718			pivots[k + 1] = !(offset + i1 + k);
719		} else {
720			pivots[k] = offset + i0 + k;
721		}
722		k += npiv;
723	}
724
725	let W = W.rb().get(k.., ..k);
726	let (_, _, Al, mut Ar) = A.rb_mut().split_at_mut(k, k);
727	let Al = Al.rb();
728
729	linalg::matmul::triangular::matmul(
730		Ar.rb_mut(),
731		BlockStructure::StrictTriangularLower,
732		Accum::Add,
733		W,
734		BlockStructure::Rectangular,
735		Al.adjoint(),
736		BlockStructure::Rectangular,
737		-one::<T>(),
738		par,
739	);
740
741	for j in 0..n - k {
742		Ar[(j, j)] = from_real(real(Ar[(j, j)]));
743	}
744
745	k
746}
747
748#[math]
749fn lblt_blocked<T: ComplexField>(
750	A: MatMut<'_, T>,
751	subdiag: DiagMut<'_, T>,
752	pivots: &mut [usize],
753	blocksize: usize,
754	rook: bool,
755	diagonal: bool,
756	par: Par,
757	stack: &mut MemStack,
758) {
759	let alpha = (one::<T::Real>() + sqrt(from_f64::<T::Real>(17.0))) * from_f64::<T::Real>(0.125);
760
761	let mut A = A;
762	let mut subdiag = subdiag.column_vector_mut();
763	let n = A.nrows();
764
765	let mut k = 0;
766	while k < n {
767		let (_, _, A_left, A) = A.rb_mut().split_at_mut(k, k);
768		let (mut W, _) = unsafe { temp_mat_uninit::<T, _, _>(n - k, blocksize, stack) };
769		let W = W.as_mat_mut();
770
771		if blocksize < 2 || n - k <= blocksize {
772			lblt_unblocked(
773				copy(alpha),
774				A_left,
775				A,
776				subdiag.rb_mut().get_mut(k..).as_diagonal_mut(),
777				&mut pivots[k..],
778				rook,
779				diagonal,
780				par,
781			);
782
783			k = n;
784		} else {
785			let blocksize = lblt_blocked_step(
786				copy(alpha),
787				W,
788				A_left,
789				A,
790				subdiag.rb_mut().get_mut(k..).as_diagonal_mut(),
791				&mut pivots[k..],
792				rook,
793				diagonal,
794				par,
795			);
796
797			k += blocksize;
798		}
799	}
800}
801
802#[math]
803fn lblt_unblocked<T: ComplexField>(
804	alpha: T::Real,
805	A_left: MatMut<'_, T>,
806	A: MatMut<'_, T>,
807	subdiag: DiagMut<'_, T>,
808	pivots: &mut [usize],
809	rook: bool,
810	diagonal: bool,
811	par: Par,
812) {
813	let _ = par;
814	let mut A = A;
815	let mut A_left = A_left;
816	let mut subdiag = subdiag;
817
818	let n = A.nrows();
819	assert!(all(A.nrows() == n, A.ncols() == n, subdiag.dim() == n));
820
821	let mut k = 0usize;
822	while k < n {
823		let (_, _, mut L_prev, mut A) = A.rb_mut().split_at_mut(k, k);
824		let mut subdiag = subdiag.rb_mut().column_vector_mut().get_mut(k..);
825		let mut A_left = A_left.rb_mut().get_mut(k.., ..);
826
827		let npiv;
828
829		// find the diagonal pivot candidate, if requested
830		let mut i0 = if diagonal {
831			l1_argmax(A.rb().diagonal().column_vector()).0.unwrap()
832		} else {
833			0
834		};
835		let mut i1 = usize::MAX;
836
837		// find the largest off-diagonal in the pivot's column
838		let (r, mut gamma_i) = offdiag_argmax(A.rb(), i0);
839
840		let mut nothing_to_do = false;
841
842		if k + 1 == n || gamma_i == zero() {
843			nothing_to_do = true;
844			npiv = 1;
845		} else if abs(real(A[(i0, i0)])) >= alpha * gamma_i {
846			npiv = 1;
847		} else {
848			i1 = r.unwrap();
849
850			// pivot search
851			if rook {
852				loop {
853					let (s, gamma_r) = offdiag_argmax(A.rb(), i1);
854
855					if abs1(A[(i1, i1)]) >= alpha * gamma_r {
856						npiv = 1;
857						i0 = i1;
858						i1 = usize::MAX;
859						break;
860					} else if gamma_i == gamma_r {
861						npiv = 2;
862						break;
863					} else {
864						i0 = i1;
865						i1 = s.unwrap();
866						gamma_i = gamma_r;
867					}
868				}
869			} else {
870				let (_, gamma_r) = offdiag_argmax(A.rb(), i1);
871				if abs(real(A[(i0, i0)])) >= (alpha * gamma_r) * (gamma_r / gamma_i) {
872					npiv = 1;
873				} else if abs(real(A[(i1, i1)])) >= alpha * gamma_r {
874					npiv = 1;
875					i0 = i1;
876				} else {
877					npiv = 2;
878				}
879			}
880		}
881
882		if npiv == 2 && i0 > i1 {
883			(i0, i1) = (i1, i0);
884		}
885
886		'next_iter: {
887			// swap pivots to first (and second) column
888			if i0 != 0 {
889				swap_self_adjoint(A.rb_mut(), 0, i0);
890				perm::swap_rows_idx(A_left.rb_mut(), 0, i0);
891				perm::swap_rows_idx(L_prev.rb_mut(), 0, i0);
892			}
893			if npiv == 2 && i1 != 1 {
894				swap_self_adjoint(A.rb_mut(), 1, i1);
895				perm::swap_rows_idx(A_left.rb_mut(), 1, i1);
896				perm::swap_rows_idx(L_prev.rb_mut(), 1, i1);
897			}
898
899			if nothing_to_do {
900				break 'next_iter;
901			}
902
903			// rank downdate
904			if npiv == 1 {
905				let diag = real(A[(0, 0)]);
906				let diag_inv = recip(diag);
907				subdiag[0] = zero();
908
909				let (_, _, L, A) = A.rb_mut().split_at_mut(1, 1);
910				let L = L.col_mut(0);
911				rank1_update(A, L, diag_inv);
912			} else {
913				let a00 = real(A[(0, 0)]);
914				let a11 = real(A[(1, 1)]);
915				let a10 = copy(A[(1, 0)]);
916
917				subdiag[0] = copy(a10);
918				subdiag[1] = zero();
919				A[(1, 0)] = zero();
920
921				let d10 = abs(a10);
922				let d10_inv = recip(d10);
923				let d00 = a00 * d10_inv;
924				let d11 = a11 * d10_inv;
925
926				// t = (d00/|d10| * d11/|d10| - 1.0)
927				let t = recip(d00 * d11 - one());
928				let d10 = mul_real(a10, d10_inv);
929				let d = t * d10_inv;
930
931				//         [ a00  a01 ]
932				// L_new * [ a10  a11 ] = L
933				let (_, _, L, A) = A.rb_mut().split_at_mut(2, 2);
934				let (L0, L1) = L.two_cols_mut(0, 1);
935				rank2_update(A, L0, L1, d, d00, d10, d11);
936			}
937		}
938
939		let offset = A_left.ncols();
940		if npiv == 2 {
941			pivots[k] = !(offset + i0 + k);
942			pivots[k + 1] = !(offset + i1 + k);
943		} else {
944			pivots[k] = offset + i0 + k;
945		}
946		k += npiv;
947	}
948}
949
950impl<T: ComplexField> Auto<T> for LbltParams {
951	fn auto() -> Self {
952		Self {
953			pivoting: PivotingStrategy::PartialDiag,
954			blocksize: 64,
955			par_threshold: 256 * 512,
956			non_exhaustive: NonExhaustive(()),
957		}
958	}
959}
960
961pub fn rank2_update<'a, T: ComplexField>(
962	mut A: MatMut<'a, T>,
963	mut L0: ColMut<'a, T>,
964	mut L1: ColMut<'a, T>,
965	d: T::Real,
966	d00: T::Real,
967	d10: T,
968	d11: T::Real,
969) {
970	if const { T::SIMD_CAPABILITIES.is_simd() } {
971		if let (Some(A), Some(L0), Some(L1)) = (
972			A.rb_mut().try_as_col_major_mut(),
973			L0.rb_mut().try_as_col_major_mut(),
974			L1.rb_mut().try_as_col_major_mut(),
975		) {
976			rank2_update_simd(A, L0, L1, d, d00, d10, d11);
977		} else {
978			rank2_update_fallback(A, L0, L1, d, d00, d10, d11);
979		}
980	} else {
981		rank2_update_fallback(A, L0, L1, d, d00, d10, d11);
982	}
983}
984
985#[math]
986pub fn rank2_update_simd<'a, T: ComplexField>(
987	A: MatMut<'a, T, usize, usize, ContiguousFwd>,
988	L0: ColMut<'a, T, usize, ContiguousFwd>,
989	L1: ColMut<'a, T, usize, ContiguousFwd>,
990	d: T::Real,
991	d00: T::Real,
992	d10: T,
993	d11: T::Real,
994) {
995	struct Impl<'a, T: ComplexField> {
996		A: MatMut<'a, T, usize, usize, ContiguousFwd>,
997		L0: ColMut<'a, T, usize, ContiguousFwd>,
998		L1: ColMut<'a, T, usize, ContiguousFwd>,
999		d: T::Real,
1000		d00: T::Real,
1001		d10: T,
1002		d11: T::Real,
1003	}
1004
1005	impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
1006		type Output = ();
1007
1008		#[inline(always)]
1009		fn with_simd<S: pulp::Simd>(self, simd: S) {
1010			let Self {
1011				mut A,
1012				mut L0,
1013				mut L1,
1014				d,
1015				d00,
1016				d10,
1017				d11,
1018			} = self;
1019			let n = A.nrows();
1020			for j in 0..n {
1021				let x0 = copy(L0[j]);
1022				let x1 = copy(L1[j]);
1023				let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
1024				let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
1025
1026				with_dim!({
1027					let subrange_len = n - j;
1028				});
1029				{
1030					let mut A = A.rb_mut().get_mut(j.., j).as_row_shape_mut(subrange_len);
1031					let L0 = L0.rb().get(j..).as_row_shape(subrange_len);
1032					let L1 = L1.rb().get(j..).as_row_shape(subrange_len);
1033					let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), subrange_len);
1034					let (head, body, tail) = simd.indices();
1035
1036					let w0_conj = conj(w0);
1037					let w1_conj = conj(w1);
1038					let w0_conj_neg = -w0_conj;
1039					let w1_conj_neg = -w1_conj;
1040					let w0_splat = simd.splat(&w0_conj_neg);
1041					let w1_splat = simd.splat(&w1_conj_neg);
1042
1043					if let Some(i) = head {
1044						let mut acc = simd.read(A.rb(), i);
1045						let l0_val = simd.read(L0, i);
1046						let l1_val = simd.read(L1, i);
1047						acc = simd.mul_add(l0_val, w0_splat, acc);
1048						acc = simd.mul_add(l1_val, w1_splat, acc);
1049						simd.write(A.rb_mut(), i, acc);
1050					}
1051
1052					for i in body.clone() {
1053						let mut acc = simd.read(A.rb(), i);
1054						let l0_val = simd.read(L0, i);
1055						let l1_val = simd.read(L1, i);
1056						acc = simd.mul_add(l0_val, w0_splat, acc);
1057						acc = simd.mul_add(l1_val, w1_splat, acc);
1058						simd.write(A.rb_mut(), i, acc);
1059					}
1060
1061					if let Some(i) = tail {
1062						let mut acc = simd.read(A.rb(), i);
1063						let l0_val = simd.read(L0, i);
1064						let l1_val = simd.read(L1, i);
1065						acc = simd.mul_add(l0_val, w0_splat, acc);
1066						acc = simd.mul_add(l1_val, w1_splat, acc);
1067						simd.write(A.rb_mut(), i, acc);
1068					}
1069				}
1070				A[(j, j)] = from_real(real(A[(j, j)]));
1071
1072				L0[j] = w0;
1073				L1[j] = w1;
1074			}
1075		}
1076	}
1077	dispatch!(Impl { A, L0, L1, d, d00, d10, d11 }, Impl, T)
1078}
1079
1080#[math]
1081pub fn rank2_update_fallback<'a, T: ComplexField>(
1082	mut A: MatMut<'a, T>,
1083	mut L0: ColMut<'a, T>,
1084	mut L1: ColMut<'a, T>,
1085	d: T::Real,
1086	d00: T::Real,
1087	d10: T,
1088	d11: T::Real,
1089) {
1090	let n = A.nrows();
1091	for j in 0..n {
1092		let x0 = copy(L0[j]);
1093		let x1 = copy(L1[j]);
1094
1095		let w0 = mul_real(mul_real(x0, d11) - x1 * d10, d);
1096		let w1 = mul_real(mul_real(x1, d00) - x0 * conj(d10), d);
1097
1098		for i in j..n {
1099			A[(i, j)] = A[(i, j)] - L0[i] * conj(w0) - L1[i] * conj(w1);
1100		}
1101		A[(j, j)] = from_real(real(A[(j, j)]));
1102
1103		L0[j] = w0;
1104		L1[j] = w1;
1105	}
1106}
1107
1108pub fn rank1_update<'a, T: ComplexField>(mut A: MatMut<'a, T>, mut L0: ColMut<'a, T>, d: T::Real) {
1109	if const { T::SIMD_CAPABILITIES.is_simd() } {
1110		if let (Some(A), Some(L0)) = (A.rb_mut().try_as_col_major_mut(), L0.rb_mut().try_as_col_major_mut()) {
1111			rank1_update_simd(A, L0, d);
1112		} else {
1113			rank1_update_fallback(A, L0, d);
1114		}
1115	} else {
1116		rank1_update_fallback(A, L0, d);
1117	}
1118}
1119
1120#[math]
1121pub fn rank1_update_simd<'a, T: ComplexField>(A: MatMut<'a, T, usize, usize, ContiguousFwd>, L0: ColMut<'a, T, usize, ContiguousFwd>, d: T::Real) {
1122	struct Impl<'a, T: ComplexField> {
1123		A: MatMut<'a, T, usize, usize, ContiguousFwd>,
1124		L0: ColMut<'a, T, usize, ContiguousFwd>,
1125		d: T::Real,
1126	}
1127
1128	impl<T: ComplexField> pulp::WithSimd for Impl<'_, T> {
1129		type Output = ();
1130
1131		#[inline(always)]
1132		fn with_simd<S: pulp::Simd>(self, simd: S) {
1133			let Self { mut A, mut L0, d } = self;
1134
1135			let n = A.nrows();
1136			for j in 0..n {
1137				let x0 = copy(L0[j]);
1138				let w0 = mul_real(x0, d);
1139
1140				with_dim!({
1141					let subrange_len = n - j;
1142				});
1143				{
1144					let mut A = A.rb_mut().get_mut(j.., j).as_row_shape_mut(subrange_len);
1145					let L0 = L0.rb().get(j..).as_row_shape(subrange_len);
1146					let simd = SimdCtx::<T, S>::new(T::simd_ctx(simd), subrange_len);
1147					let (head, body, tail) = simd.indices();
1148
1149					let w0_conj = conj(w0);
1150					let w0_conj_neg = -w0_conj;
1151					let w0_splat = simd.splat(&w0_conj_neg);
1152
1153					if let Some(i) = head {
1154						let mut acc = simd.read(A.rb(), i);
1155						let l0_val = simd.read(L0, i);
1156						acc = simd.mul_add(l0_val, w0_splat, acc);
1157						simd.write(A.rb_mut(), i, acc);
1158					}
1159
1160					for i in body.clone() {
1161						let mut acc = simd.read(A.rb(), i);
1162						let l0_val = simd.read(L0, i);
1163						acc = simd.mul_add(l0_val, w0_splat, acc);
1164						simd.write(A.rb_mut(), i, acc);
1165					}
1166
1167					if let Some(i) = tail {
1168						let mut acc = simd.read(A.rb(), i);
1169						let l0_val = simd.read(L0, i);
1170						acc = simd.mul_add(l0_val, w0_splat, acc);
1171						simd.write(A.rb_mut(), i, acc);
1172					}
1173				}
1174				A[(j, j)] = from_real(real(A[(j, j)]));
1175
1176				L0[j] = w0;
1177			}
1178		}
1179	}
1180	dispatch!(Impl { A, L0, d }, Impl, T)
1181}
1182
1183#[math]
1184pub fn rank1_update_fallback<'a, T: ComplexField>(mut A: MatMut<'a, T>, mut L0: ColMut<'a, T>, d: T::Real) {
1185	let n = A.nrows();
1186	for j in 0..n {
1187		let x0 = copy(L0[j]);
1188		let w0 = mul_real(x0, d);
1189
1190		for i in j..n {
1191			A[(i, j)] = A[(i, j)] - L0[i] * conj(w0);
1192		}
1193		A[(j, j)] = from_real(real(A[(j, j)]));
1194		L0[j] = w0;
1195	}
1196}
1197/// computes the size and alignment of required workspace for performing an $LBL^\top$
1198/// decomposition
1199pub fn cholesky_in_place_scratch<I: Index, T: ComplexField>(dim: usize, par: Par, params: Spec<LbltParams, T>) -> StackReq {
1200	let params = params.config;
1201	let _ = par;
1202	let mut bs = params.blocksize;
1203	if bs < 2 || dim <= bs {
1204		bs = 0;
1205	}
1206	StackReq::new::<usize>(dim).and(temp_mat_scratch::<T>(dim, bs))
1207}
1208
1209/// info about the result of the $LBL^\top$ factorization
1210#[derive(Copy, Clone, Debug)]
1211pub struct LbltInfo {
1212	/// number of pivoting transpositions
1213	pub transposition_count: usize,
1214}
1215
1216/// computes the $LBL^\top$ factorization of $A$ and stores the factorization in `matrix` and
1217/// `subdiag`
1218///
1219/// the diagonal of the block diagonal matrix is stored on the diagonal
1220/// of `matrix`, while the subdiagonal elements of the blocks are stored in `subdiag`
1221///
1222/// # panics
1223///
1224/// panics if the input matrix is not square
1225///
1226/// this can also panic if the provided memory in `stack` is insufficient (see
1227/// [`cholesky_in_place_scratch`]).
1228
1229#[track_caller]
1230#[math]
1231pub fn cholesky_in_place<'out, I: Index, T: ComplexField>(
1232	A: MatMut<'_, T>,
1233	subdiag: DiagMut<'_, T>,
1234	perm: &'out mut [I],
1235	perm_inv: &'out mut [I],
1236	par: Par,
1237	stack: &mut MemStack,
1238	params: Spec<LbltParams, T>,
1239) -> (LbltInfo, PermRef<'out, I>) {
1240	let params = params.config;
1241
1242	let truncate = <I::Signed as SignedIndex>::truncate;
1243
1244	let n = A.nrows();
1245	assert!(all(A.nrows() == A.ncols(), subdiag.dim() == n, perm.len() == n, perm_inv.len() == n));
1246
1247	#[cfg(feature = "perf-warn")]
1248	if A.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(CHOLESKY_WARN) {
1249		if A.col_stride().unsigned_abs() == 1 {
1250			log::warn!(target: "faer_perf", "$LBL^\top$ decomposition prefers column-major
1251    matrix. Found row-major matrix.");
1252		} else {
1253			log::warn!(target: "faer_perf", "$LBL^\top$ decomposition prefers column-major
1254    matrix. Found matrix with generic strides.");
1255		}
1256	}
1257
1258	let (mut pivots, stack) = stack.make_with::<usize>(n, |_| 0);
1259	let pivots = &mut *pivots;
1260
1261	let mut bs = params.blocksize;
1262	if bs < 2 || n <= bs {
1263		bs = 0;
1264	}
1265
1266	let (rook, diagonal) = match params.pivoting {
1267		PivotingStrategy::Partial => (false, false),
1268		PivotingStrategy::PartialDiag => (false, true),
1269		PivotingStrategy::Rook => (true, false),
1270		PivotingStrategy::RookDiag => (true, true),
1271		_ => (false, false),
1272	};
1273
1274	if params.pivoting == PivotingStrategy::Full {
1275		lblt_full_piv(A, subdiag, pivots, par, params);
1276	} else {
1277		lblt_blocked(A, subdiag, pivots, bs, rook, diagonal, par, stack);
1278	}
1279
1280	for (i, p) in perm.iter_mut().enumerate() {
1281		*p = I::from_signed(truncate(i));
1282	}
1283
1284	let mut transposition_count = 0usize;
1285	for i in 0..n {
1286		let mut p = pivots[i];
1287		if (p as isize) < 0 {
1288			p = !p;
1289		}
1290		if i != p {
1291			transposition_count += 1;
1292		}
1293		perm.swap(i, p);
1294	}
1295	for (i, &p) in perm.iter().enumerate() {
1296		perm_inv[p.to_signed().zx()] = I::from_signed(truncate(i));
1297	}
1298
1299	(LbltInfo { transposition_count }, unsafe { PermRef::new_unchecked(perm, perm_inv, n) })
1300}