faer/sparse/linalg/
matmul.rs

1use crate::assert;
2use crate::internal_prelude_sp::*;
3use core::cell::UnsafeCell;
4
5/// info about the matrix multiplication operation to help split the workload between multiple
6/// threads
7pub struct SparseMatMulInfo {
8	flops_prefix_sum: alloc::vec::Vec<f64>,
9}
10
11/// performs a symbolic matrix multiplication of a sparse matrix `lhs` by a sparse matrix `rhs`,
12/// and returns the result.
13///
14/// # note
15/// allows unsorted matrices, and produces a sorted output.
16#[track_caller]
17pub fn sparse_sparse_matmul_symbolic<I: Index>(
18	lhs: SymbolicSparseColMatRef<'_, I>,
19	rhs: SymbolicSparseColMatRef<'_, I>,
20) -> Result<(SymbolicSparseColMat<I>, SparseMatMulInfo), FaerError> {
21	assert!(lhs.ncols() == rhs.nrows());
22
23	let m = lhs.nrows();
24	let n = rhs.ncols();
25
26	let mut col_ptr = try_zeroed::<I>(n + 1)?;
27	let mut row_idx = alloc::vec::Vec::new();
28	let mut work = try_collect(repeat_n!(I::truncate(usize::MAX), m))?;
29	let mut info = try_zeroed::<f64>(n + 1)?;
30
31	for j in 0..n {
32		let mut count = 0usize;
33		let mut flops = 0.0f64;
34		for k in rhs.row_idx_of_col(j) {
35			for i in lhs.row_idx_of_col(k) {
36				if work[i] != I::truncate(j) {
37					row_idx.try_reserve(1).ok().ok_or(FaerError::OutOfMemory)?;
38					row_idx.push(I::truncate(i));
39					work[i] = I::truncate(j);
40
41					count += 1;
42				}
43			}
44			flops += lhs.row_idx_of_col_raw(k).len() as f64;
45		}
46
47		info[j + 1] = info[j] + flops;
48		col_ptr[j + 1] = col_ptr[j] + I::truncate(count);
49		if col_ptr[j + 1] > I::from_signed(I::Signed::MAX) {
50			return Err(FaerError::IndexOverflow);
51		}
52		row_idx[col_ptr[j].zx()..col_ptr[j + 1].zx()].sort_unstable();
53	}
54
55	unsafe {
56		Ok((
57			SymbolicSparseColMat::new_unchecked(m, n, col_ptr, None, row_idx),
58			SparseMatMulInfo { flops_prefix_sum: info },
59		))
60	}
61}
62
63/// computes the size and alignment of the workspace required to perform the numeric matrix
64/// multiplication into `dst`.
65pub fn sparse_sparse_matmul_numeric_scratch<I: Index, T: ComplexField>(dst: SymbolicSparseColMatRef<'_, I>, par: Par) -> StackReq {
66	temp_mat_scratch::<T>(dst.nrows(), par.degree())
67}
68
69/// performs a numeric matrix multiplication of a sparse matrix `lhs` by a sparse matrix `rhs`
70/// multiplied by `alpha`, and stores or adds the result to `dst`.
71///
72/// # note
73/// `lhs` and `rhs` are allowed to be unsorted matrices.
74#[track_caller]
75#[math]
76pub fn sparse_sparse_matmul_numeric<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
77	dst: SparseColMatMut<'_, I, T>,
78	beta: Accum,
79	lhs: SparseColMatRef<'_, I, LhsT>,
80	rhs: SparseColMatRef<'_, I, RhsT>,
81	alpha: T,
82	info: &SparseMatMulInfo,
83	par: Par,
84	stack: &mut MemStack,
85) {
86	assert!(all(dst.nrows() == lhs.nrows(), dst.ncols() == rhs.ncols(), lhs.ncols() == rhs.nrows()));
87	let m = lhs.nrows();
88	let n = rhs.ncols();
89	let mut dst = dst;
90	if let Accum::Replace = beta {
91		for j in 0..n {
92			dst.rb_mut().val_of_col_mut(j).fill(zero());
93		}
94	}
95	let alpha = &alpha;
96
97	let (c_symbolic, c_values) = dst.parts_mut();
98
99	let total_flop_count = info.flops_prefix_sum[n];
100
101	let (mut work, _) = temp_mat_zeroed::<T, _, _>(m, par.degree(), stack);
102	let work = work.as_mat_mut();
103	let work = work.rb();
104
105	#[derive(Copy, Clone)]
106	struct SyncWrapper<T>(T);
107	unsafe impl<T> Sync for SyncWrapper<T> {}
108	unsafe impl<T> Send for SyncWrapper<T> {}
109
110	let c_values = SyncWrapper(&*UnsafeCell::from_mut(c_values));
111
112	let nthreads = par.degree();
113	let job = &|tid: usize| {
114		assert!(tid < nthreads);
115
116		fn partition_fn(total_flop_count: f64, nthreads: usize, tid: usize) -> impl FnMut(&f64) -> bool {
117			move |&x| x < total_flop_count * (tid as f64 / nthreads as f64)
118		}
119
120		let mut work = unsafe { work.col(tid).const_cast().try_as_col_major_mut().unwrap() };
121		let col_start = info.flops_prefix_sum.partition_point(partition_fn(total_flop_count, nthreads, tid));
122		let col_end = col_start + info.flops_prefix_sum[col_start..].partition_point(partition_fn(total_flop_count, nthreads, tid + 1));
123
124		// SAFETY: UnsafeCell<[T]> ~ [T] ~ [UnsafeCell<T>]
125		let c_values = unsafe { &*({ c_values }.0 as *const UnsafeCell<[T]> as *const [UnsafeCell<T>]) };
126
127		for j in col_start..col_end {
128			for (k, b_k) in iter::zip(rhs.row_idx_of_col(j), rhs.val_of_col(j)) {
129				let b_k = Conj::apply(b_k) * *alpha;
130
131				for (i, a_i) in iter::zip(lhs.row_idx_of_col(k), lhs.val_of_col(k)) {
132					let a_i = Conj::apply(a_i);
133					work[i] = work[i] + a_i * b_k;
134				}
135			}
136			// SAFETY: UnsafeCell<[T]> ~ [T] ~ [UnsafeCell<T>]
137			// and only thread `tid` has access to the range of column `j`
138			// since `col_start..col_end` denote disjoint ranges for each `tid`
139			let c_values =
140				unsafe { &mut *UnsafeCell::raw_get((&c_values[c_symbolic.col_range(j)]) as *const [UnsafeCell<T>] as *const UnsafeCell<[T]>) };
141
142			for (i, c_i) in iter::zip(c_symbolic.row_idx_of_col(j), c_values) {
143				*c_i = *c_i + work[i];
144				work[i] = zero();
145			}
146		}
147	};
148
149	match par {
150		Par::Seq => {
151			job(0);
152		},
153		#[cfg(feature = "rayon")]
154		Par::Rayon(nthreads) => {
155			use rayon::prelude::*;
156
157			(0..nthreads.get()).into_par_iter().for_each(|tid| {
158				job(tid);
159			});
160		},
161	}
162}
163
164/// performs a numeric matrix multiplication of a sparse matrix `lhs` by a sparse matrix `rhs`
165/// multiplied by `alpha`, and returns the result.
166///
167/// # note
168/// `lhs` and `rhs` are allowed to be unsorted matrices.
169#[track_caller]
170pub fn sparse_sparse_matmul<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
171	lhs: SparseColMatRef<'_, I, LhsT>,
172	rhs: SparseColMatRef<'_, I, RhsT>,
173	alpha: T,
174	par: Par,
175) -> Result<SparseColMat<I, T>, FaerError> {
176	assert!(lhs.ncols() == rhs.nrows());
177
178	let (symbolic, info) = sparse_sparse_matmul_symbolic(lhs.symbolic(), rhs.symbolic())?;
179	let mut val = alloc::vec::Vec::new();
180	val.try_reserve_exact(symbolic.row_idx().len()).ok().ok_or(FaerError::OutOfMemory)?;
181	val.resize(symbolic.row_idx().len(), zero());
182
183	sparse_sparse_matmul_numeric(
184		SparseColMatMut::new(symbolic.rb(), &mut val),
185		Accum::Add,
186		lhs,
187		rhs,
188		alpha,
189		&info,
190		par,
191		MemStack::new(&mut MemBuffer::try_new(sparse_sparse_matmul_numeric_scratch::<I, T>(symbolic.rb(), par))?),
192	);
193
194	Ok(SparseColMat::new(symbolic, val))
195}
196
197/// multiplies a sparse matrix `lhs` by a dense matrix `rhs`, and stores or adds the result to
198/// `dst`. see [`faer::linalg::matmul::matmul`](crate::linalg::matmul::matmul) for more details.
199///
200/// # note
201/// allows unsorted matrices.
202#[track_caller]
203#[math]
204pub fn sparse_dense_matmul<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
205	dst: MatMut<'_, T>,
206	beta: Accum,
207	lhs: SparseColMatRef<'_, I, LhsT>,
208	rhs: MatRef<'_, RhsT>,
209	alpha: T,
210	par: Par,
211) {
212	assert!(all(dst.nrows() == lhs.nrows(), dst.ncols() == rhs.ncols(), lhs.ncols() == rhs.nrows()));
213
214	// TODO: parallelize this
215	let _ = par;
216	let mut dst = dst;
217
218	if let Accum::Replace = beta {
219		dst.fill(zero());
220	}
221	with_dim!(M, dst.nrows());
222	with_dim!(N, dst.ncols());
223	with_dim!(K, lhs.ncols());
224
225	let mut dst = dst.as_shape_mut(M, N);
226	let lhs = lhs.as_shape(M, K);
227	let rhs = rhs.as_shape(K, N);
228
229	for j in N.indices() {
230		for depth in K.indices() {
231			let rhs_kj = Conj::apply(&rhs[(depth, j)]) * alpha;
232			for (i, lhs_ik) in iter::zip(lhs.row_idx_of_col(depth), lhs.val_of_col(depth)) {
233				dst[(i, j)] = dst[(i, j)] + Conj::apply(lhs_ik) * rhs_kj;
234			}
235		}
236	}
237}
238
239/// multiplies a dense matrix `lhs` by a sparse matrix `rhs`, and stores or adds the result to
240/// `dst`. see [`faer::linalg::matmul::matmul`](crate::linalg::matmul::matmul) for more details.
241///
242/// # note
243/// allows unsorted matrices.
244#[track_caller]
245#[math]
246pub fn dense_sparse_matmul<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
247	dst: MatMut<'_, T>,
248	beta: Accum,
249	lhs: MatRef<'_, LhsT>,
250	rhs: SparseColMatRef<'_, I, RhsT>,
251	alpha: T,
252	par: Par,
253) {
254	assert!(all(dst.nrows() == lhs.nrows(), dst.ncols() == rhs.ncols(), lhs.ncols() == rhs.nrows()));
255
256	// TODO: parallelize this
257	let _ = par;
258
259	with_dim!(M, dst.nrows());
260	with_dim!(N, dst.ncols());
261	with_dim!(K, lhs.ncols());
262
263	let mut dst = dst.as_shape_mut(M, N);
264	let lhs = lhs.as_shape(M, K);
265	let rhs = rhs.as_shape(K, N);
266
267	for i in M.indices() {
268		for j in N.indices() {
269			let mut acc = zero::<T>();
270			for (depth, rhs_kj) in iter::zip(rhs.row_idx_of_col(j), rhs.val_of_col(j)) {
271				let l = Conj::apply(&lhs[(i, depth)]);
272				let r = Conj::apply(rhs_kj);
273				acc = acc + l * r;
274			}
275			match beta {
276				Accum::Replace => dst[(i, j)] = alpha * acc,
277				Accum::Add => dst[(i, j)] = dst[(i, j)] + alpha * acc,
278			}
279		}
280	}
281}
282
283#[cfg(test)]
284mod tests {
285	use super::*;
286	use crate::assert;
287
288	#[test]
289	fn test_sp_matmul() {
290		let a = SparseColMat::<usize, f64>::try_new_from_triplets(
291			5,
292			4,
293			&[
294				Triplet::new(0, 0, 1.0),
295				Triplet::new(1, 0, 2.0),
296				Triplet::new(3, 0, 3.0),
297				//
298				Triplet::new(1, 1, 5.0),
299				Triplet::new(4, 1, 6.0),
300				//
301				Triplet::new(0, 2, 7.0),
302				Triplet::new(2, 2, 8.0),
303				//
304				Triplet::new(0, 3, 9.0),
305				Triplet::new(2, 3, 10.0),
306				Triplet::new(3, 3, 11.0),
307				Triplet::new(4, 3, 12.0),
308			],
309		)
310		.unwrap();
311
312		let b = SparseColMat::<usize, f64>::try_new_from_triplets(
313			4,
314			6,
315			&[
316				Triplet::new(0, 0, 1.0),
317				Triplet::new(1, 0, 2.0),
318				Triplet::new(3, 0, 3.0),
319				//
320				Triplet::new(1, 1, 5.0),
321				Triplet::new(3, 1, 6.0),
322				//
323				Triplet::new(1, 2, 7.0),
324				Triplet::new(3, 2, 8.0),
325				//
326				Triplet::new(1, 3, 9.0),
327				Triplet::new(3, 3, 10.0),
328				//
329				Triplet::new(1, 4, 11.0),
330				Triplet::new(3, 4, 12.0),
331				//
332				Triplet::new(1, 5, 13.0),
333				Triplet::new(3, 5, 14.0),
334			],
335		)
336		.unwrap();
337
338		let c = sparse_sparse_matmul(a.rb(), b.rb(), 2.0, Par::rayon(12)).unwrap();
339
340		assert!(c.to_dense() == Scale(2.0) * a.to_dense() * b.to_dense());
341	}
342}