faer/linalg/
mod.rs

1//! linear algebra module
2//!
3//! contains low level routines and the implementation of their corresponding high level
4//! wrappers
5//!
6//! # memory allocation
7//! since most `faer` crates aim to expose a low level api for optimal performance, most algorithms
8//! try to defer memory allocation to the user
9//!
10//! however, since a lot of algorithms need some form of temporary space for intermediate
11//! computations, they may ask for a slice of memory for that purpose, by taking a [`stack:
12//! MemStack`](dyn_stack::MemStack) parameter. a [`MemStack`] is a thin wrapper over a slice of
13//! memory bytes. this memory may come from any valid source (heap allocation, fixed-size array on
14//! the stack, etc.). the functions taking a [`MemStack`] parameter have a corresponding function
15//! with a similar name ending in `_scratch` that returns the memory requirements of the algorithm.
16//! for example:
17//! [`householder::apply_block_householder_on_the_left_in_place_with_conj`] and
18//! [`householder::apply_block_householder_on_the_left_in_place_scratch`]
19//!
20//! the memory stack may be reused in user-code to avoid repeated allocations, and it is also
21//! possible to compute the sum ([`dyn_stack::StackReq::all_of`]) or union
22//! ([`dyn_stack::StackReq::any_of`]) of multiple scratch requirements, in order to optimally
23//! combine them into a single allocation
24//!
25//! after computing a [`dyn_stack::StackReq`], one can query its size and alignment to allocate the
26//! required memory. the simplest way to do so is through [`dyn_stack::MemBuffer::new`]
27
28use crate::internal_prelude::*;
29use core::marker::PhantomData;
30use dyn_stack::StackReq;
31use faer_traits::ComplexField;
32
33use crate::Shape;
34use crate::mat::matown::align_for;
35use crate::mat::{AsMatMut, MatMut};
36
37/// returns the stack requirements for creating a temporary matrix with the given dimensions.
38pub fn temp_mat_scratch<T: ComplexField>(nrows: usize, ncols: usize) -> StackReq {
39	let align = align_for(core::mem::size_of::<T>(), core::mem::align_of::<T>(), core::mem::needs_drop::<T>());
40
41	let mut col_stride = nrows;
42	if align > core::mem::size_of::<T>() {
43		col_stride = col_stride.msrv_next_multiple_of(align / core::mem::size_of::<T>());
44	}
45	let len = col_stride.checked_mul(ncols).unwrap();
46	StackReq::new_aligned::<T>(len, align)
47}
48
49struct DynMat<'a, T: ComplexField, Rows: Shape, Cols: Shape> {
50	ptr: *mut T,
51	nrows: Rows,
52	ncols: Cols,
53	col_stride: usize,
54	__marker: PhantomData<(&'a T, T)>,
55}
56
57impl<'a, T: ComplexField, Rows: Shape, Cols: Shape> Drop for DynMat<'a, T, Rows, Cols> {
58	#[inline]
59	fn drop(&mut self) {
60		unsafe { core::ptr::drop_in_place(core::slice::from_raw_parts_mut(self.ptr, self.col_stride * self.ncols.unbound())) };
61	}
62}
63
64impl<'a, T: ComplexField, Rows: Shape, Cols: Shape> AsMatRef for DynMat<'a, T, Rows, Cols> {
65	type Cols = Cols;
66	type Owned = Mat<T, Rows, Cols>;
67	type Rows = Rows;
68	type T = T;
69
70	fn as_mat_ref(&self) -> crate::mat::MatRef<T, Rows, Cols> {
71		unsafe { MatRef::from_raw_parts(self.ptr as *const T, self.nrows, self.ncols, 1, self.col_stride as isize) }
72	}
73}
74
75impl<'a, T: ComplexField, Rows: Shape, Cols: Shape> AsMatMut for DynMat<'a, T, Rows, Cols> {
76	fn as_mat_mut(&mut self) -> crate::mat::MatMut<T, Rows, Cols> {
77		unsafe { MatMut::from_raw_parts_mut(self.ptr, self.nrows, self.ncols, 1, self.col_stride as isize) }
78	}
79}
80
81struct DropGuard<T> {
82	ptr: *mut T,
83	len: usize,
84}
85impl<T> Drop for DropGuard<T> {
86	#[inline]
87	fn drop(&mut self) {
88		unsafe { core::ptr::drop_in_place(core::slice::from_raw_parts_mut(self.ptr, self.len)) };
89	}
90}
91
92/// creates a temporary matrix of uninit values, from the given memory stack.
93#[track_caller]
94pub unsafe fn temp_mat_uninit<'a, T: ComplexField + 'a, Rows: Shape + 'a, Cols: Shape + 'a>(
95	nrows: Rows,
96	ncols: Cols,
97	stack: &'a mut MemStack,
98) -> (impl 'a + AsMatMut<T = T, Rows = Rows, Cols = Cols>, &'a mut MemStack) {
99	let align = align_for(core::mem::size_of::<T>(), core::mem::align_of::<T>(), core::mem::needs_drop::<T>());
100
101	let mut col_stride = nrows.unbound();
102	if align > core::mem::size_of::<T>() {
103		col_stride = col_stride.msrv_next_multiple_of(align / core::mem::size_of::<T>());
104	}
105	let len = col_stride.checked_mul(ncols.unbound()).unwrap();
106
107	let (uninit, stack) = stack.make_aligned_uninit::<T>(len, align);
108
109	let ptr = uninit.as_mut_ptr() as *mut T;
110	if core::mem::needs_drop::<T>() {
111		unsafe {
112			let mut guard = DropGuard { ptr, len: 0 };
113			for j in 0..len {
114				let ptr = ptr.add(j);
115				let val = T::nan_impl();
116				ptr.write(val);
117				guard.len += 1;
118			}
119			core::mem::forget(guard);
120		}
121	}
122	(
123		DynMat {
124			ptr,
125			nrows,
126			ncols,
127			col_stride,
128			__marker: PhantomData,
129		},
130		stack,
131	)
132}
133
134/// creates a temporary matrix of zero values, from the given memory stack.
135#[track_caller]
136pub fn temp_mat_zeroed<'a, T: ComplexField + 'a, Rows: Shape + 'a, Cols: Shape + 'a>(
137	nrows: Rows,
138	ncols: Cols,
139	stack: &'a mut MemStack,
140) -> (impl 'a + AsMatMut<T = T, Rows = Rows, Cols = Cols>, &'a mut MemStack) {
141	let align = align_for(core::mem::size_of::<T>(), core::mem::align_of::<T>(), core::mem::needs_drop::<T>());
142
143	let mut col_stride = nrows.unbound();
144	if align > core::mem::size_of::<T>() {
145		col_stride = col_stride.msrv_next_multiple_of(align / core::mem::size_of::<T>());
146	}
147	let len = col_stride.checked_mul(ncols.unbound()).unwrap();
148	_ = stack.make_aligned_uninit::<T>(len, align);
149
150	let (uninit, stack) = stack.make_aligned_uninit::<T>(len, align);
151
152	let ptr = uninit.as_mut_ptr() as *mut T;
153
154	unsafe {
155		let mut guard = DropGuard { ptr, len: 0 };
156		for j in 0..len {
157			let ptr = ptr.add(j);
158			let val = T::zero_impl();
159			ptr.write(val);
160			guard.len += 1;
161		}
162		core::mem::forget(guard);
163	}
164
165	(
166		DynMat {
167			ptr,
168			nrows,
169			ncols,
170			col_stride,
171			__marker: PhantomData,
172		},
173		stack,
174	)
175}
176
177pub mod matmul;
178/// triangular matrix inverse
179pub mod triangular_inverse;
180/// triangular matrix solve
181pub mod triangular_solve;
182
183pub(crate) mod reductions;
184/// matrix zipping implementation
185pub mod zip;
186
187pub mod householder;
188/// jacobi rotation matrix
189pub mod jacobi;
190
191/// kronecker product
192pub mod kron;
193
194pub mod cholesky;
195pub mod lu;
196pub mod qr;
197
198pub mod evd;
199pub mod svd;
200
201mod mat_ops;
202
203/// high level solvers
204pub mod solvers;