faer/utils/
simd.rs

1use super::bound::{Dim, Idx};
2use crate::internal_prelude::*;
3use core::marker::PhantomData;
4use faer_traits::SimdCapabilities;
5use pulp::Simd;
6
7pub struct SimdCtx<'N, T: ComplexField, S: Simd> {
8	pub ctx: T::SimdCtx<S>,
9	pub len: Dim<'N>,
10	offset: usize,
11	head_end: usize,
12	body_end: usize,
13	tail_end: usize,
14	head_mask: T::SimdMask<S>,
15	tail_mask: T::SimdMask<S>,
16	head_mem_mask: T::SimdMemMask<S>,
17	tail_mem_mask: T::SimdMemMask<S>,
18}
19
20impl<'N, T: ComplexField, S: Simd> core::fmt::Debug for SimdCtx<'N, T, S> {
21	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
22		f.debug_struct("SimdCtx")
23			.field("len", &self.len)
24			.field("offset", &self.offset)
25			.field("head_end", &self.head_end)
26			.field("body_end", &self.body_end)
27			.field("tail_end", &self.tail_end)
28			.field("head_mask", &self.head_mask)
29			.field("tail_mask", &self.tail_mask)
30			.finish_non_exhaustive()
31	}
32}
33
34impl<T: ComplexField, S: Simd> Copy for SimdCtx<'_, T, S> {}
35impl<T: ComplexField, S: Simd> Clone for SimdCtx<'_, T, S> {
36	#[inline]
37	fn clone(&self) -> Self {
38		*self
39	}
40}
41
42impl<T: ComplexField, S: Simd> core::ops::Deref for SimdCtx<'_, T, S> {
43	type Target = faer_traits::SimdCtx<T, S>;
44
45	#[inline(always)]
46	fn deref(&self) -> &Self::Target {
47		Self::Target::new(&self.ctx)
48	}
49}
50
51pub trait SimdIndex<'N, T: ComplexField, S: Simd> {
52	fn read(simd: &SimdCtx<'N, T, S>, slice: ColRef<'_, T, Dim<'N>, ContiguousFwd>, index: Self) -> T::SimdVec<S>;
53
54	fn write(simd: &SimdCtx<'N, T, S>, slice: ColMut<'_, T, Dim<'N>, ContiguousFwd>, index: Self, value: T::SimdVec<S>);
55}
56
57impl<'N, T: ComplexField, S: Simd> SimdIndex<'N, T, S> for SimdBody<'N, T, S> {
58	#[inline(always)]
59	fn read(simd: &SimdCtx<'N, T, S>, slice: ColRef<'_, T, Dim<'N>, ContiguousFwd>, index: Self) -> T::SimdVec<S> {
60		unsafe { simd.load(&*(slice.as_ptr().wrapping_offset(index.start) as *const T::SimdVec<S>)) }
61	}
62
63	#[inline(always)]
64	fn write(simd: &SimdCtx<'N, T, S>, slice: ColMut<'_, T, Dim<'N>, ContiguousFwd>, index: Self, value: T::SimdVec<S>) {
65		unsafe {
66			simd.store(&mut *(slice.as_ptr_mut().wrapping_offset(index.start) as *mut T::SimdVec<S>), value);
67		}
68	}
69}
70
71impl<'N, T: ComplexField, S: Simd> SimdIndex<'N, T, S> for SimdHead<'N, T, S> {
72	#[inline(always)]
73	fn read(simd: &SimdCtx<'N, T, S>, slice: ColRef<'_, T, Dim<'N>, ContiguousFwd>, index: Self) -> T::SimdVec<S> {
74		unsafe { simd.mask_load(simd.head_mem_mask, slice.as_ptr().wrapping_offset(index.start) as *const T::SimdVec<S>) }
75	}
76
77	#[inline(always)]
78	fn write(simd: &SimdCtx<'N, T, S>, slice: ColMut<'_, T, Dim<'N>, ContiguousFwd>, index: Self, value: T::SimdVec<S>) {
79		unsafe {
80			simd.mask_store(
81				simd.head_mem_mask,
82				slice.as_ptr_mut().wrapping_offset(index.start) as *mut T::SimdVec<S>,
83				value,
84			);
85		}
86	}
87}
88
89impl<'N, T: ComplexField, S: Simd> SimdIndex<'N, T, S> for SimdTail<'N, T, S> {
90	#[inline(always)]
91	fn read(simd: &SimdCtx<'N, T, S>, slice: ColRef<'_, T, Dim<'N>, ContiguousFwd>, index: Self) -> T::SimdVec<S> {
92		unsafe { simd.mask_load(simd.tail_mem_mask, slice.as_ptr().wrapping_offset(index.start) as *const T::SimdVec<S>) }
93	}
94
95	#[inline(always)]
96	fn write(simd: &SimdCtx<'N, T, S>, slice: ColMut<'_, T, Dim<'N>, ContiguousFwd>, index: Self, value: T::SimdVec<S>) {
97		unsafe {
98			simd.mask_store(
99				simd.tail_mem_mask,
100				slice.as_ptr_mut().wrapping_offset(index.start) as *mut T::SimdVec<S>,
101				value,
102			);
103		}
104	}
105}
106
107impl<'N, T: ComplexField, S: Simd> SimdCtx<'N, T, S> {
108	#[inline(always)]
109	pub fn new(simd: T::SimdCtx<S>, len: Dim<'N>) -> Self {
110		core::assert!(try_const! { matches!(T::SIMD_CAPABILITIES, SimdCapabilities::Simd) });
111
112		let stride = core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>();
113		let iota = T::simd_iota(&simd);
114
115		let head_start = T::simd_index_splat(&simd, T::Index::truncate(0));
116		let head_end = T::simd_index_splat(&simd, T::Index::truncate(0));
117		let tail_start = T::simd_index_splat(&simd, T::Index::truncate(0));
118		let tail_end = T::simd_index_splat(&simd, T::Index::truncate(*len % stride));
119
120		Self {
121			ctx: simd,
122			len,
123			offset: 0,
124			head_end: 0,
125			body_end: *len / stride,
126			tail_end: (*len + stride - 1) / stride,
127			head_mask: T::simd_and_mask(
128				&simd,
129				T::simd_index_greater_than_or_equal(&simd, iota, head_start),
130				T::simd_index_less_than(&simd, iota, head_end),
131			),
132			tail_mask: T::simd_and_mask(
133				&simd,
134				T::simd_index_greater_than_or_equal(&simd, iota, tail_start),
135				T::simd_index_less_than(&simd, iota, tail_end),
136			),
137			head_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(0), T::Index::truncate(0)),
138			tail_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(0), T::Index::truncate(*len % stride)),
139		}
140	}
141
142	#[inline(always)]
143	pub fn new_align(simd: T::SimdCtx<S>, len: Dim<'N>, align_offset: usize) -> Self {
144		core::assert!(try_const! { matches!(T::SIMD_CAPABILITIES, SimdCapabilities::Simd) });
145
146		let stride = core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>();
147		let align_offset = align_offset % stride;
148		let iota = T::simd_iota(&simd);
149
150		if align_offset == 0 {
151			Self::new(simd, len)
152		} else {
153			let offset = stride - align_offset;
154			let full_len = offset + *len;
155
156			let head_start = T::simd_index_splat(&simd, T::Index::truncate(offset));
157			let head_end = T::simd_index_splat(&simd, T::Index::truncate(stride));
158			let tail_start = T::simd_index_splat(&simd, T::Index::truncate(0));
159			let tail_end = T::simd_index_splat(&simd, T::Index::truncate(full_len % stride));
160
161			if align_offset <= *len {
162				Self {
163					ctx: simd,
164					len,
165					offset,
166					head_end: 1,
167					body_end: full_len / stride,
168					tail_end: (full_len + stride - 1) / stride,
169					head_mask: T::simd_and_mask(
170						&simd,
171						T::simd_index_greater_than_or_equal(&simd, iota, head_start),
172						T::simd_index_less_than(&simd, iota, head_end),
173					),
174					tail_mask: T::simd_and_mask(
175						&simd,
176						T::simd_index_greater_than_or_equal(&simd, iota, tail_start),
177						T::simd_index_less_than(&simd, iota, tail_end),
178					),
179					head_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(offset), T::Index::truncate(stride)),
180					tail_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(0), T::Index::truncate(full_len % stride)),
181				}
182			} else {
183				let head_start = T::simd_index_splat(&simd, T::Index::truncate(offset));
184				let head_end = T::simd_index_splat(&simd, T::Index::truncate(full_len % stride));
185				let tail_start = T::simd_index_splat(&simd, T::Index::truncate(0));
186				let tail_end = T::simd_index_splat(&simd, T::Index::truncate(0));
187
188				Self {
189					ctx: simd,
190					len,
191					offset,
192					head_end: 1,
193					body_end: 1,
194					tail_end: 1,
195					head_mask: T::simd_and_mask(
196						&simd,
197						T::simd_index_greater_than_or_equal(&simd, iota, head_start),
198						T::simd_index_less_than(&simd, iota, head_end),
199					),
200					tail_mask: T::simd_and_mask(
201						&simd,
202						T::simd_index_greater_than_or_equal(&simd, iota, tail_start),
203						T::simd_index_less_than(&simd, iota, tail_end),
204					),
205					head_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(offset), T::Index::truncate(full_len % stride)),
206					tail_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(0), T::Index::truncate(0)),
207				}
208			}
209		}
210	}
211
212	#[inline]
213	pub fn offset(&self) -> usize {
214		self.offset
215	}
216
217	#[inline(always)]
218	pub fn new_force_mask(simd: T::SimdCtx<S>, len: Dim<'N>) -> Self {
219		core::assert!(try_const! { matches!(T::SIMD_CAPABILITIES, SimdCapabilities::Simd) });
220
221		crate::assert!(*len != 0);
222		let new_len = *len - 1;
223
224		let stride = core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>();
225		let iota = T::simd_iota(&simd);
226
227		let head_start = T::simd_index_splat(&simd, T::Index::truncate(0));
228		let head_end = T::simd_index_splat(&simd, T::Index::truncate(0));
229		let tail_start = T::simd_index_splat(&simd, T::Index::truncate(0));
230		let tail_end = T::simd_index_splat(&simd, T::Index::truncate((new_len % stride) + 1));
231
232		Self {
233			ctx: simd,
234			len,
235			offset: 0,
236			head_end: 0,
237			body_end: new_len / stride,
238			tail_end: new_len / stride + 1,
239			head_mask: T::simd_and_mask(
240				&simd,
241				T::simd_index_greater_than_or_equal(&simd, iota, head_start),
242				T::simd_index_less_than(&simd, iota, head_end),
243			),
244			tail_mask: T::simd_and_mask(
245				&simd,
246				T::simd_index_greater_than_or_equal(&simd, iota, tail_start),
247				T::simd_index_less_than(&simd, iota, tail_end),
248			),
249			head_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(0), T::Index::truncate(0)),
250			tail_mem_mask: T::simd_mem_mask_between(&simd, T::Index::truncate(0), T::Index::truncate((new_len % stride) + 1)),
251		}
252	}
253
254	#[inline(always)]
255	pub fn read<I: SimdIndex<'N, T, S>>(&self, slice: ColRef<'_, T, Dim<'N>, ContiguousFwd>, index: I) -> T::SimdVec<S> {
256		I::read(self, slice, index)
257	}
258
259	#[inline(always)]
260	pub fn write<I: SimdIndex<'N, T, S>>(&self, slice: ColMut<'_, T, Dim<'N>, ContiguousFwd>, index: I, value: T::SimdVec<S>) {
261		I::write(self, slice, index, value)
262	}
263
264	#[inline(always)]
265	pub fn head_mask(&self) -> T::SimdMask<S> {
266		self.head_mask
267	}
268
269	#[inline(always)]
270	pub fn tail_mask(&self) -> T::SimdMask<S> {
271		self.tail_mask
272	}
273
274	#[inline]
275	pub fn indices(
276		&self,
277	) -> (
278		Option<SimdHead<'N, T, S>>,
279		impl Clone + ExactSizeIterator + DoubleEndedIterator<Item = SimdBody<'N, T, S>>,
280		Option<SimdTail<'N, T, S>>,
281	) {
282		macro_rules! stride {
283			() => {
284				core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>()
285			};
286		}
287
288		let offset = -(self.offset as isize);
289		(
290			if 0 == self.head_end {
291				None
292			} else {
293				Some(SimdHead {
294					start: offset,
295					mask: PhantomData,
296				})
297			},
298			(self.head_end..self.body_end).map(
299				#[inline(always)]
300				move |i| SimdBody {
301					start: offset + (i * stride!()) as isize,
302					mask: PhantomData,
303				},
304			),
305			if self.body_end == self.tail_end {
306				None
307			} else {
308				Some(SimdTail {
309					start: offset + (self.body_end * stride!()) as isize,
310					mask: PhantomData,
311				})
312			},
313		)
314	}
315
316	#[inline]
317	pub fn batch_indices<const BATCH: usize>(
318		&self,
319	) -> (
320		Option<SimdHead<'N, T, S>>,
321		impl Clone + ExactSizeIterator + DoubleEndedIterator<Item = [SimdBody<'N, T, S>; BATCH]>,
322		impl Clone + ExactSizeIterator + DoubleEndedIterator<Item = SimdBody<'N, T, S>>,
323		Option<SimdTail<'N, T, S>>,
324	) {
325		macro_rules! stride {
326			() => {
327				core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>()
328			};
329		}
330
331		let len = self.body_end - self.head_end;
332
333		let offset = -(self.offset as isize);
334
335		(
336			if 0 == self.head_end {
337				None
338			} else {
339				Some(SimdHead {
340					start: offset,
341					mask: PhantomData,
342				})
343			},
344			(self.head_end..self.head_end + len / BATCH * BATCH)
345				.map(move |i| {
346					core::array::from_fn(
347						#[inline(always)]
348						|k| SimdBody {
349							start: offset + ((i + k) * stride!()) as isize,
350							mask: PhantomData,
351						},
352					)
353				})
354				.step_by(BATCH),
355			(self.head_end + len / BATCH * BATCH..self.body_end).map(
356				#[inline(always)]
357				move |i| SimdBody {
358					start: offset + (i * stride!()) as isize,
359					mask: PhantomData,
360				},
361			),
362			if self.body_end == self.tail_end {
363				None
364			} else {
365				Some(SimdTail {
366					start: offset + (self.body_end * stride!()) as isize,
367					mask: PhantomData,
368				})
369			},
370		)
371	}
372}
373
374#[repr(transparent)]
375#[derive(Debug)]
376pub struct SimdBody<'N, T: ComplexField, S: Simd> {
377	start: isize,
378	mask: PhantomData<(Idx<'N>, T::SimdMask<S>)>,
379}
380
381impl<T: ComplexField, S: Simd> SimdBody<'_, T, S> {
382	pub fn offset(&self) -> isize {
383		self.start
384	}
385}
386
387#[repr(transparent)]
388#[derive(Debug)]
389pub struct SimdHead<'N, T: ComplexField, S: Simd> {
390	start: isize,
391	mask: PhantomData<(Idx<'N>, T::SimdMask<S>)>,
392}
393#[repr(transparent)]
394#[derive(Debug)]
395pub struct SimdTail<'N, T: ComplexField, S: Simd> {
396	start: isize,
397	mask: PhantomData<(Idx<'N>, T::SimdMask<S>)>,
398}
399
400impl<'N, T: ComplexField, S: Simd> Copy for SimdBody<'N, T, S> {}
401impl<'N, T: ComplexField, S: Simd> Clone for SimdBody<'N, T, S> {
402	#[inline]
403	fn clone(&self) -> Self {
404		*self
405	}
406}
407impl<'N, T: ComplexField, S: Simd> Copy for SimdHead<'N, T, S> {}
408impl<'N, T: ComplexField, S: Simd> Clone for SimdHead<'N, T, S> {
409	#[inline]
410	fn clone(&self) -> Self {
411		*self
412	}
413}
414impl<'N, T: ComplexField, S: Simd> Copy for SimdTail<'N, T, S> {}
415impl<'N, T: ComplexField, S: Simd> Clone for SimdTail<'N, T, S> {
416	#[inline]
417	fn clone(&self) -> Self {
418		*self
419	}
420}