1use crate::internal_prelude::*;
2use core::ops::Mul;
3use faer_traits::Real;
4
5extern crate alloc;
6
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub struct ApproxEq<T> {
9 pub abs_tol: T,
10 pub rel_tol: T,
11}
12
13pub struct CwiseMat<Cmp>(pub Cmp);
14
15impl<T: RealField> ApproxEq<T> {
16 #[math]
17 #[inline]
18 pub fn eps() -> Self {
19 Self {
20 abs_tol: eps::<T>() * from_f64::<T>(128.0),
21 rel_tol: eps::<T>() * from_f64::<T>(128.0),
22 }
23 }
24}
25
26impl<T: RealField> Mul<T> for ApproxEq<T> {
27 type Output = ApproxEq<T>;
28
29 #[inline]
30 #[math]
31 fn mul(self, rhs: Real<T>) -> Self::Output {
32 ApproxEq {
33 abs_tol: self.abs_tol * rhs,
34 rel_tol: self.rel_tol * rhs,
35 }
36 }
37}
38
39#[derive(Copy, Clone, Debug)]
40pub struct ApproxEqError;
41
42#[derive(Clone, Debug)]
43pub enum CwiseMatError<Rows: Shape, Cols: Shape, Error> {
44 DimMismatch,
45 Elements(alloc::vec::Vec<(crate::Idx<Rows>, crate::Idx<Cols>, Error)>),
46}
47
48#[derive(Clone, Debug)]
49pub enum CwiseColError<Rows: Shape, Error> {
50 DimMismatch,
51 Elements(alloc::vec::Vec<(crate::Idx<Rows>, Error)>),
52}
53
54#[derive(Clone, Debug)]
55pub enum CwiseRowError<Cols: Shape, Error> {
56 DimMismatch,
57 Elements(alloc::vec::Vec<(crate::Idx<Cols>, Error)>),
58}
59
60impl<
61 T: ComplexField,
62 Rows: Shape,
63 Cols: Shape,
64 L: AsMatRef<T = T, Rows = Rows, Cols = Cols>,
65 R: AsMatRef<T = T, Rows = Rows, Cols = Cols>,
66 Error: equator::CmpDisplay<Cmp, T, T>,
67 Cmp: equator::Cmp<T, T, Error = Error>,
68> equator::CmpError<CwiseMat<Cmp>, L, R> for CwiseMat<Cmp>
69{
70 type Error = CwiseMatError<Rows, Cols, Error>;
71}
72
73impl<R: RealField, T: ComplexField<Real = R>> equator::CmpError<ApproxEq<R>, T, T> for ApproxEq<R> {
74 type Error = ApproxEqError;
75}
76
77impl<R: RealField, T: ComplexField<Real = R>> equator::CmpDisplay<ApproxEq<R>, T, T> for ApproxEqError {
78 #[math]
79 fn fmt(
80 &self,
81 cmp: &ApproxEq<R>,
82 lhs: &T,
83 mut lhs_source: &str,
84 lhs_debug: &dyn core::fmt::Debug,
85 rhs: &T,
86 rhs_source: &str,
87 rhs_debug: &dyn core::fmt::Debug,
88 f: &mut core::fmt::Formatter,
89 ) -> core::fmt::Result {
90 let ApproxEq { abs_tol, rel_tol } = cmp;
91
92 if let Some(source) = lhs_source.strip_prefix("__skip_prologue") {
93 lhs_source = source;
94 } else {
95 writeln!(
96 f,
97 "Assertion failed: {lhs_source} ~ {rhs_source}\nwith absolute tolerance = {abs_tol:?}\nwith relative tolerance = {rel_tol:?}"
98 )?;
99 }
100
101 let distance = abs(*lhs - *rhs);
102
103 write!(f, "- {lhs_source} = {lhs_debug:?}\n")?;
104 write!(f, "- {rhs_source} = {rhs_debug:?}\n")?;
105 write!(f, "- distance = {distance:?}")
106 }
107}
108
109impl<R: RealField, T: ComplexField<Real = R>> equator::Cmp<T, T> for ApproxEq<R> {
110 #[math]
111 fn test(&self, lhs: &T, rhs: &T) -> Result<(), Self::Error> {
112 let Self { abs_tol, rel_tol } = self;
113
114 let diff = abs(*lhs - *rhs);
115 let max = max(abs(*lhs), abs(*rhs));
116
117 if (max == zero() && diff <= *abs_tol) || (diff <= *abs_tol || diff <= *rel_tol * max) {
118 Ok(())
119 } else {
120 Err(ApproxEqError)
121 }
122 }
123}
124
125impl<
126 T: ComplexField,
127 Rows: Shape,
128 Cols: Shape,
129 L: AsMatRef<T = T, Rows = Rows, Cols = Cols>,
130 R: AsMatRef<T = T, Rows = Rows, Cols = Cols>,
131 Error: equator::CmpDisplay<Cmp, T, T>,
132 Cmp: equator::Cmp<T, T, Error = Error>,
133> equator::CmpDisplay<CwiseMat<Cmp>, L, R> for CwiseMatError<Rows, Cols, Error>
134{
135 #[math]
136 fn fmt(
137 &self,
138 cmp: &CwiseMat<Cmp>,
139 lhs: &L,
140 lhs_source: &str,
141 _: &dyn core::fmt::Debug,
142 rhs: &R,
143 rhs_source: &str,
144 _: &dyn core::fmt::Debug,
145 f: &mut core::fmt::Formatter,
146 ) -> core::fmt::Result {
147 let lhs = lhs.as_mat_ref();
148 let rhs = rhs.as_mat_ref();
149 match self {
150 Self::DimMismatch => {
151 let lhs_nrows = lhs.nrows();
152 let lhs_ncols = lhs.ncols();
153 let rhs_nrows = rhs.nrows();
154 let rhs_ncols = rhs.ncols();
155
156 writeln!(f, "Assertion failed: {lhs_source} ~ {rhs_source}\n")?;
157 write!(f, "- {lhs_source} = Mat[{lhs_nrows:?}, {lhs_ncols:?}]\n")?;
158 write!(f, "- {rhs_source} = Mat[{rhs_nrows:?}, {rhs_ncols:?}]")?;
159 },
160
161 Self::Elements(indices) => {
162 let mut prefix = "";
163
164 let mut count = 0;
165 for (i, j, e) in indices {
166 if count >= 10 {
167 write!(f, "\n\n... ({} mismatches omitted)\n\n", indices.len() - count,)?;
168 break;
169 }
170 count += 1;
171
172 let i = *i;
173 let j = *j;
174 let lhs = lhs.at(i, j).clone();
175 let rhs = rhs.at(i, j).clone();
176
177 e.fmt(
178 &cmp.0,
179 &lhs,
180 &alloc::format!("{prefix}{lhs_source} at ({i:?}, {j:?})"),
181 crate::hacks::hijack_debug(&lhs),
182 &rhs,
183 &alloc::format!("{rhs_source} at ({i:?}, {j:?})"),
184 crate::hacks::hijack_debug(&rhs),
185 f,
186 )?;
187 write!(f, "\n\n")?;
188 prefix = "__skip_prologue"
189 }
190 },
191 }
192 Ok(())
193 }
194}
195
196impl<
197 T: ComplexField,
198 Rows: Shape,
199 Cols: Shape,
200 L: AsMatRef<T = T, Rows = Rows, Cols = Cols>,
201 R: AsMatRef<T = T, Rows = Rows, Cols = Cols>,
202 Error: equator::CmpDisplay<Cmp, T, T>,
203 Cmp: equator::Cmp<T, T, Error = Error>,
204> equator::Cmp<L, R> for CwiseMat<Cmp>
205{
206 fn test(&self, lhs: &L, rhs: &R) -> Result<(), Self::Error> {
207 let lhs = lhs.as_mat_ref();
208 let rhs = rhs.as_mat_ref();
209
210 if lhs.nrows() != rhs.nrows() || lhs.ncols() != rhs.ncols() {
211 return Err(CwiseMatError::DimMismatch);
212 }
213
214 let mut indices = alloc::vec::Vec::new();
215 for j in 0..lhs.ncols().unbound() {
216 let j = lhs.ncols().checked_idx(j);
217 for i in 0..lhs.nrows().unbound() {
218 let i = lhs.nrows().checked_idx(i);
219
220 if let Err(err) = self.0.test(&lhs.at(i, j).clone(), &rhs.at(i, j).clone()) {
221 indices.push((i, j, err));
222 }
223 }
224 }
225
226 if indices.is_empty() {
227 Ok(())
228 } else {
229 Err(CwiseMatError::Elements(indices))
230 }
231 }
232}