1use crate::assert;
11use crate::internal_prelude_sp::*;
12use crate::sparse::utils;
13use linalg::qr::no_pivoting::factor::QrParams;
14use linalg_sp::cholesky::ghost_postorder;
15use linalg_sp::cholesky::simplicial::EliminationTreeRef;
16use linalg_sp::{SupernodalThreshold, SymbolicSupernodalParams, colamd, ghost};
17
18#[inline]
19pub(crate) fn ghost_col_etree<'n, I: Index>(
20 A: SymbolicSparseColMatRef<'_, I, Dim<'_>, Dim<'n>>,
21 col_perm: Option<PermRef<'_, I, Dim<'n>>>,
22 etree: &mut Array<'n, I::Signed>,
23 stack: &mut MemStack,
24) {
25 let I = I::truncate;
26
27 let N = A.ncols();
28 let M = A.nrows();
29
30 let (ancestor, stack) = unsafe { stack.make_raw::<I::Signed>(*N) };
31 let (prev, _) = unsafe { stack.make_raw::<I::Signed>(*M) };
32
33 let ancestor = Array::from_mut(ghost::fill_none::<I>(ancestor, N), N);
34 let prev = Array::from_mut(ghost::fill_none::<I>(prev, N), M);
35
36 etree.as_mut().fill(I::Signed::truncate(NONE));
37 for j in N.indices() {
38 let pj = col_perm.map(|perm| perm.bound_arrays().0[j].zx()).unwrap_or(j);
39 for i_ in A.row_idx_of_col(pj) {
40 let mut i = prev[i_].sx();
41 while let Some(i_) = i.idx() {
42 if i_ == j {
43 break;
44 }
45 let next_i = ancestor[i_];
46 ancestor[i_] = MaybeIdx::from_index(j.truncate());
47 if next_i.idx().is_none() {
48 etree[i_] = I(*j).to_signed();
49 break;
50 }
51 i = next_i.sx();
52 }
53 prev[i_] = MaybeIdx::from_index(j.truncate());
54 }
55 }
56}
57
58#[inline]
61pub fn col_etree_scratch<I: Index>(nrows: usize, ncols: usize) -> StackReq {
62 StackReq::all_of(&[StackReq::new::<I>(nrows), StackReq::new::<I>(ncols)])
63}
64
65#[inline]
70pub fn col_etree<'out, I: Index>(
71 A: SymbolicSparseColMatRef<'_, I>,
72 col_perm: Option<PermRef<'_, I>>,
73 etree: &'out mut [I],
74 stack: &mut MemStack,
75) -> EliminationTreeRef<'out, I> {
76 with_dim!(M, A.nrows());
77 with_dim!(N, A.ncols());
78 ghost_col_etree(
79 A.as_shape(M, N),
80 col_perm.map(|perm| perm.as_shape(N)),
81 Array::from_mut(bytemuck::cast_slice_mut(etree), N),
82 stack,
83 );
84
85 EliminationTreeRef {
86 inner: bytemuck::cast_slice_mut(etree),
87 }
88}
89
90pub(crate) fn ghost_least_common_ancestor<'n, I: Index>(
91 i: Idx<'n, usize>,
92 j: Idx<'n, usize>,
93 first: &Array<'n, MaybeIdx<'n, I>>,
94 max_first: &mut Array<'n, MaybeIdx<'n, I>>,
95 prev_leaf: &mut Array<'n, MaybeIdx<'n, I>>,
96 ancestor: &mut Array<'n, Idx<'n, I>>,
97) -> isize {
98 if i <= j || *first[j] <= *max_first[i] {
99 return -2;
100 }
101
102 max_first[i] = first[j];
103 let j_prev = prev_leaf[i].sx();
104 prev_leaf[i] = MaybeIdx::from_index(j.truncate());
105 let Some(j_prev) = j_prev.idx() else {
106 return -1;
107 };
108 let mut lca = j_prev;
109 while lca != ancestor[lca].zx() {
110 lca = ancestor[lca].zx();
111 }
112
113 let mut node = j_prev;
114 while node != lca {
115 let next = ancestor[node].zx();
116 ancestor[node] = lca.truncate();
117 node = next;
118 }
119
120 *lca as isize
121}
122
123pub(crate) fn ghost_column_counts_aat<'m, 'n, I: Index>(
124 col_counts: &mut Array<'m, I>,
125 min_row: &mut Array<'n, I::Signed>,
126 A: SymbolicSparseColMatRef<'_, I, Dim<'m>, Dim<'n>>,
127 row_perm: Option<PermRef<'_, I, Dim<'m>>>,
128 etree: &Array<'m, MaybeIdx<'m, I>>,
129 post: &Array<'m, Idx<'m, I>>,
130 stack: &mut MemStack,
131) {
132 let M: Dim<'m> = A.nrows();
133 let N: Dim<'n> = A.ncols();
134 let n = *N;
135 let m = *M;
136
137 let delta = col_counts;
138 let (first, stack) = unsafe { stack.make_raw::<I::Signed>(m) };
139 let (max_first, stack) = unsafe { stack.make_raw::<I::Signed>(m) };
140 let (prev_leaf, stack) = unsafe { stack.make_raw::<I::Signed>(m) };
141 let (ancestor, stack) = unsafe { stack.make_raw::<I>(m) };
142 let (next, stack) = unsafe { stack.make_raw::<I::Signed>(n) };
143 let (head, _) = unsafe { stack.make_raw::<I::Signed>(m) };
144
145 let post_inv = &mut *first;
146 let post_inv = Array::from_mut(ghost::fill_zero::<I>(bytemuck::cast_slice_mut(post_inv), M), M);
147 for j in M.indices() {
148 post_inv[post[j].zx()] = j.truncate();
149 }
150 let next = Array::from_mut(ghost::fill_none::<I>(next, N), N);
151 let head = Array::from_mut(ghost::fill_none::<I>(head, N), M);
152
153 for j in N.indices() {
154 if let Some(perm) = row_perm {
155 let inv = perm.bound_arrays().1;
156 min_row[j] = match Iterator::min(A.row_idx_of_col(j).map(|j| inv[j].zx())) {
157 Some(first_row) => I::Signed::truncate(*first_row),
158 None => *MaybeIdx::<'_, I>::none(),
159 };
160 } else {
161 min_row[j] = match Iterator::min(A.row_idx_of_col(j)) {
162 Some(first_row) => I::Signed::truncate(*first_row),
163 None => *MaybeIdx::<'_, I>::none(),
164 };
165 }
166
167 let min_row = if let Some(perm) = row_perm {
168 let inv = perm.bound_arrays().1;
169 Iterator::min(A.row_idx_of_col(j).map(|row| post_inv[inv[row].zx()]))
170 } else {
171 Iterator::min(A.row_idx_of_col(j).map(|row| post_inv[row]))
172 };
173 if let Some(min_row) = min_row {
174 let min_row = min_row.zx();
175 let head = &mut head[min_row];
176 next[j] = *head;
177 *head = MaybeIdx::from_index(j.truncate());
178 };
179 }
180
181 let first = Array::from_mut(ghost::fill_none::<I>(first, M), M);
182 let max_first = Array::from_mut(ghost::fill_none::<I>(max_first, M), M);
183 let prev_leaf = Array::from_mut(ghost::fill_none::<I>(prev_leaf, M), M);
184 for (i, p) in ancestor.iter_mut().enumerate() {
185 *p = I::truncate(i);
186 }
187 let ancestor = Array::from_mut(unsafe { Idx::from_slice_mut_unchecked(ancestor) }, M);
188
189 let incr = |i: &mut I| {
190 *i = I::from_signed((*i).to_signed() + I::Signed::truncate(1));
191 };
192 let decr = |i: &mut I| {
193 *i = I::from_signed((*i).to_signed() - I::Signed::truncate(1));
194 };
195
196 for k in M.indices() {
197 let mut pk = post[k].zx();
198 delta[pk] = I::truncate(if first[pk].idx().is_none() { 1 } else { 0 });
199 loop {
200 if first[pk].idx().is_some() {
201 break;
202 }
203
204 first[pk] = MaybeIdx::from_index(k.truncate());
205 if let Some(parent) = etree[pk].idx() {
206 pk = parent.zx();
207 } else {
208 break;
209 }
210 }
211 }
212
213 for k in M.indices() {
214 let pk = post[k].zx();
215
216 if let Some(parent) = etree[pk].idx() {
217 decr(&mut delta[parent.zx()]);
218 }
219
220 let head_k = &mut head[k];
221 let mut j = (*head_k).sx();
222 *head_k = MaybeIdx::none();
223
224 while let Some(j_) = j.idx() {
225 for i in A.row_idx_of_col(j_) {
226 let i = row_perm.map(|perm| perm.bound_arrays().1[i].zx()).unwrap_or(i);
227 let lca = ghost_least_common_ancestor::<I>(i, pk, first, max_first, prev_leaf, ancestor);
228
229 if lca != -2 {
230 incr(&mut delta[pk]);
231
232 if lca != -1 {
233 decr(&mut delta[M.check(lca as usize)]);
234 }
235 }
236 }
237 j = next[j_].sx();
238 }
239 if let Some(parent) = etree[pk].idx() {
240 ancestor[pk] = parent;
241 }
242 }
243
244 for k in M.indices() {
245 if let Some(parent) = etree[k].idx() {
246 let parent = parent.zx();
247 delta[parent] = I::from_signed(delta[parent].to_signed() + delta[k].to_signed());
248 }
249 }
250}
251
252#[inline]
255pub fn column_counts_aat_scratch<I: Index>(nrows: usize, ncols: usize) -> StackReq {
256 StackReq::all_of(&[StackReq::new::<I>(nrows).array(5), StackReq::new::<I>(ncols)])
257}
258
259pub fn column_counts_ata<'m, 'n, I: Index>(
270 col_counts: &mut [I],
271 min_col: &mut [I],
272 AT: SymbolicSparseColMatRef<'_, I>,
273 col_perm: Option<PermRef<'_, I>>,
274 etree: EliminationTreeRef<'_, I>,
275 post: &[I],
276 stack: &mut MemStack,
277) {
278 with_dim!(M, AT.nrows());
279 with_dim!(N, AT.ncols());
280
281 let A = AT.as_shape(M, N);
282 ghost_column_counts_aat(
283 Array::from_mut(col_counts, M),
284 Array::from_mut(bytemuck::cast_slice_mut(min_col), N),
285 A,
286 col_perm.map(|perm| perm.as_shape(M)),
287 etree.as_bound(M),
288 Array::from_ref(Idx::from_slice_ref_checked(post, M), M),
289 stack,
290 )
291}
292
293#[inline]
296pub fn postorder_scratch<I: Index>(n: usize) -> StackReq {
297 StackReq::new::<I>(n).array(3)
298}
299
300#[inline]
302pub fn postorder<I: Index>(post: &mut [I], etree: EliminationTreeRef<'_, I>, stack: &mut MemStack) {
303 with_dim!(N, etree.inner.len());
304 ghost_postorder(Array::from_mut(post, N), etree.as_bound(N), stack)
305}
306
307pub mod supernodal {
313 use super::*;
314 use crate::assert;
315 use linalg_sp::cholesky::supernodal::{SupernodalLltRef, SymbolicSupernodalCholesky};
316
317 #[derive(Debug)]
323 pub struct SymbolicSupernodalHouseholder<I> {
324 col_ptr_for_row_idx: alloc::vec::Vec<I>,
325 col_ptr_for_tau_val: alloc::vec::Vec<I>,
326 col_ptr_for_val: alloc::vec::Vec<I>,
327 super_etree: alloc::vec::Vec<I>,
328 max_blocksize: alloc::vec::Vec<I>,
329 nrows: usize,
330 }
331
332 impl<I: Index> SymbolicSupernodalHouseholder<I> {
333 #[inline]
335 pub fn nrows(&self) -> usize {
336 self.nrows
337 }
338
339 #[inline]
341 pub fn n_supernodes(&self) -> usize {
342 self.super_etree.len()
343 }
344
345 #[inline]
347 pub fn col_ptr_for_householder_val(&self) -> &[I] {
348 self.col_ptr_for_val.as_ref()
349 }
350
351 #[inline]
353 pub fn col_ptr_for_tau_val(&self) -> &[I] {
354 self.col_ptr_for_tau_val.as_ref()
355 }
356
357 #[inline]
359 pub fn col_ptr_for_householder_row_idx(&self) -> &[I] {
360 self.col_ptr_for_row_idx.as_ref()
361 }
362
363 #[inline]
366 pub fn len_householder_val(&self) -> usize {
367 self.col_ptr_for_householder_val()[self.n_supernodes()].zx()
368 }
369
370 #[inline]
373 pub fn len_householder_row_idx(&self) -> usize {
374 self.col_ptr_for_householder_row_idx()[self.n_supernodes()].zx()
375 }
376
377 #[inline]
380 pub fn len_tau_val(&self) -> usize {
381 self.col_ptr_for_tau_val()[self.n_supernodes()].zx()
382 }
383 }
384 #[derive(Debug)]
386 pub struct SymbolicSupernodalQr<I> {
387 L: SymbolicSupernodalCholesky<I>,
388 H: SymbolicSupernodalHouseholder<I>,
389 min_col: alloc::vec::Vec<I>,
390 min_col_perm: alloc::vec::Vec<I>,
391 index_to_super: alloc::vec::Vec<I>,
392 child_head: alloc::vec::Vec<I>,
393 child_next: alloc::vec::Vec<I>,
394 }
395
396 impl<I: Index> SymbolicSupernodalQr<I> {
397 #[inline]
399 pub fn R_adjoint(&self) -> &SymbolicSupernodalCholesky<I> {
400 &self.L
401 }
402
403 #[inline]
405 pub fn householder(&self) -> &SymbolicSupernodalHouseholder<I> {
406 &self.H
407 }
408
409 pub fn solve_in_place_scratch<T: ComplexField>(&self, rhs_ncols: usize, par: Par) -> StackReq {
412 let _ = par;
413 let L_symbolic = self.R_adjoint();
414 let H_symbolic = self.householder();
415 let n_supernodes = L_symbolic.n_supernodes();
416
417 let mut loop_scratch = StackReq::empty();
418 for s in 0..n_supernodes {
419 let s_h_row_begin = H_symbolic.col_ptr_for_row_idx[s].zx();
420 let s_h_row_full_end = H_symbolic.col_ptr_for_row_idx[s + 1].zx();
421 let max_blocksize = H_symbolic.max_blocksize[s].zx();
422
423 loop_scratch = loop_scratch.or(
424 linalg::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<T>(
425 s_h_row_full_end - s_h_row_begin,
426 max_blocksize,
427 rhs_ncols,
428 ),
429 );
430 }
431
432 loop_scratch
433 }
434 }
435
436 pub fn factorize_supernodal_symbolic_qr_scratch<I: Index>(nrows: usize, ncols: usize) -> StackReq {
439 let _ = nrows;
440 linalg_sp::cholesky::supernodal::factorize_supernodal_symbolic_cholesky_scratch::<I>(ncols)
441 }
442
443 pub fn factorize_supernodal_symbolic_qr<I: Index>(
446 A: SymbolicSparseColMatRef<'_, I>,
447 col_perm: Option<PermRef<'_, I>>,
448 min_col: alloc::vec::Vec<I>,
449 etree: EliminationTreeRef<'_, I>,
450 col_counts: &[I],
451 stack: &mut MemStack,
452 params: SymbolicSupernodalParams<'_>,
453 ) -> Result<SymbolicSupernodalQr<I>, FaerError> {
454 let m = A.nrows();
455 let n = A.ncols();
456
457 with_dim!(M, m);
458 with_dim!(N, n);
459 let A = A.as_shape(M, N);
460 let mut stack = stack;
461 let (L, H) = {
462 let etree = etree.as_bound(N);
463 let min_col = Array::from_ref(MaybeIdx::from_slice_ref_checked(bytemuck::cast_slice(&min_col), N), M);
464 let L = linalg_sp::cholesky::supernodal::ghost_factorize_supernodal_symbolic(
465 A,
466 col_perm.map(|perm| perm.as_shape(N)),
467 Some(min_col),
468 linalg_sp::cholesky::supernodal::CholeskyInput::ATA,
469 etree,
470 Array::from_ref(col_counts, N),
471 stack.rb_mut(),
472 params,
473 )?;
474
475 let H = ghost_factorize_supernodal_householder_symbolic(&L, M, N, min_col, etree, stack)?;
476
477 (L, H)
478 };
479 let n_supernodes = L.n_supernodes();
480
481 let mut min_col_perm = try_zeroed::<I>(m)?;
482 let mut index_to_super = try_zeroed::<I>(n)?;
483 let mut child_head = try_zeroed::<I>(n_supernodes)?;
484 let mut child_next = try_zeroed::<I>(n_supernodes)?;
485 for i in 0..m {
486 min_col_perm[i] = I::truncate(i);
487 }
488 min_col_perm.sort_unstable_by_key(|i| min_col[i.zx()]);
489 for s in 0..n_supernodes {
490 index_to_super[L.supernode_begin()[s].zx()..L.supernode_end()[s].zx()].fill(I::truncate(s));
491 }
492
493 child_head.fill(I::truncate(NONE));
494 child_next.fill(I::truncate(NONE));
495
496 for s in 0..n_supernodes {
497 let parent = H.super_etree[s];
498 if parent.to_signed() >= I::Signed::truncate(0) {
499 let parent = parent.zx();
500 let head = child_head[parent];
501 child_next[s] = head;
502 child_head[parent] = I::truncate(s);
503 }
504 }
505
506 Ok(SymbolicSupernodalQr {
507 L,
508 H,
509 min_col,
510 min_col_perm,
511 index_to_super,
512 child_head,
513 child_next,
514 })
515 }
516
517 fn ghost_factorize_supernodal_householder_symbolic<'m, 'n, I: Index>(
518 L_symbolic: &SymbolicSupernodalCholesky<I>,
519 M: Dim<'m>,
520 N: Dim<'n>,
521 min_col: &Array<'m, MaybeIdx<'n, I>>,
522 etree: &Array<'n, MaybeIdx<'n, I>>,
523 stack: &mut MemStack,
524 ) -> Result<SymbolicSupernodalHouseholder<I>, FaerError> {
525 let n_supernodes = L_symbolic.n_supernodes();
526
527 with_dim!(N_SUPERNODES, n_supernodes);
528
529 let mut col_ptr_for_row_idx = try_zeroed::<I>(n_supernodes + 1)?;
530 let mut col_ptr_for_tau_val = try_zeroed::<I>(n_supernodes + 1)?;
531 let mut col_ptr_for_val = try_zeroed::<I>(n_supernodes + 1)?;
532 let mut super_etree_ = try_zeroed::<I>(n_supernodes)?;
533 let mut max_blocksize = try_zeroed::<I>(n_supernodes)?;
534 let super_etree = bytemuck::cast_slice_mut::<I, I::Signed>(&mut super_etree_);
535
536 let to_wide = |i: I| i.zx() as u128;
537 let from_wide = |i: u128| I::truncate(i as usize);
538 let from_wide_checked = |i: u128| -> Option<I> { (i <= to_wide(I::from_signed(I::Signed::MAX))).then_some(I::truncate(i as usize)) };
539
540 let supernode_begin = Array::from_ref(L_symbolic.supernode_begin(), N_SUPERNODES);
541 let supernode_end = Array::from_ref(L_symbolic.supernode_end(), N_SUPERNODES);
542 let L_col_ptr_for_row_idx = L_symbolic.col_ptr_for_row_idx();
543
544 let (index_to_super, _) = unsafe { stack.make_raw::<I>(*N) };
545
546 for s in N_SUPERNODES.indices() {
547 index_to_super[supernode_begin[s].zx()..supernode_end[s].zx()].fill(*s.truncate::<I>());
548 }
549 let index_to_super = Array::from_ref(Idx::from_slice_ref_checked(index_to_super, N_SUPERNODES), N);
550
551 let super_etree = Array::from_mut(super_etree, N_SUPERNODES);
552 for s in N_SUPERNODES.indices() {
553 let last = supernode_end[s].zx() - 1;
554 if let Some(parent) = etree[N.check(last)].idx() {
555 super_etree[s] = index_to_super[parent.zx()].to_signed();
556 } else {
557 super_etree[s] = I::Signed::truncate(NONE);
558 }
559 }
560 let super_etree = Array::from_ref(
561 MaybeIdx::<'_, I>::from_slice_ref_checked(super_etree.as_ref(), N_SUPERNODES),
562 N_SUPERNODES,
563 );
564
565 let non_zero_count = Array::from_mut(&mut col_ptr_for_row_idx[1..], N_SUPERNODES);
566 for i in M.indices() {
567 let Some(min_col) = min_col[i].idx() else {
568 continue;
569 };
570 non_zero_count[index_to_super[min_col.zx()].zx()] += I::truncate(1);
571 }
572
573 for s in N_SUPERNODES.indices() {
574 if let Some(parent) = super_etree[s].idx() {
575 let s_col_count = L_col_ptr_for_row_idx[*s + 1] - L_col_ptr_for_row_idx[*s];
576 let panel_width = supernode_end[s] - supernode_begin[s];
577
578 let s_count = non_zero_count[s];
579 non_zero_count[parent.zx()] += Ord::min(Ord::max(s_count, panel_width) - panel_width, s_col_count);
580 }
581 }
582
583 let mut val_count = to_wide(I::truncate(0));
584 let mut tau_count = to_wide(I::truncate(0));
585 let mut row_count = to_wide(I::truncate(0));
586 for (s, ((next_row_ptr, next_val_ptr), next_tau_ptr)) in iter::zip(
587 N_SUPERNODES.indices(),
588 iter::zip(
589 iter::zip(&mut col_ptr_for_row_idx[1..], &mut col_ptr_for_val[1..]),
590 &mut col_ptr_for_tau_val[1..],
591 ),
592 ) {
593 let panel_width = supernode_end[s] - supernode_begin[s];
594 let s_row_count = *next_row_ptr;
595 let s_col_count = panel_width + (L_col_ptr_for_row_idx[*s + 1] - L_col_ptr_for_row_idx[*s]);
596 val_count += to_wide(s_row_count) * to_wide(s_col_count);
597 row_count += to_wide(s_row_count);
598 let blocksize = linalg::qr::no_pivoting::factor::recommended_blocksize::<Symbolic>(s_row_count.zx(), s_col_count.zx()) as u128;
599 max_blocksize[*s] = from_wide(blocksize);
600 tau_count += blocksize * to_wide(Ord::min(s_row_count, s_col_count));
601 *next_val_ptr = from_wide(val_count);
602 *next_row_ptr = from_wide(row_count);
603 *next_tau_ptr = from_wide(tau_count);
604 }
605 from_wide_checked(row_count).ok_or(FaerError::IndexOverflow)?;
606 from_wide_checked(tau_count).ok_or(FaerError::IndexOverflow)?;
607 from_wide_checked(val_count).ok_or(FaerError::IndexOverflow)?;
608
609 Ok(SymbolicSupernodalHouseholder {
610 col_ptr_for_row_idx,
611 col_ptr_for_val,
612 super_etree: super_etree_,
613 col_ptr_for_tau_val,
614 max_blocksize,
615 nrows: *M,
616 })
617 }
618
619 #[derive(Debug)]
621 pub struct SupernodalQrRef<'a, I: Index, T> {
622 symbolic: &'a SymbolicSupernodalQr<I>,
623 rt_val: &'a [T],
624 householder_val: &'a [T],
625 tau_val: &'a [T],
626 householder_row_idx: &'a [I],
627 tau_blocksize: &'a [I],
628 householder_nrows: &'a [I],
629 householder_ncols: &'a [I],
630 }
631
632 impl<I: Index, T> Copy for SupernodalQrRef<'_, I, T> {}
633 impl<I: Index, T> Clone for SupernodalQrRef<'_, I, T> {
634 #[inline]
635 fn clone(&self) -> Self {
636 *self
637 }
638 }
639
640 impl<'a, I: Index, T> SupernodalQrRef<'a, I, T> {
641 #[inline]
646 pub unsafe fn new_unchecked(
647 symbolic: &'a SymbolicSupernodalQr<I>,
648 householder_row_idx: &'a [I],
649 tau_blocksize: &'a [I],
650 householder_nrows: &'a [I],
651 householder_ncols: &'a [I],
652 r_val: &'a [T],
653 householder_val: &'a [T],
654 tau_val: &'a [T],
655 ) -> Self {
656 let rt_val = r_val;
657 let householder_val = householder_val;
658 let tau_val = tau_val;
659 assert!(rt_val.len() == symbolic.R_adjoint().len_val());
660 assert!(tau_val.len() == symbolic.householder().len_tau_val());
661 assert!(householder_val.len() == symbolic.householder().len_householder_val());
662 assert!(tau_blocksize.len() == householder_nrows.len());
663 Self {
664 symbolic,
665 tau_blocksize,
666 householder_nrows,
667 householder_ncols,
668 rt_val,
669 householder_val,
670 tau_val,
671 householder_row_idx,
672 }
673 }
674
675 #[inline]
677 pub fn symbolic(self) -> &'a SymbolicSupernodalQr<I> {
678 self.symbolic
679 }
680
681 #[inline]
683 pub fn R_val(self) -> &'a [T] {
684 self.rt_val
685 }
686
687 #[inline]
689 pub fn householder_val(self) -> &'a [T] {
690 self.householder_val
691 }
692
693 #[inline]
695 pub fn tau_val(self) -> &'a [T] {
696 self.tau_val
697 }
698
699 #[math]
703 pub fn apply_Q_transpose_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>, par: Par, work: MatMut<'_, T>, stack: &mut MemStack)
704 where
705 T: ComplexField,
706 {
707 let L_symbolic = self.symbolic().R_adjoint();
708 let H_symbolic = self.symbolic().householder();
709 let n_supernodes = L_symbolic.n_supernodes();
710 let mut stack = stack;
711
712 assert!(rhs.nrows() == self.symbolic().householder().nrows);
713
714 let mut x = rhs;
715 let k = x.ncols();
716 let mut tmp = work;
717 tmp.fill(zero());
718
719 {
721 let H = self.householder_val;
722 let tau = self.tau_val;
723
724 let mut block_count = 0usize;
725 for s in 0..n_supernodes {
726 let tau_begin = H_symbolic.col_ptr_for_tau_val[s].zx();
727 let tau_end = H_symbolic.col_ptr_for_tau_val[s + 1].zx();
728
729 let s_h_row_begin = H_symbolic.col_ptr_for_row_idx[s].zx();
730 let s_h_row_full_end = H_symbolic.col_ptr_for_row_idx[s + 1].zx();
731
732 let s_col_begin = L_symbolic.supernode_begin()[s].zx();
733 let s_col_end = L_symbolic.supernode_end()[s].zx();
734 let s_ncols = s_col_end - s_col_begin;
735
736 let s_row_idx_in_panel = &self.householder_row_idx[s_h_row_begin..s_h_row_full_end];
737
738 let mut tmp = tmp.rb_mut().subrows_mut(s_col_begin, s_h_row_full_end - s_h_row_begin);
739 for j in 0..k {
740 for idx in 0..s_h_row_full_end - s_h_row_begin {
741 let i = s_row_idx_in_panel[idx].zx();
742 tmp[(idx, j)] = copy(x[(i, j)]);
743 }
744 }
745
746 let s_H = &H[H_symbolic.col_ptr_for_val[s].zx()..H_symbolic.col_ptr_for_val[s + 1].zx()];
747
748 let s_H = MatRef::from_column_major_slice(
749 s_H,
750 s_h_row_full_end - s_h_row_begin,
751 s_ncols + (L_symbolic.col_ptr_for_row_idx()[s + 1].zx() - L_symbolic.col_ptr_for_row_idx()[s].zx()),
752 );
753 let s_tau = &tau[tau_begin..tau_end];
754 let max_blocksize = H_symbolic.max_blocksize[s].zx();
755 let s_tau = MatRef::from_column_major_slice(s_tau, max_blocksize, Ord::min(s_H.ncols(), s_h_row_full_end - s_h_row_begin));
756
757 let mut start = 0;
758 let end = s_H.ncols();
759 while start < end {
760 let bs = self.tau_blocksize[block_count].zx();
761 let nrows = self.householder_nrows[block_count].zx();
762 let ncols = self.householder_ncols[block_count].zx();
763
764 let b_H = s_H.submatrix(start, start, nrows, ncols);
765 let b_tau = s_tau.subcols(start, ncols).subrows(0, bs);
766
767 linalg::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
768 b_H.rb(),
769 b_tau.rb(),
770 conj,
771 tmp.rb_mut().subrows_mut(start, nrows),
772 par,
773 stack.rb_mut(),
774 );
775
776 start += ncols;
777 block_count += 1;
778
779 if start >= s_H.nrows() {
780 break;
781 }
782 }
783
784 for j in 0..k {
785 for idx in 0..s_h_row_full_end - s_h_row_begin {
786 let i = s_row_idx_in_panel[idx].zx();
787 x[(i, j)] = copy(tmp[(idx, j)]);
788 }
789 }
790 }
791 }
792 let m = H_symbolic.nrows;
793 let n = L_symbolic.nrows();
794 x.rb_mut().subrows_mut(0, n).copy_from(tmp.rb().subrows(0, n));
795 x.rb_mut().subrows_mut(n, m - n).fill(zero());
796 }
797
798 #[track_caller]
803 #[math]
804 pub fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>, par: Par, work: MatMut<'_, T>, stack: &mut MemStack)
805 where
806 T: ComplexField,
807 {
808 let mut work = work;
809 let mut rhs = rhs;
810 self.apply_Q_transpose_in_place_with_conj(conj.compose(Conj::Yes), rhs.rb_mut(), par, work.rb_mut(), stack);
811
812 let L_symbolic = self.symbolic().R_adjoint();
813 let n_supernodes = L_symbolic.n_supernodes();
814
815 let mut tmp = work;
816 let mut x = rhs;
817 let k = x.ncols();
818
819 {
821 let L = SupernodalLltRef::<'_, I, T>::new(L_symbolic, self.rt_val);
822
823 for s in (0..n_supernodes).rev() {
824 let s = L.supernode(s);
825 let size = s.val().ncols();
826 let s_L = s.val();
827 let (s_L_top, s_L_bot) = s_L.split_at_row(size);
828
829 let mut tmp = tmp.rb_mut().subrows_mut(0, s.pattern().len());
830 for j in 0..k {
831 for (idx, i) in s.pattern().iter().enumerate() {
832 let i = i.zx();
833 tmp[(idx, j)] = copy(x[(i, j)]);
834 }
835 }
836
837 let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
838 linalg::matmul::matmul_with_conj(
839 x_top.rb_mut(),
840 Accum::Add,
841 s_L_bot.transpose(),
842 conj.compose(Conj::Yes),
843 tmp.rb(),
844 Conj::No,
845 -one::<T>(),
846 par,
847 );
848 linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(
849 s_L_top.transpose(),
850 conj.compose(Conj::Yes),
851 x_top.rb_mut(),
852 par,
853 );
854 }
855 }
856 }
857 }
858
859 #[track_caller]
862 pub fn factorize_supernodal_numeric_qr_scratch<I: Index, T: ComplexField>(
863 symbolic: &SymbolicSupernodalQr<I>,
864 par: Par,
865 params: Spec<QrParams, T>,
866 ) -> StackReq {
867 let n_supernodes = symbolic.L.n_supernodes();
868 let n = symbolic.L.dimension;
869 let m = symbolic.H.nrows;
870 let init_scratch = StackReq::all_of(&[
871 StackReq::new::<I>(symbolic.H.len_householder_row_idx()),
872 StackReq::new::<I>(n_supernodes),
873 StackReq::new::<I>(n),
874 StackReq::new::<I>(n),
875 StackReq::new::<I>(m),
876 StackReq::new::<I>(m),
877 ]);
878
879 let mut loop_scratch = StackReq::empty();
880 for s in 0..n_supernodes {
881 let s_h_row_begin = symbolic.H.col_ptr_for_row_idx[s].zx();
882 let s_h_row_full_end = symbolic.H.col_ptr_for_row_idx[s + 1].zx();
883 let max_blocksize = symbolic.H.max_blocksize[s].zx();
884 let s_col_begin = symbolic.L.supernode_begin()[s].zx();
885 let s_col_end = symbolic.L.supernode_end()[s].zx();
886 let s_ncols = s_col_end - s_col_begin;
887 let s_pattern_len = symbolic.L.col_ptr_for_row_idx()[s + 1].zx() - symbolic.L.col_ptr_for_row_idx()[s].zx();
888
889 loop_scratch = loop_scratch.or(linalg::qr::no_pivoting::factor::qr_in_place_scratch::<T>(
890 s_h_row_full_end - s_h_row_begin,
891 s_ncols + s_pattern_len,
892 max_blocksize,
893 par,
894 params,
895 ));
896
897 loop_scratch = loop_scratch.or(
898 linalg::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<T>(
899 s_h_row_full_end - s_h_row_begin,
900 max_blocksize,
901 s_ncols + s_pattern_len,
902 ),
903 );
904 }
905
906 init_scratch.and(loop_scratch)
907 }
908
909 #[track_caller]
925 pub fn factorize_supernodal_numeric_qr<'a, I: Index, T: ComplexField>(
926 householder_row_idx: &'a mut [I],
927 tau_blocksize: &'a mut [I],
928 householder_nrows: &'a mut [I],
929 householder_ncols: &'a mut [I],
930
931 r_val: &'a mut [T],
932 householder_val: &'a mut [T],
933 tau_val: &'a mut [T],
934
935 AT: SparseColMatRef<'_, I, T>,
936 col_perm: Option<PermRef<'_, I>>,
937 symbolic: &'a SymbolicSupernodalQr<I>,
938 par: Par,
939 stack: &mut MemStack,
940 params: Spec<QrParams, T>,
941 ) -> SupernodalQrRef<'a, I, T> {
942 assert!(all(
943 householder_row_idx.len() == symbolic.householder().len_householder_row_idx(),
944 r_val.len() == symbolic.R_adjoint().len_val(),
945 householder_val.len() == symbolic.householder().len_householder_val(),
946 tau_val.len() == symbolic.householder().len_tau_val(),
947 tau_blocksize.len() == symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes(),
948 householder_nrows.len() == symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes(),
949 householder_ncols.len() == symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes(),
950 ));
951
952 factorize_supernodal_numeric_qr_impl(
953 householder_row_idx,
954 tau_blocksize,
955 householder_nrows,
956 householder_ncols,
957 r_val,
958 householder_val,
959 tau_val,
960 AT,
961 col_perm,
962 &symbolic.L,
963 &symbolic.H,
964 &symbolic.min_col,
965 &symbolic.min_col_perm,
966 &symbolic.index_to_super,
967 bytemuck::cast_slice(&symbolic.child_head),
968 bytemuck::cast_slice(&symbolic.child_next),
969 par,
970 stack,
971 params,
972 );
973
974 unsafe {
975 SupernodalQrRef::<'_, I, T>::new_unchecked(
976 symbolic,
977 householder_row_idx,
978 tau_blocksize,
979 householder_nrows,
980 householder_ncols,
981 r_val,
982 householder_val,
983 tau_val,
984 )
985 }
986 }
987
988 #[math]
989 pub(crate) fn factorize_supernodal_numeric_qr_impl<I: Index, T: ComplexField>(
990 householder_row_idx: &mut [I],
992
993 tau_blocksize: &mut [I],
994 householder_nrows: &mut [I],
995 householder_ncols: &mut [I],
996
997 L_val: &mut [T],
998 householder_val: &mut [T],
999 tau_val: &mut [T],
1000
1001 AT: SparseColMatRef<'_, I, T>,
1002 col_perm: Option<PermRef<'_, I>>,
1003 L_symbolic: &SymbolicSupernodalCholesky<I>,
1004 H_symbolic: &SymbolicSupernodalHouseholder<I>,
1005 min_col: &[I],
1006 min_col_perm: &[I],
1007 index_to_super: &[I],
1008 child_head: &[I::Signed],
1009 child_next: &[I::Signed],
1010
1011 par: Par,
1012 stack: &mut MemStack,
1013 params: Spec<QrParams, T>,
1014 ) -> usize {
1015 let n_supernodes = L_symbolic.n_supernodes();
1016 let m = AT.ncols();
1017 let n = AT.nrows();
1018
1019 let mut block_count = 0;
1020
1021 let (min_col_in_panel, stack) = unsafe { stack.make_raw::<I>(H_symbolic.len_householder_row_idx()) };
1022 let (min_col_in_panel_perm, stack) = unsafe { stack.make_raw::<I>(m) };
1023 let (col_end_for_row_idx_in_panel, stack) = unsafe { stack.make_raw::<I>(n_supernodes) };
1024 let (col_global_to_local, stack) = unsafe { stack.make_raw::<I::Signed>(n) };
1025 let (child_col_global_to_local, stack) = unsafe { stack.make_raw::<I::Signed>(n) };
1026 let (child_row_global_to_local, mut stack) = unsafe { stack.make_raw::<I::Signed>(m) };
1027
1028 tau_val.fill(zero());
1029 L_val.fill(zero());
1030 householder_val.fill(zero());
1031
1032 col_end_for_row_idx_in_panel.copy_from_slice(&H_symbolic.col_ptr_for_row_idx[..n_supernodes]);
1033
1034 for i in 0..m {
1035 let i = min_col_perm[i].zx();
1036 let min_col = min_col[i].zx();
1037 if min_col < n {
1038 let s = index_to_super[min_col].zx();
1039 let pos = &mut col_end_for_row_idx_in_panel[s];
1040 householder_row_idx[pos.zx()] = I::truncate(i);
1041 min_col_in_panel[pos.zx()] = I::truncate(min_col);
1042 *pos += I::truncate(1);
1043 }
1044 }
1045
1046 col_global_to_local.fill(I::Signed::truncate(NONE));
1047 child_col_global_to_local.fill(I::Signed::truncate(NONE));
1048 child_row_global_to_local.fill(I::Signed::truncate(NONE));
1049
1050 let supernode_begin = L_symbolic.supernode_begin();
1051 let supernode_end = L_symbolic.supernode_end();
1052
1053 let super_etree = &*H_symbolic.super_etree;
1054
1055 let col_pattern =
1056 |node: usize| &L_symbolic.row_idx()[L_symbolic.col_ptr_for_row_idx()[node].zx()..L_symbolic.col_ptr_for_row_idx()[node + 1].zx()];
1057
1058 for s in 0..n_supernodes {
1060 let s_h_row_begin = H_symbolic.col_ptr_for_row_idx[s].zx();
1062 let s_h_row_full_end = H_symbolic.col_ptr_for_row_idx[s + 1].zx();
1063 let s_h_row_end = col_end_for_row_idx_in_panel[s].zx();
1064
1065 let s_col_begin = supernode_begin[s].zx();
1066 let s_col_end = supernode_end[s].zx();
1067 let s_ncols = s_col_end - s_col_begin;
1068
1069 let s_pattern = col_pattern(s);
1070
1071 for i in 0..s_ncols {
1072 col_global_to_local[s_col_begin + i] = I::Signed::truncate(i);
1073 }
1074 for (i, &col) in s_pattern.iter().enumerate() {
1075 col_global_to_local[col.zx()] = I::Signed::truncate(i + s_ncols);
1076 }
1077
1078 let (s_min_col_in_panel, parent_min_col_in_panel) = min_col_in_panel.split_at_mut(s_h_row_end);
1079 let parent_offset = s_h_row_end;
1080 let (c_min_col_in_panel, s_min_col_in_panel) = s_min_col_in_panel.split_at_mut(s_h_row_begin);
1081
1082 let (householder_row_idx, parent_row_idx_in_panel) = householder_row_idx.split_at_mut(s_h_row_end);
1083
1084 let (s_H, _) = householder_val.split_at_mut(H_symbolic.col_ptr_for_val[s + 1].zx());
1085 let (c_H, s_H) = s_H.split_at_mut(H_symbolic.col_ptr_for_val[s].zx());
1086 let c_H = &*c_H;
1087
1088 let mut s_H = MatMut::from_column_major_slice_mut(s_H, s_h_row_full_end - s_h_row_begin, s_ncols + s_pattern.len())
1089 .subrows_mut(0, s_h_row_end - s_h_row_begin);
1090
1091 {
1092 let s_min_col_in_panel_perm = &mut min_col_in_panel_perm[0..s_h_row_end - s_h_row_begin];
1093 for (i, p) in s_min_col_in_panel_perm.iter_mut().enumerate() {
1094 *p = I::truncate(i);
1095 }
1096 s_min_col_in_panel_perm.sort_unstable_by_key(|i| s_min_col_in_panel[i.zx()]);
1097
1098 let s_row_idx_in_panel = &mut householder_row_idx[s_h_row_begin..];
1099 let tmp: &mut [I] = bytemuck::cast_slice_mut(&mut child_row_global_to_local[..s_h_row_end - s_h_row_begin]);
1100
1101 for (i, p) in s_min_col_in_panel_perm.iter().enumerate() {
1102 let p = p.zx();
1103 tmp[i] = s_min_col_in_panel[p];
1104 }
1105 s_min_col_in_panel.copy_from_slice(tmp);
1106
1107 for (i, p) in s_min_col_in_panel_perm.iter().enumerate() {
1108 let p = p.zx();
1109 tmp[i] = s_row_idx_in_panel[p];
1110 }
1111 s_row_idx_in_panel.copy_from_slice(tmp);
1112 for (i, p) in s_min_col_in_panel_perm.iter_mut().enumerate() {
1113 *p = I::truncate(i);
1114 }
1115
1116 tmp.fill(I::truncate(NONE));
1117 }
1118
1119 let s_row_idx_in_panel = &householder_row_idx[s_h_row_begin..];
1120
1121 for idx in 0..s_h_row_end - s_h_row_begin {
1122 let i = s_row_idx_in_panel[idx].zx();
1123 if min_col[i].zx() >= s_col_begin {
1124 for (j, value) in iter::zip(AT.row_idx_of_col(i), AT.val_of_col(i)) {
1125 let pj = col_perm.map(|perm| perm.arrays().1[j].zx()).unwrap_or(j);
1126 let ix = idx;
1127 let iy = col_global_to_local[pj].zx();
1128 s_H[(ix, iy)] = s_H[(ix, iy)] + *value;
1129 }
1130 }
1131 }
1132
1133 let mut child_ = child_head[s];
1134 while child_ >= I::Signed::truncate(0) {
1135 let child = child_.zx();
1136 assert!(super_etree[child].zx() == s);
1137 let c_pattern = col_pattern(child);
1138 let c_col_begin = supernode_begin[child].zx();
1139 let c_col_end = supernode_end[child].zx();
1140 let c_ncols = c_col_end - c_col_begin;
1141
1142 let c_h_row_begin = H_symbolic.col_ptr_for_row_idx[child].zx();
1143 let c_h_row_end = H_symbolic.col_ptr_for_row_idx[child + 1].zx();
1144
1145 let c_row_idx_in_panel = &householder_row_idx[c_h_row_begin..c_h_row_end];
1146 let c_min_col_in_panel = &c_min_col_in_panel[c_h_row_begin..c_h_row_end];
1147
1148 let c_H = &c_H[H_symbolic.col_ptr_for_val[child].zx()..H_symbolic.col_ptr_for_val[child + 1].zx()];
1149 let c_H = MatRef::from_column_major_slice(
1150 c_H,
1151 H_symbolic.col_ptr_for_row_idx[child + 1].zx() - c_h_row_begin,
1152 c_ncols + c_pattern.len(),
1153 );
1154
1155 for (idx, &col) in c_pattern.iter().enumerate() {
1156 child_col_global_to_local[col.zx()] = I::Signed::truncate(idx + c_ncols);
1157 }
1158 for (idx, &p) in c_row_idx_in_panel.iter().enumerate() {
1159 child_row_global_to_local[p.zx()] = I::Signed::truncate(idx);
1160 }
1161
1162 for s_idx in 0..s_h_row_end - s_h_row_begin {
1163 let i = s_row_idx_in_panel[s_idx].zx();
1164 let c_idx = child_row_global_to_local[i];
1165 if c_idx < I::Signed::truncate(0) {
1166 continue;
1167 }
1168
1169 let c_idx = c_idx.zx();
1170 let c_min_col = c_min_col_in_panel[c_idx].zx();
1171
1172 for (j_idx_in_c, j) in c_pattern.iter().enumerate() {
1173 let j_idx_in_c = j_idx_in_c + c_ncols;
1174 if j.zx() >= c_min_col {
1175 s_H[(s_idx, col_global_to_local[j.zx()].zx())] = copy(c_H[(c_idx, j_idx_in_c)]);
1176 }
1177 }
1178 }
1179
1180 for &row in c_row_idx_in_panel {
1181 child_row_global_to_local[row.zx()] = I::Signed::truncate(NONE);
1182 }
1183 for &col in c_pattern {
1184 child_col_global_to_local[col.zx()] = I::Signed::truncate(NONE);
1185 }
1186 child_ = child_next[child];
1187 }
1188
1189 let s_col_local_to_global = |local: usize| {
1190 if local < s_ncols {
1191 s_col_begin + local
1192 } else {
1193 s_pattern[local - s_ncols].zx()
1194 }
1195 };
1196
1197 {
1198 let s_h_nrows = s_h_row_end - s_h_row_begin;
1199
1200 let tau_begin = H_symbolic.col_ptr_for_tau_val[s].zx();
1201 let tau_end = H_symbolic.col_ptr_for_tau_val[s + 1].zx();
1202 let L_begin = L_symbolic.col_ptr_for_val()[s].zx();
1203 let L_end = L_symbolic.col_ptr_for_val()[s + 1].zx();
1204
1205 let s_tau = &mut tau_val[tau_begin..tau_end];
1206 let s_L = &mut L_val[L_begin..L_end];
1207
1208 let max_blocksize = H_symbolic.max_blocksize[s].zx();
1209 let mut s_tau = MatMut::from_column_major_slice_mut(s_tau, max_blocksize, Ord::min(s_H.ncols(), s_h_row_full_end - s_h_row_begin));
1210
1211 {
1212 let mut current_min_col = 0usize;
1213 let mut current_start = 0usize;
1214 for idx in 0..s_h_nrows + 1 {
1215 let idx_global_min_col = if idx < s_h_nrows { s_min_col_in_panel[idx].zx() } else { n };
1216
1217 let idx_min_col = if idx_global_min_col < n {
1218 col_global_to_local[idx_global_min_col.zx()].zx()
1219 } else {
1220 s_H.ncols()
1221 };
1222
1223 if idx_min_col == s_H.ncols() || idx_min_col >= current_min_col.saturating_add(Ord::max(1, max_blocksize / 2)) {
1224 let nrows = idx.saturating_sub(current_start);
1225 let full_ncols = s_H.ncols() - current_start;
1226 let ncols = Ord::min(nrows, idx_min_col - current_min_col);
1227
1228 let s_H = s_H.rb_mut().submatrix_mut(current_start, current_start, nrows, full_ncols);
1229
1230 let (mut left, mut right) = s_H.split_at_col_mut(ncols);
1231 let bs = linalg::qr::no_pivoting::factor::recommended_blocksize::<Symbolic>(left.nrows(), left.ncols());
1232 let bs = Ord::min(max_blocksize, bs);
1233 tau_blocksize[block_count] = I::truncate(bs);
1234 householder_nrows[block_count] = I::truncate(nrows);
1235 householder_ncols[block_count] = I::truncate(ncols);
1236 block_count += 1;
1237
1238 let mut s_tau = s_tau.rb_mut().subrows_mut(0, bs).subcols_mut(current_start, ncols);
1239
1240 linalg::qr::no_pivoting::factor::qr_in_place(left.rb_mut(), s_tau.rb_mut(), par, stack.rb_mut(), params);
1241
1242 if right.ncols() > 0 {
1243 linalg::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
1244 left.rb(),
1245 s_tau.rb(),
1246 Conj::Yes,
1247 right.rb_mut(),
1248 par,
1249 stack.rb_mut(),
1250 );
1251 }
1252
1253 current_min_col = idx_min_col;
1254 current_start += ncols;
1255 }
1256 }
1257 }
1258
1259 let mut s_L = MatMut::from_column_major_slice_mut(s_L, s_pattern.len() + s_ncols, s_ncols);
1260 let nrows = Ord::min(s_H.nrows(), s_L.ncols());
1261 z!(s_L.rb_mut().transpose_mut().subrows_mut(0, nrows), s_H.rb().subrows(0, nrows))
1262 .for_each_triangular_upper(linalg::zip::Diag::Include, |uz!(dst, src)| *dst = conj(*src));
1263 }
1264
1265 col_end_for_row_idx_in_panel[s] = Ord::min(I::truncate(s_h_row_begin + s_ncols + s_pattern.len()), col_end_for_row_idx_in_panel[s]);
1266
1267 let s_h_row_end = col_end_for_row_idx_in_panel[s].zx();
1268 let s_h_nrows = s_h_row_end - s_h_row_begin;
1269
1270 let mut current_min_col = 0usize;
1271 for idx in 0..s_h_nrows {
1272 let idx_global_min_col = s_min_col_in_panel[idx];
1273 if idx_global_min_col.zx() >= n {
1274 break;
1275 }
1276 let idx_min_col = col_global_to_local[idx_global_min_col.zx()].zx();
1277 if current_min_col > idx_min_col {
1278 s_min_col_in_panel[idx] = I::truncate(s_col_local_to_global(current_min_col));
1279 }
1280 current_min_col += 1;
1281 }
1282
1283 let s_pivot_row_end = s_ncols;
1284
1285 let parent = super_etree[s];
1286 if parent.to_signed() < I::Signed::truncate(0) {
1287 for i in 0..s_ncols {
1288 col_global_to_local[s_col_begin + i] = I::Signed::truncate(NONE);
1289 }
1290 for &row in s_pattern {
1291 col_global_to_local[row.zx()] = I::Signed::truncate(NONE);
1292 }
1293 continue;
1294 }
1295 let parent = parent.zx();
1296 let p_h_row_begin = H_symbolic.col_ptr_for_row_idx[parent].zx();
1297 let mut pos = col_end_for_row_idx_in_panel[parent].zx() - p_h_row_begin;
1298 let parent_min_col_in_panel = &mut parent_min_col_in_panel[p_h_row_begin - parent_offset..];
1299 let parent_row_idx_in_panel = &mut parent_row_idx_in_panel[p_h_row_begin - parent_offset..];
1300
1301 for idx in s_pivot_row_end..s_h_nrows {
1302 parent_row_idx_in_panel[pos] = s_row_idx_in_panel[idx];
1303 parent_min_col_in_panel[pos] = s_min_col_in_panel[idx];
1304 pos += 1;
1305 }
1306 col_end_for_row_idx_in_panel[parent] = I::truncate(pos + p_h_row_begin);
1307
1308 for i in 0..s_ncols {
1309 col_global_to_local[s_col_begin + i] = I::Signed::truncate(NONE);
1310 }
1311 for &row in s_pattern {
1312 col_global_to_local[row.zx()] = I::Signed::truncate(NONE);
1313 }
1314 }
1315 block_count
1316 }
1317}
1318
1319pub mod simplicial {
1325 use super::*;
1326 use crate::assert;
1327
1328 #[derive(Debug)]
1330 pub struct SymbolicSimplicialQr<I> {
1331 nrows: usize,
1332 ncols: usize,
1333 h_nnz: usize,
1334 l_nnz: usize,
1335
1336 postorder: alloc::vec::Vec<I>,
1337 postorder_inv: alloc::vec::Vec<I>,
1338 desc_count: alloc::vec::Vec<I>,
1339 }
1340
1341 impl<I: Index> SymbolicSimplicialQr<I> {
1342 #[inline]
1344 pub fn nrows(&self) -> usize {
1345 self.nrows
1346 }
1347
1348 #[inline]
1350 pub fn ncols(&self) -> usize {
1351 self.ncols
1352 }
1353
1354 #[inline]
1356 pub fn len_householder(&self) -> usize {
1357 self.h_nnz
1358 }
1359
1360 #[inline]
1362 pub fn len_r(&self) -> usize {
1363 self.l_nnz
1364 }
1365 }
1366
1367 #[derive(Debug)]
1369 pub struct SimplicialQrRef<'a, I, T> {
1370 symbolic: &'a SymbolicSimplicialQr<I>,
1371 r_col_ptr: &'a [I],
1372 r_row_idx: &'a [I],
1373 r_val: &'a [T],
1374 householder_col_ptr: &'a [I],
1375 householder_row_idx: &'a [I],
1376 householder_val: &'a [T],
1377 tau_val: &'a [T],
1378 }
1379
1380 impl<I, T> Copy for SimplicialQrRef<'_, I, T> {}
1381 impl<I, T> Clone for SimplicialQrRef<'_, I, T> {
1382 #[inline]
1383 fn clone(&self) -> Self {
1384 *self
1385 }
1386 }
1387
1388 impl<'a, I: Index, T> SimplicialQrRef<'a, I, T> {
1389 #[inline]
1391 pub fn new(
1392 symbolic: &'a SymbolicSimplicialQr<I>,
1393 r: SparseColMatRef<'a, I, T>,
1394 householder: SparseColMatRef<'a, I, T>,
1395 tau_val: &'a [T],
1396 ) -> Self {
1397 assert!(householder.nrows() == symbolic.nrows);
1398 assert!(householder.ncols() == symbolic.ncols);
1399 assert!(r.nrows() == symbolic.ncols);
1400 assert!(r.ncols() == symbolic.ncols);
1401
1402 let r_col_ptr = r.col_ptr();
1403 let r_row_idx = r.row_idx();
1404 let r_val = r.val();
1405 assert!(r.col_nnz().is_none());
1406
1407 let householder_col_ptr = householder.col_ptr();
1408 let householder_row_idx = householder.row_idx();
1409 let householder_val = householder.val();
1410 assert!(householder.col_nnz().is_none());
1411
1412 assert!(r_val.len() == symbolic.len_r());
1413 assert!(tau_val.len() == symbolic.ncols);
1414 assert!(householder_val.len() == symbolic.len_householder());
1415 Self {
1416 symbolic,
1417 householder_val,
1418 tau_val,
1419 r_val,
1420 r_col_ptr,
1421 r_row_idx,
1422 householder_col_ptr,
1423 householder_row_idx,
1424 }
1425 }
1426
1427 #[inline]
1429 pub fn symbolic(&self) -> &SymbolicSimplicialQr<I> {
1430 self.symbolic
1431 }
1432
1433 #[inline]
1435 pub fn R_val(self) -> &'a [T] {
1436 self.r_val
1437 }
1438
1439 #[inline]
1441 pub fn R(self) -> SparseColMatRef<'a, I, T> {
1442 let n = self.symbolic().ncols();
1443 SparseColMatRef::<'_, I, T>::new(
1444 unsafe { SymbolicSparseColMatRef::new_unchecked(n, n, self.r_col_ptr, None, self.r_row_idx) },
1445 self.r_val,
1446 )
1447 }
1448
1449 #[inline]
1451 pub fn householder(self) -> SparseColMatRef<'a, I, T> {
1452 let m = self.symbolic.nrows;
1453 let n = self.symbolic.ncols;
1454 SparseColMatRef::<'_, I, T>::new(
1455 unsafe { SymbolicSparseColMatRef::new_unchecked(m, n, self.householder_col_ptr, None, self.householder_row_idx) },
1456 self.householder_val,
1457 )
1458 }
1459
1460 #[inline]
1462 pub fn householder_val(self) -> &'a [T] {
1463 self.householder_val
1464 }
1465
1466 #[inline]
1468 pub fn tau_val(self) -> &'a [T] {
1469 self.tau_val
1470 }
1471
1472 #[math]
1477 pub fn apply_qt_in_place_with_conj(&self, conj_qr: Conj, rhs: MatMut<'_, T>, par: Par, work: MatMut<'_, T>)
1478 where
1479 T: ComplexField,
1480 {
1481 let _ = par;
1482 assert!(rhs.nrows() == self.symbolic.nrows);
1483 let mut x = rhs;
1484
1485 let m = self.symbolic.nrows;
1486 let n = self.symbolic.ncols;
1487
1488 let h = SparseColMatRef::<'_, I, T>::new(
1489 unsafe { SymbolicSparseColMatRef::new_unchecked(m, n, self.householder_col_ptr, None, self.householder_row_idx) },
1490 self.householder_val,
1491 );
1492 let tau = self.tau_val;
1493
1494 let mut tmp = work;
1495 tmp.fill(zero());
1496
1497 {
1499 for j in 0..n {
1500 let hi = h.row_idx_of_col_raw(j);
1501 let hx = h.val_of_col(j);
1502 let tau_inv = recip(real(tau[j]));
1503
1504 if hi.is_empty() {
1505 tmp.rb_mut().row_mut(j).fill(zero());
1506 continue;
1507 }
1508
1509 let hi0 = hi[0].zx();
1510 for k in 0..x.ncols() {
1511 let mut dot = zero::<T>();
1512 for (i, v) in iter::zip(hi, hx) {
1513 let i = i.zx();
1514 let v = if conj_qr == Conj::Yes { copy(*v) } else { conj(*v) };
1515 dot = dot + v * x[(i, k)];
1516 }
1517 dot = mul_real(dot, tau_inv);
1518 for (i, v) in iter::zip(hi, hx) {
1519 let i = i.zx();
1520 let v = if conj_qr == Conj::Yes { conj(*v) } else { copy(*v) };
1521 x[(i, k)] = x[(i, k)] - dot * v;
1522 }
1523
1524 tmp.rb_mut().row_mut(j).copy_from(x.rb().row(hi0));
1525 }
1526 }
1527 }
1528 x.rb_mut().subrows_mut(0, n).copy_from(tmp.rb().subrows(0, n));
1529 x.rb_mut().subrows_mut(n, m - n).fill(zero());
1530 }
1531
1532 #[track_caller]
1537 #[math]
1538 pub fn solve_in_place_with_conj(&self, conj_qr: Conj, rhs: MatMut<'_, T>, par: Par, work: MatMut<'_, T>)
1539 where
1540 T: ComplexField,
1541 {
1542 let mut work = work;
1543 let mut rhs = rhs;
1544 self.apply_qt_in_place_with_conj(conj_qr, rhs.rb_mut(), par, work.rb_mut());
1545
1546 let _ = par;
1547 assert!(rhs.nrows() == self.symbolic.nrows);
1548 let mut x = rhs;
1549
1550 let n = self.symbolic.ncols;
1551 let r = SparseColMatRef::<'_, I, T>::new(
1552 unsafe { SymbolicSparseColMatRef::new_unchecked(n, n, self.r_col_ptr, None, self.r_row_idx) },
1553 self.r_val,
1554 );
1555
1556 linalg_sp::triangular_solve::solve_upper_triangular_in_place(r, conj_qr, x.rb_mut().subrows_mut(0, n), par);
1557 }
1558 }
1559
1560 pub fn factorize_simplicial_symbolic_qr_scratch<I: Index>(nrows: usize, ncols: usize) -> StackReq {
1563 let _ = nrows;
1564 StackReq::new::<I>(ncols).array(3)
1565 }
1566
1567 pub fn factorize_simplicial_symbolic_qr<I: Index>(
1570 min_col: &[I],
1571 etree: EliminationTreeRef<'_, I>,
1572 col_counts: &[I],
1573 stack: &mut MemStack,
1574 ) -> Result<SymbolicSimplicialQr<I>, FaerError> {
1575 let m = min_col.len();
1576 let n = col_counts.len();
1577
1578 let mut post = try_zeroed::<I>(n)?;
1579 let mut post_inv = try_zeroed::<I>(n)?;
1580 let mut desc_count = try_zeroed::<I>(n)?;
1581
1582 let h_non_zero_count = &mut *post_inv;
1583 for i in 0..m {
1584 let min_col = min_col[i];
1585 if min_col.to_signed() < I::Signed::truncate(0) {
1586 continue;
1587 }
1588 h_non_zero_count[min_col.zx()] += I::truncate(1);
1589 }
1590 for j in 0..n {
1591 let parent = etree.inner[j];
1592 if parent < I::Signed::truncate(0) || h_non_zero_count[j] == I::truncate(0) {
1593 continue;
1594 }
1595 h_non_zero_count[parent.zx()] += h_non_zero_count[j] - I::truncate(1);
1596 }
1597
1598 let h_nnz = I::sum_nonnegative(h_non_zero_count).ok_or(FaerError::IndexOverflow)?.zx();
1599 let l_nnz = I::sum_nonnegative(col_counts).ok_or(FaerError::IndexOverflow)?.zx();
1600
1601 postorder(&mut post, etree, stack);
1602 for (i, p) in post.iter().enumerate() {
1603 post_inv[p.zx()] = I::truncate(i);
1604 }
1605 for j in 0..n {
1606 let parent = etree.inner[j];
1607 if parent >= I::Signed::truncate(0) {
1608 desc_count[parent.zx()] = desc_count[parent.zx()] + desc_count[j] + I::truncate(1);
1609 }
1610 }
1611
1612 Ok(SymbolicSimplicialQr {
1613 nrows: m,
1614 ncols: n,
1615 postorder: post,
1616 postorder_inv: post_inv,
1617 desc_count,
1618 h_nnz,
1619 l_nnz,
1620 })
1621 }
1622
1623 pub fn factorize_simplicial_numeric_qr_scratch<I: Index, T: ComplexField>(symbolic: &SymbolicSimplicialQr<I>) -> StackReq {
1626 let m = symbolic.nrows;
1627 StackReq::all_of(&[
1628 StackReq::new::<I>(m),
1629 StackReq::new::<I>(m),
1630 StackReq::new::<I>(m),
1631 temp_mat_scratch::<T>(m, 1),
1632 ])
1633 }
1634
1635 #[math]
1645 pub fn factorize_simplicial_numeric_qr_unsorted<'a, I: Index, T: ComplexField>(
1646 r_col_ptr: &'a mut [I],
1647 r_row_idx: &'a mut [I],
1648 r_val: &'a mut [T],
1649 householder_col_ptr: &'a mut [I],
1650 householder_row_idx: &'a mut [I],
1651 householder_val: &'a mut [T],
1652 tau_val: &'a mut [T],
1653
1654 A: SparseColMatRef<'_, I, T>,
1655 col_perm: Option<PermRef<'_, I>>,
1656 symbolic: &'a SymbolicSimplicialQr<I>,
1657 stack: &mut MemStack,
1658 ) -> SimplicialQrRef<'a, I, T> {
1659 assert!(all(A.nrows() == symbolic.nrows, A.ncols() == symbolic.ncols,));
1660
1661 let I = I::truncate;
1662 let m = A.nrows();
1663 let n = A.ncols();
1664 let (r_idx, stack) = unsafe { stack.make_raw::<I::Signed>(m) };
1665 let (marked, stack) = unsafe { stack.make_raw::<I>(m) };
1666 let (pattern, stack) = unsafe { stack.make_raw::<I>(m) };
1667 let (mut x, _) = temp_mat_zeroed::<T, _, _>(m, 1, stack);
1668 let x = x.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap().as_slice_mut();
1669 marked.fill(I(0));
1670 r_idx.fill(I::Signed::truncate(NONE));
1671
1672 r_col_ptr[0] = I(0);
1673 let mut r_pos = 0usize;
1674 let mut h_pos = 0usize;
1675 for j in 0..n {
1676 let pj = col_perm.map(|perm| perm.arrays().0[j].zx()).unwrap_or(j);
1677
1678 let mut pattern_len = 0usize;
1679 for (i, val) in iter::zip(A.row_idx_of_col(pj), A.val_of_col(pj)) {
1680 if marked[i] < I(j + 1) {
1681 marked[i] = I(j + 1);
1682 pattern[pattern_len] = I(i);
1683 pattern_len += 1;
1684 }
1685 x[i] = x[i] + *val;
1686 }
1687
1688 let j_postordered = symbolic.postorder_inv[j].zx();
1689 let desc_count = symbolic.desc_count[j].zx();
1690 for d in &symbolic.postorder[j_postordered - desc_count..j_postordered] {
1691 let d = d.zx();
1692
1693 let d_h_pattern = &householder_row_idx[householder_col_ptr[d].zx()..householder_col_ptr[d + 1].zx()];
1694 let d_h_val = &householder_val[householder_col_ptr[d].zx()..householder_col_ptr[d + 1].zx()];
1695
1696 let mut intersects = false;
1697 for i in d_h_pattern {
1698 if marked[i.zx()] == I(j + 1) {
1699 intersects = true;
1700 break;
1701 }
1702 }
1703 if !intersects {
1704 continue;
1705 }
1706
1707 for i in d_h_pattern {
1708 let i = i.zx();
1709 if marked[i] < I(j + 1) {
1710 marked[i] = I(j + 1);
1711 pattern[pattern_len] = I(i);
1712 pattern_len += 1;
1713 }
1714 }
1715
1716 let tau_inv = recip(real(tau_val[d]));
1717 let mut dot = zero::<T>();
1718 for (i, vi) in iter::zip(d_h_pattern, d_h_val) {
1719 let i = i.zx();
1720 dot = dot + conj(*vi) * x[i];
1721 }
1722 dot = mul_real(dot, tau_inv);
1723 for (i, vi) in iter::zip(d_h_pattern, d_h_val) {
1724 let i = i.zx();
1725 x[i] = x[i] - dot * *vi;
1726 }
1727 }
1728 let pattern = &pattern[..pattern_len];
1729
1730 let h_begin = h_pos;
1731 for i in pattern.iter() {
1732 let i = i.zx();
1733 if r_idx[i] >= I(0).to_signed() {
1734 r_val[r_pos] = copy(x[i]);
1735 x[i] = zero();
1736 r_row_idx[r_pos] = I::from_signed(r_idx[i]);
1737 r_pos += 1;
1738 } else {
1739 householder_val[h_pos] = copy(x[i]);
1740 x[i] = zero();
1741 householder_row_idx[h_pos] = I(i);
1742 h_pos += 1;
1743 }
1744 }
1745
1746 householder_col_ptr[j + 1] = I(h_pos);
1747
1748 if h_begin == h_pos {
1749 tau_val[j] = zero();
1750 r_val[r_pos] = zero();
1751 r_row_idx[r_pos] = I(j);
1752 r_pos += 1;
1753 r_col_ptr[j + 1] = I(r_pos);
1754 continue;
1755 }
1756
1757 let mut h_col = ColMut::from_slice_mut(&mut householder_val[h_begin..h_pos]);
1758
1759 let (mut head, tail) = h_col.rb_mut().split_at_row_mut(1);
1760 let head = &mut head[0];
1761 let crate::linalg::householder::HouseholderInfo { tau, .. } = crate::linalg::householder::make_householder_in_place(head, tail);
1762 tau_val[j] = from_real(tau);
1763 r_val[r_pos] = copy(*head);
1764 *head = one();
1765
1766 r_row_idx[r_pos] = I(j);
1767 r_idx[householder_row_idx[h_begin].zx()] = I(j).to_signed();
1768 r_pos += 1;
1769 r_col_ptr[j + 1] = I(r_pos);
1770 }
1771
1772 unsafe {
1773 SimplicialQrRef::new(
1774 symbolic,
1775 SparseColMatRef::<'_, I, T>::new(SymbolicSparseColMatRef::new_unchecked(n, n, r_col_ptr, None, r_row_idx), r_val),
1776 SparseColMatRef::<'_, I, T>::new(
1777 SymbolicSparseColMatRef::new_unchecked(m, n, householder_col_ptr, None, householder_row_idx),
1778 householder_val,
1779 ),
1780 tau_val,
1781 )
1782 }
1783 }
1784}
1785
1786#[derive(Copy, Clone, Debug, Default)]
1788pub struct QrSymbolicParams<'a> {
1789 pub colamd_params: colamd::Control,
1791 pub supernodal_flop_ratio_threshold: SupernodalThreshold,
1793 pub supernodal_params: SymbolicSupernodalParams<'a>,
1795}
1796
1797#[derive(Debug)]
1799pub enum SymbolicQrRaw<I> {
1800 Simplicial(simplicial::SymbolicSimplicialQr<I>),
1802 Supernodal(supernodal::SymbolicSupernodalQr<I>),
1804}
1805
1806#[derive(Debug)]
1808pub struct SymbolicQr<I> {
1809 raw: SymbolicQrRaw<I>,
1810 col_perm_fwd: alloc::vec::Vec<I>,
1811 col_perm_inv: alloc::vec::Vec<I>,
1812 A_nnz: usize,
1813}
1814
1815#[derive(Debug)]
1817pub struct QrRef<'a, I: Index, T> {
1818 symbolic: &'a SymbolicQr<I>,
1819 indices: &'a [I],
1820 val: &'a [T],
1821}
1822
1823impl<I: Index, T> Copy for QrRef<'_, I, T> {}
1824impl<I: Index, T> Clone for QrRef<'_, I, T> {
1825 fn clone(&self) -> Self {
1826 *self
1827 }
1828}
1829
1830impl<'a, I: Index, T> QrRef<'a, I, T> {
1831 #[inline]
1837 pub unsafe fn new_unchecked(symbolic: &'a SymbolicQr<I>, indices: &'a [I], val: &'a [T]) -> Self {
1838 let val = val;
1839 assert!(all(symbolic.len_val() == val.len(), symbolic.len_idx() == indices.len(),));
1840 Self { symbolic, val, indices }
1841 }
1842
1843 #[inline]
1845 pub fn symbolic(self) -> &'a SymbolicQr<I> {
1846 self.symbolic
1847 }
1848
1849 #[track_caller]
1854 pub fn solve_in_place_with_conj(self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
1855 where
1856 T: ComplexField,
1857 {
1858 let k = rhs.ncols();
1859 let m = self.symbolic.nrows();
1860 let n = self.symbolic.ncols();
1861
1862 assert!(all(rhs.nrows() == self.symbolic.nrows(), self.symbolic.nrows() >= self.symbolic.ncols(),));
1863 let mut rhs = rhs;
1864
1865 let (mut x, stack) = unsafe { temp_mat_uninit::<T, _, _>(m, k, stack) };
1866 let mut x = x.as_mat_mut();
1867
1868 let (_, inv) = self.symbolic.col_perm().arrays();
1869 x.copy_from(rhs.rb());
1870
1871 let indices = self.indices;
1872 let val = self.val;
1873
1874 match &self.symbolic.raw {
1875 SymbolicQrRaw::Simplicial(symbolic) => {
1876 let (r_col_ptr, indices) = indices.split_at(n + 1);
1877 let (r_row_idx, indices) = indices.split_at(symbolic.len_r());
1878 let (householder_col_ptr, indices) = indices.split_at(n + 1);
1879 let (householder_row_idx, _) = indices.split_at(symbolic.len_householder());
1880
1881 let (r_val, val) = val.rb().split_at(symbolic.len_r());
1882 let (householder_val, val) = val.split_at(symbolic.len_householder());
1883 let (tau_val, _) = val.split_at(n);
1884
1885 let r = SparseColMatRef::<'_, I, T>::new(unsafe { SymbolicSparseColMatRef::new_unchecked(n, n, r_col_ptr, None, r_row_idx) }, r_val);
1886 let h = SparseColMatRef::<'_, I, T>::new(
1887 unsafe { SymbolicSparseColMatRef::new_unchecked(m, n, householder_col_ptr, None, householder_row_idx) },
1888 householder_val,
1889 );
1890
1891 let this = simplicial::SimplicialQrRef::<'_, I, T>::new(symbolic, r, h, tau_val);
1892 this.solve_in_place_with_conj(conj, x.rb_mut(), par, rhs.rb_mut());
1893 },
1894 SymbolicQrRaw::Supernodal(symbolic) => {
1895 let (householder_row_idx, indices) = indices.split_at(symbolic.householder().len_householder_row_idx());
1896 let (tau_blocksize, indices) =
1897 indices.split_at(symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes());
1898 let (householder_nrows, indices) =
1899 indices.split_at(symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes());
1900 let (householder_ncols, _) =
1901 indices.split_at(symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes());
1902
1903 let (r_val, val) = val.rb().split_at(symbolic.R_adjoint().len_val());
1904 let (householder_val, val) = val.split_at(symbolic.householder().len_householder_val());
1905 let (tau_val, _) = val.split_at(symbolic.householder().len_tau_val());
1906
1907 let this = unsafe {
1908 supernodal::SupernodalQrRef::<'_, I, T>::new_unchecked(
1909 symbolic,
1910 householder_row_idx,
1911 tau_blocksize,
1912 householder_nrows,
1913 householder_ncols,
1914 r_val,
1915 householder_val,
1916 tau_val,
1917 )
1918 };
1919 this.solve_in_place_with_conj(conj, x.rb_mut(), par, rhs.rb_mut(), stack);
1920 },
1921 }
1922
1923 for j in 0..k {
1924 for (i, p) in inv.iter().enumerate() {
1925 rhs[(i, j)] = copy(&x[(p.zx(), j)]);
1926 }
1927 }
1928 }
1929}
1930
1931impl<I: Index> SymbolicQr<I> {
1932 #[inline]
1934 pub fn nrows(&self) -> usize {
1935 match &self.raw {
1936 SymbolicQrRaw::Simplicial(this) => this.nrows(),
1937 SymbolicQrRaw::Supernodal(this) => this.householder().nrows(),
1938 }
1939 }
1940
1941 #[inline]
1943 pub fn ncols(&self) -> usize {
1944 match &self.raw {
1945 SymbolicQrRaw::Simplicial(this) => this.ncols(),
1946 SymbolicQrRaw::Supernodal(this) => this.R_adjoint().ncols(),
1947 }
1948 }
1949
1950 #[inline]
1952 pub fn col_perm(&self) -> PermRef<'_, I> {
1953 unsafe { PermRef::new_unchecked(&self.col_perm_fwd, &self.col_perm_inv, self.ncols()) }
1954 }
1955
1956 #[inline]
1959 pub fn len_idx(&self) -> usize {
1960 match &self.raw {
1961 SymbolicQrRaw::Simplicial(symbolic) => symbolic.len_r() + symbolic.len_householder() + 2 * self.ncols() + 2,
1962 SymbolicQrRaw::Supernodal(symbolic) => 4 * symbolic.householder().len_householder_row_idx() + 3 * symbolic.householder().n_supernodes(),
1963 }
1964 }
1965
1966 #[inline]
1969 pub fn len_val(&self) -> usize {
1970 match &self.raw {
1971 SymbolicQrRaw::Simplicial(symbolic) => symbolic.len_r() + symbolic.len_householder() + self.ncols(),
1972 SymbolicQrRaw::Supernodal(symbolic) => {
1973 symbolic.householder().len_householder_val() + symbolic.R_adjoint().len_val() + symbolic.householder().len_tau_val()
1974 },
1975 }
1976 }
1977
1978 pub fn solve_in_place_scratch<T>(&self, rhs_ncols: usize, par: Par) -> StackReq
1981 where
1982 T: ComplexField,
1983 {
1984 temp_mat_scratch::<T>(self.nrows(), rhs_ncols).and(match &self.raw {
1985 SymbolicQrRaw::Simplicial(_) => StackReq::empty(),
1986 SymbolicQrRaw::Supernodal(this) => this.solve_in_place_scratch::<T>(rhs_ncols, par),
1987 })
1988 }
1989
1990 pub fn factorize_numeric_qr_scratch<T>(&self, par: Par, params: Spec<QrParams, T>) -> StackReq
1992 where
1993 T: ComplexField,
1994 {
1995 let m = self.nrows();
1996 let A_nnz = self.A_nnz;
1997 let AT_scratch = StackReq::all_of(&[temp_mat_scratch::<T>(A_nnz, 1), StackReq::new::<I>(m + 1), StackReq::new::<I>(A_nnz)]);
1998
1999 match &self.raw {
2000 SymbolicQrRaw::Simplicial(symbolic) => simplicial::factorize_simplicial_numeric_qr_scratch::<I, T>(symbolic),
2001 SymbolicQrRaw::Supernodal(symbolic) => StackReq::and(
2002 AT_scratch,
2003 supernodal::factorize_supernodal_numeric_qr_scratch::<I, T>(symbolic, par, params),
2004 ),
2005 }
2006 }
2007
2008 #[track_caller]
2010 pub fn factorize_numeric_qr<'out, T: ComplexField>(
2011 &'out self,
2012 indices: &'out mut [I],
2013 val: &'out mut [T],
2014 A: SparseColMatRef<'_, I, T>,
2015 par: Par,
2016 stack: &mut MemStack,
2017 params: Spec<QrParams, T>,
2018 ) -> QrRef<'out, I, T> {
2019 assert!(all(val.len() == self.len_val(), indices.len() == self.len_idx(),));
2020 assert!(all(A.nrows() == self.nrows(), A.ncols() == self.ncols()));
2021
2022 let m = A.nrows();
2023 let n = A.ncols();
2024
2025 match &self.raw {
2026 SymbolicQrRaw::Simplicial(symbolic) => {
2027 let (r_col_ptr, indices) = indices.split_at_mut(n + 1);
2028 let (r_row_idx, indices) = indices.split_at_mut(symbolic.len_r());
2029 let (householder_col_ptr, indices) = indices.split_at_mut(n + 1);
2030 let (householder_row_idx, _) = indices.split_at_mut(symbolic.len_householder());
2031
2032 let (r_val, val) = val.split_at_mut(symbolic.len_r());
2033 let (householder_val, val) = val.split_at_mut(symbolic.len_householder());
2034 let (tau_val, _) = val.split_at_mut(n);
2035
2036 simplicial::factorize_simplicial_numeric_qr_unsorted::<I, T>(
2037 r_col_ptr,
2038 r_row_idx,
2039 r_val,
2040 householder_col_ptr,
2041 householder_row_idx,
2042 householder_val,
2043 tau_val,
2044 A,
2045 Some(self.col_perm()),
2046 symbolic,
2047 stack,
2048 );
2049 },
2050 SymbolicQrRaw::Supernodal(symbolic) => {
2051 let (householder_row_idx, indices) = indices.split_at_mut(symbolic.householder().len_householder_row_idx());
2052 let (tau_blocksize, indices) =
2053 indices.split_at_mut(symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes());
2054 let (householder_nrows, indices) =
2055 indices.split_at_mut(symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes());
2056 let (householder_ncols, _) =
2057 indices.split_at_mut(symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes());
2058
2059 let (r_val, val) = val.split_at_mut(symbolic.R_adjoint().len_val());
2060 let (householder_val, val) = val.split_at_mut(symbolic.householder().len_householder_val());
2061 let (tau_val, _) = val.split_at_mut(symbolic.householder().len_tau_val());
2062
2063 let (new_col_ptr, stack) = unsafe { stack.make_raw::<I>(m + 1) };
2064 let (new_row_idx, stack) = unsafe { stack.make_raw::<I>(self.A_nnz) };
2065 let (mut new_val, mut stack) = unsafe { temp_mat_uninit::<T, _, _>(self.A_nnz, 1, stack) };
2066 let new_val = new_val.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap().as_slice_mut();
2067
2068 let AT = utils::transpose(new_val, new_col_ptr, new_row_idx, A, stack.rb_mut()).into_const();
2069
2070 supernodal::factorize_supernodal_numeric_qr::<I, T>(
2071 householder_row_idx,
2072 tau_blocksize,
2073 householder_nrows,
2074 householder_ncols,
2075 r_val,
2076 householder_val,
2077 tau_val,
2078 AT,
2079 Some(self.col_perm()),
2080 symbolic,
2081 par,
2082 stack,
2083 params,
2084 );
2085 },
2086 }
2087
2088 unsafe { QrRef::new_unchecked(self, indices, val) }
2089 }
2090}
2091
2092#[track_caller]
2095pub fn factorize_symbolic_qr<I: Index>(A: SymbolicSparseColMatRef<'_, I>, params: QrSymbolicParams<'_>) -> Result<SymbolicQr<I>, FaerError> {
2096 assert!(A.nrows() >= A.ncols());
2097 let m = A.nrows();
2098 let n = A.ncols();
2099 let A_nnz = A.compute_nnz();
2100
2101 with_dim!(M, m);
2102 with_dim!(N, n);
2103 let A = A.as_shape(M, N);
2104
2105 let req = {
2106 let n_scratch = StackReq::new::<I>(n);
2107 let m_scratch = StackReq::new::<I>(m);
2108 let AT_scratch = StackReq::and(
2109 StackReq::new::<I>(m + 1),
2111 StackReq::new::<I>(A_nnz),
2113 );
2114
2115 StackReq::or(
2116 colamd::order_scratch::<I>(m, n, A_nnz),
2117 StackReq::all_of(&[
2118 n_scratch,
2119 n_scratch,
2120 n_scratch,
2121 n_scratch,
2122 AT_scratch,
2123 StackReq::any_of(&[
2124 StackReq::and(n_scratch, m_scratch),
2125 StackReq::all_of(&[n_scratch; 3]),
2126 StackReq::all_of(&[n_scratch, n_scratch, n_scratch, n_scratch, n_scratch, m_scratch]),
2127 supernodal::factorize_supernodal_symbolic_qr_scratch::<I>(m, n),
2128 simplicial::factorize_simplicial_symbolic_qr_scratch::<I>(m, n),
2129 ]),
2130 ]),
2131 )
2132 };
2133
2134 let mut mem = dyn_stack::MemBuffer::try_new(req).ok().ok_or(FaerError::OutOfMemory)?;
2135 let mut stack = MemStack::new(&mut mem);
2136
2137 let mut col_perm_fwd = try_zeroed::<I>(n)?;
2138 let mut col_perm_inv = try_zeroed::<I>(n)?;
2139 let mut min_row = try_zeroed::<I>(m)?;
2140
2141 colamd::order(&mut col_perm_fwd, &mut col_perm_inv, A.as_dyn(), params.colamd_params, stack.rb_mut())?;
2142
2143 let col_perm = PermRef::new_checked(&col_perm_fwd, &col_perm_inv, n).as_shape(N);
2144
2145 let (new_col_ptr, stack) = unsafe { stack.make_raw::<I>(m + 1) };
2146 let (new_row_idx, mut stack) = unsafe { stack.make_raw::<I>(A_nnz) };
2147 let AT = utils::adjoint(
2148 Symbolic::materialize(new_row_idx.len()),
2149 new_col_ptr,
2150 new_row_idx,
2151 SparseColMatRef::new(A, Symbolic::materialize(A.row_idx().len())),
2152 stack.rb_mut(),
2153 )
2154 .symbolic();
2155
2156 let (etree, stack) = unsafe { stack.make_raw::<I::Signed>(n) };
2157 let (post, stack) = unsafe { stack.make_raw::<I>(n) };
2158 let (col_counts, stack) = unsafe { stack.make_raw::<I>(n) };
2159 let (h_col_counts, mut stack) = unsafe { stack.make_raw::<I>(n) };
2160
2161 ghost_col_etree(A, Some(col_perm), Array::from_mut(etree, N), stack.rb_mut());
2162 let etree_ = Array::from_ref(MaybeIdx::<'_, I>::from_slice_ref_checked(etree, N), N);
2163 ghost_postorder(Array::from_mut(post, N), etree_, stack.rb_mut());
2164
2165 ghost_column_counts_aat(
2166 Array::from_mut(col_counts, N),
2167 Array::from_mut(bytemuck::cast_slice_mut(&mut min_row), M),
2168 AT,
2169 Some(col_perm),
2170 etree_,
2171 Array::from_ref(Idx::from_slice_ref_checked(post, N), N),
2172 stack.rb_mut(),
2173 );
2174 let min_col = min_row;
2175
2176 let mut threshold = params.supernodal_flop_ratio_threshold;
2177 if threshold != SupernodalThreshold::FORCE_SIMPLICIAL && threshold != SupernodalThreshold::FORCE_SUPERNODAL {
2178 h_col_counts.fill(I::truncate(0));
2179 for i in 0..m {
2180 let min_col = min_col[i];
2181 if min_col.to_signed() < I::Signed::truncate(0) {
2182 continue;
2183 }
2184 h_col_counts[min_col.zx()] += I::truncate(1);
2185 }
2186 for j in 0..n {
2187 let parent = etree[j];
2188 if parent < I::Signed::truncate(0) || h_col_counts[j] == I::truncate(0) {
2189 continue;
2190 }
2191 h_col_counts[parent.zx()] += h_col_counts[j] - I::truncate(1);
2192 }
2193
2194 let mut nnz = 0.0f64;
2195 let mut flops = 0.0f64;
2196 for j in 0..n {
2197 let hj = h_col_counts[j].zx() as f64;
2198 let rj = col_counts[j].zx() as f64;
2199 flops += hj + 2.0 * hj * rj;
2200 nnz += hj + rj;
2201 }
2202
2203 if flops / nnz > threshold.0 * linalg_sp::QR_SUPERNODAL_RATIO_FACTOR {
2204 threshold = SupernodalThreshold::FORCE_SUPERNODAL;
2205 } else {
2206 threshold = SupernodalThreshold::FORCE_SIMPLICIAL;
2207 }
2208 }
2209
2210 if threshold == SupernodalThreshold::FORCE_SUPERNODAL {
2211 let symbolic = supernodal::factorize_supernodal_symbolic_qr::<I>(
2212 A.as_dyn(),
2213 Some(col_perm.as_shape(n)),
2214 min_col,
2215 EliminationTreeRef::<'_, I> { inner: etree },
2216 col_counts,
2217 stack.rb_mut(),
2218 params.supernodal_params,
2219 )?;
2220 Ok(SymbolicQr {
2221 raw: SymbolicQrRaw::Supernodal(symbolic),
2222 col_perm_fwd,
2223 col_perm_inv,
2224 A_nnz,
2225 })
2226 } else {
2227 let symbolic =
2228 simplicial::factorize_simplicial_symbolic_qr::<I>(&min_col, EliminationTreeRef::<'_, I> { inner: etree }, col_counts, stack.rb_mut())?;
2229 Ok(SymbolicQr {
2230 raw: SymbolicQrRaw::Simplicial(symbolic),
2231 col_perm_fwd,
2232 col_perm_inv,
2233 A_nnz,
2234 })
2235 }
2236}
2237
2238#[cfg(test)]
2239mod tests {
2240 use super::*;
2241 use crate::assert;
2242 use crate::stats::prelude::*;
2243 use dyn_stack::MemBuffer;
2244 use linalg::solvers::SolveLstsqCore;
2245 use linalg_sp::cholesky::tests::{load_mtx, reconstruct_from_supernodal_llt};
2246 use matrix_market_rs::MtxData;
2247 use std::path::PathBuf;
2248
2249 #[test]
2250 fn test_symbolic_qr() {
2251 let n = 11;
2252 let col_ptr = &[0, 3, 6, 10, 13, 16, 21, 24, 29, 31, 37, 43usize];
2253 let row_idx = &[
2254 0, 5, 6, 1, 2, 7, 1, 2, 9, 10, 3, 5, 9, 4, 7, 10, 0, 3, 5, 8, 9, 0, 6, 10, 1, 4, 7, 9, 10, 5, 8, 2, 3, 5, 7, 9, 10, 2, 4, 6, 7, 9, 10usize, ];
2266
2267 let A = SymbolicSparseColMatRef::new_checked(n, n, col_ptr, None, row_idx);
2268 let mut etree = vec![0isize; n];
2269 let mut post = vec![0usize; n];
2270 let mut col_counts = vec![0usize; n];
2271
2272 with_dim!(N, n);
2273 let A = A.as_shape(N, N);
2274 ghost_col_etree(
2275 A,
2276 None,
2277 Array::from_mut(&mut etree, N),
2278 MemStack::new(&mut MemBuffer::new(StackReq::new::<usize>(*N + *N))),
2279 );
2280 let etree = Array::from_ref(MaybeIdx::from_slice_ref_checked(&etree, N), N);
2281 ghost_postorder(
2282 Array::from_mut(&mut post, N),
2283 etree,
2284 MemStack::new(&mut MemBuffer::new(StackReq::new::<usize>(20 * *N))),
2285 );
2286
2287 let mut min_row = vec![0usize.to_signed(); n];
2288 let mut new_col_ptr = vec![0usize; n + 1];
2289 let mut new_row_idx = vec![0usize; 43];
2290
2291 let AT = utils::adjoint(
2292 Symbolic::materialize(new_row_idx.len()),
2293 &mut new_col_ptr,
2294 &mut new_row_idx,
2295 SparseColMatRef::new(A, Symbolic::materialize(A.row_idx().len())),
2296 MemStack::new(&mut MemBuffer::new(StackReq::new::<usize>(20 * *N))),
2297 )
2298 .symbolic();
2299 ghost_column_counts_aat(
2300 Array::from_mut(&mut col_counts, N),
2301 Array::from_mut(&mut min_row, N),
2302 AT,
2303 None,
2304 etree,
2305 Array::from_ref(Idx::from_slice_ref_checked(&post, N), N),
2306 MemStack::new(&mut MemBuffer::new(StackReq::new::<usize>(20 * *N))),
2307 );
2308
2309 assert!(MaybeIdx::<'_, usize>::as_slice_ref(etree.as_ref()) == [3, 2, 3, 4, 5, 6, 7, 8, 9, 10, NONE as isize]);
2310 assert!(col_counts == [7, 6, 8, 8, 7, 6, 5, 4, 3, 2, 1usize]);
2311 }
2312
2313 #[test]
2314 fn test_numeric_qr_1_no_transpose() {
2315 type I = usize;
2316
2317 let (m, n, col_ptr, row_idx, val) =
2318 load_mtx::<usize>(MtxData::from_file(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_qr/lp_share2b.mtx")).unwrap());
2319
2320 let nnz = row_idx.len();
2321
2322 let A = SparseColMatRef::<'_, I, f64>::new(SymbolicSparseColMatRef::new_checked(m, n, &col_ptr, None, &row_idx), &val);
2323
2324 with_dim!(M, m);
2325 with_dim!(N, n);
2326
2327 let A = A.as_shape(M, N);
2328 let mut new_col_ptr = vec![0usize; m + 1];
2329 let mut new_row_idx = vec![0usize; nnz];
2330 let mut new_val = vec![0.0; nnz];
2331
2332 let AT = utils::adjoint(
2333 &mut new_val,
2334 &mut new_col_ptr,
2335 &mut new_row_idx,
2336 A,
2337 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(20 * *N))),
2338 )
2339 .into_const();
2340
2341 let mut etree = vec![0usize.to_signed(); n];
2342 let mut post = vec![0usize; n];
2343 let mut col_counts = vec![0usize; n];
2344 let mut min_row = vec![0usize; m];
2345
2346 ghost_col_etree(
2347 A.symbolic(),
2348 None,
2349 Array::from_mut(&mut etree, N),
2350 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(*M + *N))),
2351 );
2352 let etree_ = Array::from_ref(MaybeIdx::from_slice_ref_checked(&etree, N), N);
2353 ghost_postorder(
2354 Array::from_mut(&mut post, N),
2355 etree_,
2356 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(20 * *N))),
2357 );
2358
2359 ghost_column_counts_aat(
2360 Array::from_mut(&mut col_counts, N),
2361 Array::from_mut(bytemuck::cast_slice_mut(&mut min_row), M),
2362 AT.symbolic(),
2363 None,
2364 etree_,
2365 Array::from_ref(Idx::from_slice_ref_checked(&post, N), N),
2366 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(20 * *N))),
2367 );
2368
2369 let min_col = min_row;
2370
2371 let symbolic = supernodal::factorize_supernodal_symbolic_qr::<I>(
2372 A.symbolic().as_dyn(),
2373 None,
2374 min_col,
2375 EliminationTreeRef::<'_, I> { inner: &etree },
2376 &col_counts,
2377 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(20 * *N))),
2378 Default::default(),
2379 )
2380 .unwrap();
2381
2382 let mut householder_row_idx = vec![0usize; symbolic.householder().len_householder_row_idx()];
2383
2384 let mut L_val = vec![0.0; symbolic.R_adjoint().len_val()];
2385 let mut householder_val = vec![0.0; symbolic.householder().len_householder_val()];
2386 let mut tau_val = vec![0.0; symbolic.householder().len_tau_val()];
2387
2388 let mut tau_blocksize = vec![0usize; symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes()];
2389 let mut householder_nrows = vec![0usize; symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes()];
2390 let mut householder_ncols = vec![0usize; symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes()];
2391
2392 supernodal::factorize_supernodal_numeric_qr::<I, f64>(
2393 &mut householder_row_idx,
2394 &mut tau_blocksize,
2395 &mut householder_nrows,
2396 &mut householder_ncols,
2397 &mut L_val,
2398 &mut householder_val,
2399 &mut tau_val,
2400 AT.as_dyn(),
2401 None,
2402 &symbolic,
2403 Par::Seq,
2404 MemStack::new(&mut MemBuffer::new(supernodal::factorize_supernodal_numeric_qr_scratch::<usize, f64>(
2405 &symbolic,
2406 Par::Seq,
2407 Default::default(),
2408 ))),
2409 Default::default(),
2410 );
2411 let llt = reconstruct_from_supernodal_llt::<I, f64>(symbolic.R_adjoint(), &L_val);
2412 let a = A.as_dyn().to_dense();
2413 let ata = a.adjoint() * &a;
2414
2415 let llt_diff = &llt - &ata;
2416 assert!(llt_diff.norm_max() <= 1e-10);
2417 }
2418
2419 #[test]
2420 fn test_numeric_qr_1_transpose() {
2421 type I = usize;
2422 type T = c64;
2423
2424 let mut gen = rand::rngs::StdRng::seed_from_u64(0);
2425
2426 let (m, n, col_ptr, row_idx, val) =
2427 load_mtx::<usize>(MtxData::from_file(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_qr/lp_share2b.mtx")).unwrap());
2428 let val = val.iter().map(|&x| c64::new(x, gen.gen())).collect::<Vec<_>>();
2429
2430 let nnz = row_idx.len();
2431
2432 let A = SparseColMatRef::<'_, I, T>::new(SymbolicSparseColMatRef::new_checked(m, n, &col_ptr, None, &row_idx), &val);
2433
2434 with_dim!(M, m);
2435 with_dim!(N, n);
2436 let A = A.as_shape(M, N);
2437 let mut new_col_ptr = vec![0usize; m + 1];
2438 let mut new_row_idx = vec![0usize; nnz];
2439 let mut new_val = vec![T::ZERO; nnz];
2440
2441 let AT = utils::transpose(
2442 &mut new_val,
2443 &mut new_col_ptr,
2444 &mut new_row_idx,
2445 A,
2446 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(*M))),
2447 )
2448 .into_const();
2449
2450 let (A, AT) = (AT, A);
2451 let (M, N) = (N, M);
2452 let (m, n) = (n, m);
2453
2454 let mut etree = vec![0usize.to_signed(); n];
2455 let mut post = vec![0usize; n];
2456 let mut col_counts = vec![0usize; n];
2457 let mut min_row = vec![0usize; m];
2458
2459 ghost_col_etree(
2460 A.symbolic(),
2461 None,
2462 Array::from_mut(&mut etree, N),
2463 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(*M + *N))),
2464 );
2465 let etree_ = Array::from_ref(MaybeIdx::from_slice_ref_checked(&etree, N), N);
2466 ghost_postorder(
2467 Array::from_mut(&mut post, N),
2468 etree_,
2469 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(3 * *N))),
2470 );
2471
2472 ghost_column_counts_aat(
2473 Array::from_mut(&mut col_counts, N),
2474 Array::from_mut(bytemuck::cast_slice_mut(&mut min_row), M),
2475 AT.symbolic(),
2476 None,
2477 etree_,
2478 Array::from_ref(Idx::from_slice_ref_checked(&post, N), N),
2479 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(5 * *N + *M))),
2480 );
2481
2482 let min_col = min_row;
2483
2484 let symbolic = supernodal::factorize_supernodal_symbolic_qr::<I>(
2485 A.symbolic().as_dyn(),
2486 None,
2487 min_col,
2488 EliminationTreeRef::<'_, I> { inner: &etree },
2489 &col_counts,
2490 MemStack::new(&mut MemBuffer::new(supernodal::factorize_supernodal_symbolic_qr_scratch::<usize>(*M, *N))),
2491 Default::default(),
2492 )
2493 .unwrap();
2494
2495 let mut householder_row_idx = vec![0usize; symbolic.householder().len_householder_row_idx()];
2496
2497 let mut L_val = vec![T::ZERO; symbolic.R_adjoint().len_val()];
2498 let mut householder_val = vec![T::ZERO; symbolic.householder().len_householder_val()];
2499 let mut tau_val = vec![T::ZERO; symbolic.householder().len_tau_val()];
2500
2501 let mut tau_blocksize = vec![0usize; symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes()];
2502 let mut householder_nrows = vec![0usize; symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes()];
2503 let mut householder_ncols = vec![0usize; symbolic.householder().len_householder_row_idx() + symbolic.householder().n_supernodes()];
2504
2505 let qr = supernodal::factorize_supernodal_numeric_qr::<I, T>(
2506 &mut householder_row_idx,
2507 &mut tau_blocksize,
2508 &mut householder_nrows,
2509 &mut householder_ncols,
2510 &mut L_val,
2511 &mut householder_val,
2512 &mut tau_val,
2513 AT.as_dyn(),
2514 None,
2515 &symbolic,
2516 Par::Seq,
2517 MemStack::new(&mut MemBuffer::new(supernodal::factorize_supernodal_numeric_qr_scratch::<usize, T>(
2518 &symbolic,
2519 Par::Seq,
2520 Default::default(),
2521 ))),
2522 Default::default(),
2523 );
2524
2525 let a = A.as_dyn().to_dense();
2526
2527 let rhs = Mat::<T>::from_fn(m, 2, |_, _| c64::new(gen.gen(), gen.gen()));
2528 let mut x = rhs.clone();
2529 let mut work = rhs.clone();
2530 qr.solve_in_place_with_conj(
2531 Conj::No,
2532 x.as_mut(),
2533 Par::Seq,
2534 work.as_mut(),
2535 MemStack::new(&mut MemBuffer::new(symbolic.solve_in_place_scratch::<T>(2, Par::Seq))),
2536 );
2537 let x = x.as_ref().subrows(0, n);
2538
2539 let linsolve_diff = a.adjoint() * (&a * &x - &rhs);
2540
2541 let llt = reconstruct_from_supernodal_llt::<I, T>(symbolic.R_adjoint(), &L_val);
2542 let ata = a.adjoint() * &a;
2543
2544 let llt_diff = &llt - &ata;
2545 assert!(llt_diff.norm_max() <= 1e-10);
2546 assert!(linsolve_diff.norm_max() <= 1e-10);
2547 }
2548
2549 #[test]
2550 fn test_numeric_simplicial_qr_1_transpose() {
2551 type I = usize;
2552 type T = c64;
2553
2554 let mut gen = rand::rngs::StdRng::seed_from_u64(0);
2555
2556 let (m, n, col_ptr, row_idx, val) =
2557 load_mtx::<usize>(MtxData::from_file(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_qr/lp_share2b.mtx")).unwrap());
2558
2559 let val = val.iter().map(|&x| c64::new(x, gen.gen())).collect::<Vec<_>>();
2560
2561 let nnz = row_idx.len();
2562
2563 let A = SparseColMatRef::<'_, I, T>::new(SymbolicSparseColMatRef::new_checked(m, n, &col_ptr, None, &row_idx), &val);
2564
2565 with_dim!(M, m);
2566 with_dim!(N, n);
2567 let A = A.as_shape(M, N);
2568 let mut new_col_ptr = vec![0usize; m + 1];
2569 let mut new_row_idx = vec![0usize; nnz];
2570 let mut new_val = vec![T::ZERO; nnz];
2571
2572 let AT = utils::transpose(
2573 &mut new_val,
2574 &mut new_col_ptr,
2575 &mut new_row_idx,
2576 A,
2577 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(*M))),
2578 )
2579 .into_const();
2580
2581 let (A, AT) = (AT, A);
2582 let (M, N) = (N, M);
2583 let (m, n) = (n, m);
2584
2585 let mut etree = vec![0usize.to_signed(); n];
2586 let mut post = vec![0usize; n];
2587 let mut col_counts = vec![0usize; n];
2588 let mut min_row = vec![0usize; m];
2589
2590 ghost_col_etree(
2591 A.symbolic(),
2592 None,
2593 Array::from_mut(&mut etree, N),
2594 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(*M + *N))),
2595 );
2596 let etree_ = Array::from_ref(MaybeIdx::from_slice_ref_checked(&etree, N), N);
2597 ghost_postorder(
2598 Array::from_mut(&mut post, N),
2599 etree_,
2600 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(3 * *N))),
2601 );
2602
2603 ghost_column_counts_aat(
2604 Array::from_mut(&mut col_counts, N),
2605 Array::from_mut(bytemuck::cast_slice_mut(&mut min_row), M),
2606 AT.symbolic(),
2607 None,
2608 etree_,
2609 Array::from_ref(Idx::from_slice_ref_checked(&post, N), N),
2610 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(5 * *N + *M))),
2611 );
2612
2613 let min_col = min_row;
2614
2615 let symbolic = simplicial::factorize_simplicial_symbolic_qr::<I>(
2616 &min_col,
2617 EliminationTreeRef::<'_, I> { inner: &etree },
2618 &col_counts,
2619 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(3 * *N))),
2620 )
2621 .unwrap();
2622
2623 let mut r_col_ptr = vec![0usize; n + 1];
2624 let mut r_row_idx = vec![0usize; symbolic.len_r()];
2625 let mut householder_col_ptr = vec![0usize; n + 1];
2626 let mut householder_row_idx = vec![0usize; symbolic.len_householder()];
2627
2628 let mut r_val = vec![T::ZERO; symbolic.len_r()];
2629 let mut householder_val = vec![T::ZERO; symbolic.len_householder()];
2630 let mut tau_val = vec![T::ZERO; n];
2631
2632 let qr = simplicial::factorize_simplicial_numeric_qr_unsorted(
2633 &mut r_col_ptr,
2634 &mut r_row_idx,
2635 &mut r_val,
2636 &mut householder_col_ptr,
2637 &mut householder_row_idx,
2638 &mut householder_val,
2639 &mut tau_val,
2640 A.as_dyn(),
2641 None,
2642 &symbolic,
2643 MemStack::new(&mut MemBuffer::new(simplicial::factorize_simplicial_numeric_qr_scratch::<usize, T>(
2644 &symbolic,
2645 ))),
2646 );
2647
2648 let a = A.as_dyn().to_dense();
2649 let rhs = Mat::<T>::from_fn(m, 2, |_, _| c64::new(gen.gen(), gen.gen()));
2650 {
2651 let mut x = rhs.clone();
2652 let mut work = rhs.clone();
2653 qr.solve_in_place_with_conj(Conj::No, x.as_mut(), Par::Seq, work.as_mut());
2654
2655 let mut y = rhs.clone();
2656 A.to_dense().as_dyn().qr().solve_lstsq_in_place_with_conj(Conj::No, y.as_mut());
2657
2658 let x = x.as_ref().subrows(0, n);
2659 let linsolve_diff = a.adjoint() * (&a * &x - &rhs);
2660 assert!(linsolve_diff.norm_max() <= 1e-10);
2661 }
2662 {
2663 let mut x = rhs.clone();
2664 let mut work = rhs.clone();
2665 qr.solve_in_place_with_conj(Conj::Yes, x.as_mut(), Par::Seq, work.as_mut());
2666
2667 let x = x.as_ref().subrows(0, n);
2668 let a = a.conjugate();
2669 let linsolve_diff = a.adjoint() * (a * &x - &rhs);
2670 assert!(linsolve_diff.norm_max() <= 1e-10);
2671 }
2672
2673 let R = SparseColMatRef::<'_, usize, T>::new(SymbolicSparseColMatRef::new_unsorted_checked(n, n, &r_col_ptr, None, &r_row_idx), &r_val);
2674 let r = R.to_dense();
2675 let ata = a.adjoint() * &a;
2676 let rtr = r.adjoint() * &r;
2677 assert!((&ata - &rtr).norm_max() < 1e-10);
2678 }
2679
2680 #[test]
2681 fn test_solver_qr_1_transpose() {
2682 type I = usize;
2683 type T = c64;
2684
2685 let mut gen = rand::rngs::StdRng::seed_from_u64(0);
2686
2687 let (m, n, col_ptr, row_idx, val) =
2688 load_mtx::<usize>(MtxData::from_file(PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_qr/lp_share2b.mtx")).unwrap());
2689 let val = val.iter().map(|&x| c64::new(x, gen.gen())).collect::<Vec<_>>();
2690 let nnz = row_idx.len();
2691
2692 let A = SparseColMatRef::<'_, I, T>::new(SymbolicSparseColMatRef::new_checked(m, n, &col_ptr, None, &row_idx), &val);
2693
2694 let mut new_col_ptr = vec![0usize; m + 1];
2695 let mut new_row_idx = vec![0usize; nnz];
2696 let mut new_val = vec![T::ZERO; nnz];
2697
2698 let AT = utils::transpose(
2699 &mut new_val,
2700 &mut new_col_ptr,
2701 &mut new_row_idx,
2702 A,
2703 MemStack::new(&mut MemBuffer::new(StackReq::new::<I>(m))),
2704 )
2705 .into_const();
2706 let A = AT;
2707 let (m, n) = (n, m);
2708
2709 let a = A.to_dense();
2710 let rhs = Mat::<T>::from_fn(m, 2, |_, _| c64::new(gen.gen(), gen.gen()));
2711
2712 for supernodal_flop_ratio_threshold in [
2713 SupernodalThreshold::FORCE_SUPERNODAL,
2714 SupernodalThreshold::FORCE_SIMPLICIAL,
2715 SupernodalThreshold::AUTO,
2716 ] {
2717 let symbolic = factorize_symbolic_qr(
2718 A.symbolic(),
2719 QrSymbolicParams {
2720 supernodal_flop_ratio_threshold,
2721 ..Default::default()
2722 },
2723 )
2724 .unwrap();
2725 let mut indices = vec![0usize; symbolic.len_idx()];
2726 let mut val = vec![T::ZERO; symbolic.len_val()];
2727 let qr = symbolic.factorize_numeric_qr::<T>(
2728 &mut indices,
2729 &mut val,
2730 A,
2731 Par::Seq,
2732 MemStack::new(&mut MemBuffer::new(
2733 symbolic.factorize_numeric_qr_scratch::<T>(Par::Seq, Default::default()),
2734 )),
2735 Default::default(),
2736 );
2737
2738 {
2739 let mut x = rhs.clone();
2740 qr.solve_in_place_with_conj(
2741 Conj::No,
2742 x.as_mut(),
2743 Par::Seq,
2744 MemStack::new(&mut MemBuffer::new(symbolic.solve_in_place_scratch::<T>(2, Par::Seq))),
2745 );
2746
2747 let x = x.as_ref().subrows(0, n);
2748 let linsolve_diff = a.adjoint() * (&a * &x - &rhs);
2749 assert!(linsolve_diff.norm_max() <= 1e-10);
2750 }
2751 {
2752 let mut x = rhs.clone();
2753 qr.solve_in_place_with_conj(
2754 Conj::Yes,
2755 x.as_mut(),
2756 Par::Seq,
2757 MemStack::new(&mut MemBuffer::new(symbolic.solve_in_place_scratch::<T>(2, Par::Seq))),
2758 );
2759
2760 let x = x.as_ref().subrows(0, n);
2761 let a = a.conjugate();
2762 let linsolve_diff = a.adjoint() * (a * &x - &rhs);
2763 assert!(linsolve_diff.norm_max() <= 1e-10);
2764 }
2765 }
2766 }
2767
2768 #[test]
2769 fn test_solver_qr_edge_case() {
2770 type I = usize;
2771 type T = c64;
2772
2773 let mut gen = rand::rngs::StdRng::seed_from_u64(0);
2774
2775 let a0_col_ptr = vec![0usize; 21];
2776 let A0 = SparseColMatRef::<'_, I, T>::new(SymbolicSparseColMatRef::new_checked(40, 20, &a0_col_ptr, None, &[]), &[]);
2777
2778 let a1_val = [c64::new(gen.gen(), gen.gen()), c64::new(gen.gen(), gen.gen())];
2779 let A1 = SparseColMatRef::<'_, I, T>::new(SymbolicSparseColMatRef::new_checked(40, 5, &[0, 1, 2, 2, 2, 2], None, &[0, 0]), &a1_val);
2780 let A2 = SparseColMatRef::<'_, I, T>::new(SymbolicSparseColMatRef::new_checked(40, 5, &[0, 1, 2, 2, 2, 2], None, &[4, 4]), &a1_val);
2781
2782 for A in [A0, A1, A2] {
2783 for supernodal_flop_ratio_threshold in [
2784 SupernodalThreshold::AUTO,
2785 SupernodalThreshold::FORCE_SUPERNODAL,
2786 SupernodalThreshold::FORCE_SIMPLICIAL,
2787 ] {
2788 let symbolic = factorize_symbolic_qr(
2789 A.symbolic(),
2790 QrSymbolicParams {
2791 supernodal_flop_ratio_threshold,
2792 ..Default::default()
2793 },
2794 )
2795 .unwrap();
2796 let mut indices = vec![0usize; symbolic.len_idx()];
2797 let mut val = vec![T::ZERO; symbolic.len_val()];
2798 symbolic.factorize_numeric_qr::<T>(
2799 &mut indices,
2800 &mut val,
2801 A,
2802 Par::Seq,
2803 MemStack::new(&mut MemBuffer::new(
2804 symbolic.factorize_numeric_qr_scratch::<T>(Par::Seq, Default::default()),
2805 )),
2806 Default::default(),
2807 );
2808 }
2809 }
2810 }
2811}