1use super::*;
2use crate::assert;
3
4pub fn bicgstab_scratch<T: ComplexField>(
6 left_precond: impl Precond<T>,
7 right_precond: impl Precond<T>,
8 mat: impl LinOp<T>,
9 rhs_ncols: usize,
10 par: Par,
11) -> StackReq {
12 fn implementation<T: ComplexField>(K1: &dyn Precond<T>, K2: &dyn Precond<T>, A: &dyn LinOp<T>, rhs_ncols: usize, par: Par) -> StackReq {
13 let n = A.nrows();
14 let k = rhs_ncols;
15
16 let nk = temp_mat_scratch::<T>(n, k);
17 let kk = temp_mat_scratch::<T>(k, k);
18 let k_usize = StackReq::new::<usize>(k);
19 let lu = crate::linalg::lu::full_pivoting::factor::lu_in_place_scratch::<usize, T>(k, k, par, Default::default());
20 StackReq::all_of(&[
21 k_usize, k_usize, k_usize, k_usize, kk, nk, nk, nk, nk, nk, nk, nk, nk, StackReq::any_of(&[
35 lu,
36 A.apply_scratch(k, par),
37 StackReq::all_of(&[
38 nk, K1.apply_scratch(k, par),
40 K2.apply_scratch(k, par),
41 ]),
42 StackReq::all_of(&[
43 kk, kk, ]),
46 kk, ]),
48 ])
49 }
50 implementation(&left_precond, &right_precond, &mat, rhs_ncols, par)
51}
52
53#[derive(Copy, Clone, Debug)]
55pub struct BicgParams<T> {
56 pub initial_guess: InitialGuessStatus,
58 pub abs_tolerance: T,
60 pub rel_tolerance: T,
62 pub max_iters: usize,
64
65 #[doc(hidden)]
66 pub non_exhaustive: NonExhaustive,
67}
68
69impl<T: RealField> Default for BicgParams<T> {
70 #[inline]
71 #[math]
72 fn default() -> Self {
73 Self {
74 initial_guess: InitialGuessStatus::MaybeNonZero,
75 abs_tolerance: zero(),
76 rel_tolerance: eps::<T>() * from_f64::<T>(128.0),
77 max_iters: usize::MAX,
78 non_exhaustive: NonExhaustive(()),
79 }
80 }
81}
82
83#[derive(Copy, Clone, Debug)]
85pub struct BicgInfo<T> {
86 pub abs_residual: T,
88 pub rel_residual: T,
90 pub iter_count: usize,
92
93 #[doc(hidden)]
94 pub non_exhaustive: NonExhaustive,
95}
96
97#[derive(Copy, Clone, Debug)]
99pub enum BicgError<T> {
100 NoConvergence {
102 abs_residual: T,
104 rel_residual: T,
106 },
107}
108
109#[track_caller]
114pub fn bicgstab<T: ComplexField>(
115 out: MatMut<'_, T>,
116 left_precond: impl Precond<T>,
117 right_precond: impl Precond<T>,
118 mat: impl LinOp<T>,
119 rhs: MatRef<'_, T>,
120 params: BicgParams<T::Real>,
121 callback: impl FnMut(MatRef<'_, T>),
122 par: Par,
123 stack: &mut MemStack,
124) -> Result<BicgInfo<T::Real>, BicgError<T::Real>> {
125 #[track_caller]
126 #[math]
127 fn implementation<T: ComplexField>(
128 out: MatMut<'_, T>,
129 left_precond: &dyn Precond<T>,
130 right_precond: &dyn Precond<T>,
131 mat: &dyn LinOp<T>,
132 rhs: MatRef<'_, T>,
133 params: BicgParams<T::Real>,
134 callback: &mut dyn FnMut(MatRef<'_, T>),
135 par: Par,
136 stack: &mut MemStack,
137 ) -> Result<BicgInfo<T::Real>, BicgError<T::Real>> {
138 let mut x = out;
139 let A = mat;
140 let K1 = left_precond;
141 let K2 = right_precond;
142 let b = rhs;
143
144 assert!(A.nrows() == A.ncols());
145 let n = A.nrows();
146 let k = x.ncols();
147
148 let b_norm = b.norm_l2();
149 if b_norm == zero::<T::Real>() {
150 x.fill(zero());
151 return Ok(BicgInfo {
152 abs_residual: zero::<T::Real>(),
153 rel_residual: zero::<T::Real>(),
154 iter_count: 0,
155 non_exhaustive: NonExhaustive(()),
156 });
157 }
158
159 let rel_threshold = params.rel_tolerance * b_norm;
160 let abs_threshold = params.abs_tolerance;
161 let threshold = if abs_threshold > rel_threshold { abs_threshold } else { rel_threshold };
162
163 let (row_perm, stack) = unsafe { stack.make_raw::<usize>(k) };
164 let (row_perm_inv, stack) = unsafe { stack.make_raw::<usize>(k) };
165 let (col_perm, stack) = unsafe { stack.make_raw::<usize>(k) };
166 let (col_perm_inv, stack) = unsafe { stack.make_raw::<usize>(k) };
167 let (mut rtv, stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack) };
168 let mut rtv = rtv.as_mat_mut();
169 let (mut r, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
170 let mut r = r.as_mat_mut();
171 let (mut p, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
172 let mut p = p.as_mat_mut();
173 let (mut r_tilde, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
174 let mut r_tilde = r_tilde.as_mat_mut();
175
176 let abs_residual = if params.initial_guess == InitialGuessStatus::MaybeNonZero {
177 A.apply(r.rb_mut(), x.rb(), par, stack);
178 z!(&mut r, &b).for_each(|uz!(r, b)| *r = *b - *r);
179
180 r.norm_l2()
181 } else {
182 copy(b_norm)
183 };
184
185 if abs_residual < threshold {
186 return Ok(BicgInfo {
187 rel_residual: abs_residual / b_norm,
188 abs_residual,
189 iter_count: 0,
190 non_exhaustive: NonExhaustive(()),
191 });
192 }
193
194 p.copy_from(&r);
195 r_tilde.copy_from(&r);
196
197 for iter in 0..params.max_iters {
198 let (mut v, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
199 let mut v = v.as_mat_mut();
200 let (mut y, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
201 let mut y = y.as_mat_mut();
202 {
203 let (mut y0, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
204 let mut y0 = y0.as_mat_mut();
205 K1.apply(y0.rb_mut(), p.rb(), par, stack);
206 K2.apply(y.rb_mut(), y0.rb(), par, stack);
207 }
208 A.apply(v.rb_mut(), y.rb(), par, stack);
209
210 crate::linalg::matmul::matmul(rtv.rb_mut(), Accum::Replace, r_tilde.rb().transpose(), v.rb(), one::<T>(), par);
211 let (_, row_perm, col_perm) = crate::linalg::lu::full_pivoting::factor::lu_in_place(
212 rtv.rb_mut(),
213 row_perm,
214 row_perm_inv,
215 col_perm,
216 col_perm_inv,
217 par,
218 stack,
219 Default::default(),
220 );
221 let mut rank = k;
222 let tol = eps::<T::Real>() * from_f64::<T::Real>(k as f64) * abs(rtv[(0, 0)]);
223 for i in 0..k {
224 if abs(rtv[(i, i)]) < tol {
225 rank = i;
226 break;
227 }
228 }
229
230 let (mut s, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
231 let mut s = s.as_mat_mut();
232 {
233 let (mut rtr, stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack) };
234 let mut rtr = rtr.as_mat_mut();
235 crate::linalg::matmul::matmul(rtr.rb_mut(), Accum::Replace, r_tilde.rb().transpose(), r.rb(), one::<T>(), par);
236 let (mut temp, _) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack) };
237 let mut temp = temp.as_mat_mut();
238 crate::perm::permute_rows(temp.rb_mut(), rtr.rb(), row_perm);
239 crate::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
240 rtv.rb().get(..rank, ..rank),
241 temp.rb_mut().get_mut(..rank, ..),
242 par,
243 );
244 crate::linalg::triangular_solve::solve_upper_triangular_in_place(
245 rtv.rb().get(..rank, ..rank),
246 temp.rb_mut().get_mut(..rank, ..),
247 par,
248 );
249 temp.rb_mut().get_mut(rank.., ..).fill(zero());
250 crate::perm::permute_rows(rtr.rb_mut(), temp.rb(), col_perm.inverse());
251 let alpha = rtr.rb();
252
253 s.copy_from(&r);
254 crate::linalg::matmul::matmul(s.rb_mut(), Accum::Add, v.rb(), alpha.rb(), -one::<T>(), par);
255 crate::linalg::matmul::matmul(
256 x.rb_mut(),
257 if iter == 0 && params.initial_guess == InitialGuessStatus::Zero {
258 Accum::Replace
259 } else {
260 Accum::Add
261 },
262 y.rb(),
263 alpha.rb(),
264 one::<T>(),
265 par,
266 );
267 }
268 let norm = s.norm_l2();
269 if norm < threshold {
270 return Ok(BicgInfo {
271 rel_residual: norm / b_norm,
272 abs_residual: norm,
273 iter_count: iter + 1,
274 non_exhaustive: NonExhaustive(()),
275 });
276 }
277
278 let (mut t, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
279 let mut t = t.as_mat_mut();
280 let (mut z, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
281 let mut z = z.as_mat_mut();
282 {
283 let (mut z0, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
284 let mut z0 = z0.as_mat_mut();
285 K1.apply(z0.rb_mut(), s.rb(), par, stack);
286 K2.apply(z.rb_mut(), z0.rb(), par, stack);
287 }
288 A.apply(t.rb_mut(), z.rb(), par, stack);
289
290 let compute_w = |kt: MatRef<'_, T>, ks: MatRef<'_, T>| {
291 let mut wt = zero::<T>();
292 let mut ws = zero::<T>();
293 for j in 0..k {
294 let kt = kt.rb().col(j);
295 let ks = ks.rb().col(j);
296 ws = ws + kt.transpose() * ks;
297 wt = wt + kt.transpose() * kt;
298 }
299 recip(wt) * ws
300 };
301
302 let w = {
303 let mut kt = y;
304 let (mut ks, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
305 let mut ks = ks.as_mat_mut();
306 K1.apply(kt.rb_mut(), t.rb(), par, stack);
307 K1.apply(ks.rb_mut(), s.rb(), par, stack);
308 compute_w(kt.rb(), ks.rb())
309 };
310
311 z!(&mut r, &s, &t).for_each(|uz!(r, s, t)| *r = *s - w * *t);
312 z!(&mut x, &z).for_each(|uz!(x, z)| *x = *x + w * *z);
313 z!(&mut p, &v).for_each(|uz!(p, v)| *p = *p - w * *v);
314
315 callback(x.rb());
316
317 let norm = r.norm_l2();
318 if norm < threshold {
319 return Ok(BicgInfo {
320 rel_residual: norm / b_norm,
321 abs_residual: norm,
322 iter_count: iter + 1,
323 non_exhaustive: NonExhaustive(()),
324 });
325 }
326
327 let (mut rtt, stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack) };
328 let mut rtt = rtt.as_mat_mut();
329 {
330 crate::linalg::matmul::matmul(rtt.rb_mut(), Accum::Replace, r_tilde.rb().transpose(), t.rb(), one::<T>(), par);
331 let (mut temp, _) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack) };
332 let mut temp = temp.as_mat_mut();
333 crate::perm::permute_rows(temp.rb_mut(), rtt.rb(), row_perm);
334 crate::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
335 rtv.rb().get(..rank, ..rank),
336 temp.rb_mut().get_mut(..rank, ..),
337 par,
338 );
339 crate::linalg::triangular_solve::solve_upper_triangular_in_place(
340 rtv.rb().get(..rank, ..rank),
341 temp.rb_mut().get_mut(..rank, ..),
342 par,
343 );
344 temp.rb_mut().get_mut(rank.., ..).fill(zero());
345 crate::perm::permute_rows(rtt.rb_mut(), temp.rb(), col_perm.inverse());
346 }
347
348 let beta = rtt.rb();
349 let mut tmp = v;
350 crate::linalg::matmul::matmul(tmp.rb_mut(), Accum::Replace, p.rb(), beta.rb(), one::<T>(), par);
351 z!(&mut p, &r, &tmp).for_each(|uz!(p, r, tmp)| *p = *r - *tmp);
352 }
353 Err(BicgError::NoConvergence {
354 rel_residual: abs_residual / b_norm,
355 abs_residual,
356 })
357 }
358 implementation(out, &left_precond, &right_precond, &mat, rhs, params, &mut { callback }, par, stack)
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use crate::mat;
365 use dyn_stack::MemBuffer;
366 use equator::assert;
367 use rand::prelude::*;
368
369 #[test]
370 fn test_bicgstab() {
371 let ref mut rng = StdRng::seed_from_u64(0);
372
373 let ref A = mat![[2.5, -1.0], [1.0, 3.1]];
374 let ref sol = mat![[2.1, 2.1], [4.1, 3.2]];
375 let ref rhs = A * sol;
376 let ref mut diag = Mat::<f64>::identity(2, 2);
377 for i in 0..2 {
378 diag[(i, i)] = f64::exp(rand::distributions::Standard.sample(rng));
379 }
380 let ref diag = *diag;
381
382 let ref mut out = Mat::<f64>::zeros(2, sol.ncols());
383 let mut params = BicgParams::default();
384 params.max_iters = 10;
385 let result = bicgstab(
386 out.as_mut(),
387 diag.as_ref(),
388 diag.as_ref(),
389 A.as_ref(),
390 rhs.as_ref(),
391 params,
392 |_| {},
393 Par::Seq,
394 MemStack::new(&mut MemBuffer::new(bicgstab_scratch(
395 diag.as_ref(),
396 diag.as_ref(),
397 A.as_ref(),
398 sol.ncols(),
399 Par::Seq,
400 ))),
401 );
402 let ref out = *out;
403
404 assert!(result.is_ok());
405 assert!((A * out - rhs).norm_l2() <= params.rel_tolerance * rhs.norm_l2());
406 }
407}