1use crate::internal_prelude::*;
29use core::marker::PhantomData;
30use dyn_stack::StackReq;
31use faer_traits::ComplexField;
32
33use crate::Shape;
34use crate::mat::matown::align_for;
35use crate::mat::{AsMatMut, MatMut};
36
37pub fn temp_mat_scratch<T: ComplexField>(nrows: usize, ncols: usize) -> StackReq {
39 let align = align_for(core::mem::size_of::<T>(), core::mem::align_of::<T>(), core::mem::needs_drop::<T>());
40
41 let mut col_stride = nrows;
42 if align > core::mem::size_of::<T>() {
43 col_stride = col_stride.msrv_next_multiple_of(align / core::mem::size_of::<T>());
44 }
45 let len = col_stride.checked_mul(ncols).unwrap();
46 StackReq::new_aligned::<T>(len, align)
47}
48
49struct DynMat<'a, T: ComplexField, Rows: Shape, Cols: Shape> {
50 ptr: *mut T,
51 nrows: Rows,
52 ncols: Cols,
53 col_stride: usize,
54 __marker: PhantomData<(&'a T, T)>,
55}
56
57impl<'a, T: ComplexField, Rows: Shape, Cols: Shape> Drop for DynMat<'a, T, Rows, Cols> {
58 #[inline]
59 fn drop(&mut self) {
60 unsafe { core::ptr::drop_in_place(core::slice::from_raw_parts_mut(self.ptr, self.col_stride * self.ncols.unbound())) };
61 }
62}
63
64impl<'a, T: ComplexField, Rows: Shape, Cols: Shape> AsMatRef for DynMat<'a, T, Rows, Cols> {
65 type Cols = Cols;
66 type Owned = Mat<T, Rows, Cols>;
67 type Rows = Rows;
68 type T = T;
69
70 fn as_mat_ref(&self) -> crate::mat::MatRef<T, Rows, Cols> {
71 unsafe { MatRef::from_raw_parts(self.ptr as *const T, self.nrows, self.ncols, 1, self.col_stride as isize) }
72 }
73}
74
75impl<'a, T: ComplexField, Rows: Shape, Cols: Shape> AsMatMut for DynMat<'a, T, Rows, Cols> {
76 fn as_mat_mut(&mut self) -> crate::mat::MatMut<T, Rows, Cols> {
77 unsafe { MatMut::from_raw_parts_mut(self.ptr, self.nrows, self.ncols, 1, self.col_stride as isize) }
78 }
79}
80
81struct DropGuard<T> {
82 ptr: *mut T,
83 len: usize,
84}
85impl<T> Drop for DropGuard<T> {
86 #[inline]
87 fn drop(&mut self) {
88 unsafe { core::ptr::drop_in_place(core::slice::from_raw_parts_mut(self.ptr, self.len)) };
89 }
90}
91
92#[track_caller]
94pub unsafe fn temp_mat_uninit<'a, T: ComplexField + 'a, Rows: Shape + 'a, Cols: Shape + 'a>(
95 nrows: Rows,
96 ncols: Cols,
97 stack: &'a mut MemStack,
98) -> (impl 'a + AsMatMut<T = T, Rows = Rows, Cols = Cols>, &'a mut MemStack) {
99 let align = align_for(core::mem::size_of::<T>(), core::mem::align_of::<T>(), core::mem::needs_drop::<T>());
100
101 let mut col_stride = nrows.unbound();
102 if align > core::mem::size_of::<T>() {
103 col_stride = col_stride.msrv_next_multiple_of(align / core::mem::size_of::<T>());
104 }
105 let len = col_stride.checked_mul(ncols.unbound()).unwrap();
106
107 let (uninit, stack) = stack.make_aligned_uninit::<T>(len, align);
108
109 let ptr = uninit.as_mut_ptr() as *mut T;
110 if core::mem::needs_drop::<T>() {
111 unsafe {
112 let mut guard = DropGuard { ptr, len: 0 };
113 for j in 0..len {
114 let ptr = ptr.add(j);
115 let val = T::nan_impl();
116 ptr.write(val);
117 guard.len += 1;
118 }
119 core::mem::forget(guard);
120 }
121 }
122 (
123 DynMat {
124 ptr,
125 nrows,
126 ncols,
127 col_stride,
128 __marker: PhantomData,
129 },
130 stack,
131 )
132}
133
134#[track_caller]
136pub fn temp_mat_zeroed<'a, T: ComplexField + 'a, Rows: Shape + 'a, Cols: Shape + 'a>(
137 nrows: Rows,
138 ncols: Cols,
139 stack: &'a mut MemStack,
140) -> (impl 'a + AsMatMut<T = T, Rows = Rows, Cols = Cols>, &'a mut MemStack) {
141 let align = align_for(core::mem::size_of::<T>(), core::mem::align_of::<T>(), core::mem::needs_drop::<T>());
142
143 let mut col_stride = nrows.unbound();
144 if align > core::mem::size_of::<T>() {
145 col_stride = col_stride.msrv_next_multiple_of(align / core::mem::size_of::<T>());
146 }
147 let len = col_stride.checked_mul(ncols.unbound()).unwrap();
148 _ = stack.make_aligned_uninit::<T>(len, align);
149
150 let (uninit, stack) = stack.make_aligned_uninit::<T>(len, align);
151
152 let ptr = uninit.as_mut_ptr() as *mut T;
153
154 unsafe {
155 let mut guard = DropGuard { ptr, len: 0 };
156 for j in 0..len {
157 let ptr = ptr.add(j);
158 let val = T::zero_impl();
159 ptr.write(val);
160 guard.len += 1;
161 }
162 core::mem::forget(guard);
163 }
164
165 (
166 DynMat {
167 ptr,
168 nrows,
169 ncols,
170 col_stride,
171 __marker: PhantomData,
172 },
173 stack,
174 )
175}
176
177pub mod matmul;
178pub mod triangular_inverse;
180pub mod triangular_solve;
182
183pub(crate) mod reductions;
184pub mod zip;
186
187pub mod householder;
188pub mod jacobi;
190
191pub mod kron;
193
194pub mod cholesky;
195pub mod lu;
196pub mod qr;
197
198pub mod evd;
199pub mod svd;
200
201mod mat_ops;
202
203pub mod solvers;