faer/linalg/lu/partial_pivoting/
solve.rs1use crate::assert;
2use crate::internal_prelude::*;
3use crate::perm::{permute_rows_in_place, permute_rows_in_place_scratch};
4
5pub fn solve_in_place_scratch<I: Index, T: ComplexField>(LU_dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
6 _ = par;
7 permute_rows_in_place_scratch::<I, T>(LU_dim, rhs_ncols)
8}
9
10pub fn solve_transpose_in_place_scratch<I: Index, T: ComplexField>(LU_dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
11 _ = par;
12 permute_rows_in_place_scratch::<I, T>(LU_dim, rhs_ncols)
13}
14
15#[track_caller]
16pub fn solve_in_place_with_conj<I: Index, T: ComplexField>(
17 L: MatRef<'_, T>,
18 U: MatRef<'_, T>,
19 row_perm: PermRef<'_, I>,
20 conj_LU: Conj,
21 rhs: MatMut<'_, T>,
22 par: Par,
23 stack: &mut MemStack,
24) {
25 let n = L.nrows();
30
31 assert!(all(
32 L.nrows() == n,
33 L.ncols() == n,
34 U.nrows() == n,
35 U.ncols() == n,
36 row_perm.len() == n,
37 rhs.nrows() == n,
38 ));
39
40 let mut rhs = rhs;
41 permute_rows_in_place(rhs.rb_mut(), row_perm, stack);
42
43 linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj(L, conj_LU, rhs.rb_mut(), par);
44
45 linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(U, conj_LU, rhs.rb_mut(), par);
46}
47
48#[track_caller]
49pub fn solve_transpose_in_place_with_conj<I: Index, T: ComplexField>(
50 L: MatRef<'_, T>,
51 U: MatRef<'_, T>,
52 row_perm: PermRef<'_, I>,
53 conj_LU: Conj,
54 rhs: MatMut<'_, T>,
55 par: Par,
56 stack: &mut MemStack,
57) {
58 let n = L.nrows();
64
65 assert!(all(
66 L.nrows() == n,
67 L.ncols() == n,
68 U.nrows() == n,
69 U.ncols() == n,
70 row_perm.len() == n,
71 rhs.nrows() == n,
72 ));
73
74 let mut rhs = rhs;
75
76 linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(U.transpose(), conj_LU, rhs.rb_mut(), par);
77 linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj(L.transpose(), conj_LU, rhs.rb_mut(), par);
78
79 permute_rows_in_place(rhs.rb_mut(), row_perm.inverse(), stack);
80}
81
82#[track_caller]
83pub fn solve_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
84 L: MatRef<'_, C>,
85 U: MatRef<'_, C>,
86 row_perm: PermRef<'_, I>,
87 rhs: MatMut<'_, T>,
88 par: Par,
89 stack: &mut MemStack,
90) {
91 solve_in_place_with_conj(L.canonical(), U.canonical(), row_perm, Conj::get::<C>(), rhs, par, stack)
92}
93
94#[track_caller]
95pub fn solve_transpose_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
96 L: MatRef<'_, C>,
97 U: MatRef<'_, C>,
98 row_perm: PermRef<'_, I>,
99 rhs: MatMut<'_, T>,
100 par: Par,
101 stack: &mut MemStack,
102) {
103 solve_transpose_in_place_with_conj(L.canonical(), U.canonical(), row_perm, Conj::get::<C>(), rhs, par, stack)
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::assert;
110 use crate::stats::prelude::*;
111 use crate::utils::approx::*;
112 use dyn_stack::MemBuffer;
113 use linalg::lu::partial_pivoting::*;
114
115 #[test]
116 fn test_solve() {
117 let rng = &mut StdRng::seed_from_u64(0);
118 let n = 50;
119 let k = 3;
120
121 let A = CwiseMatDistribution {
122 nrows: n,
123 ncols: n,
124 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
125 }
126 .rand::<Mat<c64>>(rng);
127
128 let B = CwiseMatDistribution {
129 nrows: n,
130 ncols: k,
131 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
132 }
133 .rand::<Mat<c64>>(rng);
134
135 let mut LU = A.to_owned();
136 let row_perm_fwd = &mut *vec![0usize; n];
137 let row_perm_bwd = &mut *vec![0usize; n];
138
139 let row_perm = factor::lu_in_place(
140 LU.as_mut(),
141 row_perm_fwd,
142 row_perm_bwd,
143 Par::Seq,
144 MemStack::new(&mut { MemBuffer::new(factor::lu_in_place_scratch::<usize, c64>(n, n, Par::Seq, default())) }),
145 default(),
146 )
147 .1;
148
149 let approx_eq = CwiseMat(ApproxEq::eps() * 8.0 * (n as f64));
150
151 {
152 let mut X = B.to_owned();
153 solve::solve_in_place(
154 LU.as_ref(),
155 LU.as_ref(),
156 row_perm,
157 X.as_mut(),
158 Par::Seq,
159 MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
160 );
161
162 assert!(&A * &X ~ B);
163 }
164 {
165 let mut X = B.to_owned();
166 solve::solve_transpose_in_place(
167 LU.as_ref(),
168 LU.as_ref(),
169 row_perm,
170 X.as_mut(),
171 Par::Seq,
172 MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
173 );
174
175 assert!(A.transpose() * &X ~ B);
176 }
177 {
178 let mut X = B.to_owned();
179 solve::solve_in_place(
180 LU.conjugate(),
181 LU.conjugate(),
182 row_perm,
183 X.as_mut(),
184 Par::Seq,
185 MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
186 );
187
188 assert!(A.conjugate() * &X ~ B);
189 }
190 {
191 let mut X = B.to_owned();
192 solve::solve_transpose_in_place(
193 LU.conjugate(),
194 LU.conjugate(),
195 row_perm,
196 X.as_mut(),
197 Par::Seq,
198 MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
199 );
200
201 assert!(A.adjoint() * &X ~ B);
202 }
203 }
204}