faer/operator/
lsmr.rs

1use super::*;
2use crate::{assert, debug_assert};
3use linalg::matmul::matmul;
4use linalg::{householder, qr};
5
6/// algorithm parameters
7#[derive(Copy, Clone, Debug)]
8pub struct LsmrParams<T> {
9	/// whether the initial guess is implicitly zero or not
10	pub initial_guess: InitialGuessStatus,
11	/// absolute tolerance for convergence testing
12	pub abs_tolerance: T,
13	/// relative tolerance for convergence testing
14	pub rel_tolerance: T,
15	/// maximum number of iterations
16	pub max_iters: usize,
17
18	#[doc(hidden)]
19	pub non_exhaustive: NonExhaustive,
20}
21
22impl<T: RealField> Default for LsmrParams<T> {
23	#[inline]
24	fn default() -> Self {
25		Self {
26			initial_guess: InitialGuessStatus::MaybeNonZero,
27			abs_tolerance: zero(),
28			rel_tolerance: eps::<T>() * from_f64::<T>(128.0),
29			max_iters: usize::MAX,
30			non_exhaustive: NonExhaustive(()),
31		}
32	}
33}
34
35/// algorithm result
36#[derive(Copy, Clone, Debug)]
37pub struct LsmrInfo<T> {
38	/// absolute residual at the final step
39	pub abs_residual: T,
40	/// relative residual at the final step
41	pub rel_residual: T,
42	/// number of iterations executed by the algorithm
43	pub iter_count: usize,
44
45	#[doc(hidden)]
46	#[doc(hidden)]
47	pub non_exhaustive: NonExhaustive,
48}
49
50/// algorithm error
51#[derive(Copy, Clone, Debug)]
52pub enum LsmrError<T> {
53	/// convergence failure
54	NoConvergence {
55		/// absolute residual at the final step
56		abs_residual: T,
57		/// relative residual at the final step
58		rel_residual: T,
59	},
60}
61
62/// computes the size and alignment of required workspace for executing the lsmr
63/// algorithm
64pub fn lsmr_scratch<T: ComplexField>(right_precond: impl BiPrecond<T>, mat: impl BiLinOp<T>, rhs_ncols: usize, par: Par) -> StackReq {
65	fn implementation<T: ComplexField>(M: &dyn BiPrecond<T>, A: &dyn BiLinOp<T>, rhs_ncols: usize, par: Par) -> StackReq {
66		let m = A.nrows();
67		let n = A.ncols();
68		let mut k = rhs_ncols;
69
70		assert!(k < isize::MAX as usize);
71		if k > n {
72			k = k.msrv_checked_next_multiple_of(n).unwrap();
73		}
74		assert!(k < isize::MAX as usize);
75
76		let s = Ord::min(k, Ord::min(n, m));
77
78		let mk = temp_mat_scratch::<T>(m, k);
79		let nk = temp_mat_scratch::<T>(n, k);
80		let ss = temp_mat_scratch::<T>(s, s);
81		let ss2 = temp_mat_scratch::<T>(2 * s, 2 * s);
82		let sk = temp_mat_scratch::<T>(s, k);
83		let sk2 = temp_mat_scratch::<T>(2 * s, 2 * k);
84
85		let ms_bs = qr::no_pivoting::factor::recommended_blocksize::<T>(m, s);
86		let ns_bs = qr::no_pivoting::factor::recommended_blocksize::<T>(n, s);
87		let ss_bs = qr::no_pivoting::factor::recommended_blocksize::<T>(2 * s, 2 * s);
88
89		let AT = A.transpose_apply_scratch(k, par);
90		let A = A.apply_scratch(k, par);
91		let MT = M.transpose_apply_in_place_scratch(k, par);
92		let M = M.apply_in_place_scratch(k, par);
93
94		let m_qr = StackReq::any_of(&[
95			temp_mat_scratch::<T>(ms_bs, s),
96			qr::no_pivoting::factor::qr_in_place_scratch::<T>(m, s, ms_bs, par, Default::default()),
97			householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(m, ms_bs, s),
98		]);
99
100		let n_qr = StackReq::any_of(&[
101			temp_mat_scratch::<T>(ns_bs, s),
102			qr::no_pivoting::factor::qr_in_place_scratch::<T>(n, s, ns_bs, par, Default::default()),
103			householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(n, ns_bs, s),
104		]);
105
106		let s_qr = StackReq::any_of(&[
107			temp_mat_scratch::<T>(ss_bs, s),
108			qr::no_pivoting::factor::qr_in_place_scratch::<T>(2 * s, 2 * s, ss_bs, par, Default::default()),
109			householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(2 * s, ss_bs, 2 * s),
110		]);
111
112		StackReq::all_of(&[
113			mk,  // u
114			nk,  // v
115			sk,  // beta
116			sk,  // alpha
117			sk,  // zetabar
118			sk,  // alphabar
119			sk,  // theta
120			sk2, // pbar_adjoint
121			nk,  // vold
122			StackReq::any_of(&[StackReq::all_of(&[mk, StackReq::any_of(&[A, M, m_qr])])]),
123			StackReq::any_of(&[StackReq::all_of(&[nk, StackReq::any_of(&[AT, MT, n_qr])])]),
124			ss2, // p_adjoint
125			ss,  // rho
126			ss,  // thetaold
127			ss,  // rhobar
128			ss,  // thetabar
129			ss,  // zeta
130			ss,  // zetabar
131			StackReq::all_of(&[temp_mat_scratch::<T>(2 * s, 2 * s), s_qr]),
132		])
133	}
134
135	implementation(&right_precond, &mat, rhs_ncols, par)
136}
137
138/// executes lsmr using the provided preconditioner
139///
140/// # note
141/// this function is also optimized for a rhs with multiple columns
142#[track_caller]
143pub fn lsmr<T: ComplexField>(
144	out: MatMut<'_, T>,
145	right_precond: impl BiPrecond<T>,
146	mat: impl BiLinOp<T>,
147	rhs: MatRef<'_, T>,
148	params: LsmrParams<T::Real>,
149	callback: impl FnMut(MatRef<'_, T>),
150	par: Par,
151	stack: &mut MemStack,
152) -> Result<LsmrInfo<T::Real>, LsmrError<T::Real>> {
153	#[track_caller]
154	#[math]
155	fn implementation<T: ComplexField>(
156		mut x: MatMut<'_, T>,
157		M: &impl BiPrecond<T>,
158		A: &impl BiLinOp<T>,
159		b: MatRef<'_, T>,
160		params: LsmrParams<T::Real>,
161		callback: &mut dyn FnMut(MatRef<'_, T>),
162		par: Par,
163		stack: &mut MemStack,
164	) -> Result<LsmrInfo<T::Real>, LsmrError<T::Real>> {
165		fn thin_qr<T: ComplexField>(mut Q: MatMut<'_, T>, mut R: MatMut<'_, T>, mut mat: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
166			let k = R.nrows();
167			let bs = qr::no_pivoting::factor::recommended_blocksize::<T>(mat.nrows(), mat.ncols());
168			let (mut house, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(bs, Ord::min(mat.nrows(), mat.ncols()), stack) };
169			let mut house = house.as_mat_mut();
170
171			qr::no_pivoting::factor::qr_in_place(mat.rb_mut(), house.rb_mut(), par, stack.rb_mut(), Default::default());
172
173			R.fill(zero());
174			R.copy_from_triangular_upper(mat.rb().get(..k, ..k));
175			Q.fill(zero());
176			Q.rb_mut().diagonal_mut().column_vector_mut().fill(one::<T>());
177			householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
178				mat.rb(),
179				house.rb(),
180				Conj::No,
181				Q.rb_mut(),
182				par,
183				stack.rb_mut(),
184			);
185		}
186
187		let m = A.nrows();
188		let n = A.ncols();
189		let mut k = b.ncols();
190		{
191			let out = x.rb();
192			let mat = A;
193			let right_precond = M;
194			let rhs = b;
195			assert!(all(
196				right_precond.nrows() == mat.ncols(),
197				right_precond.ncols() == mat.ncols(),
198				rhs.nrows() == mat.nrows(),
199				out.nrows() == mat.ncols(),
200				out.ncols() == rhs.ncols(),
201			));
202		}
203
204		if m == 0 || n == 0 || k == 0 || core::mem::size_of::<T::Unit>() == 0 {
205			x.fill(zero());
206			return Ok(LsmrInfo {
207				abs_residual: zero::<T::Real>(),
208				rel_residual: zero::<T::Real>(),
209				iter_count: 0,
210				non_exhaustive: NonExhaustive(()),
211			});
212		}
213
214		debug_assert!(all(m < isize::MAX as usize, n < isize::MAX as usize, k < isize::MAX as usize));
215		let actual_k = k;
216		if k > n {
217			// pad to avoid last block slowing down the rest
218			k = k.msrv_checked_next_multiple_of(n).unwrap();
219		}
220		debug_assert!(k < isize::MAX as usize);
221
222		let s = Ord::min(k, Ord::min(n, m));
223
224		let mut stack = stack;
225
226		let (mut u, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(m, k, stack.rb_mut()) };
227		let mut u = u.as_mat_mut();
228		let (mut beta, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, k, stack.rb_mut()) };
229		let mut beta = beta.as_mat_mut();
230
231		let (mut v, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
232		let mut v = v.as_mat_mut();
233		let (mut alpha, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, k, stack.rb_mut()) };
234		let mut alpha = alpha.as_mat_mut();
235
236		let (mut zetabar, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, k, stack.rb_mut()) };
237		let mut zetabar = zetabar.as_mat_mut();
238		let (mut alphabar, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, k, stack.rb_mut()) };
239		let mut alphabar = alphabar.as_mat_mut();
240		let (mut theta, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, k, stack.rb_mut()) };
241		let mut theta = theta.as_mat_mut();
242		let (mut pbar_adjoint, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(2 * s, 2 * k, stack.rb_mut()) };
243		let mut pbar_adjoint = pbar_adjoint.as_mat_mut();
244
245		let (mut w, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
246		let mut w = w.as_mat_mut();
247		let (mut wbar, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
248		let mut wbar = wbar.as_mat_mut();
249
250		{
251			let (mut qr, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(m, k, stack.rb_mut()) };
252			let mut qr = qr.as_mat_mut();
253			if params.initial_guess == InitialGuessStatus::Zero {
254				qr.rb_mut().get_mut(.., ..actual_k).copy_from(b);
255				qr.rb_mut().get_mut(.., actual_k..).fill(zero());
256			} else {
257				A.apply(qr.rb_mut().rb_mut().get_mut(.., ..actual_k), x.rb(), par, stack.rb_mut());
258				z!(qr.rb_mut().get_mut(.., ..actual_k), &b).for_each(|uz!(ax, b)| *ax = *b - *ax);
259				qr.rb_mut().get_mut(.., actual_k..).fill(zero());
260			}
261			let mut start = 0;
262			while start < k {
263				let end = Ord::min(k - start, s) + start;
264				let len = end - start;
265				thin_qr(
266					u.rb_mut().get_mut(.., start..end),
267					beta.rb_mut().get_mut(..len, start..end),
268					qr.rb_mut().get_mut(.., start..end),
269					par,
270					stack.rb_mut(),
271				);
272				start = end;
273			}
274		}
275
276		{
277			let (mut qr, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
278			let mut qr = qr.as_mat_mut();
279			A.adjoint_apply(qr.rb_mut(), u.rb(), par, stack.rb_mut());
280			M.adjoint_apply_in_place(qr.rb_mut(), par, stack.rb_mut());
281			let mut start = 0;
282			while start < k {
283				let end = Ord::min(k - start, s) + start;
284				let len = end - start;
285				thin_qr(
286					v.rb_mut().get_mut(.., start..end),
287					alpha.rb_mut().get_mut(..len, start..end),
288					qr.rb_mut().get_mut(.., start..end),
289					par,
290					stack.rb_mut(),
291				);
292				start = end;
293			}
294		}
295
296		zetabar.fill(zero());
297		let mut start = 0;
298		while start < k {
299			let end = Ord::min(k - start, s) + start;
300			let len = end - start;
301			matmul(
302				zetabar.rb_mut().get_mut(..len, start..end),
303				Accum::Replace,
304				alpha.rb().get(..len, start..end),
305				beta.rb().get(..len, start..end),
306				one::<T>(),
307				par,
308			);
309			start = end;
310		}
311		alphabar.copy_from(&alpha);
312		pbar_adjoint.fill(zero());
313		let mut start = 0;
314		while start < k {
315			let end = Ord::min(k - start, s) + start;
316			let len = end - start;
317			pbar_adjoint
318				.rb_mut()
319				.get_mut(..2 * len, 2 * start..2 * end)
320				.diagonal_mut()
321				.column_vector_mut()
322				.fill(one());
323			start = end;
324		}
325		theta.fill(zero());
326		w.fill(zero());
327		wbar.fill(zero());
328
329		let mut norm;
330		let norm_ref = if params.initial_guess == InitialGuessStatus::Zero {
331			norm = zetabar.norm_l2();
332			copy(norm)
333		} else {
334			norm = zetabar.norm_l2();
335			let (mut tmp, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, actual_k, stack.rb_mut()) };
336			let mut tmp = tmp.as_mat_mut();
337			A.adjoint_apply(tmp.rb_mut(), b, par, stack.rb_mut());
338			M.adjoint_apply_in_place(tmp.rb_mut(), par, stack.rb_mut());
339			tmp.norm_l2()
340		};
341		let threshold = norm_ref * params.rel_tolerance;
342
343		if norm_ref == zero::<T::Real>() {
344			x.fill(zero());
345			return Ok(LsmrInfo {
346				abs_residual: zero::<T::Real>(),
347				rel_residual: zero::<T::Real>(),
348				iter_count: 0,
349				non_exhaustive: NonExhaustive(()),
350			});
351		}
352
353		if norm <= threshold {
354			return Ok(LsmrInfo {
355				abs_residual: zero::<T::Real>(),
356				rel_residual: zero::<T::Real>(),
357				iter_count: 0,
358				non_exhaustive: NonExhaustive(()),
359			});
360		}
361
362		for iter in 0..params.max_iters {
363			let (mut vold, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
364			let mut vold = vold.as_mat_mut();
365			{
366				let (mut qr, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(m, k, stack.rb_mut()) };
367				let mut qr = qr.as_mat_mut();
368				vold.copy_from(&v);
369				M.apply_in_place(v.rb_mut(), par, stack.rb_mut());
370				A.apply(qr.rb_mut(), v.rb(), par, stack.rb_mut());
371
372				let mut start = 0;
373				while start < k {
374					let s = Ord::min(k - start, s);
375					let end = start + s;
376					matmul(
377						qr.rb_mut().get_mut(.., start..end),
378						Accum::Add,
379						u.rb().get(.., start..end),
380						alpha.rb().get(..s, start..end).adjoint(),
381						-one::<T>(),
382						par,
383					);
384					thin_qr(
385						u.rb_mut().get_mut(.., start..end),
386						beta.rb_mut().get_mut(..s, start..end),
387						qr.rb_mut().get_mut(.., start..end),
388						par,
389						stack.rb_mut(),
390					);
391					start = end;
392				}
393			}
394
395			{
396				let (mut qr, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
397				let mut qr = qr.as_mat_mut();
398				A.adjoint_apply(qr.rb_mut(), u.rb(), par, stack.rb_mut());
399				M.adjoint_apply_in_place(qr.rb_mut(), par, stack.rb_mut());
400
401				let mut start = 0;
402				while start < k {
403					let s = Ord::min(k - start, s);
404					let end = start + s;
405					matmul(
406						qr.rb_mut().get_mut(.., start..end),
407						Accum::Add,
408						vold.rb().get(.., start..end),
409						beta.rb().get(..s, start..end).adjoint(),
410						-one::<T>(),
411						par,
412					);
413
414					// now contains M v_old
415					vold.rb_mut().get_mut(.., start..end).copy_from(v.rb().get(.., start..end));
416
417					thin_qr(
418						v.rb_mut().get_mut(.., start..end),
419						alpha.rb_mut().get_mut(..s, start..end),
420						qr.rb_mut().get_mut(.., start..end),
421						par,
422						stack.rb_mut(),
423					);
424					start = end;
425				}
426			}
427
428			let mut Mvold = vold;
429
430			let mut start = 0;
431			while start < k {
432				let s = Ord::min(k - start, s);
433				let end = start + s;
434
435				let mut x = x.rb_mut().get_mut(.., start..Ord::min(actual_k, end));
436				let mut Mvold = Mvold.rb_mut().get_mut(.., start..end);
437				let mut w = w.rb_mut().get_mut(.., start..end);
438				let mut wbar = wbar.rb_mut().get_mut(.., start..end);
439
440				let alpha = alpha.rb_mut().get_mut(..s, start..end);
441				let beta = beta.rb_mut().get_mut(..s, start..end);
442				let mut zetabar = zetabar.rb_mut().get_mut(..s, start..end);
443				let mut alphabar = alphabar.rb_mut().get_mut(..s, start..end);
444				let mut theta = theta.rb_mut().get_mut(..s, start..end);
445				let mut pbar_adjoint = pbar_adjoint.rb_mut().get_mut(..2 * s, 2 * start..2 * end);
446
447				let (mut p_adjoint, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(2 * s, 2 * s, stack.rb_mut()) };
448				let mut p_adjoint = p_adjoint.as_mat_mut();
449
450				let (mut rho, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, s, stack.rb_mut()) };
451				let mut rho = rho.as_mat_mut();
452				let (mut thetaold, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, s, stack.rb_mut()) };
453				let mut thetaold = thetaold.as_mat_mut();
454				let (mut rhobar, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, s, stack.rb_mut()) };
455				let mut rhobar = rhobar.as_mat_mut();
456				let (mut thetabar, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, s, stack.rb_mut()) };
457				let mut thetabar = thetabar.as_mat_mut();
458				let (mut zeta, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, s, stack.rb_mut()) };
459				let mut zeta = zeta.as_mat_mut();
460				let (mut zetabar_tmp, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, s, stack.rb_mut()) };
461				let mut zetabar_tmp = zetabar_tmp.as_mat_mut();
462
463				{
464					let (mut qr, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(2 * s, s, stack.rb_mut()) };
465					let mut qr = qr.as_mat_mut();
466					qr.rb_mut().get_mut(..s, ..).copy_from(alphabar.rb().adjoint());
467					qr.rb_mut().get_mut(s.., ..).copy_from(&beta);
468					thin_qr(p_adjoint.rb_mut(), rho.rb_mut(), qr.rb_mut(), par, stack.rb_mut());
469				}
470
471				thetaold.copy_from(&theta);
472				matmul(theta.rb_mut(), Accum::Replace, alpha.rb(), p_adjoint.rb().get(s.., ..s), one::<T>(), par);
473				matmul(
474					alphabar.rb_mut(),
475					Accum::Replace,
476					alpha.rb(),
477					p_adjoint.rb().get(s.., s..),
478					one::<T>(),
479					par,
480				);
481
482				matmul(
483					thetabar.rb_mut(),
484					Accum::Replace,
485					rho.rb(),
486					pbar_adjoint.rb().get(s.., ..s),
487					one::<T>(),
488					par,
489				);
490				{
491					let (mut qr, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(2 * s, s, stack.rb_mut()) };
492					let mut qr = qr.as_mat_mut();
493					matmul(
494						qr.rb_mut().get_mut(..s, ..),
495						Accum::Replace,
496						pbar_adjoint.rb().adjoint().get(s.., s..),
497						rho.rb().adjoint(),
498						one::<T>(),
499						par,
500					);
501					qr.rb_mut().get_mut(s.., ..).copy_from(&theta);
502					thin_qr(pbar_adjoint.rb_mut(), rhobar.rb_mut(), qr.rb_mut(), par, stack.rb_mut());
503				}
504
505				matmul(
506					zeta.rb_mut(),
507					Accum::Replace,
508					pbar_adjoint.rb().adjoint().get(..s, ..s),
509					zetabar.rb(),
510					one::<T>(),
511					par,
512				);
513				matmul(
514					zetabar_tmp.rb_mut(),
515					Accum::Replace,
516					pbar_adjoint.rb().adjoint().get(s.., ..s),
517					zetabar.rb(),
518					one::<T>(),
519					par,
520				);
521				zetabar.copy_from(&zetabar_tmp);
522
523				matmul(Mvold.rb_mut(), Accum::Add, w.rb(), thetaold.rb().adjoint(), -one::<T>(), par);
524				crate::linalg::triangular_solve::solve_lower_triangular_in_place(rho.rb().transpose(), Mvold.rb_mut().transpose_mut(), par);
525				w.copy_from(&Mvold);
526
527				matmul(Mvold.rb_mut(), Accum::Add, wbar.rb(), thetabar.rb().adjoint(), -one::<T>(), par);
528				crate::linalg::triangular_solve::solve_lower_triangular_in_place(rhobar.rb().transpose(), Mvold.rb_mut().transpose_mut(), par);
529				wbar.copy_from(&Mvold);
530
531				let actual_s = x.ncols();
532				matmul(
533					x.rb_mut(),
534					if iter == 0 && params.initial_guess == InitialGuessStatus::Zero {
535						Accum::Replace
536					} else {
537						Accum::Add
538					},
539					wbar.rb(),
540					zeta.rb().get(.., ..actual_s),
541					one::<T>(),
542					par,
543				);
544				start = end;
545			}
546			norm = zetabar.norm_l2();
547			callback(x.rb());
548			if norm <= threshold {
549				return Ok(LsmrInfo {
550					rel_residual: norm / norm_ref,
551					abs_residual: norm,
552					iter_count: iter + 1,
553					non_exhaustive: NonExhaustive(()),
554				});
555			}
556		}
557
558		Err(LsmrError::NoConvergence {
559			rel_residual: norm / norm_ref,
560			abs_residual: norm,
561		})
562	}
563	implementation(out, &right_precond, &mat, rhs, params, &mut { callback }, par, stack)
564}
565
566#[cfg(test)]
567mod tests {
568	use super::*;
569	use crate::stats::prelude::*;
570	use dyn_stack::MemBuffer;
571	use equator::assert;
572
573	#[test]
574	fn test_lsmr() {
575		let ref mut rng = StdRng::seed_from_u64(0);
576		let m = 100;
577		let n = 80;
578		for k in [1, 2, 4, 7, 10, 40, 80, 100] {
579			let A: Mat<c64> = CwiseMatDistribution {
580				nrows: m,
581				ncols: n,
582				dist: ComplexDistribution::new(StandardNormal, StandardNormal),
583			}
584			.sample(rng);
585			let b: Mat<c64> = CwiseMatDistribution {
586				nrows: m,
587				ncols: k,
588				dist: ComplexDistribution::new(StandardNormal, StandardNormal),
589			}
590			.sample(rng);
591			let k = b.ncols();
592
593			let ref mut diag = Scale(c64::new(2.0, 0.0)) * Mat::<c64>::identity(n, n);
594			for i in 0..n {
595				diag[(i, i)] = (128.0 * f64::exp(rand::distributions::Standard.sample(rng))).into();
596			}
597			for i in 0..n - 1 {
598				diag[(i + 1, i)] = f64::exp(rand::distributions::Standard.sample(rng)).into();
599			}
600
601			let params = LsmrParams::default();
602
603			let rand = CwiseMatDistribution {
604				nrows: n,
605				ncols: k,
606				dist: ComplexDistribution::new(StandardNormal, StandardNormal),
607			};
608			let mut out = rand.sample(rng);
609
610			let result = lsmr(
611				out.as_mut(),
612				diag.as_ref(),
613				A.as_ref(),
614				b.as_ref(),
615				params,
616				|_| {},
617				Par::Seq,
618				MemStack::new(&mut MemBuffer::new(lsmr_scratch(diag.as_ref(), A.as_ref(), k, Par::Seq))),
619			);
620			assert!(result.is_ok());
621			let result = result.unwrap();
622			assert!(result.iter_count <= (4 * n).msrv_div_ceil(Ord::min(k, n)));
623		}
624	}
625
626	#[test]
627	fn test_breakdown() {
628		let ref mut rng = StdRng::seed_from_u64(0);
629		let m = 100;
630		let n = 80;
631		for k in [1, 2, 4, 7, 10, 40, 80, 100] {
632			let A: Mat<c64> = CwiseMatDistribution {
633				nrows: m,
634				ncols: n,
635				dist: ComplexDistribution::new(StandardNormal, StandardNormal),
636			}
637			.sample(rng);
638			let b: Mat<c64> = CwiseMatDistribution {
639				nrows: m,
640				ncols: k,
641				dist: ComplexDistribution::new(StandardNormal, StandardNormal),
642			}
643			.sample(rng);
644			let b = crate::concat![[b, b]];
645			let k = b.ncols();
646
647			let ref mut diag = Scale(c64::new(2.0, 0.0)) * Mat::<c64>::identity(n, n);
648			for i in 0..n {
649				diag[(i, i)] = (128.0 * f64::exp(rand::distributions::Standard.sample(rng))).into();
650			}
651			for i in 0..n - 1 {
652				diag[(i + 1, i)] = f64::exp(rand::distributions::Standard.sample(rng)).into();
653			}
654
655			let params = LsmrParams::default();
656
657			let rand = CwiseMatDistribution {
658				nrows: n,
659				ncols: k,
660				dist: ComplexDistribution::new(StandardNormal, StandardNormal),
661			};
662			let mut out = rand.sample(rng);
663			let result = lsmr(
664				out.as_mut(),
665				diag.as_ref(),
666				A.as_ref(),
667				b.as_ref(),
668				params,
669				|_| {},
670				Par::Seq,
671				MemStack::new(&mut MemBuffer::new(lsmr_scratch(diag.as_ref(), A.as_ref(), k, Par::Seq))),
672			);
673			assert!(result.is_ok());
674			let result = result.unwrap();
675			assert!(result.iter_count <= (4 * n).msrv_div_ceil(Ord::min(k, n)));
676		}
677	}
678}