faer/sparse/
ops.rs

1use super::*;
2use crate::assert;
3use crate::internal_prelude::*;
4
5/// returns the resulting matrix obtained by applying `f` to the elements from `lhs` and `rhs`,
6/// skipping entries that are unavailable in both of `lhs` and `rhs`.
7///
8/// # panics
9/// panics if `lhs` and `rhs` don't have matching dimensions.  
10#[track_caller]
11pub fn binary_op<I: Index, T, LhsT, RhsT>(
12	lhs: SparseColMatRef<'_, I, LhsT>,
13	rhs: SparseColMatRef<'_, I, RhsT>,
14	f: impl FnMut(Option<&LhsT>, Option<&RhsT>) -> T,
15) -> Result<SparseColMat<I, T>, FaerError> {
16	assert!(lhs.nrows() == rhs.nrows());
17	assert!(lhs.ncols() == rhs.ncols());
18	let mut f = f;
19	let m = lhs.nrows();
20	let n = lhs.ncols();
21
22	let mut col_ptr = try_zeroed::<I>(n + 1)?;
23
24	let mut nnz = 0usize;
25	for j in 0..n {
26		let lhs = lhs.row_idx_of_col_raw(j);
27		let rhs = rhs.row_idx_of_col_raw(j);
28
29		let mut lhs_pos = 0usize;
30		let mut rhs_pos = 0usize;
31		while lhs_pos < lhs.len() && rhs_pos < rhs.len() {
32			let lhs = lhs[lhs_pos];
33			let rhs = rhs[rhs_pos];
34
35			lhs_pos += (lhs <= rhs) as usize;
36			rhs_pos += (rhs <= lhs) as usize;
37			nnz += 1;
38		}
39		nnz += lhs.len() - lhs_pos;
40		nnz += rhs.len() - rhs_pos;
41		col_ptr[j + 1] = I::truncate(nnz);
42	}
43
44	if nnz > I::Signed::MAX.zx() {
45		return Err(FaerError::IndexOverflow);
46	}
47
48	let mut row_idx = try_zeroed(nnz)?;
49	let mut values = alloc::vec::Vec::new();
50	values.try_reserve_exact(nnz).map_err(|_| FaerError::OutOfMemory)?;
51
52	let mut nnz = 0usize;
53	for j in 0..n {
54		let lhs_values = lhs.val_of_col(j);
55		let rhs_values = rhs.val_of_col(j);
56		let lhs = lhs.row_idx_of_col_raw(j);
57		let rhs = rhs.row_idx_of_col_raw(j);
58
59		let mut lhs_pos = 0usize;
60		let mut rhs_pos = 0usize;
61		while lhs_pos < lhs.len() && rhs_pos < rhs.len() {
62			let lhs = lhs[lhs_pos];
63			let rhs = rhs[rhs_pos];
64
65			match lhs.cmp(&rhs) {
66				core::cmp::Ordering::Less => {
67					row_idx[nnz] = lhs;
68					values.push(f(Some(&lhs_values[lhs_pos]), None));
69				},
70				core::cmp::Ordering::Equal => {
71					row_idx[nnz] = lhs;
72					values.push(f(Some(&lhs_values[lhs_pos]), Some(&rhs_values[rhs_pos])));
73				},
74				core::cmp::Ordering::Greater => {
75					row_idx[nnz] = rhs;
76					values.push(f(None, Some(&rhs_values[rhs_pos])));
77				},
78			}
79
80			lhs_pos += (lhs <= rhs) as usize;
81			rhs_pos += (rhs <= lhs) as usize;
82			nnz += 1;
83		}
84		row_idx[nnz..nnz + lhs.len() - lhs_pos].copy_from_slice(&lhs[lhs_pos..]);
85		for src in &lhs_values[lhs_pos..lhs.len()] {
86			values.push(f(Some(src), None));
87		}
88		nnz += lhs.len() - lhs_pos;
89
90		row_idx[nnz..nnz + rhs.len() - rhs_pos].copy_from_slice(&rhs[rhs_pos..]);
91		for src in &rhs_values[rhs_pos..rhs.len()] {
92			values.push(f(None, Some(src)));
93		}
94		nnz += rhs.len() - rhs_pos;
95	}
96
97	Ok(SparseColMat::<I, T>::new(
98		SymbolicSparseColMat::<I>::new_checked(m, n, col_ptr, None, row_idx),
99		values,
100	))
101}
102
103/// returns the resulting matrix obtained by applying `f` to the elements from `dst` and `src`
104/// skipping entries that are unavailable in both of them.  
105/// the sparsity patter of `dst` is unchanged.
106///
107/// # panics
108/// panics if `src` and `dst` don't have matching dimensions.  
109/// panics if `src` contains an index that's unavailable in `dst`.  
110#[track_caller]
111pub fn binary_op_assign_into<I: Index, T, SrcT>(
112	dst: SparseColMatMut<'_, I, T>,
113	src: SparseColMatRef<'_, I, SrcT>,
114	f: impl FnMut(&mut T, Option<&SrcT>),
115) {
116	{
117		assert!(dst.nrows() == src.nrows());
118		assert!(dst.ncols() == src.ncols());
119
120		let n = dst.ncols();
121		let mut dst = dst;
122		let mut f = f;
123
124		for j in 0..n {
125			let (dst, dst_val) = dst.rb_mut().parts_mut();
126
127			let dst_val = &mut dst_val[dst.col_range(j)];
128			let src_val = src.val_of_col(j);
129
130			let dst = dst.row_idx_of_col_raw(j);
131			let src = src.row_idx_of_col_raw(j);
132
133			let mut dst_pos = 0usize;
134			let mut src_pos = 0usize;
135
136			while src_pos < src.len() {
137				let src = src[src_pos];
138
139				if dst[dst_pos] < src {
140					f(&mut dst_val[dst_pos], None);
141					dst_pos += 1;
142					continue;
143				}
144
145				assert!(dst[dst_pos] == src);
146
147				f(&mut dst_val[dst_pos], Some(&src_val[src_pos]));
148
149				src_pos += 1;
150				dst_pos += 1;
151			}
152			while dst_pos < dst.len() {
153				f(&mut dst_val[dst_pos], None);
154				dst_pos += 1;
155			}
156		}
157	}
158}
159
160/// returns the resulting matrix obtained by applying `f` to the elements from `dst`, `lhs` and
161/// `rhs`, skipping entries that are unavailable in all of `dst`, `lhs` and `rhs`.  
162/// the sparsity patter of `dst` is unchanged.
163///
164/// # panics
165/// panics if `lhs`, `rhs` and `dst` don't have matching dimensions.  
166/// panics if `lhs` or `rhs` contains an index that's unavailable in `dst`.  
167#[track_caller]
168pub fn ternary_op_assign_into<I: Index, T, LhsT, RhsT>(
169	dst: SparseColMatMut<'_, I, T>,
170	lhs: SparseColMatRef<'_, I, LhsT>,
171	rhs: SparseColMatRef<'_, I, RhsT>,
172	f: impl FnMut(&mut T, Option<&LhsT>, Option<&RhsT>),
173) {
174	{
175		assert!(dst.nrows() == lhs.nrows());
176		assert!(dst.ncols() == lhs.ncols());
177		assert!(dst.nrows() == rhs.nrows());
178		assert!(dst.ncols() == rhs.ncols());
179
180		let n = dst.ncols();
181		let mut dst = dst;
182		let mut f = f;
183
184		for j in 0..n {
185			let (dst, dst_val) = dst.rb_mut().parts_mut();
186
187			let dst_val = &mut dst_val[dst.col_range(j)];
188			let lhs_val = lhs.val_of_col(j);
189			let rhs_val = rhs.val_of_col(j);
190
191			let dst = dst.row_idx_of_col_raw(j);
192			let rhs = rhs.row_idx_of_col_raw(j);
193			let lhs = lhs.row_idx_of_col_raw(j);
194
195			let mut dst_pos = 0usize;
196			let mut lhs_pos = 0usize;
197			let mut rhs_pos = 0usize;
198
199			while lhs_pos < lhs.len() && rhs_pos < rhs.len() {
200				let lhs = lhs[lhs_pos];
201				let rhs = rhs[rhs_pos];
202
203				if dst[dst_pos] < Ord::min(lhs, rhs) {
204					f(&mut dst_val[dst_pos], None, None);
205					dst_pos += 1;
206					continue;
207				}
208
209				assert!(dst[dst_pos] == Ord::min(lhs, rhs));
210
211				match lhs.cmp(&rhs) {
212					core::cmp::Ordering::Less => {
213						f(&mut dst_val[dst_pos], Some(&lhs_val[lhs_pos]), None);
214					},
215					core::cmp::Ordering::Equal => {
216						f(&mut dst_val[dst_pos], Some(&lhs_val[lhs_pos]), Some(&rhs_val[rhs_pos]));
217					},
218					core::cmp::Ordering::Greater => {
219						f(&mut dst_val[dst_pos], None, Some(&rhs_val[rhs_pos]));
220					},
221				}
222
223				lhs_pos += (lhs <= rhs) as usize;
224				rhs_pos += (rhs <= lhs) as usize;
225				dst_pos += 1;
226			}
227			while lhs_pos < lhs.len() {
228				let lhs = lhs[lhs_pos];
229				if dst[dst_pos] < lhs {
230					f(&mut dst_val[dst_pos], None, None);
231					dst_pos += 1;
232					continue;
233				}
234				f(&mut dst_val[dst_pos], Some(&lhs_val[lhs_pos]), None);
235				lhs_pos += 1;
236				dst_pos += 1;
237			}
238			while rhs_pos < rhs.len() {
239				let rhs = rhs[rhs_pos];
240				if dst[dst_pos] < rhs {
241					f(&mut dst_val[dst_pos], None, None);
242					dst_pos += 1;
243					continue;
244				}
245				f(&mut dst_val[dst_pos], None, Some(&rhs_val[rhs_pos]));
246				rhs_pos += 1;
247				dst_pos += 1;
248			}
249			while rhs_pos < rhs.len() {
250				let rhs = rhs[rhs_pos];
251				dst_pos += dst[dst_pos..].binary_search(&rhs).unwrap();
252				f(&mut dst_val[dst_pos], None, Some(&rhs_val[rhs_pos]));
253				rhs_pos += 1;
254			}
255		}
256	}
257}
258
259/// returns the sparsity pattern containing the union of those of `lhs` and `rhs`.
260///
261/// # panics
262/// panics if `lhs` and `rhs` don't have matching dimensions.  
263#[track_caller]
264#[inline]
265pub fn union_symbolic<I: Index>(
266	lhs: SymbolicSparseColMatRef<'_, I>,
267	rhs: SymbolicSparseColMatRef<'_, I>,
268) -> Result<SymbolicSparseColMat<I>, FaerError> {
269	Ok(binary_op(
270		SparseColMatRef::<I, Symbolic>::new(lhs, Symbolic::materialize(lhs.compute_nnz())),
271		SparseColMatRef::<I, Symbolic>::new(rhs, Symbolic::materialize(rhs.compute_nnz())),
272		#[inline(always)]
273		|_, _| Symbolic,
274	)?
275	.into_parts()
276	.0)
277}
278
279/// returns the sum of `lhs` and `rhs`.
280///
281/// # panics
282/// panics if `lhs` and `rhs` don't have matching dimensions.  
283#[track_caller]
284#[inline]
285pub fn add<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
286	lhs: SparseColMatRef<'_, I, LhsT>,
287	rhs: SparseColMatRef<'_, I, RhsT>,
288) -> Result<SparseColMat<I, T>, FaerError> {
289	binary_op(lhs, rhs, |lhs, rhs| match (lhs.map(Conj::apply), rhs.map(Conj::apply)) {
290		(None, None) => zero(),
291		(None, Some(rhs)) => rhs,
292		(Some(lhs), None) => lhs,
293		(Some(lhs), Some(rhs)) => faer_traits::math_utils::add(&lhs, &rhs),
294	})
295}
296
297/// returns the difference of `lhs` and `rhs`.
298///
299/// # panics
300/// panics if `lhs` and `rhs` don't have matching dimensions.  
301#[track_caller]
302#[inline]
303pub fn sub<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
304	lhs: SparseColMatRef<'_, I, LhsT>,
305	rhs: SparseColMatRef<'_, I, RhsT>,
306) -> Result<SparseColMat<I, T>, FaerError> {
307	binary_op(lhs, rhs, |lhs, rhs| match (lhs.map(Conj::apply), rhs.map(Conj::apply)) {
308		(None, None) => zero(),
309		(None, Some(rhs)) => rhs,
310		(Some(lhs), None) => lhs,
311		(Some(lhs), Some(rhs)) => faer_traits::math_utils::sub(&lhs, &rhs),
312	})
313}
314
315/// computes the sum of `dst` and `src` and stores the result in `dst` without changing its
316/// symbolic structure.
317///
318/// # panics
319/// panics if `dst` and `rhs` don't have matching dimensions.  
320/// panics if `rhs` contains an index that's unavailable in `dst`.  
321pub fn add_assign<I: Index, T: ComplexField, RhsT: Conjugate<Canonical = T>>(dst: SparseColMatMut<'_, I, T>, rhs: SparseColMatRef<'_, I, RhsT>) {
322	binary_op_assign_into(dst, rhs, |dst, rhs| {
323		*dst = faer_traits::math_utils::add(
324			dst,
325			&match rhs {
326				Some(rhs) => Conj::apply(rhs),
327				None => zero(),
328			},
329		)
330	})
331}
332
333/// computes the difference of `dst` and `src` and stores the result in `dst` without changing its
334/// symbolic structure.
335///
336/// # panics
337/// panics if `dst` and `rhs` don't have matching dimensions.  
338/// panics if `rhs` contains an index that's unavailable in `dst`.  
339pub fn sub_assign<I: Index, T: ComplexField, RhsT: Conjugate<Canonical = T>>(dst: SparseColMatMut<'_, I, T>, rhs: SparseColMatRef<'_, I, RhsT>) {
340	binary_op_assign_into(dst, rhs, |dst, rhs| {
341		*dst = faer_traits::math_utils::sub(
342			dst,
343			&match rhs {
344				Some(rhs) => Conj::apply(rhs),
345				None => zero(),
346			},
347		)
348	})
349}
350
351/// computes the sum of `lhs` and `rhs`, storing the result in `dst` without changing its
352/// symbolic structure.
353///
354/// # panics
355/// panics if `dst`, `lhs` and `rhs` don't have matching dimensions.  
356/// panics if `lhs` or `rhs` contains an index that's unavailable in `dst`.  
357#[track_caller]
358#[inline]
359pub fn add_into<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
360	dst: SparseColMatMut<'_, I, T>,
361	lhs: SparseColMatRef<'_, I, LhsT>,
362	rhs: SparseColMatRef<'_, I, RhsT>,
363) {
364	ternary_op_assign_into(dst, lhs, rhs, |dst, lhs, rhs| {
365		*dst = match (lhs.map(Conj::apply), rhs.map(Conj::apply)) {
366			(None, None) => zero(),
367			(None, Some(rhs)) => rhs,
368			(Some(lhs), None) => lhs,
369			(Some(lhs), Some(rhs)) => faer_traits::math_utils::add(&lhs, &rhs),
370		};
371	})
372}
373
374/// computes the difference of `lhs` and `rhs`, storing the result in `dst` without changing its
375/// symbolic structure.
376///
377/// # panics
378/// panics if `dst`, `lhs` and `rhs` don't have matching dimensions.  
379/// panics if `lhs` or `rhs` contains an index that's unavailable in `dst`.  
380#[track_caller]
381#[inline]
382pub fn sub_into<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
383	dst: SparseColMatMut<'_, I, T>,
384	lhs: SparseColMatRef<'_, I, LhsT>,
385	rhs: SparseColMatRef<'_, I, RhsT>,
386) {
387	ternary_op_assign_into(dst, lhs, rhs, |dst, lhs, rhs| {
388		*dst = match (lhs.map(Conj::apply), rhs.map(Conj::apply)) {
389			(None, None) => zero(),
390			(None, Some(rhs)) => rhs,
391			(Some(lhs), None) => lhs,
392			(Some(lhs), Some(rhs)) => faer_traits::math_utils::sub(&lhs, &rhs),
393		};
394	})
395}