faer/sparse/
utils.rs

1use crate::internal_prelude_sp::*;
2use crate::{assert, debug_assert};
3
4/// sorts `row_indices` and `values` simultaneously so that `row_indices` is nonincreasing.
5pub fn sort_indices<I: Index, T>(col_ptr: &[I], col_nnz: Option<&[I]>, row_idx: &mut [I], val: &mut [T]) {
6	assert!(col_ptr.len() > 0);
7
8	let n = col_ptr.len() - 1;
9	for j in 0..n {
10		let start = col_ptr[j].zx();
11		let end = col_nnz.map(|nnz| start + nnz[j].zx()).unwrap_or(col_ptr[j + 1].zx());
12		unsafe { crate::sort::sort_indices(&mut row_idx[start..end], &mut val[start..end]) };
13	}
14}
15
16/// sorts and deduplicates `row_indices` and `values` simultaneously so that `row_indices` is
17/// nonincreasing and contains no duplicate indices.
18pub fn sort_dedup_indices<I: Index, T: ComplexField>(col_ptr: &[I], col_nnz: &mut [I], row_idx: &mut [I], val: &mut [T]) {
19	assert!(col_ptr.len() > 0);
20
21	let n = col_ptr.len() - 1;
22	for j in 0..n {
23		let start = col_ptr[j].zx();
24		let end = start + col_nnz[j].zx();
25		unsafe { crate::sort::sort_indices(&mut row_idx[start..end], &mut val[start..end]) };
26
27		let mut prev = I::truncate(usize::MAX);
28
29		let mut writer = start;
30		let mut reader = start;
31		while reader < end {
32			let cur = row_idx[reader];
33			if cur == prev {
34				writer -= 1;
35				val[writer] = add(&val[writer], &val[reader]);
36			} else {
37				val[writer] = copy(&val[reader]);
38			}
39
40			prev = cur;
41			reader += 1;
42			writer += 1;
43		}
44
45		col_nnz[j] = I::truncate(writer - start);
46	}
47}
48
49/// computes the workspace size and alignment required to apply a two sided permutation to a
50/// self-adjoint sparse matrix
51pub fn permute_self_adjoint_scratch<I: Index>(dim: usize) -> StackReq {
52	StackReq::new::<I>(dim)
53}
54
55/// computes the workspace size and alignment required to apply a two sided permutation to a
56/// self-adjoint sparse matrix and deduplicate its elements
57pub fn permute_dedup_self_adjoint_scratch<I: Index>(dim: usize) -> StackReq {
58	StackReq::new::<I>(dim)
59}
60
61/// computes the self-adjoint permutation $P A P^\top$ of the matrix $A$
62///
63/// the result is stored in `new_col_ptrs`, `new_row_indices`
64///
65/// # note
66/// allows unsorted matrices, producing a sorted output. duplicate entries are kept, however
67pub fn permute_self_adjoint<'out, N: Shape, I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
68	new_val: &'out mut [T],
69	new_col_ptr: &'out mut [I],
70	new_row_idx: &'out mut [I],
71	A: SparseColMatRef<'_, I, C, N, N>,
72	perm: PermRef<'_, I, N>,
73	in_side: Side,
74	out_side: Side,
75	stack: &mut MemStack,
76) -> SparseColMatMut<'out, I, T, N, N> {
77	let n = A.nrows();
78	with_dim!(N, n.unbound());
79
80	permute_self_adjoint_imp(
81		new_val,
82		new_col_ptr,
83		new_row_idx,
84		A.as_shape(N, N).canonical(),
85		Conj::get::<C>(),
86		perm.as_shape(N),
87		in_side,
88		out_side,
89		true,
90		stack,
91	)
92	.as_shape_mut(n, n)
93}
94
95/// computes the self-adjoint permutation $P A P^\top$ of the matrix $A$ without sorting the row
96/// indices, and returns a view over it
97///
98/// the result is stored in `new_col_ptrs`, `new_row_indices`
99///
100/// # note
101/// allows unsorted matrices, producing an sorted output
102pub fn permute_self_adjoint_to_unsorted<'out, N: Shape, I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
103	new_val: &'out mut [T],
104	new_col_ptr: &'out mut [I],
105	new_row_idx: &'out mut [I],
106	A: SparseColMatRef<'_, I, C, N, N>,
107	perm: PermRef<'_, I, N>,
108	in_side: Side,
109	out_side: Side,
110	stack: &mut MemStack,
111) -> SparseColMatMut<'out, I, T, N, N> {
112	let n = A.nrows();
113	with_dim!(N, n.unbound());
114
115	permute_self_adjoint_imp(
116		new_val,
117		new_col_ptr,
118		new_row_idx,
119		A.as_shape(N, N).canonical(),
120		Conj::get::<C>(),
121		perm.as_shape(N),
122		in_side,
123		out_side,
124		false,
125		stack,
126	)
127	.as_shape_mut(n, n)
128}
129
130/// computes the self-adjoint permutation $P A P^\top$ of the matrix $A$ and deduplicate the
131/// elements of the output matrix
132///
133/// the result is stored in `new_col_ptrs`, `new_row_indices`
134///
135/// # note
136/// allows unsorted matrices, producing a sorted output. duplicate entries are merged
137pub fn permute_dedup_self_adjoint<'out, N: Shape, I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
138	new_val: &'out mut [T],
139	new_col_ptr: &'out mut [I],
140	new_row_idx: &'out mut [I],
141	A: SparseColMatRef<'_, I, C, N, N>,
142	perm: PermRef<'_, I, N>,
143	in_side: Side,
144	out_side: Side,
145	stack: &mut MemStack,
146) -> SparseColMatMut<'out, I, T, N, N> {
147	let n = A.nrows();
148	with_dim!(N, n.unbound());
149
150	permute_dedup_self_adjoint_imp(
151		new_val,
152		new_col_ptr,
153		new_row_idx,
154		A.as_shape(N, N).canonical(),
155		Conj::get::<C>(),
156		perm.as_shape(N),
157		in_side,
158		out_side,
159		stack,
160	)
161	.as_shape_mut(n, n)
162}
163
164fn permute_self_adjoint_imp<'N, 'out, I: Index, T: ComplexField>(
165	new_val: &'out mut [T],
166	new_col_ptr: &'out mut [I],
167	new_row_idx: &'out mut [I],
168	A: SparseColMatRef<'_, I, T, Dim<'N>, Dim<'N>>,
169	conj_A: Conj,
170	perm: PermRef<'_, I, Dim<'N>>,
171	in_side: Side,
172	out_side: Side,
173	sort: bool,
174	stack: &mut MemStack,
175) -> SparseColMatMut<'out, I, T, Dim<'N>, Dim<'N>> {
176	// old_i <= old_j => -old_i >= -old_j
177	// reverse the order with bitwise not
178	// x + !x == MAX
179	// x + !x + 1 == 0
180	// !x = -1 - x
181
182	// if we flipped the side of A, then we need to check old_i <= old_j instead
183	let src_to_cmp = {
184		let mask = match in_side {
185			Side::Lower => 0,
186			Side::Upper => usize::MAX,
187		};
188		move |i: usize| mask ^ i
189	};
190
191	let dst_to_cmp = {
192		let mask = match out_side {
193			Side::Lower => 0,
194			Side::Upper => usize::MAX,
195		};
196		move |i: usize| mask ^ i
197	};
198
199	let conj_A = conj_A.is_conj();
200
201	// in_side/out_side are assumed Side::Lower
202
203	let N = A.ncols();
204	let n = *N;
205
206	assert!(new_col_ptr.len() == n + 1);
207	let (_, perm_inv) = perm.bound_arrays();
208
209	let (mut cur_row_pos, _) = stack.collect(repeat_n!(I::truncate(0), n));
210	let cur_row_pos = Array::from_mut(&mut cur_row_pos, N);
211
212	let col_counts = &mut *cur_row_pos;
213	for old_j in N.indices() {
214		let new_j = perm_inv[old_j].zx();
215
216		let old_j_cmp = src_to_cmp(*old_j);
217		let new_j_cmp = dst_to_cmp(*new_j);
218
219		for old_i in A.row_idx_of_col(old_j) {
220			let new_i = perm_inv[old_i].zx();
221
222			let old_i_cmp = src_to_cmp(*old_i);
223			let new_i_cmp = dst_to_cmp(*new_i);
224
225			if old_i_cmp >= old_j_cmp {
226				let lower = new_i_cmp >= new_j_cmp;
227				let new_j = if lower { new_j } else { new_i };
228
229				// cannot overflow because A.compute_nnz() <= I::Signed::MAX
230				// col_counts[new_j] always >= 0
231				col_counts[new_j] += I::truncate(1);
232			}
233		}
234	}
235
236	// col_counts[_] >= 0
237	// cumulative sum cannot overflow because it's <= A.compute_nnz()
238
239	new_col_ptr[0] = I::truncate(0);
240	for (count, [ci0, ci1]) in iter::zip(col_counts.as_mut(), windows2(Cell::as_slice_of_cells(Cell::from_mut(&mut *new_col_ptr)))) {
241		let ci0 = ci0.get();
242		ci1.set(ci0 + *count);
243		*count = ci0;
244	}
245
246	// new_col_ptr is non-decreasing
247	let nnz = new_col_ptr[n].zx();
248	let new_row_idx = &mut new_row_idx[..nnz];
249	let new_val = &mut new_val[..nnz];
250
251	{
252		with_dim!(NNZ, nnz);
253		let new_val = Array::from_mut(new_val, NNZ);
254		let new_row_idx = Array::from_mut(new_row_idx, NNZ);
255
256		let conj_if = |cond: bool, x: &T| -> T {
257			if try_const! { T::IS_REAL } {
258				copy(x)
259			} else {
260				if cond != conj_A { conj(x) } else { copy(x) }
261			}
262		};
263
264		for old_j in N.indices() {
265			let new_j = perm_inv[old_j].zx();
266
267			let old_j_cmp = src_to_cmp(*old_j);
268			let new_j_cmp = dst_to_cmp(*new_j);
269
270			for (old_i, val) in iter::zip(A.row_idx_of_col(old_j), A.val_of_col(old_j)) {
271				let new_i = perm_inv[old_i].zx();
272
273				let old_i_cmp = src_to_cmp(*old_i);
274				let new_i_cmp = dst_to_cmp(*new_i);
275
276				if old_i_cmp >= old_j_cmp {
277					let lower = new_i_cmp >= new_j_cmp;
278
279					let (new_j, new_i) = if lower { (new_j, new_i) } else { (new_i, new_j) };
280
281					let cur_row_pos = &mut cur_row_pos[new_j];
282
283					// SAFETY: cur_row_pos < NNZ
284					let row_pos = unsafe { Idx::new_unchecked(cur_row_pos.zx(), NNZ) };
285
286					*cur_row_pos += I::truncate(1);
287
288					new_val[row_pos] = conj_if(!lower, val);
289					new_row_idx[row_pos] = I::truncate(*new_i);
290				}
291			}
292		}
293	}
294
295	if sort {
296		sort_indices(new_col_ptr, None, new_row_idx, new_val);
297	}
298	// SAFETY:
299	// 0. new_col_ptr is non-decreasing
300	// 1. all written row indices are less than n
301
302	unsafe { SparseColMatMut::new(SymbolicSparseColMatRef::new_unchecked(N, N, new_col_ptr, None, new_row_idx), new_val) }
303}
304
305fn permute_dedup_self_adjoint_imp<'N, 'out, I: Index, T: ComplexField>(
306	new_val: &'out mut [T],
307	new_col_ptr: &'out mut [I],
308	new_row_idx: &'out mut [I],
309	A: SparseColMatRef<'_, I, T, Dim<'N>, Dim<'N>>,
310	conj_A: Conj,
311	perm: PermRef<'_, I, Dim<'N>>,
312	in_side: Side,
313	out_side: Side,
314	stack: &mut MemStack,
315) -> SparseColMatMut<'out, I, T, Dim<'N>, Dim<'N>> {
316	let N = A.nrows();
317
318	permute_self_adjoint_imp(new_val, new_col_ptr, new_row_idx, A, conj_A, perm, in_side, out_side, false, stack);
319
320	{
321		let new_col_ptr = Cell::as_slice_of_cells(Cell::from_mut(new_col_ptr));
322
323		let start = Array::from_ref(&new_col_ptr[..*N], N);
324		let end = Array::from_ref(&new_col_ptr[1..], N);
325		let mut writer = 0usize;
326
327		for j in N.indices() {
328			let start = start[j].replace(I::truncate(writer)).zx();
329			let end = end[j].get().zx();
330
331			unsafe {
332				crate::sort::sort_indices(&mut new_row_idx[start..end], &mut new_val[start..end]);
333			}
334
335			let mut prev = I::truncate(usize::MAX);
336
337			let mut reader = start;
338			while reader < end {
339				let cur = new_row_idx[reader];
340
341				if cur == prev {
342					// same element, add
343					writer -= 1;
344					new_val[writer] = add(&new_val[writer], &new_val[reader]);
345				} else {
346					// new element, copy
347					new_row_idx[writer] = new_row_idx[reader];
348					new_val[writer] = copy(&new_val[reader]);
349				}
350
351				prev = cur;
352				reader += 1;
353				writer += 1;
354			}
355		}
356		new_col_ptr[*N].set(I::truncate(writer));
357	}
358
359	unsafe { SparseColMatMut::new(SymbolicSparseColMatRef::new_unchecked(N, N, new_col_ptr, None, new_row_idx), new_val) }
360}
361
362/// computes the workspace size and alignment required to transpose a matrix
363pub fn transpose_scratch<I: Index>(nrows: usize, ncols: usize) -> StackReq {
364	_ = ncols;
365	StackReq::new::<usize>(nrows)
366}
367
368/// computes the workspace size and alignment required to transpose a matrix and deduplicate the
369/// output elements
370pub fn transpose_dedup_scratch<I: Index>(nrows: usize, ncols: usize) -> StackReq {
371	_ = ncols;
372	StackReq::new::<usize>(nrows).array(2)
373}
374
375/// computes the transpose of the matrix $A$ and returns a view over it.
376///
377/// the result is stored in `new_col_ptrs`, `new_row_indices` and `new_values`.
378///
379/// # note
380/// allows unsorted matrices, producing a sorted output. duplicate entries are kept, however
381pub fn transpose<'out, Rows: Shape, Cols: Shape, I: Index, T: Clone>(
382	new_val: &'out mut [T],
383	new_col_ptr: &'out mut [I],
384	new_row_idx: &'out mut [I],
385	A: SparseColMatRef<'_, I, T, Rows, Cols>,
386	stack: &mut MemStack,
387) -> SparseColMatMut<'out, I, T, Cols, Rows> {
388	let (m, n) = A.shape();
389	with_dim!(M, m.unbound());
390	with_dim!(N, n.unbound());
391
392	transpose_imp(T::clone, new_val, new_col_ptr, new_row_idx, A.as_shape(M, N), stack).as_shape_mut(n, m)
393}
394
395/// computes the adjoint of the matrix $A$ and returns a view over it.
396///
397/// the result is stored in `new_col_ptrs`, `new_row_indices` and `new_values`.
398///
399/// # note
400/// allows unsorted matrices, producing a sorted output. duplicate entries are kept, however
401pub fn adjoint<'out, Rows: Shape, Cols: Shape, I: Index, T: ComplexField>(
402	new_val: &'out mut [T],
403	new_col_ptr: &'out mut [I],
404	new_row_idx: &'out mut [I],
405	A: SparseColMatRef<'_, I, T, Rows, Cols>,
406	stack: &mut MemStack,
407) -> SparseColMatMut<'out, I, T, Cols, Rows> {
408	let (m, n) = A.shape();
409	with_dim!(M, m.unbound());
410	with_dim!(N, n.unbound());
411
412	transpose_imp(conj::<T>, new_val, new_col_ptr, new_row_idx, A.as_shape(M, N), stack).as_shape_mut(n, m)
413}
414
415/// computes the transpose of the matrix $A$ and returns a view over it.
416///
417/// the result is stored in `new_col_ptrs`, `new_row_indices` and `new_values`.
418///
419/// # note
420/// allows unsorted matrices, producing a sorted output. duplicate entries are merged
421pub fn transpose_dedup<'out, Rows: Shape, Cols: Shape, I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
422	new_val: &'out mut [T],
423	new_col_ptr: &'out mut [I],
424	new_row_idx: &'out mut [I],
425	A: SparseColMatRef<'_, I, C, Rows, Cols>,
426	stack: &mut MemStack,
427) -> SparseColMatMut<'out, I, T, Cols, Rows> {
428	let (m, n) = A.shape();
429	with_dim!(M, m.unbound());
430	with_dim!(N, n.unbound());
431
432	transpose_dedup_imp(new_val, new_col_ptr, new_row_idx, A.as_shape(M, N), stack).as_shape_mut(n, m)
433}
434
435fn transpose_imp<'ROWS, 'COLS, 'out, I: Index, T>(
436	clone: impl Fn(&T) -> T,
437	new_val: &'out mut [T],
438	new_col_ptr: &'out mut [I],
439	new_row_idx: &'out mut [I],
440	A: SparseColMatRef<'_, I, T, Dim<'ROWS>, Dim<'COLS>>,
441	stack: &mut MemStack,
442) -> SparseColMatMut<'out, I, T, Dim<'COLS>, Dim<'ROWS>> {
443	let (M, N) = A.shape();
444	assert!(new_col_ptr.len() == *M + 1);
445	let (mut col_count, _) = stack.collect(repeat_n!(I::truncate(0), *M));
446	let col_count = Array::from_mut(&mut col_count, M);
447
448	// can't overflow because the total count is A.compute_nnz() <= I::Signed::MAX
449	for j in N.indices() {
450		for i in A.row_idx_of_col(j) {
451			col_count[i] += I::truncate(1);
452		}
453	}
454
455	new_col_ptr[0] = I::truncate(0);
456
457	// col_count elements are >= 0
458	for (j, [pj0, pj1]) in iter::zip(M.indices(), windows2(Cell::as_slice_of_cells(Cell::from_mut(new_col_ptr)))) {
459		let cj = &mut col_count[j];
460		let pj = pj0.get();
461		// new_col_ptr is non-decreasing
462		pj1.set(pj + *cj);
463
464		// *cj = cur_row_pos
465		*cj = pj;
466	}
467
468	let new_row_idx = &mut new_row_idx[..new_col_ptr[*M].zx()];
469	let new_val = &mut new_val[..new_col_ptr[*M].zx()];
470	let cur_row_pos = col_count;
471
472	for j in N.indices() {
473		for (i, val) in iter::zip(A.row_idx_of_col(j), A.val_of_col(j)) {
474			let ci = &mut cur_row_pos[i];
475			// SAFETY: see below
476			unsafe {
477				let ci = ci.zx();
478				*new_row_idx.get_unchecked_mut(ci) = I::truncate(*j);
479				*new_val.get_unchecked_mut(ci) = clone(val);
480			}
481			*ci += I::truncate(1);
482		}
483	}
484	// cur_row_pos[i] == col_ptr[i] + col_count[i] == col_ptr[i + 1] <= col_ptr[m]
485	// so all the unchecked accesses were valid and non-overlapping, which means the entire array is
486	// filled.
487	debug_assert!(cur_row_pos.as_ref() == &new_col_ptr[1..]);
488
489	// SAFETY:
490	// 0. new_col_ptr is non-decreasing
491	// 1. all written row indices are less than n
492	unsafe { SparseColMatMut::new(SymbolicSparseColMatRef::new_unchecked(N, M, new_col_ptr, None, new_row_idx), new_val) }
493}
494
495fn transpose_dedup_imp<'ROWS, 'COLS, 'out, I: Index, T: ComplexField, C: Conjugate<Canonical = T>>(
496	new_val: &'out mut [T],
497	new_col_ptr: &'out mut [I],
498	new_row_idx: &'out mut [I],
499	A: SparseColMatRef<'_, I, C, Dim<'ROWS>, Dim<'COLS>>,
500	stack: &mut MemStack,
501) -> SparseColMatMut<'out, I, T, Dim<'COLS>, Dim<'ROWS>> {
502	let (M, N) = A.shape();
503	assert!(new_col_ptr.len() == *M + 1);
504	let A = A.canonical();
505
506	let sentinel = I::truncate(usize::MAX);
507	let (mut col_count, stack) = stack.collect(repeat_n!(I::truncate(0), *M));
508	let (mut last_seen, _) = stack.collect(repeat_n!(sentinel, *M));
509
510	let col_count = Array::from_mut(&mut col_count, M);
511	let last_seen = Array::from_mut(&mut last_seen, M);
512
513	// can't overflow because the total count is A.compute_nnz() <= I::Signed::MAX
514	for j in N.indices() {
515		for i in A.row_idx_of_col(j) {
516			let j = I::truncate(*j);
517			if last_seen[i] == j {
518				continue;
519			}
520			last_seen[i] = j;
521			col_count[i] += I::truncate(1);
522		}
523	}
524
525	new_col_ptr[0] = I::truncate(0);
526
527	// col_count elements are >= 0
528	for (j, [pj0, pj1]) in iter::zip(M.indices(), windows2(Cell::as_slice_of_cells(Cell::from_mut(new_col_ptr)))) {
529		let cj = &mut col_count[j];
530		let pj = pj0.get();
531		// new_col_ptr is non-decreasing
532		pj1.set(pj + *cj);
533
534		// *cj = cur_row_pos
535		*cj = pj;
536	}
537
538	last_seen.as_mut().fill(sentinel);
539
540	let new_row_idx = &mut new_row_idx[..new_col_ptr[*M].zx()];
541	let new_val = &mut new_val[..new_col_ptr[*M].zx()];
542	let cur_row_pos = col_count;
543
544	for j in N.indices() {
545		for (i, val) in iter::zip(A.row_idx_of_col(j), A.val_of_col(j)) {
546			let ci = &mut cur_row_pos[i];
547
548			let val = if Conj::get::<C>().is_conj() { conj(val) } else { copy(val) };
549
550			let j = I::truncate(*j);
551			// SAFETY: see below
552			unsafe {
553				if last_seen[i] == j {
554					let ci = ci.zx() - 1;
555					*new_val.get_unchecked_mut(ci) = add(new_val.get_unchecked(ci), &val);
556				} else {
557					last_seen[i] = j;
558					*ci += I::truncate(1);
559
560					let ci = ci.zx() - 1;
561					{
562						*new_row_idx.get_unchecked_mut(ci) = j;
563						*new_val.get_unchecked_mut(ci) = val;
564					}
565				}
566			}
567		}
568	}
569	// cur_row_pos[i] == col_ptr[i] + col_count[i] == col_ptr[i + 1] <= col_ptr[m]
570	// so all the unchecked accesses were valid and non-overlapping, which means the entire array is
571	// filled.
572	debug_assert!(cur_row_pos.as_ref() == &new_col_ptr[1..]);
573
574	// SAFETY:
575	// 0. new_col_ptr is non-decreasing
576	// 1. all written row indices are less than n
577	unsafe { SparseColMatMut::new(SymbolicSparseColMatRef::new_unchecked(N, M, new_col_ptr, None, new_row_idx), new_val) }
578}
579
580#[cfg(test)]
581mod tests {
582	use super::*;
583	use crate::assert;
584	use crate::stats::prelude::*;
585	use dyn_stack::MemBuffer;
586
587	#[test]
588	fn test_transpose() {
589		let nrows = 5;
590		let ncols = 3;
591		let A = SparseColMatRef::new(
592			SymbolicSparseColMatRef::new_unsorted_checked(
593				nrows,
594				ncols,
595				&[0usize, 4, 8, 11],
596				None,
597				&[
598					0, 0, 2, 4, //
599					2, 1, 1, 0, //
600					0, 1, 3,
601				],
602			),
603			&[
604				1.0, 2.0, 3.0, 4.0, //
605				11.0, 12.0, 13.0, 14.0, //
606				21.0, 22.0, 23.0,
607			],
608		);
609		let nnz = A.compute_nnz();
610
611		let new_col_ptr = &mut *vec![0usize; nrows + 1];
612		let new_row_idx = &mut *vec![0usize; nnz];
613		let new_val = &mut *vec![0.0; nnz];
614		{
615			let out = transpose(
616				new_val,
617				new_col_ptr,
618				new_row_idx,
619				A,
620				MemStack::new(&mut MemBuffer::new(transpose_scratch::<usize>(nrows, ncols))),
621			)
622			.into_const();
623
624			let target = SparseColMatRef::new(
625				SymbolicSparseColMatRef::new_unsorted_checked(
626					ncols,
627					nrows,
628					&[0usize, 4, 7, 9, 10, 11],
629					None,
630					&[
631						0, 0, 1, 2, //
632						1, 1, 2, //
633						0, 1, //
634						2, //
635						0,
636					],
637				),
638				&[
639					1.0, 2.0, 14.0, 21.0, //
640					12.0, 13.0, 22.0, //
641					3.0, 11.0, //
642					23.0, //
643					4.0,
644				],
645			);
646
647			assert!(all(
648				out.col_ptr() == target.col_ptr(),
649				out.row_idx() == target.row_idx(),
650				out.val() == target.val()
651			));
652		}
653
654		{
655			let out = transpose_dedup(
656				new_val,
657				new_col_ptr,
658				new_row_idx,
659				A,
660				MemStack::new(&mut MemBuffer::new(transpose_dedup_scratch::<usize>(nrows, ncols))),
661			)
662			.into_const();
663
664			let target = SparseColMatRef::new(
665				SymbolicSparseColMatRef::new_unsorted_checked(
666					ncols,
667					nrows,
668					&[0usize, 3, 5, 7, 8, 9],
669					None,
670					&[
671						0, 1, 2, //
672						1, 2, //
673						0, 1, //
674						2, //
675						0,
676					],
677				),
678				&[
679					3.0, 14.0, 21.0, //
680					25.0, 22.0, //
681					3.0, 11.0, //
682					23.0, //
683					4.0,
684				],
685			);
686
687			assert!(all(
688				out.col_ptr() == target.col_ptr(),
689				out.row_idx() == target.row_idx(),
690				out.val() == target.val()
691			));
692		}
693	}
694
695	#[test]
696	fn test_permute_self_adjoint() {
697		let n = 5;
698		let rng = &mut StdRng::seed_from_u64(0);
699		let diag_rng = &mut StdRng::seed_from_u64(1);
700
701		let mut rand = || ComplexDistribution::new(StandardNormal, StandardNormal).rand::<c64>(rng);
702		let mut rand_diag = || c64::new(StandardNormal.rand(diag_rng), 0.0);
703
704		let val = &[
705			rand_diag(),
706			rand_diag(),
707			rand(),
708			rand(),
709			//
710			rand(),
711			rand_diag(),
712			rand_diag(),
713			rand(),
714			//
715			rand(),
716			rand(),
717			rand(),
718			//
719			rand(),
720			rand_diag(),
721			rand(),
722			//
723			rand_diag(),
724			rand(),
725			rand(),
726		];
727
728		let A = SparseColMatRef::new(
729			SymbolicSparseColMatRef::new_unsorted_checked(
730				n,
731				n,
732				&[0usize, 4, 8, 11, 14, 17],
733				None,
734				&[
735					0, 0, 2, 4, //
736					2, 1, 1, 0, //
737					0, 1, 3, //
738					2, 3, 4, //
739					4, 3, 2, //
740				],
741			),
742			val,
743		);
744		let nnz = A.compute_nnz();
745
746		let perm_fwd = &mut *vec![0, 4, 1, 3, 2usize];
747		let perm_bwd = &mut *vec![0; 5];
748		for i in 0..n {
749			perm_bwd[perm_fwd[i]] = i;
750		}
751
752		let perm = PermRef::new_checked(perm_fwd, perm_bwd, n);
753
754		let new_col_ptr = &mut *vec![0usize; n + 1];
755		let new_row_idx = &mut *vec![0usize; nnz];
756		let new_val = &mut *vec![c64::ZERO; nnz];
757
758		for f in [permute_self_adjoint_to_unsorted, permute_self_adjoint, permute_dedup_self_adjoint] {
759			for (in_side, out_side) in [
760				(Side::Lower, Side::Lower),
761				(Side::Lower, Side::Upper),
762				(Side::Upper, Side::Lower),
763				(Side::Upper, Side::Upper),
764			] {
765				let mut out = f(
766					new_val,
767					new_col_ptr,
768					new_row_idx,
769					A,
770					perm,
771					in_side,
772					out_side,
773					MemStack::new(&mut MemBuffer::new(permute_self_adjoint_scratch::<usize>(n))),
774				)
775				.to_dense();
776
777				let mut A = A.to_dense();
778
779				match in_side {
780					Side::Lower => {
781						z!(&mut A).for_each_triangular_upper(linalg::zip::Diag::Skip, |uz!(x)| *x = c64::ZERO);
782						for j in 0..n {
783							for i in 0..j {
784								A[(i, j)] = A[(j, i)].conj();
785							}
786						}
787					},
788					Side::Upper => {
789						z!(&mut A).for_each_triangular_lower(linalg::zip::Diag::Skip, |uz!(x)| *x = c64::ZERO);
790						for j in 0..n {
791							for i in j + 1..n {
792								A[(i, j)] = A[(j, i)].conj();
793							}
794						}
795					},
796				}
797
798				match out_side {
799					Side::Lower => {
800						for j in 0..n {
801							for i in 0..j {
802								out[(i, j)] = out[(j, i)].conj();
803							}
804						}
805					},
806					Side::Upper => {
807						for j in 0..n {
808							for i in j + 1..n {
809								out[(i, j)] = out[(j, i)].conj();
810							}
811						}
812					},
813				}
814
815				assert!(out == perm * &A * perm.inverse());
816			}
817		}
818	}
819}