1use crate::internal_prelude::*;
4use crate::utils::thread::join_raw;
5use crate::{assert, debug_assert};
6use faer_macros::math;
7use faer_traits::{Conjugate, SimdArch};
8use reborrow::*;
9
10#[inline(always)]
11#[math]
12fn identity<T: ComplexField>(x: T) -> T {
13 copy(x)
14}
15
16#[inline(always)]
17#[math]
18fn conjugate<T: ComplexField>(x: T) -> T {
19 conj(x)
20}
21
22#[inline(always)]
23#[math]
24fn solve_unit_lower_triangular_in_place_base_case_generic_imp<'N, 'K, T: ComplexField>(
25 tril: MatRef<'_, T, Dim<'N>, Dim<'N>>,
26 rhs: MatMut<'_, T, Dim<'N>, Dim<'K>>,
27 maybe_conj_lhs: impl Fn(T) -> T,
28) {
29 let N = tril.nrows();
30 let n = N.unbound();
31
32 match n {
33 0 | 1 => (),
34 2 => {
35 let i0 = N.check(0);
36 let i1 = N.check(1);
37
38 let nl10_div_l11 = maybe_conj_lhs(-tril[(i1, i0)]);
39
40 let (x0, rhs) = rhs.split_first_row_mut().unwrap();
41 let (x1, rhs) = rhs.split_first_row_mut().unwrap();
42 _ = rhs;
43
44 zip!(x0, x1).for_each(|unzip!(x0, x1)| *x1 = *x1 + nl10_div_l11 * *x0);
45 },
46 3 => {
47 let i0 = N.check(0);
48 let i1 = N.check(1);
49 let i2 = N.check(2);
50
51 let nl10_div_l11 = maybe_conj_lhs(-tril[(i1, i0)]);
52 let nl20_div_l22 = maybe_conj_lhs(-tril[(i2, i0)]);
53 let nl21_div_l22 = maybe_conj_lhs(-tril[(i2, i1)]);
54
55 let (x0, rhs) = rhs.split_first_row_mut().unwrap();
56 let (x1, rhs) = rhs.split_first_row_mut().unwrap();
57 let (x2, rhs) = rhs.split_first_row_mut().unwrap();
58 _ = rhs;
59
60 zip!(x0, x1, x2).for_each(|unzip!(x0, x1, x2)| {
61 let y0 = copy(*x0);
62 let mut y1 = copy(*x1);
63 let mut y2 = copy(*x2);
64 y1 = y1 + nl10_div_l11 * y0;
65 y2 = y2 + nl20_div_l22 * y0 + nl21_div_l22 * y1;
66 *x0 = y0;
67 *x1 = y1;
68 *x2 = y2;
69 });
70 },
71 4 => {
72 let i0 = N.check(0);
73 let i1 = N.check(1);
74 let i2 = N.check(2);
75 let i3 = N.check(3);
76 let nl10_div_l11 = maybe_conj_lhs(-tril[(i1, i0)]);
77 let nl20_div_l22 = maybe_conj_lhs(-tril[(i2, i0)]);
78 let nl21_div_l22 = maybe_conj_lhs(-tril[(i2, i1)]);
79 let nl30_div_l33 = maybe_conj_lhs(-tril[(i3, i0)]);
80 let nl31_div_l33 = maybe_conj_lhs(-tril[(i3, i1)]);
81 let nl32_div_l33 = maybe_conj_lhs(-tril[(i3, i2)]);
82
83 let (x0, rhs) = rhs.split_first_row_mut().unwrap();
84 let (x1, rhs) = rhs.split_first_row_mut().unwrap();
85 let (x2, rhs) = rhs.split_first_row_mut().unwrap();
86 let (x3, rhs) = rhs.split_first_row_mut().unwrap();
87 _ = rhs;
88
89 zip!(x0, x1, x2, x3).for_each(|unzip!(x0, x1, x2, x3)| {
90 let y0 = copy(*x0);
91 let mut y1 = copy(*x1);
92 let mut y2 = copy(*x2);
93 let mut y3 = copy(*x3);
94 y1 = y1 + nl10_div_l11 * y0;
95 y2 = y2 + nl20_div_l22 * y0 + nl21_div_l22 * y1;
96 y3 = y3 + nl30_div_l33 * y0 + nl31_div_l33 * y1 + nl32_div_l33 * y2;
97 *x0 = y0;
98 *x1 = y1;
99 *x2 = y2;
100 *x3 = y3;
101 });
102 },
103 _ => unreachable!(),
104 }
105}
106
107#[inline(always)]
108#[math]
109fn solve_lower_triangular_in_place_base_case_generic_imp<'N, 'K, T: ComplexField>(
110 tril: MatRef<'_, T, Dim<'N>, Dim<'N>>,
111 rhs: MatMut<'_, T, Dim<'N>, Dim<'K>>,
112 maybe_conj_lhs: impl Fn(T) -> T,
113) {
114 let N = tril.nrows();
115 let n = N.unbound();
116
117 match n {
118 0 => (),
119 1 => {
120 let i0 = N.check(0);
121
122 let inv = maybe_conj_lhs(recip(tril[(i0, i0)]));
123
124 let (x0, rhs) = rhs.split_first_row_mut().unwrap();
125 _ = rhs;
126
127 zip!(x0).for_each(|unzip!(x0)| *x0 = *x0 * inv);
128 },
129 2 => {
130 let i0 = N.check(0);
131 let i1 = N.check(1);
132
133 let l00_inv = maybe_conj_lhs(recip(tril[(i0, i0)]));
134 let l11_inv = maybe_conj_lhs(recip(tril[(i1, i1)]));
135 let nl10_div_l11 = maybe_conj_lhs(-tril[(i1, i0)]) * l11_inv;
136
137 let (x0, rhs) = rhs.split_first_row_mut().unwrap();
138 let (x1, rhs) = rhs.split_first_row_mut().unwrap();
139 _ = rhs;
140
141 zip!(x0, x1).for_each(|unzip!(x0, x1)| {
142 *x0 = *x0 * l00_inv;
143 *x1 = *x1 * l11_inv + nl10_div_l11 * x0;
144 });
145 },
146 3 => {
147 let i0 = N.check(0);
148 let i1 = N.check(1);
149 let i2 = N.check(2);
150
151 let l00_inv = maybe_conj_lhs(recip(tril[(i0, i0)]));
152 let l11_inv = maybe_conj_lhs(recip(tril[(i1, i1)]));
153 let l22_inv = maybe_conj_lhs(recip(tril[(i2, i2)]));
154 let nl10_div_l11 = maybe_conj_lhs(-tril[(i1, i0)]) * l11_inv;
155 let nl20_div_l22 = maybe_conj_lhs(-tril[(i2, i0)]) * l22_inv;
156 let nl21_div_l22 = maybe_conj_lhs(-tril[(i2, i1)]) * l22_inv;
157
158 let (x0, rhs) = rhs.split_first_row_mut().unwrap();
159 let (x1, rhs) = rhs.split_first_row_mut().unwrap();
160 let (x2, rhs) = rhs.split_first_row_mut().unwrap();
161 _ = rhs;
162
163 zip!(x0, x1, x2).for_each(|unzip!(x0, x1, x2)| {
164 let mut y0 = copy(*x0);
165 let mut y1 = copy(*x1);
166 let mut y2 = copy(*x2);
167 y0 = y0 * l00_inv;
168 y1 = y1 * l11_inv + nl10_div_l11 * y0;
169 y2 = y2 * l22_inv + nl20_div_l22 * y0 + nl21_div_l22 * y1;
170 *x0 = y0;
171 *x1 = y1;
172 *x2 = y2;
173 });
174 },
175 4 => {
176 let i0 = N.check(0);
177 let i1 = N.check(1);
178 let i2 = N.check(2);
179 let i3 = N.check(3);
180
181 let l00_inv = maybe_conj_lhs(recip(tril[(i0, i0)]));
182 let l11_inv = maybe_conj_lhs(recip(tril[(i1, i1)]));
183 let l22_inv = maybe_conj_lhs(recip(tril[(i2, i2)]));
184 let l33_inv = maybe_conj_lhs(recip(tril[(i3, i3)]));
185 let nl10_div_l11 = maybe_conj_lhs(-tril[(i1, i0)]) * l11_inv;
186 let nl20_div_l22 = maybe_conj_lhs(-tril[(i2, i0)]) * l22_inv;
187 let nl21_div_l22 = maybe_conj_lhs(-tril[(i2, i1)]) * l22_inv;
188 let nl30_div_l33 = maybe_conj_lhs(-tril[(i3, i0)]) * l33_inv;
189 let nl31_div_l33 = maybe_conj_lhs(-tril[(i3, i1)]) * l33_inv;
190 let nl32_div_l33 = maybe_conj_lhs(-tril[(i3, i2)]) * l33_inv;
191
192 let (x0, rhs) = rhs.split_first_row_mut().unwrap();
193 let (x1, rhs) = rhs.split_first_row_mut().unwrap();
194 let (x2, rhs) = rhs.split_first_row_mut().unwrap();
195 let (x3, rhs) = rhs.split_first_row_mut().unwrap();
196 _ = rhs;
197
198 zip!(x0, x1, x2, x3).for_each(|unzip!(x0, x1, x2, x3)| {
199 let mut y0 = copy(*x0);
200 let mut y1 = copy(*x1);
201 let mut y2 = copy(*x2);
202 let mut y3 = copy(*x3);
203 y0 = y0 * l00_inv;
204 y1 = y1 * l11_inv + nl10_div_l11 * y0;
205 y2 = y2 * l22_inv + nl20_div_l22 * y0 + nl21_div_l22 * y1;
206 y3 = y3 * l33_inv + nl30_div_l33 * y0 + nl31_div_l33 * y1 + nl32_div_l33 * y2;
207 *x0 = y0;
208 *x1 = y1;
209 *x2 = y2;
210 *x3 = y3;
211 });
212 },
213 _ => unreachable!(),
214 }
215}
216
217#[inline]
218fn blocksize(n: usize) -> usize {
219 let base_rem = n / 2;
221 n - if n >= 32 {
222 (base_rem + 15) / 16 * 16
223 } else if n >= 16 {
224 (base_rem + 7) / 8 * 8
225 } else if n >= 8 {
226 (base_rem + 3) / 4 * 4
227 } else {
228 base_rem
229 }
230}
231
232#[inline]
233fn recursion_threshold() -> usize {
234 4
235}
236
237#[track_caller]
239#[inline]
240pub fn solve_lower_triangular_in_place_with_conj<T: ComplexField, N: Shape, K: Shape>(
241 triangular_lower: MatRef<'_, T, N, N, impl Stride, impl Stride>,
242 conj_lhs: Conj,
243 rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
244 par: Par,
245) {
246 assert!(all(
247 triangular_lower.nrows() == triangular_lower.ncols(),
248 rhs.nrows() == triangular_lower.ncols(),
249 ));
250
251 make_guard!(N);
252 make_guard!(K);
253 let N = rhs.nrows().bind(N);
254 let K = rhs.ncols().bind(K);
255
256 solve_lower_triangular_in_place_imp(
257 triangular_lower.as_dyn_stride().as_shape(N, N),
258 conj_lhs,
259 rhs.as_dyn_stride_mut().as_shape_mut(N, K),
260 par,
261 );
262}
263
264#[inline]
266#[track_caller]
267pub fn solve_lower_triangular_in_place<T: ComplexField, LhsT: Conjugate<Canonical = T>, N: Shape, K: Shape>(
268 triangular_lower: MatRef<'_, LhsT, N, N, impl Stride, impl Stride>,
269 rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
270 par: Par,
271) {
272 let tri = triangular_lower.canonical();
273 solve_lower_triangular_in_place_with_conj(tri, Conj::get::<LhsT>(), rhs, par)
274}
275
276#[track_caller]
279#[inline]
280pub fn solve_unit_lower_triangular_in_place_with_conj<T: ComplexField, N: Shape, K: Shape>(
281 triangular_unit_lower: MatRef<'_, T, N, N, impl Stride, impl Stride>,
282 conj_lhs: Conj,
283 rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
284 par: Par,
285) {
286 assert!(all(
287 triangular_unit_lower.nrows() == triangular_unit_lower.ncols(),
288 rhs.nrows() == triangular_unit_lower.ncols(),
289 ));
290
291 make_guard!(N);
292 make_guard!(K);
293 let N = rhs.nrows().bind(N);
294 let K = rhs.ncols().bind(K);
295
296 solve_unit_lower_triangular_in_place_imp(
297 triangular_unit_lower.as_dyn_stride().as_shape(N, N),
298 conj_lhs,
299 rhs.as_dyn_stride_mut().as_shape_mut(N, K),
300 par,
301 );
302}
303
304#[inline]
307#[track_caller]
308pub fn solve_unit_lower_triangular_in_place<T: ComplexField, LhsT: Conjugate<Canonical = T>, N: Shape, K: Shape>(
309 triangular_unit_lower: MatRef<'_, LhsT, N, N, impl Stride, impl Stride>,
310 rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
311 par: Par,
312) {
313 let tri = triangular_unit_lower.canonical();
314 solve_unit_lower_triangular_in_place_with_conj(tri, Conj::get::<LhsT>(), rhs, par)
315}
316
317#[track_caller]
319#[inline]
320pub fn solve_upper_triangular_in_place_with_conj<T: ComplexField, N: Shape, K: Shape>(
321 triangular_upper: MatRef<'_, T, N, N, impl Stride, impl Stride>,
322 conj_lhs: Conj,
323 rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
324 par: Par,
325) {
326 assert!(all(
327 triangular_upper.nrows() == triangular_upper.ncols(),
328 rhs.nrows() == triangular_upper.ncols(),
329 ));
330
331 make_guard!(N);
332 make_guard!(K);
333 let N = rhs.nrows().bind(N);
334 let K = rhs.ncols().bind(K);
335
336 solve_upper_triangular_in_place_imp(
337 triangular_upper.as_dyn_stride().as_shape(N, N),
338 conj_lhs,
339 rhs.as_dyn_stride_mut().as_shape_mut(N, K),
340 par,
341 );
342}
343
344#[inline]
346#[track_caller]
347pub fn solve_upper_triangular_in_place<T: ComplexField, LhsT: Conjugate<Canonical = T>, N: Shape, K: Shape>(
348 triangular_upper: MatRef<'_, LhsT, N, N, impl Stride, impl Stride>,
349 rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
350 par: Par,
351) {
352 let tri = triangular_upper.canonical();
353 solve_upper_triangular_in_place_with_conj(tri, Conj::get::<LhsT>(), rhs, par)
354}
355
356#[track_caller]
359#[inline]
360pub fn solve_unit_upper_triangular_in_place_with_conj<T: ComplexField, N: Shape, K: Shape>(
361 triangular_unit_upper: MatRef<'_, T, N, N, impl Stride, impl Stride>,
362 conj_lhs: Conj,
363 rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
364 par: Par,
365) {
366 assert!(all(
367 triangular_unit_upper.nrows() == triangular_unit_upper.ncols(),
368 rhs.nrows() == triangular_unit_upper.ncols(),
369 ));
370
371 make_guard!(N);
372 make_guard!(K);
373 let N = rhs.nrows().bind(N);
374 let K = rhs.ncols().bind(K);
375
376 solve_unit_upper_triangular_in_place_imp(
377 triangular_unit_upper.as_dyn_stride().as_shape(N, N),
378 conj_lhs,
379 rhs.as_dyn_stride_mut().as_shape_mut(N, K),
380 par,
381 );
382}
383
384#[inline]
387#[track_caller]
388pub fn solve_unit_upper_triangular_in_place<T: ComplexField, LhsT: Conjugate<Canonical = T>, N: Shape, K: Shape>(
389 triangular_unit_upper: MatRef<'_, LhsT, N, N, impl Stride, impl Stride>,
390 rhs: MatMut<'_, T, N, K, impl Stride, impl Stride>,
391 par: Par,
392) {
393 let tri = triangular_unit_upper.canonical();
394 solve_unit_upper_triangular_in_place_with_conj(tri, Conj::get::<LhsT>(), rhs, par)
395}
396
397#[math]
398fn solve_unit_lower_triangular_in_place_imp<'N, 'K, T: ComplexField>(
399 tril: MatRef<'_, T, Dim<'N>, Dim<'N>>,
400 conj_lhs: Conj,
401 rhs: MatMut<'_, T, Dim<'N>, Dim<'K>>,
402 par: Par,
403) {
404 let N = tril.nrows();
405 let K = rhs.ncols();
406 let n = N.unbound();
407 let k = K.unbound();
408
409 if k > 64 && n <= 128 {
410 make_guard!(LEFT);
411 make_guard!(RIGHT);
412
413 let mid = K.partition(IdxInc::new_checked(k / 2, K), LEFT, RIGHT);
414
415 let (rhs_left, rhs_right) = rhs.split_cols_with_mut(mid);
416 join_raw(
417 |_| solve_unit_lower_triangular_in_place_imp(tril, conj_lhs, rhs_left, par),
418 |_| solve_unit_lower_triangular_in_place_imp(tril, conj_lhs, rhs_right, par),
419 par,
420 );
421 return;
422 }
423
424 debug_assert!(all(tril.nrows() == tril.ncols(), rhs.nrows() == tril.ncols(),));
425
426 if n <= recursion_threshold() {
427 T::Arch::default().dispatch(
428 #[inline(always)]
429 || match conj_lhs {
430 Conj::Yes => solve_unit_lower_triangular_in_place_base_case_generic_imp(tril, rhs, conjugate),
431 Conj::No => solve_unit_lower_triangular_in_place_base_case_generic_imp(tril, rhs, identity),
432 },
433 );
434 return;
435 }
436
437 make_guard!(HEAD);
438 make_guard!(TAIL);
439 let bs = N.partition(IdxInc::new_checked(blocksize(n), N), HEAD, TAIL);
440
441 let (tril_top_left, _, tril_bot_left, tril_bot_right) = tril.split_with(bs, bs);
442 let (mut rhs_top, mut rhs_bot) = rhs.split_rows_with_mut(bs);
443
444 solve_unit_lower_triangular_in_place_imp(tril_top_left, conj_lhs, rhs_top.rb_mut(), par);
454
455 crate::linalg::matmul::matmul_with_conj(
456 rhs_bot.rb_mut(),
457 Accum::Add,
458 tril_bot_left,
459 conj_lhs,
460 rhs_top.into_const(),
461 Conj::No,
462 -one::<T>(),
463 par,
464 );
465
466 solve_unit_lower_triangular_in_place_imp(tril_bot_right, conj_lhs, rhs_bot, par);
467}
468
469#[math]
470fn solve_lower_triangular_in_place_imp<'N, 'K, T: ComplexField>(
471 tril: MatRef<'_, T, Dim<'N>, Dim<'N>>,
472 conj_lhs: Conj,
473 rhs: MatMut<'_, T, Dim<'N>, Dim<'K>>,
474 par: Par,
475) {
476 let N = tril.nrows();
477 let K = rhs.ncols();
478 let n = N.unbound();
479 let k = K.unbound();
480
481 if k > 64 && n <= 128 {
482 make_guard!(LEFT);
483 make_guard!(RIGHT);
484
485 let mid = K.partition(IdxInc::new_checked(k / 2, K), LEFT, RIGHT);
486
487 let (rhs_left, rhs_right) = rhs.split_cols_with_mut(mid);
488 join_raw(
489 |_| solve_lower_triangular_in_place_imp(tril, conj_lhs, rhs_left, par),
490 |_| solve_lower_triangular_in_place_imp(tril, conj_lhs, rhs_right, par),
491 par,
492 );
493 return;
494 }
495
496 debug_assert!(all(tril.nrows() == tril.ncols(), rhs.nrows() == tril.ncols(),));
497
498 if n <= recursion_threshold() {
499 T::Arch::default().dispatch(
500 #[inline(always)]
501 || match conj_lhs {
502 Conj::Yes => solve_lower_triangular_in_place_base_case_generic_imp(tril, rhs, conjugate),
503 Conj::No => solve_lower_triangular_in_place_base_case_generic_imp(tril, rhs, identity),
504 },
505 );
506 return;
507 }
508
509 make_guard!(HEAD);
510 make_guard!(TAIL);
511 let bs = N.partition(IdxInc::new_checked(blocksize(n), N), HEAD, TAIL);
512
513 let (tril_top_left, _, tril_bot_left, tril_bot_right) = tril.split_with(bs, bs);
514 let (mut rhs_top, mut rhs_bot) = rhs.split_rows_with_mut(bs);
515
516 solve_lower_triangular_in_place_imp(tril_top_left, conj_lhs, rhs_top.rb_mut(), par);
526
527 crate::linalg::matmul::matmul_with_conj(
528 rhs_bot.rb_mut(),
529 Accum::Add,
530 tril_bot_left,
531 conj_lhs,
532 rhs_top.into_const(),
533 Conj::No,
534 -one::<T>(),
535 par,
536 );
537
538 solve_lower_triangular_in_place_imp(tril_bot_right, conj_lhs, rhs_bot, par);
539}
540
541#[inline]
542fn solve_unit_upper_triangular_in_place_imp<'N, 'K, T: ComplexField>(
543 triu: MatRef<'_, T, Dim<'N>, Dim<'N>>,
544 conj_lhs: Conj,
545 rhs: MatMut<'_, T, Dim<'N>, Dim<'K>>,
546 par: Par,
547) {
548 solve_unit_lower_triangular_in_place_imp(triu.reverse_rows_and_cols(), conj_lhs, rhs.reverse_rows_mut(), par);
549}
550
551#[inline]
552fn solve_upper_triangular_in_place_imp<'N, 'K, T: ComplexField>(
553 triu: MatRef<'_, T, Dim<'N>, Dim<'N>>,
554 conj_lhs: Conj,
555 rhs: MatMut<'_, T, Dim<'N>, Dim<'K>>,
556 par: Par,
557) {
558 solve_lower_triangular_in_place_imp(triu.reverse_rows_and_cols(), conj_lhs, rhs.reverse_rows_mut(), par);
559}