faer/linalg/cholesky/bunch_kaufman/
solve.rs

1use crate::assert;
2use crate::internal_prelude::*;
3use crate::perm::permute_rows;
4use linalg::triangular_solve::{solve_unit_lower_triangular_in_place_with_conj, solve_unit_upper_triangular_in_place_with_conj};
5
6/// computes the size and alignment of required workspace for solving a linear system defined by
7/// a matrix in place, given its bunch-kaufman decomposition
8#[track_caller]
9pub fn solve_in_place_scratch<I: Index, T: ComplexField>(dim: usize, rhs_ncols: usize, par: Par) -> StackReq {
10	let _ = par;
11	temp_mat_scratch::<T>(dim, rhs_ncols)
12}
13
14/// given the bunch-kaufman factors of a matrix $a$ and a matrix $b$ stored in `rhs`, this
15/// function computes the solution of the linear system $A x = b$, implicitly conjugating $A$ if
16/// needed
17///
18/// the solution of the linear system is stored in `rhs`
19///
20/// # panics
21///
22/// - panics if `lb_factors` is not a square matrix
23/// - panics if `subdiag` is not a column vector with the same number of rows as the dimension of
24///   `lb_factors`
25/// - panics if `rhs` doesn't have the same number of rows as the dimension of `lb_factors`
26/// - panics if the provided memory in `stack` is insufficient (see [`solve_in_place_scratch`])
27#[track_caller]
28#[math]
29pub fn solve_in_place_with_conj<I: Index, T: ComplexField>(
30	L: MatRef<'_, T>,
31	diagonal: DiagRef<'_, T>,
32	subdiagonal: DiagRef<'_, T>,
33	conj_A: Conj,
34	perm: PermRef<'_, I>,
35	rhs: MatMut<'_, T>,
36	par: Par,
37	stack: &mut MemStack,
38) {
39	let n = L.nrows();
40	let k = rhs.ncols();
41
42	assert!(all(
43		L.nrows() == n,
44		L.ncols() == n,
45		rhs.nrows() == n,
46		diagonal.dim() == n,
47		subdiagonal.dim() == n,
48		perm.len() == n
49	));
50
51	let a = L;
52	let par = par;
53	let not_conj = conj_A.compose(Conj::Yes);
54
55	let mut rhs = rhs;
56	let mut x = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack).0 };
57	let mut x = x.as_mat_mut();
58
59	permute_rows(x.rb_mut(), rhs.rb(), perm);
60	solve_unit_lower_triangular_in_place_with_conj(a, conj_A, x.rb_mut(), par);
61
62	let mut i = 0;
63	while i < n {
64		let i0 = i;
65		let i1 = i + 1;
66
67		if subdiagonal[i] == zero() {
68			let d_inv = recip(real(diagonal[i]));
69			for j in 0..k {
70				x[(i, j)] = mul_real(x[(i, j)], d_inv);
71			}
72			i += 1;
73		} else {
74			let mut akp1k = copy(subdiagonal[i0]);
75			if matches!(conj_A, Conj::Yes) {
76				akp1k = conj(akp1k);
77			}
78			akp1k = recip(akp1k);
79			let (ak, akp1) = (mul_real(conj(akp1k), real(diagonal[i0])), mul_real(akp1k, real(diagonal[i1])));
80
81			let denom = real(recip(ak * akp1 - one()));
82
83			for j in 0..k {
84				let (xk, xkp1) = (
85					//
86					x[(i0, j)] * conj(akp1k),
87					x[(i1, j)] * akp1k,
88				);
89
90				let (xk, xkp1) = (mul_real((akp1 * xk - xkp1), denom), mul_real((ak * xkp1 - xk), denom));
91
92				x[(i, j)] = xk;
93				x[(i + 1, j)] = xkp1;
94			}
95
96			i += 2;
97		}
98	}
99
100	solve_unit_upper_triangular_in_place_with_conj(a.transpose(), not_conj, x.rb_mut(), par);
101	permute_rows(rhs.rb_mut(), x.rb(), perm.inverse());
102}
103
104#[track_caller]
105#[math]
106pub fn solve_in_place<I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
107	L: MatRef<'_, C>,
108	diagonal: DiagRef<'_, C>,
109	subdiagonal: DiagRef<'_, C>,
110	perm: PermRef<'_, I>,
111	rhs: MatMut<'_, T>,
112	par: Par,
113	stack: &mut MemStack,
114) {
115	solve_in_place_with_conj(
116		L.canonical(),
117		diagonal.canonical(),
118		subdiagonal.canonical(),
119		Conj::get::<C>(),
120		perm,
121		rhs,
122		par,
123		stack,
124	);
125}