1use super::super::utils::*;
500use super::ghost;
501use crate::assert;
502use crate::internal_prelude_sp::*;
503use linalg::cholesky::lblt::factor::{LbltInfo, LbltParams};
504use linalg::cholesky::ldlt::factor::{LdltError, LdltInfo, LdltParams, LdltRegularization};
505use linalg::cholesky::llt::factor::{LltError, LltInfo, LltParams, LltRegularization};
506use linalg_sp::{SupernodalThreshold, SymbolicSupernodalParams, amd, triangular_solve};
507
508#[derive(Copy, Clone, Debug, Default)]
510pub enum SymmetricOrdering<'a, I: Index> {
511 #[default]
513 Amd,
514 Identity,
516 Custom(PermRef<'a, I>),
518}
519
520pub mod simplicial {
526 use super::*;
527 use crate::assert;
528
529 #[derive(Copy, Clone, Debug)]
535 pub struct EliminationTreeRef<'a, I: Index> {
536 pub(crate) inner: &'a [I::Signed],
537 }
538
539 impl<'a, I: Index> EliminationTreeRef<'a, I> {
540 pub fn len(&self) -> usize {
542 self.inner.len()
543 }
544
545 #[inline]
550 pub fn into_inner(self) -> &'a [I::Signed] {
551 self.inner
552 }
553
554 #[inline]
560 pub unsafe fn from_inner(inner: &'a [I::Signed]) -> Self {
561 Self { inner }
562 }
563
564 #[inline]
565 #[track_caller]
566 pub(crate) fn as_bound<'n>(self, N: ghost::Dim<'n>) -> &'a Array<'n, MaybeIdx<'n, I>> {
567 assert!(self.inner.len() == *N);
568 unsafe { Array::from_ref(MaybeIdx::from_slice_ref_unchecked(self.inner), N) }
569 }
570 }
571
572 pub fn prefactorize_symbolic_cholesky_scratch<I: Index>(n: usize, nnz: usize) -> StackReq {
575 _ = nnz;
576 StackReq::new::<I>(n)
577 }
578
579 pub fn prefactorize_symbolic_cholesky<'out, I: Index>(
585 etree: &'out mut [I::Signed],
586 col_counts: &mut [I],
587 A: SymbolicSparseColMatRef<'_, I>,
588 stack: &mut MemStack,
589 ) -> EliminationTreeRef<'out, I> {
590 let n = A.nrows();
591 assert!(A.nrows() == A.ncols());
592 assert!(etree.len() == n);
593 assert!(col_counts.len() == n);
594
595 with_dim!(N, n);
596 ghost_prefactorize_symbolic_cholesky(Array::from_mut(etree, N), Array::from_mut(col_counts, N), A.as_shape(N, N), stack);
597
598 simplicial::EliminationTreeRef { inner: etree }
599 }
600
601 fn ghost_prefactorize_symbolic_cholesky<'n, 'out, I: Index>(
603 etree: &'out mut Array<'n, I::Signed>,
604 col_counts: &mut Array<'n, I>,
605 A: SymbolicSparseColMatRef<'_, I, Dim<'n>, Dim<'n>>,
606 stack: &mut MemStack,
607 ) -> &'out mut Array<'n, MaybeIdx<'n, I>> {
608 let N = A.ncols();
609 let (visited, _) = unsafe { stack.make_raw::<I>(*N) };
610 let etree = Array::from_mut(ghost::fill_none::<I>(etree.as_mut(), N), N);
611 let visited = Array::from_mut(visited, N);
612
613 for j in N.indices() {
614 let j_ = j.truncate::<I>();
615 visited[j] = *j_;
616 col_counts[j] = I::truncate(1);
617
618 for mut i in A.row_idx_of_col(j) {
619 if i < j {
620 loop {
621 if visited[i] == *j_ {
622 break;
623 }
624
625 let next_i = if let Some(parent) = etree[i].idx() {
626 parent.zx()
627 } else {
628 etree[i] = MaybeIdx::from_index(j_);
629 j
630 };
631
632 col_counts[i] += I::truncate(1);
633 visited[i] = *j_;
634 i = next_i;
635 }
636 }
637 }
638 }
639
640 etree
641 }
642
643 fn ereach<'n, 'a, I: Index>(
644 stack: &'a mut Array<'n, I>,
645 A: SymbolicSparseColMatRef<'_, I, Dim<'n>, Dim<'n>>,
646 etree: &Array<'n, MaybeIdx<'n, I>>,
647 k: Idx<'n, usize>,
648 visited: &mut Array<'n, I::Signed>,
649 ) -> &'a [Idx<'n, I>] {
650 let N = A.ncols();
651
652 let mut top = *N;
654 let k_: I = *k.truncate();
655 visited[k] = k_.to_signed();
656 for mut i in A.row_idx_of_col(k) {
657 if i >= k {
659 continue;
660 }
661 let mut len = 0usize;
663 loop {
664 if visited[i] == k_.to_signed() {
665 break;
666 }
667
668 let pushed: Idx<'n, I> = i.truncate::<I>();
670 stack[N.check(len)] = *pushed;
671 len += 1;
673
674 visited[i] = k_.to_signed();
675 i = N.check(etree[i].unbound().zx());
676 }
677
678 stack.as_mut().copy_within(..len, top - len);
681 top -= len;
683 }
684
685 let stack = &stack.as_ref()[top..];
686
687 unsafe { Idx::from_slice_ref_unchecked(stack) }
689 }
690
691 pub fn factorize_simplicial_symbolic_cholesky_scratch<I: Index>(n: usize) -> StackReq {
694 let n_scratch = StackReq::new::<I>(n);
695 StackReq::all_of(&[n_scratch, n_scratch, n_scratch])
696 }
697
698 pub fn factorize_simplicial_symbolic_cholesky<I: Index>(
708 A: SymbolicSparseColMatRef<'_, I>,
709 etree: EliminationTreeRef<'_, I>,
710 col_counts: &[I],
711 stack: &mut MemStack,
712 ) -> Result<SymbolicSimplicialCholesky<I>, FaerError> {
713 let n = A.nrows();
714 assert!(A.nrows() == A.ncols());
715 assert!(etree.inner.len() == n);
716 assert!(col_counts.len() == n);
717
718 with_dim!(N, n);
719 ghost_factorize_simplicial_symbolic_cholesky(A.as_shape(N, N), etree.as_bound(N), Array::from_ref(col_counts, N), stack)
720 }
721
722 pub(crate) fn ghost_factorize_simplicial_symbolic_cholesky<'n, I: Index>(
723 A: SymbolicSparseColMatRef<'_, I, Dim<'n>, Dim<'n>>,
724 etree: &Array<'n, MaybeIdx<'n, I>>,
725 col_counts: &Array<'n, I>,
726 stack: &mut MemStack,
727 ) -> Result<SymbolicSimplicialCholesky<I>, FaerError> {
728 let N = A.ncols();
729 let n = *N;
730
731 let mut L_col_ptr = try_zeroed::<I>(n + 1)?;
732 for (&count, [p, p_next]) in iter::zip(col_counts.as_ref(), windows2(Cell::as_slice_of_cells(Cell::from_mut(&mut L_col_ptr)))) {
733 p_next.set(p.get() + count);
734 }
735 let l_nnz = L_col_ptr[n].zx();
736 let mut L_row_idx = try_zeroed::<I>(l_nnz)?;
737
738 with_dim!(L_NNZ, l_nnz);
739 let (current_row_idxex, stack) = unsafe { stack.make_raw::<I>(n) };
740 let (ereach_stack, stack) = unsafe { stack.make_raw::<I>(n) };
741 let (visited, _) = unsafe { stack.make_raw::<I::Signed>(n) };
742
743 let ereach_stack = Array::from_mut(ereach_stack, N);
744 let visited = Array::from_mut(visited, N);
745
746 visited.as_mut().fill(I::Signed::truncate(NONE));
747 {
748 let L_row_idx = Array::from_mut(&mut L_row_idx, L_NNZ);
749 let L_col_ptr_start = Array::from_ref(Idx::from_slice_ref_checked(&L_col_ptr[..n], L_NNZ), N);
750 let current_row_idxex = Array::from_mut(ghost::copy_slice(current_row_idxex, L_col_ptr_start.as_ref()), N);
751
752 for k in N.indices() {
753 let reach = ereach(ereach_stack, A, etree, k, visited);
754 for &j in reach {
755 let j = j.zx();
756 let cj = &mut current_row_idxex[j];
757 let row_idx = L_NNZ.check(*cj.zx() + 1);
758 *cj = row_idx.truncate();
759 L_row_idx[row_idx] = *k.truncate();
760 }
761 let k_start = L_col_ptr_start[k].zx();
762 L_row_idx[k_start] = *k.truncate();
763 }
764 }
765
766 let etree = try_collect(
767 bytemuck::cast_slice::<I::Signed, I>(MaybeIdx::as_slice_ref(etree.as_ref()))
768 .iter()
769 .copied(),
770 )?;
771
772 let _ = SymbolicSparseColMatRef::new_unsorted_checked(n, n, &L_col_ptr, None, &L_row_idx);
773 Ok(SymbolicSimplicialCholesky {
774 dimension: n,
775 col_ptr: L_col_ptr,
776 row_idx: L_row_idx,
777 etree,
778 })
779 }
780
781 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
782 enum FactorizationKind {
783 Llt,
784 Ldlt,
785 }
786
787 #[math]
788 fn factorize_simplicial_numeric_with_row_idx<I: Index, T: ComplexField>(
789 L_values: &mut [T],
790 L_row_idx: &mut [I],
791 L_col_ptr: &[I],
792 kind: FactorizationKind,
793
794 etree: EliminationTreeRef<'_, I>,
795 A: SparseColMatRef<'_, I, T>,
796 regularization: LdltRegularization<'_, T::Real>,
797
798 stack: &mut MemStack,
799 ) -> Result<LltInfo, LltError> {
800 let n = A.ncols();
801
802 assert!(L_values.len() == L_row_idx.len());
803 assert!(L_col_ptr.len() == n + 1);
804 assert!(etree.len() == n);
805 let l_nnz = L_col_ptr[n].zx();
806
807 with_dim!(N, n);
808 with_dim!(L_NNZ, l_nnz);
809
810 let etree = etree.as_bound(N);
811 let A = A.as_shape(N, N);
812
813 let eps = abs(regularization.dynamic_regularization_epsilon);
814 let delta = abs(regularization.dynamic_regularization_delta);
815 let has_delta = delta > zero::<T::Real>();
816 let mut dynamic_regularization_count = 0usize;
817
818 let (mut x, stack) = temp_mat_zeroed::<T, _, _>(n, 1, stack);
819 let mut x = x.as_mat_mut().col_mut(0).as_row_shape_mut(N);
820
821 let (current_row_idxex, stack) = unsafe { stack.make_raw::<I>(n) };
822 let (ereach_stack, stack) = unsafe { stack.make_raw::<I>(n) };
823 let (visited, _) = unsafe { stack.make_raw::<I::Signed>(n) };
824
825 let ereach_stack = Array::from_mut(ereach_stack, N);
826 let visited = Array::from_mut(visited, N);
827
828 visited.as_mut().fill(I::Signed::truncate(NONE));
829
830 let L_values = Array::from_mut(L_values, L_NNZ);
831 let L_row_idx = Array::from_mut(L_row_idx, L_NNZ);
832
833 let L_col_ptr_start = Array::from_ref(Idx::from_slice_ref_checked(&L_col_ptr[..n], L_NNZ), N);
834
835 let current_row_idxex = Array::from_mut(ghost::copy_slice(current_row_idxex, L_col_ptr_start.as_ref()), N);
836
837 for k in N.indices() {
838 let reach = ereach(ereach_stack, A.symbolic(), etree, k, visited);
839
840 for (i, aik) in iter::zip(A.row_idx_of_col(k), A.val_of_col(k)) {
841 x[i] = x[i] + conj(aik);
842 }
843
844 let mut d = real(x[k]);
845 x[k] = zero::<T>();
846
847 for &j in reach {
848 let j = j.zx();
849
850 let j_start = L_col_ptr_start[j].zx();
851 let cj = &mut current_row_idxex[j];
852 let row_idx = L_NNZ.check(*cj.zx() + 1);
853 *cj = row_idx.truncate();
854
855 let mut xj = copy(x[j]);
856 x[j] = zero::<T>();
857
858 let dj = recip(real(L_values[j_start]));
859 let lkj = mul_real(xj, dj);
860 if kind == FactorizationKind::Llt {
861 xj = copy(lkj);
862 }
863
864 let range = j_start.next()..row_idx.into();
865 for (i, lij) in iter::zip(&L_row_idx[range.clone()], &L_values[range]) {
866 let i = N.check(i.zx());
867 x[i] = x[i] - conj(*lij) * xj;
868 }
869
870 d = d - real(lkj * conj(xj));
871
872 L_values[row_idx] = lkj;
873 L_row_idx[row_idx] = *k.truncate();
874 }
875
876 let k_start = L_col_ptr_start[k].zx();
877 L_row_idx[k_start] = *k.truncate();
878
879 if has_delta {
880 match kind {
881 FactorizationKind::Llt => {
882 if d <= eps {
883 d = copy(delta);
884 dynamic_regularization_count += 1;
885 }
886 },
887 FactorizationKind::Ldlt => {
888 if let Some(signs) = regularization.dynamic_regularization_signs {
889 if signs[*k] > 0 && d <= eps {
890 d = copy(delta);
891 dynamic_regularization_count += 1;
892 } else if signs[*k] < 0 && d >= -eps {
893 d = -delta;
894 dynamic_regularization_count += 1;
895 }
896 } else if abs(d) <= eps {
897 if d < zero::<T::Real>() {
898 d = -delta;
899 dynamic_regularization_count += 1;
900 } else {
901 d = copy(delta);
902 dynamic_regularization_count += 1;
903 }
904 }
905 },
906 }
907 }
908
909 match kind {
910 FactorizationKind::Llt => {
911 if !(d > zero::<T::Real>()) {
912 return Err(LltError::NonPositivePivot { index: *k + 1 });
913 }
914 L_values[k_start] = from_real(sqrt(d));
915 },
916 FactorizationKind::Ldlt => {
917 if d == zero::<T::Real>() || !is_finite(d) {
918 return Err(LltError::NonPositivePivot { index: *k + 1 });
919 }
920 L_values[k_start] = from_real(d);
921 },
922 }
923 }
924 Ok(LltInfo {
925 dynamic_regularization_count,
926 })
927 }
928
929 #[math]
930 fn factorize_simplicial_numeric_cholesky<I: Index, T: ComplexField>(
931 L_values: &mut [T],
932 kind: FactorizationKind,
933 A: SparseColMatRef<'_, I, T>,
934 regularization: LdltRegularization<'_, T::Real>,
935 symbolic: &SymbolicSimplicialCholesky<I>,
936 stack: &mut MemStack,
937 ) -> Result<LltInfo, LltError> {
938 let n = A.ncols();
939 let L_row_idx = &*symbolic.row_idx;
940 let L_col_ptr = &*symbolic.col_ptr;
941 let etree = &*symbolic.etree;
942
943 assert!(L_values.rb().len() == L_row_idx.len());
944 assert!(L_col_ptr.len() == n + 1);
945 let l_nnz = L_col_ptr[n].zx();
946
947 with_dim!(N, n);
948 with_dim!(L_NNZ, l_nnz);
949
950 let etree = Array::from_ref(MaybeIdx::from_slice_ref_checked(bytemuck::cast_slice::<I, I::Signed>(etree), N), N);
951 let A = A.as_shape(N, N);
952
953 let eps = abs(regularization.dynamic_regularization_epsilon);
954 let delta = abs(regularization.dynamic_regularization_delta);
955 let has_delta = delta > zero::<T::Real>();
956 let mut dynamic_regularization_count = 0usize;
957
958 let (mut x, stack) = temp_mat_zeroed::<T, _, _>(n, 1, stack);
959 let mut x = x.as_mat_mut().col_mut(0).as_row_shape_mut(N);
960 let (current_row_idxex, stack) = unsafe { stack.make_raw::<I>(n) };
961 let (ereach_stack, stack) = unsafe { stack.make_raw::<I>(n) };
962 let (visited, _) = unsafe { stack.make_raw::<I::Signed>(n) };
963
964 let ereach_stack = Array::from_mut(ereach_stack, N);
965 let visited = Array::from_mut(visited, N);
966
967 visited.as_mut().fill(I::Signed::truncate(NONE));
968
969 let L_values = Array::from_mut(L_values, L_NNZ);
970 let L_row_idx = Array::from_ref(L_row_idx, L_NNZ);
971
972 let L_col_ptr_start = Array::from_ref(Idx::from_slice_ref_checked(&L_col_ptr[..n], L_NNZ), N);
973
974 let current_row_idxex = Array::from_mut(ghost::copy_slice(current_row_idxex, L_col_ptr_start.as_ref()), N);
975
976 for k in N.indices() {
977 let reach = ereach(ereach_stack, A.symbolic(), etree, k, visited);
978
979 for (i, aik) in iter::zip(A.row_idx_of_col(k), A.val_of_col(k)) {
980 x[i] = x[i] + conj(*aik);
981 }
982
983 let mut d = real(x[k]);
984 x[k] = zero::<T>();
985
986 for &j in reach {
987 let j = j.zx();
988
989 let j_start = L_col_ptr_start[j].zx();
990 let cj = &mut current_row_idxex[j];
991 let row_idx = L_NNZ.check(*cj.zx() + 1);
992 *cj = row_idx.truncate();
993
994 let mut xj = copy(x[j]);
995 x[j] = zero::<T>();
996
997 let dj = recip(real(L_values[j_start]));
998 let lkj = mul_real(xj, dj);
999 if kind == FactorizationKind::Llt {
1000 xj = copy(lkj);
1001 }
1002
1003 let range = j_start.next()..row_idx.into();
1004 for (i, lij) in iter::zip(&L_row_idx[range.clone()], &L_values[range]) {
1005 let i = N.check(i.zx());
1006 x[i] = x[i] - conj(*lij) * xj;
1007 }
1008
1009 d = d - real(lkj * conj(xj));
1010
1011 L_values[row_idx] = lkj;
1012 }
1013
1014 let k_start = L_col_ptr_start[k].zx();
1015
1016 if has_delta {
1017 match kind {
1018 FactorizationKind::Llt => {
1019 if d <= eps {
1020 d = copy(delta);
1021 dynamic_regularization_count += 1;
1022 }
1023 },
1024 FactorizationKind::Ldlt => {
1025 if let Some(signs) = regularization.dynamic_regularization_signs {
1026 if signs[*k] > 0 && d <= eps {
1027 d = copy(delta);
1028 dynamic_regularization_count += 1;
1029 } else if signs[*k] < 0 && d >= -eps {
1030 d = -delta;
1031 dynamic_regularization_count += 1;
1032 }
1033 } else if abs(d) <= eps {
1034 if d < zero::<T::Real>() {
1035 d = -delta;
1036 dynamic_regularization_count += 1;
1037 } else {
1038 d = copy(delta);
1039 dynamic_regularization_count += 1;
1040 }
1041 }
1042 },
1043 }
1044 }
1045
1046 match kind {
1047 FactorizationKind::Llt => {
1048 if !(d > zero::<T::Real>()) {
1049 return Err(LltError::NonPositivePivot { index: *k + 1 });
1050 }
1051 L_values[k_start] = from_real(sqrt(d));
1052 },
1053 FactorizationKind::Ldlt => {
1054 if d == zero::<T::Real>() || !is_finite(d) {
1055 return Err(LltError::NonPositivePivot { index: *k + 1 });
1056 }
1057 L_values[k_start] = from_real(d);
1058 },
1059 }
1060 }
1061 Ok(LltInfo {
1062 dynamic_regularization_count,
1063 })
1064 }
1065
1066 pub fn factorize_simplicial_numeric_llt<I: Index, T: ComplexField>(
1077 L_values: &mut [T],
1078 A: SparseColMatRef<'_, I, T>,
1079 regularization: LltRegularization<T::Real>,
1080 symbolic: &SymbolicSimplicialCholesky<I>,
1081 stack: &mut MemStack,
1082 ) -> Result<LltInfo, LltError> {
1083 factorize_simplicial_numeric_cholesky(
1084 L_values,
1085 FactorizationKind::Llt,
1086 A,
1087 LdltRegularization {
1088 dynamic_regularization_signs: None,
1089 dynamic_regularization_delta: regularization.dynamic_regularization_delta,
1090 dynamic_regularization_epsilon: regularization.dynamic_regularization_epsilon,
1091 },
1092 symbolic,
1093 stack,
1094 )
1095 }
1096
1097 pub fn factorize_simplicial_numeric_llt_with_row_idx<I: Index, T: ComplexField>(
1109 L_values: &mut [T],
1110 L_row_idx: &mut [I],
1111 L_col_ptr: &[I],
1112
1113 etree: EliminationTreeRef<'_, I>,
1114 A: SparseColMatRef<'_, I, T>,
1115 regularization: LltRegularization<T::Real>,
1116
1117 stack: &mut MemStack,
1118 ) -> Result<LltInfo, LltError> {
1119 factorize_simplicial_numeric_with_row_idx(
1120 L_values,
1121 L_row_idx,
1122 L_col_ptr,
1123 FactorizationKind::Llt,
1124 etree,
1125 A,
1126 LdltRegularization {
1127 dynamic_regularization_signs: None,
1128 dynamic_regularization_delta: regularization.dynamic_regularization_delta,
1129 dynamic_regularization_epsilon: regularization.dynamic_regularization_epsilon,
1130 },
1131 stack,
1132 )
1133 }
1134
1135 pub fn factorize_simplicial_numeric_ldlt<I: Index, T: ComplexField>(
1146 L_values: &mut [T],
1147 A: SparseColMatRef<'_, I, T>,
1148 regularization: LdltRegularization<'_, T::Real>,
1149 symbolic: &SymbolicSimplicialCholesky<I>,
1150 stack: &mut MemStack,
1151 ) -> Result<LdltInfo, LdltError> {
1152 match factorize_simplicial_numeric_cholesky(L_values, FactorizationKind::Ldlt, A, regularization, symbolic, stack) {
1153 Ok(info) => Ok(LdltInfo {
1154 dynamic_regularization_count: info.dynamic_regularization_count,
1155 }),
1156 Err(LltError::NonPositivePivot { index }) => Err(LdltError::ZeroPivot { index }),
1157 }
1158 }
1159
1160 pub fn factorize_simplicial_numeric_ldlt_with_row_idx<I: Index, T: ComplexField>(
1172 L_values: &mut [T],
1173 L_row_idx: &mut [I],
1174 L_col_ptr: &[I],
1175
1176 etree: EliminationTreeRef<'_, I>,
1177 A: SparseColMatRef<'_, I, T>,
1178 regularization: LdltRegularization<'_, T::Real>,
1179
1180 stack: &mut MemStack,
1181 ) -> Result<LdltInfo, LdltError> {
1182 match factorize_simplicial_numeric_with_row_idx(L_values, L_row_idx, L_col_ptr, FactorizationKind::Ldlt, etree, A, regularization, stack) {
1183 Ok(info) => Ok(LdltInfo {
1184 dynamic_regularization_count: info.dynamic_regularization_count,
1185 }),
1186 Err(LltError::NonPositivePivot { index }) => Err(LdltError::ZeroPivot { index }),
1187 }
1188 }
1189
1190 impl<'a, I: Index, T> SimplicialLltRef<'a, I, T> {
1191 #[inline]
1196 pub fn new(symbolic: &'a SymbolicSimplicialCholesky<I>, values: &'a [T]) -> Self {
1197 assert!(values.len() == symbolic.len_val());
1198 Self { symbolic, values }
1199 }
1200
1201 #[inline]
1203 pub fn symbolic(self) -> &'a SymbolicSimplicialCholesky<I> {
1204 self.symbolic
1205 }
1206
1207 #[inline]
1209 pub fn values(self) -> &'a [T] {
1210 self.values
1211 }
1212
1213 pub fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
1219 where
1220 T: ComplexField,
1221 {
1222 let _ = par;
1223 let _ = stack;
1224 let n = self.symbolic().nrows();
1225 assert!(rhs.nrows() == n);
1226 let l = SparseColMatRef::<'_, I, T>::new(self.symbolic().factor(), self.values());
1227
1228 let mut rhs = rhs;
1229 triangular_solve::solve_lower_triangular_in_place(l, conj, rhs.rb_mut(), par);
1230 triangular_solve::solve_lower_triangular_transpose_in_place(l, conj.compose(Conj::Yes), rhs.rb_mut(), par);
1231 }
1232 }
1233
1234 impl<'a, I: Index, T> SimplicialLdltRef<'a, I, T> {
1235 #[inline]
1240 pub fn new(symbolic: &'a SymbolicSimplicialCholesky<I>, values: &'a [T]) -> Self {
1241 assert!(values.len() == symbolic.len_val());
1242 Self { symbolic, values }
1243 }
1244
1245 #[inline]
1247 pub fn symbolic(self) -> &'a SymbolicSimplicialCholesky<I> {
1248 self.symbolic
1249 }
1250
1251 #[inline]
1253 pub fn values(self) -> &'a [T] {
1254 self.values
1255 }
1256
1257 pub fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
1263 where
1264 T: ComplexField,
1265 {
1266 let _ = par;
1267 let _ = stack;
1268 let n = self.symbolic().nrows();
1269 let ld = SparseColMatRef::<'_, I, T>::new(self.symbolic().factor(), self.values());
1270 assert!(rhs.nrows() == n);
1271
1272 let mut x = rhs;
1273 triangular_solve::solve_unit_lower_triangular_in_place(ld, conj, x.rb_mut(), par);
1274 triangular_solve::ldlt_scale_solve_unit_lower_triangular_transpose_in_place_impl(ld, conj.compose(Conj::Yes), x.rb_mut(), par);
1275 }
1276 }
1277
1278 impl<I: Index> SymbolicSimplicialCholesky<I> {
1279 #[inline]
1281 pub fn nrows(&self) -> usize {
1282 self.dimension
1283 }
1284
1285 #[inline]
1287 pub fn ncols(&self) -> usize {
1288 self.nrows()
1289 }
1290
1291 #[inline]
1294 pub fn len_val(&self) -> usize {
1295 self.row_idx.len()
1296 }
1297
1298 #[inline]
1300 pub fn col_ptr(&self) -> &[I] {
1301 &self.col_ptr
1302 }
1303
1304 #[inline]
1306 pub fn row_idx(&self) -> &[I] {
1307 &self.row_idx
1308 }
1309
1310 #[inline]
1312 pub fn factor(&self) -> SymbolicSparseColMatRef<'_, I> {
1313 unsafe { SymbolicSparseColMatRef::new_unchecked(self.dimension, self.dimension, &self.col_ptr, None, &self.row_idx) }
1314 }
1315
1316 pub fn solve_in_place_scratch<T>(&self, rhs_ncols: usize) -> StackReq {
1319 let _ = rhs_ncols;
1320 StackReq::EMPTY
1321 }
1322 }
1323
1324 pub fn factorize_simplicial_numeric_ldlt_scratch<I: Index, T: ComplexField>(n: usize) -> StackReq {
1327 let n_scratch = StackReq::new::<I>(n);
1328 StackReq::all_of(&[temp_mat_scratch::<T>(n, 1), n_scratch, n_scratch, n_scratch])
1329 }
1330
1331 pub fn factorize_simplicial_numeric_llt_scratch<I: Index, T: ComplexField>(n: usize) -> StackReq {
1334 factorize_simplicial_numeric_ldlt_scratch::<I, T>(n)
1335 }
1336
1337 #[derive(Debug)]
1339 pub struct SimplicialLltRef<'a, I: Index, T> {
1340 symbolic: &'a SymbolicSimplicialCholesky<I>,
1341 values: &'a [T],
1342 }
1343
1344 #[derive(Debug)]
1346 pub struct SimplicialLdltRef<'a, I: Index, T> {
1347 symbolic: &'a SymbolicSimplicialCholesky<I>,
1348 values: &'a [T],
1349 }
1350
1351 #[derive(Debug, Clone)]
1353 pub struct SymbolicSimplicialCholesky<I> {
1354 dimension: usize,
1355 col_ptr: alloc::vec::Vec<I>,
1356 row_idx: alloc::vec::Vec<I>,
1357 etree: alloc::vec::Vec<I>,
1358 }
1359
1360 impl<I: Index, T> Copy for SimplicialLltRef<'_, I, T> {}
1361 impl<I: Index, T> Clone for SimplicialLltRef<'_, I, T> {
1362 fn clone(&self) -> Self {
1363 *self
1364 }
1365 }
1366
1367 impl<I: Index, T> Copy for SimplicialLdltRef<'_, I, T> {}
1368 impl<I: Index, T> Clone for SimplicialLdltRef<'_, I, T> {
1369 fn clone(&self) -> Self {
1370 *self
1371 }
1372 }
1373}
1374
1375pub mod supernodal {
1381 use super::*;
1382 use crate::linalg::matmul::internal::{spicy_matmul, spicy_matmul_scratch};
1383 use crate::{Shape, assert, debug_assert};
1384
1385 #[doc(hidden)]
1386 pub fn ereach_super<'n, 'nsuper, I: Index>(
1387 A: SymbolicSparseColMatRef<'_, I, Dim<'n>, Dim<'n>>,
1388 super_etree: &Array<'nsuper, MaybeIdx<'nsuper, I>>,
1389 index_to_super: &Array<'n, Idx<'nsuper, I>>,
1390 current_row_positions: &mut Array<'nsuper, I>,
1391 row_idx: &mut [Idx<'n, I>],
1392 k: Idx<'n, usize>,
1393 visited: &mut Array<'nsuper, I::Signed>,
1394 ) {
1395 let k_: I = *k.truncate();
1396 visited[index_to_super[k].zx()] = k_.to_signed();
1397 for i in A.row_idx_of_col(k) {
1398 if i >= k {
1399 continue;
1400 }
1401 let mut supernode_i = index_to_super[i].zx();
1402 loop {
1403 if visited[supernode_i] == k_.to_signed() {
1404 break;
1405 }
1406
1407 row_idx[current_row_positions[supernode_i].zx()] = k.truncate();
1408 current_row_positions[supernode_i] += I::truncate(1);
1409
1410 visited[supernode_i] = k_.to_signed();
1411 supernode_i = super_etree[supernode_i].sx().idx().unwrap();
1412 }
1413 }
1414 }
1415
1416 fn ereach_super_ata<'m, 'n, 'nsuper, I: Index>(
1417 A: SymbolicSparseColMatRef<'_, I, Dim<'m>, Dim<'n>>,
1418 perm: Option<PermRef<'_, I, Dim<'n>>>,
1419 min_col: &Array<'m, MaybeIdx<'n, I>>,
1420 super_etree: &Array<'nsuper, MaybeIdx<'nsuper, I>>,
1421 index_to_super: &Array<'n, Idx<'nsuper, I>>,
1422 current_row_positions: &mut Array<'nsuper, I>,
1423 row_idx: &mut [Idx<'n, I>],
1424 k: Idx<'n, usize>,
1425 visited: &mut Array<'nsuper, I::Signed>,
1426 ) {
1427 let k_: I = *k.truncate();
1428 visited[index_to_super[k].zx()] = k_.to_signed();
1429
1430 let fwd = perm.map(|perm| perm.bound_arrays().0);
1431 let fwd = |i: Idx<'n, usize>| fwd.map(|fwd| fwd[k].zx()).unwrap_or(i);
1432 for i in A.row_idx_of_col(fwd(k)) {
1433 let Some(i) = min_col[i].idx() else { continue };
1434 let i = i.zx();
1435
1436 if i >= k {
1437 continue;
1438 }
1439 let mut supernode_i = index_to_super[i].zx();
1440 loop {
1441 if visited[supernode_i] == k_.to_signed() {
1442 break;
1443 }
1444
1445 row_idx[current_row_positions[supernode_i].zx()] = k.truncate();
1446 current_row_positions[supernode_i] += I::truncate(1);
1447
1448 visited[supernode_i] = k_.to_signed();
1449 supernode_i = super_etree[supernode_i].sx().idx().unwrap();
1450 }
1451 }
1452 }
1453
1454 #[derive(Debug)]
1456 pub struct SymbolicSupernodeRef<'a, I> {
1457 start: usize,
1458 pattern: &'a [I],
1459 }
1460
1461 #[derive(Debug)]
1463 pub struct SupernodeRef<'a, I: Index, T> {
1464 matrix: MatRef<'a, T>,
1465 symbolic: SymbolicSupernodeRef<'a, I>,
1466 }
1467
1468 impl<I: Index> Copy for SymbolicSupernodeRef<'_, I> {}
1469 impl<I: Index> Clone for SymbolicSupernodeRef<'_, I> {
1470 fn clone(&self) -> Self {
1471 *self
1472 }
1473 }
1474
1475 impl<I: Index, T> Copy for SupernodeRef<'_, I, T> {}
1476 impl<I: Index, T> Clone for SupernodeRef<'_, I, T> {
1477 fn clone(&self) -> Self {
1478 *self
1479 }
1480 }
1481
1482 impl<'a, I: Index> SymbolicSupernodeRef<'a, I> {
1483 #[inline]
1485 pub fn start(self) -> usize {
1486 self.start
1487 }
1488
1489 pub fn pattern(self) -> &'a [I] {
1492 self.pattern
1493 }
1494 }
1495
1496 impl<'a, I: Index, T> SupernodeRef<'a, I, T> {
1497 #[inline]
1499 pub fn start(self) -> usize {
1500 self.symbolic.start
1501 }
1502
1503 pub fn pattern(self) -> &'a [I] {
1506 self.symbolic.pattern
1507 }
1508
1509 pub fn val(self) -> MatRef<'a, T> {
1511 self.matrix
1512 }
1513 }
1514
1515 #[derive(Debug)]
1517 pub struct SupernodalLltRef<'a, I: Index, T> {
1518 symbolic: &'a SymbolicSupernodalCholesky<I>,
1519 values: &'a [T],
1520 }
1521
1522 #[derive(Debug)]
1524 pub struct SupernodalLdltRef<'a, I: Index, T> {
1525 symbolic: &'a SymbolicSupernodalCholesky<I>,
1526 values: &'a [T],
1527 }
1528
1529 #[derive(Debug)]
1531 pub struct SupernodalIntranodeLbltRef<'a, I: Index, T> {
1532 symbolic: &'a SymbolicSupernodalCholesky<I>,
1533 values: &'a [T],
1534 subdiag: &'a [T],
1535 pub(super) perm: PermRef<'a, I>,
1536 }
1537
1538 #[derive(Debug)]
1540 pub struct SymbolicSupernodalCholesky<I> {
1541 pub(crate) dimension: usize,
1542 pub(crate) supernode_postorder: alloc::vec::Vec<I>,
1543 pub(crate) supernode_postorder_inv: alloc::vec::Vec<I>,
1544 pub(crate) descendant_count: alloc::vec::Vec<I>,
1545
1546 pub(crate) supernode_begin: alloc::vec::Vec<I>,
1547 pub(crate) col_ptr_for_row_idx: alloc::vec::Vec<I>,
1548 pub(crate) col_ptr_for_val: alloc::vec::Vec<I>,
1549 pub(crate) row_idx: alloc::vec::Vec<I>,
1550
1551 pub(crate) nnz_per_super: Option<alloc::vec::Vec<I>>,
1552 }
1553
1554 impl<I: Index> SymbolicSupernodalCholesky<I> {
1555 #[inline]
1557 pub fn n_supernodes(&self) -> usize {
1558 self.supernode_postorder.len()
1559 }
1560
1561 #[inline]
1563 pub fn nrows(&self) -> usize {
1564 self.dimension
1565 }
1566
1567 #[inline]
1569 pub fn ncols(&self) -> usize {
1570 self.nrows()
1571 }
1572
1573 #[inline]
1576 pub fn len_val(&self) -> usize {
1577 self.col_ptr_for_val()[self.n_supernodes()].zx()
1578 }
1579
1580 #[inline]
1583 pub fn supernode_begin(&self) -> &[I] {
1584 &self.supernode_begin[..self.n_supernodes()]
1585 }
1586
1587 #[inline]
1590 pub fn supernode_end(&self) -> &[I] {
1591 &self.supernode_begin[1..]
1592 }
1593
1594 #[inline]
1596 pub fn col_ptr_for_row_idx(&self) -> &[I] {
1597 &self.col_ptr_for_row_idx
1598 }
1599
1600 #[inline]
1602 pub fn col_ptr_for_val(&self) -> &[I] {
1603 &self.col_ptr_for_val
1604 }
1605
1606 #[inline]
1612 pub fn row_idx(&self) -> &[I] {
1613 &self.row_idx
1614 }
1615
1616 #[inline]
1618 pub fn supernode(&self, s: usize) -> supernodal::SymbolicSupernodeRef<'_, I> {
1619 let symbolic = self;
1620 let start = symbolic.supernode_begin[s].zx();
1621 let pattern = &symbolic.row_idx()[symbolic.col_ptr_for_row_idx()[s].zx()..symbolic.col_ptr_for_row_idx()[s + 1].zx()];
1622 supernodal::SymbolicSupernodeRef { start, pattern }
1623 }
1624
1625 pub fn solve_in_place_scratch<T: ComplexField>(&self, rhs_ncols: usize, par: Par) -> StackReq {
1628 _ = par;
1629 let mut req = StackReq::EMPTY;
1630 let symbolic = self;
1631 for s in 0..symbolic.n_supernodes() {
1632 let s = self.supernode(s);
1633 req = req.or(temp_mat_scratch::<T>(s.pattern.len(), rhs_ncols));
1634 }
1635 req
1636 }
1637
1638 #[doc(hidden)]
1639 pub fn __prepare_for_refactorize(&mut self) -> Result<(), FaerError> {
1640 let mut v = try_zeroed(self.n_supernodes())?;
1641 for s in 0..self.n_supernodes() {
1642 v[s] = self.col_ptr_for_row_idx[s + 1] - self.col_ptr_for_row_idx[s];
1643 }
1644 self.nnz_per_super = Some(v);
1645 Ok(())
1646 }
1647
1648 #[doc(hidden)]
1649 #[track_caller]
1650 pub fn __nnz_per_super(&self) -> &[I] {
1651 self.nnz_per_super.as_deref().unwrap()
1652 }
1653
1654 #[doc(hidden)]
1655 pub fn __refactorize(&mut self, A: SymbolicSparseColMatRef<'_, I>, etree: &[I::Signed], stack: &mut MemStack) {
1656 generativity::make_guard!(N);
1657 generativity::make_guard!(N_SUPERNODES);
1658 let N = self.nrows().bind(N);
1659 let N_SUPERNODES = self.nrows().bind(N_SUPERNODES);
1660
1661 let A = A.as_shape(N, N);
1662 let n = *N;
1663 let n_supernodes = *N_SUPERNODES;
1664 let none = I::Signed::truncate(NONE);
1665
1666 let etree = MaybeIdx::<I>::from_slice_ref_checked(etree, N);
1667 let etree = Array::from_ref(etree, N);
1668
1669 let (index_to_super, stack) = unsafe { stack.make_raw::<I>(n) };
1670 let (current_row_positions, stack) = unsafe { stack.make_raw::<I>(n_supernodes) };
1671 let (visited, stack) = unsafe { stack.make_raw::<I::Signed>(n_supernodes) };
1672 let (super_etree, _) = unsafe { stack.make_raw::<I::Signed>(n_supernodes) };
1673
1674 let super_etree = Array::from_mut(super_etree, N_SUPERNODES);
1675 let index_to_super = Array::from_mut(index_to_super, N);
1676
1677 let mut supernode_begin = 0usize;
1678 for s in N_SUPERNODES.indices() {
1679 let size = self.supernode_end()[*s].zx() - self.supernode_begin()[*s].zx();
1680 index_to_super.as_mut()[supernode_begin..][..size].fill(*s.truncate::<I>());
1681 supernode_begin += size;
1682 }
1683
1684 let index_to_super = Array::from_mut(Idx::from_slice_mut_checked(index_to_super.as_mut(), N_SUPERNODES), N);
1685
1686 let mut supernode_begin = 0usize;
1687 for s in N_SUPERNODES.indices() {
1688 let size = self.supernode_end()[*s + 1].zx() - self.supernode_begin()[*s].zx();
1689 let last = supernode_begin + size - 1;
1690 if let Some(parent) = etree[N.check(last)].idx() {
1691 super_etree[s] = index_to_super[parent.zx()].to_signed();
1692 } else {
1693 super_etree[s] = none;
1694 }
1695 supernode_begin += size;
1696 }
1697
1698 let super_etree = Array::from_mut(
1699 MaybeIdx::<'_, I>::from_slice_mut_checked(super_etree.as_mut(), N_SUPERNODES),
1700 N_SUPERNODES,
1701 );
1702
1703 let visited = Array::from_mut(visited, N_SUPERNODES);
1704 let current_row_positions = Array::from_mut(current_row_positions, N_SUPERNODES);
1705
1706 visited.as_mut().fill(I::Signed::truncate(NONE));
1707 current_row_positions.as_mut().fill(I::truncate(0));
1708
1709 for s in N_SUPERNODES.indices() {
1710 let k1 = ghost::IdxInc::new_checked(self.supernode_begin()[*s].zx(), N);
1711 let k2 = ghost::IdxInc::new_checked(self.supernode_end()[*s].zx(), N);
1712
1713 for k in k1.range_to(k2) {
1714 ereach_super(
1715 A,
1716 super_etree,
1717 index_to_super,
1718 current_row_positions,
1719 unsafe { Idx::from_slice_mut_unchecked(&mut self.row_idx) },
1720 k,
1721 visited,
1722 );
1723 }
1724 }
1725
1726 let Some(nnz_per_super) = self.nnz_per_super.as_deref_mut() else {
1727 panic!()
1728 };
1729
1730 for s in N_SUPERNODES.indices() {
1731 nnz_per_super[*s] = current_row_positions[s] - self.supernode_begin[*s];
1732 }
1733 }
1734 }
1735
1736 impl<I: Index, T> Copy for SupernodalLdltRef<'_, I, T> {}
1737 impl<I: Index, T> Clone for SupernodalLdltRef<'_, I, T> {
1738 fn clone(&self) -> Self {
1739 *self
1740 }
1741 }
1742 impl<I: Index, T> Copy for SupernodalLltRef<'_, I, T> {}
1743 impl<I: Index, T> Clone for SupernodalLltRef<'_, I, T> {
1744 fn clone(&self) -> Self {
1745 *self
1746 }
1747 }
1748 impl<I: Index, T> Copy for SupernodalIntranodeLbltRef<'_, I, T> {}
1749 impl<I: Index, T> Clone for SupernodalIntranodeLbltRef<'_, I, T> {
1750 fn clone(&self) -> Self {
1751 *self
1752 }
1753 }
1754
1755 impl<'a, I: Index, T> SupernodalLdltRef<'a, I, T> {
1756 #[inline]
1762 pub fn new(symbolic: &'a SymbolicSupernodalCholesky<I>, values: &'a [T]) -> Self {
1763 assert!(values.len() == symbolic.len_val());
1764 Self { symbolic, values }
1765 }
1766
1767 #[inline]
1769 pub fn symbolic(self) -> &'a SymbolicSupernodalCholesky<I> {
1770 self.symbolic
1771 }
1772
1773 #[inline]
1775 pub fn values(self) -> &'a [T] {
1776 self.values
1777 }
1778
1779 #[inline]
1781 pub fn supernode(self, s: usize) -> SupernodeRef<'a, I, T> {
1782 let symbolic = self.symbolic();
1783 let L_values = self.values();
1784 let s_start = symbolic.supernode_begin[s].zx();
1785 let s_end = symbolic.supernode_begin[s + 1].zx();
1786
1787 let s_pattern = &symbolic.row_idx()[symbolic.col_ptr_for_row_idx()[s].zx()..symbolic.col_ptr_for_row_idx()[s + 1].zx()];
1788 let s_ncols = s_end - s_start;
1789 let s_nrows = s_pattern.len() + s_ncols;
1790
1791 let Ls = MatRef::from_column_major_slice(
1792 &L_values[symbolic.col_ptr_for_val()[s].zx()..symbolic.col_ptr_for_val()[s + 1].zx()],
1793 s_nrows,
1794 s_ncols,
1795 );
1796
1797 SupernodeRef {
1798 matrix: Ls,
1799 symbolic: SymbolicSupernodeRef {
1800 start: s_start,
1801 pattern: s_pattern,
1802 },
1803 }
1804 }
1805
1806 #[math]
1812 pub fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
1813 where
1814 T: ComplexField,
1815 {
1816 let symbolic = self.symbolic();
1817 let n = symbolic.nrows();
1818 assert!(rhs.nrows() == n);
1819
1820 let mut x = rhs;
1821 let k = x.ncols();
1822 for s in 0..symbolic.n_supernodes() {
1823 let s = self.supernode(s);
1824 let size = s.matrix.ncols();
1825 let Ls = s.matrix;
1826 let (Ls_top, Ls_bot) = Ls.split_at_row(size);
1827 let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
1828 linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj(Ls_top, conj, x_top.rb_mut(), par);
1829
1830 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(s.pattern().len(), k, stack) };
1831 let mut tmp = tmp.as_mat_mut();
1832 linalg::matmul::matmul_with_conj(tmp.rb_mut(), Accum::Replace, Ls_bot, conj, x_top.rb(), Conj::No, one::<T>(), par);
1833
1834 for j in 0..k {
1835 for (idx, i) in s.pattern().iter().enumerate() {
1836 let i = i.zx();
1837 x[(i, j)] = x[(i, j)] - tmp[(idx, j)]
1838 }
1839 }
1840 }
1841 for s in 0..symbolic.n_supernodes() {
1842 let s = self.supernode(s);
1843 let size = s.matrix.ncols();
1844 let Ds = s.matrix.diagonal().column_vector();
1845 for j in 0..k {
1846 for idx in 0..size {
1847 let d_inv = recip(real(Ds[idx]));
1848 let i = idx + s.start();
1849 x[(i, j)] = mul_real(x[(i, j)], d_inv)
1850 }
1851 }
1852 }
1853 for s in (0..symbolic.n_supernodes()).rev() {
1854 let s = self.supernode(s);
1855 let size = s.matrix.ncols();
1856 let Ls = s.matrix;
1857 let (Ls_top, Ls_bot) = Ls.split_at_row(size);
1858
1859 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(s.pattern().len(), k, stack) };
1860 let mut tmp = tmp.as_mat_mut();
1861 for j in 0..k {
1862 for (idx, i) in s.pattern().iter().enumerate() {
1863 let i = i.zx();
1864 tmp[(idx, j)] = copy(x[(i, j)]);
1865 }
1866 }
1867
1868 let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
1869 linalg::matmul::matmul_with_conj(
1870 x_top.rb_mut(),
1871 Accum::Add,
1872 Ls_bot.transpose(),
1873 conj.compose(Conj::Yes),
1874 tmp.rb(),
1875 Conj::No,
1876 -one::<T>(),
1877 par,
1878 );
1879 linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj(
1880 Ls_top.transpose(),
1881 conj.compose(Conj::Yes),
1882 x_top.rb_mut(),
1883 par,
1884 );
1885 }
1886 }
1887 }
1888
1889 impl<'a, I: Index, T> SupernodalLltRef<'a, I, T> {
1890 #[inline]
1896 pub fn new(symbolic: &'a SymbolicSupernodalCholesky<I>, values: &'a [T]) -> Self {
1897 assert!(values.len() == symbolic.len_val());
1898 Self { symbolic, values }
1899 }
1900
1901 #[inline]
1903 pub fn symbolic(self) -> &'a SymbolicSupernodalCholesky<I> {
1904 self.symbolic
1905 }
1906
1907 #[inline]
1909 pub fn values(self) -> &'a [T] {
1910 self.values
1911 }
1912
1913 #[inline]
1915 pub fn supernode(self, s: usize) -> SupernodeRef<'a, I, T> {
1916 let symbolic = self.symbolic();
1917 let L_values = self.values();
1918 let s_start = symbolic.supernode_begin[s].zx();
1919 let s_end = symbolic.supernode_begin[s + 1].zx();
1920
1921 let s_pattern = &symbolic.row_idx()[symbolic.col_ptr_for_row_idx()[s].zx()..symbolic.col_ptr_for_row_idx()[s + 1].zx()];
1922 let s_ncols = s_end - s_start;
1923 let s_nrows = s_pattern.len() + s_ncols;
1924
1925 let Ls = MatRef::from_column_major_slice(
1926 &L_values[symbolic.col_ptr_for_val()[s].zx()..symbolic.col_ptr_for_val()[s + 1].zx()],
1927 s_nrows,
1928 s_ncols,
1929 );
1930
1931 SupernodeRef {
1932 matrix: Ls,
1933 symbolic: SymbolicSupernodeRef {
1934 start: s_start,
1935 pattern: s_pattern,
1936 },
1937 }
1938 }
1939
1940 #[math]
1946 pub fn l_solve_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
1947 where
1948 T: ComplexField,
1949 {
1950 let symbolic = self.symbolic();
1951 let n = symbolic.nrows();
1952 assert!(rhs.nrows() == n);
1953
1954 let mut x = rhs;
1955 let k = x.ncols();
1956 for s in 0..symbolic.n_supernodes() {
1957 let s = self.supernode(s);
1958 let size = s.matrix.ncols();
1959 let Ls = s.matrix;
1960 let (Ls_top, Ls_bot) = Ls.split_at_row(size);
1961 let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
1962 linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(Ls_top, conj, x_top.rb_mut(), par);
1963
1964 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(s.pattern().len(), k, stack) };
1965 let mut tmp = tmp.as_mat_mut();
1966 linalg::matmul::matmul_with_conj(tmp.rb_mut(), Accum::Replace, Ls_bot, conj, x_top.rb(), Conj::No, one::<T>(), par);
1967
1968 for j in 0..k {
1969 for (idx, i) in s.pattern().iter().enumerate() {
1970 let i = i.zx();
1971 x[(i, j)] = x[(i, j)] - tmp[(idx, j)]
1972 }
1973 }
1974 }
1975 }
1976
1977 #[inline]
1983 #[math]
1984 pub fn l_transpose_solve_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
1985 where
1986 T: ComplexField,
1987 {
1988 let symbolic = self.symbolic();
1989 let n = symbolic.nrows();
1990 assert!(rhs.nrows() == n);
1991
1992 let mut x = rhs;
1993 let k = x.ncols();
1994 for s in (0..symbolic.n_supernodes()).rev() {
1995 let s = self.supernode(s);
1996 let size = s.matrix.ncols();
1997 let Ls = s.matrix;
1998 let (Ls_top, Ls_bot) = Ls.split_at_row(size);
1999
2000 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(s.pattern().len(), k, stack) };
2001 let mut tmp = tmp.as_mat_mut();
2002 for j in 0..k {
2003 for (idx, i) in s.pattern().iter().enumerate() {
2004 let i = i.zx();
2005 tmp[(idx, j)] = copy(x[(i, j)]);
2006 }
2007 }
2008
2009 let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
2010 linalg::matmul::matmul_with_conj(x_top.rb_mut(), Accum::Add, Ls_bot.transpose(), conj, tmp.rb(), Conj::No, -one::<T>(), par);
2011 linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(Ls_top.transpose(), conj, x_top.rb_mut(), par);
2012 }
2013 }
2014
2015 #[math]
2021 pub fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
2022 where
2023 T: ComplexField,
2024 {
2025 let symbolic = self.symbolic();
2026 let n = symbolic.nrows();
2027 assert!(rhs.nrows() == n);
2028
2029 let mut x = rhs;
2030 let k = x.ncols();
2031 for s in 0..symbolic.n_supernodes() {
2032 let s = self.supernode(s);
2033 let size = s.matrix.ncols();
2034 let Ls = s.matrix;
2035 let (Ls_top, Ls_bot) = Ls.split_at_row(size);
2036 let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
2037 linalg::triangular_solve::solve_lower_triangular_in_place_with_conj(Ls_top, conj, x_top.rb_mut(), par);
2038
2039 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(s.pattern().len(), k, stack) };
2040 let mut tmp = tmp.as_mat_mut();
2041 linalg::matmul::matmul_with_conj(tmp.rb_mut(), Accum::Replace, Ls_bot, conj, x_top.rb(), Conj::No, one::<T>(), par);
2042
2043 for j in 0..k {
2044 for (idx, i) in s.pattern().iter().enumerate() {
2045 let i = i.zx();
2046 x[(i, j)] = x[(i, j)] - tmp[(idx, j)]
2047 }
2048 }
2049 }
2050 for s in (0..symbolic.n_supernodes()).rev() {
2051 let s = self.supernode(s);
2052 let size = s.matrix.ncols();
2053 let Ls = s.matrix;
2054 let (Ls_top, Ls_bot) = Ls.split_at_row(size);
2055
2056 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(s.pattern().len(), k, stack) };
2057 let mut tmp = tmp.as_mat_mut();
2058 for j in 0..k {
2059 for (idx, i) in s.pattern().iter().enumerate() {
2060 let i = i.zx();
2061 tmp[(idx, j)] = copy(x[(i, j)]);
2062 }
2063 }
2064
2065 let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
2066 linalg::matmul::matmul_with_conj(
2067 x_top.rb_mut(),
2068 Accum::Add,
2069 Ls_bot.transpose(),
2070 conj.compose(Conj::Yes),
2071 tmp.rb(),
2072 Conj::No,
2073 -one::<T>(),
2074 par,
2075 );
2076 linalg::triangular_solve::solve_upper_triangular_in_place_with_conj(Ls_top.transpose(), conj.compose(Conj::Yes), x_top.rb_mut(), par);
2077 }
2078 }
2079 }
2080
2081 impl<'a, I: Index, T> SupernodalIntranodeLbltRef<'a, I, T> {
2082 #[inline]
2090 pub fn new(symbolic: &'a SymbolicSupernodalCholesky<I>, values: &'a [T], subdiag: &'a [T], perm: PermRef<'a, I>) -> Self {
2091 assert!(all(
2092 values.len() == symbolic.len_val(),
2093 subdiag.len() == symbolic.nrows(),
2094 perm.len() == symbolic.nrows(),
2095 ));
2096 Self {
2097 symbolic,
2098 values,
2099 subdiag,
2100 perm,
2101 }
2102 }
2103
2104 #[inline]
2106 pub fn symbolic(self) -> &'a SymbolicSupernodalCholesky<I> {
2107 self.symbolic
2108 }
2109
2110 #[inline]
2112 pub fn val(self) -> &'a [T] {
2113 self.values
2114 }
2115
2116 #[inline]
2118 pub fn supernode(self, s: usize) -> SupernodeRef<'a, I, T> {
2119 let symbolic = self.symbolic();
2120 let L_values = self.val();
2121 let s_start = symbolic.supernode_begin[s].zx();
2122 let s_end = symbolic.supernode_begin[s + 1].zx();
2123
2124 let s_pattern = &symbolic.row_idx()[symbolic.col_ptr_for_row_idx()[s].zx()..symbolic.col_ptr_for_row_idx()[s + 1].zx()];
2125 let s_ncols = s_end - s_start;
2126 let s_nrows = s_pattern.len() + s_ncols;
2127
2128 let Ls = MatRef::from_column_major_slice(
2129 &L_values[symbolic.col_ptr_for_val()[s].zx()..symbolic.col_ptr_for_val()[s + 1].zx()],
2130 s_nrows,
2131 s_ncols,
2132 );
2133
2134 SupernodeRef {
2135 matrix: Ls,
2136 symbolic: SymbolicSupernodeRef {
2137 start: s_start,
2138 pattern: s_pattern,
2139 },
2140 }
2141 }
2142
2143 #[inline]
2145 pub fn perm(&self) -> PermRef<'a, I> {
2146 self.perm
2147 }
2148
2149 #[math]
2159 pub fn solve_in_place_no_numeric_permute_with_conj(self, conj_lb: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
2160 where
2161 T: ComplexField,
2162 {
2163 let symbolic = self.symbolic();
2164 let n = symbolic.nrows();
2165 assert!(rhs.nrows() == n);
2166
2167 let mut x = rhs;
2168
2169 let k = x.ncols();
2170 for s in 0..symbolic.n_supernodes() {
2171 let s = self.supernode(s);
2172 let size = s.matrix.ncols();
2173 let Ls = s.matrix;
2174 let (Ls_top, Ls_bot) = Ls.split_at_row(size);
2175 let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
2176 linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj(Ls_top, conj_lb, x_top.rb_mut(), par);
2177
2178 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(s.pattern().len(), k, stack) };
2179 let mut tmp = tmp.as_mat_mut();
2180 linalg::matmul::matmul_with_conj(tmp.rb_mut(), Accum::Replace, Ls_bot, conj_lb, x_top.rb(), Conj::No, one::<T>(), par);
2181
2182 let inv = self.perm.arrays().1;
2183 for j in 0..k {
2184 for (idx, i) in s.pattern().iter().enumerate() {
2185 let i = i.zx();
2186 let i = inv[i].zx();
2187 x[(i, j)] = x[(i, j)] - tmp[(idx, j)];
2188 }
2189 }
2190 }
2191 for s in 0..symbolic.n_supernodes() {
2192 let s = self.supernode(s);
2193 let size = s.matrix.ncols();
2194 let Bs = s.val();
2195 let subdiag = &self.subdiag[s.start()..s.start() + size];
2196
2197 let mut idx = 0;
2198 while idx < size {
2199 let subdiag = copy(subdiag[idx]);
2200 let i = idx + s.start();
2201 if subdiag == zero::<T>() {
2202 let d = recip(real(Bs[(idx, idx)]));
2203 for j in 0..k {
2204 x[(i, j)] = mul_real(x[(i, j)], d);
2205 }
2206 idx += 1;
2207 } else {
2208 let mut d21 = conj_lb.apply_rt(&subdiag);
2209 d21 = recip(d21);
2210 let d11 = mul_real(conj(d21), real(Bs[(idx, idx)]));
2211 let d22 = mul_real(d21, real(Bs[(idx + 1, idx + 1)]));
2212
2213 let denom = recip(d11 * d22 - one::<T>());
2214
2215 for j in 0..k {
2216 let xk = x[(i, j)] * conj(d21);
2217 let xkp1 = x[(i + 1, j)] * d21;
2218
2219 x[(i, j)] = (d22 * xk - xkp1) * denom;
2220 x[(i + 1, j)] = (d11 * xkp1 - xk) * denom;
2221 }
2222 idx += 2;
2223 }
2224 }
2225 }
2226 for s in (0..symbolic.n_supernodes()).rev() {
2227 let s = self.supernode(s);
2228 let size = s.matrix.ncols();
2229 let Ls = s.matrix;
2230 let (Ls_top, Ls_bot) = Ls.split_at_row(size);
2231
2232 let (mut tmp, _) = unsafe { temp_mat_uninit::<T, _, _>(s.pattern().len(), k, stack) };
2233 let mut tmp = tmp.as_mat_mut();
2234 let inv = self.perm.arrays().1;
2235 for j in 0..k {
2236 for (idx, i) in s.pattern().iter().enumerate() {
2237 let i = i.zx();
2238 let i = inv[i].zx();
2239 tmp[(idx, j)] = copy(x[(i, j)]);
2240 }
2241 }
2242
2243 let mut x_top = x.rb_mut().subrows_mut(s.start(), size);
2244 linalg::matmul::matmul_with_conj(
2245 x_top.rb_mut(),
2246 Accum::Add,
2247 Ls_bot.transpose(),
2248 conj_lb.compose(Conj::Yes),
2249 tmp.rb(),
2250 Conj::No,
2251 -one::<T>(),
2252 par,
2253 );
2254 linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj(
2255 Ls_top.transpose(),
2256 conj_lb.compose(Conj::Yes),
2257 x_top.rb_mut(),
2258 par,
2259 );
2260 }
2261 }
2262 }
2263
2264 pub fn factorize_supernodal_symbolic_cholesky_scratch<I: Index>(n: usize) -> StackReq {
2267 StackReq::new::<I>(n).array(4)
2268 }
2269
2270 pub fn factorize_supernodal_symbolic_cholesky<I: Index>(
2280 A: SymbolicSparseColMatRef<'_, I>,
2281 etree: simplicial::EliminationTreeRef<'_, I>,
2282 col_counts: &[I],
2283 stack: &mut MemStack,
2284 params: SymbolicSupernodalParams<'_>,
2285 ) -> Result<SymbolicSupernodalCholesky<I>, FaerError> {
2286 let n = A.nrows();
2287 assert!(A.nrows() == A.ncols());
2288 assert!(etree.into_inner().len() == n);
2289 assert!(col_counts.len() == n);
2290 with_dim!(N, n);
2291 ghost_factorize_supernodal_symbolic(
2292 A.as_shape(N, N),
2293 None,
2294 None,
2295 CholeskyInput::A,
2296 etree.as_bound(N),
2297 Array::from_ref(col_counts, N),
2298 stack,
2299 params,
2300 )
2301 }
2302
2303 pub(crate) enum CholeskyInput {
2304 A,
2305 ATA,
2306 }
2307
2308 pub(crate) fn ghost_factorize_supernodal_symbolic<'m, 'n, I: Index>(
2309 A: SymbolicSparseColMatRef<'_, I, Dim<'m>, Dim<'n>>,
2310 col_perm: Option<PermRef<'_, I, Dim<'n>>>,
2311 min_col: Option<&Array<'m, MaybeIdx<'n, I>>>,
2312 input: CholeskyInput,
2313 etree: &Array<'n, MaybeIdx<'n, I>>,
2314 col_counts: &Array<'n, I>,
2315 stack: &mut MemStack,
2316 params: SymbolicSupernodalParams<'_>,
2317 ) -> Result<SymbolicSupernodalCholesky<I>, FaerError> {
2318 let to_wide = |i: I| i.zx() as u128;
2319 let from_wide = |i: u128| I::truncate(i as usize);
2320 let from_wide_checked = |i: u128| -> Option<I> { (i <= to_wide(I::from_signed(I::Signed::MAX))).then_some(I::truncate(i as usize)) };
2321
2322 let N = A.ncols();
2323 let n = *N;
2324
2325 let zero = I::truncate(0);
2326 let one = I::truncate(1);
2327 let none = I::Signed::truncate(NONE);
2328
2329 if n == 0 {
2330 return Ok(SymbolicSupernodalCholesky {
2332 dimension: n,
2333 supernode_postorder: alloc::vec::Vec::new(),
2334 supernode_postorder_inv: alloc::vec::Vec::new(),
2335 descendant_count: alloc::vec::Vec::new(),
2336
2337 supernode_begin: try_collect([zero])?,
2338 col_ptr_for_row_idx: try_collect([zero])?,
2339 col_ptr_for_val: try_collect([zero])?,
2340 row_idx: alloc::vec::Vec::new(),
2341 nnz_per_super: None,
2342 });
2343 }
2344 let original_stack = stack;
2345
2346 let (index_to_super__, stack) = unsafe { original_stack.make_raw::<I>(n) };
2347 let (super_etree__, stack) = unsafe { stack.make_raw::<I::Signed>(n) };
2348 let (supernode_sizes__, stack) = unsafe { stack.make_raw::<I>(n) };
2349 let (child_count__, _) = unsafe { stack.make_raw::<I>(n) };
2350
2351 let child_count = Array::from_mut(child_count__, N);
2352 let index_to_super = Array::from_mut(index_to_super__, N);
2353
2354 child_count.as_mut().fill(zero);
2355 for j in N.indices() {
2356 if let Some(parent) = etree[j].idx() {
2357 child_count[parent.zx()] += one;
2358 }
2359 }
2360
2361 supernode_sizes__.fill(zero);
2362 let mut current_supernode = 0usize;
2363 supernode_sizes__[0] = one;
2364 for (j_prev, j) in iter::zip(N.indices().take(n - 1), N.indices().skip(1)) {
2365 let is_parent_of_prev = (*etree[j_prev]).sx() == *j;
2366 let is_parent_of_only_prev = child_count[j] == one;
2367 let same_pattern_as_prev = col_counts[j_prev] == col_counts[j] + one;
2368
2369 if !(is_parent_of_prev && is_parent_of_only_prev && same_pattern_as_prev) {
2370 current_supernode += 1;
2371 }
2372 supernode_sizes__[current_supernode] += one;
2373 }
2374 let n_fundamental_supernodes = current_supernode + 1;
2375
2376 let supernode_begin__ = {
2378 with_dim!(N_FUNDAMENTAL_SUPERNODES, n_fundamental_supernodes);
2379 let supernode_sizes = Array::from_mut(&mut supernode_sizes__[..n_fundamental_supernodes], N_FUNDAMENTAL_SUPERNODES);
2380 let super_etree = Array::from_mut(&mut super_etree__[..n_fundamental_supernodes], N_FUNDAMENTAL_SUPERNODES);
2381
2382 let mut supernode_begin = 0usize;
2383 for s in N_FUNDAMENTAL_SUPERNODES.indices() {
2384 let size = supernode_sizes[s].zx();
2385 index_to_super.as_mut()[supernode_begin..][..size].fill(*s.truncate::<I>());
2386 supernode_begin += size;
2387 }
2388
2389 let index_to_super = Array::from_mut(Idx::from_slice_mut_checked(index_to_super.as_mut(), N_FUNDAMENTAL_SUPERNODES), N);
2390
2391 let mut supernode_begin = 0usize;
2392 for s in N_FUNDAMENTAL_SUPERNODES.indices() {
2393 let size = supernode_sizes[s].zx();
2394 let last = supernode_begin + size - 1;
2395 let last = N.check(last);
2396 if let Some(parent) = etree[last].idx() {
2397 super_etree[s] = index_to_super[parent.zx()].to_signed();
2398 } else {
2399 super_etree[s] = none;
2400 }
2401 supernode_begin += size;
2402 }
2403
2404 let super_etree = Array::from_mut(
2405 MaybeIdx::<'_, I>::from_slice_mut_checked(super_etree.as_mut(), N_FUNDAMENTAL_SUPERNODES),
2406 N_FUNDAMENTAL_SUPERNODES,
2407 );
2408
2409 if let Some(relax) = params.relax {
2410 let mut mem = dyn_stack::MemBuffer::try_new(StackReq::all_of(&[StackReq::new::<I>(n_fundamental_supernodes); 5]))
2411 .ok()
2412 .ok_or(FaerError::OutOfMemory)?;
2413 let stack = MemStack::new(&mut mem);
2414
2415 let child_lists = bytemuck::cast_slice_mut(&mut child_count.as_mut()[..n_fundamental_supernodes]);
2416 let (child_list_heads, stack) = unsafe { stack.make_raw::<I::Signed>(n_fundamental_supernodes) };
2417 let (last_merged_children, stack) = unsafe { stack.make_raw::<I::Signed>(n_fundamental_supernodes) };
2418 let (merge_parents, stack) = unsafe { stack.make_raw::<I::Signed>(n_fundamental_supernodes) };
2419 let (fundamental_supernode_degrees, stack) = unsafe { stack.make_raw::<I>(n_fundamental_supernodes) };
2420 let (num_zeros, _) = unsafe { stack.make_raw::<I>(n_fundamental_supernodes) };
2421
2422 let child_lists = Array::from_mut(ghost::fill_none::<I>(child_lists, N_FUNDAMENTAL_SUPERNODES), N_FUNDAMENTAL_SUPERNODES);
2423 let child_list_heads = Array::from_mut(
2424 ghost::fill_none::<I>(child_list_heads, N_FUNDAMENTAL_SUPERNODES),
2425 N_FUNDAMENTAL_SUPERNODES,
2426 );
2427 let last_merged_children = Array::from_mut(
2428 ghost::fill_none::<I>(last_merged_children, N_FUNDAMENTAL_SUPERNODES),
2429 N_FUNDAMENTAL_SUPERNODES,
2430 );
2431 let merge_parents = Array::from_mut(ghost::fill_none::<I>(merge_parents, N_FUNDAMENTAL_SUPERNODES), N_FUNDAMENTAL_SUPERNODES);
2432 let fundamental_supernode_degrees = Array::from_mut(fundamental_supernode_degrees, N_FUNDAMENTAL_SUPERNODES);
2433 let num_zeros = Array::from_mut(num_zeros, N_FUNDAMENTAL_SUPERNODES);
2434
2435 let mut supernode_begin = 0usize;
2436 for s in N_FUNDAMENTAL_SUPERNODES.indices() {
2437 let size = supernode_sizes[s].zx();
2438 fundamental_supernode_degrees[s] = col_counts[N.check(supernode_begin + size - 1)] - one;
2439 supernode_begin += size;
2440 }
2441
2442 for s in N_FUNDAMENTAL_SUPERNODES.indices() {
2443 if let Some(parent) = super_etree[s].idx() {
2444 let parent = parent.zx();
2445 child_lists[s] = child_list_heads[parent];
2446 child_list_heads[parent] = MaybeIdx::from_index(s.truncate());
2447 }
2448 }
2449
2450 num_zeros.as_mut().fill(I::truncate(0));
2451 for parent in N_FUNDAMENTAL_SUPERNODES.indices() {
2452 loop {
2453 let mut merging_child = MaybeIdx::none();
2454 let mut num_new_zeros = 0usize;
2455 let mut num_merged_zeros = 0usize;
2456 let mut largest_mergable_size = 0usize;
2457
2458 let mut child_ = child_list_heads[parent];
2459 while let Some(child) = child_.idx() {
2460 let child = child.zx();
2461 if *child + 1 != *parent {
2462 child_ = child_lists[child];
2463 continue;
2464 }
2465
2466 if merge_parents[child].idx().is_some() {
2467 child_ = child_lists[child];
2468 continue;
2469 }
2470
2471 let parent_size = supernode_sizes[parent].zx();
2472 let child_size = supernode_sizes[child].zx();
2473 if child_size < largest_mergable_size {
2474 child_ = child_lists[child];
2475 continue;
2476 }
2477
2478 let parent_degree = fundamental_supernode_degrees[parent].zx();
2479 let child_degree = fundamental_supernode_degrees[child].zx();
2480
2481 let num_parent_zeros = num_zeros[parent].zx();
2482 let num_child_zeros = num_zeros[child].zx();
2483
2484 let status_num_merged_zeros = {
2485 let num_new_zeros = (parent_size + parent_degree - child_degree) * child_size;
2486
2487 if num_new_zeros == 0 {
2488 num_parent_zeros + num_child_zeros
2489 } else {
2490 let num_old_zeros = num_child_zeros + num_parent_zeros;
2491 let num_zeros = num_new_zeros + num_old_zeros;
2492
2493 let combined_size = child_size + parent_size;
2494 let num_expanded_entries = (combined_size * (combined_size + 1)) / 2 + parent_degree * combined_size;
2495
2496 let f = || {
2497 for cutoff in relax {
2498 let num_zeros_cutoff = num_expanded_entries as f64 * cutoff.1;
2499 if cutoff.0 >= combined_size && num_zeros_cutoff >= num_zeros as f64 {
2500 return num_zeros;
2501 }
2502 }
2503 NONE
2504 };
2505 f()
2506 }
2507 };
2508 if status_num_merged_zeros == NONE {
2509 child_ = child_lists[child];
2510 continue;
2511 }
2512
2513 let num_proposed_new_zeros = status_num_merged_zeros - (num_child_zeros + num_parent_zeros);
2514 if child_size > largest_mergable_size || num_proposed_new_zeros < num_new_zeros {
2515 merging_child = MaybeIdx::from_index(child);
2516 num_new_zeros = num_proposed_new_zeros;
2517 num_merged_zeros = status_num_merged_zeros;
2518 largest_mergable_size = child_size;
2519 }
2520
2521 child_ = child_lists[child];
2522 }
2523
2524 if let Some(merging_child) = merging_child.idx() {
2525 supernode_sizes[parent] = supernode_sizes[parent] + supernode_sizes[merging_child];
2526 supernode_sizes[merging_child] = zero;
2527 num_zeros[parent] = I::truncate(num_merged_zeros);
2528
2529 merge_parents[merging_child] = if let Some(child) = last_merged_children[parent].idx() {
2530 MaybeIdx::from_index(child)
2531 } else {
2532 MaybeIdx::from_index(parent.truncate())
2533 };
2534
2535 last_merged_children[parent] = if let Some(child) = last_merged_children[merging_child].idx() {
2536 MaybeIdx::from_index(child)
2537 } else {
2538 MaybeIdx::from_index(merging_child.truncate())
2539 };
2540 } else {
2541 break;
2542 }
2543 }
2544 }
2545
2546 let original_to_relaxed = last_merged_children;
2547 original_to_relaxed.as_mut().fill(MaybeIdx::none());
2548
2549 let mut pos = 0usize;
2550 for s in N_FUNDAMENTAL_SUPERNODES.indices() {
2551 let idx = N_FUNDAMENTAL_SUPERNODES.check(pos);
2552 let size = supernode_sizes[s];
2553 let degree = fundamental_supernode_degrees[s];
2554 if size > zero {
2555 supernode_sizes[idx] = size;
2556 fundamental_supernode_degrees[idx] = degree;
2557 original_to_relaxed[s] = MaybeIdx::from_index(idx.truncate());
2558
2559 pos += 1;
2560 }
2561 }
2562 let n_relaxed_supernodes = pos;
2563
2564 let mut supernode_begin__ = try_zeroed(n_relaxed_supernodes + 1)?;
2565 supernode_begin__[1..].copy_from_slice(&fundamental_supernode_degrees.as_ref()[..n_relaxed_supernodes]);
2566
2567 supernode_begin__
2568 } else {
2569 let mut supernode_begin__ = try_zeroed(n_fundamental_supernodes + 1)?;
2570
2571 let mut supernode_begin = 0usize;
2572 for s in N_FUNDAMENTAL_SUPERNODES.indices() {
2573 let size = supernode_sizes[s].zx();
2574 supernode_begin__[*s + 1] = col_counts[N.check(supernode_begin + size - 1)] - one;
2575 supernode_begin += size;
2576 }
2577
2578 supernode_begin__
2579 }
2580 };
2581
2582 let n_supernodes = supernode_begin__.len() - 1;
2583
2584 let (supernode_begin__, col_ptr_for_row_idx__, col_ptr_for_val__, row_idx__) = {
2585 with_dim!(N_SUPERNODES, n_supernodes);
2586 let supernode_sizes = Array::from_mut(&mut supernode_sizes__[..n_supernodes], N_SUPERNODES);
2587
2588 if n_supernodes != n_fundamental_supernodes {
2589 let mut supernode_begin = 0usize;
2590 for s in N_SUPERNODES.indices() {
2591 let size = supernode_sizes[s].zx();
2592 index_to_super.as_mut()[supernode_begin..][..size].fill(*s.truncate::<I>());
2593 supernode_begin += size;
2594 }
2595
2596 let index_to_super = Array::from_mut(Idx::<'_, I>::from_slice_mut_checked(index_to_super.as_mut(), N_SUPERNODES), N);
2597 let super_etree = Array::from_mut(&mut super_etree__[..n_supernodes], N_SUPERNODES);
2598
2599 let mut supernode_begin = 0usize;
2600 for s in N_SUPERNODES.indices() {
2601 let size = supernode_sizes[s].zx();
2602 let last = supernode_begin + size - 1;
2603 if let Some(parent) = etree[N.check(last)].idx() {
2604 super_etree[s] = index_to_super[parent.zx()].to_signed();
2605 } else {
2606 super_etree[s] = none;
2607 }
2608 supernode_begin += size;
2609 }
2610 }
2611
2612 let index_to_super = Array::from_mut(Idx::from_slice_mut_checked(index_to_super.as_mut(), N_SUPERNODES), N);
2613
2614 let mut supernode_begin__ = supernode_begin__;
2615 let mut col_ptr_for_row_idx__ = try_zeroed::<I>(n_supernodes + 1)?;
2616 let mut col_ptr_for_val__ = try_zeroed::<I>(n_supernodes + 1)?;
2617
2618 let mut row_ptr = zero;
2619 let mut val_ptr = zero;
2620
2621 supernode_begin__[0] = zero;
2622
2623 let mut row_idx__ = {
2624 let mut wide_val_count = 0u128;
2625 for (s, [current, next]) in iter::zip(
2626 N_SUPERNODES.indices(),
2627 windows2(Cell::as_slice_of_cells(Cell::from_mut(&mut *supernode_begin__))),
2628 ) {
2629 let degree = next.get();
2630 let ncols = supernode_sizes[s];
2631 let nrows = degree + ncols;
2632 supernode_sizes[s] = row_ptr;
2633 next.set(current.get() + ncols);
2634
2635 col_ptr_for_row_idx__[*s] = row_ptr;
2636 col_ptr_for_val__[*s] = val_ptr;
2637
2638 let wide_matrix_size = to_wide(nrows) * to_wide(ncols);
2639 wide_val_count += wide_matrix_size;
2640
2641 row_ptr += degree;
2642 val_ptr = from_wide(to_wide(val_ptr) + wide_matrix_size);
2643 }
2644 col_ptr_for_row_idx__[n_supernodes] = row_ptr;
2645 col_ptr_for_val__[n_supernodes] = val_ptr;
2646 from_wide_checked(wide_val_count).ok_or(FaerError::IndexOverflow)?;
2647
2648 try_zeroed::<I>(row_ptr.zx())?
2649 };
2650
2651 let super_etree = Array::from_ref(
2652 MaybeIdx::from_slice_ref_checked(&super_etree__[..n_supernodes], N_SUPERNODES),
2653 N_SUPERNODES,
2654 );
2655
2656 let current_row_positions = supernode_sizes;
2657
2658 let row_idx = Idx::from_slice_mut_checked(&mut row_idx__, N);
2659 let visited = Array::from_mut(bytemuck::cast_slice_mut(&mut child_count.as_mut()[..n_supernodes]), N_SUPERNODES);
2660
2661 visited.as_mut().fill(I::Signed::truncate(NONE));
2662 if matches!(input, CholeskyInput::A) {
2663 let A = A.as_shape(N, N);
2664 for s in N_SUPERNODES.indices() {
2665 let k1 = ghost::IdxInc::new_checked(supernode_begin__[*s].zx(), N);
2666 let k2 = ghost::IdxInc::new_checked(supernode_begin__[*s + 1].zx(), N);
2667
2668 for k in k1.range_to(k2) {
2669 ereach_super(A, super_etree, index_to_super, current_row_positions, row_idx, k, visited);
2670 }
2671 }
2672 } else {
2673 let min_col = min_col.unwrap();
2674 for s in N_SUPERNODES.indices() {
2675 let k1 = ghost::IdxInc::new_checked(supernode_begin__[*s].zx(), N);
2676 let k2 = ghost::IdxInc::new_checked(supernode_begin__[*s + 1].zx(), N);
2677
2678 for k in k1.range_to(k2) {
2679 ereach_super_ata(
2680 A,
2681 col_perm,
2682 min_col,
2683 super_etree,
2684 index_to_super,
2685 current_row_positions,
2686 row_idx,
2687 k,
2688 visited,
2689 );
2690 }
2691 }
2692 }
2693
2694 debug_assert!(current_row_positions.as_ref() == &col_ptr_for_row_idx__[1..]);
2695
2696 (supernode_begin__, col_ptr_for_row_idx__, col_ptr_for_val__, row_idx__)
2697 };
2698
2699 let mut supernode_etree__: alloc::vec::Vec<I> = try_collect(bytemuck::cast_slice(&super_etree__[..n_supernodes]).iter().copied())?;
2700 let mut supernode_postorder__ = try_zeroed::<I>(n_supernodes)?;
2701
2702 let mut descendent_count__ = try_zeroed::<I>(n_supernodes)?;
2703
2704 {
2705 with_dim!(N_SUPERNODES, n_supernodes);
2706 let post = Array::from_mut(&mut supernode_postorder__, N_SUPERNODES);
2707 let desc_count = Array::from_mut(&mut descendent_count__, N_SUPERNODES);
2708 let etree: &Array<'_, MaybeIdx<'_, I>> = Array::from_ref(
2709 MaybeIdx::from_slice_ref_checked(bytemuck::cast_slice(&supernode_etree__), N_SUPERNODES),
2710 N_SUPERNODES,
2711 );
2712
2713 for s in N_SUPERNODES.indices() {
2714 if let Some(parent) = etree[s].idx() {
2715 let parent = parent.zx();
2716 desc_count[parent] = desc_count[parent] + desc_count[s] + one;
2717 }
2718 }
2719
2720 ghost_postorder(post, etree, original_stack);
2721 let post_inv = Array::from_mut(bytemuck::cast_slice_mut(&mut supernode_etree__), N_SUPERNODES);
2722 for i in N_SUPERNODES.indices() {
2723 post_inv[N_SUPERNODES.check(post[i].zx())] = I::truncate(*i);
2724 }
2725 };
2726
2727 Ok(SymbolicSupernodalCholesky {
2728 dimension: n,
2729 supernode_postorder: supernode_postorder__,
2730 supernode_postorder_inv: supernode_etree__,
2731 descendant_count: descendent_count__,
2732 supernode_begin: supernode_begin__,
2733 col_ptr_for_row_idx: col_ptr_for_row_idx__,
2734 col_ptr_for_val: col_ptr_for_val__,
2735 row_idx: row_idx__,
2736 nnz_per_super: None,
2737 })
2738 }
2739
2740 #[inline]
2741 pub(crate) fn partition_fn<I: Index>(idx: usize) -> impl Fn(&I) -> bool {
2742 let idx = I::truncate(idx);
2743 move |&i| i < idx
2744 }
2745
2746 pub fn factorize_supernodal_numeric_llt_scratch<I: Index, T: ComplexField>(
2749 symbolic: &SymbolicSupernodalCholesky<I>,
2750 par: Par,
2751 params: Spec<LltParams, T>,
2752 ) -> StackReq {
2753 let n_supernodes = symbolic.n_supernodes();
2754 let n = symbolic.nrows();
2755 let post = &*symbolic.supernode_postorder;
2756 let post_inv = &*symbolic.supernode_postorder_inv;
2757
2758 let desc_count = &*symbolic.descendant_count;
2759
2760 let col_ptr_row = &*symbolic.col_ptr_for_row_idx;
2761 let row_idx = &*symbolic.row_idx;
2762
2763 let mut req = StackReq::empty();
2764 for s in 0..n_supernodes {
2765 let s_start = symbolic.supernode_begin[s].zx();
2766 let s_end = symbolic.supernode_begin[s + 1].zx();
2767
2768 let s_ncols = s_end - s_start;
2769
2770 let s_postordered = post_inv[s].zx();
2771 let desc_count = desc_count[s].zx();
2772 for d in &post[s_postordered - desc_count..s_postordered] {
2773 let mut d_scratch = StackReq::empty();
2774 let d = d.zx();
2775 let d_start = symbolic.supernode_begin[d].zx();
2776 let d_end = symbolic.supernode_begin[d + 1].zx();
2777
2778 let d_pattern = &row_idx[col_ptr_row[d].zx()..col_ptr_row[d + 1].zx()];
2779 let d_pattern_start = d_pattern.partition_point(partition_fn(s_start));
2780 let d_pattern_mid_len = d_pattern[d_pattern_start..].partition_point(partition_fn(s_end));
2781
2782 d_scratch = d_scratch
2783 .and(StackReq::new::<I>(d_pattern.len() - d_pattern_start))
2784 .and(StackReq::new::<I>(d_pattern_mid_len));
2785
2786 let d_ncols = d_end - d_start;
2787
2788 d_scratch = d_scratch.and(spicy_matmul_scratch::<T>(
2789 d_pattern.len() - d_pattern_start,
2790 d_pattern_mid_len,
2791 d_ncols,
2792 true,
2793 false,
2794 ));
2795 req = req.or(d_scratch);
2796 }
2797 req = req.or(linalg::cholesky::llt::factor::cholesky_in_place_scratch::<T>(s_ncols, par, params));
2798 }
2799 req.and(StackReq::new::<I>(n))
2800 }
2801
2802 pub fn factorize_supernodal_numeric_ldlt_scratch<I: Index, T: ComplexField>(
2805 symbolic: &SymbolicSupernodalCholesky<I>,
2806 par: Par,
2807 params: Spec<LdltParams, T>,
2808 ) -> StackReq {
2809 let n_supernodes = symbolic.n_supernodes();
2810 let n = symbolic.nrows();
2811 let post = &*symbolic.supernode_postorder;
2812 let post_inv = &*symbolic.supernode_postorder_inv;
2813
2814 let desc_count = &*symbolic.descendant_count;
2815
2816 let col_ptr_row = &*symbolic.col_ptr_for_row_idx;
2817 let row_idx = &*symbolic.row_idx;
2818
2819 let mut req = StackReq::empty();
2820 for s in 0..n_supernodes {
2821 let s_start = symbolic.supernode_begin[s].zx();
2822 let s_end = symbolic.supernode_begin[s + 1].zx();
2823
2824 let s_ncols = s_end - s_start;
2825
2826 let s_postordered = post_inv[s].zx();
2827 let desc_count = desc_count[s].zx();
2828 for d in &post[s_postordered - desc_count..s_postordered] {
2829 let mut d_scratch = StackReq::empty();
2830
2831 let d = d.zx();
2832 let d_start = symbolic.supernode_begin[d].zx();
2833 let d_end = symbolic.supernode_begin[d + 1].zx();
2834
2835 let d_pattern = &row_idx[col_ptr_row[d].zx()..col_ptr_row[d + 1].zx()];
2836
2837 let d_ncols = d_end - d_start;
2838
2839 let d_pattern_start = d_pattern.partition_point(partition_fn(s_start));
2840 let d_pattern_mid_len = d_pattern[d_pattern_start..].partition_point(partition_fn(s_end));
2841
2842 d_scratch = d_scratch
2843 .and(StackReq::new::<I>(d_pattern.len() - d_pattern_start))
2844 .and(StackReq::new::<I>(d_pattern_mid_len));
2845
2846 d_scratch = d_scratch.and(spicy_matmul_scratch::<T>(
2847 d_pattern.len() - d_pattern_start,
2848 d_pattern_mid_len,
2849 d_ncols,
2850 true,
2851 true,
2852 ));
2853 req = req.or(d_scratch);
2854 }
2855 req = req.or(linalg::cholesky::ldlt::factor::cholesky_in_place_scratch::<T>(s_ncols, par, params));
2856 }
2857 req.and(StackReq::new::<I>(n))
2858 }
2859
2860 pub fn factorize_supernodal_numeric_intranode_lblt_scratch<I: Index, T: ComplexField>(
2864 symbolic: &SymbolicSupernodalCholesky<I>,
2865 par: Par,
2866 params: Spec<LbltParams, T>,
2867 ) -> StackReq {
2868 let n_supernodes = symbolic.n_supernodes();
2869 let n = symbolic.nrows();
2870 let post = &*symbolic.supernode_postorder;
2871 let post_inv = &*symbolic.supernode_postorder_inv;
2872
2873 let desc_count = &*symbolic.descendant_count;
2874
2875 let col_ptr_row = &*symbolic.col_ptr_for_row_idx;
2876 let row_idx = &*symbolic.row_idx;
2877
2878 let mut req = StackReq::empty();
2879 for s in 0..n_supernodes {
2880 let s_start = symbolic.supernode_begin[s].zx();
2881 let s_end = symbolic.supernode_begin[s + 1].zx();
2882
2883 let s_ncols = s_end - s_start;
2884 let s_pattern = &row_idx[col_ptr_row[s].zx()..col_ptr_row[s + 1].zx()];
2885
2886 let s_postordered = post_inv[s].zx();
2887 let desc_count = desc_count[s].zx();
2888 for d in &post[s_postordered - desc_count..s_postordered] {
2889 let mut d_scratch = StackReq::empty();
2890
2891 let d = d.zx();
2892 let d_start = symbolic.supernode_begin[d].zx();
2893 let d_end = symbolic.supernode_begin[d + 1].zx();
2894
2895 let d_pattern = &row_idx[col_ptr_row[d].zx()..col_ptr_row[d + 1].zx()];
2896
2897 let d_ncols = d_end - d_start;
2898
2899 let d_pattern_start = d_pattern.partition_point(partition_fn(s_start));
2900 let d_pattern_mid_len = d_pattern[d_pattern_start..].partition_point(partition_fn(s_end));
2901
2902 d_scratch = d_scratch.and(temp_mat_scratch::<T>(d_pattern.len() - d_pattern_start, d_pattern_mid_len));
2903 d_scratch = d_scratch.and(temp_mat_scratch::<T>(d_ncols, d_pattern_mid_len));
2904 req = req.or(d_scratch);
2905 }
2906 req = StackReq::any_of(&[
2907 req,
2908 linalg::cholesky::lblt::factor::cholesky_in_place_scratch::<I, T>(s_ncols, par, params),
2909 crate::perm::permute_cols_in_place_scratch::<I, T>(s_pattern.len(), s_ncols),
2910 ]);
2911 }
2912 req.and(StackReq::new::<I>(n))
2913 }
2914
2915 #[math]
2927 pub fn factorize_supernodal_numeric_llt<I: Index, T: ComplexField>(
2928 L_values: &mut [T],
2929 A_lower: SparseColMatRef<'_, I, T>,
2930 regularization: LltRegularization<T::Real>,
2931 symbolic: &SymbolicSupernodalCholesky<I>,
2932 par: Par,
2933 stack: &mut MemStack,
2934 params: Spec<LltParams, T>,
2935 ) -> Result<LltInfo, LltError> {
2936 let n_supernodes = symbolic.n_supernodes();
2937 let n = symbolic.nrows();
2938 let mut dynamic_regularization_count = 0usize;
2939 L_values.fill(zero::<T>());
2940
2941 assert!(A_lower.nrows() == n);
2942 assert!(A_lower.ncols() == n);
2943 assert!(L_values.len() == symbolic.len_val());
2944
2945 let none = I::Signed::truncate(NONE);
2946
2947 let post = &*symbolic.supernode_postorder;
2948 let post_inv = &*symbolic.supernode_postorder_inv;
2949
2950 let desc_count = &*symbolic.descendant_count;
2951
2952 let col_ptr_row = &*symbolic.col_ptr_for_row_idx;
2953 let col_ptr_val = &*symbolic.col_ptr_for_val;
2954 let row_idx = &*symbolic.row_idx;
2955
2956 let (global_to_local, stack) = unsafe { stack.make_raw::<I::Signed>(n) };
2958 global_to_local.fill(I::Signed::truncate(NONE));
2959
2960 for s in 0..n_supernodes {
2961 let s_start = symbolic.supernode_begin[s].zx();
2962 let s_end = symbolic.supernode_begin[s + 1].zx();
2963
2964 let s_pattern = &row_idx[col_ptr_row[s].zx()..col_ptr_row[s + 1].zx()];
2965 let s_ncols = s_end - s_start;
2966 let s_nrows = s_pattern.len() + s_ncols;
2967
2968 for (i, &row) in s_pattern.iter().enumerate() {
2969 global_to_local[row.zx()] = I::Signed::truncate(i + s_ncols);
2970 }
2971
2972 let (head, tail) = L_values.split_at_mut(col_ptr_val[s].zx());
2973 let head = head.rb();
2974 let mut Ls = MatMut::from_column_major_slice_mut(&mut tail[..col_ptr_val[s + 1].zx() - col_ptr_val[s].zx()], s_nrows, s_ncols);
2975
2976 for j in s_start..s_end {
2977 let j_shifted = j - s_start;
2978 for (i, val) in iter::zip(A_lower.row_idx_of_col(j), A_lower.val_of_col(j)) {
2979 if i < j {
2980 continue;
2981 }
2982
2983 let (ix, iy) = if i >= s_end {
2984 (global_to_local[i].sx(), j_shifted)
2985 } else {
2986 (i - s_start, j_shifted)
2987 };
2988 Ls[(ix, iy)] = Ls[(ix, iy)] + *val;
2989 }
2990 }
2991
2992 let s_postordered = post_inv[s].zx();
2993 let desc_count = desc_count[s].zx();
2994 for d in &post[s_postordered - desc_count..s_postordered] {
2995 let d = d.zx();
2996 let d_start = symbolic.supernode_begin[d].zx();
2997 let d_end = symbolic.supernode_begin[d + 1].zx();
2998
2999 let d_pattern = &row_idx[col_ptr_row[d].zx()..col_ptr_row[d + 1].zx()];
3000 let d_ncols = d_end - d_start;
3001 let d_nrows = d_pattern.len() + d_ncols;
3002
3003 let Ld = MatRef::from_column_major_slice(&head[col_ptr_val[d].zx()..col_ptr_val[d + 1].zx()], d_nrows, d_ncols);
3004
3005 let d_pattern_start = d_pattern.partition_point(partition_fn(s_start));
3006 let d_pattern_mid_len = d_pattern[d_pattern_start..].partition_point(partition_fn(s_end));
3007
3008 let (_, Ld_mid_bot) = Ld.split_at_row(d_ncols);
3009 let (_, Ld_mid_bot) = Ld_mid_bot.split_at_row(d_pattern_start);
3010 let (Ld_mid, _) = Ld_mid_bot.split_at_row(d_pattern_mid_len);
3011
3012 use linalg::matmul::triangular;
3013 let (row_idx, stack) = stack.make_with(Ld_mid_bot.nrows(), |i| {
3014 if i < d_pattern_mid_len {
3015 I::truncate(d_pattern[d_pattern_start + i].zx() - s_start)
3016 } else {
3017 I::from_signed(global_to_local[d_pattern[d_pattern_start + i].zx()])
3018 }
3019 });
3020 let (col_idx, stack) = stack.make_with(d_pattern_mid_len, |j| I::truncate(d_pattern[d_pattern_start + j].zx() - s_start));
3021
3022 spicy_matmul(
3023 Ls.rb_mut(),
3024 triangular::BlockStructure::TriangularLower,
3025 Some(&row_idx),
3026 Some(&col_idx),
3027 Accum::Add,
3028 Ld_mid_bot,
3029 Conj::No,
3030 Ld_mid.transpose(),
3031 Conj::Yes,
3032 None,
3033 -one::<T>(),
3034 par,
3035 stack,
3036 );
3037 }
3038
3039 let (mut Ls_top, mut Ls_bot) = Ls.rb_mut().split_at_row_mut(s_ncols);
3040
3041 dynamic_regularization_count +=
3042 match linalg::cholesky::llt::factor::cholesky_in_place(Ls_top.rb_mut(), regularization.clone(), par, stack, params) {
3043 Ok(count) => count,
3044 Err(LltError::NonPositivePivot { index }) => {
3045 return Err(LltError::NonPositivePivot { index: index + s_start });
3046 },
3047 }
3048 .dynamic_regularization_count;
3049 linalg::triangular_solve::solve_lower_triangular_in_place(Ls_top.rb().conjugate(), Ls_bot.rb_mut().transpose_mut(), par);
3050
3051 for &row in s_pattern {
3052 global_to_local[row.zx()] = none;
3053 }
3054 }
3055 Ok(LltInfo {
3056 dynamic_regularization_count,
3057 })
3058 }
3059
3060 #[math]
3072 pub fn factorize_supernodal_numeric_ldlt<I: Index, T: ComplexField>(
3073 L_values: &mut [T],
3074 A_lower: SparseColMatRef<'_, I, T>,
3075 regularization: LdltRegularization<'_, T::Real>,
3076 symbolic: &SymbolicSupernodalCholesky<I>,
3077 par: Par,
3078 stack: &mut MemStack,
3079 params: Spec<LdltParams, T>,
3080 ) -> Result<LdltInfo, LdltError> {
3081 let n_supernodes = symbolic.n_supernodes();
3082 let n = symbolic.nrows();
3083 let mut dynamic_regularization_count = 0usize;
3084 L_values.fill(zero());
3085
3086 assert!(A_lower.nrows() == n);
3087 assert!(A_lower.ncols() == n);
3088 assert!(L_values.len() == symbolic.len_val());
3089
3090 let none = I::Signed::truncate(NONE);
3091
3092 let post = &*symbolic.supernode_postorder;
3093 let post_inv = &*symbolic.supernode_postorder_inv;
3094
3095 let desc_count = &*symbolic.descendant_count;
3096
3097 let col_ptr_row = &*symbolic.col_ptr_for_row_idx;
3098 let col_ptr_val = &*symbolic.col_ptr_for_val;
3099 let row_idx = &*symbolic.row_idx;
3100
3101 let (global_to_local, stack) = unsafe { stack.make_raw::<I::Signed>(n) };
3103 global_to_local.fill(I::Signed::truncate(NONE));
3104
3105 for s in 0..n_supernodes {
3106 let s_start = symbolic.supernode_begin[s].zx();
3107 let s_end = symbolic.supernode_begin[s + 1].zx();
3108 let s_pattern = if let Some(nnz_per_super) = symbolic.nnz_per_super.as_deref() {
3109 &row_idx[col_ptr_row[s].zx()..][..nnz_per_super[s].zx()]
3110 } else {
3111 &row_idx[col_ptr_row[s].zx()..col_ptr_row[s + 1].zx()]
3112 };
3113
3114 let s_ncols = s_end - s_start;
3115 let s_nrows = s_pattern.len() + s_ncols;
3116
3117 for (i, &row) in s_pattern.iter().enumerate() {
3118 global_to_local[row.zx()] = I::Signed::truncate(i + s_ncols);
3119 }
3120
3121 let (head, tail) = L_values.split_at_mut(col_ptr_val[s].zx());
3122 let head = head.rb();
3123 let mut Ls = MatMut::from_column_major_slice_mut(&mut tail[..col_ptr_val[s + 1].zx() - col_ptr_val[s].zx()], s_nrows, s_ncols);
3124
3125 for j in s_start..s_end {
3126 let j_shifted = j - s_start;
3127 for (i, val) in iter::zip(A_lower.row_idx_of_col(j), A_lower.val_of_col(j)) {
3128 if i < j {
3129 continue;
3130 }
3131
3132 let (ix, iy) = if i >= s_end {
3133 (global_to_local[i].sx(), j_shifted)
3134 } else {
3135 (i - s_start, j_shifted)
3136 };
3137 Ls[(ix, iy)] = Ls[(ix, iy)] + *val;
3138 }
3139 }
3140
3141 let s_postordered = post_inv[s].zx();
3142 let desc_count = desc_count[s].zx();
3143 for d in &post[s_postordered - desc_count..s_postordered] {
3144 let d = d.zx();
3145 let d_start = symbolic.supernode_begin[d].zx();
3146 let d_end = symbolic.supernode_begin[d + 1].zx();
3147 let d_pattern = if let Some(nnz_per_super) = symbolic.nnz_per_super.as_deref() {
3148 &row_idx[col_ptr_row[d].zx()..][..nnz_per_super[d].zx()]
3149 } else {
3150 &row_idx[col_ptr_row[d].zx()..col_ptr_row[d + 1].zx()]
3151 };
3152
3153 let d_ncols = d_end - d_start;
3154 let d_nrows = d_pattern.len() + d_ncols;
3155
3156 let Ld = MatRef::from_column_major_slice(&head[col_ptr_val[d].zx()..col_ptr_val[d + 1].zx()], d_nrows, d_ncols);
3157
3158 let d_pattern_start = d_pattern.partition_point(partition_fn(s_start));
3159 let d_pattern_mid_len = d_pattern[d_pattern_start..].partition_point(partition_fn(s_end));
3160
3161 let (Ld_top, Ld_mid_bot) = Ld.split_at_row(d_ncols);
3162 let (_, Ld_mid_bot) = Ld_mid_bot.split_at_row(d_pattern_start);
3163 let (Ld_mid, _) = Ld_mid_bot.split_at_row(d_pattern_mid_len);
3164 let D = Ld_top.diagonal().column_vector();
3165
3166 use linalg::matmul::triangular;
3167 let (row_idx, stack) = stack.make_with(Ld_mid_bot.nrows(), |i| {
3168 if i < d_pattern_mid_len {
3169 I::truncate(d_pattern[d_pattern_start + i].zx() - s_start)
3170 } else {
3171 I::from_signed(global_to_local[d_pattern[d_pattern_start + i].zx()])
3172 }
3173 });
3174 let (col_idx, stack) = stack.make_with(d_pattern_mid_len, |j| I::truncate(d_pattern[d_pattern_start + j].zx() - s_start));
3175
3176 spicy_matmul(
3177 Ls.rb_mut(),
3178 triangular::BlockStructure::TriangularLower,
3179 Some(&row_idx),
3180 Some(&col_idx),
3181 Accum::Add,
3182 Ld_mid_bot,
3183 Conj::No,
3184 Ld_mid.transpose(),
3185 Conj::Yes,
3186 Some(D.as_diagonal()),
3187 -one::<T>(),
3188 par,
3189 stack,
3190 );
3191 }
3192
3193 let (mut Ls_top, mut Ls_bot) = Ls.rb_mut().split_at_row_mut(s_ncols);
3194
3195 dynamic_regularization_count += match linalg::cholesky::ldlt::factor::cholesky_in_place(
3196 Ls_top.rb_mut(),
3197 LdltRegularization {
3198 dynamic_regularization_signs: regularization.dynamic_regularization_signs.map(|signs| &signs[s_start..s_end]),
3199 ..regularization.clone()
3200 },
3201 par,
3202 stack,
3203 params,
3204 ) {
3205 Ok(count) => count.dynamic_regularization_count,
3206 Err(LdltError::ZeroPivot { index }) => {
3207 return Err(LdltError::ZeroPivot { index: index + s_start });
3208 },
3209 };
3210 z!(Ls_top.rb_mut()).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero::<T>());
3211 linalg::triangular_solve::solve_unit_lower_triangular_in_place(Ls_top.rb().conjugate(), Ls_bot.rb_mut().transpose_mut(), par);
3212 for j in 0..s_ncols {
3213 let d = recip(real(Ls_top[(j, j)]));
3214 for i in 0..s_pattern.len() {
3215 Ls_bot[(i, j)] = mul_real(Ls_bot[(i, j)], d);
3216 }
3217 }
3218
3219 for &row in s_pattern {
3220 global_to_local[row.zx()] = none;
3221 }
3222 }
3223 Ok(LdltInfo {
3224 dynamic_regularization_count,
3225 })
3226 }
3227
3228 #[math]
3240 pub fn factorize_supernodal_numeric_intranode_lblt<I: Index, T: ComplexField>(
3241 L_values: &mut [T],
3242 subdiag: &mut [T],
3243 perm_forward: &mut [I],
3244 perm_inverse: &mut [I],
3245 A_lower: SparseColMatRef<'_, I, T>,
3246 symbolic: &SymbolicSupernodalCholesky<I>,
3247 par: Par,
3248 stack: &mut MemStack,
3249 params: Spec<LbltParams, T>,
3250 ) -> LbltInfo {
3251 let n_supernodes = symbolic.n_supernodes();
3252 let n = symbolic.nrows();
3253 let mut transposition_count = 0usize;
3254 L_values.fill(zero());
3255
3256 assert!(A_lower.nrows() == n);
3257 assert!(A_lower.ncols() == n);
3258 assert!(perm_forward.len() == n);
3259 assert!(perm_inverse.len() == n);
3260 assert!(subdiag.len() == n);
3261 assert!(L_values.len() == symbolic.len_val());
3262
3263 let none = I::Signed::truncate(NONE);
3264
3265 let post = &*symbolic.supernode_postorder;
3266 let post_inv = &*symbolic.supernode_postorder_inv;
3267
3268 let desc_count = &*symbolic.descendant_count;
3269
3270 let col_ptr_row = &*symbolic.col_ptr_for_row_idx;
3271 let col_ptr_val = &*symbolic.col_ptr_for_val;
3272 let row_idx = &*symbolic.row_idx;
3273
3274 let (global_to_local, stack) = unsafe { stack.make_raw::<I::Signed>(n) };
3276 global_to_local.fill(I::Signed::truncate(NONE));
3277
3278 for s in 0..n_supernodes {
3279 let s_start = symbolic.supernode_begin[s].zx();
3280 let s_end = symbolic.supernode_begin[s + 1].zx();
3281
3282 let s_pattern = &row_idx[col_ptr_row[s].zx()..col_ptr_row[s + 1].zx()];
3283 let s_ncols = s_end - s_start;
3284 let s_nrows = s_pattern.len() + s_ncols;
3285
3286 for (i, &row) in s_pattern.iter().enumerate() {
3287 global_to_local[row.zx()] = I::Signed::truncate(i + s_ncols);
3288 }
3289
3290 let (head, tail) = L_values.split_at_mut(col_ptr_val[s].zx());
3291 let head = head.rb();
3292 let mut Ls = MatMut::from_column_major_slice_mut(&mut tail[..col_ptr_val[s + 1].zx() - col_ptr_val[s].zx()], s_nrows, s_ncols);
3293
3294 for j in s_start..s_end {
3295 let j_shifted = j - s_start;
3296 for (i, val) in iter::zip(A_lower.row_idx_of_col(j), A_lower.val_of_col(j)) {
3297 if i < j {
3298 continue;
3299 }
3300
3301 let (ix, iy) = if i >= s_end {
3302 (global_to_local[i].sx(), j_shifted)
3303 } else {
3304 (i - s_start, j_shifted)
3305 };
3306 Ls[(ix, iy)] = Ls[(ix, iy)] + *val;
3307 }
3308 }
3309
3310 let s_postordered = post_inv[s].zx();
3311 let desc_count = desc_count[s].zx();
3312 for d in &post[s_postordered - desc_count..s_postordered] {
3313 let d = d.zx();
3314 let d_start = symbolic.supernode_begin[d].zx();
3315 let d_end = symbolic.supernode_begin[d + 1].zx();
3316
3317 let d_pattern = &row_idx[col_ptr_row[d].zx()..col_ptr_row[d + 1].zx()];
3318 let d_ncols = d_end - d_start;
3319 let d_nrows = d_pattern.len() + d_ncols;
3320
3321 let Ld = MatRef::from_column_major_slice(&head[col_ptr_val[d].zx()..col_ptr_val[d + 1].zx()], d_nrows, d_ncols);
3322
3323 let d_pattern_start = d_pattern.partition_point(partition_fn(s_start));
3324 let d_pattern_mid_len = d_pattern[d_pattern_start..].partition_point(partition_fn(s_end));
3325 let d_pattern_mid = d_pattern_start + d_pattern_mid_len;
3326
3327 let (Ld_top, Ld_mid_bot) = Ld.split_at_row(d_ncols);
3328 let (_, Ld_mid_bot) = Ld_mid_bot.split_at_row(d_pattern_start);
3329 let (Ld_mid, Ld_bot) = Ld_mid_bot.split_at_row(d_pattern_mid_len);
3330 let d_subdiag = &subdiag[d_start..d_start + d_ncols];
3331
3332 let (mut tmp, stack) = unsafe { temp_mat_uninit::<T, _, _>(Ld_mid_bot.nrows(), d_pattern_mid_len, stack) };
3333 let (mut tmp2, _) = unsafe { temp_mat_uninit::<T, _, _>(Ld_mid.ncols(), Ld_mid.nrows(), stack) };
3334 let tmp = tmp.as_mat_mut();
3335 let mut Ld_mid_x_D = tmp2.as_mat_mut().transpose_mut();
3336
3337 let mut j = 0;
3338 while j < d_ncols {
3339 let subdiag = copy(d_subdiag[j]);
3340 if subdiag == zero::<T>() {
3341 let d = real(Ld_top[(j, j)]);
3342 for i in 0..d_pattern_mid_len {
3343 Ld_mid_x_D[(i, j)] = mul_real(Ld_mid[(i, j)], d);
3344 }
3345 j += 1;
3346 } else {
3347 let akp1k = subdiag;
3348 let ak = real(Ld_top[(j, j)]);
3349 let akp1 = real(Ld_top[(j + 1, j + 1)]);
3350
3351 for i in 0..d_pattern_mid_len {
3352 let xk = copy(Ld_mid[(i, j)]);
3353 let xkp1 = copy(Ld_mid[(i, j + 1)]);
3354
3355 Ld_mid_x_D[(i, j)] = mul_real(xk, ak) + xkp1 * akp1k;
3356 Ld_mid_x_D[(i, j + 1)] = mul_real(xkp1, akp1) + xk * conj(akp1k);
3357 }
3358 j += 2;
3359 }
3360 }
3361
3362 let (mut tmp_top, mut tmp_bot) = tmp.split_at_row_mut(d_pattern_mid_len);
3363
3364 use linalg::matmul;
3365 use linalg::matmul::triangular;
3366 triangular::matmul(
3367 tmp_top.rb_mut(),
3368 triangular::BlockStructure::TriangularLower,
3369 Accum::Replace,
3370 Ld_mid,
3371 triangular::BlockStructure::Rectangular,
3372 Ld_mid_x_D.rb().adjoint(),
3373 triangular::BlockStructure::Rectangular,
3374 one::<T>(),
3375 par,
3376 );
3377 matmul::matmul(tmp_bot.rb_mut(), Accum::Replace, Ld_bot, Ld_mid_x_D.rb().adjoint(), one::<T>(), par);
3378
3379 for (j_idx, j) in d_pattern[d_pattern_start..d_pattern_mid].iter().enumerate() {
3380 let j = j.zx();
3381 let j_s = j - s_start;
3382 for (i_idx, i) in d_pattern[d_pattern_start..d_pattern_mid][j_idx..].iter().enumerate() {
3383 let i_idx = i_idx + j_idx;
3384
3385 let i = i.zx();
3386 let i_s = i - s_start;
3387
3388 debug_assert!(i_s >= j_s);
3389 Ls[(i_s, j_s)] = Ls[(i_s, j_s)] - tmp_top[(i_idx, j_idx)];
3390 }
3391 }
3392
3393 for (j_idx, j) in d_pattern[d_pattern_start..d_pattern_mid].iter().enumerate() {
3394 let j = j.zx();
3395 let j_s = j - s_start;
3396 for (i_idx, i) in d_pattern[d_pattern_mid..].iter().enumerate() {
3397 let i = i.zx();
3398 let i_s = global_to_local[i].zx();
3399 Ls[(i_s, j_s)] = Ls[(i_s, j_s)] - tmp_bot[(i_idx, j_idx)];
3400 }
3401 }
3402 }
3403
3404 let (mut Ls_top, mut Ls_bot) = Ls.rb_mut().split_at_row_mut(s_ncols);
3405 let s_subdiag = &mut subdiag[s_start..s_end];
3406
3407 let (info, perm) = linalg::cholesky::lblt::factor::cholesky_in_place(
3408 Ls_top.rb_mut(),
3409 ColMut::from_slice_mut(s_subdiag).as_diagonal_mut(),
3410 &mut perm_forward[s_start..s_end],
3411 &mut perm_inverse[s_start..s_end],
3412 par,
3413 stack,
3414 params,
3415 );
3416 transposition_count += info.transposition_count;
3417 z!(Ls_top.rb_mut()).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = zero::<T>());
3418
3419 crate::perm::permute_cols_in_place(Ls_bot.rb_mut(), perm.rb(), stack);
3420
3421 for p in &mut perm_forward[s_start..s_end] {
3422 *p += I::truncate(s_start);
3423 }
3424 for p in &mut perm_inverse[s_start..s_end] {
3425 *p += I::truncate(s_start);
3426 }
3427
3428 linalg::triangular_solve::solve_unit_lower_triangular_in_place(Ls_top.rb().conjugate(), Ls_bot.rb_mut().transpose_mut(), par);
3429
3430 let mut j = 0;
3431 while j < s_ncols {
3432 if s_subdiag[j] == zero::<T>() {
3433 let d = recip(real(Ls_top[(j, j)]));
3434 for i in 0..s_pattern.len() {
3435 Ls_bot[(i, j)] = mul_real(Ls_bot[(i, j)], d);
3436 }
3437 j += 1;
3438 } else {
3439 let akp1k = recip(conj(s_subdiag[j]));
3440 let ak = mul_real(conj(akp1k), real(Ls_top[(j, j)]));
3441 let akp1 = mul_real(akp1k, real(Ls_top[(j + 1, j + 1)]));
3442
3443 let denom = recip(ak * akp1 - one::<T>());
3444
3445 for i in 0..s_pattern.len() {
3446 let xk = Ls_bot[(i, j)] * conj(akp1k);
3447 let xkp1 = Ls_bot[(i, j + 1)] * akp1k;
3448
3449 Ls_bot[(i, j)] = (akp1 * xk - xkp1) * denom;
3450 Ls_bot[(i, j + 1)] = (ak * xkp1 - xk) * denom;
3451 }
3452 j += 2;
3453 }
3454 }
3455
3456 for &row in s_pattern {
3457 global_to_local[row.zx()] = none;
3458 }
3459 }
3460 LbltInfo { transposition_count }
3461 }
3462}
3463
3464fn postorder_depth_first_search<'n, I: Index>(
3465 post: &mut Array<'n, I>,
3466 root: usize,
3467 mut start_index: usize,
3468 stack: &mut Array<'n, I>,
3469 first_child: &mut Array<'n, MaybeIdx<'n, I>>,
3470 next_child: &Array<'n, I::Signed>,
3471) -> usize {
3472 let mut top = 1usize;
3473 let N = post.len();
3474
3475 stack[N.check(0)] = I::truncate(root);
3476 while top != 0 {
3477 let current_node = stack[N.check(top - 1)].zx();
3478 let first_child = &mut first_child[N.check(current_node)];
3479 let current_child = (*first_child).sx();
3480
3481 if let Some(current_child) = current_child.idx() {
3482 stack[N.check(top)] = *current_child.truncate::<I>();
3483 top += 1;
3484 *first_child = MaybeIdx::new_checked(next_child[current_child], N);
3485 } else {
3486 post[N.check(start_index)] = I::truncate(current_node);
3487 start_index += 1;
3488 top -= 1;
3489 }
3490 }
3491 start_index
3492}
3493
3494pub(crate) fn ghost_postorder<'n, I: Index>(post: &mut Array<'n, I>, etree: &Array<'n, MaybeIdx<'n, I>>, stack: &mut MemStack) {
3495 let N = post.len();
3496 let n = *N;
3497
3498 if n == 0 {
3499 return;
3500 }
3501
3502 let (stack_, stack) = unsafe { stack.make_raw::<I>(n) };
3503 let (first_child, stack) = unsafe { stack.make_raw::<I::Signed>(n) };
3504 let (next_child, _) = unsafe { stack.make_raw::<I::Signed>(n) };
3505
3506 let stack = Array::from_mut(stack_, N);
3507 let next_child = Array::from_mut(next_child, N);
3508 let first_child = Array::from_mut(ghost::fill_none::<I>(first_child, N), N);
3509
3510 for j in N.indices().rev() {
3511 let parent = etree[j];
3512 if let Some(parent) = parent.idx() {
3513 let first = &mut first_child[parent.zx()];
3514 next_child[j] = **first;
3515 *first = MaybeIdx::from_index(j.truncate::<I>());
3516 }
3517 }
3518
3519 let mut start_index = 0usize;
3520 for (root, &parent) in etree.as_ref().iter().enumerate() {
3521 if parent.idx().is_none() {
3522 start_index = postorder_depth_first_search(post, root, start_index, stack, first_child, next_child);
3523 }
3524 }
3525}
3526
3527#[derive(Copy, Clone, Debug, Default)]
3529pub struct CholeskySymbolicParams<'a> {
3530 pub amd_params: amd::Control,
3532 pub supernodal_flop_ratio_threshold: SupernodalThreshold,
3534 pub supernodal_params: SymbolicSupernodalParams<'a>,
3536}
3537
3538#[derive(Debug)]
3540pub enum SymbolicCholeskyRaw<I> {
3541 Simplicial(simplicial::SymbolicSimplicialCholesky<I>),
3543 Supernodal(supernodal::SymbolicSupernodalCholesky<I>),
3545}
3546
3547#[derive(Debug)]
3549pub struct SymbolicCholesky<I> {
3550 raw: SymbolicCholeskyRaw<I>,
3551 perm_fwd: Option<alloc::vec::Vec<I>>,
3552 perm_inv: Option<alloc::vec::Vec<I>>,
3553 A_nnz: usize,
3554}
3555
3556impl<I: Index> SymbolicCholesky<I> {
3557 #[inline]
3559 pub fn nrows(&self) -> usize {
3560 match &self.raw {
3561 SymbolicCholeskyRaw::Simplicial(this) => this.nrows(),
3562 SymbolicCholeskyRaw::Supernodal(this) => this.nrows(),
3563 }
3564 }
3565
3566 #[inline]
3568 pub fn ncols(&self) -> usize {
3569 self.nrows()
3570 }
3571
3572 #[inline]
3574 pub fn raw(&self) -> &SymbolicCholeskyRaw<I> {
3575 &self.raw
3576 }
3577
3578 #[inline]
3580 pub fn perm(&self) -> Option<PermRef<'_, I>> {
3581 match (&self.perm_fwd, &self.perm_inv) {
3582 (Some(perm_fwd), Some(perm_inv)) => unsafe { Some(PermRef::new_unchecked(perm_fwd, perm_inv, self.ncols())) },
3583 _ => None,
3584 }
3585 }
3586
3587 #[inline]
3590 pub fn len_val(&self) -> usize {
3591 match &self.raw {
3592 SymbolicCholeskyRaw::Simplicial(this) => this.len_val(),
3593 SymbolicCholeskyRaw::Supernodal(this) => this.len_val(),
3594 }
3595 }
3596
3597 #[inline]
3599 pub fn factorize_numeric_llt_scratch<T: ComplexField>(&self, par: Par, params: Spec<LltParams, T>) -> StackReq {
3600 let n = self.nrows();
3601 let A_nnz = self.A_nnz;
3602
3603 let n_scratch = StackReq::new::<I>(n);
3604 let A_scratch = StackReq::all_of(&[temp_mat_scratch::<T>(A_nnz, 1), StackReq::new::<I>(n + 1), StackReq::new::<I>(A_nnz)]);
3605 let permute_scratch = n_scratch;
3606
3607 let factor_scratch = match &self.raw {
3608 SymbolicCholeskyRaw::Simplicial(_) => simplicial::factorize_simplicial_numeric_llt_scratch::<I, T>(n),
3609 SymbolicCholeskyRaw::Supernodal(this) => supernodal::factorize_supernodal_numeric_llt_scratch::<I, T>(this, par, params),
3610 };
3611
3612 StackReq::all_of(&[A_scratch, StackReq::or(permute_scratch, factor_scratch)])
3613 }
3614
3615 #[inline]
3617 pub fn factorize_numeric_ldlt_scratch<T: ComplexField>(&self, par: Par, params: Spec<LdltParams, T>) -> StackReq {
3618 let n = self.nrows();
3619 let A_nnz = self.A_nnz;
3620
3621 let regularization_signs = StackReq::new::<i8>(n);
3622
3623 let n_scratch = StackReq::new::<I>(n);
3624 let A_scratch = StackReq::all_of(&[temp_mat_scratch::<T>(A_nnz, 1), StackReq::new::<I>(n + 1), StackReq::new::<I>(A_nnz)]);
3625 let permute_scratch = n_scratch;
3626
3627 let factor_scratch = match &self.raw {
3628 SymbolicCholeskyRaw::Simplicial(_) => simplicial::factorize_simplicial_numeric_ldlt_scratch::<I, T>(n),
3629 SymbolicCholeskyRaw::Supernodal(this) => supernodal::factorize_supernodal_numeric_ldlt_scratch::<I, T>(this, par, params),
3630 };
3631
3632 StackReq::all_of(&[regularization_signs, A_scratch, StackReq::or(permute_scratch, factor_scratch)])
3633 }
3634
3635 #[inline]
3638 pub fn factorize_numeric_intranode_lblt_scratch<T: ComplexField>(&self, par: Par, params: Spec<LbltParams, T>) -> StackReq {
3639 let n = self.nrows();
3640 let A_nnz = self.A_nnz;
3641
3642 let regularization_signs = StackReq::new::<i8>(n);
3643
3644 let n_scratch = StackReq::new::<I>(n);
3645 let A_scratch = StackReq::all_of(&[temp_mat_scratch::<T>(A_nnz, 1), StackReq::new::<I>(n + 1), StackReq::new::<I>(A_nnz)]);
3646 let permute_scratch = n_scratch;
3647
3648 let factor_scratch = match &self.raw {
3649 SymbolicCholeskyRaw::Simplicial(_) => simplicial::factorize_simplicial_numeric_ldlt_scratch::<I, T>(n),
3650 SymbolicCholeskyRaw::Supernodal(this) => supernodal::factorize_supernodal_numeric_intranode_lblt_scratch::<I, T>(this, par, params),
3651 };
3652
3653 StackReq::all_of(&[regularization_signs, A_scratch, StackReq::or(permute_scratch, factor_scratch)])
3654 }
3655
3656 #[track_caller]
3659 pub fn factorize_numeric_llt<'out, T: ComplexField>(
3660 &'out self,
3661 L_values: &'out mut [T],
3662 A: SparseColMatRef<'_, I, T>,
3663 side: Side,
3664 regularization: LltRegularization<T::Real>,
3665 par: Par,
3666 stack: &mut MemStack,
3667 params: Spec<LltParams, T>,
3668 ) -> Result<LltRef<'out, I, T>, LltError> {
3669 assert!(A.nrows() == A.ncols());
3670 let n = A.nrows();
3671 with_dim!(N, n);
3672
3673 let A_nnz = self.A_nnz;
3674 let A = A.as_shape(N, N);
3675
3676 let (mut new_values, stack) = unsafe { temp_mat_uninit::<T, _, _>(A_nnz, 1, stack) };
3677 let new_values = new_values.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap().as_slice_mut();
3678 let (new_col_ptr, stack) = unsafe { stack.make_raw::<I>(n + 1) };
3679 let (new_row_idx, stack) = unsafe { stack.make_raw::<I>(A_nnz) };
3680
3681 let out_side = match &self.raw {
3682 SymbolicCholeskyRaw::Simplicial(_) => Side::Upper,
3683 SymbolicCholeskyRaw::Supernodal(_) => Side::Lower,
3684 };
3685
3686 let A = match self.perm() {
3687 Some(perm) => {
3688 let perm = perm.as_shape(N);
3689 permute_self_adjoint_to_unsorted(new_values, new_col_ptr, new_row_idx, A, perm, side, out_side, stack).into_const()
3690 },
3691 None => {
3692 if side == out_side {
3693 A
3694 } else {
3695 adjoint(new_values, new_col_ptr, new_row_idx, A, stack).into_const()
3696 }
3697 },
3698 };
3699
3700 match &self.raw {
3701 SymbolicCholeskyRaw::Simplicial(this) => {
3702 simplicial::factorize_simplicial_numeric_llt(L_values, A.as_dyn().into_const(), regularization, this, stack)?;
3703 },
3704 SymbolicCholeskyRaw::Supernodal(this) => {
3705 supernodal::factorize_supernodal_numeric_llt(L_values, A.as_dyn().into_const(), regularization, this, par, stack, params)?;
3706 },
3707 }
3708 Ok(LltRef::<'out, I, T>::new(self, L_values))
3709 }
3710
3711 #[inline]
3713 pub fn factorize_numeric_ldlt<'out, T: ComplexField>(
3714 &'out self,
3715 L_values: &'out mut [T],
3716 A: SparseColMatRef<'_, I, T>,
3717 side: Side,
3718 regularization: LdltRegularization<'_, T::Real>,
3719 par: Par,
3720 stack: &mut MemStack,
3721 params: Spec<LdltParams, T>,
3722 ) -> Result<LdltRef<'out, I, T>, LdltError> {
3723 assert!(A.nrows() == A.ncols());
3724 let n = A.nrows();
3725
3726 with_dim!(N, n);
3727 let A_nnz = self.A_nnz;
3728 let A = A.as_shape(N, N);
3729
3730 let (new_signs, stack) = unsafe {
3731 stack.make_raw::<i8>(if regularization.dynamic_regularization_signs.is_some() && self.perm().is_some() {
3732 n
3733 } else {
3734 0
3735 })
3736 };
3737
3738 let (mut new_values, stack) = unsafe { temp_mat_uninit::<T, _, _>(A_nnz, 1, stack) };
3739 let new_values = new_values.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap().as_slice_mut();
3740 let (new_col_ptr, stack) = unsafe { stack.make_raw::<I>(n + 1) };
3741 let (new_row_idx, stack) = unsafe { stack.make_raw::<I>(A_nnz) };
3742
3743 let out_side = match &self.raw {
3744 SymbolicCholeskyRaw::Simplicial(_) => Side::Upper,
3745 SymbolicCholeskyRaw::Supernodal(_) => Side::Lower,
3746 };
3747
3748 let (A, signs) = match self.perm() {
3749 Some(perm) => {
3750 let perm = perm.as_shape(N);
3751 let A = permute_self_adjoint_to_unsorted(new_values, new_col_ptr, new_row_idx, A, perm, side, out_side, stack).into_const();
3752 let fwd = perm.bound_arrays().0;
3753 let signs = regularization.dynamic_regularization_signs.map(|signs| {
3754 {
3755 let new_signs = Array::from_mut(new_signs, N);
3756 let signs = Array::from_ref(signs, N);
3757 for i in N.indices() {
3758 new_signs[i] = signs[fwd[i].zx()];
3759 }
3760 }
3761 &*new_signs
3762 });
3763
3764 (A, signs)
3765 },
3766 None => {
3767 if side == out_side {
3768 (A, regularization.dynamic_regularization_signs)
3769 } else {
3770 (
3771 adjoint(new_values, new_col_ptr, new_row_idx, A, stack).into_const(),
3772 regularization.dynamic_regularization_signs,
3773 )
3774 }
3775 },
3776 };
3777
3778 let regularization = LdltRegularization {
3779 dynamic_regularization_signs: signs,
3780 ..regularization
3781 };
3782
3783 match &self.raw {
3784 SymbolicCholeskyRaw::Simplicial(this) => {
3785 simplicial::factorize_simplicial_numeric_ldlt(L_values, A.as_dyn().into_const(), regularization, this, stack)?;
3786 },
3787 SymbolicCholeskyRaw::Supernodal(this) => {
3788 supernodal::factorize_supernodal_numeric_ldlt(L_values, A.as_dyn().into_const(), regularization, this, par, stack, params)?;
3789 },
3790 }
3791
3792 Ok(LdltRef::<'out, I, T>::new(self, L_values))
3793 }
3794
3795 #[inline]
3797 pub fn factorize_numeric_intranode_lblt<'out, T: ComplexField>(
3798 &'out self,
3799 L_values: &'out mut [T],
3800 subdiag: &'out mut [T],
3801 perm_forward: &'out mut [I],
3802 perm_inverse: &'out mut [I],
3803 A: SparseColMatRef<'_, I, T>,
3804 side: Side,
3805 par: Par,
3806 stack: &mut MemStack,
3807 params: Spec<LbltParams, T>,
3808 ) -> IntranodeLbltRef<'out, I, T> {
3809 assert!(A.nrows() == A.ncols());
3810 let n = A.nrows();
3811
3812 with_dim!(N, n);
3813 let A_nnz = self.A_nnz;
3814 let A = A.as_shape(N, N);
3815
3816 let (mut new_values, stack) = unsafe { temp_mat_uninit::<T, _, _>(A_nnz, 1, stack) };
3817 let new_values = new_values.as_mat_mut().col_mut(0).try_as_col_major_mut().unwrap().as_slice_mut();
3818 let (new_col_ptr, stack) = unsafe { stack.make_raw::<I>(n + 1) };
3819 let (new_row_idx, stack) = unsafe { stack.make_raw::<I>(A_nnz) };
3820
3821 let out_side = match &self.raw {
3822 SymbolicCholeskyRaw::Simplicial(_) => Side::Upper,
3823 SymbolicCholeskyRaw::Supernodal(_) => Side::Lower,
3824 };
3825
3826 let A = match self.perm() {
3827 Some(perm) => {
3828 let perm = perm.as_shape(N);
3829 let A = permute_self_adjoint_to_unsorted(new_values, new_col_ptr, new_row_idx, A, perm, side, out_side, stack).into_const();
3830
3831 A
3832 },
3833 None => {
3834 if side == out_side {
3835 A
3836 } else {
3837 adjoint(new_values, new_col_ptr, new_row_idx, A, stack).into_const()
3838 }
3839 },
3840 };
3841
3842 match &self.raw {
3843 SymbolicCholeskyRaw::Simplicial(this) => {
3844 let regularization = LdltRegularization::default();
3845 for (i, p) in perm_forward.iter_mut().enumerate() {
3846 *p = I::truncate(i);
3847 }
3848 for (i, p) in perm_inverse.iter_mut().enumerate() {
3849 *p = I::truncate(i);
3850 }
3851 let _ = simplicial::factorize_simplicial_numeric_ldlt(L_values, A.as_dyn().into_const(), regularization, this, stack);
3852 },
3853 SymbolicCholeskyRaw::Supernodal(this) => {
3854 supernodal::factorize_supernodal_numeric_intranode_lblt(
3855 L_values,
3856 subdiag,
3857 perm_forward,
3858 perm_inverse,
3859 A.as_dyn().into_const(),
3860 this,
3861 par,
3862 stack,
3863 params,
3864 );
3865 },
3866 }
3867
3868 IntranodeLbltRef::<'out, I, T>::new(self, L_values, subdiag, unsafe {
3869 PermRef::<'out, I>::new_unchecked(perm_forward, perm_inverse, n)
3870 })
3871 }
3872
3873 pub fn solve_in_place_scratch<T: ComplexField>(&self, rhs_ncols: usize, par: Par) -> StackReq {
3876 temp_mat_scratch::<T>(self.nrows(), rhs_ncols).and(match self.raw() {
3877 SymbolicCholeskyRaw::Simplicial(this) => this.solve_in_place_scratch::<T>(rhs_ncols),
3878 SymbolicCholeskyRaw::Supernodal(this) => this.solve_in_place_scratch::<T>(rhs_ncols, par),
3879 })
3880 }
3881}
3882
3883#[derive(Debug)]
3885pub struct LltRef<'a, I: Index, T> {
3886 symbolic: &'a SymbolicCholesky<I>,
3887 values: &'a [T],
3888}
3889
3890#[derive(Debug)]
3892pub struct LdltRef<'a, I: Index, T> {
3893 symbolic: &'a SymbolicCholesky<I>,
3894 values: &'a [T],
3895}
3896
3897#[derive(Debug)]
3899pub struct IntranodeLbltRef<'a, I: Index, T> {
3900 symbolic: &'a SymbolicCholesky<I>,
3901 values: &'a [T],
3902 subdiag: &'a [T],
3903 perm: PermRef<'a, I>,
3904}
3905
3906impl<'a, I: Index, T> core::ops::Deref for LltRef<'a, I, T> {
3907 type Target = SymbolicCholesky<I>;
3908
3909 #[inline]
3910 fn deref(&self) -> &Self::Target {
3911 self.symbolic
3912 }
3913}
3914impl<'a, I: Index, T> core::ops::Deref for LdltRef<'a, I, T> {
3915 type Target = SymbolicCholesky<I>;
3916
3917 #[inline]
3918 fn deref(&self) -> &Self::Target {
3919 self.symbolic
3920 }
3921}
3922impl<'a, I: Index, T> core::ops::Deref for IntranodeLbltRef<'a, I, T> {
3923 type Target = SymbolicCholesky<I>;
3924
3925 #[inline]
3926 fn deref(&self) -> &Self::Target {
3927 self.symbolic
3928 }
3929}
3930
3931impl<'a, I: Index, T> Copy for LltRef<'a, I, T> {}
3932impl<'a, I: Index, T> Copy for LdltRef<'a, I, T> {}
3933impl<'a, I: Index, T> Copy for IntranodeLbltRef<'a, I, T> {}
3934
3935impl<'a, I: Index, T> Clone for LltRef<'a, I, T> {
3936 fn clone(&self) -> Self {
3937 *self
3938 }
3939}
3940impl<'a, I: Index, T> Clone for LdltRef<'a, I, T> {
3941 fn clone(&self) -> Self {
3942 *self
3943 }
3944}
3945impl<'a, I: Index, T> Clone for IntranodeLbltRef<'a, I, T> {
3946 fn clone(&self) -> Self {
3947 *self
3948 }
3949}
3950
3951impl<'a, I: Index, T> IntranodeLbltRef<'a, I, T> {
3952 #[inline]
3960 pub fn new(symbolic: &'a SymbolicCholesky<I>, values: &'a [T], subdiag: &'a [T], perm: PermRef<'a, I>) -> Self {
3961 assert!(all(
3962 values.len() == symbolic.len_val(),
3963 subdiag.len() == symbolic.nrows(),
3964 perm.len() == symbolic.nrows(),
3965 ));
3966 Self {
3967 symbolic,
3968 values,
3969 subdiag,
3970 perm,
3971 }
3972 }
3973
3974 #[inline]
3976 pub fn symbolic(self) -> &'a SymbolicCholesky<I> {
3977 self.symbolic
3978 }
3979
3980 pub fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
3986 where
3987 T: ComplexField,
3988 {
3989 let k = rhs.ncols();
3990 let n = self.symbolic.nrows();
3991
3992 let mut rhs = rhs;
3993
3994 let (mut x, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
3995 let mut x = x.as_mat_mut();
3996
3997 match self.symbolic.raw() {
3998 SymbolicCholeskyRaw::Simplicial(symbolic) => {
3999 let this = simplicial::SimplicialLdltRef::new(symbolic, self.values);
4000
4001 if let Some(perm) = self.symbolic.perm() {
4002 for j in 0..k {
4003 for (i, fwd) in perm.arrays().0.iter().enumerate() {
4004 x[(i, j)] = copy(&rhs[(fwd.zx(), j)]);
4005 }
4006 }
4007 }
4008 this.solve_in_place_with_conj(conj, if self.symbolic.perm().is_some() { x.rb_mut() } else { rhs.rb_mut() }, par, stack);
4009 if let Some(perm) = self.symbolic.perm() {
4010 for j in 0..k {
4011 for (i, inv) in perm.arrays().1.iter().enumerate() {
4012 rhs[(i, j)] = copy(&x[(inv.zx(), j)]);
4013 }
4014 }
4015 }
4016 },
4017 SymbolicCholeskyRaw::Supernodal(symbolic) => {
4018 let (dyn_fwd, dyn_inv) = self.perm.arrays();
4019 let (fwd, inv) = match self.symbolic.perm() {
4020 Some(perm) => {
4021 let (fwd, inv) = perm.arrays();
4022 (Some(fwd), Some(inv))
4023 },
4024 None => (None, None),
4025 };
4026
4027 if let Some(fwd) = fwd {
4028 for j in 0..k {
4029 for (i, dyn_fwd) in dyn_fwd.iter().enumerate() {
4030 x[(i, j)] = copy(&rhs[(fwd[dyn_fwd.zx()].zx(), j)]);
4031 }
4032 }
4033 } else {
4034 for j in 0..k {
4035 for (i, dyn_fwd) in dyn_fwd.iter().enumerate() {
4036 x[(i, j)] = copy(&rhs[(dyn_fwd.zx(), j)]);
4037 }
4038 }
4039 }
4040
4041 let this = supernodal::SupernodalIntranodeLbltRef::new(symbolic, self.values, self.subdiag, self.perm);
4042 this.solve_in_place_no_numeric_permute_with_conj(conj, x.rb_mut(), par, stack);
4043
4044 if let Some(inv) = inv {
4045 for j in 0..k {
4046 for (i, inv) in inv.iter().enumerate() {
4047 rhs[(i, j)] = copy(&x[(dyn_inv[inv.zx()].zx(), j)]);
4048 }
4049 }
4050 } else {
4051 for j in 0..k {
4052 for (i, dyn_inv) in dyn_inv.iter().enumerate() {
4053 rhs[(i, j)] = copy(&x[(dyn_inv.zx(), j)]);
4054 }
4055 }
4056 }
4057 },
4058 }
4059 }
4060}
4061
4062impl<'a, I: Index, T> LltRef<'a, I, T> {
4063 #[inline]
4069 pub fn new(symbolic: &'a SymbolicCholesky<I>, values: &'a [T]) -> Self {
4070 assert!(symbolic.len_val() == values.len());
4071 Self { symbolic, values }
4072 }
4073
4074 #[inline]
4076 pub fn symbolic(self) -> &'a SymbolicCholesky<I> {
4077 self.symbolic
4078 }
4079
4080 pub fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
4086 where
4087 T: ComplexField,
4088 {
4089 let k = rhs.ncols();
4090 let n = self.symbolic.nrows();
4091
4092 let mut rhs = rhs;
4093
4094 let (mut x, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
4095 let mut x = x.as_mat_mut();
4096
4097 if let Some(perm) = self.symbolic.perm() {
4098 for j in 0..k {
4099 for (i, fwd) in perm.arrays().0.iter().enumerate() {
4100 x[(i, j)] = copy(&rhs[(fwd.zx(), j)]);
4101 }
4102 }
4103 }
4104
4105 match self.symbolic.raw() {
4106 SymbolicCholeskyRaw::Simplicial(symbolic) => {
4107 let this = simplicial::SimplicialLltRef::new(symbolic, self.values);
4108 this.solve_in_place_with_conj(conj, if self.symbolic.perm().is_some() { x.rb_mut() } else { rhs.rb_mut() }, par, stack);
4109 },
4110 SymbolicCholeskyRaw::Supernodal(symbolic) => {
4111 let this = supernodal::SupernodalLltRef::new(symbolic, self.values);
4112 this.solve_in_place_with_conj(conj, if self.symbolic.perm().is_some() { x.rb_mut() } else { rhs.rb_mut() }, par, stack);
4113 },
4114 }
4115
4116 if let Some(perm) = self.symbolic.perm() {
4117 for j in 0..k {
4118 for (i, inv) in perm.arrays().1.iter().enumerate() {
4119 rhs[(i, j)] = copy(&x[(inv.zx(), j)]);
4120 }
4121 }
4122 }
4123 }
4124}
4125
4126impl<'a, I: Index, T> LdltRef<'a, I, T> {
4127 #[inline]
4133 pub fn new(symbolic: &'a SymbolicCholesky<I>, values: &'a [T]) -> Self {
4134 assert!(symbolic.len_val() == values.len());
4135 Self { symbolic, values }
4136 }
4137
4138 #[inline]
4140 pub fn symbolic(self) -> &'a SymbolicCholesky<I> {
4141 self.symbolic
4142 }
4143
4144 pub fn solve_in_place_with_conj(&self, conj: Conj, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack)
4150 where
4151 T: ComplexField,
4152 {
4153 let k = rhs.ncols();
4154 let n = self.symbolic.nrows();
4155
4156 let mut rhs = rhs;
4157
4158 let (mut x, stack) = unsafe { temp_mat_uninit::<T, _, _>(n, k, stack) };
4159 let mut x = x.as_mat_mut();
4160
4161 if let Some(perm) = self.symbolic.perm() {
4162 for j in 0..k {
4163 for (i, fwd) in perm.arrays().0.iter().enumerate() {
4164 x[(i, j)] = copy(&rhs[(fwd.zx(), j)]);
4165 }
4166 }
4167 }
4168
4169 match self.symbolic.raw() {
4170 SymbolicCholeskyRaw::Simplicial(symbolic) => {
4171 let this = simplicial::SimplicialLdltRef::new(symbolic, self.values);
4172 this.solve_in_place_with_conj(conj, if self.symbolic.perm().is_some() { x.rb_mut() } else { rhs.rb_mut() }, par, stack);
4173 },
4174 SymbolicCholeskyRaw::Supernodal(symbolic) => {
4175 let this = supernodal::SupernodalLdltRef::new(symbolic, self.values);
4176 this.solve_in_place_with_conj(conj, if self.symbolic.perm().is_some() { x.rb_mut() } else { rhs.rb_mut() }, par, stack);
4177 },
4178 }
4179
4180 if let Some(perm) = self.symbolic.perm() {
4181 for j in 0..k {
4182 for (i, inv) in perm.arrays().1.iter().enumerate() {
4183 rhs[(i, j)] = copy(&x[(inv.zx(), j)]);
4184 }
4185 }
4186 }
4187 }
4188}
4189
4190pub fn factorize_symbolic_cholesky<I: Index>(
4193 A: SymbolicSparseColMatRef<'_, I>,
4194 side: Side,
4195 ord: SymmetricOrdering<'_, I>,
4196 params: CholeskySymbolicParams<'_>,
4197) -> Result<SymbolicCholesky<I>, FaerError> {
4198 let n = A.nrows();
4199 let A_nnz = A.compute_nnz();
4200
4201 assert!(A.nrows() == A.ncols());
4202
4203 with_dim!(N, n);
4204 let A = A.as_shape(N, N);
4205
4206 let req = {
4207 let n_scratch = StackReq::new::<I>(n);
4208 let A_scratch = StackReq::and(
4209 StackReq::new::<I>(n + 1),
4211 StackReq::new::<I>(A_nnz),
4213 );
4214
4215 StackReq::or(
4216 match ord {
4217 SymmetricOrdering::Amd => amd::order_maybe_unsorted_scratch::<I>(n, A_nnz),
4218 _ => StackReq::empty(),
4219 },
4220 StackReq::all_of(&[
4221 A_scratch,
4222 n_scratch,
4224 n_scratch,
4226 n_scratch,
4228 StackReq::or(
4230 supernodal::factorize_supernodal_symbolic_cholesky_scratch::<I>(n),
4231 simplicial::factorize_simplicial_symbolic_cholesky_scratch::<I>(n),
4232 ),
4233 ]),
4234 )
4235 };
4236
4237 let mut mem = dyn_stack::MemBuffer::try_new(req).ok().ok_or(FaerError::OutOfMemory)?;
4238 let stack = MemStack::new(&mut mem);
4239
4240 let mut perm_fwd = match ord {
4241 SymmetricOrdering::Identity => None,
4242 _ => Some(try_zeroed(n)?),
4243 };
4244 let mut perm_inv = match ord {
4245 SymmetricOrdering::Identity => None,
4246 _ => Some(try_zeroed(n)?),
4247 };
4248 let flops = match ord {
4249 SymmetricOrdering::Amd => Some(amd::order_maybe_unsorted(
4250 perm_fwd.as_mut().unwrap(),
4251 perm_inv.as_mut().unwrap(),
4252 A.as_dyn(),
4253 params.amd_params,
4254 stack,
4255 )?),
4256 SymmetricOrdering::Identity => None,
4257 SymmetricOrdering::Custom(perm) => {
4258 let (fwd, inv) = perm.arrays();
4259 perm_fwd.as_mut().unwrap().copy_from_slice(fwd);
4260 perm_inv.as_mut().unwrap().copy_from_slice(inv);
4261 None
4262 },
4263 };
4264
4265 let (new_col_ptr, stack) = unsafe { stack.make_raw::<I>(n + 1) };
4266 let (new_row_idx, stack) = unsafe { stack.make_raw::<I>(A_nnz) };
4267 let A = match ord {
4268 SymmetricOrdering::Identity => A,
4269 _ => permute_self_adjoint_to_unsorted(
4270 Symbolic::materialize(A_nnz),
4271 new_col_ptr,
4272 new_row_idx,
4273 SparseColMatRef::new(A, Symbolic::materialize(A.row_idx().len())),
4274 PermRef::new_checked(perm_fwd.as_ref().unwrap(), perm_inv.as_ref().unwrap(), n).as_shape(N),
4275 side,
4276 Side::Upper,
4277 stack,
4278 )
4279 .symbolic(),
4280 };
4281
4282 let (etree, stack) = unsafe { stack.make_raw::<I::Signed>(n) };
4283 let (col_counts, stack) = unsafe { stack.make_raw::<I>(n) };
4284 let etree = simplicial::prefactorize_symbolic_cholesky::<I>(etree, col_counts, A.as_shape(n, n), stack);
4285 let L_nnz = I::sum_nonnegative(col_counts.as_ref()).ok_or(FaerError::IndexOverflow)?;
4286
4287 let col_counts = Array::from_mut(col_counts, N);
4288 let flops = match flops {
4289 Some(flops) => flops,
4290 None => {
4291 let mut n_div = 0u128;
4292 let mut n_mult_subs_ldl = 0u128;
4293 for i in N.indices() {
4294 let c = col_counts[i].zx();
4295 n_div += c as u128;
4296 n_mult_subs_ldl += (c as u128 * (c as u128 + 1)) / 2;
4297 }
4298 amd::FlopCount {
4299 n_div: n_div as f64,
4300 n_mult_subs_ldl: n_mult_subs_ldl as f64,
4301 n_mult_subs_lu: 0.0,
4302 }
4303 },
4304 };
4305
4306 let flops = flops.n_div + flops.n_mult_subs_ldl;
4307 let raw = if (flops / L_nnz.zx() as f64) > params.supernodal_flop_ratio_threshold.0 * crate::sparse::linalg::CHOLESKY_SUPERNODAL_RATIO_FACTOR {
4308 SymbolicCholeskyRaw::Supernodal(supernodal::ghost_factorize_supernodal_symbolic(
4309 A,
4310 None,
4311 None,
4312 supernodal::CholeskyInput::A,
4313 etree.as_bound(N),
4314 col_counts,
4315 stack,
4316 params.supernodal_params,
4317 )?)
4318 } else {
4319 SymbolicCholeskyRaw::Simplicial(simplicial::ghost_factorize_simplicial_symbolic_cholesky(
4320 A,
4321 etree.as_bound(N),
4322 col_counts,
4323 stack,
4324 )?)
4325 };
4326
4327 Ok(SymbolicCholesky {
4328 raw,
4329 perm_fwd,
4330 perm_inv,
4331 A_nnz,
4332 })
4333}
4334
4335#[cfg(test)]
4336pub(super) mod tests {
4337 use super::*;
4338 use crate::assert;
4339 use crate::stats::prelude::*;
4340 use crate::utils::approx::*;
4341 use dyn_stack::MemBuffer;
4342 use matrix_market_rs::MtxData;
4343 use std::path::PathBuf;
4344 use std::str::FromStr;
4345
4346 type Error = Box<dyn std::error::Error>;
4347 type Result<T = (), E = Error> = core::result::Result<T, E>;
4348
4349 pub(crate) fn load_mtx<I: Index>(data: MtxData<f64>) -> (usize, usize, Vec<I>, Vec<I>, Vec<f64>) {
4350 let I = I::truncate;
4351
4352 let MtxData::Sparse([nrows, ncols], coo_indices, coo_values, _) = data else {
4353 panic!()
4354 };
4355
4356 let m = nrows;
4357 let n = ncols;
4358 let mut col_counts = vec![I(0); n];
4359 let mut col_ptr = vec![I(0); n + 1];
4360
4361 for &[_, j] in &coo_indices {
4362 col_counts[j] += I(1);
4363 }
4364
4365 for i in 0..n {
4366 col_ptr[i + 1] = col_ptr[i] + col_counts[i];
4367 }
4368 let nnz = col_ptr[n].zx();
4369
4370 let mut row_idx = vec![I(0); nnz];
4371 let mut values = vec![0.0; nnz];
4372
4373 col_counts.copy_from_slice(&col_ptr[..n]);
4374
4375 for (&[i, j], &val) in iter::zip(&coo_indices, &coo_values) {
4376 values[col_counts[j].zx()] = val;
4377 row_idx[col_counts[j].zx()] = I(i);
4378 col_counts[j] += I(1);
4379 }
4380
4381 (m, n, col_ptr, row_idx, values)
4382 }
4383
4384 #[track_caller]
4385 pub(crate) fn parse_vec<F: FromStr>(text: &str) -> (Vec<F>, &str) {
4386 let mut text = text;
4387 let mut out = Vec::new();
4388
4389 assert!(text.trim().starts_with('['));
4390 text = &text.trim()[1..];
4391 while !text.trim().starts_with(']') {
4392 let i = text.find(',').unwrap();
4393 let num = &text[..i];
4394
4395 let num = num.trim().parse::<F>().ok().unwrap();
4396 out.push(num);
4397 text = &text[i + 1..];
4398 }
4399
4400 assert!(text.trim().starts_with("],"));
4401 text = &text.trim()[2..];
4402
4403 (out, text)
4404 }
4405
4406 pub(crate) fn parse_csc_symbolic(text: &str) -> (SymbolicSparseColMat<usize>, &str) {
4407 let (col_ptr, text) = parse_vec::<usize>(text);
4408 let (row_idx, text) = parse_vec::<usize>(text);
4409 let n = col_ptr.len() - 1;
4410
4411 (SymbolicSparseColMat::new_unsorted_checked(n, n, col_ptr, None, row_idx), text)
4412 }
4413
4414 pub(crate) fn parse_csc<T: FromStr>(text: &str) -> (SparseColMat<usize, T>, &str) {
4415 let (symbolic, text) = parse_csc_symbolic(text);
4416 let (numeric, text) = parse_vec::<T>(text);
4417 (SparseColMat::new(symbolic, numeric), text)
4418 }
4419
4420 #[test]
4421 fn test_counts() {
4422 let n = 11;
4423 let col_ptr = &[0, 3, 6, 10, 13, 16, 21, 24, 29, 31, 37, 43usize];
4424 let row_idx = &[
4425 0, 5, 6, 1, 2, 7, 1, 2, 9, 10, 3, 5, 9, 4, 7, 10, 0, 3, 5, 8, 9, 0, 6, 10, 1, 4, 7, 9, 10, 5, 8, 2, 3, 5, 7, 9, 10, 2, 4, 6, 7, 9, 10usize, ];
4437
4438 let A = SymbolicSparseColMatRef::new_unsorted_checked(n, n, col_ptr, None, row_idx);
4439 let mut etree = vec![0isize; n];
4440 let mut col_count = vec![0usize; n];
4441
4442 simplicial::prefactorize_symbolic_cholesky(
4443 &mut etree,
4444 &mut col_count,
4445 A,
4446 MemStack::new(&mut MemBuffer::new(StackReq::new::<usize>(n))),
4447 );
4448
4449 assert!(etree == [5, 2, 7, 5, 7, 6, 8, 9, 9, 10, NONE as isize]);
4450 assert!(col_count == [3, 3, 4, 3, 3, 4, 4, 3, 3, 2, 1usize]);
4451 }
4452
4453 #[test]
4454 fn test_amd() -> Result {
4455 for file in [
4456 PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_cholesky/small.txt"),
4457 PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_cholesky/medium-0.txt"),
4458 PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_cholesky/medium-1.txt"),
4459 ] {
4460 let (A, _) = parse_csc_symbolic(&std::fs::read_to_string(&file)?);
4461 let n = A.nrows();
4462
4463 let (target_fwd, target_bwd, _) = ::amd::order(A.nrows(), A.col_ptr(), A.row_idx(), &::amd::Control::default()).unwrap();
4464
4465 let fwd = &mut *vec![0usize; n];
4466 let bwd = &mut *vec![0usize; n];
4467 amd::order_maybe_unsorted(
4468 fwd,
4469 bwd,
4470 A.rb(),
4471 amd::Control::default(),
4472 MemStack::new(&mut MemBuffer::new(amd::order_maybe_unsorted_scratch::<usize>(n, A.compute_nnz()))),
4473 )?;
4474
4475 assert!(fwd == &target_fwd);
4476 assert!(bwd == &target_bwd);
4477 }
4478 Ok(())
4479 }
4480
4481 fn reconstruct_from_supernodal_ldlt<I: Index, T: ComplexField>(symbolic: &supernodal::SymbolicSupernodalCholesky<I>, L_values: &[T]) -> Mat<T> {
4482 let ldlt = supernodal::SupernodalLdltRef::new(symbolic, L_values);
4483 let n_supernodes = ldlt.symbolic().n_supernodes();
4484 let n = ldlt.symbolic().nrows();
4485
4486 let mut dense = Mat::<T>::zeros(n, n);
4487
4488 for s in 0..n_supernodes {
4489 let s = ldlt.supernode(s);
4490 let node = s.val();
4491 let size = node.ncols();
4492
4493 let (Ls_top, Ls_bot) = node.split_at_row(size);
4494 dense
4495 .rb_mut()
4496 .submatrix_mut(s.start(), s.start(), size, size)
4497 .copy_from_triangular_lower(Ls_top);
4498
4499 for col in 0..size {
4500 for (i, &row) in s.pattern().iter().enumerate() {
4501 dense[(row.zx(), s.start() + col)] = Ls_bot[(i, col)].clone();
4502 }
4503 }
4504 }
4505 let mut D = Col::<T>::zeros(n);
4506 D.copy_from(dense.rb().diagonal().column_vector());
4507 dense.rb_mut().diagonal_mut().fill(one::<T>());
4508
4509 &dense * D.as_diagonal() * dense.adjoint()
4510 }
4511
4512 pub(crate) fn reconstruct_from_supernodal_llt<I: Index, T: ComplexField>(
4513 symbolic: &supernodal::SymbolicSupernodalCholesky<I>,
4514 L_values: &[T],
4515 ) -> Mat<T> {
4516 let llt = supernodal::SupernodalLltRef::new(symbolic, L_values);
4517 let n_supernodes = llt.symbolic().n_supernodes();
4518 let n = llt.symbolic().nrows();
4519
4520 let mut dense = Mat::<T>::zeros(n, n);
4521
4522 for s in 0..n_supernodes {
4523 let s = llt.supernode(s);
4524 let node = s.val();
4525 let size = node.ncols();
4526
4527 let (Ls_top, Ls_bot) = node.split_at_row(size);
4528 dense
4529 .rb_mut()
4530 .submatrix_mut(s.start(), s.start(), size, size)
4531 .copy_from_triangular_lower(Ls_top);
4532
4533 for col in 0..size {
4534 for (i, &row) in s.pattern().iter().enumerate() {
4535 dense[(row.zx(), s.start() + col)] = Ls_bot[(i, col)].clone();
4536 }
4537 }
4538 }
4539 &dense * dense.adjoint()
4540 }
4541 fn reconstruct_from_simplicial_ldlt<I: Index, T: ComplexField>(symbolic: &simplicial::SymbolicSimplicialCholesky<I>, L_values: &[T]) -> Mat<T> {
4542 let n = symbolic.nrows();
4543
4544 let mut dense = SparseColMatRef::new(symbolic.factor(), L_values).to_dense();
4545 let mut D = Col::<T>::zeros(n);
4546 D.copy_from(dense.rb().diagonal().column_vector());
4547 dense.rb_mut().diagonal_mut().fill(one::<T>());
4548
4549 &dense * D.as_diagonal() * dense.adjoint()
4550 }
4551
4552 fn reconstruct_from_simplicial_llt<I: Index, T: ComplexField>(symbolic: &simplicial::SymbolicSimplicialCholesky<I>, L_values: &[T]) -> Mat<T> {
4553 let dense = SparseColMatRef::new(symbolic.factor(), L_values).to_dense();
4554 &dense * dense.adjoint()
4555 }
4556
4557 #[test]
4558 fn test_supernodal() -> Result {
4559 let file = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_cholesky/medium-1.txt");
4560 let A_upper = parse_csc::<c64>(&std::fs::read_to_string(&file)?).0;
4561 let mut A_lower = A_upper.adjoint().to_col_major()?;
4562 let A_upper = A_upper.rb();
4563
4564 let n = A_upper.nrows();
4565 let etree = &mut *vec![0isize; n];
4566 let col_counts = &mut *vec![0usize; n];
4567 let etree = simplicial::prefactorize_symbolic_cholesky(
4568 etree,
4569 col_counts,
4570 A_upper.symbolic(),
4571 MemStack::new(&mut MemBuffer::new(simplicial::prefactorize_symbolic_cholesky_scratch::<usize>(
4572 n,
4573 A_upper.compute_nnz(),
4574 ))),
4575 );
4576
4577 let symbolic = &supernodal::factorize_supernodal_symbolic_cholesky(
4578 A_upper.symbolic(),
4579 etree,
4580 col_counts,
4581 MemStack::new(&mut MemBuffer::new(supernodal::factorize_supernodal_symbolic_cholesky_scratch::<usize>(
4582 n,
4583 ))),
4584 Default::default(),
4585 )?;
4586
4587 {
4588 let A_lower = A_lower.rb();
4589 let approx_eq = CwiseMat(ApproxEq::eps() * 1e5);
4590 let L_val = &mut *vec![zero::<c64>(); symbolic.len_val()];
4591 supernodal::factorize_supernodal_numeric_ldlt(
4592 L_val,
4593 A_lower,
4594 Default::default(),
4595 symbolic,
4596 Par::Seq,
4597 MemStack::new(&mut MemBuffer::new(supernodal::factorize_supernodal_numeric_ldlt_scratch::<usize, c64>(
4598 symbolic,
4599 Par::Seq,
4600 Default::default(),
4601 ))),
4602 Default::default(),
4603 )?;
4604
4605 let mut target = A_lower.to_dense();
4606 let adjoint = target.adjoint().to_owned();
4607 target.copy_from_strict_triangular_upper(adjoint);
4608 let A = reconstruct_from_supernodal_ldlt(symbolic, L_val);
4609
4610 assert!(A ~ target);
4611
4612 let k = 3;
4613 let rng = &mut StdRng::seed_from_u64(0);
4614
4615 let rhs = CwiseMatDistribution {
4616 nrows: n,
4617 ncols: k,
4618 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
4619 }
4620 .rand::<Mat<c64>>(rng);
4621
4622 let supernodal = supernodal::SupernodalLdltRef::new(symbolic, L_val);
4623 for conj in [Conj::No, Conj::Yes] {
4624 let mut x = rhs.clone();
4625 supernodal.solve_in_place_with_conj(
4626 conj,
4627 x.rb_mut(),
4628 Par::Seq,
4629 MemStack::new(&mut MemBuffer::new(symbolic.solve_in_place_scratch::<c64>(k, Par::Seq))),
4630 );
4631
4632 let target = rhs.rb();
4633 let rhs = match conj {
4634 Conj::No => &A * &x,
4635 Conj::Yes => A.conjugate() * &x,
4636 };
4637
4638 assert!(rhs ~ target);
4639 }
4640 }
4641
4642 {
4643 let A_lower = A_lower.rb();
4644 let approx_eq = CwiseMat(ApproxEq::eps() * 1e2);
4645 let L_val = &mut *vec![zero::<c64>(); symbolic.len_val()];
4646 let fwd = &mut *vec![0usize; n];
4647 let bwd = &mut *vec![0usize; n];
4648 let subdiag = &mut *vec![zero::<c64>(); n];
4649
4650 supernodal::factorize_supernodal_numeric_intranode_lblt(
4651 L_val,
4652 subdiag,
4653 fwd,
4654 bwd,
4655 A_lower,
4656 symbolic,
4657 Par::Seq,
4658 MemStack::new(&mut MemBuffer::new(supernodal::factorize_supernodal_numeric_intranode_lblt_scratch::<
4659 usize,
4660 c64,
4661 >(symbolic, Par::Seq, Default::default()))),
4662 Default::default(),
4663 );
4664
4665 let mut A = A_lower.to_dense();
4666 let adjoint = A.adjoint().to_owned();
4667 A.copy_from_strict_triangular_upper(adjoint);
4668
4669 let k = 3;
4670 let rng = &mut StdRng::seed_from_u64(0);
4671
4672 let rhs = CwiseMatDistribution {
4673 nrows: n,
4674 ncols: k,
4675 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
4676 }
4677 .rand::<Mat<c64>>(rng);
4678
4679 let supernodal = supernodal::SupernodalIntranodeLbltRef::new(symbolic, L_val, subdiag, PermRef::new_checked(fwd, bwd, n));
4680 for conj in [Conj::No, Conj::Yes] {
4681 let mut x = rhs.clone();
4682 let mut tmp = x.clone();
4683
4684 for j in 0..k {
4685 for (i, &fwd) in fwd.iter().enumerate() {
4686 tmp[(i, j)] = x[(fwd, j)];
4687 }
4688 }
4689
4690 supernodal.solve_in_place_no_numeric_permute_with_conj(
4691 conj,
4692 tmp.rb_mut(),
4693 Par::Seq,
4694 MemStack::new(&mut MemBuffer::new(symbolic.solve_in_place_scratch::<c64>(k, Par::Seq))),
4695 );
4696
4697 for j in 0..k {
4698 for (i, &bwd) in bwd.iter().enumerate() {
4699 x[(i, j)] = tmp[(bwd, j)];
4700 }
4701 }
4702
4703 let target = rhs.rb();
4704 let rhs = match conj {
4705 Conj::No => &A * &x,
4706 Conj::Yes => A.conjugate() * &x,
4707 };
4708
4709 assert!(rhs ~ target);
4710 }
4711 }
4712
4713 {
4714 for j in 0..n {
4715 *A_lower.val_of_col_mut(j).first_mut().unwrap() *= 1e3;
4716 }
4717 let A_lower = A_lower.rb();
4718
4719 let approx_eq = CwiseMat(ApproxEq::eps() * 1e5);
4720 let L_val = &mut *vec![zero::<c64>(); symbolic.len_val()];
4721 supernodal::factorize_supernodal_numeric_llt(
4722 L_val,
4723 A_lower,
4724 Default::default(),
4725 symbolic,
4726 Par::Seq,
4727 MemStack::new(&mut MemBuffer::new(supernodal::factorize_supernodal_numeric_llt_scratch::<usize, c64>(
4728 symbolic,
4729 Par::Seq,
4730 Default::default(),
4731 ))),
4732 Default::default(),
4733 )?;
4734
4735 let mut target = A_lower.to_dense();
4736 let adjoint = target.adjoint().to_owned();
4737 target.copy_from_strict_triangular_upper(adjoint);
4738 let A = reconstruct_from_supernodal_llt(symbolic, L_val);
4739
4740 assert!(A ~ target);
4741
4742 let k = 3;
4743 let rng = &mut StdRng::seed_from_u64(0);
4744
4745 let rhs = CwiseMatDistribution {
4746 nrows: n,
4747 ncols: k,
4748 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
4749 }
4750 .rand::<Mat<c64>>(rng);
4751
4752 let supernodal = supernodal::SupernodalLltRef::new(symbolic, L_val);
4753 for conj in [Conj::No, Conj::Yes] {
4754 let mut x = rhs.clone();
4755 supernodal.solve_in_place_with_conj(
4756 conj,
4757 x.rb_mut(),
4758 Par::Seq,
4759 MemStack::new(&mut MemBuffer::new(symbolic.solve_in_place_scratch::<c64>(k, Par::Seq))),
4760 );
4761
4762 let target = rhs.rb();
4763 let rhs = match conj {
4764 Conj::No => &A * &x,
4765 Conj::Yes => A.conjugate() * &x,
4766 };
4767
4768 assert!(rhs ~ target);
4769 }
4770 }
4771 Ok(())
4772 }
4773
4774 #[test]
4775 fn test_simplicial() -> Result {
4776 let file = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_cholesky/medium-1.txt");
4777 let mut A_upper = parse_csc::<c64>(&std::fs::read_to_string(&file)?).0;
4778
4779 let n = A_upper.nrows();
4780 let etree = &mut *vec![0isize; n];
4781 let col_counts = &mut *vec![0usize; n];
4782 let etree = simplicial::prefactorize_symbolic_cholesky(
4783 etree,
4784 col_counts,
4785 A_upper.symbolic(),
4786 MemStack::new(&mut MemBuffer::new(simplicial::prefactorize_symbolic_cholesky_scratch::<usize>(
4787 n,
4788 A_upper.compute_nnz(),
4789 ))),
4790 );
4791
4792 let symbolic = &simplicial::factorize_simplicial_symbolic_cholesky(
4793 A_upper.symbolic(),
4794 etree,
4795 col_counts,
4796 MemStack::new(&mut MemBuffer::new(simplicial::factorize_simplicial_symbolic_cholesky_scratch::<usize>(
4797 n,
4798 ))),
4799 )?;
4800
4801 {
4802 let approx_eq = CwiseMat(ApproxEq::eps() * 1e5);
4803 let L_val = &mut *vec![zero::<c64>(); symbolic.len_val()];
4804 let A_upper = A_upper.rb();
4805 simplicial::factorize_simplicial_numeric_ldlt(
4806 L_val,
4807 A_upper,
4808 Default::default(),
4809 symbolic,
4810 MemStack::new(&mut MemBuffer::new(simplicial::factorize_simplicial_numeric_ldlt_scratch::<usize, c64>(
4811 n,
4812 ))),
4813 )?;
4814
4815 let mut target = A_upper.to_dense();
4816 let adjoint = target.adjoint().to_owned();
4817 target.copy_from_strict_triangular_lower(adjoint);
4818 let A = reconstruct_from_simplicial_ldlt(symbolic, L_val);
4819
4820 assert!(A ~ target);
4821
4822 let k = 3;
4823 let rng = &mut StdRng::seed_from_u64(0);
4824
4825 let rhs = CwiseMatDistribution {
4826 nrows: n,
4827 ncols: k,
4828 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
4829 }
4830 .rand::<Mat<c64>>(rng);
4831
4832 let simplicial = simplicial::SimplicialLdltRef::new(symbolic, L_val);
4833 for conj in [Conj::No, Conj::Yes] {
4834 let mut x = rhs.clone();
4835 simplicial.solve_in_place_with_conj(
4836 conj,
4837 x.rb_mut(),
4838 Par::Seq,
4839 MemStack::new(&mut MemBuffer::new(symbolic.solve_in_place_scratch::<c64>(k))),
4840 );
4841
4842 let target = rhs.rb();
4843 let rhs = match conj {
4844 Conj::No => &A * &x,
4845 Conj::Yes => A.conjugate() * &x,
4846 };
4847
4848 assert!(rhs ~ target);
4849 }
4850 }
4851
4852 {
4853 for j in 0..n {
4854 let (i, x) = A_upper.rb_mut().idx_val_of_col_mut(j);
4855 for (i, x) in iter::zip(i, x) {
4856 if i == j {
4857 *x *= 1e3;
4858 }
4859 }
4860 }
4861 let A_upper = A_upper.rb();
4862
4863 let approx_eq = CwiseMat(ApproxEq::eps() * 1e5);
4864 let L_val = &mut *vec![zero::<c64>(); symbolic.len_val()];
4865 simplicial::factorize_simplicial_numeric_llt(
4866 L_val,
4867 A_upper,
4868 Default::default(),
4869 symbolic,
4870 MemStack::new(&mut MemBuffer::new(simplicial::factorize_simplicial_numeric_llt_scratch::<usize, c64>(n))),
4871 )?;
4872
4873 let mut target = A_upper.to_dense();
4874 let adjoint = target.adjoint().to_owned();
4875 target.copy_from_strict_triangular_lower(adjoint);
4876 let A = reconstruct_from_simplicial_llt(symbolic, L_val);
4877
4878 assert!(A ~ target);
4879
4880 let k = 3;
4881 let rng = &mut StdRng::seed_from_u64(0);
4882
4883 let rhs = CwiseMatDistribution {
4884 nrows: n,
4885 ncols: k,
4886 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
4887 }
4888 .rand::<Mat<c64>>(rng);
4889
4890 let simplicial = simplicial::SimplicialLltRef::new(symbolic, L_val);
4891 for conj in [Conj::No, Conj::Yes] {
4892 let mut x = rhs.clone();
4893 simplicial.solve_in_place_with_conj(
4894 conj,
4895 x.rb_mut(),
4896 Par::Seq,
4897 MemStack::new(&mut MemBuffer::new(symbolic.solve_in_place_scratch::<c64>(k))),
4898 );
4899
4900 let target = rhs.rb();
4901 let rhs = match conj {
4902 Conj::No => &A * &x,
4903 Conj::Yes => A.conjugate() * &x,
4904 };
4905
4906 assert!(rhs ~ target);
4907 }
4908 }
4909 Ok(())
4910 }
4911
4912 #[test]
4913 fn test_solver_llt() -> Result {
4914 let file = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_cholesky/medium-1.txt");
4915 let mut A_upper = parse_csc::<c64>(&std::fs::read_to_string(&file)?).0;
4916 let n = A_upper.nrows();
4917 for j in 0..n {
4918 let (i, x) = A_upper.rb_mut().idx_val_of_col_mut(j);
4919 for (i, x) in iter::zip(i, x) {
4920 if i == j {
4921 *x *= 1e3;
4922 }
4923 }
4924 }
4925 let A_upper = A_upper.rb();
4926 let A_lower = A_upper.adjoint().to_col_major()?;
4927 let A_lower = A_lower.rb();
4928
4929 let mut A_full = A_lower.to_dense();
4930 let adjoint = A_full.adjoint().to_owned();
4931 A_full.copy_from_triangular_upper(adjoint);
4932 let A_full = A_full.rb();
4933
4934 let rng = &mut StdRng::seed_from_u64(0);
4935 let approx_eq = CwiseMat(ApproxEq::eps() * 1e4);
4936
4937 for (A, side) in [(A_lower, Side::Lower), (A_upper, Side::Upper)] {
4938 for supernodal_flop_ratio_threshold in [SupernodalThreshold::FORCE_SIMPLICIAL, SupernodalThreshold::FORCE_SUPERNODAL] {
4939 for par in [Par::Seq, Par::rayon(4)] {
4940 let symbolic = &factorize_symbolic_cholesky(
4941 A.symbolic(),
4942 side,
4943 SymmetricOrdering::Amd,
4944 CholeskySymbolicParams {
4945 supernodal_flop_ratio_threshold,
4946 ..Default::default()
4947 },
4948 )?;
4949
4950 let L_val = &mut *vec![zero::<c64>(); symbolic.len_val()];
4951 let llt = symbolic.factorize_numeric_llt(
4952 L_val,
4953 A,
4954 side,
4955 Default::default(),
4956 par,
4957 MemStack::new(&mut MemBuffer::new(
4958 symbolic.factorize_numeric_llt_scratch::<c64>(par, Default::default()),
4959 )),
4960 Default::default(),
4961 )?;
4962
4963 for k in (1..16).chain(128..132) {
4964 let rhs = CwiseMatDistribution {
4965 nrows: n,
4966 ncols: k,
4967 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
4968 }
4969 .rand::<Mat<c64>>(rng);
4970
4971 for conj in [Conj::No, Conj::Yes] {
4972 let mut x = rhs.clone();
4973 llt.solve_in_place_with_conj(
4974 conj,
4975 x.rb_mut(),
4976 par,
4977 MemStack::new(&mut MemBuffer::new(llt.solve_in_place_scratch::<c64>(k, Par::Seq))),
4978 );
4979
4980 let target = rhs.as_ref();
4981 let rhs = match conj {
4982 Conj::No => A_full * &x,
4983 Conj::Yes => A_full.conjugate() * &x,
4984 };
4985 assert!(rhs ~ target);
4986 }
4987 }
4988 }
4989 }
4990 }
4991
4992 Ok(())
4993 }
4994
4995 #[test]
4996 fn test_solver_ldlt() -> Result {
4997 let file = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_cholesky/medium-1.txt");
4998 let A_upper = parse_csc::<c64>(&std::fs::read_to_string(&file)?).0;
4999 let n = A_upper.nrows();
5000
5001 let A_upper = A_upper.rb();
5002 let A_lower = A_upper.adjoint().to_col_major()?;
5003 let A_lower = A_lower.rb();
5004
5005 let mut A_full = A_lower.to_dense();
5006 let adjoint = A_full.adjoint().to_owned();
5007 A_full.copy_from_triangular_upper(adjoint);
5008 let A_full = A_full.rb();
5009
5010 let rng = &mut StdRng::seed_from_u64(0);
5011 let approx_eq = CwiseMat(ApproxEq::eps() * 1e5);
5012
5013 for (A, side) in [(A_lower, Side::Lower), (A_upper, Side::Upper)] {
5014 for supernodal_flop_ratio_threshold in [SupernodalThreshold::FORCE_SIMPLICIAL, SupernodalThreshold::FORCE_SUPERNODAL] {
5015 for par in [Par::Seq, Par::rayon(4)] {
5016 let symbolic = &factorize_symbolic_cholesky(
5017 A.symbolic(),
5018 side,
5019 SymmetricOrdering::Amd,
5020 CholeskySymbolicParams {
5021 supernodal_flop_ratio_threshold,
5022 ..Default::default()
5023 },
5024 )?;
5025
5026 let L_val = &mut *vec![zero::<c64>(); symbolic.len_val()];
5027 let ldlt = symbolic.factorize_numeric_ldlt(
5028 L_val,
5029 A,
5030 side,
5031 Default::default(),
5032 par,
5033 MemStack::new(&mut MemBuffer::new(
5034 symbolic.factorize_numeric_ldlt_scratch::<c64>(par, Default::default()),
5035 )),
5036 Default::default(),
5037 )?;
5038
5039 for k in (1..16).chain(128..132) {
5040 let rhs = CwiseMatDistribution {
5041 nrows: n,
5042 ncols: k,
5043 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
5044 }
5045 .rand::<Mat<c64>>(rng);
5046
5047 for conj in [Conj::No, Conj::Yes] {
5048 let mut x = rhs.clone();
5049 ldlt.solve_in_place_with_conj(
5050 conj,
5051 x.rb_mut(),
5052 par,
5053 MemStack::new(&mut MemBuffer::new(ldlt.solve_in_place_scratch::<c64>(k, Par::Seq))),
5054 );
5055
5056 let target = rhs.as_ref();
5057 let rhs = match conj {
5058 Conj::No => A_full * &x,
5059 Conj::Yes => A_full.conjugate() * &x,
5060 };
5061 assert!(rhs ~ target);
5062 }
5063 }
5064 }
5065 }
5066 }
5067
5068 Ok(())
5069 }
5070
5071 #[test]
5072 fn test_solver_bk() -> Result {
5073 let file = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test_data/sparse_cholesky/medium-1.txt");
5074 let A_upper = parse_csc::<c64>(&std::fs::read_to_string(&file)?).0;
5075 let n = A_upper.nrows();
5076
5077 let A_upper = A_upper.rb();
5078 let A_lower = A_upper.adjoint().to_col_major()?;
5079 let A_lower = A_lower.rb();
5080
5081 let mut A_full = A_lower.to_dense();
5082 let adjoint = A_full.adjoint().to_owned();
5083 A_full.copy_from_triangular_upper(adjoint);
5084 let A_full = A_full.rb();
5085
5086 let rng = &mut StdRng::seed_from_u64(0);
5087 let approx_eq = CwiseMat(ApproxEq::eps() * 1e4);
5088
5089 for (A, side) in [(A_lower, Side::Lower), (A_upper, Side::Upper)] {
5090 for supernodal_flop_ratio_threshold in [SupernodalThreshold::FORCE_SIMPLICIAL, SupernodalThreshold::FORCE_SUPERNODAL] {
5091 for par in [Par::Seq, Par::rayon(4)] {
5092 let symbolic = &factorize_symbolic_cholesky(
5093 A.symbolic(),
5094 side,
5095 SymmetricOrdering::Amd,
5096 CholeskySymbolicParams {
5097 supernodal_flop_ratio_threshold,
5098 ..Default::default()
5099 },
5100 )?;
5101 let fwd = &mut *vec![0usize; n];
5102 let bwd = &mut *vec![0usize; n];
5103 let subdiag = &mut *vec![zero::<c64>(); n];
5104
5105 let L_val = &mut *vec![zero::<c64>(); symbolic.len_val()];
5106 let lblt = symbolic.factorize_numeric_intranode_lblt(
5107 L_val,
5108 subdiag,
5109 fwd,
5110 bwd,
5111 A,
5112 side,
5113 par,
5114 MemStack::new(&mut MemBuffer::new(
5115 symbolic.factorize_numeric_intranode_lblt_scratch::<c64>(par, Default::default()),
5116 )),
5117 Default::default(),
5118 );
5119
5120 for k in (1..16).chain(128..132) {
5121 let rhs = CwiseMatDistribution {
5122 nrows: n,
5123 ncols: k,
5124 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
5125 }
5126 .rand::<Mat<c64>>(rng);
5127
5128 for conj in [Conj::No, Conj::Yes] {
5129 let mut x = rhs.clone();
5130 lblt.solve_in_place_with_conj(
5131 conj,
5132 x.rb_mut(),
5133 par,
5134 MemStack::new(&mut MemBuffer::new(lblt.solve_in_place_scratch::<c64>(k, Par::Seq))),
5135 );
5136
5137 let target = rhs.as_ref();
5138 let rhs = match conj {
5139 Conj::No => A_full * &x,
5140 Conj::Yes => A_full.conjugate() * &x,
5141 };
5142 assert!(rhs ~ target);
5143 }
5144 }
5145 }
5146 }
5147 }
5148
5149 Ok(())
5150 }
5151}