faer/perm/
mod.rs

1use crate::Idx;
2use crate::internal_prelude::*;
3use dyn_stack::StackReq;
4use linalg::zip::{Last, Zip};
5use reborrow::*;
6
7/// swaps the values in the columns `a` and `b`
8///
9/// # panics
10///
11/// panics if `a` and `b` don't have the same number of columns
12///
13/// # example
14///
15/// ```
16/// use faer::{mat, perm};
17///
18/// let mut m = mat![
19/// 	[1.0, 2.0, 3.0], //
20/// 	[4.0, 5.0, 6.0],
21/// 	[7.0, 8.0, 9.0],
22/// 	[10.0, 14.0, 12.0],
23/// ];
24///
25/// let (a, b) = m.two_cols_mut(0, 2);
26/// perm::swap_cols(a, b);
27///
28/// let swapped = mat![
29/// 	[3.0, 2.0, 1.0], //
30/// 	[6.0, 5.0, 4.0],
31/// 	[9.0, 8.0, 7.0],
32/// 	[12.0, 14.0, 10.0],
33/// ];
34///
35/// assert_eq!(m, swapped);
36/// ```
37#[track_caller]
38#[inline]
39pub fn swap_cols<N: Shape, T>(a: ColMut<'_, T, N>, b: ColMut<'_, T, N>) {
40	fn swap<T>() -> impl FnMut(Zip<&mut T, Last<&mut T>>) {
41		|unzip!(a, b)| core::mem::swap(a, b)
42	}
43
44	zip!(a, b).for_each(swap::<T>());
45}
46
47/// swaps the values in the rows `a` and `b`
48///
49/// # panics
50///
51/// panics if `a` and `b` don't have the same number of columns
52///
53/// # example
54///
55/// ```
56/// use faer::{mat, perm};
57///
58/// let mut m = mat![
59/// 	[1.0, 2.0, 3.0], //
60/// 	[4.0, 5.0, 6.0],
61/// 	[7.0, 8.0, 9.0],
62/// 	[10.0, 14.0, 12.0],
63/// ];
64///
65/// let (a, b) = m.two_rows_mut(0, 2);
66/// perm::swap_rows(a, b);
67///
68/// let swapped = mat![
69/// 	[7.0, 8.0, 9.0], //
70/// 	[4.0, 5.0, 6.0],
71/// 	[1.0, 2.0, 3.0],
72/// 	[10.0, 14.0, 12.0],
73/// ];
74///
75/// assert_eq!(m, swapped);
76/// ```
77#[track_caller]
78#[inline]
79pub fn swap_rows<N: Shape, T>(a: RowMut<'_, T, N>, b: RowMut<'_, T, N>) {
80	swap_cols(a.transpose_mut(), b.transpose_mut())
81}
82
83/// swaps the two rows at indices `a` and `b` in the given matrix
84///
85/// # panics
86///
87/// panics if either `a` or `b` is out of bounds
88///
89/// # example
90///
91/// ```
92/// use faer::{mat, perm};
93///
94/// let mut m = mat![
95/// 	[1.0, 2.0, 3.0], //
96/// 	[4.0, 5.0, 6.0],
97/// 	[7.0, 8.0, 9.0],
98/// 	[10.0, 14.0, 12.0],
99/// ];
100///
101/// perm::swap_rows_idx(m.as_mut(), 0, 2);
102///
103/// let swapped = mat![
104/// 	[7.0, 8.0, 9.0], //
105/// 	[4.0, 5.0, 6.0],
106/// 	[1.0, 2.0, 3.0],
107/// 	[10.0, 14.0, 12.0],
108/// ];
109///
110/// assert_eq!(m, swapped);
111/// ```
112#[track_caller]
113#[inline]
114pub fn swap_rows_idx<M: Shape, N: Shape, T>(mat: MatMut<'_, T, M, N>, a: Idx<M>, b: Idx<M>) {
115	if a != b {
116		let (a, b) = mat.two_rows_mut(a, b);
117		swap_rows(a, b);
118	}
119}
120
121/// swaps the two columns at indices `a` and `b` in the given matrix
122///
123/// # panics
124///
125/// panics if either `a` or `b` is out of bounds
126///
127/// # example
128///
129/// ```
130/// use faer::{mat, perm};
131///
132/// let mut m = mat![
133/// 	[1.0, 2.0, 3.0], //
134/// 	[4.0, 5.0, 6.0],
135/// 	[7.0, 8.0, 9.0],
136/// 	[10.0, 14.0, 12.0],
137/// ];
138///
139/// perm::swap_cols_idx(m.as_mut(), 0, 2);
140///
141/// let swapped = mat![
142/// 	[3.0, 2.0, 1.0], //
143/// 	[6.0, 5.0, 4.0],
144/// 	[9.0, 8.0, 7.0],
145/// 	[12.0, 14.0, 10.0],
146/// ];
147///
148/// assert_eq!(m, swapped);
149/// ```
150#[track_caller]
151#[inline]
152pub fn swap_cols_idx<M: Shape, N: Shape, T>(mat: MatMut<'_, T, M, N>, a: Idx<N>, b: Idx<N>) {
153	if a != b {
154		let (a, b) = mat.two_cols_mut(a, b);
155		swap_cols(a, b);
156	}
157}
158
159mod permown;
160mod permref;
161
162/// permutation matrix
163pub type Perm<I, N = usize> = generic::Perm<Own<I, N>>;
164
165/// immutable permutation matrix view
166pub type PermRef<'a, I, N = usize> = generic::Perm<Ref<'a, I, N>>;
167
168pub use permown::Own;
169pub use permref::Ref;
170
171/// generic `Perm` wrapper
172pub mod generic {
173	use core::fmt::Debug;
174	use reborrow::*;
175
176	/// generic `Perm` wrapper
177	#[derive(Copy, Clone)]
178	#[repr(transparent)]
179	pub struct Perm<Inner>(pub Inner);
180
181	impl<Inner: Debug> Debug for Perm<Inner> {
182		#[inline(always)]
183		fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
184			self.0.fmt(f)
185		}
186	}
187
188	impl<Inner> Perm<Inner> {
189		/// wrap by reference
190		#[inline(always)]
191		pub fn from_inner_ref(inner: &Inner) -> &Self {
192			unsafe { &*(inner as *const Inner as *const Self) }
193		}
194
195		/// wrap by mutable reference
196		#[inline(always)]
197		pub fn from_inner_mut(inner: &mut Inner) -> &mut Self {
198			unsafe { &mut *(inner as *mut Inner as *mut Self) }
199		}
200	}
201
202	impl<Inner> core::ops::Deref for Perm<Inner> {
203		type Target = Inner;
204
205		#[inline(always)]
206		fn deref(&self) -> &Self::Target {
207			&self.0
208		}
209	}
210
211	impl<Inner> core::ops::DerefMut for Perm<Inner> {
212		#[inline(always)]
213		fn deref_mut(&mut self) -> &mut Self::Target {
214			&mut self.0
215		}
216	}
217
218	impl<'short, Inner: Reborrow<'short>> Reborrow<'short> for Perm<Inner> {
219		type Target = Perm<Inner::Target>;
220
221		#[inline(always)]
222		fn rb(&'short self) -> Self::Target {
223			Perm(self.0.rb())
224		}
225	}
226
227	impl<'short, Inner: ReborrowMut<'short>> ReborrowMut<'short> for Perm<Inner> {
228		type Target = Perm<Inner::Target>;
229
230		#[inline(always)]
231		fn rb_mut(&'short mut self) -> Self::Target {
232			Perm(self.0.rb_mut())
233		}
234	}
235
236	impl<Inner: IntoConst> IntoConst for Perm<Inner> {
237		type Target = Perm<Inner::Target>;
238
239		#[inline(always)]
240		fn into_const(self) -> Self::Target {
241			Perm(self.0.into_const())
242		}
243	}
244}
245
246use self::linalg::temp_mat_scratch;
247
248/// computes a permutation of the columns of the source matrix using the given permutation, and
249/// stores the result in the destination matrix
250///
251/// # panics
252///
253/// - panics if the matrices do not have the same shape
254/// - panics if the size of the permutation doesn't match the number of columns of the matrices
255#[inline]
256#[track_caller]
257pub fn permute_cols<I: Index, T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, perm_indices: PermRef<'_, I>) {
258	Assert!(all(
259		src.nrows() == dst.nrows(),
260		src.ncols() == dst.ncols(),
261		perm_indices.arrays().0.len() == src.ncols(),
262	));
263
264	permute_rows(dst.transpose_mut(), src.transpose(), perm_indices.canonicalized());
265}
266
267/// computes a permutation of the rows of the source matrix using the given permutation, and
268/// stores the result in the destination matrix
269///
270/// # panics
271///
272/// - panics if the matrices do not have the same shape
273/// - panics if the size of the permutation doesn't match the number of rows of the matrices
274#[inline]
275#[track_caller]
276pub fn permute_rows<I: Index, T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, perm_indices: PermRef<'_, I>) {
277	#[track_caller]
278	#[math]
279	fn implementation<I: Index, T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, perm_indices: PermRef<'_, I>) {
280		Assert!(all(
281			src.nrows() == dst.nrows(),
282			src.ncols() == dst.ncols(),
283			perm_indices.len() == src.nrows(),
284		));
285
286		with_dim!(m, src.nrows());
287		with_dim!(n, src.ncols());
288		let mut dst = dst.as_shape_mut(m, n);
289		let src = src.as_shape(m, n);
290		let perm = perm_indices.as_shape(m).bound_arrays().0;
291
292		if dst.rb().row_stride().unsigned_abs() < dst.rb().col_stride().unsigned_abs() {
293			for j in n.indices() {
294				for i in m.indices() {
295					dst[(i, j)] = copy(src[(perm[i].zx(), j)]);
296				}
297			}
298		} else {
299			for i in m.indices() {
300				let src_i = src.row(perm[i].zx());
301				let mut dst_i = dst.rb_mut().row_mut(i);
302
303				dst_i.copy_from(src_i);
304			}
305		}
306	}
307
308	implementation(dst, src, perm_indices.canonicalized())
309}
310
311/// computes the size and alignment of required workspace for applying a row permutation to a
312/// matrix in place
313pub fn permute_rows_in_place_scratch<I: Index, T: ComplexField>(nrows: usize, ncols: usize) -> StackReq {
314	temp_mat_scratch::<T>(nrows, ncols)
315}
316
317/// computes the size and alignment of required workspace for applying a column permutation to a
318/// matrix in place
319pub fn permute_cols_in_place_scratch<I: Index, T: ComplexField>(nrows: usize, ncols: usize) -> StackReq {
320	temp_mat_scratch::<T>(nrows, ncols)
321}
322
323/// computes a permutation of the rows of the matrix using the given permutation, and
324/// stores the result in the same matrix
325///
326/// # panics
327///
328/// - panics if the size of the permutation doesn't match the number of rows of the matrix
329#[inline]
330#[track_caller]
331pub fn permute_rows_in_place<I: Index, T: ComplexField>(matrix: MatMut<'_, T>, perm_indices: PermRef<'_, I>, stack: &mut MemStack) {
332	#[inline]
333	#[track_caller]
334	fn implementation<T: ComplexField, I: Index>(matrix: MatMut<'_, T>, perm_indices: PermRef<'_, I>, stack: &mut MemStack) {
335		let mut matrix = matrix;
336		let (mut tmp, _) = unsafe { temp_mat_uninit(matrix.nrows(), matrix.ncols(), stack) };
337		let mut tmp = tmp.as_mat_mut();
338		tmp.copy_from(matrix.rb());
339		permute_rows(matrix.rb_mut(), tmp.rb(), perm_indices);
340	}
341
342	implementation(matrix, perm_indices.canonicalized(), stack)
343}
344
345/// computes a permutation of the columns of the matrix using the given permutation, and
346/// stores the result in the same matrix.
347///
348/// # panics
349///
350/// - panics if the size of the permutation doesn't match the number of columns of the matrix
351#[inline]
352#[track_caller]
353pub fn permute_cols_in_place<I: Index, T: ComplexField>(matrix: MatMut<'_, T>, perm_indices: PermRef<'_, I>, stack: &mut MemStack) {
354	#[inline]
355	#[track_caller]
356	fn implementation<I: Index, T: ComplexField>(matrix: MatMut<'_, T>, perm_indices: PermRef<'_, I>, stack: &mut MemStack) {
357		let mut matrix = matrix;
358		let (mut tmp, _) = unsafe { temp_mat_uninit(matrix.nrows(), matrix.ncols(), stack) };
359		let mut tmp = tmp.as_mat_mut();
360		tmp.copy_from(matrix.rb());
361		permute_cols(matrix.rb_mut(), tmp.rb(), perm_indices);
362	}
363
364	implementation(matrix, perm_indices.canonicalized(), stack)
365}