faer/linalg/cholesky/llt_pivoting/
factor.rs1use crate::assert;
2use crate::internal_prelude::*;
3
4pub use linalg::cholesky::llt::factor::LltError;
5use linalg::matmul::triangular::BlockStructure;
6
7#[derive(Copy, Clone, Debug)]
8pub struct PivLltParams {
9 pub blocksize: usize,
10
11 #[doc(hidden)]
12 pub non_exhaustive: NonExhaustive,
13}
14
15impl Default for PivLltParams {
16 #[inline]
17 fn default() -> Self {
18 Self {
19 blocksize: 128,
20 non_exhaustive: NonExhaustive(()),
21 }
22 }
23}
24
25#[derive(Copy, Clone, Debug)]
26pub struct PivLltInfo {
27 pub rank: usize,
29 pub transposition_count: usize,
31}
32
33#[inline]
34pub fn cholesky_in_place_scratch<I: Index, T: ComplexField>(dim: usize, par: Par, params: PivLltParams) -> StackReq {
35 _ = par;
36 _ = params;
37 temp_mat_scratch::<T::Real>(dim, 2)
38}
39
40#[track_caller]
41#[math]
42pub fn cholesky_in_place<'out, I: Index, T: ComplexField>(
43 a: MatMut<'_, T>,
44 perm: &'out mut [I],
45 perm_inv: &'out mut [I],
46 par: Par,
47 stack: &mut MemStack,
48 params: PivLltParams,
49) -> Result<(PivLltInfo, PermRef<'out, I>), LltError> {
50 assert!(a.nrows() == a.ncols());
51 let n = a.nrows();
52 assert!(n <= I::Signed::MAX.zx());
53 let mut rank = n;
54 let mut transposition_count = 0;
55
56 'exit: {
57 if n > 0 {
58 let mut a = a;
59 for (i, p) in perm.iter_mut().enumerate() {
60 *p = I::truncate(i);
61 }
62
63 let (mut work1, stack) = unsafe { temp_mat_uninit::<T::Real, _, _>(n, 1, stack) };
64 let (mut work2, _) = unsafe { temp_mat_uninit::<T::Real, _, _>(n, 1, stack) };
65 let work1 = work1.as_mat_mut();
66 let work2 = work2.as_mat_mut();
67
68 let mut dot_products = work1.col_mut(0);
69 let mut diagonals = work2.col_mut(0);
70
71 let mut ajj = zero::<T::Real>();
72 let mut pvt = 0usize;
73
74 for i in 0..n {
75 let aii = real(a[(i, i)]);
76 if aii < zero::<T::Real>() || is_nan(aii) {
77 return Err(LltError::NonPositivePivot { index: 0 });
78 }
79 if aii > ajj {
80 ajj = aii;
81 pvt = i;
82 }
83 }
84
85 let tol = eps::<T::Real>() * from_f64::<T::Real>(n as f64) * ajj;
86
87 let mut k = 0usize;
88 while k < n {
89 let bs = Ord::min(n - k, params.blocksize);
90
91 for i in k..n {
92 dot_products[i] = zero::<T::Real>();
93 }
94
95 for j in k..k + bs {
96 if j == k {
97 for i in j..n {
98 diagonals[i] = real(a[(i, i)]);
99 }
100 } else {
101 for i in j..n {
102 dot_products[i] = dot_products[i] + abs2(a[(i, j - 1)]);
103 diagonals[i] = real(a[(i, i)]) - dot_products[i];
104 }
105 }
106
107 if j > 0 {
108 pvt = j;
109 ajj = zero::<T::Real>();
110 for i in j..n {
111 let aii = real(diagonals[i]);
112 if is_nan(aii) {
113 return Err(LltError::NonPositivePivot { index: j });
114 }
115 if aii > ajj {
116 pvt = i;
117 ajj = aii;
118 }
119 }
120 if ajj < tol {
121 rank = j;
122 a[(j, j)] = from_real(ajj);
123 break 'exit;
124 }
125 }
126
127 if pvt != j {
128 transposition_count += 1;
129
130 a[(pvt, pvt)] = copy(a[(j, j)]);
131 crate::perm::swap_rows_idx(a.rb_mut().get_mut(.., ..j), j, pvt);
132 crate::perm::swap_cols_idx(a.rb_mut().get_mut(pvt + 1.., ..), j, pvt);
133 unsafe {
134 z!(
135 a.rb().get(j + 1..pvt, j).const_cast(),
136 a.rb().get(pvt, j + 1..pvt).const_cast().transpose_mut(),
137 )
138 }
139 .for_each(|uz!(a, b)| (*a, *b) = (conj(*b), conj(*a)));
140 a[(pvt, j)] = conj(a[(pvt, j)]);
141
142 let tmp = copy(dot_products[j]);
143 dot_products[j] = copy(dot_products[pvt]);
144 dot_products[pvt] = tmp;
145 perm.swap(j, pvt);
146 }
147
148 ajj = sqrt(ajj);
149 a[(j, j)] = from_real(ajj);
150 unsafe {
151 linalg::matmul::matmul(
152 a.rb().get(j + 1.., j).const_cast(),
153 Accum::Add,
154 a.rb().get(j + 1.., k..j),
155 a.rb().get(j, k..j).adjoint(),
156 -one::<T>(),
157 par,
158 );
159 }
160 let ajj = recip(ajj);
161 z!(a.rb_mut().get_mut(j + 1.., j)).for_each(|uz!(x)| *x = mul_real(*x, ajj));
162 }
163
164 linalg::matmul::triangular::matmul(
165 unsafe { a.rb().get(k + bs.., k + bs..).const_cast() },
166 BlockStructure::TriangularLower,
167 Accum::Add,
168 a.rb().get(k + bs.., k..k + bs),
169 BlockStructure::Rectangular,
170 a.rb().get(k + bs.., k..k + bs).adjoint(),
171 BlockStructure::Rectangular,
172 -one::<T>(),
173 par,
174 );
175
176 k += bs;
177 }
178 rank = n;
179 }
180 }
181
182 for (i, p) in perm.iter().enumerate() {
183 perm_inv[p.zx()] = I::truncate(i);
184 }
185
186 unsafe { Ok((PivLltInfo { rank, transposition_count }, PermRef::new_unchecked(perm, perm_inv, n))) }
187}