1use crate::Idx;
2use crate::internal_prelude::*;
3use dyn_stack::StackReq;
4use linalg::zip::{Last, Zip};
5use reborrow::*;
6
7#[track_caller]
38#[inline]
39pub fn swap_cols<N: Shape, T>(a: ColMut<'_, T, N>, b: ColMut<'_, T, N>) {
40 fn swap<T>() -> impl FnMut(Zip<&mut T, Last<&mut T>>) {
41 |unzip!(a, b)| core::mem::swap(a, b)
42 }
43
44 zip!(a, b).for_each(swap::<T>());
45}
46
47#[track_caller]
78#[inline]
79pub fn swap_rows<N: Shape, T>(a: RowMut<'_, T, N>, b: RowMut<'_, T, N>) {
80 swap_cols(a.transpose_mut(), b.transpose_mut())
81}
82
83#[track_caller]
113#[inline]
114pub fn swap_rows_idx<M: Shape, N: Shape, T>(mat: MatMut<'_, T, M, N>, a: Idx<M>, b: Idx<M>) {
115 if a != b {
116 let (a, b) = mat.two_rows_mut(a, b);
117 swap_rows(a, b);
118 }
119}
120
121#[track_caller]
151#[inline]
152pub fn swap_cols_idx<M: Shape, N: Shape, T>(mat: MatMut<'_, T, M, N>, a: Idx<N>, b: Idx<N>) {
153 if a != b {
154 let (a, b) = mat.two_cols_mut(a, b);
155 swap_cols(a, b);
156 }
157}
158
159mod permown;
160mod permref;
161
162pub type Perm<I, N = usize> = generic::Perm<Own<I, N>>;
164
165pub type PermRef<'a, I, N = usize> = generic::Perm<Ref<'a, I, N>>;
167
168pub use permown::Own;
169pub use permref::Ref;
170
171pub mod generic {
173 use core::fmt::Debug;
174 use reborrow::*;
175
176 #[derive(Copy, Clone)]
178 #[repr(transparent)]
179 pub struct Perm<Inner>(pub Inner);
180
181 impl<Inner: Debug> Debug for Perm<Inner> {
182 #[inline(always)]
183 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
184 self.0.fmt(f)
185 }
186 }
187
188 impl<Inner> Perm<Inner> {
189 #[inline(always)]
191 pub fn from_inner_ref(inner: &Inner) -> &Self {
192 unsafe { &*(inner as *const Inner as *const Self) }
193 }
194
195 #[inline(always)]
197 pub fn from_inner_mut(inner: &mut Inner) -> &mut Self {
198 unsafe { &mut *(inner as *mut Inner as *mut Self) }
199 }
200 }
201
202 impl<Inner> core::ops::Deref for Perm<Inner> {
203 type Target = Inner;
204
205 #[inline(always)]
206 fn deref(&self) -> &Self::Target {
207 &self.0
208 }
209 }
210
211 impl<Inner> core::ops::DerefMut for Perm<Inner> {
212 #[inline(always)]
213 fn deref_mut(&mut self) -> &mut Self::Target {
214 &mut self.0
215 }
216 }
217
218 impl<'short, Inner: Reborrow<'short>> Reborrow<'short> for Perm<Inner> {
219 type Target = Perm<Inner::Target>;
220
221 #[inline(always)]
222 fn rb(&'short self) -> Self::Target {
223 Perm(self.0.rb())
224 }
225 }
226
227 impl<'short, Inner: ReborrowMut<'short>> ReborrowMut<'short> for Perm<Inner> {
228 type Target = Perm<Inner::Target>;
229
230 #[inline(always)]
231 fn rb_mut(&'short mut self) -> Self::Target {
232 Perm(self.0.rb_mut())
233 }
234 }
235
236 impl<Inner: IntoConst> IntoConst for Perm<Inner> {
237 type Target = Perm<Inner::Target>;
238
239 #[inline(always)]
240 fn into_const(self) -> Self::Target {
241 Perm(self.0.into_const())
242 }
243 }
244}
245
246use self::linalg::temp_mat_scratch;
247
248#[inline]
256#[track_caller]
257pub fn permute_cols<I: Index, T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, perm_indices: PermRef<'_, I>) {
258 Assert!(all(
259 src.nrows() == dst.nrows(),
260 src.ncols() == dst.ncols(),
261 perm_indices.arrays().0.len() == src.ncols(),
262 ));
263
264 permute_rows(dst.transpose_mut(), src.transpose(), perm_indices.canonicalized());
265}
266
267#[inline]
275#[track_caller]
276pub fn permute_rows<I: Index, T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, perm_indices: PermRef<'_, I>) {
277 #[track_caller]
278 #[math]
279 fn implementation<I: Index, T: ComplexField>(dst: MatMut<'_, T>, src: MatRef<'_, T>, perm_indices: PermRef<'_, I>) {
280 Assert!(all(
281 src.nrows() == dst.nrows(),
282 src.ncols() == dst.ncols(),
283 perm_indices.len() == src.nrows(),
284 ));
285
286 with_dim!(m, src.nrows());
287 with_dim!(n, src.ncols());
288 let mut dst = dst.as_shape_mut(m, n);
289 let src = src.as_shape(m, n);
290 let perm = perm_indices.as_shape(m).bound_arrays().0;
291
292 if dst.rb().row_stride().unsigned_abs() < dst.rb().col_stride().unsigned_abs() {
293 for j in n.indices() {
294 for i in m.indices() {
295 dst[(i, j)] = copy(src[(perm[i].zx(), j)]);
296 }
297 }
298 } else {
299 for i in m.indices() {
300 let src_i = src.row(perm[i].zx());
301 let mut dst_i = dst.rb_mut().row_mut(i);
302
303 dst_i.copy_from(src_i);
304 }
305 }
306 }
307
308 implementation(dst, src, perm_indices.canonicalized())
309}
310
311pub fn permute_rows_in_place_scratch<I: Index, T: ComplexField>(nrows: usize, ncols: usize) -> StackReq {
314 temp_mat_scratch::<T>(nrows, ncols)
315}
316
317pub fn permute_cols_in_place_scratch<I: Index, T: ComplexField>(nrows: usize, ncols: usize) -> StackReq {
320 temp_mat_scratch::<T>(nrows, ncols)
321}
322
323#[inline]
330#[track_caller]
331pub fn permute_rows_in_place<I: Index, T: ComplexField>(matrix: MatMut<'_, T>, perm_indices: PermRef<'_, I>, stack: &mut MemStack) {
332 #[inline]
333 #[track_caller]
334 fn implementation<T: ComplexField, I: Index>(matrix: MatMut<'_, T>, perm_indices: PermRef<'_, I>, stack: &mut MemStack) {
335 let mut matrix = matrix;
336 let (mut tmp, _) = unsafe { temp_mat_uninit(matrix.nrows(), matrix.ncols(), stack) };
337 let mut tmp = tmp.as_mat_mut();
338 tmp.copy_from(matrix.rb());
339 permute_rows(matrix.rb_mut(), tmp.rb(), perm_indices);
340 }
341
342 implementation(matrix, perm_indices.canonicalized(), stack)
343}
344
345#[inline]
352#[track_caller]
353pub fn permute_cols_in_place<I: Index, T: ComplexField>(matrix: MatMut<'_, T>, perm_indices: PermRef<'_, I>, stack: &mut MemStack) {
354 #[inline]
355 #[track_caller]
356 fn implementation<I: Index, T: ComplexField>(matrix: MatMut<'_, T>, perm_indices: PermRef<'_, I>, stack: &mut MemStack) {
357 let mut matrix = matrix;
358 let (mut tmp, _) = unsafe { temp_mat_uninit(matrix.nrows(), matrix.ncols(), stack) };
359 let mut tmp = tmp.as_mat_mut();
360 tmp.copy_from(matrix.rb());
361 permute_cols(matrix.rb_mut(), tmp.rb(), perm_indices);
362 }
363
364 implementation(matrix, perm_indices.canonicalized(), stack)
365}