1use crate::assert;
2use crate::internal_prelude::*;
3
4pub fn solve_lstsq_in_place_scratch<I: Index, T: ComplexField>(
5 qr_nrows: usize,
6 qr_ncols: usize,
7 qr_blocksize: usize,
8 rhs_ncols: usize,
9 par: Par,
10) -> StackReq {
11 StackReq::and(
12 linalg::qr::no_pivoting::solve::solve_lstsq_in_place_scratch::<T>(qr_nrows, qr_ncols, qr_blocksize, rhs_ncols, par),
13 crate::perm::permute_rows_in_place_scratch::<I, T>(qr_ncols, rhs_ncols),
14 )
15}
16
17pub fn solve_in_place_scratch<I: Index, T: ComplexField>(qr_dim: usize, qr_blocksize: usize, rhs_ncols: usize, par: Par) -> StackReq {
18 solve_lstsq_in_place_scratch::<I, T>(qr_dim, qr_dim, qr_blocksize, rhs_ncols, par)
19}
20
21pub fn solve_transpose_in_place_scratch<I: Index, T: ComplexField>(qr_dim: usize, qr_blocksize: usize, rhs_ncols: usize, par: Par) -> StackReq {
22 StackReq::and(
23 linalg::qr::no_pivoting::solve::solve_transpose_in_place_scratch::<T>(qr_dim, qr_blocksize, rhs_ncols, par),
24 crate::perm::permute_rows_in_place_scratch::<I, T>(qr_dim, rhs_ncols),
25 )
26}
27
28#[track_caller]
29pub fn solve_lstsq_in_place_with_conj<I: Index, T: ComplexField>(
30 Q_basis: MatRef<'_, T>,
31 Q_coeff: MatRef<'_, T>,
32 R: MatRef<'_, T>,
33 col_perm: PermRef<'_, I>,
34 conj_QR: Conj,
35 rhs: MatMut<'_, T>,
36 par: Par,
37 stack: &mut MemStack,
38) {
39 let m = Q_basis.nrows();
40 let n = Q_basis.ncols();
41 let size = Ord::min(m, n);
42 let mut rhs = rhs;
43
44 linalg::qr::no_pivoting::solve::solve_lstsq_in_place_with_conj(Q_basis, Q_coeff, R, conj_QR, rhs.rb_mut(), par, stack);
45
46 crate::perm::permute_rows_in_place(rhs.get_mut(..size, ..), col_perm.inverse(), stack);
47}
48
49#[track_caller]
50pub fn solve_lstsq_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
51 Q_basis: MatRef<'_, C>,
52 Q_coeff: MatRef<'_, C>,
53 R: MatRef<'_, C>,
54 col_perm: PermRef<'_, I>,
55 rhs: MatMut<'_, T>,
56 par: Par,
57 stack: &mut MemStack,
58) {
59 solve_lstsq_in_place_with_conj(
60 Q_basis.canonical(),
61 Q_coeff.canonical(),
62 R.canonical(),
63 col_perm,
64 Conj::get::<C>(),
65 rhs,
66 par,
67 stack,
68 );
69}
70
71#[track_caller]
72pub fn solve_in_place_with_conj<I: Index, T: ComplexField>(
73 Q_basis: MatRef<'_, T>,
74 Q_coeff: MatRef<'_, T>,
75 R: MatRef<'_, T>,
76 col_perm: PermRef<'_, I>,
77 conj_QR: Conj,
78 rhs: MatMut<'_, T>,
79 par: Par,
80 stack: &mut MemStack,
81) {
82 let n = Q_basis.nrows();
83 let blocksize = Q_coeff.nrows();
84
85 assert!(all(
86 blocksize > 0,
87 rhs.nrows() == n,
88 Q_basis.nrows() == n,
89 Q_basis.ncols() == n,
90 Q_coeff.ncols() == n,
91 R.nrows() == n,
92 R.ncols() == n,
93 ));
94
95 solve_lstsq_in_place_with_conj(Q_basis, Q_coeff, R, col_perm, conj_QR, rhs, par, stack);
96}
97
98#[track_caller]
99pub fn solve_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
100 Q_basis: MatRef<'_, C>,
101 Q_coeff: MatRef<'_, C>,
102 R: MatRef<'_, C>,
103 col_perm: PermRef<'_, I>,
104 rhs: MatMut<'_, T>,
105 par: Par,
106 stack: &mut MemStack,
107) {
108 solve_in_place_with_conj(
109 Q_basis.canonical(),
110 Q_coeff.canonical(),
111 R.canonical(),
112 col_perm,
113 Conj::get::<C>(),
114 rhs,
115 par,
116 stack,
117 );
118}
119
120#[track_caller]
121pub fn solve_transpose_in_place_with_conj<I: Index, T: ComplexField>(
122 Q_basis: MatRef<'_, T>,
123 Q_coeff: MatRef<'_, T>,
124 R: MatRef<'_, T>,
125 col_perm: PermRef<'_, I>,
126 conj_QR: Conj,
127 rhs: MatMut<'_, T>,
128 par: Par,
129 stack: &mut MemStack,
130) {
131 let n = Q_basis.nrows();
132 let blocksize = Q_coeff.nrows();
133
134 assert!(all(
135 blocksize > 0,
136 rhs.nrows() == n,
137 Q_basis.nrows() == n,
138 Q_basis.ncols() == n,
139 Q_coeff.ncols() == n,
140 R.nrows() == n,
141 R.ncols() == n,
142 ));
143
144 let mut rhs = rhs;
145
146 crate::perm::permute_rows_in_place(rhs.rb_mut(), col_perm, stack);
147 linalg::qr::no_pivoting::solve::solve_transpose_in_place_with_conj(Q_basis, Q_coeff, R, conj_QR, rhs, par, stack);
148}
149
150#[track_caller]
151pub fn solve_transpose_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
152 Q_basis: MatRef<'_, C>,
153 Q_coeff: MatRef<'_, C>,
154 R: MatRef<'_, C>,
155 col_perm: PermRef<'_, I>,
156 rhs: MatMut<'_, T>,
157 par: Par,
158 stack: &mut MemStack,
159) {
160 solve_transpose_in_place_with_conj(
161 Q_basis.canonical(),
162 Q_coeff.canonical(),
163 R.canonical(),
164 col_perm,
165 Conj::get::<C>(),
166 rhs,
167 par,
168 stack,
169 );
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::assert;
176 use crate::stats::prelude::*;
177 use crate::utils::approx::*;
178 use dyn_stack::MemBuffer;
179 use linalg::qr::col_pivoting::*;
180
181 #[test]
182 fn test_lstsq() {
183 let rng = &mut StdRng::seed_from_u64(0);
184 let m = 100;
185 let n = 50;
186 let k = 3;
187
188 let A = CwiseMatDistribution {
189 nrows: m,
190 ncols: n,
191 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
192 }
193 .rand::<Mat<c64>>(rng);
194
195 let B = CwiseMatDistribution {
196 nrows: m,
197 ncols: k,
198 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
199 }
200 .rand::<Mat<c64>>(rng);
201
202 let mut QR = A.to_owned();
203 let mut Q_coeff = Mat::zeros(4, n);
204
205 let col_perm_fwd = &mut *vec![0usize; n];
206 let col_perm_bwd = &mut *vec![0usize; n];
207
208 let (_, col_perm) = factor::qr_in_place(
209 QR.as_mut(),
210 Q_coeff.as_mut(),
211 col_perm_fwd,
212 col_perm_bwd,
213 Par::Seq,
214 MemStack::new(&mut MemBuffer::new(factor::qr_in_place_scratch::<usize, c64>(
215 m,
216 n,
217 4,
218 Par::Seq,
219 default(),
220 ))),
221 default(),
222 );
223
224 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
225
226 {
227 let mut X = B.to_owned();
228 solve::solve_lstsq_in_place(
229 QR.as_ref(),
230 Q_coeff.as_ref(),
231 QR.as_ref(),
232 col_perm,
233 X.as_mut(),
234 Par::Seq,
235 MemStack::new(&mut MemBuffer::new(solve::solve_lstsq_in_place_scratch::<usize, c64>(
236 m,
237 n,
238 4,
239 k,
240 Par::Seq,
241 ))),
242 );
243
244 let X = X.get(..n, ..);
245
246 assert!(A.adjoint() * &A * &X ~ A.adjoint() * &B);
247 }
248
249 {
250 let mut X = B.to_owned();
251 solve::solve_lstsq_in_place(
252 QR.conjugate(),
253 Q_coeff.conjugate(),
254 QR.conjugate(),
255 col_perm,
256 X.as_mut(),
257 Par::Seq,
258 MemStack::new(&mut MemBuffer::new(solve::solve_lstsq_in_place_scratch::<usize, c64>(
259 m,
260 n,
261 4,
262 k,
263 Par::Seq,
264 ))),
265 );
266
267 let X = X.get(..n, ..);
268 assert!(A.transpose() * A.conjugate() * &X ~ A.transpose() * &B);
269 }
270 }
271
272 #[test]
273 fn test_solve() {
274 let rng = &mut StdRng::seed_from_u64(0);
275 let n = 50;
276 let k = 3;
277
278 let A = CwiseMatDistribution {
279 nrows: n,
280 ncols: n,
281 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
282 }
283 .rand::<Mat<c64>>(rng);
284
285 let B = CwiseMatDistribution {
286 nrows: n,
287 ncols: k,
288 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
289 }
290 .rand::<Mat<c64>>(rng);
291
292 let mut QR = A.to_owned();
293 let mut Q_coeff = Mat::zeros(4, n);
294
295 let col_perm_fwd = &mut *vec![0usize; n];
296 let col_perm_bwd = &mut *vec![0usize; n];
297
298 let (_, col_perm) = factor::qr_in_place(
299 QR.as_mut(),
300 Q_coeff.as_mut(),
301 col_perm_fwd,
302 col_perm_bwd,
303 Par::Seq,
304 MemStack::new(&mut MemBuffer::new(factor::qr_in_place_scratch::<usize, c64>(
305 n,
306 n,
307 4,
308 Par::Seq,
309 default(),
310 ))),
311 default(),
312 );
313
314 let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
315
316 {
317 let mut X = B.to_owned();
318 solve::solve_in_place(
319 QR.as_ref(),
320 Q_coeff.as_ref(),
321 QR.as_ref(),
322 col_perm,
323 X.as_mut(),
324 Par::Seq,
325 MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, 4, k, Par::Seq))),
326 );
327
328 assert!(&A * &X ~ B);
329 }
330
331 {
332 let mut X = B.to_owned();
333 solve::solve_in_place(
334 QR.conjugate(),
335 Q_coeff.conjugate(),
336 QR.conjugate(),
337 col_perm,
338 X.as_mut(),
339 Par::Seq,
340 MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, 4, k, Par::Seq))),
341 );
342
343 assert!(A.conjugate() * &X ~ B);
344 }
345
346 {
347 let mut X = B.to_owned();
348 solve::solve_transpose_in_place(
349 QR.as_ref(),
350 Q_coeff.as_ref(),
351 QR.as_ref(),
352 col_perm,
353 X.as_mut(),
354 Par::Seq,
355 MemStack::new(&mut MemBuffer::new(solve::solve_transpose_in_place_scratch::<usize, c64>(
356 n,
357 4,
358 k,
359 Par::Seq,
360 ))),
361 );
362
363 assert!(A.transpose() * &X ~ B);
364 }
365
366 {
367 let mut X = B.to_owned();
368 solve::solve_transpose_in_place(
369 QR.conjugate(),
370 Q_coeff.conjugate(),
371 QR.conjugate(),
372 col_perm,
373 X.as_mut(),
374 Par::Seq,
375 MemStack::new(&mut MemBuffer::new(solve::solve_transpose_in_place_scratch::<usize, c64>(
376 n,
377 4,
378 k,
379 Par::Seq,
380 ))),
381 );
382
383 assert!(A.adjoint() * &X ~ B);
384 }
385 }
386}