faer/linalg/
solvers.rs

1use crate::internal_prelude::*;
2use crate::{assert, get_global_parallelism};
3use alloc::vec;
4use alloc::vec::Vec;
5use dyn_stack::MemBuffer;
6use faer_traits::{ComplexConj, math_utils};
7use linalg::svd::ComputeSvdVectors;
8
9pub use linalg::cholesky::ldlt::factor::LdltError;
10pub use linalg::cholesky::llt::factor::LltError;
11pub use linalg::evd::EvdError;
12pub use linalg::svd::SvdError;
13
14/// shape info of a linear system solver
15pub trait ShapeCore {
16	/// returns the number of rows of the matrix
17	fn nrows(&self) -> usize;
18	/// returns the number of columns of the matrix
19	fn ncols(&self) -> usize;
20}
21
22/// linear system solver implementation
23pub trait SolveCore<T: ComplexField>: ShapeCore {
24	/// solves the equation `self × x = rhs`, implicitly conjugating `self` if needed, and stores
25	/// the result in `rhs`
26	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>);
27	/// solves the equation `self.transpose() × x = rhs`, implicitly conjugating `self` if needed,
28	/// and stores the result in `rhs`
29	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>);
30}
31/// least squares linear system solver implementation
32pub trait SolveLstsqCore<T: ComplexField>: ShapeCore {
33	/// solves the equation `self × x = rhs` in the sense of least squares, implicitly conjugating
34	/// `self` if needed, and stores the result in the top rows of `rhs`
35	fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>);
36}
37/// dense linear system solver
38pub trait DenseSolveCore<T: ComplexField>: SolveCore<T> {
39	/// returns an approximation of the matrix that was used to create the decomposition
40	fn reconstruct(&self) -> Mat<T>;
41	/// returns an approximation of the inverse of the matrix that was used to create the
42	/// decomposition
43	fn inverse(&self) -> Mat<T>;
44}
45
46impl<S: ?Sized + ShapeCore> ShapeCore for &S {
47	#[inline]
48	fn nrows(&self) -> usize {
49		(**self).nrows()
50	}
51
52	#[inline]
53	fn ncols(&self) -> usize {
54		(**self).ncols()
55	}
56}
57
58impl<T: ComplexField, S: ?Sized + SolveCore<T>> SolveCore<T> for &S {
59	#[inline]
60	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
61		(**self).solve_in_place_with_conj(conj, rhs)
62	}
63
64	#[inline]
65	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
66		(**self).solve_transpose_in_place_with_conj(conj, rhs)
67	}
68}
69
70impl<T: ComplexField, S: ?Sized + SolveLstsqCore<T>> SolveLstsqCore<T> for &S {
71	#[inline]
72	fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
73		(**self).solve_lstsq_in_place_with_conj(conj, rhs)
74	}
75}
76
77impl<T: ComplexField, S: ?Sized + DenseSolveCore<T>> DenseSolveCore<T> for &S {
78	#[inline]
79	fn reconstruct(&self) -> Mat<T> {
80		(**self).reconstruct()
81	}
82
83	#[inline]
84	fn inverse(&self) -> Mat<T> {
85		(**self).inverse()
86	}
87}
88
89/// [`SolveCore`] extension trait
90pub trait Solve<T: ComplexField>: SolveCore<T> {
91	#[track_caller]
92	#[inline]
93	/// solves $A x = b$
94	fn solve_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
95		self.solve_in_place_with_conj(Conj::No, { rhs }.as_mat_mut().as_dyn_cols_mut());
96	}
97	#[track_caller]
98	#[inline]
99	/// solves $\bar A x = b$
100	fn solve_conjugate_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
101		self.solve_in_place_with_conj(Conj::Yes, { rhs }.as_mat_mut().as_dyn_cols_mut());
102	}
103
104	#[track_caller]
105	#[inline]
106	/// solves $A^\top x = b$
107	fn solve_transpose_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
108		self.solve_transpose_in_place_with_conj(Conj::No, { rhs }.as_mat_mut().as_dyn_cols_mut());
109	}
110	#[track_caller]
111	#[inline]
112	/// solves $A^H x = b$
113	fn solve_adjoint_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
114		self.solve_transpose_in_place_with_conj(Conj::Yes, { rhs }.as_mat_mut().as_dyn_cols_mut());
115	}
116
117	#[track_caller]
118	#[inline]
119	/// solves $x A = b$
120	fn rsolve_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
121		self.solve_transpose_in_place_with_conj(Conj::No, { lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut());
122	}
123	#[track_caller]
124	#[inline]
125	/// solves $x \bar A = b$
126	fn rsolve_conjugate_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
127		self.solve_transpose_in_place_with_conj(Conj::Yes, { lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut());
128	}
129
130	#[track_caller]
131	#[inline]
132	/// solves $x A^\top = b$
133	fn rsolve_transpose_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
134		self.solve_in_place_with_conj(Conj::No, { lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut());
135	}
136	#[track_caller]
137	#[inline]
138	/// solves $x A^H = b$
139	fn rsolve_adjoint_in_place(&self, lhs: impl AsMatMut<T = T, Cols = usize>) {
140		self.solve_in_place_with_conj(Conj::Yes, { lhs }.as_mat_mut().as_dyn_rows_mut().transpose_mut());
141	}
142
143	#[track_caller]
144	#[inline]
145	/// solves $A x = b$
146	fn solve<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
147		let rhs = rhs.as_mat_ref();
148		let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
149		out.as_mat_mut().copy_from(rhs);
150		self.solve_in_place(&mut out);
151		out
152	}
153	#[track_caller]
154	#[inline]
155	/// solves $\bar A x = b$
156	fn solve_conjugate<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
157		let rhs = rhs.as_mat_ref();
158		let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
159		out.as_mat_mut().copy_from(rhs);
160		self.solve_conjugate_in_place(&mut out);
161		out
162	}
163
164	#[track_caller]
165	#[inline]
166	/// solves $A^\top x = b$
167	fn solve_transpose<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
168		let rhs = rhs.as_mat_ref();
169		let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
170		out.as_mat_mut().copy_from(rhs);
171		self.solve_transpose_in_place(&mut out);
172		out
173	}
174	#[track_caller]
175	#[inline]
176	/// solves $A^H x = b$
177	fn solve_adjoint<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
178		let rhs = rhs.as_mat_ref();
179		let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
180		out.as_mat_mut().copy_from(rhs);
181		self.solve_adjoint_in_place(&mut out);
182		out
183	}
184
185	#[track_caller]
186	#[inline]
187	/// solves $x A = b$
188	fn rsolve<Lhs: AsMatRef<T = T, Cols = usize>>(&self, lhs: Lhs) -> Lhs::Owned {
189		let lhs = lhs.as_mat_ref();
190		let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
191		out.as_mat_mut().copy_from(lhs);
192		self.rsolve_in_place(&mut out);
193		out
194	}
195	#[track_caller]
196	#[inline]
197	/// solves $x \bar A = b$
198	fn rsolve_conjugate<Lhs: AsMatRef<T = T, Cols = usize>>(&self, lhs: Lhs) -> Lhs::Owned {
199		let lhs = lhs.as_mat_ref();
200		let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
201		out.as_mat_mut().copy_from(lhs);
202		self.rsolve_conjugate_in_place(&mut out);
203		out
204	}
205
206	#[track_caller]
207	#[inline]
208	/// solves $x A^\top = b$
209	fn rsolve_transpose<Lhs: AsMatRef<T = T, Cols = usize>>(&self, lhs: Lhs) -> Lhs::Owned {
210		let lhs = lhs.as_mat_ref();
211		let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
212		out.as_mat_mut().copy_from(lhs);
213		self.rsolve_transpose_in_place(&mut out);
214		out
215	}
216	#[track_caller]
217	#[inline]
218	/// solves $x A^H = b$
219	fn rsolve_adjoint<Lhs: AsMatRef<T = T, Cols = usize>>(&self, lhs: Lhs) -> Lhs::Owned {
220		let lhs = lhs.as_mat_ref();
221		let mut out = Lhs::Owned::zeros(lhs.nrows(), lhs.ncols());
222		out.as_mat_mut().copy_from(lhs);
223		self.rsolve_adjoint_in_place(&mut out);
224		out
225	}
226}
227
228impl<C: Conjugate, Inner: for<'short> Reborrow<'short, Target = mat::Ref<'short, C>>> mat::generic::Mat<Inner> {
229	#[track_caller]
230	/// returns the $LU$ decomposition of `self` with partial (row) pivoting
231	pub fn partial_piv_lu(&self) -> PartialPivLu<C::Canonical> {
232		PartialPivLu::new(self.rb())
233	}
234
235	#[track_caller]
236	/// returns the $LU$ decomposition of `self` with full pivoting
237	pub fn full_piv_lu(&self) -> FullPivLu<C::Canonical> {
238		FullPivLu::new(self.rb())
239	}
240
241	#[track_caller]
242	/// returns the $QR$ decomposition of `self`
243	pub fn qr(&self) -> Qr<C::Canonical> {
244		Qr::new(self.rb())
245	}
246
247	#[track_caller]
248	/// returns the $QR$ decomposition of `self` with column pivoting
249	pub fn col_piv_qr(&self) -> ColPivQr<C::Canonical> {
250		ColPivQr::new(self.rb())
251	}
252
253	#[track_caller]
254	/// returns the svd of `self`
255	///
256	/// singular values are nonnegative and sorted in nonincreasing order
257	pub fn svd(&self) -> Result<Svd<C::Canonical>, SvdError> {
258		Svd::new(self.rb())
259	}
260
261	#[track_caller]
262	/// returns the thin svd of `self`
263	///
264	/// singular values are nonnegative and sorted in nonincreasing order
265	pub fn thin_svd(&self) -> Result<Svd<C::Canonical>, SvdError> {
266		Svd::new_thin(self.rb())
267	}
268
269	#[track_caller]
270	/// returns the $L L^\top$ decomposition of `self`
271	pub fn llt(&self, side: Side) -> Result<Llt<C::Canonical>, LltError> {
272		Llt::new(self.rb(), side)
273	}
274
275	#[track_caller]
276	/// returns the $L D L^\top$ decomposition of `self`
277	pub fn ldlt(&self, side: Side) -> Result<Ldlt<C::Canonical>, LdltError> {
278		Ldlt::new(self.rb(), side)
279	}
280
281	#[track_caller]
282	/// returns the $LBL^\top$ decomposition of `self`
283	pub fn lblt(&self, side: Side) -> Lblt<C::Canonical> {
284		Lblt::new(self.rb(), side)
285	}
286
287	#[track_caller]
288	/// returns the eigendecomposition of `self`, assuming it is self-adjoint
289	///
290	/// eigenvalues sorted in nondecreasing order
291	pub fn self_adjoint_eigen(&self, side: Side) -> Result<SelfAdjointEigen<C::Canonical>, EvdError> {
292		SelfAdjointEigen::new(self.rb(), side)
293	}
294
295	#[track_caller]
296	/// returns the eigenvalues of `self`, assuming it is self-adjoint
297	///
298	/// eigenvalues sorted in nondecreasing order
299	pub fn self_adjoint_eigenvalues(&self, side: Side) -> Result<Vec<Real<C>>, EvdError> {
300		#[track_caller]
301		pub fn imp<T: ComplexField>(mut A: MatRef<'_, T>, side: Side) -> Result<Vec<T::Real>, EvdError> {
302			assert!(A.nrows() == A.ncols());
303			if side == Side::Upper {
304				A = A.transpose();
305			}
306			let par = get_global_parallelism();
307			let n = A.nrows();
308
309			let mut s = Diag::<T>::zeros(n);
310
311			linalg::evd::self_adjoint_evd(
312				A,
313				s.as_mut(),
314				None,
315				par,
316				MemStack::new(&mut MemBuffer::new(linalg::evd::self_adjoint_evd_scratch::<T>(
317					n,
318					linalg::evd::ComputeEigenvectors::No,
319					par,
320					default(),
321				))),
322				default(),
323			)?;
324
325			Ok(s.column_vector().iter().map(|x| real(x)).collect())
326		}
327
328		imp(self.rb().canonical(), side)
329	}
330
331	#[track_caller]
332	/// returns the singular values of `self`
333	///
334	/// singular values are nonnegative and sorted in nonincreasing order
335	pub fn singular_values(&self) -> Result<Vec<Real<C>>, SvdError> {
336		pub fn imp<T: ComplexField>(A: MatRef<'_, T>) -> Result<Vec<T::Real>, SvdError> {
337			let par = get_global_parallelism();
338			let m = A.nrows();
339			let n = A.ncols();
340
341			let mut s = Diag::<T>::zeros(Ord::min(m, n));
342
343			linalg::svd::svd(
344				A,
345				s.as_mut(),
346				None,
347				None,
348				par,
349				MemStack::new(&mut MemBuffer::new(linalg::svd::svd_scratch::<T>(
350					m,
351					n,
352					linalg::svd::ComputeSvdVectors::No,
353					linalg::svd::ComputeSvdVectors::No,
354					par,
355					default(),
356				))),
357				default(),
358			)?;
359
360			Ok(s.column_vector().iter().map(|x| real(x)).collect())
361		}
362
363		imp(self.rb().canonical())
364	}
365}
366
367impl<C: Conjugate> MatRef<'_, C> {
368	#[track_caller]
369	fn eigen_imp(&self) -> Result<Eigen<Real<C>>, EvdError> {
370		if const { C::Canonical::IS_REAL } {
371			Eigen::new_from_real(unsafe { crate::hacks::coerce(*self) })
372		} else if const { C::IS_CANONICAL } {
373			Eigen::new(unsafe { crate::hacks::coerce::<_, MatRef<'_, Complex<Real<C>>>>(*self) })
374		} else {
375			Eigen::new(unsafe { crate::hacks::coerce::<_, MatRef<'_, ComplexConj<Real<C>>>>(*self) })
376		}
377	}
378
379	#[track_caller]
380	fn eigenvalues_imp(&self) -> Result<Vec<Complex<Real<C>>>, EvdError> {
381		let par = get_global_parallelism();
382
383		if const { C::Canonical::IS_REAL } {
384			let A = unsafe { crate::hacks::coerce::<_, MatRef<'_, Real<C>>>(*self) };
385			assert!(A.nrows() == A.ncols());
386			let n = A.nrows();
387
388			let mut s_re = Diag::<Real<C>>::zeros(n);
389			let mut s_im = Diag::<Real<C>>::zeros(n);
390
391			linalg::evd::evd_real(
392				A,
393				s_re.as_mut(),
394				s_im.as_mut(),
395				None,
396				None,
397				par,
398				MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<Real<C>>(
399					n,
400					linalg::evd::ComputeEigenvectors::No,
401					linalg::evd::ComputeEigenvectors::No,
402					par,
403					default(),
404				))),
405				default(),
406			)?;
407
408			Ok(s_re
409				.column_vector()
410				.iter()
411				.zip(s_im.column_vector().iter())
412				.map(|(re, im)| Complex::new(re.clone(), im.clone()))
413				.collect())
414		} else {
415			let A = unsafe { crate::hacks::coerce::<_, MatRef<'_, Complex<Real<C>>>>(self.canonical()) };
416			assert!(A.nrows() == A.ncols());
417			let n = A.nrows();
418
419			let mut s = Diag::<Complex<Real<C>>>::zeros(n);
420
421			linalg::evd::evd_cplx(
422				A,
423				s.as_mut(),
424				None,
425				None,
426				par,
427				MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<Complex<Real<C>>>(
428					n,
429					linalg::evd::ComputeEigenvectors::No,
430					linalg::evd::ComputeEigenvectors::No,
431					par,
432					default(),
433				))),
434				default(),
435			)?;
436
437			if const { C::IS_CANONICAL } {
438				Ok(s.column_vector().iter().cloned().collect())
439			} else {
440				Ok(s.column_vector().iter().map(conj).collect())
441			}
442		}
443	}
444}
445
446impl<T: Conjugate, Inner: for<'short> Reborrow<'short, Target = mat::Ref<'short, T>>> mat::generic::Mat<Inner> {
447	/// returns the eigendecomposition of `self`
448	#[track_caller]
449	pub fn eigen(&self) -> Result<Eigen<Real<T>>, EvdError> {
450		self.rb().eigen_imp()
451	}
452
453	/// returns the eigenvalues of `self`
454	#[track_caller]
455	pub fn eigenvalues(&self) -> Result<Vec<Complex<Real<T>>>, EvdError> {
456		self.rb().eigenvalues_imp()
457	}
458}
459
460/// [`SolveLstsqCore`] extension trait
461pub trait SolveLstsq<T: ComplexField>: SolveLstsqCore<T> {
462	#[track_caller]
463	#[inline]
464	/// solves $A x = b$ in the sense of least squares.
465	fn solve_lstsq_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
466		self.solve_lstsq_in_place_with_conj(Conj::No, { rhs }.as_mat_mut().as_dyn_cols_mut());
467	}
468
469	#[track_caller]
470	#[inline]
471	/// solves $\bar A x = b$ in the sense of least squares.
472	fn solve_conjugate_lstsq_in_place(&self, rhs: impl AsMatMut<T = T, Rows = usize>) {
473		self.solve_lstsq_in_place_with_conj(Conj::Yes, { rhs }.as_mat_mut().as_dyn_cols_mut());
474	}
475
476	#[track_caller]
477	#[inline]
478	/// solves $A x = b$ in the sense of least squares.
479	fn solve_lstsq<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
480		let rhs = rhs.as_mat_ref();
481		let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
482		out.as_mat_mut().copy_from(rhs);
483		self.solve_lstsq_in_place(&mut out);
484		out.truncate(self.ncols(), rhs.ncols());
485		out
486	}
487	#[track_caller]
488	#[inline]
489	/// solves $\bar A x = b$ in the sense of least squares.
490	fn solve_conjugate_lstsq<Rhs: AsMatRef<T = T, Rows = usize>>(&self, rhs: Rhs) -> Rhs::Owned {
491		let rhs = rhs.as_mat_ref();
492		let mut out = Rhs::Owned::zeros(rhs.nrows(), rhs.ncols());
493		out.as_mat_mut().copy_from(rhs);
494		self.solve_conjugate_lstsq_in_place(&mut out);
495		out.truncate(self.ncols(), rhs.ncols());
496		out
497	}
498}
499/// [`DenseSolveCore`] extension trait
500pub trait DenseSolve<T: ComplexField>: DenseSolveCore<T> {}
501
502impl<T: ComplexField, S: ?Sized + SolveCore<T>> Solve<T> for S {}
503impl<T: ComplexField, S: ?Sized + SolveLstsqCore<T>> SolveLstsq<T> for S {}
504impl<T: ComplexField, S: ?Sized + DenseSolveCore<T>> DenseSolve<T> for S {}
505
506/// $L L^\top$ decomposition
507#[derive(Clone, Debug)]
508pub struct Llt<T> {
509	L: Mat<T>,
510}
511
512/// $L D L^\top$ decomposition
513#[derive(Clone, Debug)]
514pub struct Ldlt<T> {
515	L: Mat<T>,
516	D: Diag<T>,
517}
518
519/// $LBL^\top$ decomposition
520#[derive(Clone, Debug)]
521pub struct Lblt<T> {
522	L: Mat<T>,
523	B_diag: Diag<T>,
524	B_subdiag: Diag<T>,
525	P: Perm<usize>,
526}
527
528/// $LU$ decomposition with partial (row) pivoting
529#[derive(Clone, Debug)]
530pub struct PartialPivLu<T> {
531	L: Mat<T>,
532	U: Mat<T>,
533	P: Perm<usize>,
534}
535
536/// $LU$ decomposition with full pivoting
537#[derive(Clone, Debug)]
538pub struct FullPivLu<T> {
539	L: Mat<T>,
540	U: Mat<T>,
541	P: Perm<usize>,
542	Q: Perm<usize>,
543}
544
545/// $QR$ decomposition
546#[derive(Clone, Debug)]
547pub struct Qr<T> {
548	Q_basis: Mat<T>,
549	Q_coeff: Mat<T>,
550	R: Mat<T>,
551}
552
553/// $QR$ decomposition with column pivoting
554#[derive(Clone, Debug)]
555pub struct ColPivQr<T> {
556	Q_basis: Mat<T>,
557	Q_coeff: Mat<T>,
558	R: Mat<T>,
559	P: Perm<usize>,
560}
561
562/// svd decomposition (either full or thin)
563#[derive(Clone, Debug)]
564pub struct Svd<T> {
565	U: Mat<T>,
566	V: Mat<T>,
567	S: Diag<T>,
568}
569
570/// self-adjoint eigendecomposition
571#[derive(Clone, Debug)]
572pub struct SelfAdjointEigen<T> {
573	U: Mat<T>,
574	S: Diag<T>,
575}
576
577/// eigendecomposition
578#[derive(Clone, Debug)]
579pub struct Eigen<T> {
580	U: Mat<Complex<T>>,
581	S: Diag<Complex<T>>,
582}
583
584impl<T: ComplexField> Llt<T> {
585	/// returns the $L L^\top$ decomposition of $A$
586	#[track_caller]
587	pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>, side: Side) -> Result<Self, LltError> {
588		assert!(all(A.nrows() == A.ncols()));
589		let n = A.nrows();
590
591		let mut L = Mat::zeros(n, n);
592		match side {
593			Side::Lower => L.copy_from_triangular_lower(A),
594			Side::Upper => L.copy_from_triangular_lower(A.adjoint()),
595		}
596
597		Self::new_imp(L)
598	}
599
600	#[track_caller]
601	fn new_imp(mut L: Mat<T>) -> Result<Self, LltError> {
602		let par = get_global_parallelism();
603
604		let n = L.nrows();
605
606		let mut mem = MemBuffer::new(linalg::cholesky::llt::factor::cholesky_in_place_scratch::<T>(n, par, default()));
607		let stack = MemStack::new(&mut mem);
608
609		linalg::cholesky::llt::factor::cholesky_in_place(L.as_mut(), Default::default(), par, stack, default())?;
610		z!(&mut L).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
611
612		Ok(Self { L })
613	}
614
615	/// returns the $L$ factor
616	pub fn L(&self) -> MatRef<'_, T> {
617		self.L.as_ref()
618	}
619}
620
621impl<T: ComplexField> Ldlt<T> {
622	/// returns the $L D L^\top$ decomposition of $A$
623	#[track_caller]
624	pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>, side: Side) -> Result<Self, LdltError> {
625		assert!(all(A.nrows() == A.ncols()));
626		let n = A.nrows();
627
628		let mut L = Mat::zeros(n, n);
629		match side {
630			Side::Lower => L.copy_from_triangular_lower(A),
631			Side::Upper => L.copy_from_triangular_lower(A.adjoint()),
632		}
633
634		Self::new_imp(L)
635	}
636
637	#[track_caller]
638	fn new_imp(mut L: Mat<T>) -> Result<Self, LdltError> {
639		let par = get_global_parallelism();
640
641		let n = L.nrows();
642		let mut D = Diag::zeros(n);
643
644		let mut mem = MemBuffer::new(linalg::cholesky::ldlt::factor::cholesky_in_place_scratch::<T>(n, par, default()));
645		let stack = MemStack::new(&mut mem);
646
647		linalg::cholesky::ldlt::factor::cholesky_in_place(L.as_mut(), Default::default(), par, stack, default())?;
648
649		D.copy_from(L.diagonal());
650		L.diagonal_mut().fill(one());
651		z!(&mut L).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
652
653		Ok(Self { L, D })
654	}
655
656	/// returns the $L$ factor
657	pub fn L(&self) -> MatRef<'_, T> {
658		self.L.as_ref()
659	}
660
661	/// returns the $D$ factor
662	pub fn D(&self) -> DiagRef<'_, T> {
663		self.D.as_ref()
664	}
665}
666
667impl<T: ComplexField> Lblt<T> {
668	/// returns the $LBL^\top$ decomposition of $A$
669	#[track_caller]
670	pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>, side: Side) -> Self {
671		assert!(all(A.nrows() == A.ncols()));
672		let n = A.nrows();
673
674		let mut L = Mat::zeros(n, n);
675		match side {
676			Side::Lower => L.copy_from_triangular_lower(A),
677			Side::Upper => L.copy_from_triangular_lower(A.adjoint()),
678		}
679		Self::new_imp(L)
680	}
681
682	#[track_caller]
683	fn new_imp(mut L: Mat<T>) -> Self {
684		let par = get_global_parallelism();
685
686		let n = L.nrows();
687
688		let mut diag = Diag::zeros(n);
689		let mut subdiag = Diag::zeros(n);
690		let mut perm_fwd = vec![0usize; n];
691		let mut perm_bwd = vec![0usize; n];
692
693		let mut mem = MemBuffer::new(linalg::cholesky::lblt::factor::cholesky_in_place_scratch::<usize, T>(n, par, default()));
694		let stack = MemStack::new(&mut mem);
695
696		linalg::cholesky::lblt::factor::cholesky_in_place(L.as_mut(), subdiag.as_mut(), &mut perm_fwd, &mut perm_bwd, par, stack, default());
697
698		diag.copy_from(L.diagonal());
699		L.diagonal_mut().fill(one());
700		z!(&mut L).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
701
702		Self {
703			L,
704			B_diag: diag,
705			B_subdiag: subdiag,
706			P: unsafe { Perm::new_unchecked(perm_fwd.into_boxed_slice(), perm_bwd.into_boxed_slice()) },
707		}
708	}
709
710	/// returns the $L$ factor
711	pub fn L(&self) -> MatRef<'_, T> {
712		self.L.as_ref()
713	}
714
715	/// returns the diagonal of the $B$ factor
716	pub fn B_diag(&self) -> DiagRef<'_, T> {
717		self.B_diag.as_ref()
718	}
719
720	/// returns the subdiagonal of the $B$ factor
721	pub fn B_subdiag(&self) -> DiagRef<'_, T> {
722		self.B_subdiag.as_ref()
723	}
724
725	/// returns the pivoting permutation $P$
726	pub fn P(&self) -> PermRef<'_, usize> {
727		self.P.as_ref()
728	}
729}
730
731fn split_LU<T: ComplexField>(LU: Mat<T>) -> (Mat<T>, Mat<T>) {
732	let (m, n) = LU.shape();
733	let size = Ord::min(m, n);
734
735	let (L, U) = if m >= n {
736		let mut L = LU;
737		let mut U = Mat::zeros(size, size);
738
739		U.copy_from_triangular_upper(L.get(..size, ..size));
740
741		z!(&mut L).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
742		L.diagonal_mut().fill(one());
743
744		(L, U)
745	} else {
746		let mut U = LU;
747		let mut L = Mat::zeros(size, size);
748
749		L.copy_from_strict_triangular_lower(U.get(..size, ..size));
750
751		z!(&mut U).for_each_triangular_lower(linalg::zip::Diag::Skip, |uz!(x)| *x = zero());
752		L.diagonal_mut().fill(one());
753
754		(L, U)
755	};
756	(L, U)
757}
758
759impl<T: ComplexField> PartialPivLu<T> {
760	/// returns the $LU$ decomposition of $A$ with partial pivoting
761	#[track_caller]
762	pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
763		let LU = A.to_owned();
764		Self::new_imp(LU)
765	}
766
767	#[track_caller]
768	fn new_imp(mut LU: Mat<T>) -> Self {
769		let par = get_global_parallelism();
770
771		let (m, n) = LU.shape();
772		let mut row_perm_fwd = vec![0usize; m];
773		let mut row_perm_bwd = vec![0usize; m];
774
775		linalg::lu::partial_pivoting::factor::lu_in_place(
776			LU.as_mut(),
777			&mut row_perm_fwd,
778			&mut row_perm_bwd,
779			par,
780			MemStack::new(&mut MemBuffer::new(
781				linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<usize, T>(m, n, par, default()),
782			)),
783			default(),
784		);
785
786		let (L, U) = split_LU(LU);
787
788		Self {
789			L,
790			U,
791			P: unsafe { Perm::new_unchecked(row_perm_fwd.into_boxed_slice(), row_perm_bwd.into_boxed_slice()) },
792		}
793	}
794
795	/// returns the $L$ factor
796	pub fn L(&self) -> MatRef<'_, T> {
797		self.L.as_ref()
798	}
799
800	/// returns the $U$ factor
801	pub fn U(&self) -> MatRef<'_, T> {
802		self.U.as_ref()
803	}
804
805	/// returns the row pivoting permutation $P$
806	pub fn P(&self) -> PermRef<'_, usize> {
807		self.P.as_ref()
808	}
809}
810
811impl<T: ComplexField> FullPivLu<T> {
812	/// returns the $LU$ decomposition of $A$ with full pivoting
813	#[track_caller]
814	pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
815		let LU = A.to_owned();
816		Self::new_imp(LU)
817	}
818
819	#[track_caller]
820	fn new_imp(mut LU: Mat<T>) -> Self {
821		let par = get_global_parallelism();
822
823		let (m, n) = LU.shape();
824		let mut row_perm_fwd = vec![0usize; m];
825		let mut row_perm_bwd = vec![0usize; m];
826		let mut col_perm_fwd = vec![0usize; n];
827		let mut col_perm_bwd = vec![0usize; n];
828
829		linalg::lu::full_pivoting::factor::lu_in_place(
830			LU.as_mut(),
831			&mut row_perm_fwd,
832			&mut row_perm_bwd,
833			&mut col_perm_fwd,
834			&mut col_perm_bwd,
835			par,
836			MemStack::new(&mut MemBuffer::new(linalg::lu::full_pivoting::factor::lu_in_place_scratch::<usize, T>(
837				m,
838				n,
839				par,
840				default(),
841			))),
842			default(),
843		);
844
845		let (L, U) = split_LU(LU);
846
847		Self {
848			L,
849			U,
850			P: unsafe { Perm::new_unchecked(row_perm_fwd.into_boxed_slice(), row_perm_bwd.into_boxed_slice()) },
851			Q: unsafe { Perm::new_unchecked(col_perm_fwd.into_boxed_slice(), col_perm_bwd.into_boxed_slice()) },
852		}
853	}
854
855	/// returns the factor $L$
856	pub fn L(&self) -> MatRef<'_, T> {
857		self.L.as_ref()
858	}
859
860	/// returns the factor $U$
861	pub fn U(&self) -> MatRef<'_, T> {
862		self.U.as_ref()
863	}
864
865	/// returns the row pivoting permutation $P$
866	pub fn P(&self) -> PermRef<'_, usize> {
867		self.P.as_ref()
868	}
869
870	/// returns the column pivoting permutation $P$
871	pub fn Q(&self) -> PermRef<'_, usize> {
872		self.Q.as_ref()
873	}
874}
875
876impl<T: ComplexField> Qr<T> {
877	/// returns the $QR$ decomposition of $A$
878	#[track_caller]
879	pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
880		let QR = A.to_owned();
881		Self::new_imp(QR)
882	}
883
884	#[track_caller]
885	fn new_imp(mut QR: Mat<T>) -> Self {
886		let par = get_global_parallelism();
887
888		let (m, n) = QR.shape();
889		let size = Ord::min(m, n);
890
891		let blocksize = linalg::qr::no_pivoting::factor::recommended_blocksize::<T>(m, n);
892		let mut Q_coeff = Mat::zeros(blocksize, size);
893
894		linalg::qr::no_pivoting::factor::qr_in_place(
895			QR.as_mut(),
896			Q_coeff.as_mut(),
897			par,
898			MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::factor::qr_in_place_scratch::<T>(
899				m,
900				n,
901				blocksize,
902				par,
903				default(),
904			))),
905			default(),
906		);
907
908		let (Q_basis, R) = split_LU(QR);
909
910		Self { Q_basis, Q_coeff, R }
911	}
912
913	/// returns the householder basis of $Q$
914	pub fn Q_basis(&self) -> MatRef<'_, T> {
915		self.Q_basis.as_ref()
916	}
917
918	/// returns the householder coefficients of $Q$
919	pub fn Q_coeff(&self) -> MatRef<'_, T> {
920		self.Q_coeff.as_ref()
921	}
922
923	/// returns the factor $R$
924	pub fn R(&self) -> MatRef<'_, T> {
925		self.R.as_ref()
926	}
927
928	/// returns the upper trapezoidal part of $R$
929	pub fn thin_R(&self) -> MatRef<'_, T> {
930		let size = Ord::min(self.nrows(), self.ncols());
931		self.R.get(..size, ..)
932	}
933
934	/// computes the factor $Q$
935	pub fn compute_Q(&self) -> Mat<T> {
936		let mut Q = Mat::identity(self.nrows(), self.nrows());
937		let par = get_global_parallelism();
938		linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
939			self.Q_basis(),
940			self.Q_coeff(),
941			Conj::No,
942			Q.rb_mut(),
943			par,
944			MemStack::new(&mut MemBuffer::new(
945				linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
946					self.nrows(),
947					self.Q_coeff.nrows(),
948					self.nrows(),
949				),
950			)),
951		);
952		Q
953	}
954
955	/// computes the first $\min(\text{nrows}, \text{ncols})$ columns of the factor $Q$
956	pub fn compute_thin_Q(&self) -> Mat<T> {
957		let size = Ord::min(self.nrows(), self.ncols());
958		let mut Q = Mat::identity(self.nrows(), size);
959		let par = get_global_parallelism();
960		linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
961			self.Q_basis(),
962			self.Q_coeff(),
963			Conj::No,
964			Q.rb_mut(),
965			par,
966			MemStack::new(&mut MemBuffer::new(
967				linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(self.nrows(), self.Q_coeff.nrows(), size),
968			)),
969		);
970		Q
971	}
972}
973
974impl<T: ComplexField> ColPivQr<T> {
975	/// returns the $QR$ decomposition of $A$ with column pivoting
976	#[track_caller]
977	pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Self {
978		let QR = A.to_owned();
979		Self::new_imp(QR)
980	}
981
982	#[track_caller]
983	fn new_imp(mut QR: Mat<T>) -> Self {
984		let par = get_global_parallelism();
985
986		let (m, n) = QR.shape();
987		let size = Ord::min(m, n);
988
989		let mut col_perm_fwd = vec![0usize; n];
990		let mut col_perm_bwd = vec![0usize; n];
991
992		let blocksize = linalg::qr::no_pivoting::factor::recommended_blocksize::<T>(m, n);
993		let mut Q_coeff = Mat::zeros(blocksize, size);
994
995		linalg::qr::col_pivoting::factor::qr_in_place(
996			QR.as_mut(),
997			Q_coeff.as_mut(),
998			&mut col_perm_fwd,
999			&mut col_perm_bwd,
1000			par,
1001			MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::factor::qr_in_place_scratch::<usize, T>(
1002				m,
1003				n,
1004				blocksize,
1005				par,
1006				default(),
1007			))),
1008			default(),
1009		);
1010
1011		let (Q_basis, R) = split_LU(QR);
1012
1013		Self {
1014			Q_basis,
1015			Q_coeff,
1016			R,
1017			P: unsafe { Perm::new_unchecked(col_perm_fwd.into_boxed_slice(), col_perm_bwd.into_boxed_slice()) },
1018		}
1019	}
1020
1021	/// returns the householder basis of $Q$
1022	pub fn Q_basis(&self) -> MatRef<'_, T> {
1023		self.Q_basis.as_ref()
1024	}
1025
1026	/// returns the householder coefficients of $Q$
1027	pub fn Q_coeff(&self) -> MatRef<'_, T> {
1028		self.Q_coeff.as_ref()
1029	}
1030
1031	/// returns the factor $R$
1032	pub fn R(&self) -> MatRef<'_, T> {
1033		self.R.as_ref()
1034	}
1035
1036	/// returns the upper trapezoidal part of $R$
1037	pub fn thin_R(&self) -> MatRef<'_, T> {
1038		let size = Ord::min(self.nrows(), self.ncols());
1039		self.R.get(..size, ..)
1040	}
1041
1042	/// computes the factor $Q$
1043	pub fn compute_Q(&self) -> Mat<T> {
1044		let mut Q = Mat::identity(self.nrows(), self.nrows());
1045		let par = get_global_parallelism();
1046		linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
1047			self.Q_basis(),
1048			self.Q_coeff(),
1049			Conj::No,
1050			Q.rb_mut(),
1051			par,
1052			MemStack::new(&mut MemBuffer::new(
1053				linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(
1054					self.nrows(),
1055					self.Q_coeff.nrows(),
1056					self.nrows(),
1057				),
1058			)),
1059		);
1060		Q
1061	}
1062
1063	/// computes the first $\min(\text{nrows}, \text{ncols})$ columns of the factor $Q$
1064	pub fn compute_thin_Q(&self) -> Mat<T> {
1065		let size = Ord::min(self.nrows(), self.ncols());
1066		let mut Q = Mat::identity(self.nrows(), size);
1067		let par = get_global_parallelism();
1068		linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
1069			self.Q_basis(),
1070			self.Q_coeff(),
1071			Conj::No,
1072			Q.rb_mut(),
1073			par,
1074			MemStack::new(&mut MemBuffer::new(
1075				linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(self.nrows(), self.Q_coeff.nrows(), size),
1076			)),
1077		);
1078		Q
1079	}
1080
1081	/// returns the column pivoting permutation $P$
1082	pub fn P(&self) -> PermRef<'_, usize> {
1083		self.P.as_ref()
1084	}
1085}
1086
1087impl<T: ComplexField> Svd<T> {
1088	/// returns the svd of $A$
1089	#[track_caller]
1090	pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Result<Self, SvdError> {
1091		Self::new_imp(A.canonical(), Conj::get::<C>(), false)
1092	}
1093
1094	/// returns the thin svd of $A$
1095	#[track_caller]
1096	pub fn new_thin<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>) -> Result<Self, SvdError> {
1097		Self::new_imp(A.canonical(), Conj::get::<C>(), true)
1098	}
1099
1100	#[track_caller]
1101	fn new_imp(A: MatRef<'_, T>, conj: Conj, thin: bool) -> Result<Self, SvdError> {
1102		let par = get_global_parallelism();
1103
1104		let (m, n) = A.shape();
1105		let size = Ord::min(m, n);
1106
1107		let mut U = Mat::zeros(m, if thin { size } else { m });
1108		let mut V = Mat::zeros(n, if thin { size } else { n });
1109		let mut S = Diag::zeros(size);
1110
1111		let compute = if thin { ComputeSvdVectors::Thin } else { ComputeSvdVectors::Full };
1112
1113		linalg::svd::svd(
1114			A,
1115			S.as_mut(),
1116			Some(U.as_mut()),
1117			Some(V.as_mut()),
1118			par,
1119			MemStack::new(&mut MemBuffer::new(linalg::svd::svd_scratch::<T>(m, n, compute, compute, par, default()))),
1120			default(),
1121		)?;
1122
1123		if conj == Conj::Yes {
1124			for c in U.col_iter_mut() {
1125				for x in c.iter_mut() {
1126					*x = math_utils::conj(x);
1127				}
1128			}
1129			for c in V.col_iter_mut() {
1130				for x in c.iter_mut() {
1131					*x = math_utils::conj(x);
1132				}
1133			}
1134		}
1135
1136		Ok(Self { U, V, S })
1137	}
1138
1139	/// returns the factor $U$
1140	pub fn U(&self) -> MatRef<'_, T> {
1141		self.U.as_ref()
1142	}
1143
1144	/// returns the factor $V$
1145	pub fn V(&self) -> MatRef<'_, T> {
1146		self.V.as_ref()
1147	}
1148
1149	/// returns the factor $S$
1150	pub fn S(&self) -> DiagRef<'_, T> {
1151		self.S.as_ref()
1152	}
1153
1154	/// returns the pseudoinverse of the original matrix $A$.
1155	pub fn pseudoinverse(&self) -> Mat<T> {
1156		let U = self.U();
1157		let V = self.V();
1158		let S = self.S();
1159		let par = get_global_parallelism();
1160		let stack = &mut MemBuffer::new(linalg::svd::pseudoinverse_from_svd_scratch::<T>(self.nrows(), self.ncols(), par));
1161		let mut pinv = Mat::zeros(self.nrows(), self.ncols());
1162		linalg::svd::pseudoinverse_from_svd(pinv.rb_mut(), S, U, V, par, MemStack::new(stack));
1163		pinv
1164	}
1165}
1166
1167impl<T: ComplexField> SelfAdjointEigen<T> {
1168	/// returns the eigendecomposition of $A$, assuming it is self-adjoint
1169	#[track_caller]
1170	pub fn new<C: Conjugate<Canonical = T>>(A: MatRef<'_, C>, side: Side) -> Result<Self, EvdError> {
1171		assert!(A.nrows() == A.ncols());
1172
1173		match side {
1174			Side::Lower => Self::new_imp(A.canonical(), Conj::get::<C>()),
1175			Side::Upper => Self::new_imp(A.adjoint().canonical(), Conj::get::<C::Conj>()),
1176		}
1177	}
1178
1179	#[track_caller]
1180	fn new_imp(A: MatRef<'_, T>, conj: Conj) -> Result<Self, EvdError> {
1181		let par = get_global_parallelism();
1182
1183		let n = A.nrows();
1184
1185		let mut U = Mat::zeros(n, n);
1186		let mut S = Diag::zeros(n);
1187
1188		linalg::evd::self_adjoint_evd(
1189			A,
1190			S.as_mut(),
1191			Some(U.as_mut()),
1192			par,
1193			MemStack::new(&mut MemBuffer::new(linalg::evd::self_adjoint_evd_scratch::<T>(
1194				n,
1195				linalg::evd::ComputeEigenvectors::Yes,
1196				par,
1197				default(),
1198			))),
1199			default(),
1200		)?;
1201
1202		if conj == Conj::Yes {
1203			for c in U.col_iter_mut() {
1204				for x in c.iter_mut() {
1205					*x = math_utils::conj(x);
1206				}
1207			}
1208		}
1209
1210		Ok(Self { U, S })
1211	}
1212
1213	/// returns the factor $U$
1214	pub fn U(&self) -> MatRef<'_, T> {
1215		self.U.as_ref()
1216	}
1217
1218	/// returns the factor $S$
1219	pub fn S(&self) -> DiagRef<'_, T> {
1220		self.S.as_ref()
1221	}
1222}
1223
1224impl<T: RealField> Eigen<T> {
1225	/// returns the eigendecomposition of $A$
1226	#[track_caller]
1227	pub fn new<C: Conjugate<Canonical = Complex<T>>>(A: MatRef<'_, C>) -> Result<Self, EvdError> {
1228		assert!(A.nrows() == A.ncols());
1229		Self::new_imp(A.canonical(), Conj::get::<C>())
1230	}
1231
1232	/// returns the eigendecomposition of $A$
1233	#[track_caller]
1234	pub fn new_from_real(A: MatRef<'_, T>) -> Result<Self, EvdError> {
1235		assert!(A.nrows() == A.ncols());
1236
1237		let par = get_global_parallelism();
1238
1239		let n = A.nrows();
1240
1241		let mut U_real = Mat::zeros(n, n);
1242		let mut S_re = Diag::zeros(n);
1243		let mut S_im = Diag::zeros(n);
1244
1245		linalg::evd::evd_real(
1246			A,
1247			S_re.as_mut(),
1248			S_im.as_mut(),
1249			None,
1250			Some(U_real.as_mut()),
1251			par,
1252			MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<T>(
1253				n,
1254				linalg::evd::ComputeEigenvectors::No,
1255				linalg::evd::ComputeEigenvectors::Yes,
1256				par,
1257				default(),
1258			))),
1259			default(),
1260		)?;
1261
1262		let mut U = Mat::zeros(n, n);
1263		let mut S = Diag::zeros(n);
1264
1265		let mut j = 0;
1266		while j < n {
1267			if S_im[j] == zero() {
1268				S[j] = Complex::new(S_re[j].clone(), zero());
1269
1270				for i in 0..n {
1271					U[(i, j)] = Complex::new(U_real[(i, j)].clone(), zero());
1272				}
1273
1274				j += 1;
1275			} else {
1276				S[j] = Complex::new(S_re[j].clone(), S_im[j].clone());
1277				S[j + 1] = Complex::new(S_re[j].clone(), neg(&S_im[j]));
1278
1279				for i in 0..n {
1280					U[(i, j)] = Complex::new(U_real[(i, j)].clone(), U_real[(i, j + 1)].clone());
1281					U[(i, j + 1)] = Complex::new(U_real[(i, j)].clone(), neg(&U_real[(i, j + 1)]));
1282				}
1283
1284				j += 2;
1285			}
1286		}
1287
1288		Ok(Self { U, S })
1289	}
1290
1291	fn new_imp(A: MatRef<'_, Complex<T>>, conj: Conj) -> Result<Self, EvdError> {
1292		let par = get_global_parallelism();
1293
1294		let n = A.nrows();
1295
1296		let mut U = Mat::zeros(n, n);
1297		let mut S = Diag::zeros(n);
1298
1299		linalg::evd::evd_cplx(
1300			A,
1301			S.as_mut(),
1302			None,
1303			Some(U.as_mut()),
1304			par,
1305			MemStack::new(&mut MemBuffer::new(linalg::evd::evd_scratch::<Complex<T>>(
1306				n,
1307				linalg::evd::ComputeEigenvectors::No,
1308				linalg::evd::ComputeEigenvectors::Yes,
1309				par,
1310				default(),
1311			))),
1312			default(),
1313		)?;
1314
1315		if conj == Conj::Yes {
1316			for c in U.col_iter_mut() {
1317				for x in c.iter_mut() {
1318					*x = math_utils::conj(x);
1319				}
1320			}
1321		}
1322
1323		Ok(Self { U, S })
1324	}
1325
1326	/// returns the factor $U$
1327	pub fn U(&self) -> MatRef<'_, Complex<T>> {
1328		self.U.as_ref()
1329	}
1330
1331	/// returns the factor $S$
1332	pub fn S(&self) -> DiagRef<'_, Complex<T>> {
1333		self.S.as_ref()
1334	}
1335}
1336
1337impl<T: ComplexField> ShapeCore for Llt<T> {
1338	#[inline]
1339	fn nrows(&self) -> usize {
1340		self.L().nrows()
1341	}
1342
1343	#[inline]
1344	fn ncols(&self) -> usize {
1345		self.L().ncols()
1346	}
1347}
1348impl<T: ComplexField> ShapeCore for Ldlt<T> {
1349	#[inline]
1350	fn nrows(&self) -> usize {
1351		self.L().nrows()
1352	}
1353
1354	#[inline]
1355	fn ncols(&self) -> usize {
1356		self.L().ncols()
1357	}
1358}
1359impl<T: ComplexField> ShapeCore for Lblt<T> {
1360	#[inline]
1361	fn nrows(&self) -> usize {
1362		self.L().nrows()
1363	}
1364
1365	#[inline]
1366	fn ncols(&self) -> usize {
1367		self.L().ncols()
1368	}
1369}
1370impl<T: ComplexField> ShapeCore for PartialPivLu<T> {
1371	#[inline]
1372	fn nrows(&self) -> usize {
1373		self.L().nrows()
1374	}
1375
1376	#[inline]
1377	fn ncols(&self) -> usize {
1378		self.U().ncols()
1379	}
1380}
1381impl<T: ComplexField> ShapeCore for FullPivLu<T> {
1382	#[inline]
1383	fn nrows(&self) -> usize {
1384		self.L().nrows()
1385	}
1386
1387	#[inline]
1388	fn ncols(&self) -> usize {
1389		self.U().ncols()
1390	}
1391}
1392impl<T: ComplexField> ShapeCore for Qr<T> {
1393	#[inline]
1394	fn nrows(&self) -> usize {
1395		self.Q_basis().nrows()
1396	}
1397
1398	#[inline]
1399	fn ncols(&self) -> usize {
1400		self.R().ncols()
1401	}
1402}
1403impl<T: ComplexField> ShapeCore for ColPivQr<T> {
1404	#[inline]
1405	fn nrows(&self) -> usize {
1406		self.Q_basis().nrows()
1407	}
1408
1409	#[inline]
1410	fn ncols(&self) -> usize {
1411		self.R().ncols()
1412	}
1413}
1414impl<T: ComplexField> ShapeCore for Svd<T> {
1415	#[inline]
1416	fn nrows(&self) -> usize {
1417		self.U().nrows()
1418	}
1419
1420	#[inline]
1421	fn ncols(&self) -> usize {
1422		self.V().nrows()
1423	}
1424}
1425impl<T: ComplexField> ShapeCore for SelfAdjointEigen<T> {
1426	#[inline]
1427	fn nrows(&self) -> usize {
1428		self.U().nrows()
1429	}
1430
1431	#[inline]
1432	fn ncols(&self) -> usize {
1433		self.U().nrows()
1434	}
1435}
1436impl<T: RealField> ShapeCore for Eigen<T> {
1437	#[inline]
1438	fn nrows(&self) -> usize {
1439		self.U().nrows()
1440	}
1441
1442	#[inline]
1443	fn ncols(&self) -> usize {
1444		self.U().nrows()
1445	}
1446}
1447
1448impl<T: ComplexField> SolveCore<T> for Llt<T> {
1449	#[track_caller]
1450	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1451		let par = get_global_parallelism();
1452
1453		let mut mem = MemBuffer::new(linalg::cholesky::llt::solve::solve_in_place_scratch::<T>(
1454			self.L.nrows(),
1455			rhs.ncols(),
1456			par,
1457		));
1458		let stack = MemStack::new(&mut mem);
1459
1460		linalg::cholesky::llt::solve::solve_in_place_with_conj(self.L.as_ref(), conj, rhs, par, stack);
1461	}
1462
1463	#[track_caller]
1464	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1465		let par = get_global_parallelism();
1466
1467		let mut mem = MemBuffer::new(linalg::cholesky::llt::solve::solve_in_place_scratch::<T>(
1468			self.L.nrows(),
1469			rhs.ncols(),
1470			par,
1471		));
1472		let stack = MemStack::new(&mut mem);
1473
1474		linalg::cholesky::llt::solve::solve_in_place_with_conj(self.L.as_ref(), conj.compose(Conj::Yes), rhs, par, stack);
1475	}
1476}
1477
1478#[math]
1479fn make_self_adjoint<T: ComplexField>(mut A: MatMut<'_, T>) {
1480	assert!(A.nrows() == A.ncols());
1481	let n = A.nrows();
1482	for j in 0..n {
1483		A[(j, j)] = from_real(real(A[(j, j)]));
1484		for i in 0..j {
1485			A[(i, j)] = conj(A[(j, i)]);
1486		}
1487	}
1488}
1489
1490impl<T: ComplexField> DenseSolveCore<T> for Llt<T> {
1491	#[track_caller]
1492	fn reconstruct(&self) -> Mat<T> {
1493		let par = get_global_parallelism();
1494
1495		let n = self.L.nrows();
1496		let mut out = Mat::zeros(n, n);
1497
1498		let mut mem = MemBuffer::new(linalg::cholesky::llt::reconstruct::reconstruct_scratch::<T>(n, par));
1499		let stack = MemStack::new(&mut mem);
1500
1501		linalg::cholesky::llt::reconstruct::reconstruct(out.as_mut(), self.L(), par, stack);
1502
1503		make_self_adjoint(out.as_mut());
1504		out
1505	}
1506
1507	#[track_caller]
1508	fn inverse(&self) -> Mat<T> {
1509		let par = get_global_parallelism();
1510
1511		let n = self.L.nrows();
1512		let mut out = Mat::zeros(n, n);
1513
1514		let mut mem = MemBuffer::new(linalg::cholesky::llt::inverse::inverse_scratch::<T>(n, par));
1515		let stack = MemStack::new(&mut mem);
1516
1517		linalg::cholesky::llt::inverse::inverse(out.as_mut(), self.L(), par, stack);
1518
1519		make_self_adjoint(out.as_mut());
1520		out
1521	}
1522}
1523
1524impl<T: ComplexField> SolveCore<T> for Ldlt<T> {
1525	#[track_caller]
1526	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1527		let par = get_global_parallelism();
1528
1529		let mut mem = MemBuffer::new(linalg::cholesky::ldlt::solve::solve_in_place_scratch::<T>(
1530			self.L.nrows(),
1531			rhs.ncols(),
1532			par,
1533		));
1534		let stack = MemStack::new(&mut mem);
1535
1536		linalg::cholesky::ldlt::solve::solve_in_place_with_conj(self.L.as_ref(), self.D.as_ref(), conj, rhs, par, stack);
1537	}
1538
1539	#[track_caller]
1540	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1541		let par = get_global_parallelism();
1542
1543		let mut mem = MemBuffer::new(linalg::cholesky::ldlt::solve::solve_in_place_scratch::<T>(
1544			self.L.nrows(),
1545			rhs.ncols(),
1546			par,
1547		));
1548		let stack = MemStack::new(&mut mem);
1549
1550		linalg::cholesky::ldlt::solve::solve_in_place_with_conj(self.L(), self.D(), conj.compose(Conj::Yes), rhs, par, stack);
1551	}
1552}
1553
1554impl<T: ComplexField> DenseSolveCore<T> for Ldlt<T> {
1555	#[track_caller]
1556	fn reconstruct(&self) -> Mat<T> {
1557		let par = get_global_parallelism();
1558
1559		let n = self.L.nrows();
1560		let mut out = Mat::zeros(n, n);
1561
1562		let mut mem = MemBuffer::new(linalg::cholesky::ldlt::reconstruct::reconstruct_scratch::<T>(n, par));
1563		let stack = MemStack::new(&mut mem);
1564
1565		linalg::cholesky::ldlt::reconstruct::reconstruct(out.as_mut(), self.L(), self.D(), par, stack);
1566
1567		make_self_adjoint(out.as_mut());
1568		out
1569	}
1570
1571	#[track_caller]
1572	fn inverse(&self) -> Mat<T> {
1573		let par = get_global_parallelism();
1574
1575		let n = self.L.nrows();
1576		let mut out = Mat::zeros(n, n);
1577
1578		let mut mem = MemBuffer::new(linalg::cholesky::ldlt::inverse::inverse_scratch::<T>(n, par));
1579		let stack = MemStack::new(&mut mem);
1580
1581		linalg::cholesky::ldlt::inverse::inverse(out.as_mut(), self.L(), self.D(), par, stack);
1582
1583		make_self_adjoint(out.as_mut());
1584		out
1585	}
1586}
1587
1588impl<T: ComplexField> SolveCore<T> for Lblt<T> {
1589	#[track_caller]
1590	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1591		let par = get_global_parallelism();
1592
1593		let mut mem = MemBuffer::new(linalg::cholesky::lblt::solve::solve_in_place_scratch::<usize, T>(
1594			self.L.nrows(),
1595			rhs.ncols(),
1596			par,
1597		));
1598		let stack = MemStack::new(&mut mem);
1599
1600		linalg::cholesky::lblt::solve::solve_in_place_with_conj(self.L.as_ref(), self.B_diag(), self.B_subdiag(), conj, self.P(), rhs, par, stack);
1601	}
1602
1603	#[track_caller]
1604	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1605		let par = get_global_parallelism();
1606
1607		let mut mem = MemBuffer::new(linalg::cholesky::lblt::solve::solve_in_place_scratch::<usize, T>(
1608			self.L.nrows(),
1609			rhs.ncols(),
1610			par,
1611		));
1612		let stack = MemStack::new(&mut mem);
1613
1614		linalg::cholesky::lblt::solve::solve_in_place_with_conj(
1615			self.L(),
1616			self.B_diag(),
1617			self.B_subdiag(),
1618			conj.compose(Conj::Yes),
1619			self.P(),
1620			rhs,
1621			par,
1622			stack,
1623		);
1624	}
1625}
1626
1627impl<T: ComplexField> DenseSolveCore<T> for Lblt<T> {
1628	#[track_caller]
1629	fn reconstruct(&self) -> Mat<T> {
1630		let par = get_global_parallelism();
1631
1632		let n = self.L.nrows();
1633		let mut out = Mat::zeros(n, n);
1634
1635		let mut mem = MemBuffer::new(linalg::cholesky::lblt::reconstruct::reconstruct_scratch::<usize, T>(n, par));
1636		let stack = MemStack::new(&mut mem);
1637
1638		linalg::cholesky::lblt::reconstruct::reconstruct(out.as_mut(), self.L(), self.B_diag(), self.B_subdiag(), self.P(), par, stack);
1639
1640		make_self_adjoint(out.as_mut());
1641		out
1642	}
1643
1644	#[track_caller]
1645	fn inverse(&self) -> Mat<T> {
1646		let par = get_global_parallelism();
1647
1648		let n = self.L.nrows();
1649		let mut out = Mat::zeros(n, n);
1650
1651		let mut mem = MemBuffer::new(linalg::cholesky::lblt::inverse::inverse_scratch::<usize, T>(n, par));
1652		let stack = MemStack::new(&mut mem);
1653
1654		linalg::cholesky::lblt::inverse::inverse(out.as_mut(), self.L(), self.B_diag(), self.B_subdiag(), self.P(), par, stack);
1655
1656		make_self_adjoint(out.as_mut());
1657		out
1658	}
1659}
1660
1661impl<T: ComplexField> SolveCore<T> for PartialPivLu<T> {
1662	#[track_caller]
1663	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1664		let par = get_global_parallelism();
1665
1666		assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
1667
1668		let k = rhs.ncols();
1669
1670		linalg::lu::partial_pivoting::solve::solve_in_place_with_conj(
1671			self.L(),
1672			self.U(),
1673			self.P(),
1674			conj,
1675			rhs,
1676			par,
1677			MemStack::new(&mut MemBuffer::new(
1678				linalg::lu::partial_pivoting::solve::solve_in_place_scratch::<usize, T>(self.nrows(), k, par),
1679			)),
1680		);
1681	}
1682
1683	#[track_caller]
1684	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1685		let par = get_global_parallelism();
1686
1687		assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
1688
1689		let k = rhs.ncols();
1690
1691		linalg::lu::partial_pivoting::solve::solve_transpose_in_place_with_conj(
1692			self.L(),
1693			self.U(),
1694			self.P(),
1695			conj,
1696			rhs,
1697			par,
1698			MemStack::new(&mut MemBuffer::new(
1699				linalg::lu::partial_pivoting::solve::solve_transpose_in_place_scratch::<usize, T>(self.nrows(), k, par),
1700			)),
1701		);
1702	}
1703}
1704
1705impl<T: ComplexField> DenseSolveCore<T> for PartialPivLu<T> {
1706	fn reconstruct(&self) -> Mat<T> {
1707		let par = get_global_parallelism();
1708		let m = self.nrows();
1709		let n = self.ncols();
1710
1711		let mut out = Mat::zeros(m, n);
1712
1713		linalg::lu::partial_pivoting::reconstruct::reconstruct(
1714			out.as_mut(),
1715			self.L(),
1716			self.U(),
1717			self.P(),
1718			par,
1719			MemStack::new(&mut MemBuffer::new(linalg::lu::partial_pivoting::reconstruct::reconstruct_scratch::<
1720				usize,
1721				T,
1722			>(m, n, par))),
1723		);
1724
1725		out
1726	}
1727
1728	#[track_caller]
1729	fn inverse(&self) -> Mat<T> {
1730		let par = get_global_parallelism();
1731
1732		assert!(self.nrows() == self.ncols());
1733
1734		let n = self.ncols();
1735
1736		let mut out = Mat::zeros(n, n);
1737
1738		linalg::lu::partial_pivoting::inverse::inverse(
1739			out.as_mut(),
1740			self.L(),
1741			self.U(),
1742			self.P(),
1743			par,
1744			MemStack::new(&mut MemBuffer::new(linalg::lu::partial_pivoting::inverse::inverse_scratch::<usize, T>(
1745				n, par,
1746			))),
1747		);
1748
1749		out
1750	}
1751}
1752
1753impl<T: ComplexField> SolveCore<T> for FullPivLu<T> {
1754	#[track_caller]
1755	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1756		let par = get_global_parallelism();
1757
1758		assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
1759
1760		let k = rhs.ncols();
1761
1762		linalg::lu::full_pivoting::solve::solve_in_place_with_conj(
1763			self.L(),
1764			self.U(),
1765			self.P(),
1766			self.Q(),
1767			conj,
1768			rhs,
1769			par,
1770			MemStack::new(&mut MemBuffer::new(linalg::lu::full_pivoting::solve::solve_in_place_scratch::<usize, T>(
1771				self.nrows(),
1772				k,
1773				par,
1774			))),
1775		);
1776	}
1777
1778	#[track_caller]
1779	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1780		let par = get_global_parallelism();
1781
1782		assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
1783
1784		let k = rhs.ncols();
1785
1786		linalg::lu::full_pivoting::solve::solve_transpose_in_place_with_conj(
1787			self.L(),
1788			self.U(),
1789			self.P(),
1790			self.Q(),
1791			conj,
1792			rhs,
1793			par,
1794			MemStack::new(&mut MemBuffer::new(linalg::lu::full_pivoting::solve::solve_transpose_in_place_scratch::<
1795				usize,
1796				T,
1797			>(self.nrows(), k, par))),
1798		);
1799	}
1800}
1801
1802impl<T: ComplexField> DenseSolveCore<T> for FullPivLu<T> {
1803	fn reconstruct(&self) -> Mat<T> {
1804		let par = get_global_parallelism();
1805		let m = self.nrows();
1806		let n = self.ncols();
1807
1808		let mut out = Mat::zeros(m, n);
1809
1810		linalg::lu::full_pivoting::reconstruct::reconstruct(
1811			out.as_mut(),
1812			self.L(),
1813			self.U(),
1814			self.P(),
1815			self.Q(),
1816			par,
1817			MemStack::new(&mut MemBuffer::new(
1818				linalg::lu::full_pivoting::reconstruct::reconstruct_scratch::<usize, T>(m, n, par),
1819			)),
1820		);
1821
1822		out
1823	}
1824
1825	#[track_caller]
1826	fn inverse(&self) -> Mat<T> {
1827		let par = get_global_parallelism();
1828
1829		assert!(self.nrows() == self.ncols());
1830
1831		let n = self.ncols();
1832
1833		let mut out = Mat::zeros(n, n);
1834
1835		linalg::lu::full_pivoting::inverse::inverse(
1836			out.as_mut(),
1837			self.L(),
1838			self.U(),
1839			self.P(),
1840			self.Q(),
1841			par,
1842			MemStack::new(&mut MemBuffer::new(linalg::lu::full_pivoting::inverse::inverse_scratch::<usize, T>(
1843				n, par,
1844			))),
1845		);
1846
1847		out
1848	}
1849}
1850
1851impl<T: ComplexField> SolveCore<T> for Qr<T> {
1852	#[track_caller]
1853	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1854		let par = get_global_parallelism();
1855
1856		assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
1857
1858		let n = self.nrows();
1859		let blocksize = self.Q_coeff().nrows();
1860		let k = rhs.ncols();
1861
1862		linalg::qr::no_pivoting::solve::solve_in_place_with_conj(
1863			self.Q_basis(),
1864			self.Q_coeff(),
1865			self.R(),
1866			conj,
1867			rhs,
1868			par,
1869			MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::solve::solve_in_place_scratch::<T>(
1870				n, blocksize, k, par,
1871			))),
1872		);
1873	}
1874
1875	#[track_caller]
1876	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1877		let par = get_global_parallelism();
1878
1879		assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
1880
1881		let n = self.nrows();
1882		let blocksize = self.Q_coeff().nrows();
1883		let k = rhs.ncols();
1884
1885		linalg::qr::no_pivoting::solve::solve_transpose_in_place_with_conj(
1886			self.Q_basis(),
1887			self.Q_coeff(),
1888			self.R(),
1889			conj,
1890			rhs,
1891			par,
1892			MemStack::new(&mut MemBuffer::new(
1893				linalg::qr::no_pivoting::solve::solve_transpose_in_place_scratch::<T>(n, blocksize, k, par),
1894			)),
1895		);
1896	}
1897}
1898
1899impl<T: ComplexField> SolveLstsqCore<T> for Qr<T> {
1900	#[track_caller]
1901	fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1902		let par = get_global_parallelism();
1903
1904		assert!(all(self.nrows() == rhs.nrows(), self.nrows() >= self.ncols(),));
1905
1906		let m = self.nrows();
1907		let n = self.ncols();
1908		let blocksize = self.Q_coeff().nrows();
1909		let k = rhs.ncols();
1910
1911		linalg::qr::no_pivoting::solve::solve_lstsq_in_place_with_conj(
1912			self.Q_basis(),
1913			self.Q_coeff(),
1914			self.R(),
1915			conj,
1916			rhs,
1917			par,
1918			MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::solve::solve_lstsq_in_place_scratch::<T>(
1919				m, n, blocksize, k, par,
1920			))),
1921		);
1922	}
1923}
1924
1925impl<T: ComplexField> DenseSolveCore<T> for Qr<T> {
1926	fn reconstruct(&self) -> Mat<T> {
1927		let par = get_global_parallelism();
1928		let m = self.nrows();
1929		let n = self.ncols();
1930		let blocksize = self.Q_coeff().nrows();
1931
1932		let mut out = Mat::zeros(m, n);
1933
1934		linalg::qr::no_pivoting::reconstruct::reconstruct(
1935			out.as_mut(),
1936			self.Q_basis(),
1937			self.Q_coeff(),
1938			self.R(),
1939			par,
1940			MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::reconstruct::reconstruct_scratch::<T>(
1941				m, n, blocksize, par,
1942			))),
1943		);
1944
1945		out
1946	}
1947
1948	fn inverse(&self) -> Mat<T> {
1949		let par = get_global_parallelism();
1950		assert!(self.nrows() == self.ncols());
1951
1952		let n = self.ncols();
1953		let blocksize = self.Q_coeff().nrows();
1954
1955		let mut out = Mat::zeros(n, n);
1956
1957		linalg::qr::no_pivoting::inverse::inverse(
1958			out.as_mut(),
1959			self.Q_basis(),
1960			self.Q_coeff(),
1961			self.R(),
1962			par,
1963			MemStack::new(&mut MemBuffer::new(linalg::qr::no_pivoting::inverse::inverse_scratch::<T>(
1964				n, blocksize, par,
1965			))),
1966		);
1967
1968		out
1969	}
1970}
1971
1972impl<T: ComplexField> SolveCore<T> for ColPivQr<T> {
1973	#[track_caller]
1974	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1975		let par = get_global_parallelism();
1976
1977		assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
1978
1979		let n = self.nrows();
1980		let blocksize = self.Q_coeff().nrows();
1981		let k = rhs.ncols();
1982
1983		linalg::qr::col_pivoting::solve::solve_in_place_with_conj(
1984			self.Q_basis(),
1985			self.Q_coeff(),
1986			self.R(),
1987			self.P(),
1988			conj,
1989			rhs,
1990			par,
1991			MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::solve::solve_in_place_scratch::<usize, T>(
1992				n, blocksize, k, par,
1993			))),
1994		);
1995	}
1996
1997	#[track_caller]
1998	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
1999		let par = get_global_parallelism();
2000
2001		assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
2002
2003		let n = self.nrows();
2004		let blocksize = self.Q_coeff().nrows();
2005		let k = rhs.ncols();
2006
2007		linalg::qr::col_pivoting::solve::solve_transpose_in_place_with_conj(
2008			self.Q_basis(),
2009			self.Q_coeff(),
2010			self.R(),
2011			self.P(),
2012			conj,
2013			rhs,
2014			par,
2015			MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::solve::solve_transpose_in_place_scratch::<
2016				usize,
2017				T,
2018			>(n, blocksize, k, par))),
2019		);
2020	}
2021}
2022
2023impl<T: ComplexField> SolveLstsqCore<T> for ColPivQr<T> {
2024	#[track_caller]
2025	fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2026		let par = get_global_parallelism();
2027
2028		assert!(all(self.nrows() == rhs.nrows(), self.nrows() >= self.ncols(),));
2029
2030		let m = self.nrows();
2031		let n = self.ncols();
2032		let blocksize = self.Q_coeff().nrows();
2033		let k = rhs.ncols();
2034
2035		linalg::qr::col_pivoting::solve::solve_lstsq_in_place_with_conj(
2036			self.Q_basis(),
2037			self.Q_coeff(),
2038			self.R(),
2039			self.P(),
2040			conj,
2041			rhs,
2042			par,
2043			MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::solve::solve_lstsq_in_place_scratch::<
2044				usize,
2045				T,
2046			>(m, n, blocksize, k, par))),
2047		);
2048	}
2049}
2050
2051impl<T: ComplexField> DenseSolveCore<T> for ColPivQr<T> {
2052	fn reconstruct(&self) -> Mat<T> {
2053		let par = get_global_parallelism();
2054		let m = self.nrows();
2055		let n = self.ncols();
2056		let blocksize = self.Q_coeff().nrows();
2057
2058		let mut out = Mat::zeros(m, n);
2059
2060		linalg::qr::col_pivoting::reconstruct::reconstruct(
2061			out.as_mut(),
2062			self.Q_basis(),
2063			self.Q_coeff(),
2064			self.R(),
2065			self.P(),
2066			par,
2067			MemStack::new(&mut MemBuffer::new(
2068				linalg::qr::col_pivoting::reconstruct::reconstruct_scratch::<usize, T>(m, n, blocksize, par),
2069			)),
2070		);
2071
2072		out
2073	}
2074
2075	fn inverse(&self) -> Mat<T> {
2076		let par = get_global_parallelism();
2077		assert!(self.nrows() == self.ncols());
2078
2079		let n = self.ncols();
2080		let blocksize = self.Q_coeff().nrows();
2081
2082		let mut out = Mat::zeros(n, n);
2083
2084		linalg::qr::col_pivoting::inverse::inverse(
2085			out.as_mut(),
2086			self.Q_basis(),
2087			self.Q_coeff(),
2088			self.R(),
2089			self.P(),
2090			par,
2091			MemStack::new(&mut MemBuffer::new(linalg::qr::col_pivoting::inverse::inverse_scratch::<usize, T>(
2092				n, blocksize, par,
2093			))),
2094		);
2095
2096		out
2097	}
2098}
2099
2100impl<T: ComplexField> SolveCore<T> for Svd<T> {
2101	#[track_caller]
2102	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2103		let par = get_global_parallelism();
2104
2105		assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
2106
2107		let mut rhs = rhs;
2108		let n = self.nrows();
2109		let k = rhs.ncols();
2110		let mut tmp = Mat::zeros(n, k);
2111
2112		linalg::matmul::matmul_with_conj(
2113			tmp.as_mut(),
2114			Accum::Replace,
2115			self.U().transpose(),
2116			conj.compose(Conj::Yes),
2117			rhs.as_ref(),
2118			Conj::No,
2119			one(),
2120			par,
2121		);
2122
2123		for j in 0..k {
2124			for i in 0..n {
2125				let s = recip(&real(&self.S()[i]));
2126				tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2127			}
2128		}
2129
2130		linalg::matmul::matmul_with_conj(rhs.as_mut(), Accum::Replace, self.V(), conj, tmp.as_ref(), Conj::No, one(), par);
2131	}
2132
2133	#[track_caller]
2134	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2135		let par = get_global_parallelism();
2136
2137		assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
2138
2139		let mut rhs = rhs;
2140		let n = self.nrows();
2141		let k = rhs.ncols();
2142		let mut tmp = Mat::zeros(n, k);
2143
2144		linalg::matmul::matmul_with_conj(
2145			tmp.as_mut(),
2146			Accum::Replace,
2147			self.V().transpose(),
2148			conj,
2149			rhs.as_ref(),
2150			Conj::No,
2151			one(),
2152			par,
2153		);
2154
2155		for j in 0..k {
2156			for i in 0..n {
2157				let s = recip(&real(&self.S()[i]));
2158				tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2159			}
2160		}
2161
2162		linalg::matmul::matmul_with_conj(
2163			rhs.as_mut(),
2164			Accum::Replace,
2165			self.U(),
2166			conj.compose(Conj::Yes),
2167			tmp.as_ref(),
2168			Conj::No,
2169			one(),
2170			par,
2171		);
2172	}
2173}
2174
2175impl<T: ComplexField> SolveLstsqCore<T> for Svd<T> {
2176	#[track_caller]
2177	fn solve_lstsq_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2178		let par = get_global_parallelism();
2179
2180		assert!(all(self.nrows() == rhs.nrows(), self.nrows() >= self.ncols(),));
2181
2182		let m = self.nrows();
2183		let n = self.ncols();
2184
2185		let size = Ord::min(m, n);
2186
2187		let U = self.U().get(.., ..size);
2188		let V = self.V().get(.., ..size);
2189
2190		let k = rhs.ncols();
2191
2192		let mut tmp = Mat::zeros(size, k);
2193
2194		linalg::matmul::matmul_with_conj(
2195			tmp.as_mut(),
2196			Accum::Replace,
2197			U.transpose(),
2198			conj.compose(Conj::Yes),
2199			rhs.as_ref(),
2200			Conj::No,
2201			one(),
2202			par,
2203		);
2204
2205		for j in 0..k {
2206			for i in 0..size {
2207				let s = recip(&real(&self.S()[i]));
2208				tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2209			}
2210		}
2211
2212		linalg::matmul::matmul_with_conj(rhs.get_mut(..size, ..), Accum::Replace, V, conj, tmp.as_ref(), Conj::No, one(), par);
2213	}
2214}
2215
2216impl<T: ComplexField> DenseSolveCore<T> for Svd<T> {
2217	fn reconstruct(&self) -> Mat<T> {
2218		let par = get_global_parallelism();
2219		let m = self.nrows();
2220		let n = self.ncols();
2221
2222		let size = Ord::min(m, n);
2223
2224		let U = self.U().get(.., ..size);
2225		let V = self.V().get(.., ..size);
2226		let S = self.S();
2227
2228		let mut UxS = Mat::zeros(m, size);
2229		for j in 0..size {
2230			let s = real(&S[j]);
2231			for i in 0..m {
2232				UxS[(i, j)] = mul_real(&U[(i, j)], &s);
2233			}
2234		}
2235
2236		let mut out = Mat::zeros(m, n);
2237
2238		linalg::matmul::matmul(out.as_mut(), Accum::Replace, UxS.as_ref(), V.adjoint(), one(), par);
2239
2240		out
2241	}
2242
2243	#[track_caller]
2244	fn inverse(&self) -> Mat<T> {
2245		let par = get_global_parallelism();
2246
2247		assert!(self.nrows() == self.ncols());
2248		let n = self.nrows();
2249
2250		let U = self.U();
2251		let V = self.V();
2252		let S = self.S();
2253
2254		let mut VxS = Mat::zeros(n, n);
2255		for j in 0..n {
2256			let s = recip(&real(&S[j]));
2257
2258			for i in 0..n {
2259				VxS[(i, j)] = mul_real(&V[(i, j)], &s);
2260			}
2261		}
2262
2263		let mut out = Mat::zeros(n, n);
2264
2265		linalg::matmul::matmul(out.as_mut(), Accum::Replace, VxS.as_ref(), U.adjoint(), one(), par);
2266
2267		out
2268	}
2269}
2270
2271impl<T: ComplexField> SolveCore<T> for SelfAdjointEigen<T> {
2272	#[track_caller]
2273	fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2274		let par = get_global_parallelism();
2275
2276		assert!(all(self.nrows() == self.ncols(), self.nrows() == rhs.nrows(),));
2277
2278		let mut rhs = rhs;
2279		let n = self.nrows();
2280		let k = rhs.ncols();
2281		let mut tmp = Mat::zeros(n, k);
2282
2283		linalg::matmul::matmul_with_conj(
2284			tmp.as_mut(),
2285			Accum::Replace,
2286			self.U().transpose(),
2287			conj.compose(Conj::Yes),
2288			rhs.as_ref(),
2289			Conj::No,
2290			one(),
2291			par,
2292		);
2293
2294		for j in 0..k {
2295			for i in 0..n {
2296				let s = recip(&real(&self.S()[i]));
2297				tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2298			}
2299		}
2300
2301		linalg::matmul::matmul_with_conj(rhs.as_mut(), Accum::Replace, self.U(), conj, tmp.as_ref(), Conj::No, one(), par);
2302	}
2303
2304	#[track_caller]
2305	fn solve_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>) {
2306		let par = get_global_parallelism();
2307
2308		assert!(all(self.nrows() == self.ncols(), self.ncols() == rhs.nrows(),));
2309
2310		let mut rhs = rhs;
2311		let n = self.nrows();
2312		let k = rhs.ncols();
2313		let mut tmp = Mat::zeros(n, k);
2314
2315		linalg::matmul::matmul_with_conj(
2316			tmp.as_mut(),
2317			Accum::Replace,
2318			self.U().transpose(),
2319			conj,
2320			rhs.as_ref(),
2321			Conj::No,
2322			one(),
2323			par,
2324		);
2325
2326		for j in 0..k {
2327			for i in 0..n {
2328				let s = recip(&real(&self.S()[i]));
2329				tmp[(i, j)] = mul_real(&tmp[(i, j)], &s);
2330			}
2331		}
2332
2333		linalg::matmul::matmul_with_conj(
2334			rhs.as_mut(),
2335			Accum::Replace,
2336			self.U(),
2337			conj.compose(Conj::Yes),
2338			tmp.as_ref(),
2339			Conj::No,
2340			one(),
2341			par,
2342		);
2343	}
2344}
2345
2346impl<T: ComplexField> DenseSolveCore<T> for SelfAdjointEigen<T> {
2347	fn reconstruct(&self) -> Mat<T> {
2348		let par = get_global_parallelism();
2349		let m = self.nrows();
2350		let n = self.ncols();
2351
2352		let size = Ord::min(m, n);
2353
2354		let U = self.U().get(.., ..size);
2355		let V = self.U().get(.., ..size);
2356		let S = self.S();
2357
2358		let mut UxS = Mat::zeros(m, size);
2359		for j in 0..size {
2360			let s = real(&S[j]);
2361			for i in 0..m {
2362				UxS[(i, j)] = mul_real(&U[(i, j)], &s);
2363			}
2364		}
2365
2366		let mut out = Mat::zeros(m, n);
2367
2368		linalg::matmul::matmul(out.as_mut(), Accum::Replace, UxS.as_ref(), V.adjoint(), one(), par);
2369
2370		out
2371	}
2372
2373	fn inverse(&self) -> Mat<T> {
2374		let par = get_global_parallelism();
2375
2376		assert!(self.nrows() == self.ncols());
2377		let n = self.nrows();
2378
2379		let U = self.U();
2380		let V = self.U();
2381		let S = self.S();
2382
2383		let mut VxS = Mat::zeros(n, n);
2384		for j in 0..n {
2385			let s = recip(&real(&S[j]));
2386
2387			for i in 0..n {
2388				VxS[(i, j)] = mul_real(&V[(i, j)], &s);
2389			}
2390		}
2391
2392		let mut out = Mat::zeros(n, n);
2393
2394		linalg::matmul::matmul(out.as_mut(), Accum::Replace, VxS.as_ref(), U.adjoint(), one(), par);
2395
2396		out
2397	}
2398}
2399
2400#[cfg(test)]
2401mod tests {
2402	use super::*;
2403	use crate::assert;
2404	use crate::stats::prelude::*;
2405	use crate::utils::approx::*;
2406
2407	#[track_caller]
2408	fn test_solver(A: MatRef<'_, c64>, A_dec: impl SolveCore<c64>) {
2409		#[track_caller]
2410		fn test_solver_imp(A: MatRef<'_, c64>, A_dec: &dyn SolveCore<c64>) {
2411			let rng = &mut StdRng::seed_from_u64(0xC0FFEE);
2412
2413			let n = A.nrows();
2414			let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
2415
2416			let k = 3;
2417
2418			let ref R = CwiseMatDistribution {
2419				nrows: n,
2420				ncols: k,
2421				dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2422			}
2423			.rand::<Mat<c64>>(rng);
2424
2425			let ref L = CwiseMatDistribution {
2426				nrows: k,
2427				ncols: n,
2428				dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2429			}
2430			.rand::<Mat<c64>>(rng);
2431
2432			assert!(A * A_dec.solve(R) ~ R);
2433			assert!(A.conjugate() * A_dec.solve_conjugate(R) ~ R);
2434			assert!(A.transpose() * A_dec.solve_transpose(R) ~ R);
2435			assert!(A.adjoint() * A_dec.solve_adjoint(R) ~ R);
2436
2437			assert!(A_dec.rsolve(L) * A ~ L);
2438			assert!(A_dec.rsolve_conjugate(L) * A.conjugate() ~ L);
2439			assert!(A_dec.rsolve_transpose(L) * A.transpose() ~ L);
2440			assert!(A_dec.rsolve_adjoint(L) * A.adjoint() ~ L);
2441		}
2442
2443		test_solver_imp(A, &A_dec)
2444	}
2445
2446	#[test]
2447	fn test_all_solvers() {
2448		let rng = &mut StdRng::seed_from_u64(0);
2449		let n = 50;
2450
2451		let ref A = CwiseMatDistribution {
2452			nrows: n,
2453			ncols: n,
2454			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2455		}
2456		.rand::<Mat<c64>>(rng);
2457		let A = A.rb();
2458
2459		test_solver(A, A.partial_piv_lu());
2460		test_solver(A, A.full_piv_lu());
2461		test_solver(A, A.qr());
2462		test_solver(A, A.col_piv_qr());
2463		test_solver(A, A.svd().unwrap());
2464
2465		{
2466			let ref A = A * A.adjoint();
2467			let A = A.rb();
2468			test_solver(A, A.llt(Side::Lower).unwrap());
2469			test_solver(A, A.ldlt(Side::Lower).unwrap());
2470		}
2471
2472		{
2473			let ref A = A + A.adjoint();
2474			let A = A.rb();
2475			test_solver(A, A.lblt(Side::Lower));
2476			test_solver(A, A.self_adjoint_eigen(Side::Lower).unwrap());
2477		}
2478	}
2479
2480	#[test]
2481	fn test_eigen_cplx() {
2482		let rng = &mut StdRng::seed_from_u64(0);
2483		let n = 50;
2484
2485		let A = CwiseMatDistribution {
2486			nrows: n,
2487			ncols: n,
2488			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
2489		}
2490		.rand::<Mat<c64>>(rng);
2491
2492		let n = A.nrows();
2493		let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
2494
2495		let evd = A.eigen().unwrap();
2496		let e = A.eigenvalues().unwrap();
2497		assert!(&A * evd.U() ~ evd.U() * evd.S());
2498		assert!(evd.S().column_vector() ~ ColRef::from_slice(&e));
2499	}
2500
2501	#[test]
2502	fn test_eigen_real() {
2503		let rng = &mut StdRng::seed_from_u64(0);
2504		let n = 50;
2505
2506		let A = CwiseMatDistribution {
2507			nrows: n,
2508			ncols: n,
2509			dist: StandardNormal,
2510		}
2511		.rand::<Mat<f64>>(rng);
2512
2513		let n = A.nrows();
2514		let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (n as f64));
2515
2516		let evd = A.eigen().unwrap();
2517		let e = A.eigenvalues().unwrap();
2518
2519		let A = Mat::from_fn(A.nrows(), A.ncols(), |i, j| c64::from(A[(i, j)]));
2520
2521		assert!(&A * evd.U() ~ evd.U() * evd.S());
2522		assert!(evd.S().column_vector() ~ ColRef::from_slice(&e));
2523	}
2524
2525	#[test]
2526	fn test_svd_solver_for_rectangular_matrix() {
2527		#[rustfmt::skip]
2528    	let A = crate::mat![
2529    	    [4.,   5.,   7.],
2530    	    [8.,   8.,   2.],
2531    	    [4.,   0.,   9.],
2532    	    [2.,   6.,   2.],
2533    	    [0.,   6.,   0.],
2534    	];
2535		#[rustfmt::skip]
2536    	let B = crate::mat![
2537        	[105.,    49.],
2538        	[ 98.,    54.],
2539        	[113.,    35.],
2540        	[ 46.,    34.],
2541        	[ 12.,    24.],
2542     	];
2543
2544		#[rustfmt::skip]
2545	    let X_true= crate::mat![
2546	      [8.,   2.],
2547	      [2.,   4.],
2548	      [9.,   3.],
2549	    ];
2550
2551		let approx_eq = CwiseMat(ApproxEq::eps() * 128.0 * (A.nrows() as f64));
2552		let svd = A.svd().unwrap();
2553		let mut X = B.cloned();
2554		svd.solve_lstsq_in_place_with_conj(crate::Conj::No, X.as_mat_mut());
2555		assert!(X.get(..X_true.nrows(),..) ~ X_true);
2556	}
2557}