1use crate::assert;
2use crate::internal_prelude::*;
3
4pub fn solve_lstsq_in_place_scratch<T: ComplexField>(qr_nrows: usize, qr_ncols: usize, qr_blocksize: usize, rhs_ncols: usize, par: Par) -> StackReq {
5 _ = qr_ncols;
6 _ = par;
7 linalg::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<T>(qr_nrows, qr_blocksize, rhs_ncols)
8}
9
10pub fn solve_in_place_scratch<T: ComplexField>(qr_dim: usize, qr_blocksize: usize, rhs_ncols: usize, par: Par) -> StackReq {
11 solve_lstsq_in_place_scratch::<T>(qr_dim, qr_dim, qr_blocksize, rhs_ncols, par)
12}
13
14pub fn solve_transpose_in_place_scratch<T: ComplexField>(qr_dim: usize, qr_blocksize: usize, rhs_ncols: usize, par: Par) -> StackReq {
15 _ = par;
16 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(qr_dim, qr_blocksize, rhs_ncols)
17}
18
19#[track_caller]
20pub fn solve_lstsq_in_place_with_conj<T: ComplexField>(
21 Q_basis: MatRef<'_, T>,
22 Q_coeff: MatRef<'_, T>,
23 R: MatRef<'_, T>,
24 conj_QR: Conj,
25 rhs: MatMut<'_, T>,
26 par: Par,
27 stack: &mut MemStack,
28) {
29 let m = Q_basis.nrows();
30 let n = Q_basis.ncols();
31 let size = Ord::min(m, n);
32 let blocksize = Q_coeff.nrows();
33 assert!(all(
34 blocksize > 0,
35 rhs.nrows() == m,
36 Q_basis.nrows() >= Q_basis.ncols(),
37 Q_coeff.ncols() == size,
38 R.nrows() >= size,
39 R.ncols() == n,
40 ));
41
42 let mut rhs = rhs;
43 let mut stack = stack;
44 linalg::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
45 Q_basis,
46 Q_coeff,
47 conj_QR.compose(Conj::Yes),
48 rhs.rb_mut(),
49 par,
50 stack.rb_mut(),
51 );
52
53 linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(R.get(..size, ..), conj_QR, rhs.subrows_mut(0, size), par);
54}
55
56#[track_caller]
57pub fn solve_lstsq_in_place<T: ComplexField, C: Conjugate<Canonical = T>>(
58 Q_basis: MatRef<'_, C>,
59 Q_coeff: MatRef<'_, C>,
60 R: MatRef<'_, C>,
61 rhs: MatMut<'_, T>,
62 par: Par,
63 stack: &mut MemStack,
64) {
65 solve_lstsq_in_place_with_conj(Q_basis.canonical(), Q_coeff.canonical(), R.canonical(), Conj::get::<C>(), rhs, par, stack);
66}
67
68#[track_caller]
69pub fn solve_in_place_with_conj<T: ComplexField>(
70 Q_basis: MatRef<'_, T>,
71 Q_coeff: MatRef<'_, T>,
72 R: MatRef<'_, T>,
73 conj_QR: Conj,
74 rhs: MatMut<'_, T>,
75 par: Par,
76 stack: &mut MemStack,
77) {
78 let n = Q_basis.nrows();
79 let blocksize = Q_coeff.nrows();
80 assert!(all(
81 blocksize > 0,
82 rhs.nrows() == n,
83 Q_basis.nrows() == n,
84 Q_basis.ncols() == n,
85 Q_coeff.ncols() == n,
86 R.nrows() == n,
87 R.ncols() == n,
88 ));
89
90 solve_lstsq_in_place_with_conj(Q_basis, Q_coeff, R, conj_QR, rhs, par, stack);
91}
92
93#[track_caller]
94pub fn solve_in_place<T: ComplexField, C: Conjugate<Canonical = T>>(
95 Q_basis: MatRef<'_, C>,
96 Q_coeff: MatRef<'_, C>,
97 R: MatRef<'_, C>,
98 rhs: MatMut<'_, T>,
99 par: Par,
100 stack: &mut MemStack,
101) {
102 solve_in_place_with_conj(Q_basis.canonical(), Q_coeff.canonical(), R.canonical(), Conj::get::<C>(), rhs, par, stack);
103}
104
105#[track_caller]
106pub fn solve_transpose_in_place_with_conj<T: ComplexField>(
107 Q_basis: MatRef<'_, T>,
108 Q_coeff: MatRef<'_, T>,
109 R: MatRef<'_, T>,
110 conj_QR: Conj,
111 rhs: MatMut<'_, T>,
112 par: Par,
113 stack: &mut MemStack,
114) {
115 let n = Q_basis.nrows();
116 let blocksize = Q_coeff.nrows();
117
118 assert!(all(
119 blocksize > 0,
120 rhs.nrows() == n,
121 Q_basis.nrows() == n,
122 Q_basis.ncols() == n,
123 Q_coeff.ncols() == n,
124 R.nrows() == n,
125 R.ncols() == n,
126 ));
127
128 let mut rhs = rhs;
129 let mut stack = stack;
130
131 linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(R.transpose(), conj_QR, rhs.rb_mut(), par);
132 linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
133 Q_basis,
134 Q_coeff,
135 conj_QR.compose(Conj::Yes),
136 rhs.rb_mut(),
137 par,
138 stack.rb_mut(),
139 );
140}
141
142#[track_caller]
143pub fn solve_transpose_in_place<T: ComplexField, C: Conjugate<Canonical = T>>(
144 Q_basis: MatRef<'_, C>,
145 Q_coeff: MatRef<'_, C>,
146 R: MatRef<'_, C>,
147 rhs: MatMut<'_, T>,
148 par: Par,
149 stack: &mut MemStack,
150) {
151 solve_transpose_in_place_with_conj(Q_basis.canonical(), Q_coeff.canonical(), R.canonical(), Conj::get::<C>(), rhs, par, stack);
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::assert;
158 use crate::stats::prelude::*;
159 use crate::utils::approx::*;
160 use dyn_stack::MemBuffer;
161 use linalg::qr::no_pivoting::*;
162
163 #[test]
164 fn test_lstsq() {
165 let rng = &mut StdRng::seed_from_u64(0);
166 let m = 100;
167 let n = 50;
168 let k = 3;
169
170 let A = CwiseMatDistribution {
171 nrows: m,
172 ncols: n,
173 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
174 }
175 .rand::<Mat<c64>>(rng);
176
177 let B = CwiseMatDistribution {
178 nrows: m,
179 ncols: k,
180 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
181 }
182 .rand::<Mat<c64>>(rng);
183
184 let mut QR = A.to_owned();
185 let mut H = Mat::zeros(4, n);
186
187 factor::qr_in_place(
188 QR.as_mut(),
189 H.as_mut(),
190 Par::Seq,
191 MemStack::new(&mut MemBuffer::new(factor::qr_in_place_scratch::<c64>(m, n, 4, Par::Seq, default()))),
192 default(),
193 );
194
195 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
196
197 {
198 let mut X = B.to_owned();
199 solve::solve_lstsq_in_place(
200 QR.as_ref(),
201 H.as_ref(),
202 QR.as_ref(),
203 X.as_mut(),
204 Par::Seq,
205 MemStack::new(&mut MemBuffer::new(solve::solve_lstsq_in_place_scratch::<c64>(m, n, 4, k, Par::Seq))),
206 );
207
208 let X = X.get(..n, ..);
209
210 assert!(A.adjoint() * &A * &X ~ A.adjoint() * &B);
211 }
212
213 {
214 let mut X = B.to_owned();
215 solve::solve_lstsq_in_place(
216 QR.conjugate(),
217 H.conjugate(),
218 QR.conjugate(),
219 X.as_mut(),
220 Par::Seq,
221 MemStack::new(&mut MemBuffer::new(solve::solve_lstsq_in_place_scratch::<c64>(m, n, 4, k, Par::Seq))),
222 );
223
224 let X = X.get(..n, ..);
225 assert!(A.transpose() * A.conjugate() * &X ~ A.transpose() * &B);
226 }
227 }
228
229 #[test]
230 fn test_solve() {
231 let rng = &mut StdRng::seed_from_u64(0);
232 let n = 50;
233 let k = 3;
234
235 let A = CwiseMatDistribution {
236 nrows: n,
237 ncols: n,
238 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
239 }
240 .rand::<Mat<c64>>(rng);
241
242 let B = CwiseMatDistribution {
243 nrows: n,
244 ncols: k,
245 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
246 }
247 .rand::<Mat<c64>>(rng);
248
249 let mut QR = A.to_owned();
250 let mut H = Mat::zeros(4, n);
251
252 factor::qr_in_place(
253 QR.as_mut(),
254 H.as_mut(),
255 Par::Seq,
256 MemStack::new(&mut MemBuffer::new(factor::qr_in_place_scratch::<c64>(n, n, 4, Par::Seq, default()))),
257 default(),
258 );
259
260 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
261
262 {
263 let mut X = B.to_owned();
264 solve::solve_in_place(
265 QR.as_ref(),
266 H.as_ref(),
267 QR.as_ref(),
268 X.as_mut(),
269 Par::Seq,
270 MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<c64>(n, 4, k, Par::Seq))),
271 );
272
273 assert!(&A * &X ~ B);
274 }
275
276 {
277 let mut X = B.to_owned();
278 solve::solve_in_place(
279 QR.conjugate(),
280 H.conjugate(),
281 QR.conjugate(),
282 X.as_mut(),
283 Par::Seq,
284 MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<c64>(n, 4, k, Par::Seq))),
285 );
286
287 assert!(A.conjugate() * &X ~ B);
288 }
289
290 {
291 let mut X = B.to_owned();
292 solve::solve_transpose_in_place(
293 QR.as_ref(),
294 H.as_ref(),
295 QR.as_ref(),
296 X.as_mut(),
297 Par::Seq,
298 MemStack::new(&mut MemBuffer::new(solve::solve_transpose_in_place_scratch::<c64>(n, 4, k, Par::Seq))),
299 );
300
301 assert!(A.transpose() * &X ~ B);
302 }
303
304 {
305 let mut X = B.to_owned();
306 solve::solve_transpose_in_place(
307 QR.conjugate(),
308 H.conjugate(),
309 QR.conjugate(),
310 X.as_mut(),
311 Par::Seq,
312 MemStack::new(&mut MemBuffer::new(solve::solve_transpose_in_place_scratch::<c64>(n, 4, k, Par::Seq))),
313 );
314
315 assert!(A.adjoint() * &X ~ B);
316 }
317 }
318}