faer/utils/
approx.rs

1use crate::internal_prelude::*;
2use core::ops::Mul;
3use faer_traits::Real;
4
5extern crate alloc;
6
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub struct ApproxEq<T> {
9	pub abs_tol: T,
10	pub rel_tol: T,
11}
12
13pub struct CwiseMat<Cmp>(pub Cmp);
14
15impl<T: RealField> ApproxEq<T> {
16	#[math]
17	#[inline]
18	pub fn eps() -> Self {
19		Self {
20			abs_tol: eps::<T>() * from_f64::<T>(128.0),
21			rel_tol: eps::<T>() * from_f64::<T>(128.0),
22		}
23	}
24}
25
26impl<T: RealField> Mul<T> for ApproxEq<T> {
27	type Output = ApproxEq<T>;
28
29	#[inline]
30	#[math]
31	fn mul(self, rhs: Real<T>) -> Self::Output {
32		ApproxEq {
33			abs_tol: self.abs_tol * rhs,
34			rel_tol: self.rel_tol * rhs,
35		}
36	}
37}
38
39#[derive(Copy, Clone, Debug)]
40pub struct ApproxEqError;
41
42#[derive(Clone, Debug)]
43pub enum CwiseMatError<Rows: Shape, Cols: Shape, Error> {
44	DimMismatch,
45	Elements(alloc::vec::Vec<(crate::Idx<Rows>, crate::Idx<Cols>, Error)>),
46}
47
48#[derive(Clone, Debug)]
49pub enum CwiseColError<Rows: Shape, Error> {
50	DimMismatch,
51	Elements(alloc::vec::Vec<(crate::Idx<Rows>, Error)>),
52}
53
54#[derive(Clone, Debug)]
55pub enum CwiseRowError<Cols: Shape, Error> {
56	DimMismatch,
57	Elements(alloc::vec::Vec<(crate::Idx<Cols>, Error)>),
58}
59
60impl<
61	T: ComplexField,
62	Rows: Shape,
63	Cols: Shape,
64	L: AsMatRef<T = T, Rows = Rows, Cols = Cols>,
65	R: AsMatRef<T = T, Rows = Rows, Cols = Cols>,
66	Error: equator::CmpDisplay<Cmp, T, T>,
67	Cmp: equator::Cmp<T, T, Error = Error>,
68> equator::CmpError<CwiseMat<Cmp>, L, R> for CwiseMat<Cmp>
69{
70	type Error = CwiseMatError<Rows, Cols, Error>;
71}
72
73impl<R: RealField, T: ComplexField<Real = R>> equator::CmpError<ApproxEq<R>, T, T> for ApproxEq<R> {
74	type Error = ApproxEqError;
75}
76
77impl<R: RealField, T: ComplexField<Real = R>> equator::CmpDisplay<ApproxEq<R>, T, T> for ApproxEqError {
78	#[math]
79	fn fmt(
80		&self,
81		cmp: &ApproxEq<R>,
82		lhs: &T,
83		mut lhs_source: &str,
84		lhs_debug: &dyn core::fmt::Debug,
85		rhs: &T,
86		rhs_source: &str,
87		rhs_debug: &dyn core::fmt::Debug,
88		f: &mut core::fmt::Formatter,
89	) -> core::fmt::Result {
90		let ApproxEq { abs_tol, rel_tol } = cmp;
91
92		if let Some(source) = lhs_source.strip_prefix("__skip_prologue") {
93			lhs_source = source;
94		} else {
95			writeln!(
96				f,
97				"Assertion failed: {lhs_source} ~ {rhs_source}\nwith absolute tolerance = {abs_tol:?}\nwith relative tolerance = {rel_tol:?}"
98			)?;
99		}
100
101		let distance = abs(*lhs - *rhs);
102
103		write!(f, "- {lhs_source} = {lhs_debug:?}\n")?;
104		write!(f, "- {rhs_source} = {rhs_debug:?}\n")?;
105		write!(f, "- distance = {distance:?}")
106	}
107}
108
109impl<R: RealField, T: ComplexField<Real = R>> equator::Cmp<T, T> for ApproxEq<R> {
110	#[math]
111	fn test(&self, lhs: &T, rhs: &T) -> Result<(), Self::Error> {
112		let Self { abs_tol, rel_tol } = self;
113
114		let diff = abs(*lhs - *rhs);
115		let max = max(abs(*lhs), abs(*rhs));
116
117		if (max == zero() && diff <= *abs_tol) || (diff <= *abs_tol || diff <= *rel_tol * max) {
118			Ok(())
119		} else {
120			Err(ApproxEqError)
121		}
122	}
123}
124
125impl<
126	T: ComplexField,
127	Rows: Shape,
128	Cols: Shape,
129	L: AsMatRef<T = T, Rows = Rows, Cols = Cols>,
130	R: AsMatRef<T = T, Rows = Rows, Cols = Cols>,
131	Error: equator::CmpDisplay<Cmp, T, T>,
132	Cmp: equator::Cmp<T, T, Error = Error>,
133> equator::CmpDisplay<CwiseMat<Cmp>, L, R> for CwiseMatError<Rows, Cols, Error>
134{
135	#[math]
136	fn fmt(
137		&self,
138		cmp: &CwiseMat<Cmp>,
139		lhs: &L,
140		lhs_source: &str,
141		_: &dyn core::fmt::Debug,
142		rhs: &R,
143		rhs_source: &str,
144		_: &dyn core::fmt::Debug,
145		f: &mut core::fmt::Formatter,
146	) -> core::fmt::Result {
147		let lhs = lhs.as_mat_ref();
148		let rhs = rhs.as_mat_ref();
149		match self {
150			Self::DimMismatch => {
151				let lhs_nrows = lhs.nrows();
152				let lhs_ncols = lhs.ncols();
153				let rhs_nrows = rhs.nrows();
154				let rhs_ncols = rhs.ncols();
155
156				writeln!(f, "Assertion failed: {lhs_source} ~ {rhs_source}\n")?;
157				write!(f, "- {lhs_source} = Mat[{lhs_nrows:?}, {lhs_ncols:?}]\n")?;
158				write!(f, "- {rhs_source} = Mat[{rhs_nrows:?}, {rhs_ncols:?}]")?;
159			},
160
161			Self::Elements(indices) => {
162				let mut prefix = "";
163
164				let mut count = 0;
165				for (i, j, e) in indices {
166					if count >= 10 {
167						write!(f, "\n\n... ({} mismatches omitted)\n\n", indices.len() - count,)?;
168						break;
169					}
170					count += 1;
171
172					let i = *i;
173					let j = *j;
174					let lhs = lhs.at(i, j).clone();
175					let rhs = rhs.at(i, j).clone();
176
177					e.fmt(
178						&cmp.0,
179						&lhs,
180						&alloc::format!("{prefix}{lhs_source} at ({i:?}, {j:?})"),
181						crate::hacks::hijack_debug(&lhs),
182						&rhs,
183						&alloc::format!("{rhs_source} at ({i:?}, {j:?})"),
184						crate::hacks::hijack_debug(&rhs),
185						f,
186					)?;
187					write!(f, "\n\n")?;
188					prefix = "__skip_prologue"
189				}
190			},
191		}
192		Ok(())
193	}
194}
195
196impl<
197	T: ComplexField,
198	Rows: Shape,
199	Cols: Shape,
200	L: AsMatRef<T = T, Rows = Rows, Cols = Cols>,
201	R: AsMatRef<T = T, Rows = Rows, Cols = Cols>,
202	Error: equator::CmpDisplay<Cmp, T, T>,
203	Cmp: equator::Cmp<T, T, Error = Error>,
204> equator::Cmp<L, R> for CwiseMat<Cmp>
205{
206	fn test(&self, lhs: &L, rhs: &R) -> Result<(), Self::Error> {
207		let lhs = lhs.as_mat_ref();
208		let rhs = rhs.as_mat_ref();
209
210		if lhs.nrows() != rhs.nrows() || lhs.ncols() != rhs.ncols() {
211			return Err(CwiseMatError::DimMismatch);
212		}
213
214		let mut indices = alloc::vec::Vec::new();
215		for j in 0..lhs.ncols().unbound() {
216			let j = lhs.ncols().checked_idx(j);
217			for i in 0..lhs.nrows().unbound() {
218				let i = lhs.nrows().checked_idx(i);
219
220				if let Err(err) = self.0.test(&lhs.at(i, j).clone(), &rhs.at(i, j).clone()) {
221					indices.push((i, j, err));
222				}
223			}
224		}
225
226		if indices.is_empty() {
227			Ok(())
228		} else {
229			Err(CwiseMatError::Elements(indices))
230		}
231	}
232}