1use crate::internal_prelude::*;
2use crate::{assert, get_global_parallelism};
3use alloc::vec;
4use alloc::vec::Vec;
5use dyn_stack::MemBuffer;
6use faer_traits::{ComplexConj, math_utils};
7use linalg::svd::ComputeSvdVectors;
8
9pub use linalg::cholesky::ldlt::factor::LdltError;
10pub use linalg::cholesky::llt::factor::LltError;
11pub use linalg::evd::EvdError;
12pub use linalg::svd::SvdError;
13
14pub trait ShapeCore {
16 fn nrows(&self) -> usize;
18 fn ncols(&self) -> usize;
20}
21
22pub trait SolveCore<T: ComplexField>: ShapeCore {
24 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>);
27 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>);
30}
31pub trait SolveLstsqCore<T: ComplexField>: ShapeCore {
33 fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>);
36}
37pub trait DenseSolveCore<T: ComplexField>: SolveCore<T> {
39 fn reconstruct(&self) -> Mat<T>;
41 fn inverse(&self) -> Mat<T>;
44}
45
46impl<S: ?Sized + ShapeCore> ShapeCore for &S {
47 #[inline]
48 fn nrows(&self) -> usize {
49 (**self).nrows()
50 }
51
52 #[inline]
53 fn ncols(&self) -> usize {
54 (**self).ncols()
55 }
56}
57
58impl<T: ComplexField, S: ?Sized + SolveCore<T>> SolveCore<T> for &S {
59 #[inline]
60 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
61 (**self).solve_in_place_with_conj(conj, rhs)
62 }
63
64 #[inline]
65 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
66 (**self).solve_transpose_in_place_with_conj(conj, rhs)
67 }
68}
69
70impl<T: ComplexField, S: ?Sized + SolveLstsqCore<T>> SolveLstsqCore<T> for &S {
71 #[inline]
72 fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
73 (**self).solve_lstsq_in_place_with_conj(conj, rhs)
74 }
75}
76
77impl<T: ComplexField, S: ?Sized + DenseSolveCore<T>> DenseSolveCore<T> for &S {
78 #[inline]
79 fn reconstruct(&self) -> Mat<T> {
80 (**self).reconstruct()
81 }
82
83 #[inline]
84 fn inverse(&self) -> Mat<T> {
85 (**self).inverse()
86 }
87}
88
89pub trait Solve<T: ComplexField>: SolveCore<T> {
91 #[track_caller]
92 #[inline]
93 fn solve_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
95 self.solve_in_place_with_conj(Conj::No, { rhs }.as_mat_mut().as_dyn_cols_mut());
96 }
97 #[track_caller]
98 #[inline]
99 fn solve_conjugate_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
101 self.solve_in_place_with_conj(Conj::Yes, { rhs }.as_mat_mut().as_dyn_cols_mut());
102 }
103
104 #[track_caller]
105 #[inline]
106 fn solve_transpose_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
108 self.solve_transpose_in_place_with_conj(Conj::No, { rhs }.as_mat_mut().as_dyn_cols_mut());
109 }
110 #[track_caller]
111 #[inline]
112 fn solve_adjoint_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
114 self.solve_transpose_in_place_with_conj(Conj::Yes, { rhs }.as_mat_mut().as_dyn_cols_mut());
115 }
116
117 #[track_caller]
118 #[inline]
119 fn rsolve_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
121 self.solve_transpose_in_place_with_conj(Conj::No, { lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut());
122 }
123 #[track_caller]
124 #[inline]
125 fn rsolve_conjugate_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
127 self.solve_transpose_in_place_with_conj(Conj::Yes, { lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut());
128 }
129
130 #[track_caller]
131 #[inline]
132 fn rsolve_transpose_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
134 self.solve_in_place_with_conj(Conj::No, { lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut());
135 }
136 #[track_caller]
137 #[inline]
138 fn rsolve_adjoint_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
140 self.solve_in_place_with_conj(Conj::Yes, { lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut());
141 }
142
143 #[track_caller]
144 #[inline]
145 fn solve<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
147 let rhs = rhs.as_mat_ref();
148 let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
149 out.as_mat_mut().copy_from(rhs);
150 self.solve_in_place(&mut out);
151 out
152 }
153 #[track_caller]
154 #[inline]
155 fn solve_conjugate<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
157 let rhs = rhs.as_mat_ref();
158 let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
159 out.as_mat_mut().copy_from(rhs);
160 self.solve_conjugate_in_place(&mut out);
161 out
162 }
163
164 #[track_caller]
165 #[inline]
166 fn solve_transpose<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
168 let rhs = rhs.as_mat_ref();
169 let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
170 out.as_mat_mut().copy_from(rhs);
171 self.solve_transpose_in_place(&mut out);
172 out
173 }
174 #[track_caller]
175 #[inline]
176 fn solve_adjoint<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
178 let rhs = rhs.as_mat_ref();
179 let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
180 out.as_mat_mut().copy_from(rhs);
181 self.solve_adjoint_in_place(&mut out);
182 out
183 }
184
185 #[track_caller]
186 #[inline]
187 fn rsolve<Lhs: AsMatRef<T = T, Cols = usize>>(&self, lhs: Lhs) -> Lhs::Owned {
189 let lhs = lhs.as_mat_ref();
190 let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
191 out.as_mat_mut().copy_from(lhs);
192 self.rsolve_in_place(&mut out);
193 out
194 }
195 #[track_caller]
196 #[inline]
197 fn rsolve_conjugate<Lhs: AsMatRef<T = T, Cols = usize>>(&self, lhs: Lhs) -> Lhs::Owned {
199 let lhs = lhs.as_mat_ref();
200 let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
201 out.as_mat_mut().copy_from(lhs);
202 self.rsolve_conjugate_in_place(&mut out);
203 out
204 }
205
206 #[track_caller]
207 #[inline]
208 fn rsolve_transpose<Lhs: AsMatRef<T = T, Cols = usize>>(&self, lhs: Lhs) -> Lhs::Owned {
210 let lhs = lhs.as_mat_ref();
211 let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
212 out.as_mat_mut().copy_from(lhs);
213 self.rsolve_transpose_in_place(&mut out);
214 out
215 }
216 #[track_caller]
217 #[inline]
218 fn rsolve_adjoint<Lhs: AsMatRef<T = T, Cols = usize>>(&self, lhs: Lhs) -> Lhs::Owned {
220 let lhs = lhs.as_mat_ref();
221 let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
222 out.as_mat_mut().copy_from(lhs);
223 self.rsolve_adjoint_in_place(&mut out);
224 out
225 }
226}
227
228impl<C: Conjugate, Inner: for<'short> Reborrow<'short, Target = mat::Ref<'short, C>>> mat::generic::Mat<Inner> {
229 #[track_caller]
230 pub fn partial_piv_lu(&self) -> PartialPivLu<C::Canonical> {
232 PartialPivLu::new(self.rb())
233 }
234
235 #[track_caller]
236 pub fn full_piv_lu(&self) -> FullPivLu<C::Canonical> {
238 FullPivLu::new(self.rb())
239 }
240
241 #[track_caller]
242 pub fn qr(&self) -> Qr<C::Canonical> {
244 Qr::new(self.rb())
245 }
246
247 #[track_caller]
248 pub fn col_piv_qr(&self) -> ColPivQr<C::Canonical> {
250 ColPivQr::new(self.rb())
251 }
252
253 #[track_caller]
254 pub fn svd(&self) -> Result<Svd<C::Canonical>, SvdError> {
258 Svd::new(self.rb())
259 }
260
261 #[track_caller]
262 pub fn thin_svd(&self) -> Result<Svd<C::Canonical>, SvdError> {
266 Svd::new_thin(self.rb())
267 }
268
269 #[track_caller]
270 pub fn llt(&self, side: Side) -> Result<Llt<C::Canonical>, LltError> {
272 Llt::new(self.rb(), side)
273 }
274
275 #[track_caller]
276 pub fn ldlt(&self, side: Side) -> Result<Ldlt<C::Canonical>, LdltError> {
278 Ldlt::new(self.rb(), side)
279 }
280
281 #[track_caller]
282 pub fn lblt(&self, side: Side) -> Lblt<C::Canonical> {
284 Lblt::new(self.rb(), side)
285 }
286
287 #[track_caller]
288 pub fn self_adjoint_eigen(&self, side: Side) -> Result<SelfAdjointEigen<C::Canonical>, EvdError> {
292 SelfAdjointEigen::new(self.rb(), side)
293 }
294
295 #[track_caller]
296 pub fn self_adjoint_eigenvalues(&self, side: Side) -> Result<Vec<Real<C>>, EvdError> {
300 #[track_caller]
301 pub fn imp<T: ComplexField>(mut A: MatRef<'_, T>, side: Side) -> Result<Vec<T::Real>, EvdError> {
302 assert!(A.nrows() == A.ncols());
303 if side == Side::Upper {
304 A = A.transpose();
305 }
306 let par = get_global_parallelism();
307 let n = A.nrows();
308
309 let mut s = Diag::<T>::zeros(n);
310
311 linalg::evd::self_adjoint_evd(
312 A,
313 s.as_mut(),
314 None,
315 par,
316 MemStack::new(&mut MemBuffer::new(linalg::evd::self_adjoint_evd_scratch::<T>(
317 n,
318 linalg::evd::ComputeEigenvectors::No,
319 par,
320 default(),
321 ))),
322 default(),
323 )?;
324
325 Ok(s.column_vector().iter().map(|x| real(x)).collect())
326 }
327
328 imp(self.rb().canonical(), side)
329 }
330
331 #[track_caller]
332 pub fn singular_values(&self) -> Result<Vec<Real<C>>, SvdError> {
336 pub fn imp<T: ComplexField>(A: MatRef<'_, T>) -> Result<Vec<T::Real>, SvdError> {
337 let par = get_global_parallelism();
338 let m = A.nrows();
339 let n = A.ncols();
340
341 let mut s = Diag::<T>::zeros(Ord::min(m, n));
342
343 linalg::svd::svd(
344 A,
345 s.as_mut(),
346 None,
347 None,
348 par,
349 MemStack::new(&mut MemBuffer::new(linalg::svd::svd_scratch::<T>(
350 m,
351 n,
352 linalg::svd::ComputeSvdVectors::No,
353 linalg::svd::ComputeSvdVectors::No,
354 par,
355 default(),
356 ))),
357 default(),
358 )?;
359
360 Ok(s.column_vector().iter().map(|x| real(x)).collect())
361 }
362
363 imp(self.rb().canonical())
364 }
365}
366
367impl<C: Conjugate> MatRef<'_, C> {
368 #[track_caller]
369 fn eigen_imp(&self) -> Result<Eigen<Real<C>>, EvdError> {
370 if const { C::Canonical::IS_REAL } {
371 Eigen::new_from_real(unsafe { crate::hacks::coerce(*self) })
372 } else if const { C::IS_CANONICAL } {
373 Eigen::new(unsafe { crate::hacks::coerce::<_, MatRef<'_, Complex<Real<C>>>>(*self) })
374 } else {
375 Eigen::new(unsafe { crate::hacks::coerce::<_, MatRef<'_, ComplexConj<Real<C>>>>(*self) })
376 }
377 }
378
379 #[track_caller]
380 fn eigenvalues_imp(&self) -> Result<Vec<Complex<Real<C>>>, EvdError> {
381 let par = get_global_parallelism();
382
383 if const { C::Canonical::IS_REAL } {
384 let A = unsafe { crate::hacks::coerce::<_, MatRef<'_, Real<C>>>(*self) };
385 assert!(A.nrows() == A.ncols());
386 let n = A.nrows();
387
388 let mut s_re = Diag::<Real<C>>::zeros(n);
389 let mut s_im = Diag::<Real<C>>::zeros(n);
390
391 linalg::evd::evd_real(
392 A,
393 s_re.as_mut(),
394 s_im.as_mut(),
395 None,
396 None,
397 par,
398 MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<Real<C>>(
399 n,
400 linalg::evd::ComputeEigenvectors::No,
401 linalg::evd::ComputeEigenvectors::No,
402 par,
403 default(),
404 ))),
405 default(),
406 )?;
407
408 Ok(s_re
409 .column_vector()
410 .iter()
411 .zip(s_im.column_vector().iter())
412 .map(|(re, im)| Complex::new(re.clone(), im.clone()))
413 .collect())
414 } else {
415 let A = unsafe { crate::hacks::coerce::<_, MatRef<'_, Complex<Real<C>>>>(self.canonical()) };
416 assert!(A.nrows() == A.ncols());
417 let n = A.nrows();
418
419 let mut s = Diag::<Complex<Real<C>>>::zeros(n);
420
421 linalg::evd::evd_cplx(
422 A,
423 s.as_mut(),
424 None,
425 None,
426 par,
427 MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<Complex<Real<C>>>(
428 n,
429 linalg::evd::ComputeEigenvectors::No,
430 linalg::evd::ComputeEigenvectors::No,
431 par,
432 default(),
433 ))),
434 default(),
435 )?;
436
437 if const { C::IS_CANONICAL } {
438 Ok(s.column_vector().iter().cloned().collect())
439 } else {
440 Ok(s.column_vector().iter().map(conj).collect())
441 }
442 }
443 }
444}
445
446impl<T: Conjugate, Inner: for<'short> Reborrow<'short, Target = mat::Ref<'short, T>>> mat::generic::Mat<Inner> {
447 #[track_caller]
449 pub fn eigen(&self) -> Result<Eigen<Real<T>>, EvdError> {
450 self.rb().eigen_imp()
451 }
452
453 #[track_caller]
455 pub fn eigenvalues(&self) -> Result<Vec<Complex<Real<T>>>, EvdError> {
456 self.rb().eigenvalues_imp()
457 }
458}
459
460pub trait SolveLstsq<T: ComplexField>: SolveLstsqCore<T> {
462 #[track_caller]
463 #[inline]
464 fn solve_lstsq_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
466 self.solve_lstsq_in_place_with_conj(Conj::No, { rhs }.as_mat_mut().as_dyn_cols_mut());
467 }
468
469 #[track_caller]
470 #[inline]
471 fn solve_conjugate_lstsq_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
473 self.solve_lstsq_in_place_with_conj(Conj::Yes, { rhs }.as_mat_mut().as_dyn_cols_mut());
474 }
475
476 #[track_caller]
477 #[inline]
478 fn solve_lstsq<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
480 let rhs = rhs.as_mat_ref();
481 let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
482 out.as_mat_mut().copy_from(rhs);
483 self.solve_lstsq_in_place(&mut out);
484 out.truncate(self.ncols(), rhs.ncols());
485 out
486 }
487 #[track_caller]
488 #[inline]
489 fn solve_conjugate_lstsq<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
491 let rhs = rhs.as_mat_ref();
492 let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
493 out.as_mat_mut().copy_from(rhs);
494 self.solve_conjugate_lstsq_in_place(&mut out);
495 out.truncate(self.ncols(), rhs.ncols());
496 out
497 }
498}
499pub trait DenseSolve<T: ComplexField>: DenseSolveCore<T> {}
501
502impl<T: ComplexField, S: ?Sized + SolveCore<T>> Solve<T> for S {}
503impl<T: ComplexField, S: ?Sized + SolveLstsqCore<T>> SolveLstsq<T> for S {}
504impl<T: ComplexField, S: ?Sized + DenseSolveCore<T>> DenseSolve<T> for S {}
505
506#[derive(Clone, Debug)]
508pub struct Llt<T> {
509 L: Mat<T>,
510}
511
512#[derive(Clone, Debug)]
514pub struct Ldlt<T> {
515 L: Mat<T>,
516 D: Diag<T>,
517}
518
519#[derive(Clone, Debug)]
521pub struct Lblt<T> {
522 L: Mat<T>,
523 B_diag: Diag<T>,
524 B_subdiag: Diag<T>,
525 P: Perm<usize>,
526}
527
528#[derive(Clone, Debug)]
530pub struct PartialPivLu<T> {
531 L: Mat<T>,
532 U: Mat<T>,
533 P: Perm<usize>,
534}
535
536#[derive(Clone, Debug)]
538pub struct FullPivLu<T> {
539 L: Mat<T>,
540 U: Mat<T>,
541 P: Perm<usize>,
542 Q: Perm<usize>,
543}
544
545#[derive(Clone, Debug)]
547pub struct Qr<T> {
548 Q_basis: Mat<T>,
549 Q_coeff: Mat<T>,
550 R: Mat<T>,
551}
552
553#[derive(Clone, Debug)]
555pub struct ColPivQr<T> {
556 Q_basis: Mat<T>,
557 Q_coeff: Mat<T>,
558 R: Mat<T>,
559 P: Perm<usize>,
560}
561
562#[derive(Clone, Debug)]
564pub struct Svd<T> {
565 U: Mat<T>,
566 V: Mat<T>,
567 S: Diag<T>,
568}
569
570#[derive(Clone, Debug)]
572pub struct SelfAdjointEigen<T> {
573 U: Mat<T>,
574 S: Diag<T>,
575}
576
577#[derive(Clone, Debug)]
579pub struct Eigen<T> {
580 U: Mat<Complex<T>>,
581 S: Diag<Complex<T>>,
582}
583
584impl<T: ComplexField> Llt<T> {
585 #[track_caller]
587 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>, side: Side) -> Result<Self, LltError> {
588 assert!(all(A.nrows() == A.ncols()));
589 let n = A.nrows();
590
591 let mut L = Mat::zeros(n, n);
592 match side {
593 Side::Lower => L.copy_from_triangular_lower(A),
594 Side::Upper => L.copy_from_triangular_lower(A.adjoint()),
595 }
596
597 Self::new_imp(L)
598 }
599
600 #[track_caller]
601 fn new_imp(mut L: Mat<T>) -> Result<Self, LltError> {
602 let par = get_global_parallelism();
603
604 let n = L.nrows();
605
606 let mut mem = MemBuffer::new(linalg::cholesky::llt::factor::cholesky_in_place_scratch::<T>(n, par, default()));
607 let stack = MemStack::new(&mut mem);
608
609 linalg::cholesky::llt::factor::cholesky_in_place(L.as_mut(), Default::default(), par, stack, default())?;
610 z!(&mut L).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
611
612 Ok(Self { L })
613 }
614
615 pub fn L(&self) -> MatRef<'_, T> {
617 self.L.as_ref()
618 }
619}
620
621impl<T: ComplexField> Ldlt<T> {
622 #[track_caller]
624 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>, side: Side) -> Result<Self, LdltError> {
625 assert!(all(A.nrows() == A.ncols()));
626 let n = A.nrows();
627
628 let mut L = Mat::zeros(n, n);
629 match side {
630 Side::Lower => L.copy_from_triangular_lower(A),
631 Side::Upper => L.copy_from_triangular_lower(A.adjoint()),
632 }
633
634 Self::new_imp(L)
635 }
636
637 #[track_caller]
638 fn new_imp(mut L: Mat<T>) -> Result<Self, LdltError> {
639 let par = get_global_parallelism();
640
641 let n = L.nrows();
642 let mut D = Diag::zeros(n);
643
644 let mut mem = MemBuffer::new(linalg::cholesky::ldlt::factor::cholesky_in_place_scratch::<T>(n, par, default()));
645 let stack = MemStack::new(&mut mem);
646
647 linalg::cholesky::ldlt::factor::cholesky_in_place(L.as_mut(), Default::default(), par, stack, default())?;
648
649 D.copy_from(L.diagonal());
650 L.diagonal_mut().fill(one());
651 z!(&mut L).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
652
653 Ok(Self { L, D })
654 }
655
656 pub fn L(&self) -> MatRef<'_, T> {
658 self.L.as_ref()
659 }
660
661 pub fn D(&self) -> DiagRef<'_, T> {
663 self.D.as_ref()
664 }
665}
666
667impl<T: ComplexField> Lblt<T> {
668 #[track_caller]
670 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>, side: Side) -> Self {
671 assert!(all(A.nrows() == A.ncols()));
672 let n = A.nrows();
673
674 let mut L = Mat::zeros(n, n);
675 match side {
676 Side::Lower => L.copy_from_triangular_lower(A),
677 Side::Upper => L.copy_from_triangular_lower(A.adjoint()),
678 }
679 Self::new_imp(L)
680 }
681
682 #[track_caller]
683 fn new_imp(mut L: Mat<T>) -> Self {
684 let par = get_global_parallelism();
685
686 let n = L.nrows();
687
688 let mut diag = Diag::zeros(n);
689 let mut subdiag = Diag::zeros(n);
690 let mut perm_fwd = vec![0usize; n];
691 let mut perm_bwd = vec![0usize; n];
692
693 let mut mem = MemBuffer::new(linalg::cholesky::lblt::factor::cholesky_in_place_scratch::<usize, T>(n, par, default()));
694 let stack = MemStack::new(&mut mem);
695
696 linalg::cholesky::lblt::factor::cholesky_in_place(L.as_mut(), subdiag.as_mut(), &mut perm_fwd, &mut perm_bwd, par, stack, default());
697
698 diag.copy_from(L.diagonal());
699 L.diagonal_mut().fill(one());
700 z!(&mut L).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
701
702 Self {
703 L,
704 B_diag: diag,
705 B_subdiag: subdiag,
706 P: unsafe { Perm::new_unchecked(perm_fwd.into_boxed_slice(), perm_bwd.into_boxed_slice()) },
707 }
708 }
709
710 pub fn L(&self) -> MatRef<'_, T> {
712 self.L.as_ref()
713 }
714
715 pub fn B_diag(&self) -> DiagRef<'_, T> {
717 self.B_diag.as_ref()
718 }
719
720 pub fn B_subdiag(&self) -> DiagRef<'_, T> {
722 self.B_subdiag.as_ref()
723 }
724
725 pub fn P(&self) -> PermRef<'_, usize> {
727 self.P.as_ref()
728 }
729}
730
731fn split_LU<T: ComplexField>(LU: Mat<T>) -> (Mat<T>, Mat<T>) {
732 let (m, n) = LU.shape();
733 let size = Ord::min(m, n);
734
735 let (L, U) = if m >= n {
736 let mut L = LU;
737 let mut U = Mat::zeros(size, size);
738
739 U.copy_from_triangular_upper(L.get(..size, ..size));
740
741 z!(&mut L).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
742 L.diagonal_mut().fill(one());
743
744 (L, U)
745 } else {
746 let mut U = LU;
747 let mut L = Mat::zeros(size, size);
748
749 L.copy_from_strict_triangular_lower(U.get(..size, ..size));
750
751 z!(&mut U).for_each_triangular_lower(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
752 L.diagonal_mut().fill(one());
753
754 (L, U)
755 };
756 (L, U)
757}
758
759impl<T: ComplexField> PartialPivLu<T> {
760 #[track_caller]
762 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
763 let LU = A.to_owned();
764 Self::new_imp(LU)
765 }
766
767 #[track_caller]
768 fn new_imp(mut LU: Mat<T>) -> Self {
769 let par = get_global_parallelism();
770
771 let (m, n) = LU.shape();
772 let mut row_perm_fwd = vec![0usize; m];
773 let mut row_perm_bwd = vec![0usize; m];
774
775 linalg::lu::partial_pivoting::factor::lu_in_place(
776 LU.as_mut(),
777 &mut row_perm_fwd,
778 &mut row_perm_bwd,
779 par,
780 MemStack::new(&mut MemBuffer::new(
781 linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<usize, T>(m, n, par, default()),
782 )),
783 default(),
784 );
785
786 let (L, U) = split_LU(LU);
787
788 Self {
789 L,
790 U,
791 P: unsafe { Perm::new_unchecked(row_perm_fwd.into_boxed_slice(), row_perm_bwd.into_boxed_slice()) },
792 }
793 }
794
795 pub fn L(&self) -> MatRef<'_, T> {
797 self.L.as_ref()
798 }
799
800 pub fn U(&self) -> MatRef<'_, T> {
802 self.U.as_ref()
803 }
804
805 pub fn P(&self) -> PermRef<'_, usize> {
807 self.P.as_ref()
808 }
809}
810
811impl<T: ComplexField> FullPivLu<T> {
812 #[track_caller]
814 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
815 let LU = A.to_owned();
816 Self::new_imp(LU)
817 }
818
819 #[track_caller]
820 fn new_imp(mut LU: Mat<T>) -> Self {
821 let par = get_global_parallelism();
822
823 let (m, n) = LU.shape();
824 let mut row_perm_fwd = vec![0usize; m];
825 let mut row_perm_bwd = vec![0usize; m];
826 let mut col_perm_fwd = vec![0usize; n];
827 let mut col_perm_bwd = vec![0usize; n];
828
829 linalg::lu::full_pivoting::factor::lu_in_place(
830 LU.as_mut(),
831 &mut row_perm_fwd,
832 &mut row_perm_bwd,
833 &mut col_perm_fwd,
834 &mut col_perm_bwd,
835 par,
836 MemStack::new(&mut MemBuffer::new(linalg::lu::full_pivoting::factor::lu_in_place_scratch::<usize, T>(
837 m,
838 n,
839 par,
840 default(),
841 ))),
842 default(),
843 );
844
845 let (L, U) = split_LU(LU);
846
847 Self {
848 L,
849 U,
850 P: unsafe { Perm::new_unchecked(row_perm_fwd.into_boxed_slice(), row_perm_bwd.into_boxed_slice()) },
851 Q: unsafe { Perm::new_unchecked(col_perm_fwd.into_boxed_slice(), col_perm_bwd.into_boxed_slice()) },
852 }
853 }
854
855 pub fn L(&self) -> MatRef<'_, T> {
857 self.L.as_ref()
858 }
859
860 pub fn U(&self) -> MatRef<'_, T> {
862 self.U.as_ref()
863 }
864
865 pub fn P(&self) -> PermRef<'_, usize> {
867 self.P.as_ref()
868 }
869
870 pub fn Q(&self) -> PermRef<'_, usize> {
872 self.Q.as_ref()
873 }
874}
875
876impl<T: ComplexField> Qr<T> {
877 #[track_caller]
879 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
880 let QR = A.to_owned();
881 Self::new_imp(QR)
882 }
883
884 #[track_caller]
885 fn new_imp(mut QR: Mat<T>) -> Self {
886 let par = get_global_parallelism();
887
888 let (m, n) = QR.shape();
889 let size = Ord::min(m, n);
890
891 let blocksize = linalg::qr::no_pivoting::factor::recommended_blocksize::<T>(m, n);
892 let mut Q_coeff = Mat::zeros(blocksize, size);
893
894 linalg::qr::no_pivoting::factor::qr_in_place(
895 QR.as_mut(),
896 Q_coeff.as_mut(),
897 par,
898 MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::factor::qr_in_place_scratch::<T>(
899 m,
900 n,
901 blocksize,
902 par,
903 default(),
904 ))),
905 default(),
906 );
907
908 let (Q_basis, R) = split_LU(QR);
909
910 Self { Q_basis, Q_coeff, R }
911 }
912
913 pub fn Q_basis(&self) -> MatRef<'_, T> {
915 self.Q_basis.as_ref()
916 }
917
918 pub fn Q_coeff(&self) -> MatRef<'_, T> {
920 self.Q_coeff.as_ref()
921 }
922
923 pub fn R(&self) -> MatRef<'_, T> {
925 self.R.as_ref()
926 }
927
928 pub fn thin_R(&self) -> MatRef<'_, T> {
930 let size = Ord::min(self.nrows(), self.ncols());
931 self.R.get(..size, ..)
932 }
933
934 pub fn compute_Q(&self) -> Mat<T> {
936 let mut Q = Mat::identity(self.nrows(), self.nrows());
937 let par = get_global_parallelism();
938 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
939 self.Q_basis(),
940 self.Q_coeff(),
941 Conj::No,
942 Q.rb_mut(),
943 par,
944 MemStack::new(&mut MemBuffer::new(
945 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
946 self.nrows(),
947 self.Q_coeff.nrows(),
948 self.nrows(),
949 ),
950 )),
951 );
952 Q
953 }
954
955 pub fn compute_thin_Q(&self) -> Mat<T> {
957 let size = Ord::min(self.nrows(), self.ncols());
958 let mut Q = Mat::identity(self.nrows(), size);
959 let par = get_global_parallelism();
960 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
961 self.Q_basis(),
962 self.Q_coeff(),
963 Conj::No,
964 Q.rb_mut(),
965 par,
966 MemStack::new(&mut MemBuffer::new(
967 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(self.nrows(), self.Q_coeff.nrows(), size),
968 )),
969 );
970 Q
971 }
972}
973
974impl<T: ComplexField> ColPivQr<T> {
975 #[track_caller]
977 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
978 let QR = A.to_owned();
979 Self::new_imp(QR)
980 }
981
982 #[track_caller]
983 fn new_imp(mut QR: Mat<T>) -> Self {
984 let par = get_global_parallelism();
985
986 let (m, n) = QR.shape();
987 let size = Ord::min(m, n);
988
989 let mut col_perm_fwd = vec![0usize; n];
990 let mut col_perm_bwd = vec![0usize; n];
991
992 let blocksize = linalg::qr::no_pivoting::factor::recommended_blocksize::<T>(m, n);
993 let mut Q_coeff = Mat::zeros(blocksize, size);
994
995 linalg::qr::col_pivoting::factor::qr_in_place(
996 QR.as_mut(),
997 Q_coeff.as_mut(),
998 &mut col_perm_fwd,
999 &mut col_perm_bwd,
1000 par,
1001 MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::factor::qr_in_place_scratch::<usize, T>(
1002 m,
1003 n,
1004 blocksize,
1005 par,
1006 default(),
1007 ))),
1008 default(),
1009 );
1010
1011 let (Q_basis, R) = split_LU(QR);
1012
1013 Self {
1014 Q_basis,
1015 Q_coeff,
1016 R,
1017 P: unsafe { Perm::new_unchecked(col_perm_fwd.into_boxed_slice(), col_perm_bwd.into_boxed_slice()) },
1018 }
1019 }
1020
1021 pub fn Q_basis(&self) -> MatRef<'_, T> {
1023 self.Q_basis.as_ref()
1024 }
1025
1026 pub fn Q_coeff(&self) -> MatRef<'_, T> {
1028 self.Q_coeff.as_ref()
1029 }
1030
1031 pub fn R(&self) -> MatRef<'_, T> {
1033 self.R.as_ref()
1034 }
1035
1036 pub fn thin_R(&self) -> MatRef<'_, T> {
1038 let size = Ord::min(self.nrows(), self.ncols());
1039 self.R.get(..size, ..)
1040 }
1041
1042 pub fn compute_Q(&self) -> Mat<T> {
1044 let mut Q = Mat::identity(self.nrows(), self.nrows());
1045 let par = get_global_parallelism();
1046 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
1047 self.Q_basis(),
1048 self.Q_coeff(),
1049 Conj::No,
1050 Q.rb_mut(),
1051 par,
1052 MemStack::new(&mut MemBuffer::new(
1053 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
1054 self.nrows(),
1055 self.Q_coeff.nrows(),
1056 self.nrows(),
1057 ),
1058 )),
1059 );
1060 Q
1061 }
1062
1063 pub fn compute_thin_Q(&self) -> Mat<T> {
1065 let size = Ord::min(self.nrows(), self.ncols());
1066 let mut Q = Mat::identity(self.nrows(), size);
1067 let par = get_global_parallelism();
1068 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
1069 self.Q_basis(),
1070 self.Q_coeff(),
1071 Conj::No,
1072 Q.rb_mut(),
1073 par,
1074 MemStack::new(&mut MemBuffer::new(
1075 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(self.nrows(), self.Q_coeff.nrows(), size),
1076 )),
1077 );
1078 Q
1079 }
1080
1081 pub fn P(&self) -> PermRef<'_, usize> {
1083 self.P.as_ref()
1084 }
1085}
1086
1087impl<T: ComplexField> Svd<T> {
1088 #[track_caller]
1090 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Result<Self, SvdError> {
1091 Self::new_imp(A.canonical(), Conj::get::<C>(), false)
1092 }
1093
1094 #[track_caller]
1096 pub fn new_thin<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Result<Self, SvdError> {
1097 Self::new_imp(A.canonical(), Conj::get::<C>(), true)
1098 }
1099
1100 #[track_caller]
1101 fn new_imp(A: MatRef<'_, T>, conj: Conj, thin: bool) -> Result<Self, SvdError> {
1102 let par = get_global_parallelism();
1103
1104 let (m, n) = A.shape();
1105 let size = Ord::min(m, n);
1106
1107 let mut U = Mat::zeros(m, if thin { size } else { m });
1108 let mut V = Mat::zeros(n, if thin { size } else { n });
1109 let mut S = Diag::zeros(size);
1110
1111 let compute = if thin { ComputeSvdVectors::Thin } else { ComputeSvdVectors::Full };
1112
1113 linalg::svd::svd(
1114 A,
1115 S.as_mut(),
1116 Some(U.as_mut()),
1117 Some(V.as_mut()),
1118 par,
1119 MemStack::new(&mut MemBuffer::new(linalg::svd::svd_scratch::<T>(m, n, compute, compute, par, default()))),
1120 default(),
1121 )?;
1122
1123 if conj == Conj::Yes {
1124 for c in U.col_iter_mut() {
1125 for x in c.iter_mut() {
1126 *x = math_utils::conj(x);
1127 }
1128 }
1129 for c in V.col_iter_mut() {
1130 for x in c.iter_mut() {
1131 *x = math_utils::conj(x);
1132 }
1133 }
1134 }
1135
1136 Ok(Self { U, V, S })
1137 }
1138
1139 pub fn U(&self) -> MatRef<'_, T> {
1141 self.U.as_ref()
1142 }
1143
1144 pub fn V(&self) -> MatRef<'_, T> {
1146 self.V.as_ref()
1147 }
1148
1149 pub fn S(&self) -> DiagRef<'_, T> {
1151 self.S.as_ref()
1152 }
1153
1154 pub fn pseudoinverse(&self) -> Mat<T> {
1156 let U = self.U();
1157 let V = self.V();
1158 let S = self.S();
1159 let par = get_global_parallelism();
1160 let stack = &mut MemBuffer::new(linalg::svd::pseudoinverse_from_svd_scratch::<T>(self.nrows(), self.ncols(), par));
1161 let mut pinv = Mat::zeros(self.nrows(), self.ncols());
1162 linalg::svd::pseudoinverse_from_svd(pinv.rb_mut(), S, U, V, par, MemStack::new(stack));
1163 pinv
1164 }
1165}
1166
1167impl<T: ComplexField> SelfAdjointEigen<T> {
1168 #[track_caller]
1170 pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>, side: Side) -> Result<Self, EvdError> {
1171 assert!(A.nrows() == A.ncols());
1172
1173 match side {
1174 Side::Lower => Self::new_imp(A.canonical(), Conj::get::<C>()),
1175 Side::Upper => Self::new_imp(A.adjoint().canonical(), Conj::get::<C::Conj>()),
1176 }
1177 }
1178
1179 #[track_caller]
1180 fn new_imp(A: MatRef<'_, T>, conj: Conj) -> Result<Self, EvdError> {
1181 let par = get_global_parallelism();
1182
1183 let n = A.nrows();
1184
1185 let mut U = Mat::zeros(n, n);
1186 let mut S = Diag::zeros(n);
1187
1188 linalg::evd::self_adjoint_evd(
1189 A,
1190 S.as_mut(),
1191 Some(U.as_mut()),
1192 par,
1193 MemStack::new(&mut MemBuffer::new(linalg::evd::self_adjoint_evd_scratch::<T>(
1194 n,
1195 linalg::evd::ComputeEigenvectors::Yes,
1196 par,
1197 default(),
1198 ))),
1199 default(),
1200 )?;
1201
1202 if conj == Conj::Yes {
1203 for c in U.col_iter_mut() {
1204 for x in c.iter_mut() {
1205 *x = math_utils::conj(x);
1206 }
1207 }
1208 }
1209
1210 Ok(Self { U, S })
1211 }
1212
1213 pub fn U(&self) -> MatRef<'_, T> {
1215 self.U.as_ref()
1216 }
1217
1218 pub fn S(&self) -> DiagRef<'_, T> {
1220 self.S.as_ref()
1221 }
1222}
1223
1224impl<T: RealField> Eigen<T> {
1225 #[track_caller]
1227 pub fn new<C: Conjugate<Canonical = Complex<T>>>(A: MatRef<'_, C>) -> Result<Self, EvdError> {
1228 assert!(A.nrows() == A.ncols());
1229 Self::new_imp(A.canonical(), Conj::get::<C>())
1230 }
1231
1232 #[track_caller]
1234 pub fn new_from_real(A: MatRef<'_, T>) -> Result<Self, EvdError> {
1235 assert!(A.nrows() == A.ncols());
1236
1237 let par = get_global_parallelism();
1238
1239 let n = A.nrows();
1240
1241 let mut U_real = Mat::zeros(n, n);
1242 let mut S_re = Diag::zeros(n);
1243 let mut S_im = Diag::zeros(n);
1244
1245 linalg::evd::evd_real(
1246 A,
1247 S_re.as_mut(),
1248 S_im.as_mut(),
1249 None,
1250 Some(U_real.as_mut()),
1251 par,
1252 MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<T>(
1253 n,
1254 linalg::evd::ComputeEigenvectors::No,
1255 linalg::evd::ComputeEigenvectors::Yes,
1256 par,
1257 default(),
1258 ))),
1259 default(),
1260 )?;
1261
1262 let mut U = Mat::zeros(n, n);
1263 let mut S = Diag::zeros(n);
1264
1265 let mut j = 0;
1266 while j < n {
1267 if S_im[j] == zero() {
1268 S[j] = Complex::new(S_re[j].clone(), zero());
1269
1270 for i in 0..n {
1271 U[(i, j)] = Complex::new(U_real[(i, j)].clone(), zero());
1272 }
1273
1274 j += 1;
1275 } else {
1276 S[j] = Complex::new(S_re[j].clone(), S_im[j].clone());
1277 S[j + 1] = Complex::new(S_re[j].clone(), neg(&S_im[j]));
1278
1279 for i in 0..n {
1280 U[(i, j)] = Complex::new(U_real[(i, j)].clone(), U_real[(i, j + 1)].clone());
1281 U[(i, j + 1)] = Complex::new(U_real[(i, j)].clone(), neg(&U_real[(i, j + 1)]));
1282 }
1283
1284 j += 2;
1285 }
1286 }
1287
1288 Ok(Self { U, S })
1289 }
1290
1291 fn new_imp(A: MatRef<'_, Complex<T>>, conj: Conj) -> Result<Self, EvdError> {
1292 let par = get_global_parallelism();
1293
1294 let n = A.nrows();
1295
1296 let mut U = Mat::zeros(n, n);
1297 let mut S = Diag::zeros(n);
1298
1299 linalg::evd::evd_cplx(
1300 A,
1301 S.as_mut(),
1302 None,
1303 Some(U.as_mut()),
1304 par,
1305 MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<Complex<T>>(
1306 n,
1307 linalg::evd::ComputeEigenvectors::No,
1308 linalg::evd::ComputeEigenvectors::Yes,
1309 par,
1310 default(),
1311 ))),
1312 default(),
1313 )?;
1314
1315 if conj == Conj::Yes {
1316 for c in U.col_iter_mut() {
1317 for x in c.iter_mut() {
1318 *x = math_utils::conj(x);
1319 }
1320 }
1321 }
1322
1323 Ok(Self { U, S })
1324 }
1325
1326 pub fn U(&self) -> MatRef<'_, Complex<T>> {
1328 self.U.as_ref()
1329 }
1330
1331 pub fn S(&self) -> DiagRef<'_, Complex<T>> {
1333 self.S.as_ref()
1334 }
1335}
1336
1337impl<T: ComplexField> ShapeCore for Llt<T> {
1338 #[inline]
1339 fn nrows(&self) -> usize {
1340 self.L().nrows()
1341 }
1342
1343 #[inline]
1344 fn ncols(&self) -> usize {
1345 self.L().ncols()
1346 }
1347}
1348impl<T: ComplexField> ShapeCore for Ldlt<T> {
1349 #[inline]
1350 fn nrows(&self) -> usize {
1351 self.L().nrows()
1352 }
1353
1354 #[inline]
1355 fn ncols(&self) -> usize {
1356 self.L().ncols()
1357 }
1358}
1359impl<T: ComplexField> ShapeCore for Lblt<T> {
1360 #[inline]
1361 fn nrows(&self) -> usize {
1362 self.L().nrows()
1363 }
1364
1365 #[inline]
1366 fn ncols(&self) -> usize {
1367 self.L().ncols()
1368 }
1369}
1370impl<T: ComplexField> ShapeCore for PartialPivLu<T> {
1371 #[inline]
1372 fn nrows(&self) -> usize {
1373 self.L().nrows()
1374 }
1375
1376 #[inline]
1377 fn ncols(&self) -> usize {
1378 self.U().ncols()
1379 }
1380}
1381impl<T: ComplexField> ShapeCore for FullPivLu<T> {
1382 #[inline]
1383 fn nrows(&self) -> usize {
1384 self.L().nrows()
1385 }
1386
1387 #[inline]
1388 fn ncols(&self) -> usize {
1389 self.U().ncols()
1390 }
1391}
1392impl<T: ComplexField> ShapeCore for Qr<T> {
1393 #[inline]
1394 fn nrows(&self) -> usize {
1395 self.Q_basis().nrows()
1396 }
1397
1398 #[inline]
1399 fn ncols(&self) -> usize {
1400 self.R().ncols()
1401 }
1402}
1403impl<T: ComplexField> ShapeCore for ColPivQr<T> {
1404 #[inline]
1405 fn nrows(&self) -> usize {
1406 self.Q_basis().nrows()
1407 }
1408
1409 #[inline]
1410 fn ncols(&self) -> usize {
1411 self.R().ncols()
1412 }
1413}
1414impl<T: ComplexField> ShapeCore for Svd<T> {
1415 #[inline]
1416 fn nrows(&self) -> usize {
1417 self.U().nrows()
1418 }
1419
1420 #[inline]
1421 fn ncols(&self) -> usize {
1422 self.V().nrows()
1423 }
1424}
1425impl<T: ComplexField> ShapeCore for SelfAdjointEigen<T> {
1426 #[inline]
1427 fn nrows(&self) -> usize {
1428 self.U().nrows()
1429 }
1430
1431 #[inline]
1432 fn ncols(&self) -> usize {
1433 self.U().nrows()
1434 }
1435}
1436impl<T: RealField> ShapeCore for Eigen<T> {
1437 #[inline]
1438 fn nrows(&self) -> usize {
1439 self.U().nrows()
1440 }
1441
1442 #[inline]
1443 fn ncols(&self) -> usize {
1444 self.U().nrows()
1445 }
1446}
1447
1448impl<T: ComplexField> SolveCore<T> for Llt<T> {
1449 #[track_caller]
1450 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1451 let par = get_global_parallelism();
1452
1453 let mut mem = MemBuffer::new(linalg::cholesky::llt::solve::solve_in_place_scratch::<T>(
1454 self.L.nrows(),
1455 rhs.ncols(),
1456 par,
1457 ));
1458 let stack = MemStack::new(&mut mem);
1459
1460 linalg::cholesky::llt::solve::solve_in_place_with_conj(self.L.as_ref(), conj, rhs, par, stack);
1461 }
1462
1463 #[track_caller]
1464 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1465 let par = get_global_parallelism();
1466
1467 let mut mem = MemBuffer::new(linalg::cholesky::llt::solve::solve_in_place_scratch::<T>(
1468 self.L.nrows(),
1469 rhs.ncols(),
1470 par,
1471 ));
1472 let stack = MemStack::new(&mut mem);
1473
1474 linalg::cholesky::llt::solve::solve_in_place_with_conj(self.L.as_ref(), conj.compose(Conj::Yes), rhs, par, stack);
1475 }
1476}
1477
1478#[math]
1479fn make_self_adjoint<T: ComplexField>(mut A: MatMut<'_, T>) {
1480 assert!(A.nrows() == A.ncols());
1481 let n = A.nrows();
1482 for j in 0..n {
1483 A[(j, j)] = from_real(real(A[(j, j)]));
1484 for i in 0..j {
1485 A[(i, j)] = conj(A[(j, i)]);
1486 }
1487 }
1488}
1489
1490impl<T: ComplexField> DenseSolveCore<T> for Llt<T> {
1491 #[track_caller]
1492 fn reconstruct(&self) -> Mat<T> {
1493 let par = get_global_parallelism();
1494
1495 let n = self.L.nrows();
1496 let mut out = Mat::zeros(n, n);
1497
1498 let mut mem = MemBuffer::new(linalg::cholesky::llt::reconstruct::reconstruct_scratch::<T>(n, par));
1499 let stack = MemStack::new(&mut mem);
1500
1501 linalg::cholesky::llt::reconstruct::reconstruct(out.as_mut(), self.L(), par, stack);
1502
1503 make_self_adjoint(out.as_mut());
1504 out
1505 }
1506
1507 #[track_caller]
1508 fn inverse(&self) -> Mat<T> {
1509 let par = get_global_parallelism();
1510
1511 let n = self.L.nrows();
1512 let mut out = Mat::zeros(n, n);
1513
1514 let mut mem = MemBuffer::new(linalg::cholesky::llt::inverse::inverse_scratch::<T>(n, par));
1515 let stack = MemStack::new(&mut mem);
1516
1517 linalg::cholesky::llt::inverse::inverse(out.as_mut(), self.L(), par, stack);
1518
1519 make_self_adjoint(out.as_mut());
1520 out
1521 }
1522}
1523
1524impl<T: ComplexField> SolveCore<T> for Ldlt<T> {
1525 #[track_caller]
1526 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1527 let par = get_global_parallelism();
1528
1529 let mut mem = MemBuffer::new(linalg::cholesky::ldlt::solve::solve_in_place_scratch::<T>(
1530 self.L.nrows(),
1531 rhs.ncols(),
1532 par,
1533 ));
1534 let stack = MemStack::new(&mut mem);
1535
1536 linalg::cholesky::ldlt::solve::solve_in_place_with_conj(self.L.as_ref(), self.D.as_ref(), conj, rhs, par, stack);
1537 }
1538
1539 #[track_caller]
1540 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1541 let par = get_global_parallelism();
1542
1543 let mut mem = MemBuffer::new(linalg::cholesky::ldlt::solve::solve_in_place_scratch::<T>(
1544 self.L.nrows(),
1545 rhs.ncols(),
1546 par,
1547 ));
1548 let stack = MemStack::new(&mut mem);
1549
1550 linalg::cholesky::ldlt::solve::solve_in_place_with_conj(self.L(), self.D(), conj.compose(Conj::Yes), rhs, par, stack);
1551 }
1552}
1553
1554impl<T: ComplexField> DenseSolveCore<T> for Ldlt<T> {
1555 #[track_caller]
1556 fn reconstruct(&self) -> Mat<T> {
1557 let par = get_global_parallelism();
1558
1559 let n = self.L.nrows();
1560 let mut out = Mat::zeros(n, n);
1561
1562 let mut mem = MemBuffer::new(linalg::cholesky::ldlt::reconstruct::reconstruct_scratch::<T>(n, par));
1563 let stack = MemStack::new(&mut mem);
1564
1565 linalg::cholesky::ldlt::reconstruct::reconstruct(out.as_mut(), self.L(), self.D(), par, stack);
1566
1567 make_self_adjoint(out.as_mut());
1568 out
1569 }
1570
1571 #[track_caller]
1572 fn inverse(&self) -> Mat<T> {
1573 let par = get_global_parallelism();
1574
1575 let n = self.L.nrows();
1576 let mut out = Mat::zeros(n, n);
1577
1578 let mut mem = MemBuffer::new(linalg::cholesky::ldlt::inverse::inverse_scratch::<T>(n, par));
1579 let stack = MemStack::new(&mut mem);
1580
1581 linalg::cholesky::ldlt::inverse::inverse(out.as_mut(), self.L(), self.D(), par, stack);
1582
1583 make_self_adjoint(out.as_mut());
1584 out
1585 }
1586}
1587
1588impl<T: ComplexField> SolveCore<T> for Lblt<T> {
1589 #[track_caller]
1590 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1591 let par = get_global_parallelism();
1592
1593 let mut mem = MemBuffer::new(linalg::cholesky::lblt::solve::solve_in_place_scratch::<usize, T>(
1594 self.L.nrows(),
1595 rhs.ncols(),
1596 par,
1597 ));
1598 let stack = MemStack::new(&mut mem);
1599
1600 linalg::cholesky::lblt::solve::solve_in_place_with_conj(self.L.as_ref(), self.B_diag(), self.B_subdiag(), conj, self.P(), rhs, par, stack);
1601 }
1602
1603 #[track_caller]
1604 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1605 let par = get_global_parallelism();
1606
1607 let mut mem = MemBuffer::new(linalg::cholesky::lblt::solve::solve_in_place_scratch::<usize, T>(
1608 self.L.nrows(),
1609 rhs.ncols(),
1610 par,
1611 ));
1612 let stack = MemStack::new(&mut mem);
1613
1614 linalg::cholesky::lblt::solve::solve_in_place_with_conj(
1615 self.L(),
1616 self.B_diag(),
1617 self.B_subdiag(),
1618 conj.compose(Conj::Yes),
1619 self.P(),
1620 rhs,
1621 par,
1622 stack,
1623 );
1624 }
1625}
1626
1627impl<T: ComplexField> DenseSolveCore<T> for Lblt<T> {
1628 #[track_caller]
1629 fn reconstruct(&self) -> Mat<T> {
1630 let par = get_global_parallelism();
1631
1632 let n = self.L.nrows();
1633 let mut out = Mat::zeros(n, n);
1634
1635 let mut mem = MemBuffer::new(linalg::cholesky::lblt::reconstruct::reconstruct_scratch::<usize, T>(n, par));
1636 let stack = MemStack::new(&mut mem);
1637
1638 linalg::cholesky::lblt::reconstruct::reconstruct(out.as_mut(), self.L(), self.B_diag(), self.B_subdiag(), self.P(), par, stack);
1639
1640 make_self_adjoint(out.as_mut());
1641 out
1642 }
1643
1644 #[track_caller]
1645 fn inverse(&self) -> Mat<T> {
1646 let par = get_global_parallelism();
1647
1648 let n = self.L.nrows();
1649 let mut out = Mat::zeros(n, n);
1650
1651 let mut mem = MemBuffer::new(linalg::cholesky::lblt::inverse::inverse_scratch::<usize, T>(n, par));
1652 let stack = MemStack::new(&mut mem);
1653
1654 linalg::cholesky::lblt::inverse::inverse(out.as_mut(), self.L(), self.B_diag(), self.B_subdiag(), self.P(), par, stack);
1655
1656 make_self_adjoint(out.as_mut());
1657 out
1658 }
1659}
1660
1661impl<T: ComplexField> SolveCore<T> for PartialPivLu<T> {
1662 #[track_caller]
1663 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1664 let par = get_global_parallelism();
1665
1666 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
1667
1668 let k = rhs.ncols();
1669
1670 linalg::lu::partial_pivoting::solve::solve_in_place_with_conj(
1671 self.L(),
1672 self.U(),
1673 self.P(),
1674 conj,
1675 rhs,
1676 par,
1677 MemStack::new(&mut MemBuffer::new(
1678 linalg::lu::partial_pivoting::solve::solve_in_place_scratch::<usize, T>(self.nrows(), k, par),
1679 )),
1680 );
1681 }
1682
1683 #[track_caller]
1684 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1685 let par = get_global_parallelism();
1686
1687 assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
1688
1689 let k = rhs.ncols();
1690
1691 linalg::lu::partial_pivoting::solve::solve_transpose_in_place_with_conj(
1692 self.L(),
1693 self.U(),
1694 self.P(),
1695 conj,
1696 rhs,
1697 par,
1698 MemStack::new(&mut MemBuffer::new(
1699 linalg::lu::partial_pivoting::solve::solve_transpose_in_place_scratch::<usize, T>(self.nrows(), k, par),
1700 )),
1701 );
1702 }
1703}
1704
1705impl<T: ComplexField> DenseSolveCore<T> for PartialPivLu<T> {
1706 fn reconstruct(&self) -> Mat<T> {
1707 let par = get_global_parallelism();
1708 let m = self.nrows();
1709 let n = self.ncols();
1710
1711 let mut out = Mat::zeros(m, n);
1712
1713 linalg::lu::partial_pivoting::reconstruct::reconstruct(
1714 out.as_mut(),
1715 self.L(),
1716 self.U(),
1717 self.P(),
1718 par,
1719 MemStack::new(&mut MemBuffer::new(linalg::lu::partial_pivoting::reconstruct::reconstruct_scratch::<
1720 usize,
1721 T,
1722 >(m, n, par))),
1723 );
1724
1725 out
1726 }
1727
1728 #[track_caller]
1729 fn inverse(&self) -> Mat<T> {
1730 let par = get_global_parallelism();
1731
1732 assert!(self.nrows() == self.ncols());
1733
1734 let n = self.ncols();
1735
1736 let mut out = Mat::zeros(n, n);
1737
1738 linalg::lu::partial_pivoting::inverse::inverse(
1739 out.as_mut(),
1740 self.L(),
1741 self.U(),
1742 self.P(),
1743 par,
1744 MemStack::new(&mut MemBuffer::new(linalg::lu::partial_pivoting::inverse::inverse_scratch::<usize, T>(
1745 n, par,
1746 ))),
1747 );
1748
1749 out
1750 }
1751}
1752
1753impl<T: ComplexField> SolveCore<T> for FullPivLu<T> {
1754 #[track_caller]
1755 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1756 let par = get_global_parallelism();
1757
1758 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
1759
1760 let k = rhs.ncols();
1761
1762 linalg::lu::full_pivoting::solve::solve_in_place_with_conj(
1763 self.L(),
1764 self.U(),
1765 self.P(),
1766 self.Q(),
1767 conj,
1768 rhs,
1769 par,
1770 MemStack::new(&mut MemBuffer::new(linalg::lu::full_pivoting::solve::solve_in_place_scratch::<usize, T>(
1771 self.nrows(),
1772 k,
1773 par,
1774 ))),
1775 );
1776 }
1777
1778 #[track_caller]
1779 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1780 let par = get_global_parallelism();
1781
1782 assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
1783
1784 let k = rhs.ncols();
1785
1786 linalg::lu::full_pivoting::solve::solve_transpose_in_place_with_conj(
1787 self.L(),
1788 self.U(),
1789 self.P(),
1790 self.Q(),
1791 conj,
1792 rhs,
1793 par,
1794 MemStack::new(&mut MemBuffer::new(linalg::lu::full_pivoting::solve::solve_transpose_in_place_scratch::<
1795 usize,
1796 T,
1797 >(self.nrows(), k, par))),
1798 );
1799 }
1800}
1801
1802impl<T: ComplexField> DenseSolveCore<T> for FullPivLu<T> {
1803 fn reconstruct(&self) -> Mat<T> {
1804 let par = get_global_parallelism();
1805 let m = self.nrows();
1806 let n = self.ncols();
1807
1808 let mut out = Mat::zeros(m, n);
1809
1810 linalg::lu::full_pivoting::reconstruct::reconstruct(
1811 out.as_mut(),
1812 self.L(),
1813 self.U(),
1814 self.P(),
1815 self.Q(),
1816 par,
1817 MemStack::new(&mut MemBuffer::new(
1818 linalg::lu::full_pivoting::reconstruct::reconstruct_scratch::<usize, T>(m, n, par),
1819 )),
1820 );
1821
1822 out
1823 }
1824
1825 #[track_caller]
1826 fn inverse(&self) -> Mat<T> {
1827 let par = get_global_parallelism();
1828
1829 assert!(self.nrows() == self.ncols());
1830
1831 let n = self.ncols();
1832
1833 let mut out = Mat::zeros(n, n);
1834
1835 linalg::lu::full_pivoting::inverse::inverse(
1836 out.as_mut(),
1837 self.L(),
1838 self.U(),
1839 self.P(),
1840 self.Q(),
1841 par,
1842 MemStack::new(&mut MemBuffer::new(linalg::lu::full_pivoting::inverse::inverse_scratch::<usize, T>(
1843 n, par,
1844 ))),
1845 );
1846
1847 out
1848 }
1849}
1850
1851impl<T: ComplexField> SolveCore<T> for Qr<T> {
1852 #[track_caller]
1853 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1854 let par = get_global_parallelism();
1855
1856 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
1857
1858 let n = self.nrows();
1859 let blocksize = self.Q_coeff().nrows();
1860 let k = rhs.ncols();
1861
1862 linalg::qr::no_pivoting::solve::solve_in_place_with_conj(
1863 self.Q_basis(),
1864 self.Q_coeff(),
1865 self.R(),
1866 conj,
1867 rhs,
1868 par,
1869 MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::solve::solve_in_place_scratch::<T>(
1870 n, blocksize, k, par,
1871 ))),
1872 );
1873 }
1874
1875 #[track_caller]
1876 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1877 let par = get_global_parallelism();
1878
1879 assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
1880
1881 let n = self.nrows();
1882 let blocksize = self.Q_coeff().nrows();
1883 let k = rhs.ncols();
1884
1885 linalg::qr::no_pivoting::solve::solve_transpose_in_place_with_conj(
1886 self.Q_basis(),
1887 self.Q_coeff(),
1888 self.R(),
1889 conj,
1890 rhs,
1891 par,
1892 MemStack::new(&mut MemBuffer::new(
1893 linalg::qr::no_pivoting::solve::solve_transpose_in_place_scratch::<T>(n, blocksize, k, par),
1894 )),
1895 );
1896 }
1897}
1898
1899impl<T: ComplexField> SolveLstsqCore<T> for Qr<T> {
1900 #[track_caller]
1901 fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1902 let par = get_global_parallelism();
1903
1904 assert!(all(self.nrows() == rhs.nrows(), self.nrows() >= self.ncols(),));
1905
1906 let m = self.nrows();
1907 let n = self.ncols();
1908 let blocksize = self.Q_coeff().nrows();
1909 let k = rhs.ncols();
1910
1911 linalg::qr::no_pivoting::solve::solve_lstsq_in_place_with_conj(
1912 self.Q_basis(),
1913 self.Q_coeff(),
1914 self.R(),
1915 conj,
1916 rhs,
1917 par,
1918 MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::solve::solve_lstsq_in_place_scratch::<T>(
1919 m, n, blocksize, k, par,
1920 ))),
1921 );
1922 }
1923}
1924
1925impl<T: ComplexField> DenseSolveCore<T> for Qr<T> {
1926 fn reconstruct(&self) -> Mat<T> {
1927 let par = get_global_parallelism();
1928 let m = self.nrows();
1929 let n = self.ncols();
1930 let blocksize = self.Q_coeff().nrows();
1931
1932 let mut out = Mat::zeros(m, n);
1933
1934 linalg::qr::no_pivoting::reconstruct::reconstruct(
1935 out.as_mut(),
1936 self.Q_basis(),
1937 self.Q_coeff(),
1938 self.R(),
1939 par,
1940 MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::reconstruct::reconstruct_scratch::<T>(
1941 m, n, blocksize, par,
1942 ))),
1943 );
1944
1945 out
1946 }
1947
1948 fn inverse(&self) -> Mat<T> {
1949 let par = get_global_parallelism();
1950 assert!(self.nrows() == self.ncols());
1951
1952 let n = self.ncols();
1953 let blocksize = self.Q_coeff().nrows();
1954
1955 let mut out = Mat::zeros(n, n);
1956
1957 linalg::qr::no_pivoting::inverse::inverse(
1958 out.as_mut(),
1959 self.Q_basis(),
1960 self.Q_coeff(),
1961 self.R(),
1962 par,
1963 MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::inverse::inverse_scratch::<T>(
1964 n, blocksize, par,
1965 ))),
1966 );
1967
1968 out
1969 }
1970}
1971
1972impl<T: ComplexField> SolveCore<T> for ColPivQr<T> {
1973 #[track_caller]
1974 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1975 let par = get_global_parallelism();
1976
1977 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
1978
1979 let n = self.nrows();
1980 let blocksize = self.Q_coeff().nrows();
1981 let k = rhs.ncols();
1982
1983 linalg::qr::col_pivoting::solve::solve_in_place_with_conj(
1984 self.Q_basis(),
1985 self.Q_coeff(),
1986 self.R(),
1987 self.P(),
1988 conj,
1989 rhs,
1990 par,
1991 MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::solve::solve_in_place_scratch::<usize, T>(
1992 n, blocksize, k, par,
1993 ))),
1994 );
1995 }
1996
1997 #[track_caller]
1998 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1999 let par = get_global_parallelism();
2000
2001 assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
2002
2003 let n = self.nrows();
2004 let blocksize = self.Q_coeff().nrows();
2005 let k = rhs.ncols();
2006
2007 linalg::qr::col_pivoting::solve::solve_transpose_in_place_with_conj(
2008 self.Q_basis(),
2009 self.Q_coeff(),
2010 self.R(),
2011 self.P(),
2012 conj,
2013 rhs,
2014 par,
2015 MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::solve::solve_transpose_in_place_scratch::<
2016 usize,
2017 T,
2018 >(n, blocksize, k, par))),
2019 );
2020 }
2021}
2022
2023impl<T: ComplexField> SolveLstsqCore<T> for ColPivQr<T> {
2024 #[track_caller]
2025 fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2026 let par = get_global_parallelism();
2027
2028 assert!(all(self.nrows() == rhs.nrows(), self.nrows() >= self.ncols(),));
2029
2030 let m = self.nrows();
2031 let n = self.ncols();
2032 let blocksize = self.Q_coeff().nrows();
2033 let k = rhs.ncols();
2034
2035 linalg::qr::col_pivoting::solve::solve_lstsq_in_place_with_conj(
2036 self.Q_basis(),
2037 self.Q_coeff(),
2038 self.R(),
2039 self.P(),
2040 conj,
2041 rhs,
2042 par,
2043 MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::solve::solve_lstsq_in_place_scratch::<
2044 usize,
2045 T,
2046 >(m, n, blocksize, k, par))),
2047 );
2048 }
2049}
2050
2051impl<T: ComplexField> DenseSolveCore<T> for ColPivQr<T> {
2052 fn reconstruct(&self) -> Mat<T> {
2053 let par = get_global_parallelism();
2054 let m = self.nrows();
2055 let n = self.ncols();
2056 let blocksize = self.Q_coeff().nrows();
2057
2058 let mut out = Mat::zeros(m, n);
2059
2060 linalg::qr::col_pivoting::reconstruct::reconstruct(
2061 out.as_mut(),
2062 self.Q_basis(),
2063 self.Q_coeff(),
2064 self.R(),
2065 self.P(),
2066 par,
2067 MemStack::new(&mut MemBuffer::new(
2068 linalg::qr::col_pivoting::reconstruct::reconstruct_scratch::<usize, T>(m, n, blocksize, par),
2069 )),
2070 );
2071
2072 out
2073 }
2074
2075 fn inverse(&self) -> Mat<T> {
2076 let par = get_global_parallelism();
2077 assert!(self.nrows() == self.ncols());
2078
2079 let n = self.ncols();
2080 let blocksize = self.Q_coeff().nrows();
2081
2082 let mut out = Mat::zeros(n, n);
2083
2084 linalg::qr::col_pivoting::inverse::inverse(
2085 out.as_mut(),
2086 self.Q_basis(),
2087 self.Q_coeff(),
2088 self.R(),
2089 self.P(),
2090 par,
2091 MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::inverse::inverse_scratch::<usize, T>(
2092 n, blocksize, par,
2093 ))),
2094 );
2095
2096 out
2097 }
2098}
2099
2100impl<T: ComplexField> SolveCore<T> for Svd<T> {
2101 #[track_caller]
2102 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2103 let par = get_global_parallelism();
2104
2105 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
2106
2107 let mut rhs = rhs;
2108 let n = self.nrows();
2109 let k = rhs.ncols();
2110 let mut tmp = Mat::zeros(n, k);
2111
2112 linalg::matmul::matmul_with_conj(
2113 tmp.as_mut(),
2114 Accum::Replace,
2115 self.U().transpose(),
2116 conj.compose(Conj::Yes),
2117 rhs.as_ref(),
2118 Conj::No,
2119 one(),
2120 par,
2121 );
2122
2123 for j in 0..k {
2124 for i in 0..n {
2125 let s = recip(&real(&self.S()[i]));
2126 tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2127 }
2128 }
2129
2130 linalg::matmul::matmul_with_conj(rhs.as_mut(), Accum::Replace, self.V(), conj, tmp.as_ref(), Conj::No, one(), par);
2131 }
2132
2133 #[track_caller]
2134 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2135 let par = get_global_parallelism();
2136
2137 assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
2138
2139 let mut rhs = rhs;
2140 let n = self.nrows();
2141 let k = rhs.ncols();
2142 let mut tmp = Mat::zeros(n, k);
2143
2144 linalg::matmul::matmul_with_conj(
2145 tmp.as_mut(),
2146 Accum::Replace,
2147 self.V().transpose(),
2148 conj,
2149 rhs.as_ref(),
2150 Conj::No,
2151 one(),
2152 par,
2153 );
2154
2155 for j in 0..k {
2156 for i in 0..n {
2157 let s = recip(&real(&self.S()[i]));
2158 tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2159 }
2160 }
2161
2162 linalg::matmul::matmul_with_conj(
2163 rhs.as_mut(),
2164 Accum::Replace,
2165 self.U(),
2166 conj.compose(Conj::Yes),
2167 tmp.as_ref(),
2168 Conj::No,
2169 one(),
2170 par,
2171 );
2172 }
2173}
2174
2175impl<T: ComplexField> SolveLstsqCore<T> for Svd<T> {
2176 #[track_caller]
2177 fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2178 let par = get_global_parallelism();
2179
2180 assert!(all(self.nrows() == rhs.nrows(), self.nrows() >= self.ncols(),));
2181
2182 let m = self.nrows();
2183 let n = self.ncols();
2184
2185 let size = Ord::min(m, n);
2186
2187 let U = self.U().get(.., ..size);
2188 let V = self.V().get(.., ..size);
2189
2190 let k = rhs.ncols();
2191
2192 let mut tmp = Mat::zeros(size, k);
2193
2194 linalg::matmul::matmul_with_conj(
2195 tmp.as_mut(),
2196 Accum::Replace,
2197 U.transpose(),
2198 conj.compose(Conj::Yes),
2199 rhs.as_ref(),
2200 Conj::No,
2201 one(),
2202 par,
2203 );
2204
2205 for j in 0..k {
2206 for i in 0..size {
2207 let s = recip(&real(&self.S()[i]));
2208 tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2209 }
2210 }
2211
2212 linalg::matmul::matmul_with_conj(rhs.get_mut(..size, ..), Accum::Replace, V, conj, tmp.as_ref(), Conj::No, one(), par);
2213 }
2214}
2215
2216impl<T: ComplexField> DenseSolveCore<T> for Svd<T> {
2217 fn reconstruct(&self) -> Mat<T> {
2218 let par = get_global_parallelism();
2219 let m = self.nrows();
2220 let n = self.ncols();
2221
2222 let size = Ord::min(m, n);
2223
2224 let U = self.U().get(.., ..size);
2225 let V = self.V().get(.., ..size);
2226 let S = self.S();
2227
2228 let mut UxS = Mat::zeros(m, size);
2229 for j in 0..size {
2230 let s = real(&S[j]);
2231 for i in 0..m {
2232 UxS[(i, j)] = mul_real(&U[(i, j)], &s);
2233 }
2234 }
2235
2236 let mut out = Mat::zeros(m, n);
2237
2238 linalg::matmul::matmul(out.as_mut(), Accum::Replace, UxS.as_ref(), V.adjoint(), one(), par);
2239
2240 out
2241 }
2242
2243 #[track_caller]
2244 fn inverse(&self) -> Mat<T> {
2245 let par = get_global_parallelism();
2246
2247 assert!(self.nrows() == self.ncols());
2248 let n = self.nrows();
2249
2250 let U = self.U();
2251 let V = self.V();
2252 let S = self.S();
2253
2254 let mut VxS = Mat::zeros(n, n);
2255 for j in 0..n {
2256 let s = recip(&real(&S[j]));
2257
2258 for i in 0..n {
2259 VxS[(i, j)] = mul_real(&V[(i, j)], &s);
2260 }
2261 }
2262
2263 let mut out = Mat::zeros(n, n);
2264
2265 linalg::matmul::matmul(out.as_mut(), Accum::Replace, VxS.as_ref(), U.adjoint(), one(), par);
2266
2267 out
2268 }
2269}
2270
2271impl<T: ComplexField> SolveCore<T> for SelfAdjointEigen<T> {
2272 #[track_caller]
2273 fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2274 let par = get_global_parallelism();
2275
2276 assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
2277
2278 let mut rhs = rhs;
2279 let n = self.nrows();
2280 let k = rhs.ncols();
2281 let mut tmp = Mat::zeros(n, k);
2282
2283 linalg::matmul::matmul_with_conj(
2284 tmp.as_mut(),
2285 Accum::Replace,
2286 self.U().transpose(),
2287 conj.compose(Conj::Yes),
2288 rhs.as_ref(),
2289 Conj::No,
2290 one(),
2291 par,
2292 );
2293
2294 for j in 0..k {
2295 for i in 0..n {
2296 let s = recip(&real(&self.S()[i]));
2297 tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2298 }
2299 }
2300
2301 linalg::matmul::matmul_with_conj(rhs.as_mut(), Accum::Replace, self.U(), conj, tmp.as_ref(), Conj::No, one(), par);
2302 }
2303
2304 #[track_caller]
2305 fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2306 let par = get_global_parallelism();
2307
2308 assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
2309
2310 let mut rhs = rhs;
2311 let n = self.nrows();
2312 let k = rhs.ncols();
2313 let mut tmp = Mat::zeros(n, k);
2314
2315 linalg::matmul::matmul_with_conj(
2316 tmp.as_mut(),
2317 Accum::Replace,
2318 self.U().transpose(),
2319 conj,
2320 rhs.as_ref(),
2321 Conj::No,
2322 one(),
2323 par,
2324 );
2325
2326 for j in 0..k {
2327 for i in 0..n {
2328 let s = recip(&real(&self.S()[i]));
2329 tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2330 }
2331 }
2332
2333 linalg::matmul::matmul_with_conj(
2334 rhs.as_mut(),
2335 Accum::Replace,
2336 self.U(),
2337 conj.compose(Conj::Yes),
2338 tmp.as_ref(),
2339 Conj::No,
2340 one(),
2341 par,
2342 );
2343 }
2344}
2345
2346impl<T: ComplexField> DenseSolveCore<T> for SelfAdjointEigen<T> {
2347 fn reconstruct(&self) -> Mat<T> {
2348 let par = get_global_parallelism();
2349 let m = self.nrows();
2350 let n = self.ncols();
2351
2352 let size = Ord::min(m, n);
2353
2354 let U = self.U().get(.., ..size);
2355 let V = self.U().get(.., ..size);
2356 let S = self.S();
2357
2358 let mut UxS = Mat::zeros(m, size);
2359 for j in 0..size {
2360 let s = real(&S[j]);
2361 for i in 0..m {
2362 UxS[(i, j)] = mul_real(&U[(i, j)], &s);
2363 }
2364 }
2365
2366 let mut out = Mat::zeros(m, n);
2367
2368 linalg::matmul::matmul(out.as_mut(), Accum::Replace, UxS.as_ref(), V.adjoint(), one(), par);
2369
2370 out
2371 }
2372
2373 fn inverse(&self) -> Mat<T> {
2374 let par = get_global_parallelism();
2375
2376 assert!(self.nrows() == self.ncols());
2377 let n = self.nrows();
2378
2379 let U = self.U();
2380 let V = self.U();
2381 let S = self.S();
2382
2383 let mut VxS = Mat::zeros(n, n);
2384 for j in 0..n {
2385 let s = recip(&real(&S[j]));
2386
2387 for i in 0..n {
2388 VxS[(i, j)] = mul_real(&V[(i, j)], &s);
2389 }
2390 }
2391
2392 let mut out = Mat::zeros(n, n);
2393
2394 linalg::matmul::matmul(out.as_mut(), Accum::Replace, VxS.as_ref(), U.adjoint(), one(), par);
2395
2396 out
2397 }
2398}
2399
2400#[cfg(test)]
2401mod tests {
2402 use super::*;
2403 use crate::assert;
2404 use crate::stats::prelude::*;
2405 use crate::utils::approx::*;
2406
2407 #[track_caller]
2408 fn test_solver(A: MatRef<'_, c64>, A_dec: impl SolveCore<c64>) {
2409 #[track_caller]
2410 fn test_solver_imp(A: MatRef<'_, c64>, A_dec: &dyn SolveCore<c64>) {
2411 let rng = &mut StdRng::seed_from_u64(0xC0FFEE);
2412
2413 let n = A.nrows();
2414 let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
2415
2416 let k = 3;
2417
2418 let ref R = CwiseMatDistribution {
2419 nrows: n,
2420 ncols: k,
2421 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2422 }
2423 .rand::<Mat<c64>>(rng);
2424
2425 let ref L = CwiseMatDistribution {
2426 nrows: k,
2427 ncols: n,
2428 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2429 }
2430 .rand::<Mat<c64>>(rng);
2431
2432 assert!(A * A_dec.solve(R) ~ R);
2433 assert!(A.conjugate() * A_dec.solve_conjugate(R) ~ R);
2434 assert!(A.transpose() * A_dec.solve_transpose(R) ~ R);
2435 assert!(A.adjoint() * A_dec.solve_adjoint(R) ~ R);
2436
2437 assert!(A_dec.rsolve(L) * A ~ L);
2438 assert!(A_dec.rsolve_conjugate(L) * A.conjugate() ~ L);
2439 assert!(A_dec.rsolve_transpose(L) * A.transpose() ~ L);
2440 assert!(A_dec.rsolve_adjoint(L) * A.adjoint() ~ L);
2441 }
2442
2443 test_solver_imp(A, &A_dec)
2444 }
2445
2446 #[test]
2447 fn test_all_solvers() {
2448 let rng = &mut StdRng::seed_from_u64(0);
2449 let n = 50;
2450
2451 let ref A = CwiseMatDistribution {
2452 nrows: n,
2453 ncols: n,
2454 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2455 }
2456 .rand::<Mat<c64>>(rng);
2457 let A = A.rb();
2458
2459 test_solver(A, A.partial_piv_lu());
2460 test_solver(A, A.full_piv_lu());
2461 test_solver(A, A.qr());
2462 test_solver(A, A.col_piv_qr());
2463 test_solver(A, A.svd().unwrap());
2464
2465 {
2466 let ref A = A * A.adjoint();
2467 let A = A.rb();
2468 test_solver(A, A.llt(Side::Lower).unwrap());
2469 test_solver(A, A.ldlt(Side::Lower).unwrap());
2470 }
2471
2472 {
2473 let ref A = A + A.adjoint();
2474 let A = A.rb();
2475 test_solver(A, A.lblt(Side::Lower));
2476 test_solver(A, A.self_adjoint_eigen(Side::Lower).unwrap());
2477 }
2478 }
2479
2480 #[test]
2481 fn test_eigen_cplx() {
2482 let rng = &mut StdRng::seed_from_u64(0);
2483 let n = 50;
2484
2485 let A = CwiseMatDistribution {
2486 nrows: n,
2487 ncols: n,
2488 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2489 }
2490 .rand::<Mat<c64>>(rng);
2491
2492 let n = A.nrows();
2493 let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
2494
2495 let evd = A.eigen().unwrap();
2496 let e = A.eigenvalues().unwrap();
2497 assert!(&A * evd.U() ~ evd.U() * evd.S());
2498 assert!(evd.S().column_vector() ~ ColRef::from_slice(&e));
2499 }
2500
2501 #[test]
2502 fn test_eigen_real() {
2503 let rng = &mut StdRng::seed_from_u64(0);
2504 let n = 50;
2505
2506 let A = CwiseMatDistribution {
2507 nrows: n,
2508 ncols: n,
2509 dist: StandardNormal,
2510 }
2511 .rand::<Mat<f64>>(rng);
2512
2513 let n = A.nrows();
2514 let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
2515
2516 let evd = A.eigen().unwrap();
2517 let e = A.eigenvalues().unwrap();
2518
2519 let A = Mat::from_fn(A.nrows(), A.ncols(), |i, j| c64::from(A[(i, j)]));
2520
2521 assert!(&A * evd.U() ~ evd.U() * evd.S());
2522 assert!(evd.S().column_vector() ~ ColRef::from_slice(&e));
2523 }
2524
2525 #[test]
2526 fn test_svd_solver_for_rectangular_matrix() {
2527 #[rustfmt::skip]
2528 let A = crate::mat![
2529 [4., 5., 7.],
2530 [8., 8., 2.],
2531 [4., 0., 9.],
2532 [2., 6., 2.],
2533 [0., 6., 0.],
2534 ];
2535 #[rustfmt::skip]
2536 let B = crate::mat![
2537 [105., 49.],
2538 [ 98., 54.],
2539 [113., 35.],
2540 [ 46., 34.],
2541 [ 12., 24.],
2542 ];
2543
2544 #[rustfmt::skip]
2545 let X_true= crate::mat![
2546 [8., 2.],
2547 [2., 4.],
2548 [9., 3.],
2549 ];
2550
2551 let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (A.nrows() as f64));
2552 let svd = A.svd().unwrap();
2553 let mut X = B.cloned();
2554 svd.solve_lstsq_in_place_with_conj(crate::Conj::No, X.as_mat_mut());
2555 assert!(X.get(..X_true.nrows(),..) ~ X_true);
2556 }
2557}