faer/linalg/qr/no_pivoting/
solve.rs

1use crate::assert;
2use crate::internal_prelude::*;
3
4pub fn solve_lstsq_in_place_scratch<T: ComplexField>(qr_nrows: usize, qr_ncols: usize, qr_blocksize: usize, rhs_ncols: usize, par: Par) -> StackReq {
5	_ = qr_ncols;
6	_ = par;
7	linalg::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<T>(qr_nrows, qr_blocksize, rhs_ncols)
8}
9
10pub fn solve_in_place_scratch<T: ComplexField>(qr_dim: usize, qr_blocksize: usize, rhs_ncols: usize, par: Par) -> StackReq {
11	solve_lstsq_in_place_scratch::<T>(qr_dim, qr_dim, qr_blocksize, rhs_ncols, par)
12}
13
14pub fn solve_transpose_in_place_scratch<T: ComplexField>(qr_dim: usize, qr_blocksize: usize, rhs_ncols: usize, par: Par) -> StackReq {
15	_ = par;
16	linalg::householder::apply_block_householder_sequence_on_the_left_in_place_scratch::<T>(qr_dim, qr_blocksize, rhs_ncols)
17}
18
19#[track_caller]
20pub fn solve_lstsq_in_place_with_conj<T: ComplexField>(
21	Q_basis: MatRef<'_, T>,
22	Q_coeff: MatRef<'_, T>,
23	R: MatRef<'_, T>,
24	conj_QR: Conj,
25	rhs: MatMut<'_, T>,
26	par: Par,
27	stack: &mut MemStack,
28) {
29	let m = Q_basis.nrows();
30	let n = Q_basis.ncols();
31	let size = Ord::min(m, n);
32	let blocksize = Q_coeff.nrows();
33	assert!(all(
34		blocksize > 0,
35		rhs.nrows() == m,
36		Q_basis.nrows() >= Q_basis.ncols(),
37		Q_coeff.ncols() == size,
38		R.nrows() >= size,
39		R.ncols() == n,
40	));
41
42	let mut rhs = rhs;
43	let mut stack = stack;
44	linalg::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
45		Q_basis,
46		Q_coeff,
47		conj_QR.compose(Conj::Yes),
48		rhs.rb_mut(),
49		par,
50		stack.rb_mut(),
51	);
52
53	linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(R.get(..size, ..), conj_QR, rhs.subrows_mut(0, size), par);
54}
55
56#[track_caller]
57pub fn solve_lstsq_in_place<T: ComplexField, C: Conjugate<Canonical = T>>(
58	Q_basis: MatRef<'_, C>,
59	Q_coeff: MatRef<'_, C>,
60	R: MatRef<'_, C>,
61	rhs: MatMut<'_, T>,
62	par: Par,
63	stack: &mut MemStack,
64) {
65	solve_lstsq_in_place_with_conj(Q_basis.canonical(), Q_coeff.canonical(), R.canonical(), Conj::get::<C>(), rhs, par, stack);
66}
67
68#[track_caller]
69pub fn solve_in_place_with_conj<T: ComplexField>(
70	Q_basis: MatRef<'_, T>,
71	Q_coeff: MatRef<'_, T>,
72	R: MatRef<'_, T>,
73	conj_QR: Conj,
74	rhs: MatMut<'_, T>,
75	par: Par,
76	stack: &mut MemStack,
77) {
78	let n = Q_basis.nrows();
79	let blocksize = Q_coeff.nrows();
80	assert!(all(
81		blocksize > 0,
82		rhs.nrows() == n,
83		Q_basis.nrows() == n,
84		Q_basis.ncols() == n,
85		Q_coeff.ncols() == n,
86		R.nrows() == n,
87		R.ncols() == n,
88	));
89
90	solve_lstsq_in_place_with_conj(Q_basis, Q_coeff, R, conj_QR, rhs, par, stack);
91}
92
93#[track_caller]
94pub fn solve_in_place<T: ComplexField, C: Conjugate<Canonical = T>>(
95	Q_basis: MatRef<'_, C>,
96	Q_coeff: MatRef<'_, C>,
97	R: MatRef<'_, C>,
98	rhs: MatMut<'_, T>,
99	par: Par,
100	stack: &mut MemStack,
101) {
102	solve_in_place_with_conj(Q_basis.canonical(), Q_coeff.canonical(), R.canonical(), Conj::get::<C>(), rhs, par, stack);
103}
104
105#[track_caller]
106pub fn solve_transpose_in_place_with_conj<T: ComplexField>(
107	Q_basis: MatRef<'_, T>,
108	Q_coeff: MatRef<'_, T>,
109	R: MatRef<'_, T>,
110	conj_QR: Conj,
111	rhs: MatMut<'_, T>,
112	par: Par,
113	stack: &mut MemStack,
114) {
115	let n = Q_basis.nrows();
116	let blocksize = Q_coeff.nrows();
117
118	assert!(all(
119		blocksize > 0,
120		rhs.nrows() == n,
121		Q_basis.nrows() == n,
122		Q_basis.ncols() == n,
123		Q_coeff.ncols() == n,
124		R.nrows() == n,
125		R.ncols() == n,
126	));
127
128	let mut rhs = rhs;
129	let mut stack = stack;
130
131	linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(R.transpose(), conj_QR, rhs.rb_mut(), par);
132	linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj(
133		Q_basis,
134		Q_coeff,
135		conj_QR.compose(Conj::Yes),
136		rhs.rb_mut(),
137		par,
138		stack.rb_mut(),
139	);
140}
141
142#[track_caller]
143pub fn solve_transpose_in_place<T: ComplexField, C: Conjugate<Canonical = T>>(
144	Q_basis: MatRef<'_, C>,
145	Q_coeff: MatRef<'_, C>,
146	R: MatRef<'_, C>,
147	rhs: MatMut<'_, T>,
148	par: Par,
149	stack: &mut MemStack,
150) {
151	solve_transpose_in_place_with_conj(Q_basis.canonical(), Q_coeff.canonical(), R.canonical(), Conj::get::<C>(), rhs, par, stack);
152}
153
154#[cfg(test)]
155mod tests {
156	use super::*;
157	use crate::assert;
158	use crate::stats::prelude::*;
159	use crate::utils::approx::*;
160	use dyn_stack::MemBuffer;
161	use linalg::qr::no_pivoting::*;
162
163	#[test]
164	fn test_lstsq() {
165		let rng = &mut StdRng::seed_from_u64(0);
166		let m = 100;
167		let n = 50;
168		let k = 3;
169
170		let A = CwiseMatDistribution {
171			nrows: m,
172			ncols: n,
173			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
174		}
175		.rand::<Mat<c64>>(rng);
176
177		let B = CwiseMatDistribution {
178			nrows: m,
179			ncols: k,
180			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
181		}
182		.rand::<Mat<c64>>(rng);
183
184		let mut QR = A.to_owned();
185		let mut H = Mat::zeros(4, n);
186
187		factor::qr_in_place(
188			QR.as_mut(),
189			H.as_mut(),
190			Par::Seq,
191			MemStack::new(&mut MemBuffer::new(factor::qr_in_place_scratch::<c64>(m, n, 4, Par::Seq, default()))),
192			default(),
193		);
194
195		let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
196
197		{
198			let mut X = B.to_owned();
199			solve::solve_lstsq_in_place(
200				QR.as_ref(),
201				H.as_ref(),
202				QR.as_ref(),
203				X.as_mut(),
204				Par::Seq,
205				MemStack::new(&mut MemBuffer::new(solve::solve_lstsq_in_place_scratch::<c64>(m, n, 4, k, Par::Seq))),
206			);
207
208			let X = X.get(..n, ..);
209
210			assert!(A.adjoint() * &A * &X ~ A.adjoint() * &B);
211		}
212
213		{
214			let mut X = B.to_owned();
215			solve::solve_lstsq_in_place(
216				QR.conjugate(),
217				H.conjugate(),
218				QR.conjugate(),
219				X.as_mut(),
220				Par::Seq,
221				MemStack::new(&mut MemBuffer::new(solve::solve_lstsq_in_place_scratch::<c64>(m, n, 4, k, Par::Seq))),
222			);
223
224			let X = X.get(..n, ..);
225			assert!(A.transpose() * A.conjugate() * &X ~ A.transpose() * &B);
226		}
227	}
228
229	#[test]
230	fn test_solve() {
231		let rng = &mut StdRng::seed_from_u64(0);
232		let n = 50;
233		let k = 3;
234
235		let A = CwiseMatDistribution {
236			nrows: n,
237			ncols: n,
238			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
239		}
240		.rand::<Mat<c64>>(rng);
241
242		let B = CwiseMatDistribution {
243			nrows: n,
244			ncols: k,
245			dist: ComplexDistribution::new(StandardNormal, StandardNormal),
246		}
247		.rand::<Mat<c64>>(rng);
248
249		let mut QR = A.to_owned();
250		let mut H = Mat::zeros(4, n);
251
252		factor::qr_in_place(
253			QR.as_mut(),
254			H.as_mut(),
255			Par::Seq,
256			MemStack::new(&mut MemBuffer::new(factor::qr_in_place_scratch::<c64>(n, n, 4, Par::Seq, default()))),
257			default(),
258		);
259
260		let approx_eq = CwiseMat(ApproxEq::eps() * (n as f64));
261
262		{
263			let mut X = B.to_owned();
264			solve::solve_in_place(
265				QR.as_ref(),
266				H.as_ref(),
267				QR.as_ref(),
268				X.as_mut(),
269				Par::Seq,
270				MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<c64>(n, 4, k, Par::Seq))),
271			);
272
273			assert!(&A * &X ~ B);
274		}
275
276		{
277			let mut X = B.to_owned();
278			solve::solve_in_place(
279				QR.conjugate(),
280				H.conjugate(),
281				QR.conjugate(),
282				X.as_mut(),
283				Par::Seq,
284				MemStack::new(&mut MemBuffer::new(solve::solve_in_place_scratch::<c64>(n, 4, k, Par::Seq))),
285			);
286
287			assert!(A.conjugate() * &X ~ B);
288		}
289
290		{
291			let mut X = B.to_owned();
292			solve::solve_transpose_in_place(
293				QR.as_ref(),
294				H.as_ref(),
295				QR.as_ref(),
296				X.as_mut(),
297				Par::Seq,
298				MemStack::new(&mut MemBuffer::new(solve::solve_transpose_in_place_scratch::<c64>(n, 4, k, Par::Seq))),
299			);
300
301			assert!(A.transpose() * &X ~ B);
302		}
303
304		{
305			let mut X = B.to_owned();
306			solve::solve_transpose_in_place(
307				QR.conjugate(),
308				H.conjugate(),
309				QR.conjugate(),
310				X.as_mut(),
311				Par::Seq,
312				MemStack::new(&mut MemBuffer::new(solve::solve_transpose_in_place_scratch::<c64>(n, 4, k, Par::Seq))),
313			);
314
315			assert!(A.adjoint() * &X ~ B);
316		}
317	}
318}