1use super::*;
2use crate::utils::bound::{Array, Dim, Partition};
3use crate::{ContiguousFwd, Idx, IdxInc};
4use equator::{assert, debug_assert};
5
6pub struct Mut<'a, T, Cols = usize, CStride = isize> {
8 pub(crate) trans: ColMut<'a, T, Cols, CStride>,
9}
10
11impl<'short, T, Rows: Copy, RStride: Copy> Reborrow<'short> for Mut<'_, T, Rows, RStride> {
12 type Target = Ref<'short, T, Rows, RStride>;
13
14 #[inline]
15 fn rb(&'short self) -> Self::Target {
16 Ref { trans: self.trans.rb() }
17 }
18}
19impl<'short, T, Rows: Copy, RStride: Copy> ReborrowMut<'short> for Mut<'_, T, Rows, RStride> {
20 type Target = Mut<'short, T, Rows, RStride>;
21
22 #[inline]
23 fn rb_mut(&'short mut self) -> Self::Target {
24 Mut { trans: self.trans.rb_mut() }
25 }
26}
27impl<'a, T, Rows: Copy, RStride: Copy> IntoConst for Mut<'a, T, Rows, RStride> {
28 type Target = Ref<'a, T, Rows, RStride>;
29
30 #[inline]
31 fn into_const(self) -> Self::Target {
32 Ref {
33 trans: self.trans.into_const(),
34 }
35 }
36}
37
38impl<'a, T> RowMut<'a, T> {
39 #[inline]
41 pub fn from_mut(value: &'a mut T) -> Self {
42 unsafe { RowMut::from_raw_parts_mut(value as *mut T, 1, 1) }
43 }
44
45 #[inline]
48 pub fn from_slice_mut(slice: &'a mut [T]) -> Self {
49 let len = slice.len();
50 unsafe { Self::from_raw_parts_mut(slice.as_mut_ptr(), len, 1) }
51 }
52}
53
54impl<'a, T, Cols: Shape, CStride: Stride> RowMut<'a, T, Cols, CStride> {
55 #[inline(always)]
61 #[track_caller]
62 pub const unsafe fn from_raw_parts_mut(ptr: *mut T, ncols: Cols, col_stride: CStride) -> Self {
63 Self {
64 0: Mut {
65 trans: ColMut::from_raw_parts_mut(ptr, ncols, col_stride),
66 },
67 }
68 }
69
70 #[inline(always)]
72 pub fn as_ptr(&self) -> *const T {
73 self.trans.as_ptr()
74 }
75
76 #[inline(always)]
78 pub fn nrows(&self) -> usize {
79 1
80 }
81
82 #[inline(always)]
84 pub fn ncols(&self) -> Cols {
85 self.trans.nrows()
86 }
87
88 #[inline(always)]
90 pub fn shape(&self) -> (usize, Cols) {
91 (self.nrows(), self.ncols())
92 }
93
94 #[inline(always)]
96 pub fn col_stride(&self) -> CStride {
97 self.trans.row_stride()
98 }
99
100 #[inline(always)]
102 pub fn ptr_at(&self, col: IdxInc<Cols>) -> *const T {
103 self.trans.ptr_at(col)
104 }
105
106 #[inline(always)]
113 #[track_caller]
114 pub unsafe fn ptr_inbounds_at(&self, col: Idx<Cols>) -> *const T {
115 debug_assert!(all(col < self.ncols()));
116 self.trans.ptr_inbounds_at(col)
117 }
118
119 #[inline]
120 #[track_caller]
121 pub fn split_at_col(self, col: IdxInc<Cols>) -> (RowRef<'a, T, usize, CStride>, RowRef<'a, T, usize, CStride>) {
123 self.into_const().split_at_col(col)
124 }
125
126 #[inline(always)]
127 pub fn transpose(self) -> ColRef<'a, T, Cols, CStride> {
129 self.into_const().transpose()
130 }
131
132 #[inline(always)]
133 pub fn conjugate(self) -> RowRef<'a, T::Conj, Cols, CStride>
135 where
136 T: Conjugate,
137 {
138 self.into_const().conjugate()
139 }
140
141 #[inline(always)]
142 pub fn canonical(self) -> RowRef<'a, T::Canonical, Cols, CStride>
144 where
145 T: Conjugate,
146 {
147 self.into_const().canonical()
148 }
149
150 #[inline(always)]
151 pub fn adjoint(self) -> ColRef<'a, T::Conj, Cols, CStride>
153 where
154 T: Conjugate,
155 {
156 self.into_const().adjoint()
157 }
158
159 #[track_caller]
160 #[inline(always)]
161 pub fn get<ColRange>(self, col: ColRange) -> <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::Target
163 where
164 RowRef<'a, T, Cols, CStride>: RowIndex<ColRange>,
165 {
166 <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::get(self.into_const(), col)
167 }
168
169 #[track_caller]
170 #[inline(always)]
171 pub unsafe fn get_unchecked<ColRange>(self, col: ColRange) -> <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::Target
173 where
174 RowRef<'a, T, Cols, CStride>: RowIndex<ColRange>,
175 {
176 unsafe { <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::get_unchecked(self.into_const(), col) }
177 }
178
179 #[inline]
180 pub fn reverse_cols(self) -> RowRef<'a, T, Cols, CStride::Rev> {
182 self.into_const().reverse_cols()
183 }
184
185 #[inline]
186 pub fn subcols<V: Shape>(self, col_start: IdxInc<Cols>, ncols: V) -> RowRef<'a, T, V, CStride> {
188 self.into_const().subcols(col_start, ncols)
189 }
190
191 #[inline]
192 #[track_caller]
193 pub fn as_col_shape<V: Shape>(self, ncols: V) -> RowRef<'a, T, V, CStride> {
195 self.into_const().as_col_shape(ncols)
196 }
197
198 #[inline]
199 pub fn as_dyn_cols(self) -> RowRef<'a, T, usize, CStride> {
201 self.into_const().as_dyn_cols()
202 }
203
204 #[inline]
205 pub fn as_dyn_stride(self) -> RowRef<'a, T, Cols, isize> {
207 self.into_const().as_dyn_stride()
208 }
209
210 #[inline]
211 pub fn iter(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = &'a T>
213 where
214 Cols: 'a,
215 {
216 self.0.trans.iter()
217 }
218
219 #[inline]
220 #[cfg(feature = "rayon")]
221 pub fn par_iter(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = &'a T>
223 where
224 T: Sync,
225 Cols: 'a,
226 {
227 self.0.trans.par_iter()
228 }
229
230 #[inline]
231 pub fn try_as_row_major(self) -> Option<RowRef<'a, T, Cols, ContiguousFwd>> {
233 self.into_const().try_as_row_major()
234 }
235
236 #[inline]
237 pub fn as_diagonal(self) -> DiagRef<'a, T, Cols, CStride> {
239 DiagRef {
240 0: crate::diag::Ref {
241 inner: self.0.trans.into_const(),
242 },
243 }
244 }
245
246 #[inline(always)]
247 #[doc(hidden)]
248 pub unsafe fn const_cast(self) -> RowMut<'a, T, Cols, CStride> {
249 RowMut {
250 0: Mut {
251 trans: self.0.trans.const_cast(),
252 },
253 }
254 }
255
256 #[inline]
257 pub fn as_mat(self) -> MatRef<'a, T, usize, Cols, isize, CStride> {
259 self.into_const().as_mat()
260 }
261
262 #[inline]
263 pub fn as_mat_mut(self) -> MatMut<'a, T, usize, Cols, isize, CStride> {
265 unsafe { self.into_const().as_mat().const_cast() }
266 }
267}
268
269impl<T, Cols: Shape, CStride: Stride, Inner: for<'short> ReborrowMut<'short, Target = Mut<'short, T, Cols, CStride>>> generic::Row<Inner> {
270 #[inline]
271 pub fn as_mut(&mut self) -> RowMut<'_, T, Cols, CStride> {
273 self.rb_mut()
274 }
275
276 #[inline]
277 pub fn copy_from<RhsT: Conjugate<Canonical = T>>(&mut self, other: impl AsRowRef<T = RhsT, Cols = Cols>)
279 where
280 T: ComplexField,
281 {
282 self.rb_mut().transpose_mut().copy_from(other.as_row_ref().transpose());
283 }
284
285 pub fn fill(&mut self, value: T)
287 where
288 T: Clone,
289 {
290 self.rb_mut().transpose_mut().fill(value)
291 }
292}
293
294impl<'a, T, Cols: Shape, CStride: Stride> RowMut<'a, T, Cols, CStride> {
295 #[inline(always)]
296 pub fn as_ptr_mut(&self) -> *mut T {
298 self.trans.as_ptr_mut()
299 }
300
301 #[inline(always)]
302 pub fn ptr_at_mut(&self, col: IdxInc<Cols>) -> *mut T {
304 self.trans.ptr_at_mut(col)
305 }
306
307 #[inline(always)]
308 #[track_caller]
309 pub unsafe fn ptr_inbounds_at_mut(&self, col: Idx<Cols>) -> *mut T {
311 debug_assert!(all(col < self.ncols()));
312 self.trans.ptr_inbounds_at_mut(col)
313 }
314
315 #[inline]
316 #[track_caller]
317 pub fn split_at_col_mut(self, col: IdxInc<Cols>) -> (RowMut<'a, T, usize, CStride>, RowMut<'a, T, usize, CStride>) {
319 let (a, b) = self.into_const().split_at_col(col);
320 unsafe { (a.const_cast(), b.const_cast()) }
321 }
322
323 #[inline(always)]
324 pub fn transpose_mut(self) -> ColMut<'a, T, Cols, CStride> {
326 self.0.trans
327 }
328
329 #[inline(always)]
330 pub fn conjugate_mut(self) -> RowMut<'a, T::Conj, Cols, CStride>
332 where
333 T: Conjugate,
334 {
335 unsafe { self.into_const().conjugate().const_cast() }
336 }
337
338 #[inline(always)]
339 pub fn canonical_mut(self) -> RowMut<'a, T::Canonical, Cols, CStride>
341 where
342 T: Conjugate,
343 {
344 unsafe { self.into_const().canonical().const_cast() }
345 }
346
347 #[inline(always)]
348 pub fn adjoint_mut(self) -> ColMut<'a, T::Conj, Cols, CStride>
350 where
351 T: Conjugate,
352 {
353 unsafe { self.into_const().adjoint().const_cast() }
354 }
355
356 #[inline(always)]
357 #[track_caller]
358 pub(crate) fn at_mut(self, col: Idx<Cols>) -> &'a mut T {
359 assert!(all(col < self.ncols()));
360 unsafe { self.at_mut_unchecked(col) }
361 }
362
363 #[inline(always)]
364 #[track_caller]
365 pub(crate) unsafe fn at_mut_unchecked(self, col: Idx<Cols>) -> &'a mut T {
366 &mut *self.ptr_inbounds_at_mut(col)
367 }
368
369 #[track_caller]
370 #[inline(always)]
371 pub fn get_mut<ColRange>(self, col: ColRange) -> <RowMut<'a, T, Cols, CStride> as RowIndex<ColRange>>::Target
373 where
374 RowMut<'a, T, Cols, CStride>: RowIndex<ColRange>,
375 {
376 <RowMut<'a, T, Cols, CStride> as RowIndex<ColRange>>::get(self, col)
377 }
378
379 #[track_caller]
380 #[inline(always)]
381 pub unsafe fn get_mut_unchecked<ColRange>(self, col: ColRange) -> <RowMut<'a, T, Cols, CStride> as RowIndex<ColRange>>::Target
383 where
384 RowMut<'a, T, Cols, CStride>: RowIndex<ColRange>,
385 {
386 unsafe { <RowMut<'a, T, Cols, CStride> as RowIndex<ColRange>>::get_unchecked(self, col) }
387 }
388
389 #[inline]
390 pub fn reverse_cols_mut(self) -> RowMut<'a, T, Cols, CStride::Rev> {
392 unsafe { self.into_const().reverse_cols().const_cast() }
393 }
394
395 #[inline]
396 pub fn subcols_mut<V: Shape>(self, col_start: IdxInc<Cols>, ncols: V) -> RowMut<'a, T, V, CStride> {
398 unsafe { self.into_const().subcols(col_start, ncols).const_cast() }
399 }
400
401 #[inline]
402 #[track_caller]
403 pub fn as_col_shape_mut<V: Shape>(self, ncols: V) -> RowMut<'a, T, V, CStride> {
405 unsafe { self.into_const().as_col_shape(ncols).const_cast() }
406 }
407
408 #[inline]
409 pub fn as_dyn_cols_mut(self) -> RowMut<'a, T, usize, CStride> {
411 unsafe { self.into_const().as_dyn_cols().const_cast() }
412 }
413
414 #[inline]
415 pub fn as_dyn_stride_mut(self) -> RowMut<'a, T, Cols, isize> {
417 unsafe { self.into_const().as_dyn_stride().const_cast() }
418 }
419
420 #[inline]
421 pub fn iter_mut(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = &'a mut T>
423 where
424 Cols: 'a,
425 {
426 self.0.trans.iter_mut()
427 }
428
429 #[inline]
430 #[cfg(feature = "rayon")]
431 pub fn par_iter_mut(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = &'a mut T>
433 where
434 T: Send,
435 Cols: 'a,
436 {
437 self.0.trans.par_iter_mut()
438 }
439
440 #[inline]
441 #[track_caller]
442 #[cfg(feature = "rayon")]
443 pub fn par_partition(self, count: usize) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = RowRef<'a, T, usize, CStride>>
445 where
446 T: Sync,
447 Cols: 'a,
448 {
449 self.into_const().par_partition(count)
450 }
451
452 #[inline]
453 #[track_caller]
454 #[cfg(feature = "rayon")]
455 pub fn par_partition_mut(self, count: usize) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = RowMut<'a, T, usize, CStride>>
457 where
458 T: Send,
459 Cols: 'a,
460 {
461 use crate::mat::matmut::SyncCell;
462 use rayon::prelude::*;
463 unsafe {
464 self.as_type::<SyncCell<T>>()
465 .into_const()
466 .par_partition(count)
467 .map(|col| col.const_cast().as_type::<T>())
468 }
469 }
470
471 pub(crate) unsafe fn as_type<U>(self) -> RowMut<'a, U, Cols, CStride> {
472 RowMut::from_raw_parts_mut(self.as_ptr_mut() as *mut U, self.ncols(), self.col_stride())
473 }
474
475 #[inline]
476 pub fn try_as_row_major_mut(self) -> Option<RowMut<'a, T, Cols, ContiguousFwd>> {
478 self.into_const().try_as_row_major().map(|x| unsafe { x.const_cast() })
479 }
480
481 #[inline]
482 pub fn as_diagonal_mut(self) -> DiagMut<'a, T, Cols, CStride> {
484 DiagMut {
485 0: crate::diag::Mut { inner: self.0.trans },
486 }
487 }
488
489 #[inline]
490 pub(crate) fn __at_mut(self, i: Idx<Cols>) -> &'a mut T {
491 self.at_mut(i)
492 }
493}
494
495impl<'a, T, Rows: Shape> RowMut<'a, T, Rows, ContiguousFwd> {
496 #[inline]
498 pub fn as_slice(self) -> &'a [T] {
499 self.transpose().as_slice()
500 }
501}
502
503impl<'a, 'ROWS, T> RowMut<'a, T, Dim<'ROWS>, ContiguousFwd> {
504 #[inline]
506 pub fn as_array(self) -> &'a Array<'ROWS, T> {
507 self.transpose().as_array()
508 }
509}
510
511impl<'a, T, Cols: Shape> RowMut<'a, T, Cols, ContiguousFwd> {
512 #[inline]
514 pub fn as_slice_mut(self) -> &'a mut [T] {
515 self.transpose_mut().as_slice_mut()
516 }
517}
518
519impl<'a, 'COLS, T> RowMut<'a, T, Dim<'COLS>, ContiguousFwd> {
520 #[inline]
522 pub fn as_array_mut(self) -> &'a mut Array<'COLS, T> {
523 self.transpose_mut().as_array_mut()
524 }
525}
526
527impl<'COLS, 'a, T, CStride: Stride> RowMut<'a, T, Dim<'COLS>, CStride> {
528 #[doc(hidden)]
529 #[inline]
530 pub fn split_cols_with<'LEFT, 'RIGHT>(
531 self,
532 col: Partition<'LEFT, 'RIGHT, 'COLS>,
533 ) -> (RowRef<'a, T, Dim<'LEFT>, CStride>, RowRef<'a, T, Dim<'RIGHT>, CStride>) {
534 let (a, b) = self.split_at_col(col.midpoint());
535 (a.as_col_shape(col.head), b.as_col_shape(col.tail))
536 }
537}
538
539impl<'COLS, 'a, T, CStride: Stride> RowMut<'a, T, Dim<'COLS>, CStride> {
540 #[doc(hidden)]
541 #[inline]
542 pub fn split_cols_with_mut<'LEFT, 'RIGHT>(
543 self,
544 col: Partition<'LEFT, 'RIGHT, 'COLS>,
545 ) -> (RowMut<'a, T, Dim<'LEFT>, CStride>, RowMut<'a, T, Dim<'RIGHT>, CStride>) {
546 let (a, b) = self.split_at_col_mut(col.midpoint());
547 (a.as_col_shape_mut(col.head), b.as_col_shape_mut(col.tail))
548 }
549}
550
551impl<T: core::fmt::Debug, Cols: Shape, CStride: Stride> core::fmt::Debug for Mut<'_, T, Cols, CStride> {
552 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
553 self.rb().fmt(f)
554 }
555}
556
557impl<'a, T, Cols: Shape, CStride: Stride> RowMut<'a, T, Cols, CStride>
558where
559 T: RealField,
560{
561 pub fn max(&self) -> Option<T> {
563 self.rb().as_dyn_cols().as_dyn_stride().internal_max()
564 }
565
566 pub fn min(&self) -> Option<T> {
568 self.rb().as_dyn_cols().as_dyn_stride().internal_min()
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use crate::Row;
575
576 #[test]
577 fn test_row_min() {
578 let row: Row<f64> = Row::from_fn(5, |x| (x + 1) as f64);
579 let rowmut = row.as_ref();
580 assert_eq!(rowmut.min(), Some(1.0));
581
582 let empty: Row<f64> = Row::from_fn(0, |_| 0.0);
583 let emptymut = empty.as_ref();
584 assert_eq!(emptymut.min(), None);
585 }
586
587 #[test]
588 fn test_row_max() {
589 let row: Row<f64> = Row::from_fn(5, |x| (x + 1) as f64);
590 let rowmut = row.as_ref();
591 assert_eq!(rowmut.max(), Some(5.0));
592
593 let empty: Row<f64> = Row::from_fn(0, |_| 0.0);
594 let emptymut = empty.as_ref();
595 assert_eq!(emptymut.max(), None);
596 }
597}