faer/linalg/lu/full_pivoting/
solve.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use crate::perm::{permute_rows_in_place, permute_rows_in_place_scratch};
4
5pub fn solve_in_place_scratch<I: Index, T: ComplexField>(LU_dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
6	_ = par;
7	permute_rows_in_place_scratch::<I, T>(LU_dim, rhs_ncols)
8}
9
10pub fn solve_transpose_in_place_scratch<I: Index, T: ComplexField>(LU_dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
11	_ = par;
12	permute_rows_in_place_scratch::<I, T>(LU_dim, rhs_ncols)
13}
14
15#[track_caller]
16pub fn solve_in_place_with_conj<I: Index, T: ComplexField>(
17	L: MatRef<'_, T>,
18	U: MatRef<'_, T>,
19	row_perm: PermRef<'_, I>,
20	col_perm: PermRef<'_, I>,
21	conj_LU: Conj,
22	rhs: MatMut<'_, T>,
23	par: Par,
24	stack: &mut MemStack,
25) {
26	// LU = P A Q^-1
27	// P^-1 LU Q = A
28	// A^-1 = Q^-1 U^-1 L^-1 P
29
30	let n = L.nrows();
31
32	assert!(all(
33		L.nrows() == n,
34		L.ncols() == n,
35		U.nrows() == n,
36		U.ncols() == n,
37		row_perm.len() == n,
38		col_perm.len() == n,
39		rhs.nrows() == n,
40	));
41
42	let mut rhs = rhs;
43	permute_rows_in_place(rhs.rb_mut(), row_perm, stack);
44	linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj(L, conj_LU, rhs.rb_mut(), par);
45	linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(U, conj_LU, rhs.rb_mut(), par);
46	permute_rows_in_place(rhs.rb_mut(), col_perm.inverse(), stack);
47}
48
49#[track_caller]
50pub fn solve_transpose_in_place_with_conj<I: Index, T: ComplexField>(
51	L: MatRef<'_, T>,
52	U: MatRef<'_, T>,
53	row_perm: PermRef<'_, I>,
54	col_perm: PermRef<'_, I>,
55	conj_LU: Conj,
56	rhs: MatMut<'_, T>,
57	par: Par,
58	stack: &mut MemStack,
59) {
60	// LU = P A Q^-1
61	// A^-1 = Q^-1 U^-1 L^-1 P
62	// A^-T = P^-1 L^-T U^-T Q
63
64	let n = L.nrows();
65
66	assert!(all(
67		L.nrows() == n,
68		L.ncols() == n,
69		U.nrows() == n,
70		U.ncols() == n,
71		row_perm.len() == n,
72		col_perm.len() == n,
73		rhs.nrows() == n,
74	));
75
76	let mut rhs = rhs;
77	permute_rows_in_place(rhs.rb_mut(), col_perm, stack);
78	linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(U.transpose(), conj_LU, rhs.rb_mut(), par);
79	linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj(L.transpose(), conj_LU, rhs.rb_mut(), par);
80	permute_rows_in_place(rhs.rb_mut(), row_perm.inverse(), stack);
81}
82
83#[track_caller]
84pub fn solve_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
85	L: MatRef<'_, C>,
86	U: MatRef<'_, C>,
87	row_perm: PermRef<'_, I>,
88	col_perm: PermRef<'_, I>,
89	rhs: MatMut<'_, T>,
90	par: Par,
91	stack: &mut MemStack,
92) {
93	solve_in_place_with_conj(L.canonical(), U.canonical(), row_perm, col_perm, Conj::get::<C>(), rhs, par, stack)
94}
95
96#[track_caller]
97pub fn solve_transpose_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
98	L: MatRef<'_, C>,
99	U: MatRef<'_, C>,
100	row_perm: PermRef<'_, I>,
101	col_perm: PermRef<'_, I>,
102	rhs: MatMut<'_, T>,
103	par: Par,
104	stack: &mut MemStack,
105) {
106	solve_transpose_in_place_with_conj(L.canonical(), U.canonical(), row_perm, col_perm, Conj::get::<C>(), rhs, par, stack)
107}
108
109#[cfg(test)]
110mod tests {
111	use super::*;
112	use crate::assert;
113	use crate::stats::prelude::*;
114	use crate::utils::approx::*;
115	use dyn_stack::MemBuffer;
116	use linalg::lu::full_pivoting::*;
117
118	#[test]
119	fn test_solve() {
120		let rng = &mut StdRng::seed_from_u64(0);
121		let n = 50;
122		let k = 3;
123
124		let A = CwiseMatDistribution {
125			nrows: n,
126			ncols: n,
127			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
128		}
129		.rand::<Mat<c64>>(rng);
130
131		let B = CwiseMatDistribution {
132			nrows: n,
133			ncols: k,
134			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
135		}
136		.rand::<Mat<c64>>(rng);
137
138		let mut LU = A.to_owned();
139		let row_perm_fwd = &mut *vec![0usize; n];
140		let row_perm_bwd = &mut *vec![0usize; n];
141		let col_perm_fwd = &mut *vec![0usize; n];
142		let col_perm_bwd = &mut *vec![0usize; n];
143
144		let (_, row_perm, col_perm) = factor::lu_in_place(
145			LU.as_mut(),
146			row_perm_fwd,
147			row_perm_bwd,
148			col_perm_fwd,
149			col_perm_bwd,
150			Par::Seq,
151			MemStack::new(&mut { MemBuffer::new(factor::lu_in_place_scratch::<usize, c64>(n, n, Par::Seq, default())) }),
152			default(),
153		);
154
155		let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
156
157		{
158			let mut X = B.to_owned();
159			solve::solve_in_place(
160				LU.as_ref(),
161				LU.as_ref(),
162				row_perm,
163				col_perm,
164				X.as_mut(),
165				Par::Seq,
166				MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
167			);
168
169			assert!(&A * &X ~ B);
170		}
171
172		{
173			let mut X = B.to_owned();
174			solve::solve_transpose_in_place(
175				LU.as_ref(),
176				LU.as_ref(),
177				row_perm,
178				col_perm,
179				X.as_mut(),
180				Par::Seq,
181				MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
182			);
183
184			assert!(A.transpose() * &X ~ B);
185		}
186		{
187			let mut X = B.to_owned();
188			solve::solve_in_place(
189				LU.conjugate(),
190				LU.conjugate(),
191				row_perm,
192				col_perm,
193				X.as_mut(),
194				Par::Seq,
195				MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
196			);
197
198			assert!(A.conjugate() * &X ~ B);
199		}
200		{
201			let mut X = B.to_owned();
202			solve::solve_transpose_in_place(
203				LU.conjugate(),
204				LU.conjugate(),
205				row_perm,
206				col_perm,
207				X.as_mut(),
208				Par::Seq,
209				MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
210			);
211
212			assert!(A.adjoint() * &X ~ B);
213		}
214	}
215}