1use crate::assert;
2use crate::internal_prelude::*;
3use crate::perm::{permute_rows_in_place, permute_rows_in_place_scratch};
4
5pub fn solve_in_place_scratch<I: Index, T: ComplexField>(LU_dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
6 _ = par;
7 permute_rows_in_place_scratch::<I, T>(LU_dim, rhs_ncols)
8}
9
10pub fn solve_transpose_in_place_scratch<I: Index, T: ComplexField>(LU_dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
11 _ = par;
12 permute_rows_in_place_scratch::<I, T>(LU_dim, rhs_ncols)
13}
14
15#[track_caller]
16pub fn solve_in_place_with_conj<I: Index, T: ComplexField>(
17 L: MatRef<'_, T>,
18 U: MatRef<'_, T>,
19 row_perm: PermRef<'_, I>,
20 col_perm: PermRef<'_, I>,
21 conj_LU: Conj,
22 rhs: MatMut<'_, T>,
23 par: Par,
24 stack: &mut MemStack,
25) {
26 let n = L.nrows();
31
32 assert!(all(
33 L.nrows() == n,
34 L.ncols() == n,
35 U.nrows() == n,
36 U.ncols() == n,
37 row_perm.len() == n,
38 col_perm.len() == n,
39 rhs.nrows() == n,
40 ));
41
42 let mut rhs = rhs;
43 permute_rows_in_place(rhs.rb_mut(), row_perm, stack);
44 linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj(L, conj_LU, rhs.rb_mut(), par);
45 linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(U, conj_LU, rhs.rb_mut(), par);
46 permute_rows_in_place(rhs.rb_mut(), col_perm.inverse(), stack);
47}
48
49#[track_caller]
50pub fn solve_transpose_in_place_with_conj<I: Index, T: ComplexField>(
51 L: MatRef<'_, T>,
52 U: MatRef<'_, T>,
53 row_perm: PermRef<'_, I>,
54 col_perm: PermRef<'_, I>,
55 conj_LU: Conj,
56 rhs: MatMut<'_, T>,
57 par: Par,
58 stack: &mut MemStack,
59) {
60 let n = L.nrows();
65
66 assert!(all(
67 L.nrows() == n,
68 L.ncols() == n,
69 U.nrows() == n,
70 U.ncols() == n,
71 row_perm.len() == n,
72 col_perm.len() == n,
73 rhs.nrows() == n,
74 ));
75
76 let mut rhs = rhs;
77 permute_rows_in_place(rhs.rb_mut(), col_perm, stack);
78 linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(U.transpose(), conj_LU, rhs.rb_mut(), par);
79 linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj(L.transpose(), conj_LU, rhs.rb_mut(), par);
80 permute_rows_in_place(rhs.rb_mut(), row_perm.inverse(), stack);
81}
82
83#[track_caller]
84pub fn solve_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
85 L: MatRef<'_, C>,
86 U: MatRef<'_, C>,
87 row_perm: PermRef<'_, I>,
88 col_perm: PermRef<'_, I>,
89 rhs: MatMut<'_, T>,
90 par: Par,
91 stack: &mut MemStack,
92) {
93 solve_in_place_with_conj(L.canonical(), U.canonical(), row_perm, col_perm, Conj::get::<C>(), rhs, par, stack)
94}
95
96#[track_caller]
97pub fn solve_transpose_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
98 L: MatRef<'_, C>,
99 U: MatRef<'_, C>,
100 row_perm: PermRef<'_, I>,
101 col_perm: PermRef<'_, I>,
102 rhs: MatMut<'_, T>,
103 par: Par,
104 stack: &mut MemStack,
105) {
106 solve_transpose_in_place_with_conj(L.canonical(), U.canonical(), row_perm, col_perm, Conj::get::<C>(), rhs, par, stack)
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::assert;
113 use crate::stats::prelude::*;
114 use crate::utils::approx::*;
115 use dyn_stack::MemBuffer;
116 use linalg::lu::full_pivoting::*;
117
118 #[test]
119 fn test_solve() {
120 let rng = &mut StdRng::seed_from_u64(0);
121 let n = 50;
122 let k = 3;
123
124 let A = CwiseMatDistribution {
125 nrows: n,
126 ncols: n,
127 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
128 }
129 .rand::<Mat<c64>>(rng);
130
131 let B = CwiseMatDistribution {
132 nrows: n,
133 ncols: k,
134 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
135 }
136 .rand::<Mat<c64>>(rng);
137
138 let mut LU = A.to_owned();
139 let row_perm_fwd = &mut *vec![0usize; n];
140 let row_perm_bwd = &mut *vec![0usize; n];
141 let col_perm_fwd = &mut *vec![0usize; n];
142 let col_perm_bwd = &mut *vec![0usize; n];
143
144 let (_, row_perm, col_perm) = factor::lu_in_place(
145 LU.as_mut(),
146 row_perm_fwd,
147 row_perm_bwd,
148 col_perm_fwd,
149 col_perm_bwd,
150 Par::Seq,
151 MemStack::new(&mut { MemBuffer::new(factor::lu_in_place_scratch::<usize, c64>(n, n, Par::Seq, default())) }),
152 default(),
153 );
154
155 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
156
157 {
158 let mut X = B.to_owned();
159 solve::solve_in_place(
160 LU.as_ref(),
161 LU.as_ref(),
162 row_perm,
163 col_perm,
164 X.as_mut(),
165 Par::Seq,
166 MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
167 );
168
169 assert!(&A * &X ~ B);
170 }
171
172 {
173 let mut X = B.to_owned();
174 solve::solve_transpose_in_place(
175 LU.as_ref(),
176 LU.as_ref(),
177 row_perm,
178 col_perm,
179 X.as_mut(),
180 Par::Seq,
181 MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
182 );
183
184 assert!(A.transpose() * &X ~ B);
185 }
186 {
187 let mut X = B.to_owned();
188 solve::solve_in_place(
189 LU.conjugate(),
190 LU.conjugate(),
191 row_perm,
192 col_perm,
193 X.as_mut(),
194 Par::Seq,
195 MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
196 );
197
198 assert!(A.conjugate() * &X ~ B);
199 }
200 {
201 let mut X = B.to_owned();
202 solve::solve_transpose_in_place(
203 LU.conjugate(),
204 LU.conjugate(),
205 row_perm,
206 col_perm,
207 X.as_mut(),
208 Par::Seq,
209 MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, k, Par::Seq))),
210 );
211
212 assert!(A.adjoint() * &X ~ B);
213 }
214 }
215}