1use crate::assert;
2use crate::internal_prelude_sp::*;
3use core::cell::UnsafeCell;
4
5pub struct SparseMatMulInfo {
8 flops_prefix_sum: alloc::vec::Vec<f64>,
9}
10
11#[track_caller]
17pub fn sparse_sparse_matmul_symbolic<I: Index>(
18 lhs: SymbolicSparseColMatRef<'_, I>,
19 rhs: SymbolicSparseColMatRef<'_, I>,
20) -> Result<(SymbolicSparseColMat<I>, SparseMatMulInfo), FaerError> {
21 assert!(lhs.ncols() == rhs.nrows());
22
23 let m = lhs.nrows();
24 let n = rhs.ncols();
25
26 let mut col_ptr = try_zeroed::<I>(n + 1)?;
27 let mut row_idx = alloc::vec::Vec::new();
28 let mut work = try_collect(repeat_n!(I::truncate(usize::MAX), m))?;
29 let mut info = try_zeroed::<f64>(n + 1)?;
30
31 for j in 0..n {
32 let mut count = 0usize;
33 let mut flops = 0.0f64;
34 for k in rhs.row_idx_of_col(j) {
35 for i in lhs.row_idx_of_col(k) {
36 if work[i] != I::truncate(j) {
37 row_idx.try_reserve(1).ok().ok_or(FaerError::OutOfMemory)?;
38 row_idx.push(I::truncate(i));
39 work[i] = I::truncate(j);
40
41 count += 1;
42 }
43 }
44 flops += lhs.row_idx_of_col_raw(k).len() as f64;
45 }
46
47 info[j + 1] = info[j] + flops;
48 col_ptr[j + 1] = col_ptr[j] + I::truncate(count);
49 if col_ptr[j + 1] > I::from_signed(I::Signed::MAX) {
50 return Err(FaerError::IndexOverflow);
51 }
52 row_idx[col_ptr[j].zx()..col_ptr[j + 1].zx()].sort_unstable();
53 }
54
55 unsafe {
56 Ok((
57 SymbolicSparseColMat::new_unchecked(m, n, col_ptr, None, row_idx),
58 SparseMatMulInfo { flops_prefix_sum: info },
59 ))
60 }
61}
62
63pub fn sparse_sparse_matmul_numeric_scratch<I: Index, T: ComplexField>(dst: SymbolicSparseColMatRef<'_, I>, par: Par) -> StackReq {
66 temp_mat_scratch::<T>(dst.nrows(), par.degree())
67}
68
69#[track_caller]
75#[math]
76pub fn sparse_sparse_matmul_numeric<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
77 dst: SparseColMatMut<'_, I, T>,
78 beta: Accum,
79 lhs: SparseColMatRef<'_, I, LhsT>,
80 rhs: SparseColMatRef<'_, I, RhsT>,
81 alpha: T,
82 info: &SparseMatMulInfo,
83 par: Par,
84 stack: &mut MemStack,
85) {
86 assert!(all(dst.nrows() == lhs.nrows(), dst.ncols() == rhs.ncols(), lhs.ncols() == rhs.nrows()));
87 let m = lhs.nrows();
88 let n = rhs.ncols();
89 let mut dst = dst;
90 if let Accum::Replace = beta {
91 for j in 0..n {
92 dst.rb_mut().val_of_col_mut(j).fill(zero());
93 }
94 }
95 let alpha = α
96
97 let (c_symbolic, c_values) = dst.parts_mut();
98
99 let total_flop_count = info.flops_prefix_sum[n];
100
101 let (mut work, _) = temp_mat_zeroed::<T, _, _>(m, par.degree(), stack);
102 let work = work.as_mat_mut();
103 let work = work.rb();
104
105 #[derive(Copy, Clone)]
106 struct SyncWrapper<T>(T);
107 unsafe impl<T> Sync for SyncWrapper<T> {}
108 unsafe impl<T> Send for SyncWrapper<T> {}
109
110 let c_values = SyncWrapper(&*UnsafeCell::from_mut(c_values));
111
112 let nthreads = par.degree();
113 let job = &|tid: usize| {
114 assert!(tid < nthreads);
115
116 fn partition_fn(total_flop_count: f64, nthreads: usize, tid: usize) -> impl FnMut(&f64) -> bool {
117 move |&x| x < total_flop_count * (tid as f64 / nthreads as f64)
118 }
119
120 let mut work = unsafe { work.col(tid).const_cast().try_as_col_major_mut().unwrap() };
121 let col_start = info.flops_prefix_sum.partition_point(partition_fn(total_flop_count, nthreads, tid));
122 let col_end = col_start + info.flops_prefix_sum[col_start..].partition_point(partition_fn(total_flop_count, nthreads, tid + 1));
123
124 let c_values = unsafe { &*({ c_values }.0 as *const UnsafeCell<[T]> as *const [UnsafeCell<T>]) };
126
127 for j in col_start..col_end {
128 for (k, b_k) in iter::zip(rhs.row_idx_of_col(j), rhs.val_of_col(j)) {
129 let b_k = Conj::apply(b_k) * *alpha;
130
131 for (i, a_i) in iter::zip(lhs.row_idx_of_col(k), lhs.val_of_col(k)) {
132 let a_i = Conj::apply(a_i);
133 work[i] = work[i] + a_i * b_k;
134 }
135 }
136 let c_values =
140 unsafe { &mut *UnsafeCell::raw_get((&c_values[c_symbolic.col_range(j)]) as *const [UnsafeCell<T>] as *const UnsafeCell<[T]>) };
141
142 for (i, c_i) in iter::zip(c_symbolic.row_idx_of_col(j), c_values) {
143 *c_i = *c_i + work[i];
144 work[i] = zero();
145 }
146 }
147 };
148
149 match par {
150 Par::Seq => {
151 job(0);
152 },
153 #[cfg(feature = "rayon")]
154 Par::Rayon(nthreads) => {
155 use rayon::prelude::*;
156
157 (0..nthreads.get()).into_par_iter().for_each(|tid| {
158 job(tid);
159 });
160 },
161 }
162}
163
164#[track_caller]
170pub fn sparse_sparse_matmul<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
171 lhs: SparseColMatRef<'_, I, LhsT>,
172 rhs: SparseColMatRef<'_, I, RhsT>,
173 alpha: T,
174 par: Par,
175) -> Result<SparseColMat<I, T>, FaerError> {
176 assert!(lhs.ncols() == rhs.nrows());
177
178 let (symbolic, info) = sparse_sparse_matmul_symbolic(lhs.symbolic(), rhs.symbolic())?;
179 let mut val = alloc::vec::Vec::new();
180 val.try_reserve_exact(symbolic.row_idx().len()).ok().ok_or(FaerError::OutOfMemory)?;
181 val.resize(symbolic.row_idx().len(), zero());
182
183 sparse_sparse_matmul_numeric(
184 SparseColMatMut::new(symbolic.rb(), &mut val),
185 Accum::Add,
186 lhs,
187 rhs,
188 alpha,
189 &info,
190 par,
191 MemStack::new(&mut MemBuffer::try_new(sparse_sparse_matmul_numeric_scratch::<I, T>(symbolic.rb(), par))?),
192 );
193
194 Ok(SparseColMat::new(symbolic, val))
195}
196
197#[track_caller]
203#[math]
204pub fn sparse_dense_matmul<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
205 dst: MatMut<'_, T>,
206 beta: Accum,
207 lhs: SparseColMatRef<'_, I, LhsT>,
208 rhs: MatRef<'_, RhsT>,
209 alpha: T,
210 par: Par,
211) {
212 assert!(all(dst.nrows() == lhs.nrows(), dst.ncols() == rhs.ncols(), lhs.ncols() == rhs.nrows()));
213
214 let _ = par;
216 let mut dst = dst;
217
218 if let Accum::Replace = beta {
219 dst.fill(zero());
220 }
221 with_dim!(M, dst.nrows());
222 with_dim!(N, dst.ncols());
223 with_dim!(K, lhs.ncols());
224
225 let mut dst = dst.as_shape_mut(M, N);
226 let lhs = lhs.as_shape(M, K);
227 let rhs = rhs.as_shape(K, N);
228
229 for j in N.indices() {
230 for depth in K.indices() {
231 let rhs_kj = Conj::apply(&rhs[(depth, j)]) * alpha;
232 for (i, lhs_ik) in iter::zip(lhs.row_idx_of_col(depth), lhs.val_of_col(depth)) {
233 dst[(i, j)] = dst[(i, j)] + Conj::apply(lhs_ik) * rhs_kj;
234 }
235 }
236 }
237}
238
239#[track_caller]
245#[math]
246pub fn dense_sparse_matmul<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
247 dst: MatMut<'_, T>,
248 beta: Accum,
249 lhs: MatRef<'_, LhsT>,
250 rhs: SparseColMatRef<'_, I, RhsT>,
251 alpha: T,
252 par: Par,
253) {
254 assert!(all(dst.nrows() == lhs.nrows(), dst.ncols() == rhs.ncols(), lhs.ncols() == rhs.nrows()));
255
256 let _ = par;
258
259 with_dim!(M, dst.nrows());
260 with_dim!(N, dst.ncols());
261 with_dim!(K, lhs.ncols());
262
263 let mut dst = dst.as_shape_mut(M, N);
264 let lhs = lhs.as_shape(M, K);
265 let rhs = rhs.as_shape(K, N);
266
267 for i in M.indices() {
268 for j in N.indices() {
269 let mut acc = zero::<T>();
270 for (depth, rhs_kj) in iter::zip(rhs.row_idx_of_col(j), rhs.val_of_col(j)) {
271 let l = Conj::apply(&lhs[(i, depth)]);
272 let r = Conj::apply(rhs_kj);
273 acc = acc + l * r;
274 }
275 match beta {
276 Accum::Replace => dst[(i, j)] = alpha * acc,
277 Accum::Add => dst[(i, j)] = dst[(i, j)] + alpha * acc,
278 }
279 }
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use crate::assert;
287
288 #[test]
289 fn test_sp_matmul() {
290 let a = SparseColMat::<usize, f64>::try_new_from_triplets(
291 5,
292 4,
293 &[
294 Triplet::new(0, 0, 1.0),
295 Triplet::new(1, 0, 2.0),
296 Triplet::new(3, 0, 3.0),
297 Triplet::new(1, 1, 5.0),
299 Triplet::new(4, 1, 6.0),
300 Triplet::new(0, 2, 7.0),
302 Triplet::new(2, 2, 8.0),
303 Triplet::new(0, 3, 9.0),
305 Triplet::new(2, 3, 10.0),
306 Triplet::new(3, 3, 11.0),
307 Triplet::new(4, 3, 12.0),
308 ],
309 )
310 .unwrap();
311
312 let b = SparseColMat::<usize, f64>::try_new_from_triplets(
313 4,
314 6,
315 &[
316 Triplet::new(0, 0, 1.0),
317 Triplet::new(1, 0, 2.0),
318 Triplet::new(3, 0, 3.0),
319 Triplet::new(1, 1, 5.0),
321 Triplet::new(3, 1, 6.0),
322 Triplet::new(1, 2, 7.0),
324 Triplet::new(3, 2, 8.0),
325 Triplet::new(1, 3, 9.0),
327 Triplet::new(3, 3, 10.0),
328 Triplet::new(1, 4, 11.0),
330 Triplet::new(3, 4, 12.0),
331 Triplet::new(1, 5, 13.0),
333 Triplet::new(3, 5, 14.0),
334 ],
335 )
336 .unwrap();
337
338 let c = sparse_sparse_matmul(a.rb(), b.rb(), 2.0, Par::rayon(12)).unwrap();
339
340 assert!(c.to_dense() == Scale(2.0) * a.to_dense() * b.to_dense());
341 }
342}