1use crate::internal_prelude_sp::*;
4
5pub mod bicgstab;
7pub mod conjugate_gradient;
9pub mod lsmr;
11
12pub mod eigen;
14
15mod operator_impl;
16
17#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
19pub enum InitialGuessStatus {
20 Zero,
22 #[default]
24 MaybeNonZero,
25}
26
27#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
29pub struct IdentityPrecond {
30 pub dim: usize,
32}
33
34pub trait LinOp<T: ComplexField>: Sync + core::fmt::Debug {
36 fn apply_scratch(&self, rhs_ncols: usize, par: Par) -> StackReq;
39
40 fn nrows(&self) -> usize;
42 fn ncols(&self) -> usize;
44
45 fn apply(&self, out: MatMut<'_, T>, rhs: MatRef<'_, T>, par: Par, stack: &mut MemStack);
47
48 fn conj_apply(&self, out: MatMut<'_, T>, rhs: MatRef<'_, T>, par: Par, stack: &mut MemStack);
50}
51
52impl<T: ComplexField> LinOp<T> for IdentityPrecond {
53 #[inline]
54 #[track_caller]
55 fn apply_scratch(&self, _rhs_ncols: usize, _par: Par) -> StackReq {
56 StackReq::EMPTY
57 }
58
59 #[inline]
60 fn nrows(&self) -> usize {
61 self.dim
62 }
63
64 #[inline]
65 fn ncols(&self) -> usize {
66 self.dim
67 }
68
69 #[inline]
70 #[track_caller]
71 fn apply(&self, out: MatMut<'_, T>, rhs: MatRef<'_, T>, _par: Par, _stack: &mut MemStack) {
72 { out }.copy_from(rhs);
73 }
74
75 #[inline]
76 #[track_caller]
77 fn conj_apply(&self, out: MatMut<'_, T>, rhs: MatRef<'_, T>, _par: Par, _stack: &mut MemStack) {
78 { out }.copy_from(rhs);
79 }
80}
81impl<T: ComplexField> BiLinOp<T> for IdentityPrecond {
82 #[inline]
83 fn transpose_apply_scratch(&self, _rhs_ncols: usize, _par: Par) -> StackReq {
84 StackReq::EMPTY
85 }
86
87 #[inline]
88 #[track_caller]
89 fn transpose_apply(&self, out: MatMut<'_, T>, rhs: MatRef<'_, T>, _par: Par, _stack: &mut MemStack) {
90 { out }.copy_from(rhs);
91 }
92
93 #[inline]
94 #[track_caller]
95 fn adjoint_apply(&self, out: MatMut<'_, T>, rhs: MatRef<'_, T>, _par: Par, _stack: &mut MemStack) {
96 { out }.copy_from(rhs);
97 }
98}
99impl<T: ComplexField> Precond<T> for IdentityPrecond {
100 fn apply_in_place_scratch(&self, _rhs_ncols: usize, _par: Par) -> StackReq {
101 StackReq::EMPTY
102 }
103
104 fn apply_in_place(&self, _rhs: MatMut<'_, T>, _par: Par, _stack: &mut MemStack) {}
105
106 fn conj_apply_in_place(&self, _rhs: MatMut<'_, T>, _par: Par, _stack: &mut MemStack) {}
107}
108impl<T: ComplexField> BiPrecond<T> for IdentityPrecond {
109 fn transpose_apply_in_place_scratch(&self, _rhs_ncols: usize, _par: Par) -> StackReq {
110 StackReq::EMPTY
111 }
112
113 fn transpose_apply_in_place(&self, _rhs: MatMut<'_, T>, _par: Par, _stack: &mut MemStack) {}
114
115 fn adjoint_apply_in_place(&self, _rhs: MatMut<'_, T>, _par: Par, _stack: &mut MemStack) {}
116}
117
118pub trait BiLinOp<T: ComplexField>: LinOp<T> {
120 fn transpose_apply_scratch(&self, rhs_ncols: usize, par: Par) -> StackReq;
123
124 fn transpose_apply(&self, out: MatMut<'_, T>, rhs: MatRef<'_, T>, par: Par, stack: &mut MemStack);
126
127 fn adjoint_apply(&self, out: MatMut<'_, T>, rhs: MatRef<'_, T>, par: Par, stack: &mut MemStack);
129}
130
131pub trait Precond<T: ComplexField>: LinOp<T> {
135 fn apply_in_place_scratch(&self, rhs_ncols: usize, par: Par) -> StackReq {
138 temp_mat_scratch::<T>(self.nrows(), rhs_ncols).and(self.apply_scratch(rhs_ncols, par))
139 }
140
141 #[track_caller]
143 fn apply_in_place(&self, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
144 let (mut tmp, stack) = unsafe { temp_mat_uninit::<T, _, _>(self.nrows(), rhs.ncols(), stack) };
145 let mut tmp = tmp.as_mat_mut();
146 self.apply(tmp.rb_mut(), rhs.rb(), par, stack);
147 { rhs }.copy_from(&tmp);
148 }
149
150 #[track_caller]
152 fn conj_apply_in_place(&self, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
153 let (mut tmp, stack) = unsafe { temp_mat_uninit::<T, _, _>(self.nrows(), rhs.ncols(), stack) };
154 let mut tmp = tmp.as_mat_mut();
155
156 self.conj_apply(tmp.rb_mut(), rhs.rb(), par, stack);
157 { rhs }.copy_from(&tmp);
158 }
159}
160
161pub trait BiPrecond<T: ComplexField>: Precond<T> + BiLinOp<T> {
165 fn transpose_apply_in_place_scratch(&self, rhs_ncols: usize, par: Par) -> StackReq {
168 temp_mat_scratch::<T>(self.nrows(), rhs_ncols).and(self.transpose_apply_scratch(rhs_ncols, par))
169 }
170
171 #[track_caller]
173 fn transpose_apply_in_place(&self, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
174 let (mut tmp, stack) = unsafe { temp_mat_uninit::<T, _, _>(self.nrows(), rhs.ncols(), stack) };
175 let mut tmp = tmp.as_mat_mut();
176 self.transpose_apply(tmp.rb_mut(), rhs.rb(), par, stack);
177 { rhs }.copy_from(&tmp);
178 }
179
180 #[track_caller]
182 fn adjoint_apply_in_place(&self, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
183 let (mut tmp, stack) = unsafe { temp_mat_uninit::<T, _, _>(self.nrows(), rhs.ncols(), stack) };
184 let mut tmp = tmp.as_mat_mut();
185
186 self.adjoint_apply(tmp.rb_mut(), rhs.rb(), par, stack);
187 { rhs }.copy_from(&tmp);
188 }
189}
190
191impl<T: ComplexField, M: Sized + LinOp<T>> LinOp<T> for &M {
192 #[inline]
193 #[track_caller]
194 fn apply_scratch(&self, rhs_ncols: usize, par: Par) -> StackReq {
195 (**self).apply_scratch(rhs_ncols, par)
196 }
197
198 #[inline]
199 fn nrows(&self) -> usize {
200 (**self).nrows()
201 }
202
203 #[inline]
204 fn ncols(&self) -> usize {
205 (**self).ncols()
206 }
207
208 #[inline]
209 #[track_caller]
210 fn apply(&self, out: MatMut<'_, T>, rhs: MatRef<'_, T>, par: Par, stack: &mut MemStack) {
211 (**self).apply(out, rhs, par, stack)
212 }
213
214 #[inline]
215 #[track_caller]
216 fn conj_apply(&self, out: MatMut<'_, T>, rhs: MatRef<'_, T>, par: Par, stack: &mut MemStack) {
217 (**self).conj_apply(out, rhs, par, stack)
218 }
219}
220
221impl<T: ComplexField, M: Sized + BiLinOp<T>> BiLinOp<T> for &M {
222 #[inline]
223 #[track_caller]
224 fn transpose_apply_scratch(&self, rhs_ncols: usize, par: Par) -> StackReq {
225 (**self).transpose_apply_scratch(rhs_ncols, par)
226 }
227
228 #[inline]
229 #[track_caller]
230 fn transpose_apply(&self, out: MatMut<'_, T>, rhs: MatRef<'_, T>, par: Par, stack: &mut MemStack) {
231 (**self).transpose_apply(out, rhs, par, stack)
232 }
233
234 #[inline]
235 #[track_caller]
236 fn adjoint_apply(&self, out: MatMut<'_, T>, rhs: MatRef<'_, T>, par: Par, stack: &mut MemStack) {
237 (**self).adjoint_apply(out, rhs, par, stack)
238 }
239}
240
241impl<T: ComplexField, M: Sized + Precond<T>> Precond<T> for &M {
242 fn apply_in_place_scratch(&self, rhs_ncols: usize, par: Par) -> StackReq {
243 (**self).apply_in_place_scratch(rhs_ncols, par)
244 }
245
246 fn apply_in_place(&self, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
247 (**self).apply_in_place(rhs, par, stack);
248 }
249
250 fn conj_apply_in_place(&self, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
251 (**self).conj_apply_in_place(rhs, par, stack);
252 }
253}
254
255impl<T: ComplexField, M: Sized + BiPrecond<T>> BiPrecond<T> for &M {
256 fn transpose_apply_in_place_scratch(&self, rhs_ncols: usize, par: Par) -> StackReq {
257 (**self).transpose_apply_in_place_scratch(rhs_ncols, par)
258 }
259
260 fn transpose_apply_in_place(&self, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
261 (**self).transpose_apply_in_place(rhs, par, stack);
262 }
263
264 fn adjoint_apply_in_place(&self, rhs: MatMut<'_, T>, par: Par, stack: &mut MemStack) {
265 (**self).adjoint_apply_in_place(rhs, par, stack);
266 }
267}