faer/linalg/cholesky/llt_pivoting/
factor.rs

1use crate::assert;
2use crate::internal_prelude::*;
3
4pub use linalg::cholesky::llt::factor::LltError;
5use linalg::matmul::triangular::BlockStructure;
6
7#[derive(Copy, Clone, Debug)]
8pub struct PivLltParams {
9	pub blocksize: usize,
10
11	#[doc(hidden)]
12	pub non_exhaustive: NonExhaustive,
13}
14
15impl Default for PivLltParams {
16	#[inline]
17	fn default() -> Self {
18		Self {
19			blocksize: 128,
20			non_exhaustive: NonExhaustive(()),
21		}
22	}
23}
24
25#[derive(Copy, Clone, Debug)]
26pub struct PivLltInfo {
27	/// numerical rank of the matrix
28	pub rank: usize,
29	/// number of transpositions that make up the permutation
30	pub transposition_count: usize,
31}
32
33#[inline]
34pub fn cholesky_in_place_scratch<I: Index, T: ComplexField>(dim: usize, par: Par, params: PivLltParams) -> StackReq {
35	_ = par;
36	_ = params;
37	temp_mat_scratch::<T::Real>(dim, 2)
38}
39
40#[track_caller]
41#[math]
42pub fn cholesky_in_place<'out, I: Index, T: ComplexField>(
43	a: MatMut<'_, T>,
44	perm: &'out mut [I],
45	perm_inv: &'out mut [I],
46	par: Par,
47	stack: &mut MemStack,
48	params: PivLltParams,
49) -> Result<(PivLltInfo, PermRef<'out, I>), LltError> {
50	assert!(a.nrows() == a.ncols());
51	let n = a.nrows();
52	assert!(n <= I::Signed::MAX.zx());
53	let mut rank = n;
54	let mut transposition_count = 0;
55
56	'exit: {
57		if n > 0 {
58			let mut a = a;
59			for (i, p) in perm.iter_mut().enumerate() {
60				*p = I::truncate(i);
61			}
62
63			let (mut work1, stack) = unsafe { temp_mat_uninit::<T::Real, _, _>(n, 1, stack) };
64			let (mut work2, _) = unsafe { temp_mat_uninit::<T::Real, _, _>(n, 1, stack) };
65			let work1 = work1.as_mat_mut();
66			let work2 = work2.as_mat_mut();
67
68			let mut dot_products = work1.col_mut(0);
69			let mut diagonals = work2.col_mut(0);
70
71			let mut ajj = zero::<T::Real>();
72			let mut pvt = 0usize;
73
74			for i in 0..n {
75				let aii = real(a[(i, i)]);
76				if aii < zero::<T::Real>() || is_nan(aii) {
77					return Err(LltError::NonPositivePivot { index: 0 });
78				}
79				if aii > ajj {
80					ajj = aii;
81					pvt = i;
82				}
83			}
84
85			let tol = eps::<T::Real>() * from_f64::<T::Real>(n as f64) * ajj;
86
87			let mut k = 0usize;
88			while k < n {
89				let bs = Ord::min(n - k, params.blocksize);
90
91				for i in k..n {
92					dot_products[i] = zero::<T::Real>();
93				}
94
95				for j in k..k + bs {
96					if j == k {
97						for i in j..n {
98							diagonals[i] = real(a[(i, i)]);
99						}
100					} else {
101						for i in j..n {
102							dot_products[i] = dot_products[i] + abs2(a[(i, j - 1)]);
103							diagonals[i] = real(a[(i, i)]) - dot_products[i];
104						}
105					}
106
107					if j > 0 {
108						pvt = j;
109						ajj = zero::<T::Real>();
110						for i in j..n {
111							let aii = real(diagonals[i]);
112							if is_nan(aii) {
113								return Err(LltError::NonPositivePivot { index: j });
114							}
115							if aii > ajj {
116								pvt = i;
117								ajj = aii;
118							}
119						}
120						if ajj < tol {
121							rank = j;
122							a[(j, j)] = from_real(ajj);
123							break 'exit;
124						}
125					}
126
127					if pvt != j {
128						transposition_count += 1;
129
130						a[(pvt, pvt)] = copy(a[(j, j)]);
131						crate::perm::swap_rows_idx(a.rb_mut().get_mut(.., ..j), j, pvt);
132						crate::perm::swap_cols_idx(a.rb_mut().get_mut(pvt + 1.., ..), j, pvt);
133						unsafe {
134							z!(
135								a.rb().get(j + 1..pvt, j).const_cast(),
136								a.rb().get(pvt, j + 1..pvt).const_cast().transpose_mut(),
137							)
138						}
139						.for_each(|uz!(a, b)| (*a, *b) = (conj(*b), conj(*a)));
140						a[(pvt, j)] = conj(a[(pvt, j)]);
141
142						let tmp = copy(dot_products[j]);
143						dot_products[j] = copy(dot_products[pvt]);
144						dot_products[pvt] = tmp;
145						perm.swap(j, pvt);
146					}
147
148					ajj = sqrt(ajj);
149					a[(j, j)] = from_real(ajj);
150					unsafe {
151						linalg::matmul::matmul(
152							a.rb().get(j + 1.., j).const_cast(),
153							Accum::Add,
154							a.rb().get(j + 1.., k..j),
155							a.rb().get(j, k..j).adjoint(),
156							-one::<T>(),
157							par,
158						);
159					}
160					let ajj = recip(ajj);
161					z!(a.rb_mut().get_mut(j + 1.., j)).for_each(|uz!(x)| *x = mul_real(*x, ajj));
162				}
163
164				linalg::matmul::triangular::matmul(
165					unsafe { a.rb().get(k + bs.., k + bs..).const_cast() },
166					BlockStructure::TriangularLower,
167					Accum::Add,
168					a.rb().get(k + bs.., k..k + bs),
169					BlockStructure::Rectangular,
170					a.rb().get(k + bs.., k..k + bs).adjoint(),
171					BlockStructure::Rectangular,
172					-one::<T>(),
173					par,
174				);
175
176				k += bs;
177			}
178			rank = n;
179		}
180	}
181
182	for (i, p) in perm.iter().enumerate() {
183		perm_inv[p.zx()] = I::truncate(i);
184	}
185
186	unsafe { Ok((PivLltInfo { rank, transposition_count }, PermRef::new_unchecked(perm, perm_inv, n))) }
187}