1use crate::internal_prelude_sp::*;
2use crate::{assert, debug_assert};
3
4#[track_caller]
11pub fn solve_lower_triangular_in_place<I: Index, T: ComplexField>(tril: SparseColMatRef<'_, I, T>, conj_tril: Conj, rhs: MatMut<'_, T>, par: Par) {
12 solve_lower_triangular_in_place_impl(tril, conj_tril, DiagStatus::Generic, rhs, par)
13}
14
15#[track_caller]
22pub fn solve_unit_lower_triangular_in_place<I: Index, T: ComplexField>(
23 tril: SparseColMatRef<'_, I, T>,
24 conj_tril: Conj,
25 rhs: MatMut<'_, T>,
26 par: Par,
27) {
28 solve_lower_triangular_in_place_impl(tril, conj_tril, DiagStatus::Unit, rhs, par)
29}
30
31#[track_caller]
38pub fn solve_lower_triangular_transpose_in_place<I: Index, T: ComplexField>(
39 tril: SparseColMatRef<'_, I, T>,
40 conj_tril: Conj,
41 rhs: MatMut<'_, T>,
42 par: Par,
43) {
44 solve_lower_triangular_transpose_in_place_impl(tril, conj_tril, DiagStatus::Generic, rhs, par)
45}
46
47#[track_caller]
54pub fn solve_unit_lower_triangular_transpose_in_place<I: Index, T: ComplexField>(
55 tril: SparseColMatRef<'_, I, T>,
56 conj_tril: Conj,
57 rhs: MatMut<'_, T>,
58 par: Par,
59) {
60 solve_lower_triangular_transpose_in_place_impl(tril, conj_tril, DiagStatus::Unit, rhs, par)
61}
62
63#[track_caller]
70pub fn solve_upper_triangular_in_place<I: Index, T: ComplexField>(triu: SparseColMatRef<'_, I, T>, conj_triu: Conj, rhs: MatMut<'_, T>, par: Par) {
71 solve_upper_triangular_in_place_impl(triu, conj_triu, DiagStatus::Generic, rhs, par)
72}
73
74#[track_caller]
81pub fn solve_unit_upper_triangular_in_place<I: Index, T: ComplexField>(
82 triu: SparseColMatRef<'_, I, T>,
83 conj_triu: Conj,
84 rhs: MatMut<'_, T>,
85 par: Par,
86) {
87 solve_upper_triangular_in_place_impl(triu, conj_triu, DiagStatus::Unit, rhs, par)
88}
89
90#[track_caller]
97pub fn solve_upper_triangular_transpose_in_place<I: Index, T: ComplexField>(
98 triu: SparseColMatRef<'_, I, T>,
99 conj_triu: Conj,
100 rhs: MatMut<'_, T>,
101 par: Par,
102) {
103 solve_upper_triangular_transpose_in_place_impl(triu, conj_triu, DiagStatus::Generic, rhs, par)
104}
105
106#[track_caller]
113pub fn solve_unit_upper_triangular_transpose_in_place<I: Index, T: ComplexField>(
114 triu: SparseColMatRef<'_, I, T>,
115 conj_triu: Conj,
116 rhs: MatMut<'_, T>,
117 par: Par,
118) {
119 solve_upper_triangular_transpose_in_place_impl(triu, conj_triu, DiagStatus::Unit, rhs, par)
120}
121
122#[track_caller]
123#[math]
124fn solve_lower_triangular_in_place_impl<I: Index, T: ComplexField>(
125 tril: SparseColMatRef<'_, I, T>,
126 conj_tril: Conj,
127 diag_tril: DiagStatus,
128 rhs: MatMut<'_, T>,
129 par: Par,
130) {
131 let _ = par;
132 assert!(all(tril.nrows() == tril.ncols(), rhs.nrows() == tril.nrows()));
133
134 with_dim!(N, rhs.nrows());
135 with_dim!(K, rhs.ncols());
136
137 let mut x = rhs.as_shape_mut(N, K);
138 let l = tril.as_shape(N, N);
139
140 let mut k = IdxInc::ZERO;
141 while let Some(k0) = K.try_check(*k) {
142 let k1 = K.try_check(*k + 1);
143 let k2 = K.try_check(*k + 2);
144 let k3 = K.try_check(*k + 3);
145
146 match (k1, k2, k3) {
147 (Some(_), Some(_), Some(k3)) => {
148 let mut x = x.rb_mut().get_mut(.., k..k3.next()).col_iter_mut();
149 let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = (x.next(), x.next(), x.next(), x.next()) else {
150 panic!()
151 };
152
153 for j in N.indices() {
154 let mut l = iter::zip(l.row_idx_of_col(j), l.val_of_col(j));
155 let (i, d) = l.next().unwrap();
156 debug_assert!(i == j);
157
158 let x0j;
159 let x1j;
160 let x2j;
161 let x3j;
162 match diag_tril {
163 DiagStatus::Unit => {
164 x0j = copy(x0[j]);
165 x1j = copy(x1[j]);
166 x2j = copy(x2[j]);
167 x3j = copy(x3[j]);
168 },
169 DiagStatus::Generic => {
170 let d = conj_tril.apply_rt(&recip(*d));
171 x0j = x0[j] * d;
172 x1j = x1[j] * d;
173 x2j = x2[j] * d;
174 x3j = x3[j] * d;
175 x0[j] = copy(x0j);
176 x1[j] = copy(x1j);
177 x2[j] = copy(x2j);
178 x3[j] = copy(x3j);
179 },
180 }
181
182 for (i, lij) in l {
183 let lij = conj_tril.apply_rt(lij);
184 x0[i] = x0[i] - lij * x0j;
185 x1[i] = x1[i] - lij * x1j;
186 x2[i] = x2[i] - lij * x2j;
187 x3[i] = x3[i] - lij * x3j;
188 }
189 }
190 k = k3.next();
191 },
192 (Some(_), Some(k2), _) => {
193 let mut x = x.rb_mut().get_mut(.., k..k2.next()).col_iter_mut();
194 let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else {
195 panic!()
196 };
197
198 for j in N.indices() {
199 let mut l = iter::zip(l.row_idx_of_col(j), l.val_of_col(j));
200 let (i, d) = l.next().unwrap();
201 debug_assert!(i == j);
202
203 let x0j;
204 let x1j;
205 let x2j;
206 match diag_tril {
207 DiagStatus::Unit => {
208 x0j = copy(x0[j]);
209 x1j = copy(x1[j]);
210 x2j = copy(x2[j]);
211 },
212 DiagStatus::Generic => {
213 let d = conj_tril.apply_rt(&recip(*d));
214 x0j = x0[j] * d;
215 x1j = x1[j] * d;
216 x2j = x2[j] * d;
217 x0[j] = copy(x0j);
218 x1[j] = copy(x1j);
219 x2[j] = copy(x2j);
220 },
221 }
222
223 for (i, lij) in l {
224 let lij = conj_tril.apply_rt(lij);
225 x0[i] = x0[i] - lij * x0j;
226 x1[i] = x1[i] - lij * x1j;
227 x2[i] = x2[i] - lij * x2j;
228 }
229 }
230 k = k2.next();
231 },
232 (Some(k1), _, _) => {
233 let mut x = x.rb_mut().get_mut(.., k..k1.next()).col_iter_mut();
234 let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { panic!() };
235
236 for j in N.indices() {
237 let mut l = iter::zip(l.row_idx_of_col(j), l.val_of_col(j));
238 let (i, d) = l.next().unwrap();
239 debug_assert!(i == j);
240
241 let x0j;
242 let x1j;
243 match diag_tril {
244 DiagStatus::Unit => {
245 x0j = copy(x0[j]);
246 x1j = copy(x1[j]);
247 },
248 DiagStatus::Generic => {
249 let d = conj_tril.apply_rt(&recip(*d));
250 x0j = x0[j] * d;
251 x1j = x1[j] * d;
252 x0[j] = copy(x0j);
253 x1[j] = copy(x1j);
254 },
255 }
256
257 for (i, lij) in l {
258 let lij = conj_tril.apply_rt(lij);
259 x0[i] = x0[i] - lij * x0j;
260 x1[i] = x1[i] - lij * x1j;
261 }
262 }
263 k = k1.next();
264 },
265 (_, _, _) => {
266 let mut x0 = x.rb_mut().get_mut(.., k0);
267
268 for j in N.indices() {
269 let mut l = iter::zip(l.row_idx_of_col(j), l.val_of_col(j));
270 let (i, d) = l.next().unwrap();
271 debug_assert!(i == j);
272
273 let x0j;
274 match diag_tril {
275 DiagStatus::Unit => {
276 x0j = copy(x0[j]);
277 },
278 DiagStatus::Generic => {
279 let d = conj_tril.apply_rt(&recip(*d));
280 x0j = x0[j] * d;
281 x0[j] = copy(x0j);
282 },
283 }
284
285 for (i, lij) in l {
286 let lij = conj_tril.apply_rt(lij);
287 x0[i] = x0[i] - lij * x0j;
288 }
289 }
290 k = k0.next();
291 },
292 }
293 }
294}
295
296#[track_caller]
297#[math]
298pub(crate) fn ldlt_scale_solve_unit_lower_triangular_transpose_in_place_impl<I: Index, T: ComplexField>(
299 tril: SparseColMatRef<'_, I, T>,
300 conj_tril: Conj,
301 rhs: MatMut<'_, T>,
302 par: Par,
303) {
304 let _ = par;
305 assert!(all(tril.nrows() == tril.ncols(), rhs.nrows() == tril.nrows()));
306
307 with_dim!(N, rhs.nrows());
308 with_dim!(K, rhs.ncols());
309
310 let mut x = rhs.as_shape_mut(N, K);
311 let l = tril.as_shape(N, N);
312
313 let mut k = IdxInc::ZERO;
314 while let Some(k0) = K.try_check(*k) {
315 let k1 = K.try_check(*k + 1);
316 let k2 = K.try_check(*k + 2);
317 let k3 = K.try_check(*k + 3);
318
319 match (k1, k2, k3) {
320 (Some(_), Some(_), Some(k3)) => {
321 let mut x = x.rb_mut().get_mut(.., k..k3.next()).col_iter_mut();
322 let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = (x.next(), x.next(), x.next(), x.next()) else {
323 panic!()
324 };
325
326 for j in N.indices().rev() {
327 let mut li = l.row_idx_of_col(j);
328 let mut lv = l.val_of_col(j).iter();
329 let first = li.next().zip(lv.next());
330
331 let mut acc0a = zero::<T>();
332 let mut acc1a = zero::<T>();
333 let mut acc2a = zero::<T>();
334 let mut acc3a = zero::<T>();
335
336 for (i, lij) in iter::zip(li.rev(), lv.rev()) {
337 let lij = conj_tril.apply_rt(lij);
338 acc0a = acc0a + lij * x0[i];
339 acc1a = acc1a + lij * x1[i];
340 acc2a = acc2a + lij * x2[i];
341 acc3a = acc3a + lij * x3[i];
342 }
343
344 let (i, d) = first.unwrap();
345 debug_assert!(i == j);
346 let d = conj_tril.apply_rt(&recip(*d));
347
348 x0[j] = x0[j] * d - acc0a;
349 x1[j] = x1[j] * d - acc1a;
350 x2[j] = x2[j] * d - acc2a;
351 x3[j] = x3[j] * d - acc3a;
352 }
353 k = k3.next();
354 },
355 (Some(_), Some(k2), _) => {
356 let mut x = x.rb_mut().get_mut(.., k..k2.next()).col_iter_mut();
357 let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else {
358 panic!()
359 };
360
361 for j in N.indices().rev() {
362 let mut li = l.row_idx_of_col(j);
363 let mut lv = l.val_of_col(j).iter();
364 let first = li.next().zip(lv.next());
365
366 let mut acc0a = zero::<T>();
367 let mut acc1a = zero::<T>();
368 let mut acc2a = zero::<T>();
369
370 for (i, lij) in iter::zip(li.rev(), lv.rev()) {
371 let lij = conj_tril.apply_rt(lij);
372 acc0a = acc0a + lij * x0[i];
373 acc1a = acc1a + lij * x1[i];
374 acc2a = acc2a + lij * x2[i];
375 }
376
377 let (i, d) = first.unwrap();
378 debug_assert!(i == j);
379 let d = conj_tril.apply_rt(&recip(*d));
380
381 x0[j] = x0[j] * d - acc0a;
382 x1[j] = x1[j] * d - acc1a;
383 x2[j] = x2[j] * d - acc2a;
384 }
385
386 k = k2.next();
387 },
388 (Some(k1), _, _) => {
389 let mut x = x.rb_mut().get_mut(.., k..k1.next()).col_iter_mut();
390 let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { panic!() };
391
392 for j in N.indices().rev() {
393 let mut li = l.row_idx_of_col(j);
394 let mut lv = l.val_of_col(j).iter();
395 let first = li.next().zip(lv.next());
396
397 let mut acc0a = zero::<T>();
398 let mut acc1a = zero::<T>();
399
400 for (i, lij) in iter::zip(li.rev(), lv.rev()) {
401 let lij = conj_tril.apply_rt(lij);
402 acc0a = acc0a + lij * x0[i];
403 acc1a = acc1a + lij * x1[i];
404 }
405
406 let (i, d) = first.unwrap();
407 debug_assert!(i == j);
408 let d = conj_tril.apply_rt(&recip(*d));
409
410 x0[j] = x0[j] * d - acc0a;
411 x1[j] = x1[j] * d - acc1a;
412 }
413
414 k = k1.next();
415 },
416 (_, _, _) => {
417 let mut x0 = x.rb_mut().get_mut(.., k0);
418
419 for j in N.indices().rev() {
420 let mut li = l.row_idx_of_col(j);
421 let mut lv = l.val_of_col(j).iter();
422 let first = li.next().zip(lv.next());
423
424 let mut acc0a = zero::<T>();
425
426 for (i, lij) in iter::zip(li.rev(), lv.rev()) {
427 let lij = conj_tril.apply_rt(lij);
428 acc0a = acc0a + lij * x0[i];
429 }
430
431 let (i, d) = first.unwrap();
432 debug_assert!(i == j);
433 let d = conj_tril.apply_rt(&recip(*d));
434
435 x0[j] = x0[j] * d - acc0a;
436 }
437
438 k = k0.next();
439 },
440 }
441 }
442}
443
444#[track_caller]
445#[math]
446fn solve_lower_triangular_transpose_in_place_impl<I: Index, T: ComplexField>(
447 tril: SparseColMatRef<'_, I, T>,
448 conj_tril: Conj,
449 diag_tril: DiagStatus,
450 rhs: MatMut<'_, T>,
451 par: Par,
452) {
453 let _ = par;
454 assert!(all(tril.nrows() == tril.ncols(), rhs.nrows() == tril.nrows()));
455
456 with_dim!(N, rhs.nrows());
457 with_dim!(K, rhs.ncols());
458
459 let mut x = rhs.as_shape_mut(N, K);
460 let l = tril.as_shape(N, N);
461
462 let mut k = IdxInc::ZERO;
463 while let Some(k0) = K.try_check(*k) {
464 let k1 = K.try_check(*k + 1);
465 let k2 = K.try_check(*k + 2);
466 let k3 = K.try_check(*k + 3);
467
468 match (k1, k2, k3) {
469 (Some(_), Some(_), Some(k3)) => {
470 let mut x = x.rb_mut().get_mut(.., k..k3.next()).col_iter_mut();
471 let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = (x.next(), x.next(), x.next(), x.next()) else {
472 panic!()
473 };
474
475 for j in N.indices().rev() {
476 let mut li = l.row_idx_of_col(j);
477 let mut lv = l.val_of_col(j).iter();
478 let first = li.next().zip(lv.next());
479
480 let mut acc0a = zero::<T>();
481 let mut acc1a = zero::<T>();
482 let mut acc2a = zero::<T>();
483 let mut acc3a = zero::<T>();
484
485 for (i, lij) in iter::zip(li.rev(), lv.rev()) {
486 let lij = conj_tril.apply_rt(lij);
487 acc0a = acc0a + lij * x0[i];
488 acc1a = acc1a + lij * x1[i];
489 acc2a = acc2a + lij * x2[i];
490 acc3a = acc3a + lij * x3[i];
491 }
492
493 let mut x0j = x0[j] - acc0a;
494 let mut x1j = x1[j] - acc1a;
495 let mut x2j = x2[j] - acc2a;
496 let mut x3j = x3[j] - acc3a;
497
498 let (i, d) = first.unwrap();
499 debug_assert!(i == j);
500 match diag_tril {
501 DiagStatus::Unit => {},
502 DiagStatus::Generic => {
503 let d = conj_tril.apply_rt(&recip(*d));
504 x0j = x0j * d;
505 x1j = x1j * d;
506 x2j = x2j * d;
507 x3j = x3j * d;
508 },
509 }
510
511 x0[j] = x0j;
512 x1[j] = x1j;
513 x2[j] = x2j;
514 x3[j] = x3j;
515 }
516 k = k3.next();
517 },
518 (Some(_), Some(k2), _) => {
519 let mut x = x.rb_mut().get_mut(.., k..k2.next()).col_iter_mut();
520 let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else {
521 panic!()
522 };
523
524 for j in N.indices().rev() {
525 let mut li = l.row_idx_of_col(j);
526 let mut lv = l.val_of_col(j).iter();
527 let first = li.next().zip(lv.next());
528
529 let mut acc0a = zero::<T>();
530 let mut acc1a = zero::<T>();
531 let mut acc2a = zero::<T>();
532
533 for (i, lij) in iter::zip(li.rev(), lv.rev()) {
534 let lij = conj_tril.apply_rt(lij);
535 acc0a = acc0a + lij * x0[i];
536 acc1a = acc1a + lij * x1[i];
537 acc2a = acc2a + lij * x2[i];
538 }
539
540 let mut x0j = x0[j] - acc0a;
541 let mut x1j = x1[j] - acc1a;
542 let mut x2j = x2[j] - acc2a;
543
544 let (i, d) = first.unwrap();
545 debug_assert!(i == j);
546 match diag_tril {
547 DiagStatus::Unit => {},
548 DiagStatus::Generic => {
549 let d = conj_tril.apply_rt(&recip(*d));
550 x0j = x0j * d;
551 x1j = x1j * d;
552 x2j = x2j * d;
553 },
554 }
555
556 x0[j] = x0j;
557 x1[j] = x1j;
558 x2[j] = x2j;
559 }
560
561 k = k2.next();
562 },
563 (Some(k1), _, _) => {
564 let mut x = x.rb_mut().get_mut(.., k..k1.next()).col_iter_mut();
565 let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { panic!() };
566
567 for j in N.indices().rev() {
568 let mut li = l.row_idx_of_col(j);
569 let mut lv = l.val_of_col(j).iter();
570 let first = li.next().zip(lv.next());
571
572 let mut acc0a = zero::<T>();
573 let mut acc1a = zero::<T>();
574
575 for (i, lij) in iter::zip(li.rev(), lv.rev()) {
576 let lij = conj_tril.apply_rt(lij);
577 acc0a = acc0a + lij * x0[i];
578 acc1a = acc1a + lij * x1[i];
579 }
580
581 let mut x0j = x0[j] - acc0a;
582 let mut x1j = x1[j] - acc1a;
583
584 let (i, d) = first.unwrap();
585 debug_assert!(i == j);
586 match diag_tril {
587 DiagStatus::Unit => {},
588 DiagStatus::Generic => {
589 let d = conj_tril.apply_rt(&recip(*d));
590 x0j = x0j * d;
591 x1j = x1j * d;
592 },
593 }
594
595 x0[j] = x0j;
596 x1[j] = x1j;
597 }
598
599 k = k1.next();
600 },
601 (_, _, _) => {
602 let mut x0 = x.rb_mut().get_mut(.., k0);
603
604 for j in N.indices().rev() {
605 let mut li = l.row_idx_of_col(j);
606 let mut lv = l.val_of_col(j).iter();
607 let first = li.next().zip(lv.next());
608
609 let mut acc0a = zero::<T>();
610
611 for (i, lij) in iter::zip(li.rev(), lv.rev()) {
612 let lij = conj_tril.apply_rt(lij);
613 acc0a = acc0a + lij * x0[i];
614 }
615
616 let mut x0j = x0[j] - acc0a;
617
618 let (i, d) = first.unwrap();
619 debug_assert!(i == j);
620 match diag_tril {
621 DiagStatus::Unit => {},
622 DiagStatus::Generic => {
623 let d = conj_tril.apply_rt(&recip(*d));
624 x0j = x0j * d;
625 },
626 }
627
628 x0[j] = x0j;
629 }
630
631 k = k0.next();
632 },
633 }
634 }
635}
636
637#[track_caller]
638#[math]
639fn solve_upper_triangular_in_place_impl<I: Index, T: ComplexField>(
640 triu: SparseColMatRef<'_, I, T>,
641 conj_triu: Conj,
642 diag_triu: DiagStatus,
643 rhs: MatMut<'_, T>,
644 par: Par,
645) {
646 let _ = par;
647
648 assert!(all(triu.nrows() == triu.ncols(), rhs.nrows() == triu.nrows()));
649 with_dim!(N, rhs.nrows());
650 with_dim!(K, rhs.ncols());
651
652 let mut x = rhs.as_shape_mut(N, K);
653 let u = triu.as_shape(N, N);
654
655 let mut k = IdxInc::ZERO;
656 while let Some(k0) = K.try_check(*k) {
657 let k1 = K.try_check(*k + 1);
658 let k2 = K.try_check(*k + 2);
659 let k3 = K.try_check(*k + 3);
660
661 match (k1, k2, k3) {
662 (Some(_), Some(_), Some(k3)) => {
663 let mut x = x.rb_mut().get_mut(.., k..k3.next()).col_iter_mut();
664 let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = (x.next(), x.next(), x.next(), x.next()) else {
665 panic!()
666 };
667
668 for j in N.indices().rev() {
669 let mut u = iter::zip(u.row_idx_of_col(j).rev(), u.val_of_col(j).iter().rev());
670
671 let (i, d) = u.next().unwrap();
672 debug_assert!(i == j);
673
674 let x0j;
675 let x1j;
676 let x2j;
677 let x3j;
678 match diag_triu {
679 DiagStatus::Unit => {
680 x0j = copy(x0[j]);
681 x1j = copy(x1[j]);
682 x2j = copy(x2[j]);
683 x3j = copy(x3[j]);
684 },
685 DiagStatus::Generic => {
686 let d = conj_triu.apply_rt(&recip(*d));
687 x0j = x0[j] * d;
688 x1j = x1[j] * d;
689 x2j = x2[j] * d;
690 x3j = x3[j] * d;
691 x0[j] = copy(x0j);
692 x1[j] = copy(x1j);
693 x2[j] = copy(x2j);
694 x3[j] = copy(x3j);
695 },
696 }
697
698 for (i, u) in u {
699 let uij = conj_triu.apply_rt(u);
700 x0[i] = x0[i] - uij * x0j;
701 x1[i] = x1[i] - uij * x1j;
702 x2[i] = x2[i] - uij * x2j;
703 x3[i] = x3[i] - uij * x3j;
704 }
705 }
706 k = k3.next();
707 },
708 (Some(_), Some(k2), _) => {
709 let mut x = x.rb_mut().get_mut(.., k..k2.next()).col_iter_mut();
710 let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else {
711 panic!()
712 };
713
714 for j in N.indices().rev() {
715 let mut u = iter::zip(u.row_idx_of_col(j).rev(), u.val_of_col(j).iter().rev());
716
717 let (i, d) = u.next().unwrap();
718 debug_assert!(i == j);
719
720 let x0j;
721 let x1j;
722 let x2j;
723 match diag_triu {
724 DiagStatus::Unit => {
725 x0j = copy(x0[j]);
726 x1j = copy(x1[j]);
727 x2j = copy(x2[j]);
728 },
729 DiagStatus::Generic => {
730 let d = conj_triu.apply_rt(&recip(*d));
731 x0j = x0[j] * d;
732 x1j = x1[j] * d;
733 x2j = x2[j] * d;
734 x0[j] = copy(x0j);
735 x1[j] = copy(x1j);
736 x2[j] = copy(x2j);
737 },
738 }
739
740 for (i, u) in u {
741 let uij = conj_triu.apply_rt(u);
742 x0[i] = x0[i] - uij * x0j;
743 x1[i] = x1[i] - uij * x1j;
744 x2[i] = x2[i] - uij * x2j;
745 }
746 }
747 k = k2.next();
748 },
749 (Some(k1), _, _) => {
750 let mut x = x.rb_mut().get_mut(.., k..k1.next()).col_iter_mut();
751 let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { panic!() };
752
753 for j in N.indices().rev() {
754 let mut u = iter::zip(u.row_idx_of_col(j).rev(), u.val_of_col(j).iter().rev());
755
756 let (i, d) = u.next().unwrap();
757 debug_assert!(i == j);
758
759 let x0j;
760 let x1j;
761 match diag_triu {
762 DiagStatus::Unit => {
763 x0j = copy(x0[j]);
764 x1j = copy(x1[j]);
765 },
766 DiagStatus::Generic => {
767 let d = conj_triu.apply_rt(&recip(*d));
768 x0j = x0[j] * d;
769 x1j = x1[j] * d;
770 x0[j] = copy(x0j);
771 x1[j] = copy(x1j);
772 },
773 }
774
775 for (i, u) in u {
776 let uij = conj_triu.apply_rt(u);
777 x0[i] = x0[i] - uij * x0j;
778 x1[i] = x1[i] - uij * x1j;
779 }
780 }
781 k = k1.next();
782 },
783 (_, _, _) => {
784 let mut x0 = x.rb_mut().get_mut(.., k0);
785
786 for j in N.indices().rev() {
787 let mut u = iter::zip(u.row_idx_of_col(j).rev(), u.val_of_col(j).iter().rev());
788
789 let (i, d) = u.next().unwrap();
790 debug_assert!(i == j);
791
792 let x0j;
793 match diag_triu {
794 DiagStatus::Unit => {
795 x0j = copy(x0[j]);
796 },
797 DiagStatus::Generic => {
798 let d = conj_triu.apply_rt(&recip(*d));
799 x0j = x0[j] * d;
800 x0[j] = copy(x0j);
801 },
802 }
803
804 for (i, u) in u {
805 let uij = conj_triu.apply_rt(u);
806 x0[i] = x0[i] - uij * x0j;
807 }
808 }
809 k = k0.next();
810 },
811 }
812 }
813}
814
815#[track_caller]
816#[math]
817fn solve_upper_triangular_transpose_in_place_impl<I: Index, T: ComplexField>(
818 triu: SparseColMatRef<'_, I, T>,
819 conj_triu: Conj,
820 diag_triu: DiagStatus,
821 rhs: MatMut<'_, T>,
822 par: Par,
823) {
824 let _ = par;
825 assert!(all(triu.nrows() == triu.ncols(), rhs.nrows() == triu.nrows()));
826
827 with_dim!(N, rhs.nrows());
828 with_dim!(K, rhs.ncols());
829
830 let mut x = rhs.as_shape_mut(N, K);
831 let u = triu.as_shape(N, N);
832
833 let mut k = IdxInc::ZERO;
834 while let Some(k0) = K.try_check(*k) {
835 let k1 = K.try_check(*k + 1);
836 let k2 = K.try_check(*k + 2);
837 let k3 = K.try_check(*k + 3);
838
839 match (k1, k2, k3) {
840 (Some(_), Some(_), Some(k3)) => {
841 let mut x = x.rb_mut().get_mut(.., k..k3.next()).col_iter_mut();
842 let (Some(mut x0), Some(mut x1), Some(mut x2), Some(mut x3)) = (x.next(), x.next(), x.next(), x.next()) else {
843 panic!()
844 };
845
846 for j in N.indices() {
847 let mut ui = u.row_idx_of_col(j);
848 let mut uv = u.val_of_col(j).iter();
849 let first = ui.next_back().zip(uv.next_back());
850
851 let mut acc0a = zero::<T>();
852 let mut acc1a = zero::<T>();
853 let mut acc2a = zero::<T>();
854 let mut acc3a = zero::<T>();
855
856 for (i, uij) in iter::zip(ui, uv) {
857 let uij = conj_triu.apply_rt(uij);
858 acc0a = acc0a + uij * x0[i];
859 acc1a = acc1a + uij * x1[i];
860 acc2a = acc2a + uij * x2[i];
861 acc3a = acc3a + uij * x3[i];
862 }
863
864 let mut x0j = x0[j] - acc0a;
865 let mut x1j = x1[j] - acc1a;
866 let mut x2j = x2[j] - acc2a;
867 let mut x3j = x3[j] - acc3a;
868
869 let (i, d) = first.unwrap();
870 debug_assert!(i == j);
871 match diag_triu {
872 DiagStatus::Unit => {},
873 DiagStatus::Generic => {
874 let d = conj_triu.apply_rt(&recip(*d));
875 x0j = x0j * d;
876 x1j = x1j * d;
877 x2j = x2j * d;
878 x3j = x3j * d;
879 },
880 }
881
882 x0[j] = x0j;
883 x1[j] = x1j;
884 x2[j] = x2j;
885 x3[j] = x3j;
886 }
887 k = k3.next();
888 },
889 (Some(_), Some(k2), _) => {
890 let mut x = x.rb_mut().get_mut(.., k..k2.next()).col_iter_mut();
891 let (Some(mut x0), Some(mut x1), Some(mut x2)) = (x.next(), x.next(), x.next()) else {
892 panic!()
893 };
894
895 for j in N.indices() {
896 let mut ui = u.row_idx_of_col(j);
897 let mut uv = u.val_of_col(j).iter();
898 let first = ui.next_back().zip(uv.next_back());
899
900 let mut acc0a = zero::<T>();
901 let mut acc1a = zero::<T>();
902 let mut acc2a = zero::<T>();
903
904 for (i, uij) in iter::zip(ui, uv) {
905 let uij = conj_triu.apply_rt(uij);
906 acc0a = acc0a + uij * x0[i];
907 acc1a = acc1a + uij * x1[i];
908 acc2a = acc2a + uij * x2[i];
909 }
910
911 let mut x0j = x0[j] - acc0a;
912 let mut x1j = x1[j] - acc1a;
913 let mut x2j = x2[j] - acc2a;
914
915 let (i, d) = first.unwrap();
916 debug_assert!(i == j);
917 match diag_triu {
918 DiagStatus::Unit => {},
919 DiagStatus::Generic => {
920 let d = conj_triu.apply_rt(&recip(*d));
921 x0j = x0j * d;
922 x1j = x1j * d;
923 x2j = x2j * d;
924 },
925 }
926
927 x0[j] = x0j;
928 x1[j] = x1j;
929 x2[j] = x2j;
930 }
931
932 k = k2.next();
933 },
934 (Some(k1), _, _) => {
935 let mut x = x.rb_mut().get_mut(.., k..k1.next()).col_iter_mut();
936 let (Some(mut x0), Some(mut x1)) = (x.next(), x.next()) else { panic!() };
937
938 for j in N.indices() {
939 let mut ui = u.row_idx_of_col(j);
940 let mut uv = u.val_of_col(j).iter();
941 let first = ui.next_back().zip(uv.next_back());
942
943 let mut acc0a = zero::<T>();
944 let mut acc1a = zero::<T>();
945
946 for (i, uij) in iter::zip(ui, uv) {
947 let uij = conj_triu.apply_rt(uij);
948 acc0a = acc0a + uij * x0[i];
949 acc1a = acc1a + uij * x1[i];
950 }
951
952 let mut x0j = x0[j] - acc0a;
953 let mut x1j = x1[j] - acc1a;
954
955 let (i, d) = first.unwrap();
956 debug_assert!(i == j);
957 match diag_triu {
958 DiagStatus::Unit => {},
959 DiagStatus::Generic => {
960 let d = conj_triu.apply_rt(&recip(*d));
961 x0j = x0j * d;
962 x1j = x1j * d;
963 },
964 }
965
966 x0[j] = x0j;
967 x1[j] = x1j;
968 }
969
970 k = k1.next();
971 },
972 (_, _, _) => {
973 let mut x0 = x.rb_mut().get_mut(.., k0);
974
975 for j in N.indices() {
976 let mut ui = u.row_idx_of_col(j);
977 let mut uv = u.val_of_col(j).iter();
978 let first = ui.next_back().zip(uv.next_back());
979
980 let mut acc0a = zero::<T>();
981
982 for (i, uij) in iter::zip(ui, uv) {
983 let uij = conj_triu.apply_rt(uij);
984 acc0a = acc0a + uij * x0[i];
985 }
986
987 let mut x0j = x0[j] - acc0a;
988
989 let (i, d) = first.unwrap();
990 debug_assert!(i == j);
991 match diag_triu {
992 DiagStatus::Unit => {},
993 DiagStatus::Generic => {
994 let d = conj_triu.apply_rt(&recip(*d));
995 x0j = x0j * d;
996 },
997 }
998
999 x0[j] = x0j;
1000 }
1001
1002 k = k0.next();
1003 },
1004 }
1005 }
1006}