1use super::{Mat, MatMut, MatRef, *};
2use crate::col::ColRef;
3use crate::internal_prelude::*;
4use crate::row::RowRef;
5use crate::utils::bound::{Dim, Partition};
6use crate::{ContiguousFwd, Idx, IdxInc};
7use equator::{assert, debug_assert};
8use faer_traits::Real;
9use generativity::Guard;
10
11pub struct Ref<'a, T, Rows = usize, Cols = usize, RStride = isize, CStride = isize> {
13 pub(super) imp: MatView<T, Rows, Cols, RStride, CStride>,
14 pub(super) __marker: PhantomData<&'a T>,
15}
16
17impl<T, Rows: Copy, Cols: Copy, RStride: Copy, CStride: Copy> Copy for Ref<'_, T, Rows, Cols, RStride, CStride> {}
18impl<T, Rows: Copy, Cols: Copy, RStride: Copy, CStride: Copy> Clone for Ref<'_, T, Rows, Cols, RStride, CStride> {
19 #[inline]
20 fn clone(&self) -> Self {
21 *self
22 }
23}
24
25impl<'short, T, Rows: Copy, Cols: Copy, RStride: Copy, CStride: Copy> Reborrow<'short> for Ref<'_, T, Rows, Cols, RStride, CStride> {
26 type Target = Ref<'short, T, Rows, Cols, RStride, CStride>;
27
28 #[inline]
29 fn rb(&'short self) -> Self::Target {
30 *self
31 }
32}
33impl<'short, T, Rows: Copy, Cols: Copy, RStride: Copy, CStride: Copy> ReborrowMut<'short> for Ref<'_, T, Rows, Cols, RStride, CStride> {
34 type Target = Ref<'short, T, Rows, Cols, RStride, CStride>;
35
36 #[inline]
37 fn rb_mut(&'short mut self) -> Self::Target {
38 *self
39 }
40}
41impl<'a, T, Rows: Copy, Cols: Copy, RStride: Copy, CStride: Copy> IntoConst for Ref<'a, T, Rows, Cols, RStride, CStride> {
42 type Target = Ref<'a, T, Rows, Cols, RStride, CStride>;
43
44 #[inline]
45 fn into_const(self) -> Self::Target {
46 self
47 }
48}
49
50unsafe impl<T: Sync, Rows: Sync, Cols: Sync, RStride: Sync, CStride: Sync> Sync for Ref<'_, T, Rows, Cols, RStride, CStride> {}
51unsafe impl<T: Sync, Rows: Send, Cols: Send, RStride: Send, CStride: Send> Send for Ref<'_, T, Rows, Cols, RStride, CStride> {}
52
53#[track_caller]
54#[inline]
55fn from_strided_column_major_slice_assert(nrows: usize, ncols: usize, col_stride: usize, len: usize) {
56 if nrows > 0 && ncols > 0 {
57 let last = usize::checked_mul(col_stride, ncols - 1).and_then(|last_col| last_col.checked_add(nrows - 1));
58 let Some(last) = last else {
59 panic!("address computation of the last matrix element overflowed");
60 };
61 assert!(last < len);
62 }
63}
64
65impl<'a, T> MatRef<'a, T> {
66 #[inline]
68 pub fn from_row_major_array<const ROWS: usize, const COLS: usize>(array: &'a [[T; COLS]; ROWS]) -> Self {
69 unsafe { Self::from_raw_parts(array as *const _ as *const T, ROWS, COLS, COLS as isize, 1) }
70 }
71
72 #[inline]
74 pub fn from_column_major_array<const ROWS: usize, const COLS: usize>(array: &'a [[T; ROWS]; COLS]) -> Self {
75 unsafe { Self::from_raw_parts(array as *const _ as *const T, ROWS, COLS, 1, ROWS as isize) }
76 }
77
78 #[inline]
80 pub fn from_ref(value: &'a T) -> Self {
81 unsafe { MatRef::from_raw_parts(value as *const T, 1, 1, 0, 0) }
82 }
83}
84
85impl<'a, T, Rows: Shape, Cols: Shape> MatRef<'a, T, Rows, Cols> {
86 #[inline]
88 pub fn from_repeated_ref(value: &'a T, nrows: Rows, ncols: Cols) -> Self {
89 unsafe { MatRef::from_raw_parts(value as *const T, nrows, ncols, 0, 0) }
90 }
91
92 #[inline]
112 #[track_caller]
113 pub fn from_column_major_slice(slice: &'a [T], nrows: Rows, ncols: Cols) -> Self {
114 from_slice_assert(nrows.unbound(), ncols.unbound(), slice.len());
115
116 unsafe { MatRef::from_raw_parts(slice.as_ptr(), nrows, ncols, 1, nrows.unbound() as isize) }
117 }
118
119 #[inline]
123 #[track_caller]
124 pub fn from_column_major_slice_with_stride(slice: &'a [T], nrows: Rows, ncols: Cols, col_stride: usize) -> Self {
125 from_strided_column_major_slice_assert(nrows.unbound(), ncols.unbound(), col_stride, slice.len());
126
127 unsafe { MatRef::from_raw_parts(slice.as_ptr(), nrows, ncols, 1, col_stride as isize) }
128 }
129
130 #[inline]
150 #[track_caller]
151 pub fn from_row_major_slice(slice: &'a [T], nrows: Rows, ncols: Cols) -> Self {
152 MatRef::from_column_major_slice(slice, ncols, nrows).transpose()
153 }
154
155 #[inline]
159 #[track_caller]
160 pub fn from_row_major_slice_with_stride(slice: &'a [T], nrows: Rows, ncols: Cols, row_stride: usize) -> Self {
161 MatRef::from_column_major_slice_with_stride(slice, ncols, nrows, row_stride).transpose()
162 }
163}
164
165impl<'a, T, Rows: Shape, Cols: Shape, RStride: Stride, CStride: Stride> MatRef<'a, T, Rows, Cols, RStride, CStride> {
166 #[inline]
203 #[track_caller]
204 pub const unsafe fn from_raw_parts(ptr: *const T, nrows: Rows, ncols: Cols, row_stride: RStride, col_stride: CStride) -> Self {
205 Self(Ref {
206 imp: MatView {
207 ptr: NonNull::new_unchecked(ptr as *mut T),
208 nrows,
209 ncols,
210 row_stride,
211 col_stride,
212 },
213 __marker: PhantomData,
214 })
215 }
216
217 #[inline]
219 pub fn as_ptr(&self) -> *const T {
220 self.imp.ptr.as_ptr() as *const T
221 }
222
223 #[inline]
225 pub fn nrows(&self) -> Rows {
226 self.imp.nrows
227 }
228
229 #[inline]
231 pub fn ncols(&self) -> Cols {
232 self.imp.ncols
233 }
234
235 #[inline]
237 pub fn shape(&self) -> (Rows, Cols) {
238 (self.nrows(), self.ncols())
239 }
240
241 #[inline]
243 pub fn row_stride(&self) -> RStride {
244 self.imp.row_stride
245 }
246
247 #[inline]
249 pub fn col_stride(&self) -> CStride {
250 self.imp.col_stride
251 }
252
253 #[inline]
255 pub fn ptr_at(&self, row: IdxInc<Rows>, col: IdxInc<Cols>) -> *const T {
256 let ptr = self.as_ptr();
257
258 if row >= self.nrows() || col >= self.ncols() {
259 ptr
260 } else {
261 ptr.wrapping_offset(row.unbound() as isize * self.row_stride().element_stride())
262 .wrapping_offset(col.unbound() as isize * self.col_stride().element_stride())
263 }
264 }
265
266 #[inline]
274 #[track_caller]
275 pub unsafe fn ptr_inbounds_at(&self, row: Idx<Rows>, col: Idx<Cols>) -> *const T {
276 debug_assert!(all(row < self.nrows(), col < self.ncols()));
277 self.as_ptr()
278 .offset(row.unbound() as isize * self.row_stride().element_stride())
279 .offset(col.unbound() as isize * self.col_stride().element_stride())
280 }
281
282 #[inline]
294 #[track_caller]
295 pub fn split_at(
296 self,
297 row: IdxInc<Rows>,
298 col: IdxInc<Cols>,
299 ) -> (
300 MatRef<'a, T, usize, usize, RStride, CStride>,
301 MatRef<'a, T, usize, usize, RStride, CStride>,
302 MatRef<'a, T, usize, usize, RStride, CStride>,
303 MatRef<'a, T, usize, usize, RStride, CStride>,
304 ) {
305 assert!(all(row <= self.nrows(), col <= self.ncols()));
306
307 let rs = self.row_stride();
308 let cs = self.col_stride();
309
310 let top_left = self.ptr_at(Rows::start(), Cols::start());
311 let top_right = self.ptr_at(Rows::start(), col);
312 let bot_left = self.ptr_at(row, Cols::start());
313 let bot_right = self.ptr_at(row, col);
314
315 unsafe {
316 (
317 MatRef::from_raw_parts(top_left, row.unbound(), col.unbound(), rs, cs),
318 MatRef::from_raw_parts(top_right, row.unbound(), self.ncols().unbound() - col.unbound(), rs, cs),
319 MatRef::from_raw_parts(bot_left, self.nrows().unbound() - row.unbound(), col.unbound(), rs, cs),
320 MatRef::from_raw_parts(
321 bot_right,
322 self.nrows().unbound() - row.unbound(),
323 self.ncols().unbound() - col.unbound(),
324 rs,
325 cs,
326 ),
327 )
328 }
329 }
330
331 #[inline]
340 #[track_caller]
341 pub fn split_at_row(self, row: IdxInc<Rows>) -> (MatRef<'a, T, usize, Cols, RStride, CStride>, MatRef<'a, T, usize, Cols, RStride, CStride>) {
342 assert!(all(row <= self.nrows()));
343
344 let rs = self.row_stride();
345 let cs = self.col_stride();
346
347 let top = self.ptr_at(Rows::start(), Cols::start());
348 let bot = self.ptr_at(row, Cols::start());
349
350 unsafe {
351 (
352 MatRef::from_raw_parts(top, row.unbound(), self.ncols(), rs, cs),
353 MatRef::from_raw_parts(bot, self.nrows().unbound() - row.unbound(), self.ncols(), rs, cs),
354 )
355 }
356 }
357
358 #[inline]
367 #[track_caller]
368 pub fn split_at_col(self, col: IdxInc<Cols>) -> (MatRef<'a, T, Rows, usize, RStride, CStride>, MatRef<'a, T, Rows, usize, RStride, CStride>) {
369 assert!(all(col <= self.ncols()));
370
371 let rs = self.row_stride();
372 let cs = self.col_stride();
373
374 let left = self.ptr_at(Rows::start(), Cols::start());
375 let right = self.ptr_at(Rows::start(), col);
376
377 unsafe {
378 (
379 MatRef::from_raw_parts(left, self.nrows(), col.unbound(), rs, cs),
380 MatRef::from_raw_parts(right, self.nrows(), self.ncols().unbound() - col.unbound(), rs, cs),
381 )
382 }
383 }
384
385 #[inline]
399 pub fn transpose(self) -> MatRef<'a, T, Cols, Rows, CStride, RStride> {
400 MatRef {
401 0: Ref {
402 imp: MatView {
403 ptr: self.imp.ptr,
404 nrows: self.imp.ncols,
405 ncols: self.imp.nrows,
406 row_stride: self.imp.col_stride,
407 col_stride: self.imp.row_stride,
408 },
409 __marker: PhantomData,
410 },
411 }
412 }
413
414 #[inline]
416 pub fn conjugate(self) -> MatRef<'a, T::Conj, Rows, Cols, RStride, CStride>
417 where
418 T: Conjugate,
419 {
420 unsafe {
421 MatRef::from_raw_parts(
422 self.as_ptr() as *const T::Conj,
423 self.nrows(),
424 self.ncols(),
425 self.row_stride(),
426 self.col_stride(),
427 )
428 }
429 }
430
431 #[inline]
433 pub fn canonical(self) -> MatRef<'a, T::Canonical, Rows, Cols, RStride, CStride>
434 where
435 T: Conjugate,
436 {
437 unsafe {
438 MatRef::from_raw_parts(
439 self.as_ptr() as *const T::Canonical,
440 self.nrows(),
441 self.ncols(),
442 self.row_stride(),
443 self.col_stride(),
444 )
445 }
446 }
447
448 #[inline]
449 #[doc(hidden)]
450 pub fn __canonicalize(self) -> (MatRef<'a, T::Canonical, Rows, Cols, RStride, CStride>, Conj)
451 where
452 T: Conjugate,
453 {
454 (self.canonical(), Conj::get::<T>())
455 }
456
457 #[inline]
459 pub fn adjoint(self) -> MatRef<'a, T::Conj, Cols, Rows, CStride, RStride>
460 where
461 T: Conjugate,
462 {
463 self.conjugate().transpose()
464 }
465
466 #[inline]
467 #[track_caller]
468 pub(crate) fn at(self, row: Idx<Rows>, col: Idx<Cols>) -> &'a T {
469 assert!(all(row < self.nrows(), col < self.ncols()));
470 unsafe { self.at_unchecked(row, col) }
471 }
472
473 #[inline]
474 #[track_caller]
475 pub(crate) unsafe fn at_unchecked(self, row: Idx<Rows>, col: Idx<Cols>) -> &'a T {
476 &*self.ptr_inbounds_at(row, col)
477 }
478
479 #[inline]
493 pub fn reverse_rows(self) -> MatRef<'a, T, Rows, Cols, RStride::Rev, CStride> {
494 let row = unsafe { IdxInc::<Rows>::new_unbound(self.nrows().unbound().saturating_sub(1)) };
495 let ptr = self.ptr_at(row, Cols::start());
496 unsafe { MatRef::from_raw_parts(ptr, self.nrows(), self.ncols(), self.row_stride().rev(), self.col_stride()) }
497 }
498
499 #[inline]
513 pub fn reverse_cols(self) -> MatRef<'a, T, Rows, Cols, RStride, CStride::Rev> {
514 let col = unsafe { IdxInc::<Cols>::new_unbound(self.ncols().unbound().saturating_sub(1)) };
515 let ptr = self.ptr_at(Rows::start(), col);
516 unsafe { MatRef::from_raw_parts(ptr, self.nrows(), self.ncols(), self.row_stride(), self.col_stride().rev()) }
517 }
518
519 #[inline]
533 pub fn reverse_rows_and_cols(self) -> MatRef<'a, T, Rows, Cols, RStride::Rev, CStride::Rev> {
534 self.reverse_rows().reverse_cols()
535 }
536
537 #[inline]
567 #[track_caller]
568 pub fn submatrix<V: Shape, H: Shape>(
569 self,
570 row_start: IdxInc<Rows>,
571 col_start: IdxInc<Cols>,
572 nrows: V,
573 ncols: H,
574 ) -> MatRef<'a, T, V, H, RStride, CStride> {
575 assert!(all(row_start <= self.nrows(), col_start <= self.ncols()));
576 {
577 let nrows = nrows.unbound();
578 let full_nrows = self.nrows().unbound();
579 let row_start = row_start.unbound();
580 let ncols = ncols.unbound();
581 let full_ncols = self.ncols().unbound();
582 let col_start = col_start.unbound();
583 assert!(all(nrows <= full_nrows - row_start, ncols <= full_ncols - col_start,));
584 }
585 let rs = self.row_stride();
586 let cs = self.col_stride();
587
588 unsafe { MatRef::from_raw_parts(self.ptr_at(row_start, col_start), nrows, ncols, rs, cs) }
589 }
590
591 #[inline]
617 #[track_caller]
618 pub fn subrows<V: Shape>(self, row_start: IdxInc<Rows>, nrows: V) -> MatRef<'a, T, V, Cols, RStride, CStride> {
619 assert!(all(row_start <= self.nrows()));
620 {
621 let nrows = nrows.unbound();
622 let full_nrows = self.nrows().unbound();
623 let row_start = row_start.unbound();
624 assert!(all(nrows <= full_nrows - row_start));
625 }
626 let rs = self.row_stride();
627 let cs = self.col_stride();
628
629 unsafe { MatRef::from_raw_parts(self.ptr_at(row_start, Cols::start()), nrows, self.ncols(), rs, cs) }
630 }
631
632 #[inline]
658 #[track_caller]
659 pub fn subcols<H: Shape>(self, col_start: IdxInc<Cols>, ncols: H) -> MatRef<'a, T, Rows, H, RStride, CStride> {
660 assert!(all(col_start <= self.ncols()));
661 {
662 let ncols = ncols.unbound();
663 let full_ncols = self.ncols().unbound();
664 let col_start = col_start.unbound();
665 assert!(all(ncols <= full_ncols - col_start));
666 }
667 let rs = self.row_stride();
668 let cs = self.col_stride();
669
670 unsafe { MatRef::from_raw_parts(self.ptr_at(Rows::start(), col_start), self.nrows(), ncols, rs, cs) }
671 }
672
673 #[inline]
676 #[track_caller]
677 pub fn as_shape<V: Shape, H: Shape>(self, nrows: V, ncols: H) -> MatRef<'a, T, V, H, RStride, CStride> {
678 assert!(all(self.nrows().unbound() == nrows.unbound(), self.ncols().unbound() == ncols.unbound(),));
679 unsafe { MatRef::from_raw_parts(self.as_ptr(), nrows, ncols, self.row_stride(), self.col_stride()) }
680 }
681
682 #[inline]
685 #[track_caller]
686 pub fn as_row_shape<V: Shape>(self, nrows: V) -> MatRef<'a, T, V, Cols, RStride, CStride> {
687 assert!(all(self.nrows().unbound() == nrows.unbound()));
688 unsafe { MatRef::from_raw_parts(self.as_ptr(), nrows, self.ncols(), self.row_stride(), self.col_stride()) }
689 }
690
691 #[inline]
694 #[track_caller]
695 pub fn as_col_shape<H: Shape>(self, ncols: H) -> MatRef<'a, T, Rows, H, RStride, CStride> {
696 assert!(all(self.ncols().unbound() == ncols.unbound()));
697 unsafe { MatRef::from_raw_parts(self.as_ptr(), self.nrows(), ncols, self.row_stride(), self.col_stride()) }
698 }
699
700 #[inline]
702 pub fn as_dyn_stride(self) -> MatRef<'a, T, Rows, Cols, isize, isize> {
703 unsafe {
704 MatRef::from_raw_parts(
705 self.as_ptr(),
706 self.nrows(),
707 self.ncols(),
708 self.row_stride().element_stride(),
709 self.col_stride().element_stride(),
710 )
711 }
712 }
713
714 #[inline]
716 pub fn as_dyn(self) -> MatRef<'a, T, usize, usize, RStride, CStride> {
717 unsafe {
718 MatRef::from_raw_parts(
719 self.as_ptr(),
720 self.nrows().unbound(),
721 self.ncols().unbound(),
722 self.row_stride(),
723 self.col_stride(),
724 )
725 }
726 }
727
728 #[inline]
730 pub fn as_dyn_rows(self) -> MatRef<'a, T, usize, Cols, RStride, CStride> {
731 unsafe { MatRef::from_raw_parts(self.as_ptr(), self.nrows().unbound(), self.ncols(), self.row_stride(), self.col_stride()) }
732 }
733
734 #[inline]
736 pub fn as_dyn_cols(self) -> MatRef<'a, T, Rows, usize, RStride, CStride> {
737 unsafe { MatRef::from_raw_parts(self.as_ptr(), self.nrows(), self.ncols().unbound(), self.row_stride(), self.col_stride()) }
738 }
739
740 #[inline]
746 #[track_caller]
747 pub fn row(self, i: Idx<Rows>) -> RowRef<'a, T, Cols, CStride> {
748 assert!(i < self.nrows());
749
750 unsafe { RowRef::from_raw_parts(self.ptr_at(i.into(), Cols::start()), self.ncols(), self.col_stride()) }
751 }
752
753 #[inline]
759 #[track_caller]
760 pub fn col(self, j: Idx<Cols>) -> ColRef<'a, T, Rows, RStride> {
761 assert!(j < self.ncols());
762
763 unsafe { ColRef::from_raw_parts(self.ptr_at(Rows::start(), j.into()), self.nrows(), self.row_stride()) }
764 }
765
766 #[inline]
768 pub fn col_iter(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = ColRef<'a, T, Rows, RStride>>
769 where
770 Rows: 'a,
771 Cols: 'a,
772 {
773 Cols::indices(Cols::start(), self.ncols().end()).map(move |j| self.col(j))
774 }
775
776 #[inline]
778 pub fn row_iter(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = RowRef<'a, T, Cols, CStride>>
779 where
780 Rows: 'a,
781 Cols: 'a,
782 {
783 Rows::indices(Rows::start(), self.nrows().end()).map(move |i| self.row(i))
784 }
785
786 #[inline]
788 #[cfg(feature = "rayon")]
789 pub fn par_col_iter(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = ColRef<'a, T, Rows, RStride>>
790 where
791 T: Sync,
792 Rows: 'a,
793 Cols: 'a,
794 {
795 use rayon::prelude::*;
796
797 #[inline]
798 fn col_fn<T, Rows: Shape, RStride: Stride, CStride: Stride>(
799 col: MatRef<'_, T, Rows, usize, RStride, CStride>,
800 ) -> ColRef<'_, T, Rows, RStride> {
801 col.col(0)
802 }
803
804 self.par_col_chunks(1).map(col_fn)
805 }
806
807 #[inline]
809 #[cfg(feature = "rayon")]
810 pub fn par_row_iter(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = RowRef<'a, T, Cols, CStride>>
811 where
812 T: Sync,
813 Rows: 'a,
814 Cols: 'a,
815 {
816 use rayon::prelude::*;
817 self.transpose().par_col_iter().map(ColRef::transpose)
818 }
819
820 #[inline]
828 #[track_caller]
829 #[cfg(feature = "rayon")]
830 pub fn par_col_chunks(
831 self,
832 chunk_size: usize,
833 ) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = MatRef<'a, T, Rows, usize, RStride, CStride>>
834 where
835 T: Sync,
836 Rows: 'a,
837 Cols: 'a,
838 {
839 use rayon::prelude::*;
840
841 let this = self.as_dyn_cols();
842
843 assert!(chunk_size > 0);
844 let chunk_count = this.ncols().msrv_div_ceil(chunk_size);
845 (0..chunk_count).into_par_iter().map(move |chunk_idx| {
846 let pos = chunk_size * chunk_idx;
847 this.subcols(pos, Ord::min(chunk_size, this.ncols() - pos))
848 })
849 }
850
851 #[inline]
856 #[track_caller]
857 #[cfg(feature = "rayon")]
858 pub fn par_col_partition(
859 self,
860 count: usize,
861 ) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = MatRef<'a, T, Rows, usize, RStride, CStride>>
862 where
863 T: Sync,
864 Rows: 'a,
865 Cols: 'a,
866 {
867 use rayon::prelude::*;
868
869 let this = self.as_dyn_cols();
870
871 assert!(count > 0);
872 (0..count).into_par_iter().map(move |chunk_idx| {
873 let (start, len) = crate::utils::thread::par_split_indices(this.ncols(), chunk_idx, count);
874 this.subcols(start, len)
875 })
876 }
877
878 #[inline]
886 #[track_caller]
887 #[cfg(feature = "rayon")]
888 pub fn par_row_chunks(
889 self,
890 chunk_size: usize,
891 ) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = MatRef<'a, T, usize, Cols, RStride, CStride>>
892 where
893 T: Sync,
894 Rows: 'a,
895 Cols: 'a,
896 {
897 use rayon::prelude::*;
898 self.transpose().par_col_chunks(chunk_size).map(MatRef::transpose)
899 }
900
901 #[inline]
906 #[track_caller]
907 #[cfg(feature = "rayon")]
908 pub fn par_row_partition(
909 self,
910 count: usize,
911 ) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = MatRef<'a, T, usize, Cols, RStride, CStride>>
912 where
913 T: Sync,
914 Rows: 'a,
915 Cols: 'a,
916 {
917 use rayon::prelude::*;
918 self.transpose().par_col_partition(count).map(MatRef::transpose)
919 }
920
921 #[inline]
924 pub fn split_first_row(self) -> Option<(RowRef<'a, T, Cols, CStride>, MatRef<'a, T, usize, Cols, RStride, CStride>)> {
925 if let Some(i0) = self.nrows().idx_inc(1) {
926 let (head, tail) = self.split_at_row(i0);
927 Some((head.row(0), tail))
928 } else {
929 None
930 }
931 }
932
933 #[inline]
936 pub fn split_first_col(self) -> Option<(ColRef<'a, T, Rows, RStride>, MatRef<'a, T, Rows, usize, RStride, CStride>)> {
937 if let Some(i0) = self.ncols().idx_inc(1) {
938 let (head, tail) = self.split_at_col(i0);
939 Some((head.col(0), tail))
940 } else {
941 None
942 }
943 }
944
945 #[inline]
948 pub fn split_last_row(self) -> Option<(RowRef<'a, T, Cols, CStride>, MatRef<'a, T, usize, Cols, RStride, CStride>)> {
949 if self.nrows().unbound() > 0 {
950 let i0 = self.nrows().checked_idx_inc(self.nrows().unbound() - 1);
951 let (head, tail) = self.split_at_row(i0);
952 Some((tail.row(0), head))
953 } else {
954 None
955 }
956 }
957
958 #[inline]
961 pub fn split_last_col(self) -> Option<(ColRef<'a, T, Rows, RStride>, MatRef<'a, T, Rows, usize, RStride, CStride>)> {
962 if self.ncols().unbound() > 0 {
963 let i0 = self.ncols().checked_idx_inc(self.ncols().unbound() - 1);
964 let (head, tail) = self.split_at_col(i0);
965 Some((tail.col(0), head))
966 } else {
967 None
968 }
969 }
970
971 #[inline]
974 pub fn try_as_row_major(self) -> Option<MatRef<'a, T, Rows, Cols, RStride, ContiguousFwd>> {
975 if self.col_stride().element_stride() == 1 {
976 Some(unsafe { MatRef::from_raw_parts(self.as_ptr(), self.nrows(), self.ncols(), self.row_stride(), ContiguousFwd) })
977 } else {
978 None
979 }
980 }
981
982 #[doc(hidden)]
983 #[inline]
984 pub fn bind<'M, 'N>(self, row: Guard<'M>, col: Guard<'N>) -> MatRef<'a, T, Dim<'M>, Dim<'N>, RStride, CStride> {
985 unsafe {
986 MatRef::from_raw_parts(
987 self.as_ptr(),
988 self.nrows().bind(row),
989 self.ncols().bind(col),
990 self.row_stride(),
991 self.col_stride(),
992 )
993 }
994 }
995
996 #[doc(hidden)]
997 #[inline]
998 pub fn bind_r<'M>(self, row: Guard<'M>) -> MatRef<'a, T, Dim<'M>, Cols, RStride, CStride> {
999 unsafe { MatRef::from_raw_parts(self.as_ptr(), self.nrows().bind(row), self.ncols(), self.row_stride(), self.col_stride()) }
1000 }
1001
1002 #[doc(hidden)]
1003 #[inline]
1004 pub fn bind_c<'N>(self, col: Guard<'N>) -> MatRef<'a, T, Rows, Dim<'N>, RStride, CStride> {
1005 unsafe { MatRef::from_raw_parts(self.as_ptr(), self.nrows(), self.ncols().bind(col), self.row_stride(), self.col_stride()) }
1006 }
1007
1008 #[doc(hidden)]
1009 #[inline]
1010 pub unsafe fn const_cast(self) -> MatMut<'a, T, Rows, Cols, RStride, CStride> {
1011 MatMut::from_raw_parts_mut(self.as_ptr() as *mut T, self.nrows(), self.ncols(), self.row_stride(), self.col_stride())
1012 }
1013
1014 #[inline]
1016 pub fn try_as_col_major(self) -> Option<MatRef<'a, T, Rows, Cols, ContiguousFwd, CStride>> {
1017 if self.row_stride().element_stride() == 1 {
1018 Some(unsafe { MatRef::from_raw_parts(self.as_ptr(), self.nrows(), self.ncols(), ContiguousFwd, self.col_stride()) })
1019 } else {
1020 None
1021 }
1022 }
1023
1024 #[track_caller]
1032 #[inline]
1033 pub fn get<RowRange, ColRange>(
1034 self,
1035 row: RowRange,
1036 col: ColRange,
1037 ) -> <MatRef<'a, T, Rows, Cols, RStride, CStride> as MatIndex<RowRange, ColRange>>::Target
1038 where
1039 MatRef<'a, T, Rows, Cols, RStride, CStride>: MatIndex<RowRange, ColRange>,
1040 {
1041 <MatRef<'a, T, Rows, Cols, RStride, CStride> as MatIndex<RowRange, ColRange>>::get(self, row, col)
1042 }
1043
1044 #[track_caller]
1046 #[inline]
1047 pub fn get_r<RowRange>(self, row: RowRange) -> <MatRef<'a, T, Rows, Cols, RStride, CStride> as MatIndex<RowRange, core::ops::RangeFull>>::Target
1048 where
1049 MatRef<'a, T, Rows, Cols, RStride, CStride>: MatIndex<RowRange, core::ops::RangeFull>,
1050 {
1051 <MatRef<'a, T, Rows, Cols, RStride, CStride> as MatIndex<RowRange, core::ops::RangeFull>>::get(self, row, ..)
1052 }
1053
1054 #[track_caller]
1056 #[inline]
1057 pub fn get_c<ColRange>(self, col: ColRange) -> <MatRef<'a, T, Rows, Cols, RStride, CStride> as MatIndex<core::ops::RangeFull, ColRange>>::Target
1058 where
1059 MatRef<'a, T, Rows, Cols, RStride, CStride>: MatIndex<core::ops::RangeFull, ColRange>,
1060 {
1061 <MatRef<'a, T, Rows, Cols, RStride, CStride> as MatIndex<core::ops::RangeFull, ColRange>>::get(self, .., col)
1062 }
1063
1064 #[track_caller]
1072 #[inline]
1073 pub unsafe fn get_unchecked<RowRange, ColRange>(
1074 self,
1075 row: RowRange,
1076 col: ColRange,
1077 ) -> <MatRef<'a, T, Rows, Cols, RStride, CStride> as MatIndex<RowRange, ColRange>>::Target
1078 where
1079 MatRef<'a, T, Rows, Cols, RStride, CStride>: MatIndex<RowRange, ColRange>,
1080 {
1081 unsafe { <MatRef<'a, T, Rows, Cols, RStride, CStride> as MatIndex<RowRange, ColRange>>::get_unchecked(self, row, col) }
1082 }
1083
1084 #[inline]
1085 pub(crate) fn __at(self, (i, j): (Idx<Rows>, Idx<Cols>)) -> &'a T {
1086 self.at(i, j)
1087 }
1088}
1089
1090impl<
1091 T,
1092 Rows: Shape,
1093 Cols: Shape,
1094 RStride: Stride,
1095 CStride: Stride,
1096 Inner: for<'short> Reborrow<'short, Target = Ref<'short, T, Rows, Cols, RStride, CStride>>,
1097> generic::Mat<Inner>
1098{
1099 #[inline]
1101 pub fn as_ref(&self) -> MatRef<'_, T, Rows, Cols, RStride, CStride> {
1102 self.rb()
1103 }
1104
1105 #[inline]
1107 pub fn cloned(&self) -> Mat<T, Rows, Cols>
1108 where
1109 T: Clone,
1110 {
1111 fn imp<'M, 'N, T: Clone, RStride: Stride, CStride: Stride>(
1112 this: MatRef<'_, T, Dim<'M>, Dim<'N>, RStride, CStride>,
1113 ) -> Mat<T, Dim<'M>, Dim<'N>> {
1114 Mat::from_fn(this.nrows(), this.ncols(), |i, j| this.at(i, j).clone())
1115 }
1116
1117 let this = self.rb();
1118
1119 with_dim!(M, this.nrows().unbound());
1120 with_dim!(N, this.ncols().unbound());
1121 imp(this.as_shape(M, N)).into_shape(this.nrows(), this.ncols())
1122 }
1123
1124 #[inline]
1126 pub fn to_owned(&self) -> Mat<T::Canonical, Rows, Cols>
1127 where
1128 T: Conjugate,
1129 {
1130 fn imp<'M, 'N, T, RStride: Stride, CStride: Stride>(
1131 this: MatRef<'_, T, Dim<'M>, Dim<'N>, RStride, CStride>,
1132 ) -> Mat<T::Canonical, Dim<'M>, Dim<'N>>
1133 where
1134 T: Conjugate,
1135 {
1136 Mat::from_fn(this.nrows(), this.ncols(), |i, j| Conj::apply::<T>(this.at(i, j)))
1137 }
1138
1139 let this = self.rb();
1140 with_dim!(M, this.nrows().unbound());
1141 with_dim!(N, this.ncols().unbound());
1142 imp(this.as_shape(M, N)).into_shape(this.nrows(), this.ncols())
1143 }
1144
1145 #[inline]
1147 pub fn norm_max(&self) -> Real<T>
1148 where
1149 T: Conjugate,
1150 {
1151 linalg::reductions::norm_max::norm_max(self.rb().canonical().as_dyn_stride().as_dyn())
1152 }
1153
1154 #[inline]
1156 pub fn norm_l2(&self) -> Real<T>
1157 where
1158 T: Conjugate,
1159 {
1160 linalg::reductions::norm_l2::norm_l2(self.rb().canonical().as_dyn_stride().as_dyn())
1161 }
1162
1163 #[inline]
1165 pub fn squared_norm_l2(&self) -> Real<T>
1166 where
1167 T: Conjugate,
1168 {
1169 linalg::reductions::norm_l2_sqr::norm_l2_sqr(self.rb().canonical().as_dyn_stride().as_dyn())
1170 }
1171
1172 #[inline]
1174 pub fn norm_l1(&self) -> Real<T>
1175 where
1176 T: Conjugate,
1177 {
1178 linalg::reductions::norm_l1::norm_l1(self.rb().canonical().as_dyn_stride().as_dyn())
1179 }
1180
1181 #[inline]
1183 #[math]
1184 pub fn sum(&self) -> T::Canonical
1185 where
1186 T: Conjugate,
1187 {
1188 let val = linalg::reductions::sum::sum(self.rb().canonical().as_dyn_stride().as_dyn());
1189 if try_const! { Conj::get::<T>().is_conj() } { conj(val) } else { val }
1190 }
1191
1192 #[inline]
1194 #[math]
1195 pub fn determinant(&self) -> T::Canonical
1196 where
1197 T: Conjugate,
1198 {
1199 let det = linalg::reductions::determinant::determinant(self.rb().canonical().as_dyn_stride().as_dyn());
1200 if const { T::IS_CANONICAL } { det } else { conj(det) }
1201 }
1202
1203 #[inline]
1239 pub fn kron(&self, rhs: impl AsMatRef<T: Conjugate<Canonical = T::Canonical>>) -> Mat<T::Canonical>
1240 where
1241 T: Conjugate,
1242 {
1243 fn imp<T: ComplexField>(lhs: MatRef<'_, impl Conjugate<Canonical = T>>, rhs: MatRef<'_, impl Conjugate<Canonical = T>>) -> Mat<T> {
1244 let mut out = Mat::zeros(lhs.nrows() * rhs.nrows(), lhs.ncols() * rhs.ncols());
1245 linalg::kron::kron(out.rb_mut(), lhs, rhs);
1246 out
1247 }
1248
1249 imp(self.rb().as_dyn().as_dyn_stride(), rhs.as_mat_ref().as_dyn().as_dyn_stride())
1250 }
1251
1252 #[inline]
1255 pub fn is_all_finite(&self) -> bool
1256 where
1257 T: Conjugate,
1258 {
1259 fn imp<T: ComplexField>(A: MatRef<'_, T>) -> bool {
1260 with_dim!({
1261 let M = A.nrows();
1262 let N = A.ncols();
1263 });
1264
1265 let A = A.as_shape(M, N);
1266
1267 for j in N.indices() {
1268 for i in M.indices() {
1269 if !is_finite(&A[(i, j)]) {
1270 return false;
1271 }
1272 }
1273 }
1274
1275 true
1276 }
1277
1278 imp(self.rb().as_dyn().as_dyn_stride().canonical())
1279 }
1280
1281 #[inline]
1284 pub fn has_nan(&self) -> bool
1285 where
1286 T: Conjugate,
1287 {
1288 fn imp<T: ComplexField>(A: MatRef<'_, T>) -> bool {
1289 with_dim!({
1290 let M = A.nrows();
1291 let N = A.ncols();
1292 });
1293
1294 let A = A.as_shape(M, N);
1295
1296 for j in N.indices() {
1297 for i in M.indices() {
1298 if is_nan(&A[(i, j)]) {
1299 return true;
1300 }
1301 }
1302 }
1303
1304 false
1305 }
1306
1307 imp(self.rb().as_dyn().as_dyn_stride().canonical())
1308 }
1309}
1310
1311impl<'a, T, Dim: Shape, RStride: Stride, CStride: Stride> MatRef<'a, T, Dim, Dim, RStride, CStride> {
1312 #[inline]
1314 pub fn diagonal(self) -> DiagRef<'a, T, Dim, isize> {
1315 let k = Ord::min(self.nrows(), self.ncols());
1316 DiagRef {
1317 0: crate::diag::Ref {
1318 inner: unsafe { ColRef::from_raw_parts(self.as_ptr(), k, self.row_stride().element_stride() + self.col_stride().element_stride()) },
1319 },
1320 }
1321 }
1322}
1323
1324impl<'ROWS, 'COLS, 'a, T, RStride: Stride, CStride: Stride> MatRef<'a, T, Dim<'ROWS>, Dim<'COLS>, RStride, CStride> {
1325 #[doc(hidden)]
1326 #[inline]
1327 pub fn split_with<'TOP, 'BOT, 'LEFT, 'RIGHT>(
1328 self,
1329 row: Partition<'TOP, 'BOT, 'ROWS>,
1330 col: Partition<'LEFT, 'RIGHT, 'COLS>,
1331 ) -> (
1332 MatRef<'a, T, Dim<'TOP>, Dim<'LEFT>, RStride, CStride>,
1333 MatRef<'a, T, Dim<'TOP>, Dim<'RIGHT>, RStride, CStride>,
1334 MatRef<'a, T, Dim<'BOT>, Dim<'LEFT>, RStride, CStride>,
1335 MatRef<'a, T, Dim<'BOT>, Dim<'RIGHT>, RStride, CStride>,
1336 ) {
1337 let (a, b, c, d) = self.split_at(row.midpoint(), col.midpoint());
1338 (
1339 a.as_shape(row.head, col.head),
1340 b.as_shape(row.head, col.tail),
1341 c.as_shape(row.tail, col.head),
1342 d.as_shape(row.tail, col.tail),
1343 )
1344 }
1345}
1346
1347impl<'ROWS, 'a, T, Cols: Shape, RStride: Stride, CStride: Stride> MatRef<'a, T, Dim<'ROWS>, Cols, RStride, CStride> {
1348 #[doc(hidden)]
1349 #[inline]
1350 pub fn split_rows_with<'TOP, 'BOT>(
1351 self,
1352 row: Partition<'TOP, 'BOT, 'ROWS>,
1353 ) -> (
1354 MatRef<'a, T, Dim<'TOP>, Cols, RStride, CStride>,
1355 MatRef<'a, T, Dim<'BOT>, Cols, RStride, CStride>,
1356 ) {
1357 let (a, b) = self.split_at_row(row.midpoint());
1358 (a.as_row_shape(row.head), b.as_row_shape(row.tail))
1359 }
1360}
1361
1362impl<'COLS, 'a, T, Rows: Shape, RStride: Stride, CStride: Stride> MatRef<'a, T, Rows, Dim<'COLS>, RStride, CStride> {
1363 #[doc(hidden)]
1364 #[inline]
1365 pub fn split_cols_with<'LEFT, 'RIGHT>(
1366 self,
1367 col: Partition<'LEFT, 'RIGHT, 'COLS>,
1368 ) -> (
1369 MatRef<'a, T, Rows, Dim<'LEFT>, RStride, CStride>,
1370 MatRef<'a, T, Rows, Dim<'RIGHT>, RStride, CStride>,
1371 ) {
1372 let (a, b) = self.split_at_col(col.midpoint());
1373 (a.as_col_shape(col.head), b.as_col_shape(col.tail))
1374 }
1375}
1376
1377impl<'a, T: core::fmt::Debug, Rows: Shape, Cols: Shape, RStride: Stride, CStride: Stride> core::fmt::Debug
1378 for Ref<'a, T, Rows, Cols, RStride, CStride>
1379{
1380 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1381 fn imp<'M, 'N, T: core::fmt::Debug>(this: MatRef<'_, T, Dim<'M>, Dim<'N>>, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1382 writeln!(f, "[")?;
1383 for i in this.nrows().indices() {
1384 this.row(i).fmt(f)?;
1385 f.write_str(",\n")?;
1386 }
1387 write!(f, "]")
1388 }
1389
1390 let this = generic::Mat::from_inner_ref(self);
1391
1392 with_dim!(M, this.nrows().unbound());
1393 with_dim!(N, this.ncols().unbound());
1394 imp(this.as_shape(M, N).as_dyn_stride(), f)
1395 }
1396}
1397
1398impl<'a, T> MatRef<'a, T, usize, usize>
1399where
1400 T: RealField,
1401{
1402 pub(crate) fn internal_max(self) -> Option<T> {
1403 if self.nrows().unbound() == 0 || self.ncols().unbound() == 0 {
1404 return None;
1405 }
1406
1407 let mut max_val = self.get(0, 0);
1408
1409 let this = if self.row_stride().unsigned_abs() == 1 { self.transpose() } else { self };
1410
1411 let this = if this.col_stride() > 0 { this } else { this.reverse_cols() };
1412
1413 this.row_iter().for_each(|row| {
1414 row.iter().for_each(|val| {
1415 if val > max_val {
1416 max_val = &val;
1417 }
1418 });
1419 });
1420
1421 Some((*max_val).clone())
1422 }
1423
1424 pub(crate) fn internal_min(self) -> Option<T> {
1425 if self.nrows().unbound() == 0 || self.ncols().unbound() == 0 {
1426 return None;
1427 }
1428
1429 let mut min_val = self.get(0, 0);
1430
1431 let this = if self.row_stride().unsigned_abs() == 1 { self.transpose() } else { self };
1432
1433 let this = if this.col_stride() > 0 { this } else { this.reverse_cols() };
1434
1435 this.row_iter().for_each(|row| {
1436 row.iter().for_each(|val| {
1437 if val < min_val {
1438 min_val = &val;
1439 }
1440 });
1441 });
1442
1443 Some((*min_val).clone())
1444 }
1445}
1446
1447impl<'a, T, Rows: Shape, Cols: Shape> MatRef<'a, T, Rows, Cols>
1448where
1449 T: RealField,
1450{
1451 pub fn max(self) -> Option<T> {
1470 MatRef::internal_max(self.as_dyn())
1471 }
1472
1473 pub fn min(self) -> Option<T> {
1492 MatRef::internal_min(self.as_dyn())
1493 }
1494}
1495
1496#[cfg(test)]
1497mod tests {
1498 use super::*;
1499
1500 #[test]
1501 fn test_min() {
1502 let m = mat![
1503 [1.0, 5.0, 3.0],
1504 [4.0, 2.0, 9.0],
1505 [7.0, 8.0, 6.0], ];
1507
1508 assert_eq!(m.as_ref().min(), Some(1.0));
1509
1510 let empty: Mat<f64> = Mat::new();
1511 assert_eq!(empty.as_ref().min(), None);
1512 }
1513
1514 #[test]
1515 fn test_max() {
1516 let m = mat![
1517 [1.0, 5.0, 3.0],
1518 [4.0, 2.0, 9.0],
1519 [7.0, 8.0, 6.0], ];
1521
1522 assert_eq!(m.as_ref().max(), Some(9.0));
1523
1524 let empty: Mat<f64> = Mat::new();
1525 assert_eq!(empty.as_ref().max(), None);
1526 }
1527}