1use super::*;
2use crate::{assert, debug_assert};
3use linalg::matmul::matmul;
4use linalg::{householder, qr};
5
6#[derive(Copy, Clone, Debug)]
8pub struct LsmrParams<T> {
9 pub initial_guess: InitialGuessStatus,
11 pub abs_tolerance: T,
13 pub rel_tolerance: T,
15 pub max_iters: usize,
17
18 #[doc(hidden)]
19 pub non_exhaustive: NonExhaustive,
20}
21
22impl<T: RealField> Default for LsmrParams<T> {
23 #[inline]
24 fn default() -> Self {
25 Self {
26 initial_guess: InitialGuessStatus::MaybeNonZero,
27 abs_tolerance: zero(),
28 rel_tolerance: eps::<T>() * from_f64::<T>(128.0),
29 max_iters: usize::MAX,
30 non_exhaustive: NonExhaustive(()),
31 }
32 }
33}
34
35#[derive(Copy, Clone, Debug)]
37pub struct LsmrInfo<T> {
38 pub abs_residual: T,
40 pub rel_residual: T,
42 pub iter_count: usize,
44
45 #[doc(hidden)]
46 #[doc(hidden)]
47 pub non_exhaustive: NonExhaustive,
48}
49
50#[derive(Copy, Clone, Debug)]
52pub enum LsmrError<T> {
53 NoConvergence {
55 abs_residual: T,
57 rel_residual: T,
59 },
60}
61
62pub fn lsmr_scratch<T: ComplexField>(right_precond: impl BiPrecond<T>, mat: impl BiLinOp<T>, rhs_ncols: usize, par: Par) -> StackReq {
65 fn implementation<T: ComplexField>(M: &dyn BiPrecond<T>, A: &dyn BiLinOp<T>, rhs_ncols: usize, par: Par) -> StackReq {
66 let m = A.nrows();
67 let n = A.ncols();
68 let mut k = rhs_ncols;
69
70 assert!(k < isize::MAX as usize);
71 if k > n {
72 k = k.msrv_checked_next_multiple_of(n).unwrap();
73 }
74 assert!(k < isize::MAX as usize);
75
76 let s = Ord::min(k, Ord::min(n, m));
77
78 let mk = temp_mat_scratch::<T>(m, k);
79 let nk = temp_mat_scratch::<T>(n, k);
80 let ss = temp_mat_scratch::<T>(s, s);
81 let ss2 = temp_mat_scratch::<T>(2 * s, 2 * s);
82 let sk = temp_mat_scratch::<T>(s, k);
83 let sk2 = temp_mat_scratch::<T>(2 * s, 2 * k);
84
85 let ms_bs = qr::no_pivoting::factor::recommended_blocksize::<T>(m, s);
86 let ns_bs = qr::no_pivoting::factor::recommended_blocksize::<T>(n, s);
87 let ss_bs = qr::no_pivoting::factor::recommended_blocksize::<T>(2 * s, 2 * s);
88
89 let AT = A.transpose_apply_scratch(k, par);
90 let A = A.apply_scratch(k, par);
91 let MT = M.transpose_apply_in_place_scratch(k, par);
92 let M = M.apply_in_place_scratch(k, par);
93
94 let m_qr = StackReq::any_of(&[
95 temp_mat_scratch::<T>(ms_bs, s),
96 qr::no_pivoting::factor::qr_in_place_scratch::<T>(m, s, ms_bs, par, Default::default()),
97 householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(m, ms_bs, s),
98 ]);
99
100 let n_qr = StackReq::any_of(&[
101 temp_mat_scratch::<T>(ns_bs, s),
102 qr::no_pivoting::factor::qr_in_place_scratch::<T>(n, s, ns_bs, par, Default::default()),
103 householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(n, ns_bs, s),
104 ]);
105
106 let s_qr = StackReq::any_of(&[
107 temp_mat_scratch::<T>(ss_bs, s),
108 qr::no_pivoting::factor::qr_in_place_scratch::<T>(2 * s, 2 * s, ss_bs, par, Default::default()),
109 householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(2 * s, ss_bs, 2 * s),
110 ]);
111
112 StackReq::all_of(&[
113 mk, nk, sk, sk, sk, sk, sk, sk2, nk, StackReq::any_of(&[StackReq::all_of(&[mk, StackReq::any_of(&[A, M, m_qr])])]),
123 StackReq::any_of(&[StackReq::all_of(&[nk, StackReq::any_of(&[AT, MT, n_qr])])]),
124 ss2, ss, ss, ss, ss, ss, ss, StackReq::all_of(&[temp_mat_scratch::<T>(2 * s, 2 * s), s_qr]),
132 ])
133 }
134
135 implementation(&right_precond, &mat, rhs_ncols, par)
136}
137
138#[track_caller]
143pub fn lsmr<T: ComplexField>(
144 out: MatMut<'_, T>,
145 right_precond: impl BiPrecond<T>,
146 mat: impl BiLinOp<T>,
147 rhs: MatRef<'_, T>,
148 params: LsmrParams<T::Real>,
149 callback: impl FnMut(MatRef<'_, T>),
150 par: Par,
151 stack: &mut MemStack,
152) -> Result<LsmrInfo<T::Real>, LsmrError<T::Real>> {
153 #[track_caller]
154 #[math]
155 fn implementation<T: ComplexField>(
156 mut x: MatMut<'_, T>,
157 M: &impl BiPrecond<T>,
158 A: &impl BiLinOp<T>,
159 b: MatRef<'_, T>,
160 params: LsmrParams<T::Real>,
161 callback: &mut dyn FnMut(MatRef<'_, T>),
162 par: Par,
163 stack: &mut MemStack,
164 ) -> Result<LsmrInfo<T::Real>, LsmrError<T::Real>> {
165 fn thin_qr<T: ComplexField>(mut Q: MatMut<'_, T>, mut R: MatMut<'_, T>, mut mat: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
166 let k = R.nrows();
167 let bs = qr::no_pivoting::factor::recommended_blocksize::<T>(mat.nrows(), mat.ncols());
168 let (mut house, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(bs, Ord::min(mat.nrows(), mat.ncols()), stack) };
169 let mut house = house.as_mat_mut();
170
171 qr::no_pivoting::factor::qr_in_place(mat.rb_mut(), house.rb_mut(), par, stack.rb_mut(), Default::default());
172
173 R.fill(zero());
174 R.copy_from_triangular_upper(mat.rb().get(..k, ..k));
175 Q.fill(zero());
176 Q.rb_mut().diagonal_mut().column_vector_mut().fill(one::<T>());
177 householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
178 mat.rb(),
179 house.rb(),
180 Conj::No,
181 Q.rb_mut(),
182 par,
183 stack.rb_mut(),
184 );
185 }
186
187 let m = A.nrows();
188 let n = A.ncols();
189 let mut k = b.ncols();
190 {
191 let out = x.rb();
192 let mat = A;
193 let right_precond = M;
194 let rhs = b;
195 assert!(all(
196 right_precond.nrows() == mat.ncols(),
197 right_precond.ncols() == mat.ncols(),
198 rhs.nrows() == mat.nrows(),
199 out.nrows() == mat.ncols(),
200 out.ncols() == rhs.ncols(),
201 ));
202 }
203
204 if m == 0 || n == 0 || k == 0 || core::mem::size_of::<T::Unit>() == 0 {
205 x.fill(zero());
206 return Ok(LsmrInfo {
207 abs_residual: zero::<T::Real>(),
208 rel_residual: zero::<T::Real>(),
209 iter_count: 0,
210 non_exhaustive: NonExhaustive(()),
211 });
212 }
213
214 debug_assert!(all(m < isize::MAX as usize, n < isize::MAX as usize, k < isize::MAX as usize));
215 let actual_k = k;
216 if k > n {
217 k = k.msrv_checked_next_multiple_of(n).unwrap();
219 }
220 debug_assert!(k < isize::MAX as usize);
221
222 let s = Ord::min(k, Ord::min(n, m));
223
224 let mut stack = stack;
225
226 let (mut u, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(m, k, stack.rb_mut()) };
227 let mut u = u.as_mat_mut();
228 let (mut beta, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, k, stack.rb_mut()) };
229 let mut beta = beta.as_mat_mut();
230
231 let (mut v, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
232 let mut v = v.as_mat_mut();
233 let (mut alpha, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, k, stack.rb_mut()) };
234 let mut alpha = alpha.as_mat_mut();
235
236 let (mut zetabar, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, k, stack.rb_mut()) };
237 let mut zetabar = zetabar.as_mat_mut();
238 let (mut alphabar, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, k, stack.rb_mut()) };
239 let mut alphabar = alphabar.as_mat_mut();
240 let (mut theta, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, k, stack.rb_mut()) };
241 let mut theta = theta.as_mat_mut();
242 let (mut pbar_adjoint, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(2 * s, 2 * k, stack.rb_mut()) };
243 let mut pbar_adjoint = pbar_adjoint.as_mat_mut();
244
245 let (mut w, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
246 let mut w = w.as_mat_mut();
247 let (mut wbar, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
248 let mut wbar = wbar.as_mat_mut();
249
250 {
251 let (mut qr, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(m, k, stack.rb_mut()) };
252 let mut qr = qr.as_mat_mut();
253 if params.initial_guess == InitialGuessStatus::Zero {
254 qr.rb_mut().get_mut(.., ..actual_k).copy_from(b);
255 qr.rb_mut().get_mut(.., actual_k..).fill(zero());
256 } else {
257 A.apply(qr.rb_mut().rb_mut().get_mut(.., ..actual_k), x.rb(), par, stack.rb_mut());
258 z!(qr.rb_mut().get_mut(.., ..actual_k), &b).for_each(|uz!(ax, b)| *ax = *b - *ax);
259 qr.rb_mut().get_mut(.., actual_k..).fill(zero());
260 }
261 let mut start = 0;
262 while start < k {
263 let end = Ord::min(k - start, s) + start;
264 let len = end - start;
265 thin_qr(
266 u.rb_mut().get_mut(.., start..end),
267 beta.rb_mut().get_mut(..len, start..end),
268 qr.rb_mut().get_mut(.., start..end),
269 par,
270 stack.rb_mut(),
271 );
272 start = end;
273 }
274 }
275
276 {
277 let (mut qr, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
278 let mut qr = qr.as_mat_mut();
279 A.adjoint_apply(qr.rb_mut(), u.rb(), par, stack.rb_mut());
280 M.adjoint_apply_in_place(qr.rb_mut(), par, stack.rb_mut());
281 let mut start = 0;
282 while start < k {
283 let end = Ord::min(k - start, s) + start;
284 let len = end - start;
285 thin_qr(
286 v.rb_mut().get_mut(.., start..end),
287 alpha.rb_mut().get_mut(..len, start..end),
288 qr.rb_mut().get_mut(.., start..end),
289 par,
290 stack.rb_mut(),
291 );
292 start = end;
293 }
294 }
295
296 zetabar.fill(zero());
297 let mut start = 0;
298 while start < k {
299 let end = Ord::min(k - start, s) + start;
300 let len = end - start;
301 matmul(
302 zetabar.rb_mut().get_mut(..len, start..end),
303 Accum::Replace,
304 alpha.rb().get(..len, start..end),
305 beta.rb().get(..len, start..end),
306 one::<T>(),
307 par,
308 );
309 start = end;
310 }
311 alphabar.copy_from(&alpha);
312 pbar_adjoint.fill(zero());
313 let mut start = 0;
314 while start < k {
315 let end = Ord::min(k - start, s) + start;
316 let len = end - start;
317 pbar_adjoint
318 .rb_mut()
319 .get_mut(..2 * len, 2 * start..2 * end)
320 .diagonal_mut()
321 .column_vector_mut()
322 .fill(one());
323 start = end;
324 }
325 theta.fill(zero());
326 w.fill(zero());
327 wbar.fill(zero());
328
329 let mut norm;
330 let norm_ref = if params.initial_guess == InitialGuessStatus::Zero {
331 norm = zetabar.norm_l2();
332 copy(norm)
333 } else {
334 norm = zetabar.norm_l2();
335 let (mut tmp, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, actual_k, stack.rb_mut()) };
336 let mut tmp = tmp.as_mat_mut();
337 A.adjoint_apply(tmp.rb_mut(), b, par, stack.rb_mut());
338 M.adjoint_apply_in_place(tmp.rb_mut(), par, stack.rb_mut());
339 tmp.norm_l2()
340 };
341 let threshold = norm_ref * params.rel_tolerance;
342
343 if norm_ref == zero::<T::Real>() {
344 x.fill(zero());
345 return Ok(LsmrInfo {
346 abs_residual: zero::<T::Real>(),
347 rel_residual: zero::<T::Real>(),
348 iter_count: 0,
349 non_exhaustive: NonExhaustive(()),
350 });
351 }
352
353 if norm <= threshold {
354 return Ok(LsmrInfo {
355 abs_residual: zero::<T::Real>(),
356 rel_residual: zero::<T::Real>(),
357 iter_count: 0,
358 non_exhaustive: NonExhaustive(()),
359 });
360 }
361
362 for iter in 0..params.max_iters {
363 let (mut vold, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
364 let mut vold = vold.as_mat_mut();
365 {
366 let (mut qr, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(m, k, stack.rb_mut()) };
367 let mut qr = qr.as_mat_mut();
368 vold.copy_from(&v);
369 M.apply_in_place(v.rb_mut(), par, stack.rb_mut());
370 A.apply(qr.rb_mut(), v.rb(), par, stack.rb_mut());
371
372 let mut start = 0;
373 while start < k {
374 let s = Ord::min(k - start, s);
375 let end = start + s;
376 matmul(
377 qr.rb_mut().get_mut(.., start..end),
378 Accum::Add,
379 u.rb().get(.., start..end),
380 alpha.rb().get(..s, start..end).adjoint(),
381 -one::<T>(),
382 par,
383 );
384 thin_qr(
385 u.rb_mut().get_mut(.., start..end),
386 beta.rb_mut().get_mut(..s, start..end),
387 qr.rb_mut().get_mut(.., start..end),
388 par,
389 stack.rb_mut(),
390 );
391 start = end;
392 }
393 }
394
395 {
396 let (mut qr, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack.rb_mut()) };
397 let mut qr = qr.as_mat_mut();
398 A.adjoint_apply(qr.rb_mut(), u.rb(), par, stack.rb_mut());
399 M.adjoint_apply_in_place(qr.rb_mut(), par, stack.rb_mut());
400
401 let mut start = 0;
402 while start < k {
403 let s = Ord::min(k - start, s);
404 let end = start + s;
405 matmul(
406 qr.rb_mut().get_mut(.., start..end),
407 Accum::Add,
408 vold.rb().get(.., start..end),
409 beta.rb().get(..s, start..end).adjoint(),
410 -one::<T>(),
411 par,
412 );
413
414 vold.rb_mut().get_mut(.., start..end).copy_from(v.rb().get(.., start..end));
416
417 thin_qr(
418 v.rb_mut().get_mut(.., start..end),
419 alpha.rb_mut().get_mut(..s, start..end),
420 qr.rb_mut().get_mut(.., start..end),
421 par,
422 stack.rb_mut(),
423 );
424 start = end;
425 }
426 }
427
428 let mut Mvold = vold;
429
430 let mut start = 0;
431 while start < k {
432 let s = Ord::min(k - start, s);
433 let end = start + s;
434
435 let mut x = x.rb_mut().get_mut(.., start..Ord::min(actual_k, end));
436 let mut Mvold = Mvold.rb_mut().get_mut(.., start..end);
437 let mut w = w.rb_mut().get_mut(.., start..end);
438 let mut wbar = wbar.rb_mut().get_mut(.., start..end);
439
440 let alpha = alpha.rb_mut().get_mut(..s, start..end);
441 let beta = beta.rb_mut().get_mut(..s, start..end);
442 let mut zetabar = zetabar.rb_mut().get_mut(..s, start..end);
443 let mut alphabar = alphabar.rb_mut().get_mut(..s, start..end);
444 let mut theta = theta.rb_mut().get_mut(..s, start..end);
445 let mut pbar_adjoint = pbar_adjoint.rb_mut().get_mut(..2 * s, 2 * start..2 * end);
446
447 let (mut p_adjoint, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(2 * s, 2 * s, stack.rb_mut()) };
448 let mut p_adjoint = p_adjoint.as_mat_mut();
449
450 let (mut rho, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, s, stack.rb_mut()) };
451 let mut rho = rho.as_mat_mut();
452 let (mut thetaold, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, s, stack.rb_mut()) };
453 let mut thetaold = thetaold.as_mat_mut();
454 let (mut rhobar, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, s, stack.rb_mut()) };
455 let mut rhobar = rhobar.as_mat_mut();
456 let (mut thetabar, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, s, stack.rb_mut()) };
457 let mut thetabar = thetabar.as_mat_mut();
458 let (mut zeta, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, s, stack.rb_mut()) };
459 let mut zeta = zeta.as_mat_mut();
460 let (mut zetabar_tmp, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(s, s, stack.rb_mut()) };
461 let mut zetabar_tmp = zetabar_tmp.as_mat_mut();
462
463 {
464 let (mut qr, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(2 * s, s, stack.rb_mut()) };
465 let mut qr = qr.as_mat_mut();
466 qr.rb_mut().get_mut(..s, ..).copy_from(alphabar.rb().adjoint());
467 qr.rb_mut().get_mut(s.., ..).copy_from(&beta);
468 thin_qr(p_adjoint.rb_mut(), rho.rb_mut(), qr.rb_mut(), par, stack.rb_mut());
469 }
470
471 thetaold.copy_from(&theta);
472 matmul(theta.rb_mut(), Accum::Replace, alpha.rb(), p_adjoint.rb().get(s.., ..s), one::<T>(), par);
473 matmul(
474 alphabar.rb_mut(),
475 Accum::Replace,
476 alpha.rb(),
477 p_adjoint.rb().get(s.., s..),
478 one::<T>(),
479 par,
480 );
481
482 matmul(
483 thetabar.rb_mut(),
484 Accum::Replace,
485 rho.rb(),
486 pbar_adjoint.rb().get(s.., ..s),
487 one::<T>(),
488 par,
489 );
490 {
491 let (mut qr, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(2 * s, s, stack.rb_mut()) };
492 let mut qr = qr.as_mat_mut();
493 matmul(
494 qr.rb_mut().get_mut(..s, ..),
495 Accum::Replace,
496 pbar_adjoint.rb().adjoint().get(s.., s..),
497 rho.rb().adjoint(),
498 one::<T>(),
499 par,
500 );
501 qr.rb_mut().get_mut(s.., ..).copy_from(&theta);
502 thin_qr(pbar_adjoint.rb_mut(), rhobar.rb_mut(), qr.rb_mut(), par, stack.rb_mut());
503 }
504
505 matmul(
506 zeta.rb_mut(),
507 Accum::Replace,
508 pbar_adjoint.rb().adjoint().get(..s, ..s),
509 zetabar.rb(),
510 one::<T>(),
511 par,
512 );
513 matmul(
514 zetabar_tmp.rb_mut(),
515 Accum::Replace,
516 pbar_adjoint.rb().adjoint().get(s.., ..s),
517 zetabar.rb(),
518 one::<T>(),
519 par,
520 );
521 zetabar.copy_from(&zetabar_tmp);
522
523 matmul(Mvold.rb_mut(), Accum::Add, w.rb(), thetaold.rb().adjoint(), -one::<T>(), par);
524 crate::linalg::triangular_solve::solve_lower_triangular_in_place(rho.rb().transpose(), Mvold.rb_mut().transpose_mut(), par);
525 w.copy_from(&Mvold);
526
527 matmul(Mvold.rb_mut(), Accum::Add, wbar.rb(), thetabar.rb().adjoint(), -one::<T>(), par);
528 crate::linalg::triangular_solve::solve_lower_triangular_in_place(rhobar.rb().transpose(), Mvold.rb_mut().transpose_mut(), par);
529 wbar.copy_from(&Mvold);
530
531 let actual_s = x.ncols();
532 matmul(
533 x.rb_mut(),
534 if iter == 0 && params.initial_guess == InitialGuessStatus::Zero {
535 Accum::Replace
536 } else {
537 Accum::Add
538 },
539 wbar.rb(),
540 zeta.rb().get(.., ..actual_s),
541 one::<T>(),
542 par,
543 );
544 start = end;
545 }
546 norm = zetabar.norm_l2();
547 callback(x.rb());
548 if norm <= threshold {
549 return Ok(LsmrInfo {
550 rel_residual: norm / norm_ref,
551 abs_residual: norm,
552 iter_count: iter + 1,
553 non_exhaustive: NonExhaustive(()),
554 });
555 }
556 }
557
558 Err(LsmrError::NoConvergence {
559 rel_residual: norm / norm_ref,
560 abs_residual: norm,
561 })
562 }
563 implementation(out, &right_precond, &mat, rhs, params, &mut { callback }, par, stack)
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569 use crate::stats::prelude::*;
570 use dyn_stack::MemBuffer;
571 use equator::assert;
572
573 #[test]
574 fn test_lsmr() {
575 let ref mut rng = StdRng::seed_from_u64(0);
576 let m = 100;
577 let n = 80;
578 for k in [1, 2, 4, 7, 10, 40, 80, 100] {
579 let A: Mat<c64> = CwiseMatDistribution {
580 nrows: m,
581 ncols: n,
582 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
583 }
584 .sample(rng);
585 let b: Mat<c64> = CwiseMatDistribution {
586 nrows: m,
587 ncols: k,
588 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
589 }
590 .sample(rng);
591 let k = b.ncols();
592
593 let ref mut diag = Scale(c64::new(2.0, 0.0)) * Mat::<c64>::identity(n, n);
594 for i in 0..n {
595 diag[(i, i)] = (128.0 * f64::exp(rand::distributions::Standard.sample(rng))).into();
596 }
597 for i in 0..n - 1 {
598 diag[(i + 1, i)] = f64::exp(rand::distributions::Standard.sample(rng)).into();
599 }
600
601 let params = LsmrParams::default();
602
603 let rand = CwiseMatDistribution {
604 nrows: n,
605 ncols: k,
606 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
607 };
608 let mut out = rand.sample(rng);
609
610 let result = lsmr(
611 out.as_mut(),
612 diag.as_ref(),
613 A.as_ref(),
614 b.as_ref(),
615 params,
616 |_| {},
617 Par::Seq,
618 MemStack::new(&mut MemBuffer::new(lsmr_scratch(diag.as_ref(), A.as_ref(), k, Par::Seq))),
619 );
620 assert!(result.is_ok());
621 let result = result.unwrap();
622 assert!(result.iter_count <= (4 * n).msrv_div_ceil(Ord::min(k, n)));
623 }
624 }
625
626 #[test]
627 fn test_breakdown() {
628 let ref mut rng = StdRng::seed_from_u64(0);
629 let m = 100;
630 let n = 80;
631 for k in [1, 2, 4, 7, 10, 40, 80, 100] {
632 let A: Mat<c64> = CwiseMatDistribution {
633 nrows: m,
634 ncols: n,
635 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
636 }
637 .sample(rng);
638 let b: Mat<c64> = CwiseMatDistribution {
639 nrows: m,
640 ncols: k,
641 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
642 }
643 .sample(rng);
644 let b = crate::concat![[b, b]];
645 let k = b.ncols();
646
647 let ref mut diag = Scale(c64::new(2.0, 0.0)) * Mat::<c64>::identity(n, n);
648 for i in 0..n {
649 diag[(i, i)] = (128.0 * f64::exp(rand::distributions::Standard.sample(rng))).into();
650 }
651 for i in 0..n - 1 {
652 diag[(i + 1, i)] = f64::exp(rand::distributions::Standard.sample(rng)).into();
653 }
654
655 let params = LsmrParams::default();
656
657 let rand = CwiseMatDistribution {
658 nrows: n,
659 ncols: k,
660 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
661 };
662 let mut out = rand.sample(rng);
663 let result = lsmr(
664 out.as_mut(),
665 diag.as_ref(),
666 A.as_ref(),
667 b.as_ref(),
668 params,
669 |_| {},
670 Par::Seq,
671 MemStack::new(&mut MemBuffer::new(lsmr_scratch(diag.as_ref(), A.as_ref(), k, Par::Seq))),
672 );
673 assert!(result.is_ok());
674 let result = result.unwrap();
675 assert!(result.iter_count <= (4 * n).msrv_div_ceil(Ord::min(k, n)));
676 }
677 }
678}