1use crate::prelude::*;
2
3#[cfg(feature = "npy")]
5pub mod npy {
6 use super::*;
7
8 pub struct Npy<'a> {
10 aligned_bytes: &'a [u8],
11 nrows: usize,
12 ncols: usize,
13 prefix_len: usize,
14 dtype: NpyDType,
15 fortran_order: bool,
16 }
17
18 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
20 pub enum NpyDType {
21 F32,
23 F64,
25 C32,
27 C64,
29 Other,
31 }
32
33 pub trait FromNpy: bytemuck::Pod {
35 const DTYPE: NpyDType;
37 }
38
39 impl FromNpy for f32 {
40 const DTYPE: NpyDType = NpyDType::F32;
41 }
42 impl FromNpy for f64 {
43 const DTYPE: NpyDType = NpyDType::F64;
44 }
45 impl FromNpy for c32 {
46 const DTYPE: NpyDType = NpyDType::C32;
47 }
48 impl FromNpy for c64 {
49 const DTYPE: NpyDType = NpyDType::C64;
50 }
51
52 impl<'a> Npy<'a> {
53 fn parse_npyz(data: &[u8], npyz: npyz::NpyFile<&[u8]>) -> Result<(NpyDType, usize, usize, usize, bool), std::io::Error> {
54 let ver_major = data[6] - b'\x00';
55 let length = if ver_major <= 1 {
56 2usize
57 } else if ver_major <= 3 {
58 4usize
59 } else {
60 return Err(std::io::Error::new(std::io::ErrorKind::Other, "unsupported version"));
61 };
62 let header_len = if length == 2 {
63 u16::from_le_bytes(data[8..10].try_into().unwrap()) as usize
64 } else {
65 u32::from_le_bytes(data[8..12].try_into().unwrap()) as usize
66 };
67 let dtype = || -> NpyDType {
68 match npyz.dtype() {
69 npyz::DType::Plain(str) => {
70 let is_complex = match str.type_char() {
71 npyz::TypeChar::Float => false,
72 npyz::TypeChar::Complex => true,
73 _ => return NpyDType::Other,
74 };
75
76 let byte_size = str.size_field();
77 if byte_size == 8 && is_complex {
78 NpyDType::C32
79 } else if byte_size == 16 && is_complex {
80 NpyDType::C64
81 } else if byte_size == 4 && !is_complex {
82 NpyDType::F32
83 } else if byte_size == 16 && !is_complex {
84 NpyDType::F64
85 } else {
86 NpyDType::Other
87 }
88 },
89 _ => NpyDType::Other,
90 }
91 };
92
93 let dtype = dtype();
94 let order = npyz.header().order();
95 let shape = npyz.shape();
96 let nrows = shape.get(0).copied().unwrap_or(1) as usize;
97 let ncols = shape.get(1).copied().unwrap_or(1) as usize;
98 let prefix_len = 8 + length + header_len;
99 let fortran_order = order == npyz::Order::Fortran;
100 Ok((dtype, nrows, ncols, prefix_len, fortran_order))
101 }
102
103 #[inline]
105 pub fn new(data: &'a [u8]) -> Result<Self, std::io::Error> {
106 let npyz = npyz::NpyFile::new(data)?;
107
108 let (dtype, nrows, ncols, prefix_len, fortran_order) = Self::parse_npyz(data, npyz)?;
109
110 Ok(Self {
111 aligned_bytes: data,
112 prefix_len,
113 nrows,
114 ncols,
115 dtype,
116 fortran_order,
117 })
118 }
119
120 #[inline]
122 pub fn dtype(&self) -> NpyDType {
123 self.dtype
124 }
125
126 #[inline]
129 pub fn is_aligned(&self) -> bool {
130 self.aligned_bytes.as_ptr().align_offset(64) == 0
131 }
132
133 #[inline]
136 pub fn as_aligned_ref<T: FromNpy>(&self) -> MatRef<'_, T> {
137 assert!(self.is_aligned());
138 assert!(self.dtype == T::DTYPE);
139
140 if self.fortran_order {
141 MatRef::from_column_major_slice(bytemuck::cast_slice(&self.aligned_bytes[self.prefix_len..]), self.nrows, self.ncols)
142 } else {
143 MatRef::from_row_major_slice(bytemuck::cast_slice(&self.aligned_bytes[self.prefix_len..]), self.nrows, self.ncols)
144 }
145 }
146
147 #[inline]
150 pub fn to_mat<T: FromNpy>(&self) -> Mat<T> {
151 assert!(self.dtype == T::DTYPE);
152
153 let mut mat = Mat::<T>::with_capacity(self.nrows, self.ncols);
154 unsafe { mat.set_dims(self.nrows, self.ncols) };
155
156 let data = &self.aligned_bytes[self.prefix_len..];
157
158 if self.fortran_order {
159 for j in 0..self.ncols {
160 bytemuck::cast_slice_mut(mat.col_as_slice_mut(j))
161 .copy_from_slice(&data[j * self.nrows * core::mem::size_of::<T>()..][..self.nrows * core::mem::size_of::<T>()])
162 }
163 } else {
164 for j in 0..self.ncols {
165 for i in 0..self.nrows {
166 mat[(i, j)] =
167 bytemuck::cast_slice::<u8, T>(&data[(i * self.ncols + j) * core::mem::size_of::<T>()..][..core::mem::size_of::<T>()])[0];
168 }
169 }
170 };
171
172 mat
173 }
174 }
175}