faer/operator/
bicgstab.rs

1use super::*;
2use crate::assert;
3
4/// computes the size and alignment of required workspace for executing the bicgstab algorithm
5pub fn bicgstab_scratch<T: ComplexField>(
6	left_precond: impl Precond<T>,
7	right_precond: impl Precond<T>,
8	mat: impl LinOp<T>,
9	rhs_ncols: usize,
10	par: Par,
11) -> StackReq {
12	fn implementation<T: ComplexField>(K1: &dyn Precond<T>, K2: &dyn Precond<T>, A: &dyn LinOp<T>, rhs_ncols: usize, par: Par) -> StackReq {
13		let n = A.nrows();
14		let k = rhs_ncols;
15
16		let nk = temp_mat_scratch::<T>(n, k);
17		let kk = temp_mat_scratch::<T>(k, k);
18		let k_usize = StackReq::new::<usize>(k);
19		let lu = crate::linalg::lu::full_pivoting::factor::lu_in_place_scratch::<usize, T>(k, k, par, Default::default());
20		StackReq::all_of(&[
21			k_usize, // row_perm
22			k_usize, // row_perm_inv
23			k_usize, // col_perm
24			k_usize, // col_perm_inv
25			kk,      // rtv
26			nk,      // r
27			nk,      // p
28			nk,      // r_tilde
29			nk,      // v
30			nk,      // y
31			nk,      // s
32			nk,      // t
33			nk,      // z
34			StackReq::any_of(&[
35				lu,
36				A.apply_scratch(k, par),
37				StackReq::all_of(&[
38					nk, // y0 | z0 | ks
39					K1.apply_scratch(k, par),
40					K2.apply_scratch(k, par),
41				]),
42				StackReq::all_of(&[
43					kk, // rtr | rtt
44					kk, // temp
45				]),
46				kk, // rtr | rtt
47			]),
48		])
49	}
50	implementation(&left_precond, &right_precond, &mat, rhs_ncols, par)
51}
52
53/// algorithm parameters
54#[derive(Copy, Clone, Debug)]
55pub struct BicgParams<T> {
56	/// whether the initial guess is implicitly zero or not
57	pub initial_guess: InitialGuessStatus,
58	/// absolute tolerance for convergence testing
59	pub abs_tolerance: T,
60	/// relative tolerance for convergence testing
61	pub rel_tolerance: T,
62	/// maximum number of iterations
63	pub max_iters: usize,
64
65	#[doc(hidden)]
66	pub non_exhaustive: NonExhaustive,
67}
68
69impl<T: RealField> Default for BicgParams<T> {
70	#[inline]
71	#[math]
72	fn default() -> Self {
73		Self {
74			initial_guess: InitialGuessStatus::MaybeNonZero,
75			abs_tolerance: zero(),
76			rel_tolerance: eps::<T>() * from_f64::<T>(128.0),
77			max_iters: usize::MAX,
78			non_exhaustive: NonExhaustive(()),
79		}
80	}
81}
82
83/// algorithm result
84#[derive(Copy, Clone, Debug)]
85pub struct BicgInfo<T> {
86	/// absolute residual at the final step
87	pub abs_residual: T,
88	/// relative residual at the final step
89	pub rel_residual: T,
90	/// number of iterations executed by the algorithm
91	pub iter_count: usize,
92
93	#[doc(hidden)]
94	pub non_exhaustive: NonExhaustive,
95}
96
97/// algorithm error
98#[derive(Copy, Clone, Debug)]
99pub enum BicgError<T> {
100	/// convergence failure
101	NoConvergence {
102		/// absolute residual at the final step
103		abs_residual: T,
104		/// relative residual at the final step
105		rel_residual: T,
106	},
107}
108
109/// executes bicgstab using the provided preconditioners
110///
111/// # note
112/// this function is also optimized for a rhs with multiple columns
113#[track_caller]
114pub fn bicgstab<T: ComplexField>(
115	out: MatMut<'_, T>,
116	left_precond: impl Precond<T>,
117	right_precond: impl Precond<T>,
118	mat: impl LinOp<T>,
119	rhs: MatRef<'_, T>,
120	params: BicgParams<T::Real>,
121	callback: impl FnMut(MatRef<'_, T>),
122	par: Par,
123	stack: &mut MemStack,
124) -> Result<BicgInfo<T::Real>, BicgError<T::Real>> {
125	#[track_caller]
126	#[math]
127	fn implementation<T: ComplexField>(
128		out: MatMut<'_, T>,
129		left_precond: &dyn Precond<T>,
130		right_precond: &dyn Precond<T>,
131		mat: &dyn LinOp<T>,
132		rhs: MatRef<'_, T>,
133		params: BicgParams<T::Real>,
134		callback: &mut dyn FnMut(MatRef<'_, T>),
135		par: Par,
136		stack: &mut MemStack,
137	) -> Result<BicgInfo<T::Real>, BicgError<T::Real>> {
138		let mut x = out;
139		let A = mat;
140		let K1 = left_precond;
141		let K2 = right_precond;
142		let b = rhs;
143
144		assert!(A.nrows() == A.ncols());
145		let n = A.nrows();
146		let k = x.ncols();
147
148		let b_norm = b.norm_l2();
149		if b_norm == zero::<T::Real>() {
150			x.fill(zero());
151			return Ok(BicgInfo {
152				abs_residual: zero::<T::Real>(),
153				rel_residual: zero::<T::Real>(),
154				iter_count: 0,
155				non_exhaustive: NonExhaustive(()),
156			});
157		}
158
159		let rel_threshold = params.rel_tolerance * b_norm;
160		let abs_threshold = params.abs_tolerance;
161		let threshold = if abs_threshold > rel_threshold { abs_threshold } else { rel_threshold };
162
163		let (row_perm, stack) = unsafe { stack.make_raw::<usize>(k) };
164		let (row_perm_inv, stack) = unsafe { stack.make_raw::<usize>(k) };
165		let (col_perm, stack) = unsafe { stack.make_raw::<usize>(k) };
166		let (col_perm_inv, stack) = unsafe { stack.make_raw::<usize>(k) };
167		let (mut rtv, stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack) };
168		let mut rtv = rtv.as_mat_mut();
169		let (mut r, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
170		let mut r = r.as_mat_mut();
171		let (mut p, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
172		let mut p = p.as_mat_mut();
173		let (mut r_tilde, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
174		let mut r_tilde = r_tilde.as_mat_mut();
175
176		let abs_residual = if params.initial_guess == InitialGuessStatus::MaybeNonZero {
177			A.apply(r.rb_mut(), x.rb(), par, stack);
178			z!(&mut r, &b).for_each(|uz!(r, b)| *r = *b - *r);
179
180			r.norm_l2()
181		} else {
182			copy(b_norm)
183		};
184
185		if abs_residual < threshold {
186			return Ok(BicgInfo {
187				rel_residual: abs_residual / b_norm,
188				abs_residual,
189				iter_count: 0,
190				non_exhaustive: NonExhaustive(()),
191			});
192		}
193
194		p.copy_from(&r);
195		r_tilde.copy_from(&r);
196
197		for iter in 0..params.max_iters {
198			let (mut v, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
199			let mut v = v.as_mat_mut();
200			let (mut y, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
201			let mut y = y.as_mat_mut();
202			{
203				let (mut y0, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
204				let mut y0 = y0.as_mat_mut();
205				K1.apply(y0.rb_mut(), p.rb(), par, stack);
206				K2.apply(y.rb_mut(), y0.rb(), par, stack);
207			}
208			A.apply(v.rb_mut(), y.rb(), par, stack);
209
210			crate::linalg::matmul::matmul(rtv.rb_mut(), Accum::Replace, r_tilde.rb().transpose(), v.rb(), one::<T>(), par);
211			let (_, row_perm, col_perm) = crate::linalg::lu::full_pivoting::factor::lu_in_place(
212				rtv.rb_mut(),
213				row_perm,
214				row_perm_inv,
215				col_perm,
216				col_perm_inv,
217				par,
218				stack,
219				Default::default(),
220			);
221			let mut rank = k;
222			let tol = eps::<T::Real>() * from_f64::<T::Real>(k as f64) * abs(rtv[(0, 0)]);
223			for i in 0..k {
224				if abs(rtv[(i, i)]) < tol {
225					rank = i;
226					break;
227				}
228			}
229
230			let (mut s, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
231			let mut s = s.as_mat_mut();
232			{
233				let (mut rtr, stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack) };
234				let mut rtr = rtr.as_mat_mut();
235				crate::linalg::matmul::matmul(rtr.rb_mut(), Accum::Replace, r_tilde.rb().transpose(), r.rb(), one::<T>(), par);
236				let (mut temp, _) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack) };
237				let mut temp = temp.as_mat_mut();
238				crate::perm::permute_rows(temp.rb_mut(), rtr.rb(), row_perm);
239				crate::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
240					rtv.rb().get(..rank, ..rank),
241					temp.rb_mut().get_mut(..rank, ..),
242					par,
243				);
244				crate::linalg::triangular_solve::solve_upper_triangular_in_place(
245					rtv.rb().get(..rank, ..rank),
246					temp.rb_mut().get_mut(..rank, ..),
247					par,
248				);
249				temp.rb_mut().get_mut(rank.., ..).fill(zero());
250				crate::perm::permute_rows(rtr.rb_mut(), temp.rb(), col_perm.inverse());
251				let alpha = rtr.rb();
252
253				s.copy_from(&r);
254				crate::linalg::matmul::matmul(s.rb_mut(), Accum::Add, v.rb(), alpha.rb(), -one::<T>(), par);
255				crate::linalg::matmul::matmul(
256					x.rb_mut(),
257					if iter == 0 && params.initial_guess == InitialGuessStatus::Zero {
258						Accum::Replace
259					} else {
260						Accum::Add
261					},
262					y.rb(),
263					alpha.rb(),
264					one::<T>(),
265					par,
266				);
267			}
268			let norm = s.norm_l2();
269			if norm < threshold {
270				return Ok(BicgInfo {
271					rel_residual: norm / b_norm,
272					abs_residual: norm,
273					iter_count: iter + 1,
274					non_exhaustive: NonExhaustive(()),
275				});
276			}
277
278			let (mut t, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
279			let mut t = t.as_mat_mut();
280			let (mut z, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
281			let mut z = z.as_mat_mut();
282			{
283				let (mut z0, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
284				let mut z0 = z0.as_mat_mut();
285				K1.apply(z0.rb_mut(), s.rb(), par, stack);
286				K2.apply(z.rb_mut(), z0.rb(), par, stack);
287			}
288			A.apply(t.rb_mut(), z.rb(), par, stack);
289
290			let compute_w = |kt: MatRef<'_, T>, ks: MatRef<'_, T>| {
291				let mut wt = zero::<T>();
292				let mut ws = zero::<T>();
293				for j in 0..k {
294					let kt = kt.rb().col(j);
295					let ks = ks.rb().col(j);
296					ws = ws + kt.transpose() * ks;
297					wt = wt + kt.transpose() * kt;
298				}
299				recip(wt) * ws
300			};
301
302			let w = {
303				let mut kt = y;
304				let (mut ks, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
305				let mut ks = ks.as_mat_mut();
306				K1.apply(kt.rb_mut(), t.rb(), par, stack);
307				K1.apply(ks.rb_mut(), s.rb(), par, stack);
308				compute_w(kt.rb(), ks.rb())
309			};
310
311			z!(&mut r, &s, &t).for_each(|uz!(r, s, t)| *r = *s - w * *t);
312			z!(&mut x, &z).for_each(|uz!(x, z)| *x = *x + w * *z);
313			z!(&mut p, &v).for_each(|uz!(p, v)| *p = *p - w * *v);
314
315			callback(x.rb());
316
317			let norm = r.norm_l2();
318			if norm < threshold {
319				return Ok(BicgInfo {
320					rel_residual: norm / b_norm,
321					abs_residual: norm,
322					iter_count: iter + 1,
323					non_exhaustive: NonExhaustive(()),
324				});
325			}
326
327			let (mut rtt, stack) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack) };
328			let mut rtt = rtt.as_mat_mut();
329			{
330				crate::linalg::matmul::matmul(rtt.rb_mut(), Accum::Replace, r_tilde.rb().transpose(), t.rb(), one::<T>(), par);
331				let (mut temp, _) = unsafe { temp_mat_uninit::<T, _, _>(k, k, stack) };
332				let mut temp = temp.as_mat_mut();
333				crate::perm::permute_rows(temp.rb_mut(), rtt.rb(), row_perm);
334				crate::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
335					rtv.rb().get(..rank, ..rank),
336					temp.rb_mut().get_mut(..rank, ..),
337					par,
338				);
339				crate::linalg::triangular_solve::solve_upper_triangular_in_place(
340					rtv.rb().get(..rank, ..rank),
341					temp.rb_mut().get_mut(..rank, ..),
342					par,
343				);
344				temp.rb_mut().get_mut(rank.., ..).fill(zero());
345				crate::perm::permute_rows(rtt.rb_mut(), temp.rb(), col_perm.inverse());
346			}
347
348			let beta = rtt.rb();
349			let mut tmp = v;
350			crate::linalg::matmul::matmul(tmp.rb_mut(), Accum::Replace, p.rb(), beta.rb(), one::<T>(), par);
351			z!(&mut p, &r, &tmp).for_each(|uz!(p, r, tmp)| *p = *r - *tmp);
352		}
353		Err(BicgError::NoConvergence {
354			rel_residual: abs_residual / b_norm,
355			abs_residual,
356		})
357	}
358	implementation(out, &left_precond, &right_precond, &mat, rhs, params, &mut { callback }, par, stack)
359}
360
361#[cfg(test)]
362mod tests {
363	use super::*;
364	use crate::mat;
365	use dyn_stack::MemBuffer;
366	use equator::assert;
367	use rand::prelude::*;
368
369	#[test]
370	fn test_bicgstab() {
371		let ref mut rng = StdRng::seed_from_u64(0);
372
373		let ref A = mat![[2.5, -1.0], [1.0, 3.1]];
374		let ref sol = mat![[2.1, 2.1], [4.1, 3.2]];
375		let ref rhs = A * sol;
376		let ref mut diag = Mat::<f64>::identity(2, 2);
377		for i in 0..2 {
378			diag[(i, i)] = f64::exp(rand::distributions::Standard.sample(rng));
379		}
380		let ref diag = *diag;
381
382		let ref mut out = Mat::<f64>::zeros(2, sol.ncols());
383		let mut params = BicgParams::default();
384		params.max_iters = 10;
385		let result = bicgstab(
386			out.as_mut(),
387			diag.as_ref(),
388			diag.as_ref(),
389			A.as_ref(),
390			rhs.as_ref(),
391			params,
392			|_| {},
393			Par::Seq,
394			MemStack::new(&mut MemBuffer::new(bicgstab_scratch(
395				diag.as_ref(),
396				diag.as_ref(),
397				A.as_ref(),
398				sol.ncols(),
399				Par::Seq,
400			))),
401		);
402		let ref out = *out;
403
404		assert!(result.is_ok());
405		assert!((A * out - rhs).norm_l2() <= params.rel_tolerance * rhs.norm_l2());
406	}
407}