faer/sparse/linalg/
triangular_solve.rs

1use crate::internal_prelude_sp::*;
2use crate::{assert, debug_assert};
3
4/// assuming `tril` is a lower triangular matrix, solves the equation `tril * x = rhs`, and
5/// stores the result in `rhs`, implicitly conjugating `tril` if needed
6///
7/// # note
8/// the matrix indices need not be sorted, but
9/// the diagonal element is assumed to be the first stored element in each column.
10#[track_caller]
11pub fn solve_lower_triangular_in_place<I: Index, T: ComplexField>(tril: SparseColMatRef<'_, I, T>, conj_tril: Conj, rhs: MatMut<'_, T>, par: Par) {
12	solve_lower_triangular_in_place_impl(tril, conj_tril, DiagStatus::Generic, rhs, par)
13}
14
15/// assuming `tril` is a lower triangular matrix, solves the equation `tril * x = rhs`, and
16/// stores the result in `rhs`, implicitly conjugating `tril` if needed
17///
18/// # note
19/// the matrix indices need not be sorted, but
20/// the diagonal element is assumed to be the first stored element in each column.
21#[track_caller]
22pub fn solve_unit_lower_triangular_in_place<I: Index, T: ComplexField>(
23	tril: SparseColMatRef<'_, I, T>,
24	conj_tril: Conj,
25	rhs: MatMut<'_, T>,
26	par: Par,
27) {
28	solve_lower_triangular_in_place_impl(tril, conj_tril, DiagStatus::Unit, rhs, par)
29}
30
31/// assuming `tril` is a lower triangular matrix, solves the equation `tril.transpose() * x =
32/// rhs`, and stores the result in `rhs`, implicitly conjugating `tril` if needed
33///
34/// # note
35/// the matrix indices need not be sorted, but
36/// the diagonal element is assumed to be the first stored element in each column.
37#[track_caller]
38pub fn solve_lower_triangular_transpose_in_place<I: Index, T: ComplexField>(
39	tril: SparseColMatRef<'_, I, T>,
40	conj_tril: Conj,
41	rhs: MatMut<'_, T>,
42	par: Par,
43) {
44	solve_lower_triangular_transpose_in_place_impl(tril, conj_tril, DiagStatus::Generic, rhs, par)
45}
46
47/// assuming `tril` is a lower triangular matrix, solves the equation `tril.transpose() * x =
48/// rhs`, and stores the result in `rhs`, implicitly conjugating `tril` if needed
49///
50/// # note
51/// the matrix indices need not be sorted, but
52/// the diagonal element is assumed to be the first stored element in each column.
53#[track_caller]
54pub fn solve_unit_lower_triangular_transpose_in_place<I: Index, T: ComplexField>(
55	tril: SparseColMatRef<'_, I, T>,
56	conj_tril: Conj,
57	rhs: MatMut<'_, T>,
58	par: Par,
59) {
60	solve_lower_triangular_transpose_in_place_impl(tril, conj_tril, DiagStatus::Unit, rhs, par)
61}
62
63/// assuming `triu` is an upper triangular matrix, solves the equation `triu * x = rhs`, and
64/// stores the result in `rhs`, implicitly conjugating `triu` if needed
65///
66/// # note
67/// the matrix indices need not be sorted, but
68/// the diagonal element is assumed to be the last stored element in each column.
69#[track_caller]
70pub fn solve_upper_triangular_in_place<I: Index, T: ComplexField>(triu: SparseColMatRef<'_, I, T>, conj_triu: Conj, rhs: MatMut<'_, T>, par: Par) {
71	solve_upper_triangular_in_place_impl(triu, conj_triu, DiagStatus::Generic, rhs, par)
72}
73
74/// assuming `triu` is an upper triangular matrix, solves the equation `triu * x = rhs`, and
75/// stores the result in `rhs`, implicitly conjugating `triu` if needed
76///
77/// # note
78/// the matrix indices need not be sorted, but
79/// the diagonal element is assumed to be the last stored element in each column.
80#[track_caller]
81pub fn solve_unit_upper_triangular_in_place<I: Index, T: ComplexField>(
82	triu: SparseColMatRef<'_, I, T>,
83	conj_triu: Conj,
84	rhs: MatMut<'_, T>,
85	par: Par,
86) {
87	solve_upper_triangular_in_place_impl(triu, conj_triu, DiagStatus::Unit, rhs, par)
88}
89
90/// assuming `triu` is an upper triangular matrix, solves the equation `triu.transpose() * x =
91/// rhs`, and stores the result in `rhs`, implicitly conjugating `triu` if needed
92///
93/// # note
94/// the matrix indices need not be sorted, but
95/// the diagonal element is assumed to be the first stored element in each column.
96#[track_caller]
97pub fn solve_upper_triangular_transpose_in_place<I: Index, T: ComplexField>(
98	triu: SparseColMatRef<'_, I, T>,
99	conj_triu: Conj,
100	rhs: MatMut<'_, T>,
101	par: Par,
102) {
103	solve_upper_triangular_transpose_in_place_impl(triu, conj_triu, DiagStatus::Generic, rhs, par)
104}
105
106/// assuming `triu` is an upper triangular matrix, solves the equation `triu.transpose() * x =
107/// rhs`, and stores the result in `rhs`, implicitly conjugating `triu` if needed
108///
109/// # note
110/// the matrix indices need not be sorted, but
111/// the diagonal element is assumed to be the first stored element in each column.
112#[track_caller]
113pub fn solve_unit_upper_triangular_transpose_in_place<I: Index, T: ComplexField>(
114	triu: SparseColMatRef<'_, I, T>,
115	conj_triu: Conj,
116	rhs: MatMut<'_, T>,
117	par: Par,
118) {
119	solve_upper_triangular_transpose_in_place_impl(triu, conj_triu, DiagStatus::Unit, rhs, par)
120}
121
122#[track_caller]
123#[math]
124fn solve_lower_triangular_in_place_impl<I: Index, T: ComplexField>(
125	tril: SparseColMatRef<'_, I, T>,
126	conj_tril: Conj,
127	diag_tril: DiagStatus,
128	rhs: MatMut<'_, T>,
129	par: Par,
130) {
131	let _ = par;
132	assert!(all(tril.nrows() == tril.ncols(), rhs.nrows() == tril.nrows()));
133
134	with_dim!(N, rhs.nrows());
135	with_dim!(K, rhs.ncols());
136
137	let mut x = rhs.as_shape_mut(N, K);
138	let l = tril.as_shape(N, N);
139
140	let mut k = IdxInc::ZERO;
141	while let Some(k0) = K.try_check(*k) {
142		let k1 = K.try_check(*k + 1);
143		let k2 = K.try_check(*k + 2);
144		let k3 = K.try_check(*k + 3);
145
146		match (k1, k2, k3) {
147			(Some(_), Some(_), Some(k3)) => {
148				let mut x = x.rb_mut().get_mut(.., k..k3.next()).col_iter_mut();
149				let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = (x.next(), x.next(), x.next(), x.next()) else {
150					panic!()
151				};
152
153				for j in N.indices() {
154					let mut l = iter::zip(l.row_idx_of_col(j), l.val_of_col(j));
155					let (i, d) = l.next().unwrap();
156					debug_assert!(i == j);
157
158					let x0j;
159					let x1j;
160					let x2j;
161					let x3j;
162					match diag_tril {
163						DiagStatus::Unit => {
164							x0j = copy(x0[j]);
165							x1j = copy(x1[j]);
166							x2j = copy(x2[j]);
167							x3j = copy(x3[j]);
168						},
169						DiagStatus::Generic => {
170							let d = conj_tril.apply_rt(&recip(*d));
171							x0j = x0[j] * d;
172							x1j = x1[j] * d;
173							x2j = x2[j] * d;
174							x3j = x3[j] * d;
175							x0[j] = copy(x0j);
176							x1[j] = copy(x1j);
177							x2[j] = copy(x2j);
178							x3[j] = copy(x3j);
179						},
180					}
181
182					for (i, lij) in l {
183						let lij = conj_tril.apply_rt(lij);
184						x0[i] = x0[i] - lij * x0j;
185						x1[i] = x1[i] - lij * x1j;
186						x2[i] = x2[i] - lij * x2j;
187						x3[i] = x3[i] - lij * x3j;
188					}
189				}
190				k = k3.next();
191			},
192			(Some(_), Some(k2), _) => {
193				let mut x = x.rb_mut().get_mut(.., k..k2.next()).col_iter_mut();
194				let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else {
195					panic!()
196				};
197
198				for j in N.indices() {
199					let mut l = iter::zip(l.row_idx_of_col(j), l.val_of_col(j));
200					let (i, d) = l.next().unwrap();
201					debug_assert!(i == j);
202
203					let x0j;
204					let x1j;
205					let x2j;
206					match diag_tril {
207						DiagStatus::Unit => {
208							x0j = copy(x0[j]);
209							x1j = copy(x1[j]);
210							x2j = copy(x2[j]);
211						},
212						DiagStatus::Generic => {
213							let d = conj_tril.apply_rt(&recip(*d));
214							x0j = x0[j] * d;
215							x1j = x1[j] * d;
216							x2j = x2[j] * d;
217							x0[j] = copy(x0j);
218							x1[j] = copy(x1j);
219							x2[j] = copy(x2j);
220						},
221					}
222
223					for (i, lij) in l {
224						let lij = conj_tril.apply_rt(lij);
225						x0[i] = x0[i] - lij * x0j;
226						x1[i] = x1[i] - lij * x1j;
227						x2[i] = x2[i] - lij * x2j;
228					}
229				}
230				k = k2.next();
231			},
232			(Some(k1), _, _) => {
233				let mut x = x.rb_mut().get_mut(.., k..k1.next()).col_iter_mut();
234				let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { panic!() };
235
236				for j in N.indices() {
237					let mut l = iter::zip(l.row_idx_of_col(j), l.val_of_col(j));
238					let (i, d) = l.next().unwrap();
239					debug_assert!(i == j);
240
241					let x0j;
242					let x1j;
243					match diag_tril {
244						DiagStatus::Unit => {
245							x0j = copy(x0[j]);
246							x1j = copy(x1[j]);
247						},
248						DiagStatus::Generic => {
249							let d = conj_tril.apply_rt(&recip(*d));
250							x0j = x0[j] * d;
251							x1j = x1[j] * d;
252							x0[j] = copy(x0j);
253							x1[j] = copy(x1j);
254						},
255					}
256
257					for (i, lij) in l {
258						let lij = conj_tril.apply_rt(lij);
259						x0[i] = x0[i] - lij * x0j;
260						x1[i] = x1[i] - lij * x1j;
261					}
262				}
263				k = k1.next();
264			},
265			(_, _, _) => {
266				let mut x0 = x.rb_mut().get_mut(.., k0);
267
268				for j in N.indices() {
269					let mut l = iter::zip(l.row_idx_of_col(j), l.val_of_col(j));
270					let (i, d) = l.next().unwrap();
271					debug_assert!(i == j);
272
273					let x0j;
274					match diag_tril {
275						DiagStatus::Unit => {
276							x0j = copy(x0[j]);
277						},
278						DiagStatus::Generic => {
279							let d = conj_tril.apply_rt(&recip(*d));
280							x0j = x0[j] * d;
281							x0[j] = copy(x0j);
282						},
283					}
284
285					for (i, lij) in l {
286						let lij = conj_tril.apply_rt(lij);
287						x0[i] = x0[i] - lij * x0j;
288					}
289				}
290				k = k0.next();
291			},
292		}
293	}
294}
295
296#[track_caller]
297#[math]
298pub(crate) fn ldlt_scale_solve_unit_lower_triangular_transpose_in_place_impl<I: Index, T: ComplexField>(
299	tril: SparseColMatRef<'_, I, T>,
300	conj_tril: Conj,
301	rhs: MatMut<'_, T>,
302	par: Par,
303) {
304	let _ = par;
305	assert!(all(tril.nrows() == tril.ncols(), rhs.nrows() == tril.nrows()));
306
307	with_dim!(N, rhs.nrows());
308	with_dim!(K, rhs.ncols());
309
310	let mut x = rhs.as_shape_mut(N, K);
311	let l = tril.as_shape(N, N);
312
313	let mut k = IdxInc::ZERO;
314	while let Some(k0) = K.try_check(*k) {
315		let k1 = K.try_check(*k + 1);
316		let k2 = K.try_check(*k + 2);
317		let k3 = K.try_check(*k + 3);
318
319		match (k1, k2, k3) {
320			(Some(_), Some(_), Some(k3)) => {
321				let mut x = x.rb_mut().get_mut(.., k..k3.next()).col_iter_mut();
322				let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = (x.next(), x.next(), x.next(), x.next()) else {
323					panic!()
324				};
325
326				for j in N.indices().rev() {
327					let mut li = l.row_idx_of_col(j);
328					let mut lv = l.val_of_col(j).iter();
329					let first = li.next().zip(lv.next());
330
331					let mut acc0a = zero::<T>();
332					let mut acc1a = zero::<T>();
333					let mut acc2a = zero::<T>();
334					let mut acc3a = zero::<T>();
335
336					for (i, lij) in iter::zip(li.rev(), lv.rev()) {
337						let lij = conj_tril.apply_rt(lij);
338						acc0a = acc0a + lij * x0[i];
339						acc1a = acc1a + lij * x1[i];
340						acc2a = acc2a + lij * x2[i];
341						acc3a = acc3a + lij * x3[i];
342					}
343
344					let (i, d) = first.unwrap();
345					debug_assert!(i == j);
346					let d = conj_tril.apply_rt(&recip(*d));
347
348					x0[j] = x0[j] * d - acc0a;
349					x1[j] = x1[j] * d - acc1a;
350					x2[j] = x2[j] * d - acc2a;
351					x3[j] = x3[j] * d - acc3a;
352				}
353				k = k3.next();
354			},
355			(Some(_), Some(k2), _) => {
356				let mut x = x.rb_mut().get_mut(.., k..k2.next()).col_iter_mut();
357				let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else {
358					panic!()
359				};
360
361				for j in N.indices().rev() {
362					let mut li = l.row_idx_of_col(j);
363					let mut lv = l.val_of_col(j).iter();
364					let first = li.next().zip(lv.next());
365
366					let mut acc0a = zero::<T>();
367					let mut acc1a = zero::<T>();
368					let mut acc2a = zero::<T>();
369
370					for (i, lij) in iter::zip(li.rev(), lv.rev()) {
371						let lij = conj_tril.apply_rt(lij);
372						acc0a = acc0a + lij * x0[i];
373						acc1a = acc1a + lij * x1[i];
374						acc2a = acc2a + lij * x2[i];
375					}
376
377					let (i, d) = first.unwrap();
378					debug_assert!(i == j);
379					let d = conj_tril.apply_rt(&recip(*d));
380
381					x0[j] = x0[j] * d - acc0a;
382					x1[j] = x1[j] * d - acc1a;
383					x2[j] = x2[j] * d - acc2a;
384				}
385
386				k = k2.next();
387			},
388			(Some(k1), _, _) => {
389				let mut x = x.rb_mut().get_mut(.., k..k1.next()).col_iter_mut();
390				let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { panic!() };
391
392				for j in N.indices().rev() {
393					let mut li = l.row_idx_of_col(j);
394					let mut lv = l.val_of_col(j).iter();
395					let first = li.next().zip(lv.next());
396
397					let mut acc0a = zero::<T>();
398					let mut acc1a = zero::<T>();
399
400					for (i, lij) in iter::zip(li.rev(), lv.rev()) {
401						let lij = conj_tril.apply_rt(lij);
402						acc0a = acc0a + lij * x0[i];
403						acc1a = acc1a + lij * x1[i];
404					}
405
406					let (i, d) = first.unwrap();
407					debug_assert!(i == j);
408					let d = conj_tril.apply_rt(&recip(*d));
409
410					x0[j] = x0[j] * d - acc0a;
411					x1[j] = x1[j] * d - acc1a;
412				}
413
414				k = k1.next();
415			},
416			(_, _, _) => {
417				let mut x0 = x.rb_mut().get_mut(.., k0);
418
419				for j in N.indices().rev() {
420					let mut li = l.row_idx_of_col(j);
421					let mut lv = l.val_of_col(j).iter();
422					let first = li.next().zip(lv.next());
423
424					let mut acc0a = zero::<T>();
425
426					for (i, lij) in iter::zip(li.rev(), lv.rev()) {
427						let lij = conj_tril.apply_rt(lij);
428						acc0a = acc0a + lij * x0[i];
429					}
430
431					let (i, d) = first.unwrap();
432					debug_assert!(i == j);
433					let d = conj_tril.apply_rt(&recip(*d));
434
435					x0[j] = x0[j] * d - acc0a;
436				}
437
438				k = k0.next();
439			},
440		}
441	}
442}
443
444#[track_caller]
445#[math]
446fn solve_lower_triangular_transpose_in_place_impl<I: Index, T: ComplexField>(
447	tril: SparseColMatRef<'_, I, T>,
448	conj_tril: Conj,
449	diag_tril: DiagStatus,
450	rhs: MatMut<'_, T>,
451	par: Par,
452) {
453	let _ = par;
454	assert!(all(tril.nrows() == tril.ncols(), rhs.nrows() == tril.nrows()));
455
456	with_dim!(N, rhs.nrows());
457	with_dim!(K, rhs.ncols());
458
459	let mut x = rhs.as_shape_mut(N, K);
460	let l = tril.as_shape(N, N);
461
462	let mut k = IdxInc::ZERO;
463	while let Some(k0) = K.try_check(*k) {
464		let k1 = K.try_check(*k + 1);
465		let k2 = K.try_check(*k + 2);
466		let k3 = K.try_check(*k + 3);
467
468		match (k1, k2, k3) {
469			(Some(_), Some(_), Some(k3)) => {
470				let mut x = x.rb_mut().get_mut(.., k..k3.next()).col_iter_mut();
471				let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = (x.next(), x.next(), x.next(), x.next()) else {
472					panic!()
473				};
474
475				for j in N.indices().rev() {
476					let mut li = l.row_idx_of_col(j);
477					let mut lv = l.val_of_col(j).iter();
478					let first = li.next().zip(lv.next());
479
480					let mut acc0a = zero::<T>();
481					let mut acc1a = zero::<T>();
482					let mut acc2a = zero::<T>();
483					let mut acc3a = zero::<T>();
484
485					for (i, lij) in iter::zip(li.rev(), lv.rev()) {
486						let lij = conj_tril.apply_rt(lij);
487						acc0a = acc0a + lij * x0[i];
488						acc1a = acc1a + lij * x1[i];
489						acc2a = acc2a + lij * x2[i];
490						acc3a = acc3a + lij * x3[i];
491					}
492
493					let mut x0j = x0[j] - acc0a;
494					let mut x1j = x1[j] - acc1a;
495					let mut x2j = x2[j] - acc2a;
496					let mut x3j = x3[j] - acc3a;
497
498					let (i, d) = first.unwrap();
499					debug_assert!(i == j);
500					match diag_tril {
501						DiagStatus::Unit => {},
502						DiagStatus::Generic => {
503							let d = conj_tril.apply_rt(&recip(*d));
504							x0j = x0j * d;
505							x1j = x1j * d;
506							x2j = x2j * d;
507							x3j = x3j * d;
508						},
509					}
510
511					x0[j] = x0j;
512					x1[j] = x1j;
513					x2[j] = x2j;
514					x3[j] = x3j;
515				}
516				k = k3.next();
517			},
518			(Some(_), Some(k2), _) => {
519				let mut x = x.rb_mut().get_mut(.., k..k2.next()).col_iter_mut();
520				let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else {
521					panic!()
522				};
523
524				for j in N.indices().rev() {
525					let mut li = l.row_idx_of_col(j);
526					let mut lv = l.val_of_col(j).iter();
527					let first = li.next().zip(lv.next());
528
529					let mut acc0a = zero::<T>();
530					let mut acc1a = zero::<T>();
531					let mut acc2a = zero::<T>();
532
533					for (i, lij) in iter::zip(li.rev(), lv.rev()) {
534						let lij = conj_tril.apply_rt(lij);
535						acc0a = acc0a + lij * x0[i];
536						acc1a = acc1a + lij * x1[i];
537						acc2a = acc2a + lij * x2[i];
538					}
539
540					let mut x0j = x0[j] - acc0a;
541					let mut x1j = x1[j] - acc1a;
542					let mut x2j = x2[j] - acc2a;
543
544					let (i, d) = first.unwrap();
545					debug_assert!(i == j);
546					match diag_tril {
547						DiagStatus::Unit => {},
548						DiagStatus::Generic => {
549							let d = conj_tril.apply_rt(&recip(*d));
550							x0j = x0j * d;
551							x1j = x1j * d;
552							x2j = x2j * d;
553						},
554					}
555
556					x0[j] = x0j;
557					x1[j] = x1j;
558					x2[j] = x2j;
559				}
560
561				k = k2.next();
562			},
563			(Some(k1), _, _) => {
564				let mut x = x.rb_mut().get_mut(.., k..k1.next()).col_iter_mut();
565				let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { panic!() };
566
567				for j in N.indices().rev() {
568					let mut li = l.row_idx_of_col(j);
569					let mut lv = l.val_of_col(j).iter();
570					let first = li.next().zip(lv.next());
571
572					let mut acc0a = zero::<T>();
573					let mut acc1a = zero::<T>();
574
575					for (i, lij) in iter::zip(li.rev(), lv.rev()) {
576						let lij = conj_tril.apply_rt(lij);
577						acc0a = acc0a + lij * x0[i];
578						acc1a = acc1a + lij * x1[i];
579					}
580
581					let mut x0j = x0[j] - acc0a;
582					let mut x1j = x1[j] - acc1a;
583
584					let (i, d) = first.unwrap();
585					debug_assert!(i == j);
586					match diag_tril {
587						DiagStatus::Unit => {},
588						DiagStatus::Generic => {
589							let d = conj_tril.apply_rt(&recip(*d));
590							x0j = x0j * d;
591							x1j = x1j * d;
592						},
593					}
594
595					x0[j] = x0j;
596					x1[j] = x1j;
597				}
598
599				k = k1.next();
600			},
601			(_, _, _) => {
602				let mut x0 = x.rb_mut().get_mut(.., k0);
603
604				for j in N.indices().rev() {
605					let mut li = l.row_idx_of_col(j);
606					let mut lv = l.val_of_col(j).iter();
607					let first = li.next().zip(lv.next());
608
609					let mut acc0a = zero::<T>();
610
611					for (i, lij) in iter::zip(li.rev(), lv.rev()) {
612						let lij = conj_tril.apply_rt(lij);
613						acc0a = acc0a + lij * x0[i];
614					}
615
616					let mut x0j = x0[j] - acc0a;
617
618					let (i, d) = first.unwrap();
619					debug_assert!(i == j);
620					match diag_tril {
621						DiagStatus::Unit => {},
622						DiagStatus::Generic => {
623							let d = conj_tril.apply_rt(&recip(*d));
624							x0j = x0j * d;
625						},
626					}
627
628					x0[j] = x0j;
629				}
630
631				k = k0.next();
632			},
633		}
634	}
635}
636
637#[track_caller]
638#[math]
639fn solve_upper_triangular_in_place_impl<I: Index, T: ComplexField>(
640	triu: SparseColMatRef<'_, I, T>,
641	conj_triu: Conj,
642	diag_triu: DiagStatus,
643	rhs: MatMut<'_, T>,
644	par: Par,
645) {
646	let _ = par;
647
648	assert!(all(triu.nrows() == triu.ncols(), rhs.nrows() == triu.nrows()));
649	with_dim!(N, rhs.nrows());
650	with_dim!(K, rhs.ncols());
651
652	let mut x = rhs.as_shape_mut(N, K);
653	let u = triu.as_shape(N, N);
654
655	let mut k = IdxInc::ZERO;
656	while let Some(k0) = K.try_check(*k) {
657		let k1 = K.try_check(*k + 1);
658		let k2 = K.try_check(*k + 2);
659		let k3 = K.try_check(*k + 3);
660
661		match (k1, k2, k3) {
662			(Some(_), Some(_), Some(k3)) => {
663				let mut x = x.rb_mut().get_mut(.., k..k3.next()).col_iter_mut();
664				let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = (x.next(), x.next(), x.next(), x.next()) else {
665					panic!()
666				};
667
668				for j in N.indices().rev() {
669					let mut u = iter::zip(u.row_idx_of_col(j).rev(), u.val_of_col(j).iter().rev());
670
671					let (i, d) = u.next().unwrap();
672					debug_assert!(i == j);
673
674					let x0j;
675					let x1j;
676					let x2j;
677					let x3j;
678					match diag_triu {
679						DiagStatus::Unit => {
680							x0j = copy(x0[j]);
681							x1j = copy(x1[j]);
682							x2j = copy(x2[j]);
683							x3j = copy(x3[j]);
684						},
685						DiagStatus::Generic => {
686							let d = conj_triu.apply_rt(&recip(*d));
687							x0j = x0[j] * d;
688							x1j = x1[j] * d;
689							x2j = x2[j] * d;
690							x3j = x3[j] * d;
691							x0[j] = copy(x0j);
692							x1[j] = copy(x1j);
693							x2[j] = copy(x2j);
694							x3[j] = copy(x3j);
695						},
696					}
697
698					for (i, u) in u {
699						let uij = conj_triu.apply_rt(u);
700						x0[i] = x0[i] - uij * x0j;
701						x1[i] = x1[i] - uij * x1j;
702						x2[i] = x2[i] - uij * x2j;
703						x3[i] = x3[i] - uij * x3j;
704					}
705				}
706				k = k3.next();
707			},
708			(Some(_), Some(k2), _) => {
709				let mut x = x.rb_mut().get_mut(.., k..k2.next()).col_iter_mut();
710				let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else {
711					panic!()
712				};
713
714				for j in N.indices().rev() {
715					let mut u = iter::zip(u.row_idx_of_col(j).rev(), u.val_of_col(j).iter().rev());
716
717					let (i, d) = u.next().unwrap();
718					debug_assert!(i == j);
719
720					let x0j;
721					let x1j;
722					let x2j;
723					match diag_triu {
724						DiagStatus::Unit => {
725							x0j = copy(x0[j]);
726							x1j = copy(x1[j]);
727							x2j = copy(x2[j]);
728						},
729						DiagStatus::Generic => {
730							let d = conj_triu.apply_rt(&recip(*d));
731							x0j = x0[j] * d;
732							x1j = x1[j] * d;
733							x2j = x2[j] * d;
734							x0[j] = copy(x0j);
735							x1[j] = copy(x1j);
736							x2[j] = copy(x2j);
737						},
738					}
739
740					for (i, u) in u {
741						let uij = conj_triu.apply_rt(u);
742						x0[i] = x0[i] - uij * x0j;
743						x1[i] = x1[i] - uij * x1j;
744						x2[i] = x2[i] - uij * x2j;
745					}
746				}
747				k = k2.next();
748			},
749			(Some(k1), _, _) => {
750				let mut x = x.rb_mut().get_mut(.., k..k1.next()).col_iter_mut();
751				let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { panic!() };
752
753				for j in N.indices().rev() {
754					let mut u = iter::zip(u.row_idx_of_col(j).rev(), u.val_of_col(j).iter().rev());
755
756					let (i, d) = u.next().unwrap();
757					debug_assert!(i == j);
758
759					let x0j;
760					let x1j;
761					match diag_triu {
762						DiagStatus::Unit => {
763							x0j = copy(x0[j]);
764							x1j = copy(x1[j]);
765						},
766						DiagStatus::Generic => {
767							let d = conj_triu.apply_rt(&recip(*d));
768							x0j = x0[j] * d;
769							x1j = x1[j] * d;
770							x0[j] = copy(x0j);
771							x1[j] = copy(x1j);
772						},
773					}
774
775					for (i, u) in u {
776						let uij = conj_triu.apply_rt(u);
777						x0[i] = x0[i] - uij * x0j;
778						x1[i] = x1[i] - uij * x1j;
779					}
780				}
781				k = k1.next();
782			},
783			(_, _, _) => {
784				let mut x0 = x.rb_mut().get_mut(.., k0);
785
786				for j in N.indices().rev() {
787					let mut u = iter::zip(u.row_idx_of_col(j).rev(), u.val_of_col(j).iter().rev());
788
789					let (i, d) = u.next().unwrap();
790					debug_assert!(i == j);
791
792					let x0j;
793					match diag_triu {
794						DiagStatus::Unit => {
795							x0j = copy(x0[j]);
796						},
797						DiagStatus::Generic => {
798							let d = conj_triu.apply_rt(&recip(*d));
799							x0j = x0[j] * d;
800							x0[j] = copy(x0j);
801						},
802					}
803
804					for (i, u) in u {
805						let uij = conj_triu.apply_rt(u);
806						x0[i] = x0[i] - uij * x0j;
807					}
808				}
809				k = k0.next();
810			},
811		}
812	}
813}
814
815#[track_caller]
816#[math]
817fn solve_upper_triangular_transpose_in_place_impl<I: Index, T: ComplexField>(
818	triu: SparseColMatRef<'_, I, T>,
819	conj_triu: Conj,
820	diag_triu: DiagStatus,
821	rhs: MatMut<'_, T>,
822	par: Par,
823) {
824	let _ = par;
825	assert!(all(triu.nrows() == triu.ncols(), rhs.nrows() == triu.nrows()));
826
827	with_dim!(N, rhs.nrows());
828	with_dim!(K, rhs.ncols());
829
830	let mut x = rhs.as_shape_mut(N, K);
831	let u = triu.as_shape(N, N);
832
833	let mut k = IdxInc::ZERO;
834	while let Some(k0) = K.try_check(*k) {
835		let k1 = K.try_check(*k + 1);
836		let k2 = K.try_check(*k + 2);
837		let k3 = K.try_check(*k + 3);
838
839		match (k1, k2, k3) {
840			(Some(_), Some(_), Some(k3)) => {
841				let mut x = x.rb_mut().get_mut(.., k..k3.next()).col_iter_mut();
842				let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = (x.next(), x.next(), x.next(), x.next()) else {
843					panic!()
844				};
845
846				for j in N.indices() {
847					let mut ui = u.row_idx_of_col(j);
848					let mut uv = u.val_of_col(j).iter();
849					let first = ui.next_back().zip(uv.next_back());
850
851					let mut acc0a = zero::<T>();
852					let mut acc1a = zero::<T>();
853					let mut acc2a = zero::<T>();
854					let mut acc3a = zero::<T>();
855
856					for (i, uij) in iter::zip(ui, uv) {
857						let uij = conj_triu.apply_rt(uij);
858						acc0a = acc0a + uij * x0[i];
859						acc1a = acc1a + uij * x1[i];
860						acc2a = acc2a + uij * x2[i];
861						acc3a = acc3a + uij * x3[i];
862					}
863
864					let mut x0j = x0[j] - acc0a;
865					let mut x1j = x1[j] - acc1a;
866					let mut x2j = x2[j] - acc2a;
867					let mut x3j = x3[j] - acc3a;
868
869					let (i, d) = first.unwrap();
870					debug_assert!(i == j);
871					match diag_triu {
872						DiagStatus::Unit => {},
873						DiagStatus::Generic => {
874							let d = conj_triu.apply_rt(&recip(*d));
875							x0j = x0j * d;
876							x1j = x1j * d;
877							x2j = x2j * d;
878							x3j = x3j * d;
879						},
880					}
881
882					x0[j] = x0j;
883					x1[j] = x1j;
884					x2[j] = x2j;
885					x3[j] = x3j;
886				}
887				k = k3.next();
888			},
889			(Some(_), Some(k2), _) => {
890				let mut x = x.rb_mut().get_mut(.., k..k2.next()).col_iter_mut();
891				let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else {
892					panic!()
893				};
894
895				for j in N.indices() {
896					let mut ui = u.row_idx_of_col(j);
897					let mut uv = u.val_of_col(j).iter();
898					let first = ui.next_back().zip(uv.next_back());
899
900					let mut acc0a = zero::<T>();
901					let mut acc1a = zero::<T>();
902					let mut acc2a = zero::<T>();
903
904					for (i, uij) in iter::zip(ui, uv) {
905						let uij = conj_triu.apply_rt(uij);
906						acc0a = acc0a + uij * x0[i];
907						acc1a = acc1a + uij * x1[i];
908						acc2a = acc2a + uij * x2[i];
909					}
910
911					let mut x0j = x0[j] - acc0a;
912					let mut x1j = x1[j] - acc1a;
913					let mut x2j = x2[j] - acc2a;
914
915					let (i, d) = first.unwrap();
916					debug_assert!(i == j);
917					match diag_triu {
918						DiagStatus::Unit => {},
919						DiagStatus::Generic => {
920							let d = conj_triu.apply_rt(&recip(*d));
921							x0j = x0j * d;
922							x1j = x1j * d;
923							x2j = x2j * d;
924						},
925					}
926
927					x0[j] = x0j;
928					x1[j] = x1j;
929					x2[j] = x2j;
930				}
931
932				k = k2.next();
933			},
934			(Some(k1), _, _) => {
935				let mut x = x.rb_mut().get_mut(.., k..k1.next()).col_iter_mut();
936				let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { panic!() };
937
938				for j in N.indices() {
939					let mut ui = u.row_idx_of_col(j);
940					let mut uv = u.val_of_col(j).iter();
941					let first = ui.next_back().zip(uv.next_back());
942
943					let mut acc0a = zero::<T>();
944					let mut acc1a = zero::<T>();
945
946					for (i, uij) in iter::zip(ui, uv) {
947						let uij = conj_triu.apply_rt(uij);
948						acc0a = acc0a + uij * x0[i];
949						acc1a = acc1a + uij * x1[i];
950					}
951
952					let mut x0j = x0[j] - acc0a;
953					let mut x1j = x1[j] - acc1a;
954
955					let (i, d) = first.unwrap();
956					debug_assert!(i == j);
957					match diag_triu {
958						DiagStatus::Unit => {},
959						DiagStatus::Generic => {
960							let d = conj_triu.apply_rt(&recip(*d));
961							x0j = x0j * d;
962							x1j = x1j * d;
963						},
964					}
965
966					x0[j] = x0j;
967					x1[j] = x1j;
968				}
969
970				k = k1.next();
971			},
972			(_, _, _) => {
973				let mut x0 = x.rb_mut().get_mut(.., k0);
974
975				for j in N.indices() {
976					let mut ui = u.row_idx_of_col(j);
977					let mut uv = u.val_of_col(j).iter();
978					let first = ui.next_back().zip(uv.next_back());
979
980					let mut acc0a = zero::<T>();
981
982					for (i, uij) in iter::zip(ui, uv) {
983						let uij = conj_triu.apply_rt(uij);
984						acc0a = acc0a + uij * x0[i];
985					}
986
987					let mut x0j = x0[j] - acc0a;
988
989					let (i, d) = first.unwrap();
990					debug_assert!(i == j);
991					match diag_triu {
992						DiagStatus::Unit => {},
993						DiagStatus::Generic => {
994							let d = conj_triu.apply_rt(&recip(*d));
995							x0j = x0j * d;
996						},
997					}
998
999					x0[j] = x0j;
1000				}
1001
1002				k = k0.next();
1003			},
1004		}
1005	}
1006}