1use super::*;
2use crate::assert;
3use crate::internal_prelude::*;
4
5#[track_caller]
11pub fn binary_op<I: Index, T, LhsT, RhsT>(
12 lhs: SparseColMatRef<'_, I, LhsT>,
13 rhs: SparseColMatRef<'_, I, RhsT>,
14 f: impl FnMut(Option<&LhsT>, Option<&RhsT>) -> T,
15) -> Result<SparseColMat<I, T>, FaerError> {
16 assert!(lhs.nrows() == rhs.nrows());
17 assert!(lhs.ncols() == rhs.ncols());
18 let mut f = f;
19 let m = lhs.nrows();
20 let n = lhs.ncols();
21
22 let mut col_ptr = try_zeroed::<I>(n + 1)?;
23
24 let mut nnz = 0usize;
25 for j in 0..n {
26 let lhs = lhs.row_idx_of_col_raw(j);
27 let rhs = rhs.row_idx_of_col_raw(j);
28
29 let mut lhs_pos = 0usize;
30 let mut rhs_pos = 0usize;
31 while lhs_pos < lhs.len() && rhs_pos < rhs.len() {
32 let lhs = lhs[lhs_pos];
33 let rhs = rhs[rhs_pos];
34
35 lhs_pos += (lhs <= rhs) as usize;
36 rhs_pos += (rhs <= lhs) as usize;
37 nnz += 1;
38 }
39 nnz += lhs.len() - lhs_pos;
40 nnz += rhs.len() - rhs_pos;
41 col_ptr[j + 1] = I::truncate(nnz);
42 }
43
44 if nnz > I::Signed::MAX.zx() {
45 return Err(FaerError::IndexOverflow);
46 }
47
48 let mut row_idx = try_zeroed(nnz)?;
49 let mut values = alloc::vec::Vec::new();
50 values.try_reserve_exact(nnz).map_err(|_| FaerError::OutOfMemory)?;
51
52 let mut nnz = 0usize;
53 for j in 0..n {
54 let lhs_values = lhs.val_of_col(j);
55 let rhs_values = rhs.val_of_col(j);
56 let lhs = lhs.row_idx_of_col_raw(j);
57 let rhs = rhs.row_idx_of_col_raw(j);
58
59 let mut lhs_pos = 0usize;
60 let mut rhs_pos = 0usize;
61 while lhs_pos < lhs.len() && rhs_pos < rhs.len() {
62 let lhs = lhs[lhs_pos];
63 let rhs = rhs[rhs_pos];
64
65 match lhs.cmp(&rhs) {
66 core::cmp::Ordering::Less => {
67 row_idx[nnz] = lhs;
68 values.push(f(Some(&lhs_values[lhs_pos]), None));
69 },
70 core::cmp::Ordering::Equal => {
71 row_idx[nnz] = lhs;
72 values.push(f(Some(&lhs_values[lhs_pos]), Some(&rhs_values[rhs_pos])));
73 },
74 core::cmp::Ordering::Greater => {
75 row_idx[nnz] = rhs;
76 values.push(f(None, Some(&rhs_values[rhs_pos])));
77 },
78 }
79
80 lhs_pos += (lhs <= rhs) as usize;
81 rhs_pos += (rhs <= lhs) as usize;
82 nnz += 1;
83 }
84 row_idx[nnz..nnz + lhs.len() - lhs_pos].copy_from_slice(&lhs[lhs_pos..]);
85 for src in &lhs_values[lhs_pos..lhs.len()] {
86 values.push(f(Some(src), None));
87 }
88 nnz += lhs.len() - lhs_pos;
89
90 row_idx[nnz..nnz + rhs.len() - rhs_pos].copy_from_slice(&rhs[rhs_pos..]);
91 for src in &rhs_values[rhs_pos..rhs.len()] {
92 values.push(f(None, Some(src)));
93 }
94 nnz += rhs.len() - rhs_pos;
95 }
96
97 Ok(SparseColMat::<I, T>::new(
98 SymbolicSparseColMat::<I>::new_checked(m, n, col_ptr, None, row_idx),
99 values,
100 ))
101}
102
103#[track_caller]
111pub fn binary_op_assign_into<I: Index, T, SrcT>(
112 dst: SparseColMatMut<'_, I, T>,
113 src: SparseColMatRef<'_, I, SrcT>,
114 f: impl FnMut(&mut T, Option<&SrcT>),
115) {
116 {
117 assert!(dst.nrows() == src.nrows());
118 assert!(dst.ncols() == src.ncols());
119
120 let n = dst.ncols();
121 let mut dst = dst;
122 let mut f = f;
123
124 for j in 0..n {
125 let (dst, dst_val) = dst.rb_mut().parts_mut();
126
127 let dst_val = &mut dst_val[dst.col_range(j)];
128 let src_val = src.val_of_col(j);
129
130 let dst = dst.row_idx_of_col_raw(j);
131 let src = src.row_idx_of_col_raw(j);
132
133 let mut dst_pos = 0usize;
134 let mut src_pos = 0usize;
135
136 while src_pos < src.len() {
137 let src = src[src_pos];
138
139 if dst[dst_pos] < src {
140 f(&mut dst_val[dst_pos], None);
141 dst_pos += 1;
142 continue;
143 }
144
145 assert!(dst[dst_pos] == src);
146
147 f(&mut dst_val[dst_pos], Some(&src_val[src_pos]));
148
149 src_pos += 1;
150 dst_pos += 1;
151 }
152 while dst_pos < dst.len() {
153 f(&mut dst_val[dst_pos], None);
154 dst_pos += 1;
155 }
156 }
157 }
158}
159
160#[track_caller]
168pub fn ternary_op_assign_into<I: Index, T, LhsT, RhsT>(
169 dst: SparseColMatMut<'_, I, T>,
170 lhs: SparseColMatRef<'_, I, LhsT>,
171 rhs: SparseColMatRef<'_, I, RhsT>,
172 f: impl FnMut(&mut T, Option<&LhsT>, Option<&RhsT>),
173) {
174 {
175 assert!(dst.nrows() == lhs.nrows());
176 assert!(dst.ncols() == lhs.ncols());
177 assert!(dst.nrows() == rhs.nrows());
178 assert!(dst.ncols() == rhs.ncols());
179
180 let n = dst.ncols();
181 let mut dst = dst;
182 let mut f = f;
183
184 for j in 0..n {
185 let (dst, dst_val) = dst.rb_mut().parts_mut();
186
187 let dst_val = &mut dst_val[dst.col_range(j)];
188 let lhs_val = lhs.val_of_col(j);
189 let rhs_val = rhs.val_of_col(j);
190
191 let dst = dst.row_idx_of_col_raw(j);
192 let rhs = rhs.row_idx_of_col_raw(j);
193 let lhs = lhs.row_idx_of_col_raw(j);
194
195 let mut dst_pos = 0usize;
196 let mut lhs_pos = 0usize;
197 let mut rhs_pos = 0usize;
198
199 while lhs_pos < lhs.len() && rhs_pos < rhs.len() {
200 let lhs = lhs[lhs_pos];
201 let rhs = rhs[rhs_pos];
202
203 if dst[dst_pos] < Ord::min(lhs, rhs) {
204 f(&mut dst_val[dst_pos], None, None);
205 dst_pos += 1;
206 continue;
207 }
208
209 assert!(dst[dst_pos] == Ord::min(lhs, rhs));
210
211 match lhs.cmp(&rhs) {
212 core::cmp::Ordering::Less => {
213 f(&mut dst_val[dst_pos], Some(&lhs_val[lhs_pos]), None);
214 },
215 core::cmp::Ordering::Equal => {
216 f(&mut dst_val[dst_pos], Some(&lhs_val[lhs_pos]), Some(&rhs_val[rhs_pos]));
217 },
218 core::cmp::Ordering::Greater => {
219 f(&mut dst_val[dst_pos], None, Some(&rhs_val[rhs_pos]));
220 },
221 }
222
223 lhs_pos += (lhs <= rhs) as usize;
224 rhs_pos += (rhs <= lhs) as usize;
225 dst_pos += 1;
226 }
227 while lhs_pos < lhs.len() {
228 let lhs = lhs[lhs_pos];
229 if dst[dst_pos] < lhs {
230 f(&mut dst_val[dst_pos], None, None);
231 dst_pos += 1;
232 continue;
233 }
234 f(&mut dst_val[dst_pos], Some(&lhs_val[lhs_pos]), None);
235 lhs_pos += 1;
236 dst_pos += 1;
237 }
238 while rhs_pos < rhs.len() {
239 let rhs = rhs[rhs_pos];
240 if dst[dst_pos] < rhs {
241 f(&mut dst_val[dst_pos], None, None);
242 dst_pos += 1;
243 continue;
244 }
245 f(&mut dst_val[dst_pos], None, Some(&rhs_val[rhs_pos]));
246 rhs_pos += 1;
247 dst_pos += 1;
248 }
249 while rhs_pos < rhs.len() {
250 let rhs = rhs[rhs_pos];
251 dst_pos += dst[dst_pos..].binary_search(&rhs).unwrap();
252 f(&mut dst_val[dst_pos], None, Some(&rhs_val[rhs_pos]));
253 rhs_pos += 1;
254 }
255 }
256 }
257}
258
259#[track_caller]
264#[inline]
265pub fn union_symbolic<I: Index>(
266 lhs: SymbolicSparseColMatRef<'_, I>,
267 rhs: SymbolicSparseColMatRef<'_, I>,
268) -> Result<SymbolicSparseColMat<I>, FaerError> {
269 Ok(binary_op(
270 SparseColMatRef::<I, Symbolic>::new(lhs, Symbolic::materialize(lhs.compute_nnz())),
271 SparseColMatRef::<I, Symbolic>::new(rhs, Symbolic::materialize(rhs.compute_nnz())),
272 #[inline(always)]
273 |_, _| Symbolic,
274 )?
275 .into_parts()
276 .0)
277}
278
279#[track_caller]
284#[inline]
285pub fn add<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
286 lhs: SparseColMatRef<'_, I, LhsT>,
287 rhs: SparseColMatRef<'_, I, RhsT>,
288) -> Result<SparseColMat<I, T>, FaerError> {
289 binary_op(lhs, rhs, |lhs, rhs| match (lhs.map(Conj::apply), rhs.map(Conj::apply)) {
290 (None, None) => zero(),
291 (None, Some(rhs)) => rhs,
292 (Some(lhs), None) => lhs,
293 (Some(lhs), Some(rhs)) => faer_traits::math_utils::add(&lhs, &rhs),
294 })
295}
296
297#[track_caller]
302#[inline]
303pub fn sub<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
304 lhs: SparseColMatRef<'_, I, LhsT>,
305 rhs: SparseColMatRef<'_, I, RhsT>,
306) -> Result<SparseColMat<I, T>, FaerError> {
307 binary_op(lhs, rhs, |lhs, rhs| match (lhs.map(Conj::apply), rhs.map(Conj::apply)) {
308 (None, None) => zero(),
309 (None, Some(rhs)) => rhs,
310 (Some(lhs), None) => lhs,
311 (Some(lhs), Some(rhs)) => faer_traits::math_utils::sub(&lhs, &rhs),
312 })
313}
314
315pub fn add_assign<I: Index, T: ComplexField, RhsT: Conjugate<Canonical = T>>(dst: SparseColMatMut<'_, I, T>, rhs: SparseColMatRef<'_, I, RhsT>) {
322 binary_op_assign_into(dst, rhs, |dst, rhs| {
323 *dst = faer_traits::math_utils::add(
324 dst,
325 &match rhs {
326 Some(rhs) => Conj::apply(rhs),
327 None => zero(),
328 },
329 )
330 })
331}
332
333pub fn sub_assign<I: Index, T: ComplexField, RhsT: Conjugate<Canonical = T>>(dst: SparseColMatMut<'_, I, T>, rhs: SparseColMatRef<'_, I, RhsT>) {
340 binary_op_assign_into(dst, rhs, |dst, rhs| {
341 *dst = faer_traits::math_utils::sub(
342 dst,
343 &match rhs {
344 Some(rhs) => Conj::apply(rhs),
345 None => zero(),
346 },
347 )
348 })
349}
350
351#[track_caller]
358#[inline]
359pub fn add_into<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
360 dst: SparseColMatMut<'_, I, T>,
361 lhs: SparseColMatRef<'_, I, LhsT>,
362 rhs: SparseColMatRef<'_, I, RhsT>,
363) {
364 ternary_op_assign_into(dst, lhs, rhs, |dst, lhs, rhs| {
365 *dst = match (lhs.map(Conj::apply), rhs.map(Conj::apply)) {
366 (None, None) => zero(),
367 (None, Some(rhs)) => rhs,
368 (Some(lhs), None) => lhs,
369 (Some(lhs), Some(rhs)) => faer_traits::math_utils::add(&lhs, &rhs),
370 };
371 })
372}
373
374#[track_caller]
381#[inline]
382pub fn sub_into<I: Index, T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>>(
383 dst: SparseColMatMut<'_, I, T>,
384 lhs: SparseColMatRef<'_, I, LhsT>,
385 rhs: SparseColMatRef<'_, I, RhsT>,
386) {
387 ternary_op_assign_into(dst, lhs, rhs, |dst, lhs, rhs| {
388 *dst = match (lhs.map(Conj::apply), rhs.map(Conj::apply)) {
389 (None, None) => zero(),
390 (None, Some(rhs)) => rhs,
391 (Some(lhs), None) => lhs,
392 (Some(lhs), Some(rhs)) => faer_traits::math_utils::sub(&lhs, &rhs),
393 };
394 })
395}