1use super::bound::{Dim, Idx};
2use crate::internal_prelude::*;
3use core::marker::PhantomData;
4use faer_traits::SimdCapabilities;
5use pulp::Simd;
6
7pub struct SimdCtx<'N, T: ComplexField, S: Simd> {
8 pub ctx: T::SimdCtx<S>,
9 pub len: Dim<'N>,
10 offset: usize,
11 head_end: usize,
12 body_end: usize,
13 tail_end: usize,
14 head_mask: T::SimdMask<S>,
15 tail_mask: T::SimdMask<S>,
16 head_mem_mask: T::SimdMemMask<S>,
17 tail_mem_mask: T::SimdMemMask<S>,
18}
19
20impl<'N, T: ComplexField, S: Simd> core::fmt::Debug for SimdCtx<'N, T, S> {
21 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
22 f.debug_struct("SimdCtx")
23 .field("len", &self.len)
24 .field("offset", &self.offset)
25 .field("head_end", &self.head_end)
26 .field("body_end", &self.body_end)
27 .field("tail_end", &self.tail_end)
28 .field("head_mask", &self.head_mask)
29 .field("tail_mask", &self.tail_mask)
30 .finish_non_exhaustive()
31 }
32}
33
34impl<T: ComplexField, S: Simd> Copy for SimdCtx<'_, T, S> {}
35impl<T: ComplexField, S: Simd> Clone for SimdCtx<'_, T, S> {
36 #[inline]
37 fn clone(&self) -> Self {
38 *self
39 }
40}
41
42impl<T: ComplexField, S: Simd> core::ops::Deref for SimdCtx<'_, T, S> {
43 type Target = faer_traits::SimdCtx<T, S>;
44
45 #[inline(always)]
46 fn deref(&self) -> &Self::Target {
47 Self::Target::new(&self.ctx)
48 }
49}
50
51pub trait SimdIndex<'N, T: ComplexField, S: Simd> {
52 fn read(simd: &SimdCtx<'N, T, S>, slice: ColRef<'_, T, Dim<'N>, ContiguousFwd>, index: Self) -> T::SimdVec<S>;
53
54 fn write(simd: &SimdCtx<'N, T, S>, slice: ColMut<'_, T, Dim<'N>, ContiguousFwd>, index: Self, value: T::SimdVec<S>);
55}
56
57impl<'N, T: ComplexField, S: Simd> SimdIndex<'N, T, S> for SimdBody<'N, T, S> {
58 #[inline(always)]
59 fn read(simd: &SimdCtx<'N, T, S>, slice: ColRef<'_, T, Dim<'N>, ContiguousFwd>, index: Self) -> T::SimdVec<S> {
60 unsafe { simd.load(&*(slice.as_ptr().wrapping_offset(index.start) as *const T::SimdVec<S>)) }
61 }
62
63 #[inline(always)]
64 fn write(simd: &SimdCtx<'N, T, S>, slice: ColMut<'_, T, Dim<'N>, ContiguousFwd>, index: Self, value: T::SimdVec<S>) {
65 unsafe {
66 simd.store(&mut *(slice.as_ptr_mut().wrapping_offset(index.start) as *mut T::SimdVec<S>), value);
67 }
68 }
69}
70
71impl<'N, T: ComplexField, S: Simd> SimdIndex<'N, T, S> for SimdHead<'N, T, S> {
72 #[inline(always)]
73 fn read(simd: &SimdCtx<'N, T, S>, slice: ColRef<'_, T, Dim<'N>, ContiguousFwd>, index: Self) -> T::SimdVec<S> {
74 unsafe { simd.mask_load(simd.head_mem_mask, slice.as_ptr().wrapping_offset(index.start) as *const T::SimdVec<S>) }
75 }
76
77 #[inline(always)]
78 fn write(simd: &SimdCtx<'N, T, S>, slice: ColMut<'_, T, Dim<'N>, ContiguousFwd>, index: Self, value: T::SimdVec<S>) {
79 unsafe {
80 simd.mask_store(
81 simd.head_mem_mask,
82 slice.as_ptr_mut().wrapping_offset(index.start) as *mut T::SimdVec<S>,
83 value,
84 );
85 }
86 }
87}
88
89impl<'N, T: ComplexField, S: Simd> SimdIndex<'N, T, S> for SimdTail<'N, T, S> {
90 #[inline(always)]
91 fn read(simd: &SimdCtx<'N, T, S>, slice: ColRef<'_, T, Dim<'N>, ContiguousFwd>, index: Self) -> T::SimdVec<S> {
92 unsafe { simd.mask_load(simd.tail_mem_mask, slice.as_ptr().wrapping_offset(index.start) as *const T::SimdVec<S>) }
93 }
94
95 #[inline(always)]
96 fn write(simd: &SimdCtx<'N, T, S>, slice: ColMut<'_, T, Dim<'N>, ContiguousFwd>, index: Self, value: T::SimdVec<S>) {
97 unsafe {
98 simd.mask_store(
99 simd.tail_mem_mask,
100 slice.as_ptr_mut().wrapping_offset(index.start) as *mut T::SimdVec<S>,
101 value,
102 );
103 }
104 }
105}
106
107impl<'N, T: ComplexField, S: Simd> SimdCtx<'N, T, S> {
108 #[inline(always)]
109 pub fn new(simd: T::SimdCtx<S>, len: Dim<'N>) -> Self {
110 core::assert!(try_const! { matches!(T::SIMD_CAPABILITIES, SimdCapabilities::Simd) });
111
112 let stride = core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>();
113 let iota = T::simd_iota(&simd);
114
115 let head_start = T::simd_index_splat(&simd, T::Index::truncate(0));
116 let head_end = T::simd_index_splat(&simd, T::Index::truncate(0));
117 let tail_start = T::simd_index_splat(&simd, T::Index::truncate(0));
118 let tail_end = T::simd_index_splat(&simd, T::Index::truncate(*len % stride));
119
120 Self {
121 ctx: simd,
122 len,
123 offset: 0,
124 head_end: 0,
125 body_end: *len / stride,
126 tail_end: (*len + stride - 1) / stride,
127 head_mask: T::simd_and_mask(
128 &simd,
129 T::simd_index_greater_than_or_equal(&simd, iota, head_start),
130 T::simd_index_less_than(&simd, iota, head_end),
131 ),
132 tail_mask: T::simd_and_mask(
133 &simd,
134 T::simd_index_greater_than_or_equal(&simd, iota, tail_start),
135 T::simd_index_less_than(&simd, iota, tail_end),
136 ),
137 head_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(0), T::Index::truncate(0)),
138 tail_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(0), T::Index::truncate(*len % stride)),
139 }
140 }
141
142 #[inline(always)]
143 pub fn new_align(simd: T::SimdCtx<S>, len: Dim<'N>, align_offset: usize) -> Self {
144 core::assert!(try_const! { matches!(T::SIMD_CAPABILITIES, SimdCapabilities::Simd) });
145
146 let stride = core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>();
147 let align_offset = align_offset % stride;
148 let iota = T::simd_iota(&simd);
149
150 if align_offset == 0 {
151 Self::new(simd, len)
152 } else {
153 let offset = stride - align_offset;
154 let full_len = offset + *len;
155
156 let head_start = T::simd_index_splat(&simd, T::Index::truncate(offset));
157 let head_end = T::simd_index_splat(&simd, T::Index::truncate(stride));
158 let tail_start = T::simd_index_splat(&simd, T::Index::truncate(0));
159 let tail_end = T::simd_index_splat(&simd, T::Index::truncate(full_len % stride));
160
161 if align_offset <= *len {
162 Self {
163 ctx: simd,
164 len,
165 offset,
166 head_end: 1,
167 body_end: full_len / stride,
168 tail_end: (full_len + stride - 1) / stride,
169 head_mask: T::simd_and_mask(
170 &simd,
171 T::simd_index_greater_than_or_equal(&simd, iota, head_start),
172 T::simd_index_less_than(&simd, iota, head_end),
173 ),
174 tail_mask: T::simd_and_mask(
175 &simd,
176 T::simd_index_greater_than_or_equal(&simd, iota, tail_start),
177 T::simd_index_less_than(&simd, iota, tail_end),
178 ),
179 head_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(offset), T::Index::truncate(stride)),
180 tail_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(0), T::Index::truncate(full_len % stride)),
181 }
182 } else {
183 let head_start = T::simd_index_splat(&simd, T::Index::truncate(offset));
184 let head_end = T::simd_index_splat(&simd, T::Index::truncate(full_len % stride));
185 let tail_start = T::simd_index_splat(&simd, T::Index::truncate(0));
186 let tail_end = T::simd_index_splat(&simd, T::Index::truncate(0));
187
188 Self {
189 ctx: simd,
190 len,
191 offset,
192 head_end: 1,
193 body_end: 1,
194 tail_end: 1,
195 head_mask: T::simd_and_mask(
196 &simd,
197 T::simd_index_greater_than_or_equal(&simd, iota, head_start),
198 T::simd_index_less_than(&simd, iota, head_end),
199 ),
200 tail_mask: T::simd_and_mask(
201 &simd,
202 T::simd_index_greater_than_or_equal(&simd, iota, tail_start),
203 T::simd_index_less_than(&simd, iota, tail_end),
204 ),
205 head_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(offset), T::Index::truncate(full_len % stride)),
206 tail_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(0), T::Index::truncate(0)),
207 }
208 }
209 }
210 }
211
212 #[inline]
213 pub fn offset(&self) -> usize {
214 self.offset
215 }
216
217 #[inline(always)]
218 pub fn new_force_mask(simd: T::SimdCtx<S>, len: Dim<'N>) -> Self {
219 core::assert!(try_const! { matches!(T::SIMD_CAPABILITIES, SimdCapabilities::Simd) });
220
221 crate::assert!(*len != 0);
222 let new_len = *len - 1;
223
224 let stride = core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>();
225 let iota = T::simd_iota(&simd);
226
227 let head_start = T::simd_index_splat(&simd, T::Index::truncate(0));
228 let head_end = T::simd_index_splat(&simd, T::Index::truncate(0));
229 let tail_start = T::simd_index_splat(&simd, T::Index::truncate(0));
230 let tail_end = T::simd_index_splat(&simd, T::Index::truncate((new_len % stride) + 1));
231
232 Self {
233 ctx: simd,
234 len,
235 offset: 0,
236 head_end: 0,
237 body_end: new_len / stride,
238 tail_end: new_len / stride + 1,
239 head_mask: T::simd_and_mask(
240 &simd,
241 T::simd_index_greater_than_or_equal(&simd, iota, head_start),
242 T::simd_index_less_than(&simd, iota, head_end),
243 ),
244 tail_mask: T::simd_and_mask(
245 &simd,
246 T::simd_index_greater_than_or_equal(&simd, iota, tail_start),
247 T::simd_index_less_than(&simd, iota, tail_end),
248 ),
249 head_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(0), T::Index::truncate(0)),
250 tail_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(0), T::Index::truncate((new_len % stride) + 1)),
251 }
252 }
253
254 #[inline(always)]
255 pub fn read<I: SimdIndex<'N, T, S>>(&self, slice: ColRef<'_, T, Dim<'N>, ContiguousFwd>, index: I) -> T::SimdVec<S> {
256 I::read(self, slice, index)
257 }
258
259 #[inline(always)]
260 pub fn write<I: SimdIndex<'N, T, S>>(&self, slice: ColMut<'_, T, Dim<'N>, ContiguousFwd>, index: I, value: T::SimdVec<S>) {
261 I::write(self, slice, index, value)
262 }
263
264 #[inline(always)]
265 pub fn head_mask(&self) -> T::SimdMask<S> {
266 self.head_mask
267 }
268
269 #[inline(always)]
270 pub fn tail_mask(&self) -> T::SimdMask<S> {
271 self.tail_mask
272 }
273
274 #[inline]
275 pub fn indices(
276 &self,
277 ) -> (
278 Option<SimdHead<'N, T, S>>,
279 impl Clone + ExactSizeIterator + DoubleEndedIterator<Item = SimdBody<'N, T, S>>,
280 Option<SimdTail<'N, T, S>>,
281 ) {
282 macro_rules! stride {
283 () => {
284 core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>()
285 };
286 }
287
288 let offset = -(self.offset as isize);
289 (
290 if 0 == self.head_end {
291 None
292 } else {
293 Some(SimdHead {
294 start: offset,
295 mask: PhantomData,
296 })
297 },
298 (self.head_end..self.body_end).map(
299 #[inline(always)]
300 move |i| SimdBody {
301 start: offset + (i * stride!()) as isize,
302 mask: PhantomData,
303 },
304 ),
305 if self.body_end == self.tail_end {
306 None
307 } else {
308 Some(SimdTail {
309 start: offset + (self.body_end * stride!()) as isize,
310 mask: PhantomData,
311 })
312 },
313 )
314 }
315
316 #[inline]
317 pub fn batch_indices<const BATCH: usize>(
318 &self,
319 ) -> (
320 Option<SimdHead<'N, T, S>>,
321 impl Clone + ExactSizeIterator + DoubleEndedIterator<Item = [SimdBody<'N, T, S>; BATCH]>,
322 impl Clone + ExactSizeIterator + DoubleEndedIterator<Item = SimdBody<'N, T, S>>,
323 Option<SimdTail<'N, T, S>>,
324 ) {
325 macro_rules! stride {
326 () => {
327 core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>()
328 };
329 }
330
331 let len = self.body_end - self.head_end;
332
333 let offset = -(self.offset as isize);
334
335 (
336 if 0 == self.head_end {
337 None
338 } else {
339 Some(SimdHead {
340 start: offset,
341 mask: PhantomData,
342 })
343 },
344 (self.head_end..self.head_end + len / BATCH * BATCH)
345 .map(move |i| {
346 core::array::from_fn(
347 #[inline(always)]
348 |k| SimdBody {
349 start: offset + ((i + k) * stride!()) as isize,
350 mask: PhantomData,
351 },
352 )
353 })
354 .step_by(BATCH),
355 (self.head_end + len / BATCH * BATCH..self.body_end).map(
356 #[inline(always)]
357 move |i| SimdBody {
358 start: offset + (i * stride!()) as isize,
359 mask: PhantomData,
360 },
361 ),
362 if self.body_end == self.tail_end {
363 None
364 } else {
365 Some(SimdTail {
366 start: offset + (self.body_end * stride!()) as isize,
367 mask: PhantomData,
368 })
369 },
370 )
371 }
372}
373
374#[repr(transparent)]
375#[derive(Debug)]
376pub struct SimdBody<'N, T: ComplexField, S: Simd> {
377 start: isize,
378 mask: PhantomData<(Idx<'N>, T::SimdMask<S>)>,
379}
380
381impl<T: ComplexField, S: Simd> SimdBody<'_, T, S> {
382 pub fn offset(&self) -> isize {
383 self.start
384 }
385}
386
387#[repr(transparent)]
388#[derive(Debug)]
389pub struct SimdHead<'N, T: ComplexField, S: Simd> {
390 start: isize,
391 mask: PhantomData<(Idx<'N>, T::SimdMask<S>)>,
392}
393#[repr(transparent)]
394#[derive(Debug)]
395pub struct SimdTail<'N, T: ComplexField, S: Simd> {
396 start: isize,
397 mask: PhantomData<(Idx<'N>, T::SimdMask<S>)>,
398}
399
400impl<'N, T: ComplexField, S: Simd> Copy for SimdBody<'N, T, S> {}
401impl<'N, T: ComplexField, S: Simd> Clone for SimdBody<'N, T, S> {
402 #[inline]
403 fn clone(&self) -> Self {
404 *self
405 }
406}
407impl<'N, T: ComplexField, S: Simd> Copy for SimdHead<'N, T, S> {}
408impl<'N, T: ComplexField, S: Simd> Clone for SimdHead<'N, T, S> {
409 #[inline]
410 fn clone(&self) -> Self {
411 *self
412 }
413}
414impl<'N, T: ComplexField, S: Simd> Copy for SimdTail<'N, T, S> {}
415impl<'N, T: ComplexField, S: Simd> Clone for SimdTail<'N, T, S> {
416 #[inline]
417 fn clone(&self) -> Self {
418 *self
419 }
420}