faer/linalg/qr/col_pivoting/
solve.rs

1use crate::assert;
2use crate::internal_prelude::*;
3
4pub fn solve_lstsq_in_place_scratch<I: Index, T: ComplexField>(
5	qr_nrows: usize,
6	qr_ncols: usize,
7	qr_blocksize: usize,
8	rhs_ncols: usize,
9	par: Par,
10) -> StackReq {
11	StackReq::and(
12		linalg::qr::no_pivoting::solve::solve_lstsq_in_place_scratch::<T>(qr_nrows, qr_ncols, qr_blocksize, rhs_ncols, par),
13		crate::perm::permute_rows_in_place_scratch::<I, T>(qr_ncols, rhs_ncols),
14	)
15}
16
17pub fn solve_in_place_scratch<I: Index, T: ComplexField>(qr_dim: usize, qr_blocksize: usize, rhs_ncols: usize, par: Par) -> StackReq {
18	solve_lstsq_in_place_scratch::<I, T>(qr_dim, qr_dim, qr_blocksize, rhs_ncols, par)
19}
20
21pub fn solve_transpose_in_place_scratch<I: Index, T: ComplexField>(qr_dim: usize, qr_blocksize: usize, rhs_ncols: usize, par: Par) -> StackReq {
22	StackReq::and(
23		linalg::qr::no_pivoting::solve::solve_transpose_in_place_scratch::<T>(qr_dim, qr_blocksize, rhs_ncols, par),
24		crate::perm::permute_rows_in_place_scratch::<I, T>(qr_dim, rhs_ncols),
25	)
26}
27
28#[track_caller]
29pub fn solve_lstsq_in_place_with_conj<I: Index, T: ComplexField>(
30	Q_basis: MatRef<'_, T>,
31	Q_coeff: MatRef<'_, T>,
32	R: MatRef<'_, T>,
33	col_perm: PermRef<'_, I>,
34	conj_QR: Conj,
35	rhs: MatMut<'_, T>,
36	par: Par,
37	stack: &mut MemStack,
38) {
39	let m = Q_basis.nrows();
40	let n = Q_basis.ncols();
41	let size = Ord::min(m, n);
42	let mut rhs = rhs;
43
44	linalg::qr::no_pivoting::solve::solve_lstsq_in_place_with_conj(Q_basis, Q_coeff, R, conj_QR, rhs.rb_mut(), par, stack);
45
46	crate::perm::permute_rows_in_place(rhs.get_mut(..size, ..), col_perm.inverse(), stack);
47}
48
49#[track_caller]
50pub fn solve_lstsq_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
51	Q_basis: MatRef<'_, C>,
52	Q_coeff: MatRef<'_, C>,
53	R: MatRef<'_, C>,
54	col_perm: PermRef<'_, I>,
55	rhs: MatMut<'_, T>,
56	par: Par,
57	stack: &mut MemStack,
58) {
59	solve_lstsq_in_place_with_conj(
60		Q_basis.canonical(),
61		Q_coeff.canonical(),
62		R.canonical(),
63		col_perm,
64		Conj::get::<C>(),
65		rhs,
66		par,
67		stack,
68	);
69}
70
71#[track_caller]
72pub fn solve_in_place_with_conj<I: Index, T: ComplexField>(
73	Q_basis: MatRef<'_, T>,
74	Q_coeff: MatRef<'_, T>,
75	R: MatRef<'_, T>,
76	col_perm: PermRef<'_, I>,
77	conj_QR: Conj,
78	rhs: MatMut<'_, T>,
79	par: Par,
80	stack: &mut MemStack,
81) {
82	let n = Q_basis.nrows();
83	let blocksize = Q_coeff.nrows();
84
85	assert!(all(
86		blocksize > 0,
87		rhs.nrows() == n,
88		Q_basis.nrows() == n,
89		Q_basis.ncols() == n,
90		Q_coeff.ncols() == n,
91		R.nrows() == n,
92		R.ncols() == n,
93	));
94
95	solve_lstsq_in_place_with_conj(Q_basis, Q_coeff, R, col_perm, conj_QR, rhs, par, stack);
96}
97
98#[track_caller]
99pub fn solve_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
100	Q_basis: MatRef<'_, C>,
101	Q_coeff: MatRef<'_, C>,
102	R: MatRef<'_, C>,
103	col_perm: PermRef<'_, I>,
104	rhs: MatMut<'_, T>,
105	par: Par,
106	stack: &mut MemStack,
107) {
108	solve_in_place_with_conj(
109		Q_basis.canonical(),
110		Q_coeff.canonical(),
111		R.canonical(),
112		col_perm,
113		Conj::get::<C>(),
114		rhs,
115		par,
116		stack,
117	);
118}
119
120#[track_caller]
121pub fn solve_transpose_in_place_with_conj<I: Index, T: ComplexField>(
122	Q_basis: MatRef<'_, T>,
123	Q_coeff: MatRef<'_, T>,
124	R: MatRef<'_, T>,
125	col_perm: PermRef<'_, I>,
126	conj_QR: Conj,
127	rhs: MatMut<'_, T>,
128	par: Par,
129	stack: &mut MemStack,
130) {
131	let n = Q_basis.nrows();
132	let blocksize = Q_coeff.nrows();
133
134	assert!(all(
135		blocksize > 0,
136		rhs.nrows() == n,
137		Q_basis.nrows() == n,
138		Q_basis.ncols() == n,
139		Q_coeff.ncols() == n,
140		R.nrows() == n,
141		R.ncols() == n,
142	));
143
144	let mut rhs = rhs;
145
146	crate::perm::permute_rows_in_place(rhs.rb_mut(), col_perm, stack);
147	linalg::qr::no_pivoting::solve::solve_transpose_in_place_with_conj(Q_basis, Q_coeff, R, conj_QR, rhs, par, stack);
148}
149
150#[track_caller]
151pub fn solve_transpose_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
152	Q_basis: MatRef<'_, C>,
153	Q_coeff: MatRef<'_, C>,
154	R: MatRef<'_, C>,
155	col_perm: PermRef<'_, I>,
156	rhs: MatMut<'_, T>,
157	par: Par,
158	stack: &mut MemStack,
159) {
160	solve_transpose_in_place_with_conj(
161		Q_basis.canonical(),
162		Q_coeff.canonical(),
163		R.canonical(),
164		col_perm,
165		Conj::get::<C>(),
166		rhs,
167		par,
168		stack,
169	);
170}
171
172#[cfg(test)]
173mod tests {
174	use super::*;
175	use crate::assert;
176	use crate::stats::prelude::*;
177	use crate::utils::approx::*;
178	use dyn_stack::MemBuffer;
179	use linalg::qr::col_pivoting::*;
180
181	#[test]
182	fn test_lstsq() {
183		let rng = &mut StdRng::seed_from_u64(0);
184		let m = 100;
185		let n = 50;
186		let k = 3;
187
188		let A = CwiseMatDistribution {
189			nrows: m,
190			ncols: n,
191			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
192		}
193		.rand::<Mat<c64>>(rng);
194
195		let B = CwiseMatDistribution {
196			nrows: m,
197			ncols: k,
198			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
199		}
200		.rand::<Mat<c64>>(rng);
201
202		let mut QR = A.to_owned();
203		let mut Q_coeff = Mat::zeros(4, n);
204
205		let col_perm_fwd = &mut *vec![0usize; n];
206		let col_perm_bwd = &mut *vec![0usize; n];
207
208		let (_, col_perm) = factor::qr_in_place(
209			QR.as_mut(),
210			Q_coeff.as_mut(),
211			col_perm_fwd,
212			col_perm_bwd,
213			Par::Seq,
214			MemStack::new(&mut MemBuffer::new(factor::qr_in_place_scratch::<usize, c64>(
215				m,
216				n,
217				4,
218				Par::Seq,
219				default(),
220			))),
221			default(),
222		);
223
224		let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
225
226		{
227			let mut X = B.to_owned();
228			solve::solve_lstsq_in_place(
229				QR.as_ref(),
230				Q_coeff.as_ref(),
231				QR.as_ref(),
232				col_perm,
233				X.as_mut(),
234				Par::Seq,
235				MemStack::new(&mut MemBuffer::new(solve::solve_lstsq_in_place_scratch::<usize, c64>(
236					m,
237					n,
238					4,
239					k,
240					Par::Seq,
241				))),
242			);
243
244			let X = X.get(..n, ..);
245
246			assert!(A.adjoint() * &A * &X ~ A.adjoint() * &B);
247		}
248
249		{
250			let mut X = B.to_owned();
251			solve::solve_lstsq_in_place(
252				QR.conjugate(),
253				Q_coeff.conjugate(),
254				QR.conjugate(),
255				col_perm,
256				X.as_mut(),
257				Par::Seq,
258				MemStack::new(&mut MemBuffer::new(solve::solve_lstsq_in_place_scratch::<usize, c64>(
259					m,
260					n,
261					4,
262					k,
263					Par::Seq,
264				))),
265			);
266
267			let X = X.get(..n, ..);
268			assert!(A.transpose() * A.conjugate() * &X ~ A.transpose() * &B);
269		}
270	}
271
272	#[test]
273	fn test_solve() {
274		let rng = &mut StdRng::seed_from_u64(0);
275		let n = 50;
276		let k = 3;
277
278		let A = CwiseMatDistribution {
279			nrows: n,
280			ncols: n,
281			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
282		}
283		.rand::<Mat<c64>>(rng);
284
285		let B = CwiseMatDistribution {
286			nrows: n,
287			ncols: k,
288			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
289		}
290		.rand::<Mat<c64>>(rng);
291
292		let mut QR = A.to_owned();
293		let mut Q_coeff = Mat::zeros(4, n);
294
295		let col_perm_fwd = &mut *vec![0usize; n];
296		let col_perm_bwd = &mut *vec![0usize; n];
297
298		let (_, col_perm) = factor::qr_in_place(
299			QR.as_mut(),
300			Q_coeff.as_mut(),
301			col_perm_fwd,
302			col_perm_bwd,
303			Par::Seq,
304			MemStack::new(&mut MemBuffer::new(factor::qr_in_place_scratch::<usize, c64>(
305				n,
306				n,
307				4,
308				Par::Seq,
309				default(),
310			))),
311			default(),
312		);
313
314		let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
315
316		{
317			let mut X = B.to_owned();
318			solve::solve_in_place(
319				QR.as_ref(),
320				Q_coeff.as_ref(),
321				QR.as_ref(),
322				col_perm,
323				X.as_mut(),
324				Par::Seq,
325				MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, 4, k, Par::Seq))),
326			);
327
328			assert!(&A * &X ~ B);
329		}
330
331		{
332			let mut X = B.to_owned();
333			solve::solve_in_place(
334				QR.conjugate(),
335				Q_coeff.conjugate(),
336				QR.conjugate(),
337				col_perm,
338				X.as_mut(),
339				Par::Seq,
340				MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<usize, c64>(n, 4, k, Par::Seq))),
341			);
342
343			assert!(A.conjugate() * &X ~ B);
344		}
345
346		{
347			let mut X = B.to_owned();
348			solve::solve_transpose_in_place(
349				QR.as_ref(),
350				Q_coeff.as_ref(),
351				QR.as_ref(),
352				col_perm,
353				X.as_mut(),
354				Par::Seq,
355				MemStack::new(&mut MemBuffer::new(solve::solve_transpose_in_place_scratch::<usize, c64>(
356					n,
357					4,
358					k,
359					Par::Seq,
360				))),
361			);
362
363			assert!(A.transpose() * &X ~ B);
364		}
365
366		{
367			let mut X = B.to_owned();
368			solve::solve_transpose_in_place(
369				QR.conjugate(),
370				Q_coeff.conjugate(),
371				QR.conjugate(),
372				col_perm,
373				X.as_mut(),
374				Par::Seq,
375				MemStack::new(&mut MemBuffer::new(solve::solve_transpose_in_place_scratch::<usize, c64>(
376					n,
377					4,
378					k,
379					Par::Seq,
380				))),
381			);
382
383			assert!(A.adjoint() * &X ~ B);
384		}
385	}
386}