faer/linalg/
triangular_solve.rs

1//! Triangular solve module.
2
3use crate::internal_prelude::*;
4use crate::utils::thread::join_raw;
5use crate::{assert, debug_assert};
6use faer_macros::math;
7use faer_traits::{Conjugate, SimdArch};
8use reborrow::*;
9
10#[inline(always)]
11#[math]
12fn identity<T: ComplexField>(x: T) -> T {
13	copy(x)
14}
15
16#[inline(always)]
17#[math]
18fn conjugate<T: ComplexField>(x: T) -> T {
19	conj(x)
20}
21
22#[inline(always)]
23#[math]
24fn solve_unit_lower_triangular_in_place_base_case_generic_imp<'N, 'K, T: ComplexField>(
25	tril: MatRef<'_, T, Dim<'N>, Dim<'N>>,
26	rhs: MatMut<'_, T, Dim<'N>, Dim<'K>>,
27	maybe_conj_lhs: impl Fn(T) -> T,
28) {
29	let N = tril.nrows();
30	let n = N.unbound();
31
32	match n {
33		0 | 1 => (),
34		2 => {
35			let i0 = N.check(0);
36			let i1 = N.check(1);
37
38			let nl10_div_l11 = maybe_conj_lhs(-tril[(i1, i0)]);
39
40			let (x0, rhs) = rhs.split_first_row_mut().unwrap();
41			let (x1, rhs) = rhs.split_first_row_mut().unwrap();
42			_ = rhs;
43
44			zip!(x0, x1).for_each(|unzip!(x0, x1)| *x1 = *x1 + nl10_div_l11 * *x0);
45		},
46		3 => {
47			let i0 = N.check(0);
48			let i1 = N.check(1);
49			let i2 = N.check(2);
50
51			let nl10_div_l11 = maybe_conj_lhs(-tril[(i1, i0)]);
52			let nl20_div_l22 = maybe_conj_lhs(-tril[(i2, i0)]);
53			let nl21_div_l22 = maybe_conj_lhs(-tril[(i2, i1)]);
54
55			let (x0, rhs) = rhs.split_first_row_mut().unwrap();
56			let (x1, rhs) = rhs.split_first_row_mut().unwrap();
57			let (x2, rhs) = rhs.split_first_row_mut().unwrap();
58			_ = rhs;
59
60			zip!(x0, x1, x2).for_each(|unzip!(x0, x1, x2)| {
61				let y0 = copy(*x0);
62				let mut y1 = copy(*x1);
63				let mut y2 = copy(*x2);
64				y1 = y1 + nl10_div_l11 * y0;
65				y2 = y2 + nl20_div_l22 * y0 + nl21_div_l22 * y1;
66				*x0 = y0;
67				*x1 = y1;
68				*x2 = y2;
69			});
70		},
71		4 => {
72			let i0 = N.check(0);
73			let i1 = N.check(1);
74			let i2 = N.check(2);
75			let i3 = N.check(3);
76			let nl10_div_l11 = maybe_conj_lhs(-tril[(i1, i0)]);
77			let nl20_div_l22 = maybe_conj_lhs(-tril[(i2, i0)]);
78			let nl21_div_l22 = maybe_conj_lhs(-tril[(i2, i1)]);
79			let nl30_div_l33 = maybe_conj_lhs(-tril[(i3, i0)]);
80			let nl31_div_l33 = maybe_conj_lhs(-tril[(i3, i1)]);
81			let nl32_div_l33 = maybe_conj_lhs(-tril[(i3, i2)]);
82
83			let (x0, rhs) = rhs.split_first_row_mut().unwrap();
84			let (x1, rhs) = rhs.split_first_row_mut().unwrap();
85			let (x2, rhs) = rhs.split_first_row_mut().unwrap();
86			let (x3, rhs) = rhs.split_first_row_mut().unwrap();
87			_ = rhs;
88
89			zip!(x0, x1, x2, x3).for_each(|unzip!(x0, x1, x2, x3)| {
90				let y0 = copy(*x0);
91				let mut y1 = copy(*x1);
92				let mut y2 = copy(*x2);
93				let mut y3 = copy(*x3);
94				y1 = y1 + nl10_div_l11 * y0;
95				y2 = y2 + nl20_div_l22 * y0 + nl21_div_l22 * y1;
96				y3 = y3 + nl30_div_l33 * y0 + nl31_div_l33 * y1 + nl32_div_l33 * y2;
97				*x0 = y0;
98				*x1 = y1;
99				*x2 = y2;
100				*x3 = y3;
101			});
102		},
103		_ => unreachable!(),
104	}
105}
106
107#[inline(always)]
108#[math]
109fn solve_lower_triangular_in_place_base_case_generic_imp<'N, 'K, T: ComplexField>(
110	tril: MatRef<'_, T, Dim<'N>, Dim<'N>>,
111	rhs: MatMut<'_, T, Dim<'N>, Dim<'K>>,
112	maybe_conj_lhs: impl Fn(T) -> T,
113) {
114	let N = tril.nrows();
115	let n = N.unbound();
116
117	match n {
118		0 => (),
119		1 => {
120			let i0 = N.check(0);
121
122			let inv = maybe_conj_lhs(recip(tril[(i0, i0)]));
123
124			let (x0, rhs) = rhs.split_first_row_mut().unwrap();
125			_ = rhs;
126
127			zip!(x0).for_each(|unzip!(x0)| *x0 = *x0 * inv);
128		},
129		2 => {
130			let i0 = N.check(0);
131			let i1 = N.check(1);
132
133			let l00_inv = maybe_conj_lhs(recip(tril[(i0, i0)]));
134			let l11_inv = maybe_conj_lhs(recip(tril[(i1, i1)]));
135			let nl10_div_l11 = maybe_conj_lhs(-tril[(i1, i0)]) * l11_inv;
136
137			let (x0, rhs) = rhs.split_first_row_mut().unwrap();
138			let (x1, rhs) = rhs.split_first_row_mut().unwrap();
139			_ = rhs;
140
141			zip!(x0, x1).for_each(|unzip!(x0, x1)| {
142				*x0 = *x0 * l00_inv;
143				*x1 = *x1 * l11_inv + nl10_div_l11 * x0;
144			});
145		},
146		3 => {
147			let i0 = N.check(0);
148			let i1 = N.check(1);
149			let i2 = N.check(2);
150
151			let l00_inv = maybe_conj_lhs(recip(tril[(i0, i0)]));
152			let l11_inv = maybe_conj_lhs(recip(tril[(i1, i1)]));
153			let l22_inv = maybe_conj_lhs(recip(tril[(i2, i2)]));
154			let nl10_div_l11 = maybe_conj_lhs(-tril[(i1, i0)]) * l11_inv;
155			let nl20_div_l22 = maybe_conj_lhs(-tril[(i2, i0)]) * l22_inv;
156			let nl21_div_l22 = maybe_conj_lhs(-tril[(i2, i1)]) * l22_inv;
157
158			let (x0, rhs) = rhs.split_first_row_mut().unwrap();
159			let (x1, rhs) = rhs.split_first_row_mut().unwrap();
160			let (x2, rhs) = rhs.split_first_row_mut().unwrap();
161			_ = rhs;
162
163			zip!(x0, x1, x2).for_each(|unzip!(x0, x1, x2)| {
164				let mut y0 = copy(*x0);
165				let mut y1 = copy(*x1);
166				let mut y2 = copy(*x2);
167				y0 = y0 * l00_inv;
168				y1 = y1 * l11_inv + nl10_div_l11 * y0;
169				y2 = y2 * l22_inv + nl20_div_l22 * y0 + nl21_div_l22 * y1;
170				*x0 = y0;
171				*x1 = y1;
172				*x2 = y2;
173			});
174		},
175		4 => {
176			let i0 = N.check(0);
177			let i1 = N.check(1);
178			let i2 = N.check(2);
179			let i3 = N.check(3);
180
181			let l00_inv = maybe_conj_lhs(recip(tril[(i0, i0)]));
182			let l11_inv = maybe_conj_lhs(recip(tril[(i1, i1)]));
183			let l22_inv = maybe_conj_lhs(recip(tril[(i2, i2)]));
184			let l33_inv = maybe_conj_lhs(recip(tril[(i3, i3)]));
185			let nl10_div_l11 = maybe_conj_lhs(-tril[(i1, i0)]) * l11_inv;
186			let nl20_div_l22 = maybe_conj_lhs(-tril[(i2, i0)]) * l22_inv;
187			let nl21_div_l22 = maybe_conj_lhs(-tril[(i2, i1)]) * l22_inv;
188			let nl30_div_l33 = maybe_conj_lhs(-tril[(i3, i0)]) * l33_inv;
189			let nl31_div_l33 = maybe_conj_lhs(-tril[(i3, i1)]) * l33_inv;
190			let nl32_div_l33 = maybe_conj_lhs(-tril[(i3, i2)]) * l33_inv;
191
192			let (x0, rhs) = rhs.split_first_row_mut().unwrap();
193			let (x1, rhs) = rhs.split_first_row_mut().unwrap();
194			let (x2, rhs) = rhs.split_first_row_mut().unwrap();
195			let (x3, rhs) = rhs.split_first_row_mut().unwrap();
196			_ = rhs;
197
198			zip!(x0, x1, x2, x3).for_each(|unzip!(x0, x1, x2, x3)| {
199				let mut y0 = copy(*x0);
200				let mut y1 = copy(*x1);
201				let mut y2 = copy(*x2);
202				let mut y3 = copy(*x3);
203				y0 = y0 * l00_inv;
204				y1 = y1 * l11_inv + nl10_div_l11 * y0;
205				y2 = y2 * l22_inv + nl20_div_l22 * y0 + nl21_div_l22 * y1;
206				y3 = y3 * l33_inv + nl30_div_l33 * y0 + nl31_div_l33 * y1 + nl32_div_l33 * y2;
207				*x0 = y0;
208				*x1 = y1;
209				*x2 = y2;
210				*x3 = y3;
211			});
212		},
213		_ => unreachable!(),
214	}
215}
216
217#[inline]
218fn blocksize(n: usize) -> usize {
219	// we want remainder to be a multiple of register size
220	let base_rem = n / 2;
221	n - if n >= 32 {
222		(base_rem + 15) / 16 * 16
223	} else if n >= 16 {
224		(base_rem + 7) / 8 * 8
225	} else if n >= 8 {
226		(base_rem + 3) / 4 * 4
227	} else {
228		base_rem
229	}
230}
231
232#[inline]
233fn recursion_threshold() -> usize {
234	4
235}
236
237/// solves $L x = b$, implicitly conjugating $L$ if needed, and stores the result in `rhs`
238#[track_caller]
239#[inline]
240pub fn solve_lower_triangular_in_place_with_conj<T: ComplexField, N: Shape, K: Shape>(
241	triangular_lower: MatRef<'_, T, N, N, impl Stride, impl Stride>,
242	conj_lhs: Conj,
243	rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
244	par: Par,
245) {
246	assert!(all(
247		triangular_lower.nrows() == triangular_lower.ncols(),
248		rhs.nrows() == triangular_lower.ncols(),
249	));
250
251	make_guard!(N);
252	make_guard!(K);
253	let N = rhs.nrows().bind(N);
254	let K = rhs.ncols().bind(K);
255
256	solve_lower_triangular_in_place_imp(
257		triangular_lower.as_dyn_stride().as_shape(N, N),
258		conj_lhs,
259		rhs.as_dyn_stride_mut().as_shape_mut(N, K),
260		par,
261	);
262}
263
264/// solves $L x = b$, implicitly conjugating $L$ if needed, and stores the result in `rhs`
265#[inline]
266#[track_caller]
267pub fn solve_lower_triangular_in_place<T: ComplexField, LhsT: Conjugate<Canonical = T>, N: Shape, K: Shape>(
268	triangular_lower: MatRef<'_, LhsT, N, N, impl Stride, impl Stride>,
269	rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
270	par: Par,
271) {
272	let tri = triangular_lower.canonical();
273	solve_lower_triangular_in_place_with_conj(tri, Conj::get::<LhsT>(), rhs, par)
274}
275
276/// solves $L x = b$, replacing the diagonal of $L$ with ones, and implicitly conjugating $L$ if
277/// needed, and stores the result in `rhs`
278#[track_caller]
279#[inline]
280pub fn solve_unit_lower_triangular_in_place_with_conj<T: ComplexField, N: Shape, K: Shape>(
281	triangular_unit_lower: MatRef<'_, T, N, N, impl Stride, impl Stride>,
282	conj_lhs: Conj,
283	rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
284	par: Par,
285) {
286	assert!(all(
287		triangular_unit_lower.nrows() == triangular_unit_lower.ncols(),
288		rhs.nrows() == triangular_unit_lower.ncols(),
289	));
290
291	make_guard!(N);
292	make_guard!(K);
293	let N = rhs.nrows().bind(N);
294	let K = rhs.ncols().bind(K);
295
296	solve_unit_lower_triangular_in_place_imp(
297		triangular_unit_lower.as_dyn_stride().as_shape(N, N),
298		conj_lhs,
299		rhs.as_dyn_stride_mut().as_shape_mut(N, K),
300		par,
301	);
302}
303
304/// solves $L x = b$, replacing the diagonal of $L$ with ones, and implicitly conjugating $L$ if
305/// needed, and stores the result in `rhs`
306#[inline]
307#[track_caller]
308pub fn solve_unit_lower_triangular_in_place<T: ComplexField, LhsT: Conjugate<Canonical = T>, N: Shape, K: Shape>(
309	triangular_unit_lower: MatRef<'_, LhsT, N, N, impl Stride, impl Stride>,
310	rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
311	par: Par,
312) {
313	let tri = triangular_unit_lower.canonical();
314	solve_unit_lower_triangular_in_place_with_conj(tri, Conj::get::<LhsT>(), rhs, par)
315}
316
317/// solves $U x = b$, implicitly conjugating $U$ if needed, and stores the result in `rhs`
318#[track_caller]
319#[inline]
320pub fn solve_upper_triangular_in_place_with_conj<T: ComplexField, N: Shape, K: Shape>(
321	triangular_upper: MatRef<'_, T, N, N, impl Stride, impl Stride>,
322	conj_lhs: Conj,
323	rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
324	par: Par,
325) {
326	assert!(all(
327		triangular_upper.nrows() == triangular_upper.ncols(),
328		rhs.nrows() == triangular_upper.ncols(),
329	));
330
331	make_guard!(N);
332	make_guard!(K);
333	let N = rhs.nrows().bind(N);
334	let K = rhs.ncols().bind(K);
335
336	solve_upper_triangular_in_place_imp(
337		triangular_upper.as_dyn_stride().as_shape(N, N),
338		conj_lhs,
339		rhs.as_dyn_stride_mut().as_shape_mut(N, K),
340		par,
341	);
342}
343
344/// solves $U x = b$, implicitly conjugating $U$ if needed, and stores the result in `rhs`
345#[inline]
346#[track_caller]
347pub fn solve_upper_triangular_in_place<T: ComplexField, LhsT: Conjugate<Canonical = T>, N: Shape, K: Shape>(
348	triangular_upper: MatRef<'_, LhsT, N, N, impl Stride, impl Stride>,
349	rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
350	par: Par,
351) {
352	let tri = triangular_upper.canonical();
353	solve_upper_triangular_in_place_with_conj(tri, Conj::get::<LhsT>(), rhs, par)
354}
355
356/// solves $U x = b$, replacing the diagonal of $U$ with ones, and implicitly conjugating $U$ if
357/// needed, and stores the result in `rhs`
358#[track_caller]
359#[inline]
360pub fn solve_unit_upper_triangular_in_place_with_conj<T: ComplexField, N: Shape, K: Shape>(
361	triangular_unit_upper: MatRef<'_, T, N, N, impl Stride, impl Stride>,
362	conj_lhs: Conj,
363	rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
364	par: Par,
365) {
366	assert!(all(
367		triangular_unit_upper.nrows() == triangular_unit_upper.ncols(),
368		rhs.nrows() == triangular_unit_upper.ncols(),
369	));
370
371	make_guard!(N);
372	make_guard!(K);
373	let N = rhs.nrows().bind(N);
374	let K = rhs.ncols().bind(K);
375
376	solve_unit_upper_triangular_in_place_imp(
377		triangular_unit_upper.as_dyn_stride().as_shape(N, N),
378		conj_lhs,
379		rhs.as_dyn_stride_mut().as_shape_mut(N, K),
380		par,
381	);
382}
383
384/// solves $U x = b$, replacing the diagonal of $U$ with ones, and implicitly conjugating $U$ if
385/// needed, and stores the result in `rhs`
386#[inline]
387#[track_caller]
388pub fn solve_unit_upper_triangular_in_place<T: ComplexField, LhsT: Conjugate<Canonical = T>, N: Shape, K: Shape>(
389	triangular_unit_upper: MatRef<'_, LhsT, N, N, impl Stride, impl Stride>,
390	rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
391	par: Par,
392) {
393	let tri = triangular_unit_upper.canonical();
394	solve_unit_upper_triangular_in_place_with_conj(tri, Conj::get::<LhsT>(), rhs, par)
395}
396
397#[math]
398fn solve_unit_lower_triangular_in_place_imp<'N, 'K, T: ComplexField>(
399	tril: MatRef<'_, T, Dim<'N>, Dim<'N>>,
400	conj_lhs: Conj,
401	rhs: MatMut<'_, T, Dim<'N>, Dim<'K>>,
402	par: Par,
403) {
404	let N = tril.nrows();
405	let K = rhs.ncols();
406	let n = N.unbound();
407	let k = K.unbound();
408
409	if k > 64 && n <= 128 {
410		make_guard!(LEFT);
411		make_guard!(RIGHT);
412
413		let mid = K.partition(IdxInc::new_checked(k / 2, K), LEFT, RIGHT);
414
415		let (rhs_left, rhs_right) = rhs.split_cols_with_mut(mid);
416		join_raw(
417			|_| solve_unit_lower_triangular_in_place_imp(tril, conj_lhs, rhs_left, par),
418			|_| solve_unit_lower_triangular_in_place_imp(tril, conj_lhs, rhs_right, par),
419			par,
420		);
421		return;
422	}
423
424	debug_assert!(all(tril.nrows() == tril.ncols(), rhs.nrows() == tril.ncols(),));
425
426	if n <= recursion_threshold() {
427		T::Arch::default().dispatch(
428			#[inline(always)]
429			|| match conj_lhs {
430				Conj::Yes => solve_unit_lower_triangular_in_place_base_case_generic_imp(tril, rhs, conjugate),
431				Conj::No => solve_unit_lower_triangular_in_place_base_case_generic_imp(tril, rhs, identity),
432			},
433		);
434		return;
435	}
436
437	make_guard!(HEAD);
438	make_guard!(TAIL);
439	let bs = N.partition(IdxInc::new_checked(blocksize(n), N), HEAD, TAIL);
440
441	let (tril_top_left, _, tril_bot_left, tril_bot_right) = tril.split_with(bs, bs);
442	let (mut rhs_top, mut rhs_bot) = rhs.split_rows_with_mut(bs);
443
444	//       (A00    )   X0         (B0)
445	// ConjA?(A10 A11)   X1 = ConjB?(B1)
446	//
447	//
448	// 1. ConjA?(A00) X0 = ConjB?(B0)
449	//
450	// 2. ConjA?(A10) X0 + ConjA?(A11) X1 = ConjB?(B1)
451	// => ConjA?(A11) X1 = ConjB?(B1) - ConjA?(A10) X0
452
453	solve_unit_lower_triangular_in_place_imp(tril_top_left, conj_lhs, rhs_top.rb_mut(), par);
454
455	crate::linalg::matmul::matmul_with_conj(
456		rhs_bot.rb_mut(),
457		Accum::Add,
458		tril_bot_left,
459		conj_lhs,
460		rhs_top.into_const(),
461		Conj::No,
462		-one::<T>(),
463		par,
464	);
465
466	solve_unit_lower_triangular_in_place_imp(tril_bot_right, conj_lhs, rhs_bot, par);
467}
468
469#[math]
470fn solve_lower_triangular_in_place_imp<'N, 'K, T: ComplexField>(
471	tril: MatRef<'_, T, Dim<'N>, Dim<'N>>,
472	conj_lhs: Conj,
473	rhs: MatMut<'_, T, Dim<'N>, Dim<'K>>,
474	par: Par,
475) {
476	let N = tril.nrows();
477	let K = rhs.ncols();
478	let n = N.unbound();
479	let k = K.unbound();
480
481	if k > 64 && n <= 128 {
482		make_guard!(LEFT);
483		make_guard!(RIGHT);
484
485		let mid = K.partition(IdxInc::new_checked(k / 2, K), LEFT, RIGHT);
486
487		let (rhs_left, rhs_right) = rhs.split_cols_with_mut(mid);
488		join_raw(
489			|_| solve_lower_triangular_in_place_imp(tril, conj_lhs, rhs_left, par),
490			|_| solve_lower_triangular_in_place_imp(tril, conj_lhs, rhs_right, par),
491			par,
492		);
493		return;
494	}
495
496	debug_assert!(all(tril.nrows() == tril.ncols(), rhs.nrows() == tril.ncols(),));
497
498	if n <= recursion_threshold() {
499		T::Arch::default().dispatch(
500			#[inline(always)]
501			|| match conj_lhs {
502				Conj::Yes => solve_lower_triangular_in_place_base_case_generic_imp(tril, rhs, conjugate),
503				Conj::No => solve_lower_triangular_in_place_base_case_generic_imp(tril, rhs, identity),
504			},
505		);
506		return;
507	}
508
509	make_guard!(HEAD);
510	make_guard!(TAIL);
511	let bs = N.partition(IdxInc::new_checked(blocksize(n), N), HEAD, TAIL);
512
513	let (tril_top_left, _, tril_bot_left, tril_bot_right) = tril.split_with(bs, bs);
514	let (mut rhs_top, mut rhs_bot) = rhs.split_rows_with_mut(bs);
515
516	//       (A00    )   X0         (B0)
517	// ConjA?(A10 A11)   X1 = ConjB?(B1)
518	//
519	//
520	// 1. ConjA?(A00) X0 = ConjB?(B0)
521	//
522	// 2. ConjA?(A10) X0 + ConjA?(A11) X1 = ConjB?(B1)
523	// => ConjA?(A11) X1 = ConjB?(B1) - ConjA?(A10) X0
524
525	solve_lower_triangular_in_place_imp(tril_top_left, conj_lhs, rhs_top.rb_mut(), par);
526
527	crate::linalg::matmul::matmul_with_conj(
528		rhs_bot.rb_mut(),
529		Accum::Add,
530		tril_bot_left,
531		conj_lhs,
532		rhs_top.into_const(),
533		Conj::No,
534		-one::<T>(),
535		par,
536	);
537
538	solve_lower_triangular_in_place_imp(tril_bot_right, conj_lhs, rhs_bot, par);
539}
540
541#[inline]
542fn solve_unit_upper_triangular_in_place_imp<'N, 'K, T: ComplexField>(
543	triu: MatRef<'_, T, Dim<'N>, Dim<'N>>,
544	conj_lhs: Conj,
545	rhs: MatMut<'_, T, Dim<'N>, Dim<'K>>,
546	par: Par,
547) {
548	solve_unit_lower_triangular_in_place_imp(triu.reverse_rows_and_cols(), conj_lhs, rhs.reverse_rows_mut(), par);
549}
550
551#[inline]
552fn solve_upper_triangular_in_place_imp<'N, 'K, T: ComplexField>(
553	triu: MatRef<'_, T, Dim<'N>, Dim<'N>>,
554	conj_lhs: Conj,
555	rhs: MatMut<'_, T, Dim<'N>, Dim<'K>>,
556	par: Par,
557) {
558	solve_lower_triangular_in_place_imp(triu.reverse_rows_and_cols(), conj_lhs, rhs.reverse_rows_mut(), par);
559}