faer/linalg/cholesky/ldlt/
factor.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use crate::linalg::matmul::internal::*;
4use linalg::matmul::triangular::BlockStructure;
5use pulp::Simd;
6
7#[inline(always)]
8#[math]
9fn simd_cholesky_row_batch<'N, T: ComplexField, S: Simd>(
10	simd: T::SimdCtx<S>,
11	A: MatMut<'_, T, Dim<'N>, Dim<'N>, ContiguousFwd>,
12	D: RowMut<'_, T, Dim<'N>>,
13
14	start: IdxInc<'N>,
15
16	is_llt: bool,
17	regularize: bool,
18	eps: T::Real,
19	delta: T::Real,
20	signs: Option<&Array<'N, i8>>,
21) -> Result<usize, usize> {
22	let mut A = A;
23	let mut D = D;
24
25	let n = A.ncols();
26
27	with_dim!(TAIL, *n - *start);
28
29	let simd = SimdCtx::<T, S>::new_force_mask(simd, TAIL);
30	let (idx_head, indices, idx_tail) = simd.indices();
31	assert!(idx_head.is_none());
32	let Some(idx_tail) = idx_tail else { panic!() };
33
34	let mut count = 0usize;
35
36	for j in n.indices() {
37		with_dim!(LEFT, *j);
38
39		let (A_0, Aj) = A.rb_mut().split_at_col_mut(j.into());
40		let A_0 = A_0.as_col_shape(LEFT);
41		let A10 = A_0.subrows(start, TAIL);
42
43		let mut Aj = Aj.col_mut(0).subrows_mut(start, TAIL);
44
45		{
46			let D = D.rb().subcols(IdxInc::ZERO, LEFT);
47			let mut Aj = Aj.rb_mut();
48			let mut iter = indices.clone();
49			let i0 = iter.next();
50			let i1 = iter.next();
51			let i2 = iter.next();
52
53			match (i0, i1, i2) {
54				(None, None, None) => {
55					let mut Aij = simd.read(Aj.rb(), idx_tail);
56
57					for k in LEFT.indices() {
58						let Ak = A10.col(k);
59
60						let D = real(D[k]);
61						let D = if is_llt { one() } else { D };
62
63						let Ajk = simd.splat(&mul_real(conj(A_0[(j, k)]), -D));
64
65						let Aik = simd.read(Ak, idx_tail);
66						Aij = simd.mul_add(Ajk, Aik, Aij);
67					}
68					simd.write(Aj.rb_mut(), idx_tail, Aij);
69				},
70				(Some(i0), None, None) => {
71					let mut A0j = simd.read(Aj.rb(), i0);
72					let mut Aij = simd.read(Aj.rb(), idx_tail);
73
74					for k in LEFT.indices() {
75						let Ak = A10.col(k);
76
77						let D = real(D[k]);
78						let D = if is_llt { one() } else { D };
79
80						let Ajk = simd.splat(&mul_real(conj(A_0[(j, k)]), -D));
81
82						let A0k = simd.read(Ak, i0);
83						let Aik = simd.read(Ak, idx_tail);
84						A0j = simd.mul_add(Ajk, A0k, A0j);
85						Aij = simd.mul_add(Ajk, Aik, Aij);
86					}
87					simd.write(Aj.rb_mut(), i0, A0j);
88					simd.write(Aj.rb_mut(), idx_tail, Aij);
89				},
90				(Some(i0), Some(i1), None) => {
91					let mut A0j = simd.read(Aj.rb(), i0);
92					let mut A1j = simd.read(Aj.rb(), i1);
93					let mut Aij = simd.read(Aj.rb(), idx_tail);
94
95					for k in LEFT.indices() {
96						let Ak = A10.col(k);
97
98						let D = real(D[k]);
99						let D = if is_llt { one() } else { D };
100
101						let Ajk = simd.splat(&mul_real(conj(A_0[(j, k)]), -D));
102
103						let A0k = simd.read(Ak, i0);
104						let A1k = simd.read(Ak, i1);
105						let Aik = simd.read(Ak, idx_tail);
106						A0j = simd.mul_add(Ajk, A0k, A0j);
107						A1j = simd.mul_add(Ajk, A1k, A1j);
108						Aij = simd.mul_add(Ajk, Aik, Aij);
109					}
110					simd.write(Aj.rb_mut(), i0, A0j);
111					simd.write(Aj.rb_mut(), i1, A1j);
112					simd.write(Aj.rb_mut(), idx_tail, Aij);
113				},
114				(Some(i0), Some(i1), Some(i2)) => {
115					let mut A0j = simd.read(Aj.rb(), i0);
116					let mut A1j = simd.read(Aj.rb(), i1);
117					let mut A2j = simd.read(Aj.rb(), i2);
118					let mut Aij = simd.read(Aj.rb(), idx_tail);
119
120					for k in LEFT.indices() {
121						let Ak = A10.col(k);
122
123						let D = real(D[k]);
124						let D = if is_llt { one() } else { D };
125
126						let Ajk = simd.splat(&mul_real(conj(A_0[(j, k)]), -D));
127
128						let A0k = simd.read(Ak, i0);
129						let A1k = simd.read(Ak, i1);
130						let A2k = simd.read(Ak, i2);
131						let Aik = simd.read(Ak, idx_tail);
132						A0j = simd.mul_add(Ajk, A0k, A0j);
133						A1j = simd.mul_add(Ajk, A1k, A1j);
134						A2j = simd.mul_add(Ajk, A2k, A2j);
135						Aij = simd.mul_add(Ajk, Aik, Aij);
136					}
137					simd.write(Aj.rb_mut(), i0, A0j);
138					simd.write(Aj.rb_mut(), i1, A1j);
139					simd.write(Aj.rb_mut(), i2, A2j);
140					simd.write(Aj.rb_mut(), idx_tail, Aij);
141				},
142				_ => {
143					unreachable!();
144				},
145			}
146		}
147
148		let D = D.rb_mut().at_mut(j);
149
150		if *j >= *start {
151			let j_row = TAIL.idx(*j - *start);
152
153			let mut diag = real(Aj[j_row]);
154
155			if regularize {
156				let sign = if is_llt { 1 } else { if let Some(signs) = signs { signs[j] } else { 0 } };
157
158				let small_or_negative = diag <= eps;
159				let minus_small_or_positive = diag >= -eps;
160
161				if sign == 1 && small_or_negative {
162					diag = copy(delta);
163					count += 1;
164				} else if sign == -1i8 && minus_small_or_positive {
165					diag = neg(delta);
166				} else {
167					if small_or_negative && minus_small_or_positive {
168						if diag < zero() {
169							diag = neg(delta);
170						} else {
171							diag = copy(delta);
172						}
173					}
174				}
175			}
176
177			let j = j;
178			let diag = if is_llt {
179				if !(diag > zero()) {
180					*D = from_real(diag);
181					return Err(*j);
182				}
183				sqrt(diag)
184			} else {
185				copy(diag)
186			};
187
188			*D = from_real(diag);
189
190			if diag == zero() || !is_finite(diag) {
191				return Err(*j);
192			}
193		}
194
195		let diag = real(*D);
196
197		{
198			let mut Aj = Aj.rb_mut();
199			let inv = simd.splat_real(&recip(diag));
200
201			for i in indices.clone() {
202				let mut Aij = simd.read(Aj.rb(), i);
203				Aij = simd.mul_real(Aij, inv);
204				simd.write(Aj.rb_mut(), i, Aij);
205			}
206			{
207				let mut Aij = simd.read(Aj.rb(), idx_tail);
208				Aij = simd.mul_real(Aij, inv);
209				simd.write(Aj.rb_mut(), idx_tail, Aij);
210			}
211		}
212	}
213
214	Ok(count)
215}
216
217#[inline(always)]
218#[math]
219fn simd_cholesky_matrix<T: ComplexField, S: Simd>(
220	simd: T::SimdCtx<S>,
221	A: MatMut<'_, T, usize, usize, ContiguousFwd>,
222	D: RowMut<'_, T, usize>,
223
224	is_llt: bool,
225	regularize: bool,
226	eps: T::Real,
227	delta: T::Real,
228	signs: Option<&[i8]>,
229) -> Result<usize, usize> {
230	let N = A.ncols();
231
232	let blocksize = 4 * (core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>());
233
234	let mut A = A;
235	let mut D = D;
236
237	let mut count = 0;
238
239	let mut j = 0;
240	while j < N {
241		let blocksize = Ord::min(blocksize, N - j);
242		let j_next = j + blocksize;
243
244		with_dim!(HEAD, j_next);
245		let A = A.rb_mut().submatrix_mut(0, 0, HEAD, HEAD);
246		let D = D.rb_mut().subcols_mut(0, HEAD);
247
248		let signs = signs.map(|signs| Array::from_ref(&signs[..*HEAD], HEAD));
249
250		count += simd_cholesky_row_batch(simd, A, D, HEAD.idx_inc(j), is_llt, regularize, eps.clone(), delta.clone(), signs)?;
251		j += blocksize;
252	}
253
254	Ok(count)
255}
256
257fn simd_cholesky<T: ComplexField>(
258	A: MatMut<'_, T>,
259	D: RowMut<'_, T>,
260	is_llt: bool,
261	regularize: bool,
262	eps: T::Real,
263	delta: T::Real,
264	signs: Option<&[i8]>,
265) -> Result<usize, usize> {
266	struct Impl<'a, T: ComplexField> {
267		A: MatMut<'a, T, usize, usize, ContiguousFwd>,
268		D: RowMut<'a, T>,
269		is_llt: bool,
270		regularize: bool,
271		eps: T::Real,
272		delta: T::Real,
273		signs: Option<&'a [i8]>,
274	}
275
276	impl<'a, T: ComplexField> pulp::WithSimd for Impl<'a, T> {
277		type Output = Result<usize, usize>;
278
279		#[inline(always)]
280		fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
281			let Self {
282				A,
283				D,
284				is_llt,
285				regularize,
286				eps,
287				delta,
288				signs,
289			} = self;
290			let simd = T::simd_ctx(simd);
291			if A.nrows() > 0 {
292				simd_cholesky_matrix(simd, A, D, is_llt, regularize, eps, delta, signs)
293			} else {
294				Ok(0)
295			}
296		}
297	}
298
299	let mut A = A;
300	if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
301		if let Some(A) = A.rb_mut().try_as_col_major_mut() {
302			dispatch!(
303				Impl {
304					A,
305					D,
306					is_llt,
307					regularize,
308					eps,
309					delta,
310					signs,
311				},
312				Impl,
313				T
314			)
315		} else {
316			cholesky_fallback(A, D, is_llt, regularize, eps.clone(), delta.clone(), signs)
317		}
318	} else {
319		cholesky_fallback(A, D, is_llt, regularize, eps.clone(), delta.clone(), signs)
320	}
321}
322
323#[math]
324fn cholesky_fallback<T: ComplexField>(
325	A: MatMut<'_, T>,
326	D: RowMut<'_, T>,
327	is_llt: bool,
328	regularize: bool,
329	eps: T::Real,
330	delta: T::Real,
331	signs: Option<&[i8]>,
332) -> Result<usize, usize> {
333	let n = A.nrows();
334	let mut count = 0;
335	let mut A = A;
336	let mut D = D;
337
338	for j in 0..n {
339		for i in j..n {
340			let mut sum = zero();
341			for k in 0..j {
342				let D = real(D[k]);
343				let D = if is_llt { one() } else { D };
344
345				sum = sum + mul_real(conj(A[(j, k)]) * A[(i, k)], D);
346			}
347			A[(i, j)] = A[(i, j)] - sum;
348		}
349
350		let D = D.rb_mut().at_mut(j);
351		let mut diag = real(A[(j, j)]);
352
353		if regularize {
354			let sign = if is_llt { 1 } else { if let Some(signs) = signs { signs[j] } else { 0 } };
355
356			let small_or_negative = diag <= eps;
357			let minus_small_or_positive = diag >= -eps;
358
359			if sign == 1 && small_or_negative {
360				diag = copy(delta);
361				count += 1;
362			} else if sign == -1i8 && minus_small_or_positive {
363				diag = neg(delta);
364			} else {
365				if small_or_negative && minus_small_or_positive {
366					if diag < zero() {
367						diag = neg(delta);
368					} else {
369						diag = copy(delta);
370					}
371				}
372			}
373		}
374
375		let diag = if is_llt {
376			if !(diag > zero()) {
377				*D = from_real(diag);
378				return Err(j);
379			}
380			sqrt(diag)
381		} else {
382			copy(diag)
383		};
384		*D = from_real(diag);
385
386		if diag == zero() || !is_finite(diag) {
387			return Err(j);
388		}
389
390		let inv = recip(diag);
391
392		for i in j..n {
393			A[(i, j)] = mul_real(A[(i, j)], inv);
394		}
395	}
396
397	Ok(count)
398}
399
400#[math]
401pub(crate) fn cholesky_recursion<T: ComplexField>(
402	A: MatMut<'_, T>,
403	D: RowMut<'_, T>,
404
405	recursion_threshold: usize,
406	blocksize: usize,
407	is_llt: bool,
408	regularize: bool,
409	eps: &T::Real,
410	delta: &T::Real,
411	signs: Option<&[i8]>,
412	par: Par,
413) -> Result<usize, usize> {
414	let n = A.ncols();
415	if n <= recursion_threshold {
416		simd_cholesky(A, D, is_llt, regularize, eps.clone(), delta.clone(), signs)
417	} else {
418		let mut count = 0;
419		let blocksize = Ord::min(n.next_power_of_two() / 2, blocksize);
420		let mut A = A;
421		let mut D = D;
422
423		let mut j = 0;
424		while j < n {
425			let blocksize = Ord::min(blocksize, n - j);
426
427			let (mut A00, A01, mut A10, mut A11) = A.rb_mut().get_mut(j.., j..).split_at_mut(blocksize, blocksize);
428
429			let mut D0 = D.rb_mut().subcols_mut(j, blocksize);
430
431			let mut L10xD0 = A01.transpose_mut();
432
433			let signs = signs.map(|signs| &signs[j..][..blocksize]);
434
435			match cholesky_recursion(
436				A00.rb_mut(),
437				D0.rb_mut(),
438				recursion_threshold,
439				blocksize,
440				is_llt,
441				regularize,
442				eps,
443				delta,
444				signs,
445				par,
446			) {
447				Ok(local_count) => count += local_count,
448				Err(fail_idx) => return Err(j + fail_idx),
449			}
450			let A00 = A00.rb();
451
452			if is_llt {
453				linalg::triangular_solve::solve_lower_triangular_in_place(A00.conjugate(), A10.rb_mut().transpose_mut(), par)
454			} else {
455				linalg::triangular_solve::solve_unit_lower_triangular_in_place(A00.conjugate(), A10.rb_mut().transpose_mut(), par)
456			}
457			let mut A10 = A10.rb_mut();
458
459			if is_llt {
460				linalg::matmul::triangular::matmul(
461					A11.rb_mut(),
462					BlockStructure::TriangularLower,
463					Accum::Add,
464					A10.rb(),
465					BlockStructure::Rectangular,
466					A10.rb().adjoint(),
467					BlockStructure::Rectangular,
468					-one::<T>(),
469					par,
470				);
471			} else {
472				if has_spicy_matmul::<T>() {
473					for k in 0..blocksize {
474						let d = real(D0[k]);
475						let d = recip(d);
476
477						for i in j + blocksize..n {
478							let i = i - (j + blocksize);
479							A10[(i, k)] = mul_real(A10[(i, k)], d);
480						}
481					}
482					spicy_matmul::<usize, T>(
483						A11.rb_mut(),
484						BlockStructure::TriangularLower,
485						None,
486						None,
487						Accum::Add,
488						A10.rb(),
489						Conj::No,
490						A10.rb().transpose(),
491						Conj::Yes,
492						Some(D0.rb().transpose().as_diagonal()),
493						-one::<T>(),
494						par,
495						MemStack::new(&mut []),
496					);
497				} else {
498					for k in 0..blocksize {
499						let d = real(D0[k]);
500						let d = recip(d);
501
502						for i in j + blocksize..n {
503							let i = i - (j + blocksize);
504							let a = copy(A10[(i, k)]);
505							A10[(i, k)] = mul_real(A10[(i, k)], d);
506							L10xD0[(i, k)] = a;
507						}
508					}
509					linalg::matmul::triangular::matmul(
510						A11.rb_mut(),
511						BlockStructure::TriangularLower,
512						Accum::Add,
513						A10,
514						BlockStructure::Rectangular,
515						L10xD0.adjoint(),
516						BlockStructure::Rectangular,
517						-one::<T>(),
518						par,
519					);
520				}
521			};
522
523			j += blocksize;
524		}
525
526		Ok(count)
527	}
528}
529
530/// dynamic $LDL^\top$ regularization.
531/// values below `epsilon` in absolute value, or with the wrong sign are set to `delta` with
532/// their corrected sign.
533#[derive(Copy, Clone, Debug)]
534pub struct LdltRegularization<'a, T> {
535	/// expected signs for the diagonal at each step of the decomposition.
536	pub dynamic_regularization_signs: Option<&'a [i8]>,
537	/// regularized value.
538	pub dynamic_regularization_delta: T,
539	/// regularization threshold.
540	pub dynamic_regularization_epsilon: T,
541}
542
543/// info about the result of the $LDL^\top$ factorization.
544#[derive(Copy, Clone, Debug)]
545pub struct LdltInfo {
546	/// number of pivots whose value or sign had to be corrected.
547	pub dynamic_regularization_count: usize,
548}
549
550/// error in the $LDL^\top$ factorization.
551#[derive(Copy, Clone, Debug)]
552pub enum LdltError {
553	ZeroPivot { index: usize },
554}
555
556impl core::fmt::Display for LdltError {
557	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
558		core::fmt::Debug::fmt(self, f)
559	}
560}
561impl core::error::Error for LdltError {}
562
563impl<T: RealField> Default for LdltRegularization<'_, T> {
564	fn default() -> Self {
565		Self {
566			dynamic_regularization_signs: None,
567			dynamic_regularization_delta: zero(),
568			dynamic_regularization_epsilon: zero(),
569		}
570	}
571}
572
573#[derive(Copy, Clone, Debug)]
574pub struct LdltParams {
575	pub recursion_threshold: usize,
576	pub blocksize: usize,
577	#[doc(hidden)]
578	pub non_exhaustive: NonExhaustive,
579}
580
581impl<T: ComplexField> Auto<T> for LdltParams {
582	#[inline]
583	fn auto() -> Self {
584		Self {
585			recursion_threshold: 64,
586			blocksize: 128,
587			non_exhaustive: NonExhaustive(()),
588		}
589	}
590}
591
592#[inline]
593pub fn cholesky_in_place_scratch<T: ComplexField>(dim: usize, par: Par, params: Spec<LdltParams, T>) -> StackReq {
594	_ = par;
595	_ = params;
596	temp_mat_scratch::<T>(dim, 1)
597}
598
599#[math]
600pub fn cholesky_in_place<T: ComplexField>(
601	A: MatMut<'_, T>,
602	regularization: LdltRegularization<'_, T::Real>,
603	par: Par,
604	stack: &mut MemStack,
605	params: Spec<LdltParams, T>,
606) -> Result<LdltInfo, LdltError> {
607	let params = params.config;
608
609	let n = A.nrows();
610	let mut D = unsafe { temp_mat_uninit(n, 1, stack).0 };
611	let D = D.as_mat_mut();
612	let mut D = D.col_mut(0).transpose_mut();
613	let mut A = A;
614
615	let ret = match cholesky_recursion(
616		A.rb_mut(),
617		D.rb_mut(),
618		params.recursion_threshold,
619		params.blocksize,
620		false,
621		regularization.dynamic_regularization_delta > zero() && regularization.dynamic_regularization_epsilon > zero(),
622		&regularization.dynamic_regularization_epsilon,
623		&regularization.dynamic_regularization_delta,
624		regularization.dynamic_regularization_signs.map(|signs| signs),
625		par,
626	) {
627		Ok(count) => Ok(LdltInfo {
628			dynamic_regularization_count: count,
629		}),
630		Err(index) => Err(LdltError::ZeroPivot { index }),
631	};
632	let init = if let Err(LdltError::ZeroPivot { index }) = ret { index + 1 } else { n };
633
634	for i in 0..init {
635		A[(i, i)] = copy(D[i]);
636	}
637
638	ret
639}
640
641#[cfg(test)]
642mod tests {
643	use super::*;
644	use crate::stats::prelude::*;
645	use crate::utils::approx::*;
646	use crate::{Mat, Row, assert, c64};
647
648	#[test]
649	fn test_simd_cholesky() {
650		let rng = &mut StdRng::seed_from_u64(0);
651
652		type T = c64;
653
654		for n in 0..=64 {
655			for f in [cholesky_fallback::<T>, simd_cholesky::<T>] {
656				for llt in [true, false] {
657					let approx_eq = CwiseMat(ApproxEq {
658						abs_tol: 1e-12,
659						rel_tol: 1e-12,
660					});
661
662					let A = CwiseMatDistribution {
663						nrows: n,
664						ncols: n,
665						dist: ComplexDistribution::new(StandardNormal, StandardNormal),
666					}
667					.rand::<Mat<c64>>(rng);
668
669					let A = &A * &A.adjoint();
670					let A = A.as_ref().as_shape(n, n);
671
672					let mut L = A.cloned();
673					let mut L = L.as_mut();
674					let mut D = Row::zeros(n);
675					let mut D = D.as_mut();
676
677					f(L.rb_mut(), D.rb_mut(), llt, false, 0.0, 0.0, None).unwrap();
678
679					for j in 0..n {
680						for i in 0..j {
681							L[(i, j)] = c64::ZERO;
682						}
683					}
684					let L = L.rb().as_dyn_stride();
685
686					if llt {
687						assert!(L * L.adjoint() ~ A);
688					} else {
689						assert!(L * D.as_diagonal() * L.adjoint() ~ A);
690					};
691				}
692			}
693		}
694	}
695
696	#[test]
697	fn test_cholesky() {
698		let rng = &mut StdRng::seed_from_u64(0);
699
700		for n in [2, 4, 8, 31, 127, 240] {
701			for llt in [false, true] {
702				let approx_eq = CwiseMat(ApproxEq {
703					abs_tol: 1e-12,
704					rel_tol: 1e-12,
705				});
706
707				let A = CwiseMatDistribution {
708					nrows: n,
709					ncols: n,
710					dist: ComplexDistribution::new(StandardNormal, StandardNormal),
711				}
712				.rand::<Mat<c64>>(rng);
713
714				let A = &A * &A.adjoint();
715				let A = A.as_ref();
716
717				let mut L = A.cloned();
718				let mut L = L.as_mut();
719				let mut D = Row::zeros(n);
720				let mut D = D.as_mut();
721
722				cholesky_recursion(L.rb_mut(), D.rb_mut(), 32, 32, llt, false, &0.0, &0.0, None, Par::Seq).unwrap();
723
724				for j in 0..n {
725					for i in 0..j {
726						L[(i, j)] = c64::ZERO;
727					}
728				}
729				let L = L.rb().as_dyn_stride();
730
731				if llt {
732					assert!(L * L.adjoint() ~ A);
733				} else {
734					assert!(L * D.as_diagonal() * L.adjoint() ~ A);
735				};
736			}
737		}
738	}
739}