faer/linalg/lu/partial_pivoting/
solve.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use crate::perm::{permute_rows_in_place, permute_rows_in_place_scratch};
4
5pub fn solve_in_place_scratch<I: Index, T: ComplexField>(LU_dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
6	_ = par;
7	permute_rows_in_place_scratch::<I, T>(LU_dim, rhs_ncols)
8}
9
10pub fn solve_transpose_in_place_scratch<I: Index, T: ComplexField>(LU_dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
11	_ = par;
12	permute_rows_in_place_scratch::<I, T>(LU_dim, rhs_ncols)
13}
14
15#[track_caller]
16pub fn solve_in_place_with_conj<I: Index, T: ComplexField>(
17	L: MatRef<'_, T>,
18	U: MatRef<'_, T>,
19	row_perm: PermRef<'_, I>,
20	conj_LU: Conj,
21	rhs: MatMut<'_, T>,
22	par: Par,
23	stack: &mut MemStack,
24) {
25	// LU = PA
26	// P^-1 LU = A
27	// A^-1 = U^-1 L^-1 P
28
29	let n = L.nrows();
30
31	assert!(all(
32		L.nrows() == n,
33		L.ncols() == n,
34		U.nrows() == n,
35		U.ncols() == n,
36		row_perm.len() == n,
37		rhs.nrows() == n,
38	));
39
40	let mut rhs = rhs;
41	permute_rows_in_place(rhs.rb_mut(), row_perm, stack);
42
43	linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj(L, conj_LU, rhs.rb_mut(), par);
44
45	linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(U, conj_LU, rhs.rb_mut(), par);
46}
47
48#[track_caller]
49pub fn solve_transpose_in_place_with_conj<I: Index, T: ComplexField>(
50	L: MatRef<'_, T>,
51	U: MatRef<'_, T>,
52	row_perm: PermRef<'_, I>,
53	conj_LU: Conj,
54	rhs: MatMut<'_, T>,
55	par: Par,
56	stack: &mut MemStack,
57) {
58	// LU = PA
59	// P^-1 LU = A
60	// A^-T = (U^-1 L^-1 P).T
61	// A^-T = P^-1 L^-T U^-T
62
63	let n = L.nrows();
64
65	assert!(all(
66		L.nrows() == n,
67		L.ncols() == n,
68		U.nrows() == n,
69		U.ncols() == n,
70		row_perm.len() == n,
71		rhs.nrows() == n,
72	));
73
74	let mut rhs = rhs;
75
76	linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(U.transpose(), conj_LU, rhs.rb_mut(), par);
77	linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj(L.transpose(), conj_LU, rhs.rb_mut(), par);
78
79	permute_rows_in_place(rhs.rb_mut(), row_perm.inverse(), stack);
80}
81
82#[track_caller]
83pub fn solve_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
84	L: MatRef<'_, C>,
85	U: MatRef<'_, C>,
86	row_perm: PermRef<'_, I>,
87	rhs: MatMut<'_, T>,
88	par: Par,
89	stack: &mut MemStack,
90) {
91	solve_in_place_with_conj(L.canonical(), U.canonical(), row_perm, Conj::get::<C>(), rhs, par, stack)
92}
93
94#[track_caller]
95pub fn solve_transpose_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
96	L: MatRef<'_, C>,
97	U: MatRef<'_, C>,
98	row_perm: PermRef<'_, I>,
99	rhs: MatMut<'_, T>,
100	par: Par,
101	stack: &mut MemStack,
102) {
103	solve_transpose_in_place_with_conj(L.canonical(), U.canonical(), row_perm, Conj::get::<C>(), rhs, par, stack)
104}
105
106#[cfg(test)]
107mod tests {
108	use super::*;
109	use crate::assert;
110	use crate::stats::prelude::*;
111	use crate::utils::approx::*;
112	use dyn_stack::MemBuffer;
113	use linalg::lu::partial_pivoting::*;
114
115	#[test]
116	fn test_solve() {
117		let rng = &mut StdRng::seed_from_u64(0);
118		let n = 50;
119		let k = 3;
120
121		let A = CwiseMatDistribution {
122			nrows: n,
123			ncols: n,
124			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
125		}
126		.rand::<Mat<c64>>(rng);
127
128		let B = CwiseMatDistribution {
129			nrows: n,
130			ncols: k,
131			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
132		}
133		.rand::<Mat<c64>>(rng);
134
135		let mut LU = A.to_owned();
136		let row_perm_fwd = &mut *vec![0usize; n];
137		let row_perm_bwd = &mut *vec![0usize; n];
138
139		let row_perm = factor::lu_in_place(
140			LU.as_mut(),
141			row_perm_fwd,
142			row_perm_bwd,
143			Par::Seq,
144			MemStack::new(&mut { MemBuffer::new(factor::lu_in_place_scratch::<usize, c64>(n, n, Par::Seq, default())) }),
145			default(),
146		)
147		.1;
148
149		let approx_eq = CwiseMat(ApproxEq::eps() * 8.0 * (n as f64));
150
151		{
152			let mut X = B.to_owned();
153			solve::solve_in_place(
154				LU.as_ref(),
155				LU.as_ref(),
156				row_perm,
157				X.as_mut(),
158				Par::Seq,
159				MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
160			);
161
162			assert!(&A * &X ~ B);
163		}
164		{
165			let mut X = B.to_owned();
166			solve::solve_transpose_in_place(
167				LU.as_ref(),
168				LU.as_ref(),
169				row_perm,
170				X.as_mut(),
171				Par::Seq,
172				MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
173			);
174
175			assert!(A.transpose() * &X ~ B);
176		}
177		{
178			let mut X = B.to_owned();
179			solve::solve_in_place(
180				LU.conjugate(),
181				LU.conjugate(),
182				row_perm,
183				X.as_mut(),
184				Par::Seq,
185				MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
186			);
187
188			assert!(A.conjugate() * &X ~ B);
189		}
190		{
191			let mut X = B.to_owned();
192			solve::solve_transpose_in_place(
193				LU.conjugate(),
194				LU.conjugate(),
195				row_perm,
196				X.as_mut(),
197				Par::Seq,
198				MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
199			);
200
201			assert!(A.adjoint() * &X ~ B);
202		}
203	}
204}