faer/row/
rowmut.rs

1use super::*;
2use crate::utils::bound::{Array, Dim, Partition};
3use crate::{ContiguousFwd, Idx, IdxInc};
4use equator::{assert, debug_assert};
5
6/// see [`super::RowMut`]
7pub struct Mut<'a, T, Cols = usize, CStride = isize> {
8	pub(crate) trans: ColMut<'a, T, Cols, CStride>,
9}
10
11impl<'short, T, Rows: Copy, RStride: Copy> Reborrow<'short> for Mut<'_, T, Rows, RStride> {
12	type Target = Ref<'short, T, Rows, RStride>;
13
14	#[inline]
15	fn rb(&'short self) -> Self::Target {
16		Ref { trans: self.trans.rb() }
17	}
18}
19impl<'short, T, Rows: Copy, RStride: Copy> ReborrowMut<'short> for Mut<'_, T, Rows, RStride> {
20	type Target = Mut<'short, T, Rows, RStride>;
21
22	#[inline]
23	fn rb_mut(&'short mut self) -> Self::Target {
24		Mut { trans: self.trans.rb_mut() }
25	}
26}
27impl<'a, T, Rows: Copy, RStride: Copy> IntoConst for Mut<'a, T, Rows, RStride> {
28	type Target = Ref<'a, T, Rows, RStride>;
29
30	#[inline]
31	fn into_const(self) -> Self::Target {
32		Ref {
33			trans: self.trans.into_const(),
34		}
35	}
36}
37
38impl<'a, T> RowMut<'a, T> {
39	/// creates a row view over the given element
40	#[inline]
41	pub fn from_mut(value: &'a mut T) -> Self {
42		unsafe { RowMut::from_raw_parts_mut(value as *mut T, 1, 1) }
43	}
44
45	/// creates a `RowMut` from slice views over the row vector data, the result has the same
46	/// number of columns as the length of the input slice
47	#[inline]
48	pub fn from_slice_mut(slice: &'a mut [T]) -> Self {
49		let len = slice.len();
50		unsafe { Self::from_raw_parts_mut(slice.as_mut_ptr(), len, 1) }
51	}
52}
53
54impl<'a, T, Cols: Shape, CStride: Stride> RowMut<'a, T, Cols, CStride> {
55	/// creates a `RowMut` from pointers to the column vector data, number of rows, and row stride
56	///
57	/// # safety
58	/// this function has the same safety requirements as
59	/// [`MatMut::from_raw_parts_mut(ptr, 1, ncols, 0, col_stride)`]
60	#[inline(always)]
61	#[track_caller]
62	pub const unsafe fn from_raw_parts_mut(ptr: *mut T, ncols: Cols, col_stride: CStride) -> Self {
63		Self {
64			0: Mut {
65				trans: ColMut::from_raw_parts_mut(ptr, ncols, col_stride),
66			},
67		}
68	}
69
70	/// returns a pointer to the row data
71	#[inline(always)]
72	pub fn as_ptr(&self) -> *const T {
73		self.trans.as_ptr()
74	}
75
76	/// returns the number of rows of the row (always 1)
77	#[inline(always)]
78	pub fn nrows(&self) -> usize {
79		1
80	}
81
82	/// returns the number of columns of the row
83	#[inline(always)]
84	pub fn ncols(&self) -> Cols {
85		self.trans.nrows()
86	}
87
88	/// returns the number of rows and columns of the row
89	#[inline(always)]
90	pub fn shape(&self) -> (usize, Cols) {
91		(self.nrows(), self.ncols())
92	}
93
94	/// returns the column stride of the row
95	#[inline(always)]
96	pub fn col_stride(&self) -> CStride {
97		self.trans.row_stride()
98	}
99
100	/// returns a raw pointer to the element at the given index
101	#[inline(always)]
102	pub fn ptr_at(&self, col: IdxInc<Cols>) -> *const T {
103		self.trans.ptr_at(col)
104	}
105
106	/// returns a raw pointer to the element at the given index, assuming the provided index
107	/// is within the row bounds
108	///
109	/// # safety
110	/// the behavior is undefined if any of the following conditions are violated:
111	/// * `col < self.ncols()`
112	#[inline(always)]
113	#[track_caller]
114	pub unsafe fn ptr_inbounds_at(&self, col: Idx<Cols>) -> *const T {
115		debug_assert!(all(col < self.ncols()));
116		self.trans.ptr_inbounds_at(col)
117	}
118
119	#[inline]
120	#[track_caller]
121	/// see [`RowRef::split_at_col`]
122	pub fn split_at_col(self, col: IdxInc<Cols>) -> (RowRef<'a, T, usize, CStride>, RowRef<'a, T, usize, CStride>) {
123		self.into_const().split_at_col(col)
124	}
125
126	#[inline(always)]
127	/// see [`RowRef::transpose`]
128	pub fn transpose(self) -> ColRef<'a, T, Cols, CStride> {
129		self.into_const().transpose()
130	}
131
132	#[inline(always)]
133	/// see [`RowRef::conjugate`]
134	pub fn conjugate(self) -> RowRef<'a, T::Conj, Cols, CStride>
135	where
136		T: Conjugate,
137	{
138		self.into_const().conjugate()
139	}
140
141	#[inline(always)]
142	/// see [`RowRef::canonical`]
143	pub fn canonical(self) -> RowRef<'a, T::Canonical, Cols, CStride>
144	where
145		T: Conjugate,
146	{
147		self.into_const().canonical()
148	}
149
150	#[inline(always)]
151	/// see [`RowRef::adjoint`]
152	pub fn adjoint(self) -> ColRef<'a, T::Conj, Cols, CStride>
153	where
154		T: Conjugate,
155	{
156		self.into_const().adjoint()
157	}
158
159	#[track_caller]
160	#[inline(always)]
161	/// see [`RowRef::get`]
162	pub fn get<ColRange>(self, col: ColRange) -> <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::Target
163	where
164		RowRef<'a, T, Cols, CStride>: RowIndex<ColRange>,
165	{
166		<RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::get(self.into_const(), col)
167	}
168
169	#[track_caller]
170	#[inline(always)]
171	/// see [`RowRef::get_unchecked`]
172	pub unsafe fn get_unchecked<ColRange>(self, col: ColRange) -> <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::Target
173	where
174		RowRef<'a, T, Cols, CStride>: RowIndex<ColRange>,
175	{
176		unsafe { <RowRef<'a, T, Cols, CStride> as RowIndex<ColRange>>::get_unchecked(self.into_const(), col) }
177	}
178
179	#[inline]
180	/// see [`RowRef::reverse_cols`]
181	pub fn reverse_cols(self) -> RowRef<'a, T, Cols, CStride::Rev> {
182		self.into_const().reverse_cols()
183	}
184
185	#[inline]
186	/// see [`RowRef::subcols`]
187	pub fn subcols<V: Shape>(self, col_start: IdxInc<Cols>, ncols: V) -> RowRef<'a, T, V, CStride> {
188		self.into_const().subcols(col_start, ncols)
189	}
190
191	#[inline]
192	#[track_caller]
193	/// see [`RowRef::as_col_shape`]
194	pub fn as_col_shape<V: Shape>(self, ncols: V) -> RowRef<'a, T, V, CStride> {
195		self.into_const().as_col_shape(ncols)
196	}
197
198	#[inline]
199	/// see [`RowRef::as_dyn_cols`]
200	pub fn as_dyn_cols(self) -> RowRef<'a, T, usize, CStride> {
201		self.into_const().as_dyn_cols()
202	}
203
204	#[inline]
205	/// see [`RowRef::as_dyn_stride`]
206	pub fn as_dyn_stride(self) -> RowRef<'a, T, Cols, isize> {
207		self.into_const().as_dyn_stride()
208	}
209
210	#[inline]
211	/// see [`RowRef::iter`]
212	pub fn iter(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = &'a T>
213	where
214		Cols: 'a,
215	{
216		self.0.trans.iter()
217	}
218
219	#[inline]
220	#[cfg(feature = "rayon")]
221	/// see [`RowRef::par_iter`]
222	pub fn par_iter(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = &'a T>
223	where
224		T: Sync,
225		Cols: 'a,
226	{
227		self.0.trans.par_iter()
228	}
229
230	#[inline]
231	/// see [`RowRef::try_as_row_major`]
232	pub fn try_as_row_major(self) -> Option<RowRef<'a, T, Cols, ContiguousFwd>> {
233		self.into_const().try_as_row_major()
234	}
235
236	#[inline]
237	/// see [`RowRef::as_diagonal`]
238	pub fn as_diagonal(self) -> DiagRef<'a, T, Cols, CStride> {
239		DiagRef {
240			0: crate::diag::Ref {
241				inner: self.0.trans.into_const(),
242			},
243		}
244	}
245
246	#[inline(always)]
247	#[doc(hidden)]
248	pub unsafe fn const_cast(self) -> RowMut<'a, T, Cols, CStride> {
249		RowMut {
250			0: Mut {
251				trans: self.0.trans.const_cast(),
252			},
253		}
254	}
255
256	#[inline]
257	/// see [`RowRef::as_mat`]
258	pub fn as_mat(self) -> MatRef<'a, T, usize, Cols, isize, CStride> {
259		self.into_const().as_mat()
260	}
261
262	#[inline]
263	/// see [`RowRef::as_mat`]
264	pub fn as_mat_mut(self) -> MatMut<'a, T, usize, Cols, isize, CStride> {
265		unsafe { self.into_const().as_mat().const_cast() }
266	}
267}
268
269impl<T, Cols: Shape, CStride: Stride, Inner: for<'short> ReborrowMut<'short, Target = Mut<'short, T, Cols, CStride>>> generic::Row<Inner> {
270	#[inline]
271	/// returns a view over `self`
272	pub fn as_mut(&mut self) -> RowMut<'_, T, Cols, CStride> {
273		self.rb_mut()
274	}
275
276	#[inline]
277	/// copies `other` into `self`
278	pub fn copy_from<RhsT: Conjugate<Canonical = T>>(&mut self, other: impl AsRowRef<T = RhsT, Cols = Cols>)
279	where
280		T: ComplexField,
281	{
282		self.rb_mut().transpose_mut().copy_from(other.as_row_ref().transpose());
283	}
284
285	/// fills all the elements of `self` with `value`
286	pub fn fill(&mut self, value: T)
287	where
288		T: Clone,
289	{
290		self.rb_mut().transpose_mut().fill(value)
291	}
292}
293
294impl<'a, T, Cols: Shape, CStride: Stride> RowMut<'a, T, Cols, CStride> {
295	#[inline(always)]
296	/// see [`RowRef::as_ptr`]
297	pub fn as_ptr_mut(&self) -> *mut T {
298		self.trans.as_ptr_mut()
299	}
300
301	#[inline(always)]
302	/// see [`RowRef::ptr_at`]
303	pub fn ptr_at_mut(&self, col: IdxInc<Cols>) -> *mut T {
304		self.trans.ptr_at_mut(col)
305	}
306
307	#[inline(always)]
308	#[track_caller]
309	/// see [`RowRef::ptr_inbounds_at`]
310	pub unsafe fn ptr_inbounds_at_mut(&self, col: Idx<Cols>) -> *mut T {
311		debug_assert!(all(col < self.ncols()));
312		self.trans.ptr_inbounds_at_mut(col)
313	}
314
315	#[inline]
316	#[track_caller]
317	/// see [`RowRef::split_at_col`]
318	pub fn split_at_col_mut(self, col: IdxInc<Cols>) -> (RowMut<'a, T, usize, CStride>, RowMut<'a, T, usize, CStride>) {
319		let (a, b) = self.into_const().split_at_col(col);
320		unsafe { (a.const_cast(), b.const_cast()) }
321	}
322
323	#[inline(always)]
324	/// see [`RowRef::transpose`]
325	pub fn transpose_mut(self) -> ColMut<'a, T, Cols, CStride> {
326		self.0.trans
327	}
328
329	#[inline(always)]
330	/// see [`RowRef::conjugate`]
331	pub fn conjugate_mut(self) -> RowMut<'a, T::Conj, Cols, CStride>
332	where
333		T: Conjugate,
334	{
335		unsafe { self.into_const().conjugate().const_cast() }
336	}
337
338	#[inline(always)]
339	/// see [`RowRef::canonical`]
340	pub fn canonical_mut(self) -> RowMut<'a, T::Canonical, Cols, CStride>
341	where
342		T: Conjugate,
343	{
344		unsafe { self.into_const().canonical().const_cast() }
345	}
346
347	#[inline(always)]
348	/// see [`RowRef::adjoint`]
349	pub fn adjoint_mut(self) -> ColMut<'a, T::Conj, Cols, CStride>
350	where
351		T: Conjugate,
352	{
353		unsafe { self.into_const().adjoint().const_cast() }
354	}
355
356	#[inline(always)]
357	#[track_caller]
358	pub(crate) fn at_mut(self, col: Idx<Cols>) -> &'a mut T {
359		assert!(all(col < self.ncols()));
360		unsafe { self.at_mut_unchecked(col) }
361	}
362
363	#[inline(always)]
364	#[track_caller]
365	pub(crate) unsafe fn at_mut_unchecked(self, col: Idx<Cols>) -> &'a mut T {
366		&mut *self.ptr_inbounds_at_mut(col)
367	}
368
369	#[track_caller]
370	#[inline(always)]
371	/// see [`RowRef::get`]
372	pub fn get_mut<ColRange>(self, col: ColRange) -> <RowMut<'a, T, Cols, CStride> as RowIndex<ColRange>>::Target
373	where
374		RowMut<'a, T, Cols, CStride>: RowIndex<ColRange>,
375	{
376		<RowMut<'a, T, Cols, CStride> as RowIndex<ColRange>>::get(self, col)
377	}
378
379	#[track_caller]
380	#[inline(always)]
381	/// see [`RowRef::get`]
382	pub unsafe fn get_mut_unchecked<ColRange>(self, col: ColRange) -> <RowMut<'a, T, Cols, CStride> as RowIndex<ColRange>>::Target
383	where
384		RowMut<'a, T, Cols, CStride>: RowIndex<ColRange>,
385	{
386		unsafe { <RowMut<'a, T, Cols, CStride> as RowIndex<ColRange>>::get_unchecked(self, col) }
387	}
388
389	#[inline]
390	/// see [`RowRef::reverse_cols`]
391	pub fn reverse_cols_mut(self) -> RowMut<'a, T, Cols, CStride::Rev> {
392		unsafe { self.into_const().reverse_cols().const_cast() }
393	}
394
395	#[inline]
396	/// see [`RowRef::subcols`]
397	pub fn subcols_mut<V: Shape>(self, col_start: IdxInc<Cols>, ncols: V) -> RowMut<'a, T, V, CStride> {
398		unsafe { self.into_const().subcols(col_start, ncols).const_cast() }
399	}
400
401	#[inline]
402	#[track_caller]
403	/// see [`RowRef::as_col_shape`]
404	pub fn as_col_shape_mut<V: Shape>(self, ncols: V) -> RowMut<'a, T, V, CStride> {
405		unsafe { self.into_const().as_col_shape(ncols).const_cast() }
406	}
407
408	#[inline]
409	/// see [`RowRef::as_dyn_cols`]
410	pub fn as_dyn_cols_mut(self) -> RowMut<'a, T, usize, CStride> {
411		unsafe { self.into_const().as_dyn_cols().const_cast() }
412	}
413
414	#[inline]
415	/// see [`RowRef::as_dyn_stride`]
416	pub fn as_dyn_stride_mut(self) -> RowMut<'a, T, Cols, isize> {
417		unsafe { self.into_const().as_dyn_stride().const_cast() }
418	}
419
420	#[inline]
421	/// see [`RowRef::iter`]
422	pub fn iter_mut(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = &'a mut T>
423	where
424		Cols: 'a,
425	{
426		self.0.trans.iter_mut()
427	}
428
429	#[inline]
430	#[cfg(feature = "rayon")]
431	/// see [`RowRef::par_iter`]
432	pub fn par_iter_mut(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = &'a mut T>
433	where
434		T: Send,
435		Cols: 'a,
436	{
437		self.0.trans.par_iter_mut()
438	}
439
440	#[inline]
441	#[track_caller]
442	#[cfg(feature = "rayon")]
443	/// see [`RowRef::par_partition`]
444	pub fn par_partition(self, count: usize) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = RowRef<'a, T, usize, CStride>>
445	where
446		T: Sync,
447		Cols: 'a,
448	{
449		self.into_const().par_partition(count)
450	}
451
452	#[inline]
453	#[track_caller]
454	#[cfg(feature = "rayon")]
455	/// see [`RowRef::par_partition`]
456	pub fn par_partition_mut(self, count: usize) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = RowMut<'a, T, usize, CStride>>
457	where
458		T: Send,
459		Cols: 'a,
460	{
461		use crate::mat::matmut::SyncCell;
462		use rayon::prelude::*;
463		unsafe {
464			self.as_type::<SyncCell<T>>()
465				.into_const()
466				.par_partition(count)
467				.map(|col| col.const_cast().as_type::<T>())
468		}
469	}
470
471	pub(crate) unsafe fn as_type<U>(self) -> RowMut<'a, U, Cols, CStride> {
472		RowMut::from_raw_parts_mut(self.as_ptr_mut() as *mut U, self.ncols(), self.col_stride())
473	}
474
475	#[inline]
476	/// see [`RowRef::try_as_row_major`]
477	pub fn try_as_row_major_mut(self) -> Option<RowMut<'a, T, Cols, ContiguousFwd>> {
478		self.into_const().try_as_row_major().map(|x| unsafe { x.const_cast() })
479	}
480
481	#[inline]
482	/// see [`RowRef::as_diagonal`]
483	pub fn as_diagonal_mut(self) -> DiagMut<'a, T, Cols, CStride> {
484		DiagMut {
485			0: crate::diag::Mut { inner: self.0.trans },
486		}
487	}
488
489	#[inline]
490	pub(crate) fn __at_mut(self, i: Idx<Cols>) -> &'a mut T {
491		self.at_mut(i)
492	}
493}
494
495impl<'a, T, Rows: Shape> RowMut<'a, T, Rows, ContiguousFwd> {
496	/// returns a reference over the elements as a slice
497	#[inline]
498	pub fn as_slice(self) -> &'a [T] {
499		self.transpose().as_slice()
500	}
501}
502
503impl<'a, 'ROWS, T> RowMut<'a, T, Dim<'ROWS>, ContiguousFwd> {
504	/// returns a reference over the elements as a lifetime-bound slice
505	#[inline]
506	pub fn as_array(self) -> &'a Array<'ROWS, T> {
507		self.transpose().as_array()
508	}
509}
510
511impl<'a, T, Cols: Shape> RowMut<'a, T, Cols, ContiguousFwd> {
512	/// returns a reference over the elements as a slice
513	#[inline]
514	pub fn as_slice_mut(self) -> &'a mut [T] {
515		self.transpose_mut().as_slice_mut()
516	}
517}
518
519impl<'a, 'COLS, T> RowMut<'a, T, Dim<'COLS>, ContiguousFwd> {
520	/// returns a reference over the elements as a lifetime-bound slice
521	#[inline]
522	pub fn as_array_mut(self) -> &'a mut Array<'COLS, T> {
523		self.transpose_mut().as_array_mut()
524	}
525}
526
527impl<'COLS, 'a, T, CStride: Stride> RowMut<'a, T, Dim<'COLS>, CStride> {
528	#[doc(hidden)]
529	#[inline]
530	pub fn split_cols_with<'LEFT, 'RIGHT>(
531		self,
532		col: Partition<'LEFT, 'RIGHT, 'COLS>,
533	) -> (RowRef<'a, T, Dim<'LEFT>, CStride>, RowRef<'a, T, Dim<'RIGHT>, CStride>) {
534		let (a, b) = self.split_at_col(col.midpoint());
535		(a.as_col_shape(col.head), b.as_col_shape(col.tail))
536	}
537}
538
539impl<'COLS, 'a, T, CStride: Stride> RowMut<'a, T, Dim<'COLS>, CStride> {
540	#[doc(hidden)]
541	#[inline]
542	pub fn split_cols_with_mut<'LEFT, 'RIGHT>(
543		self,
544		col: Partition<'LEFT, 'RIGHT, 'COLS>,
545	) -> (RowMut<'a, T, Dim<'LEFT>, CStride>, RowMut<'a, T, Dim<'RIGHT>, CStride>) {
546		let (a, b) = self.split_at_col_mut(col.midpoint());
547		(a.as_col_shape_mut(col.head), b.as_col_shape_mut(col.tail))
548	}
549}
550
551impl<T: core::fmt::Debug, Cols: Shape, CStride: Stride> core::fmt::Debug for Mut<'_, T, Cols, CStride> {
552	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
553		self.rb().fmt(f)
554	}
555}
556
557impl<'a, T, Cols: Shape, CStride: Stride> RowMut<'a, T, Cols, CStride>
558where
559	T: RealField,
560{
561	/// Returns the maximum element in the row, or `None` if the row is empty
562	pub fn max(&self) -> Option<T> {
563		self.rb().as_dyn_cols().as_dyn_stride().internal_max()
564	}
565
566	/// Returns the minimum element in the row, or `None` if the row is empty
567	pub fn min(&self) -> Option<T> {
568		self.rb().as_dyn_cols().as_dyn_stride().internal_min()
569	}
570}
571
572#[cfg(test)]
573mod tests {
574	use crate::Row;
575
576	#[test]
577	fn test_row_min() {
578		let row: Row<f64> = Row::from_fn(5, |x| (x + 1) as f64);
579		let rowmut = row.as_ref();
580		assert_eq!(rowmut.min(), Some(1.0));
581
582		let empty: Row<f64> = Row::from_fn(0, |_| 0.0);
583		let emptymut = empty.as_ref();
584		assert_eq!(emptymut.min(), None);
585	}
586
587	#[test]
588	fn test_row_max() {
589		let row: Row<f64> = Row::from_fn(5, |x| (x + 1) as f64);
590		let rowmut = row.as_ref();
591		assert_eq!(rowmut.max(), Some(5.0));
592
593		let empty: Row<f64> = Row::from_fn(0, |_| 0.0);
594		let emptymut = empty.as_ref();
595		assert_eq!(emptymut.max(), None);
596	}
597}