faer/
io.rs

1use crate::prelude::*;
2
3/// npy format conversions
4#[cfg(feature = "npy")]
5pub mod npy {
6	use super::*;
7
8	/// memory view over a buffer in `npy` format
9	pub struct Npy<'a> {
10		aligned_bytes: &'a [u8],
11		nrows: usize,
12		ncols: usize,
13		prefix_len: usize,
14		dtype: NpyDType,
15		fortran_order: bool,
16	}
17
18	/// data type of an `npy` buffer
19	#[derive(Debug, Copy, Clone, PartialEq, Eq)]
20	pub enum NpyDType {
21		/// 32-bit floating point
22		F32,
23		/// 64-bit floating point
24		F64,
25		/// 32-bit complex floating point
26		C32,
27		/// 64-bit complex floating point
28		C64,
29		/// unknown type
30		Other,
31	}
32
33	/// trait implemented for native types that can be read from a `npy` buffer
34	pub trait FromNpy: bytemuck::Pod {
35		/// data type of the buffer data
36		const DTYPE: NpyDType;
37	}
38
39	impl FromNpy for f32 {
40		const DTYPE: NpyDType = NpyDType::F32;
41	}
42	impl FromNpy for f64 {
43		const DTYPE: NpyDType = NpyDType::F64;
44	}
45	impl FromNpy for c32 {
46		const DTYPE: NpyDType = NpyDType::C32;
47	}
48	impl FromNpy for c64 {
49		const DTYPE: NpyDType = NpyDType::C64;
50	}
51
52	impl<'a> Npy<'a> {
53		fn parse_npyz(data: &[u8], npyz: npyz::NpyFile<&[u8]>) -> Result<(NpyDType, usize, usize, usize, bool), std::io::Error> {
54			let ver_major = data[6] - b'\x00';
55			let length = if ver_major <= 1 {
56				2usize
57			} else if ver_major <= 3 {
58				4usize
59			} else {
60				return Err(std::io::Error::new(std::io::ErrorKind::Other, "unsupported version"));
61			};
62			let header_len = if length == 2 {
63				u16::from_le_bytes(data[8..10].try_into().unwrap()) as usize
64			} else {
65				u32::from_le_bytes(data[8..12].try_into().unwrap()) as usize
66			};
67			let dtype = || -> NpyDType {
68				match npyz.dtype() {
69					npyz::DType::Plain(str) => {
70						let is_complex = match str.type_char() {
71							npyz::TypeChar::Float => false,
72							npyz::TypeChar::Complex => true,
73							_ => return NpyDType::Other,
74						};
75
76						let byte_size = str.size_field();
77						if byte_size == 8 && is_complex {
78							NpyDType::C32
79						} else if byte_size == 16 && is_complex {
80							NpyDType::C64
81						} else if byte_size == 4 && !is_complex {
82							NpyDType::F32
83						} else if byte_size == 16 && !is_complex {
84							NpyDType::F64
85						} else {
86							NpyDType::Other
87						}
88					},
89					_ => NpyDType::Other,
90				}
91			};
92
93			let dtype = dtype();
94			let order = npyz.header().order();
95			let shape = npyz.shape();
96			let nrows = shape.get(0).copied().unwrap_or(1) as usize;
97			let ncols = shape.get(1).copied().unwrap_or(1) as usize;
98			let prefix_len = 8 + length + header_len;
99			let fortran_order = order == npyz::Order::Fortran;
100			Ok((dtype, nrows, ncols, prefix_len, fortran_order))
101		}
102
103		/// parse a npy file from a memory buffer
104		#[inline]
105		pub fn new(data: &'a [u8]) -> Result<Self, std::io::Error> {
106			let npyz = npyz::NpyFile::new(data)?;
107
108			let (dtype, nrows, ncols, prefix_len, fortran_order) = Self::parse_npyz(data, npyz)?;
109
110			Ok(Self {
111				aligned_bytes: data,
112				prefix_len,
113				nrows,
114				ncols,
115				dtype,
116				fortran_order,
117			})
118		}
119
120		/// returns the data type of the memory buffer
121		#[inline]
122		pub fn dtype(&self) -> NpyDType {
123			self.dtype
124		}
125
126		/// checks if the memory buffer is aligned, in which case the data can be referenced
127		/// in-place
128		#[inline]
129		pub fn is_aligned(&self) -> bool {
130			self.aligned_bytes.as_ptr().align_offset(64) == 0
131		}
132
133		/// if the memory buffer is aligned, and the provided type matches the one stored in the
134		/// buffer, returns a matrix view over the data
135		#[inline]
136		pub fn as_aligned_ref<T: FromNpy>(&self) -> MatRef<'_, T> {
137			assert!(self.is_aligned());
138			assert!(self.dtype == T::DTYPE);
139
140			if self.fortran_order {
141				MatRef::from_column_major_slice(bytemuck::cast_slice(&self.aligned_bytes[self.prefix_len..]), self.nrows, self.ncols)
142			} else {
143				MatRef::from_row_major_slice(bytemuck::cast_slice(&self.aligned_bytes[self.prefix_len..]), self.nrows, self.ncols)
144			}
145		}
146
147		/// if the provided type matches the one stored in the buffer, returns a matrix containing
148		/// the data
149		#[inline]
150		pub fn to_mat<T: FromNpy>(&self) -> Mat<T> {
151			assert!(self.dtype == T::DTYPE);
152
153			let mut mat = Mat::<T>::with_capacity(self.nrows, self.ncols);
154			unsafe { mat.set_dims(self.nrows, self.ncols) };
155
156			let data = &self.aligned_bytes[self.prefix_len..];
157
158			if self.fortran_order {
159				for j in 0..self.ncols {
160					bytemuck::cast_slice_mut(mat.col_as_slice_mut(j))
161						.copy_from_slice(&data[j * self.nrows * core::mem::size_of::<T>()..][..self.nrows * core::mem::size_of::<T>()])
162				}
163			} else {
164				for j in 0..self.ncols {
165					for i in 0..self.nrows {
166						mat[(i, j)] =
167							bytemuck::cast_slice::<u8, T>(&data[(i * self.ncols + j) * core::mem::size_of::<T>()..][..core::mem::size_of::<T>()])[0];
168					}
169				}
170			};
171
172			mat
173		}
174	}
175}