1use faer_traits::{Real, RealReg};
2use linalg::matmul::matmul;
3use pulp::Simd;
4
5use crate::internal_prelude::*;
6use crate::perm::{swap_cols_idx, swap_rows_idx};
7use crate::utils::thread::par_split_indices;
8
9#[inline(always)]
10fn best_value<T: ComplexField, S: Simd>(
11 simd: &SimdCtx<T, S>,
12 best_value: RealReg<T::SimdVec<S>>,
13 best_indices: T::SimdIndex<S>,
14 value: T::SimdVec<S>,
15 indices: T::SimdIndex<S>,
16) -> (RealReg<T::SimdVec<S>>, T::SimdIndex<S>) {
17 let value = simd.abs1(value);
18 let is_better = (**simd).gt(value, best_value);
19 (
20 RealReg(simd.select(is_better, value.0, best_value.0)),
21 simd.iselect(is_better, indices, best_indices),
22 )
23}
24
25#[inline(always)]
26fn best_score<T: ComplexField, S: Simd>(
27 simd: &SimdCtx<T, S>,
28 best_score: RealReg<T::SimdVec<S>>,
29 best_indices: T::SimdIndex<S>,
30 score: RealReg<T::SimdVec<S>>,
31 indices: T::SimdIndex<S>,
32) -> (RealReg<T::SimdVec<S>>, T::SimdIndex<S>) {
33 let is_better = (**simd).gt(score, best_score);
34 (
35 RealReg(simd.select(is_better, score.0, best_score.0)),
36 simd.iselect(is_better, indices, best_indices),
37 )
38}
39
40#[inline(always)]
41fn best_score_2d<T: ComplexField, S: Simd>(
42 simd: &SimdCtx<T, S>,
43 best_score: RealReg<T::SimdVec<S>>,
44 best_row: T::SimdIndex<S>,
45 best_col: T::SimdIndex<S>,
46 score: RealReg<T::SimdVec<S>>,
47 row: T::SimdIndex<S>,
48 col: T::SimdIndex<S>,
49) -> (RealReg<T::SimdVec<S>>, T::SimdIndex<S>, T::SimdIndex<S>) {
50 let is_better = (**simd).gt(score, best_score);
51 (
52 RealReg(simd.select(is_better, score.0, best_score.0)),
53 simd.iselect(is_better, row, best_row),
54 simd.iselect(is_better, col, best_col),
55 )
56}
57
58#[inline(always)]
59#[math]
60fn reduce_2d<T: ComplexField, S: Simd>(
61 simd: &SimdCtx<T, S>,
62 best_values: RealReg<T::SimdVec<S>>,
63 best_row: T::SimdIndex<S>,
64 best_col: T::SimdIndex<S>,
65) -> (usize, usize, Real<T>) {
66 let best_val = simd.reduce_max_real(best_values);
67
68 let best_val_splat = simd.splat_real(&best_val);
69 let is_best = (**simd).ge(best_values, best_val_splat);
70 let idx = simd.first_true_mask(is_best);
71
72 let best_row = bytemuck::cast_slice::<T::SimdIndex<S>, T::Index>(core::slice::from_ref(&best_row))[idx];
73 let best_col = bytemuck::cast_slice::<T::SimdIndex<S>, T::Index>(core::slice::from_ref(&best_col))[idx];
74
75 (best_row.zx(), best_col.zx(), best_val)
76}
77
78#[inline(always)]
79#[math]
80fn best_in_col_simd<'M, T: ComplexField, S: Simd>(
81 simd: SimdCtx<'M, T, S>,
82 data: ColRef<'_, T, Dim<'M>, ContiguousFwd>,
83) -> (RealReg<T::SimdVec<S>>, T::SimdIndex<S>) {
84 let (head, body4, body1, tail) = simd.batch_indices::<4>();
85
86 let iota = T::simd_iota(&simd.0);
87 let lane_count = core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>();
88
89 let inc1 = simd.isplat(T::Index::truncate(lane_count));
90 let inc4 = simd.isplat(T::Index::truncate(4 * lane_count));
91
92 let mut best_val0 = simd.splat_real(&zero());
93 let mut best_val1 = simd.splat_real(&zero());
94 let mut best_val2 = simd.splat_real(&zero());
95 let mut best_val3 = simd.splat_real(&zero());
96
97 let mut best_idx0 = simd.isplat(T::Index::truncate(0));
98 let mut best_idx1 = simd.isplat(T::Index::truncate(0));
99 let mut best_idx2 = simd.isplat(T::Index::truncate(0));
100 let mut best_idx3 = simd.isplat(T::Index::truncate(0));
101
102 let mut idx0 = simd.iadd(iota, simd.isplat(T::Index::truncate(simd.offset().wrapping_neg())));
103 let mut idx1 = simd.iadd(idx0, inc1);
104 let mut idx2 = simd.iadd(idx1, inc1);
105 let mut idx3 = simd.iadd(idx2, inc1);
106
107 if let Some(i0) = head {
108 (best_val0, best_idx0) = best_value(&simd, best_val0, best_idx0, simd.read(data, i0), idx0);
109 idx0 = simd.iadd(idx0, inc1);
110 }
111
112 for [i0, i1, i2, i3] in body4 {
113 (best_val0, best_idx0) = best_value(&simd, best_val0, best_idx0, simd.read(data, i0), idx0);
114 (best_val1, best_idx1) = best_value(&simd, best_val1, best_idx1, simd.read(data, i1), idx1);
115 (best_val2, best_idx2) = best_value(&simd, best_val2, best_idx2, simd.read(data, i2), idx2);
116 (best_val3, best_idx3) = best_value(&simd, best_val3, best_idx3, simd.read(data, i3), idx3);
117
118 idx0 = simd.iadd(idx0, inc4);
119 idx1 = simd.iadd(idx1, inc4);
120 idx2 = simd.iadd(idx2, inc4);
121 idx3 = simd.iadd(idx3, inc4);
122 }
123
124 for i0 in body1 {
125 (best_val0, best_idx0) = best_value(&simd, best_val0, best_idx0, simd.read(data, i0), idx0);
126 idx0 = simd.iadd(idx0, inc1);
127 }
128
129 if let Some(i0) = tail {
130 (best_val0, best_idx0) = best_value(&simd, best_val0, best_idx0, simd.read(data, i0), idx0);
131 }
132
133 (best_val0, best_idx0) = best_score(&simd, best_val0, best_idx0, best_val1, best_idx1);
134 (best_val2, best_idx2) = best_score(&simd, best_val2, best_idx2, best_val3, best_idx3);
135 best_score(&simd, best_val0, best_idx0, best_val2, best_idx2)
136}
137
138#[inline(always)]
139#[math]
140fn update_and_best_in_col_simd<'M, T: ComplexField, S: Simd>(
141 simd: SimdCtx<'M, T, S>,
142 data: ColMut<'_, T, Dim<'M>, ContiguousFwd>,
143 lhs: ColRef<'_, T, Dim<'M>, ContiguousFwd>,
144 rhs: T,
145) -> (RealReg<T::SimdVec<S>>, T::SimdIndex<S>) {
146 let mut data = data;
147
148 let (head, body4, body1, tail) = simd.batch_indices::<3>();
149
150 let iota = T::simd_iota(&simd.0);
151 let lane_count = core::mem::size_of::<T::SimdVec<S>>() / core::mem::size_of::<T>();
152
153 let inc1 = simd.isplat(T::Index::truncate(lane_count));
154 let inc3 = simd.isplat(T::Index::truncate(3 * lane_count));
155
156 let mut best_val0 = simd.splat_real(&zero());
157 let mut best_val1 = simd.splat_real(&zero());
158 let mut best_val2 = simd.splat_real(&zero());
159
160 let mut best_idx0 = simd.isplat(T::Index::truncate(0));
161 let mut best_idx1 = simd.isplat(T::Index::truncate(0));
162 let mut best_idx2 = simd.isplat(T::Index::truncate(0));
163
164 let mut idx0 = simd.iadd(iota, simd.isplat(T::Index::truncate(simd.offset().wrapping_neg())));
165 let mut idx1 = simd.iadd(idx0, inc1);
166 let mut idx2 = simd.iadd(idx1, inc1);
167
168 let rhs = simd.splat(&-rhs);
169
170 if let Some(i0) = head {
171 let mut x0 = simd.read(data.rb(), i0);
172 let l0 = simd.read(lhs, i0);
173 x0 = simd.mul_add(l0, rhs, x0);
174
175 (best_val0, best_idx0) = best_value(&simd, best_val0, best_idx0, x0, idx0);
176 idx0 = simd.iadd(idx0, inc1);
177
178 simd.write(data.rb_mut(), i0, x0);
179 }
180
181 for [i0, i1, i2] in body4 {
182 let mut x0 = simd.read(data.rb(), i0);
183 let l0 = simd.read(lhs, i0);
184 x0 = simd.mul_add(l0, rhs, x0);
185 (best_val0, best_idx0) = best_value(&simd, best_val0, best_idx0, simd.read(data.rb(), i0), idx0);
186 simd.write(data.rb_mut(), i0, x0);
187
188 let mut x1 = simd.read(data.rb(), i1);
189 let l1 = simd.read(lhs, i1);
190 x1 = simd.mul_add(l1, rhs, x1);
191 (best_val1, best_idx1) = best_value(&simd, best_val1, best_idx1, simd.read(data.rb(), i1), idx1);
192 simd.write(data.rb_mut(), i1, x1);
193
194 let mut x2 = simd.read(data.rb(), i2);
195 let l2 = simd.read(lhs, i2);
196 x2 = simd.mul_add(l2, rhs, x2);
197 (best_val2, best_idx2) = best_value(&simd, best_val2, best_idx2, simd.read(data.rb(), i2), idx2);
198 simd.write(data.rb_mut(), i2, x2);
199
200 idx0 = simd.iadd(idx0, inc3);
201 idx1 = simd.iadd(idx1, inc3);
202 idx2 = simd.iadd(idx2, inc3);
203 }
204
205 for i0 in body1 {
206 let mut x0 = simd.read(data.rb(), i0);
207 let l0 = simd.read(lhs, i0);
208 x0 = simd.mul_add(l0, rhs, x0);
209
210 (best_val0, best_idx0) = best_value(&simd, best_val0, best_idx0, x0, idx0);
211 idx0 = simd.iadd(idx0, inc1);
212
213 simd.write(data.rb_mut(), i0, x0);
214 }
215
216 if let Some(i0) = tail {
217 let mut x0 = simd.read(data.rb(), i0);
218 let l0 = simd.read(lhs, i0);
219 x0 = simd.mul_add(l0, rhs, x0);
220
221 (best_val0, best_idx0) = best_value(&simd, best_val0, best_idx0, x0, idx0);
222
223 simd.write(data.rb_mut(), i0, x0);
224 }
225
226 (best_val0, best_idx0) = best_score(&simd, best_val0, best_idx0, best_val1, best_idx1);
227 best_score(&simd, best_val0, best_idx0, best_val2, best_idx2)
228}
229
230#[inline(always)]
231fn best_in_mat_simd<T: ComplexField>(data: MatRef<'_, T, usize, usize, ContiguousFwd>) -> (usize, usize, Real<T>) {
232 struct Impl<'a, 'M, 'N, T: ComplexField> {
233 data: MatRef<'a, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
234 }
235
236 impl<'a, 'M, 'N, T: ComplexField> pulp::WithSimd for Impl<'a, 'M, 'N, T> {
237 type Output = (usize, usize, Real<T>);
238
239 #[math]
240 #[inline(always)]
241 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
242 let Self { data } = self;
243
244 let M = data.nrows();
245 let N = data.ncols();
246 let simd = SimdCtx::<'_, T, S>::new(T::simd_ctx(simd), M);
247
248 let mut best_row = simd.isplat(T::Index::truncate(0));
249 let mut best_col = simd.isplat(T::Index::truncate(0));
250 let mut best_val = simd.splat_real(&zero());
251
252 for j in N.indices() {
253 let col = data.col(j);
254 let (best_val_j, best_row_j) = best_in_col_simd(simd, col);
255
256 (best_val, best_row, best_col) = best_score_2d(
257 &simd,
258 best_val,
259 best_row,
260 best_col,
261 best_val_j,
262 best_row_j,
263 simd.isplat(T::Index::truncate(*j)),
264 );
265 }
266 reduce_2d(&simd, best_val, best_row, best_col)
267 }
268 }
269
270 with_dim!(M, data.nrows());
271 with_dim!(N, data.ncols());
272 dispatch!(Impl { data: data.as_shape(M, N) }, Impl, T)
273}
274
275#[inline(always)]
276fn update_and_best_in_mat_simd<T: ComplexField>(
277 data: MatMut<'_, T, usize, usize, ContiguousFwd>,
278 lhs: ColRef<'_, T, usize, ContiguousFwd>,
279 rhs: RowRef<'_, T, usize>,
280 align: usize,
281) -> (usize, usize, Real<T>) {
282 struct Impl<'a, 'M, 'N, T: ComplexField> {
283 data: MatMut<'a, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
284 lhs: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
285 rhs: RowRef<'a, T, Dim<'N>>,
286 align: usize,
287 }
288
289 impl<'a, 'M, 'N, T: ComplexField> pulp::WithSimd for Impl<'a, 'M, 'N, T> {
290 type Output = (usize, usize, Real<T>);
291
292 #[math]
293 #[inline(always)]
294 fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
295 let Self { data, lhs, rhs, align } = self;
296
297 let M = data.nrows();
298 let N = data.ncols();
299 let simd = SimdCtx::<'_, T, S>::new_align(T::simd_ctx(simd), M, align);
300
301 let mut best_row = simd.isplat(T::Index::truncate(0));
302 let mut best_col = simd.isplat(T::Index::truncate(0));
303 let mut best_val = simd.splat_real(&zero());
304 let mut data = data;
305
306 for j in N.indices() {
307 let data = data.rb_mut().col_mut(j);
308 let rhs = copy(rhs[j]);
309 let (best_val_j, best_row_j) = update_and_best_in_col_simd(simd, data, lhs, rhs);
310
311 (best_val, best_row, best_col) = best_score_2d(
312 &simd,
313 best_val,
314 best_row,
315 best_col,
316 best_val_j,
317 best_row_j,
318 simd.isplat(T::Index::truncate(*j)),
319 );
320 }
321 reduce_2d(&simd, best_val, best_row, best_col)
322 }
323 }
324
325 with_dim!(M, data.nrows());
326 with_dim!(N, data.ncols());
327 dispatch!(
328 Impl {
329 data: data.as_shape_mut(M, N),
330 lhs: lhs.as_row_shape(M),
331 rhs: rhs.as_col_shape(N),
332 align,
333 },
334 Impl,
335 T
336 )
337}
338
339#[math]
340fn best_in_matrix_fallback<T: ComplexField>(data: MatRef<'_, T>) -> (usize, usize, Real<T>) {
341 let mut max = zero();
342 let mut row = 0;
343 let mut col = 0;
344
345 let (m, n) = data.shape();
346
347 for j in 0..n {
348 for i in 0..m {
349 let abs = abs1(data[(i, j)]);
350 if abs > max {
351 row = i;
352 col = j;
353 max = abs;
354 }
355 }
356 }
357
358 (row, col, max)
359}
360
361#[math]
362fn best_in_matrix<T: ComplexField>(data: MatRef<'_, T>) -> (usize, usize, Real<T>) {
363 if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
364 if let Some(dst) = data.try_as_col_major() {
365 best_in_mat_simd(dst)
366 } else {
367 best_in_matrix_fallback(data)
368 }
369 } else {
370 best_in_matrix_fallback(data)
371 }
372}
373#[math]
374fn rank_one_update_and_best_in_matrix<T: ComplexField>(
375 mut dst: MatMut<'_, T>,
376 lhs: ColRef<'_, T>,
377 rhs: RowRef<'_, T>,
378 align: usize,
379) -> (usize, usize, Real<T>) {
380 if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
381 if let (Some(dst), Some(lhs)) = (dst.rb_mut().try_as_col_major_mut(), lhs.try_as_col_major()) {
382 update_and_best_in_mat_simd(dst, lhs, rhs, align)
383 } else {
384 matmul(dst.rb_mut(), Accum::Add, lhs.as_mat(), rhs.as_mat(), -one::<T>(), Par::Seq);
385 best_in_matrix(dst.rb())
386 }
387 } else {
388 matmul(dst.rb_mut(), Accum::Add, lhs.as_mat(), rhs.as_mat(), -one::<T>(), Par::Seq);
389 best_in_matrix(dst.rb())
390 }
391}
392
393#[math]
394fn lu_in_place_unblocked<T: ComplexField>(
395 A: MatMut<'_, T>,
396 row_trans: &mut [usize],
397 col_trans: &mut [usize],
398 par: Par,
399 transpose: bool,
400 params: Spec<FullPivLuParams, T>,
401) -> usize {
402 let params = params.config;
403 let mut n_trans = 0;
404
405 let (m, n) = A.shape();
406 if m == 0 || n == 0 {
407 return 0;
408 }
409
410 let mut par = par;
411
412 let mut A = A;
413 let (mut max_row, mut max_col, mut max_score) = best_in_matrix(A.rb());
414
415 for k in 0..Ord::min(m, n) {
416 if max_score < min_positive() {
417 for (i, (row, col)) in core::iter::zip(&mut row_trans[k..], &mut col_trans[k..]).enumerate() {
418 *row = i + k;
419 *col = i + k;
420 }
421 break;
422 }
423
424 row_trans[k] = max_row;
425 col_trans[k] = max_col;
426
427 if max_row != k {
428 swap_rows_idx(A.rb_mut(), k, max_row);
429 n_trans += 1;
430 }
431 if max_col != k {
432 swap_cols_idx(A.rb_mut(), k, max_col);
433 n_trans += 1;
434 }
435
436 let inv = recip(A[(k, k)]);
437 if transpose {
438 for j in k + 1..n {
439 A[(k, j)] = A[(k, j)] * inv;
440 }
441 } else {
442 for i in k + 1..m {
443 A[(i, k)] = A[(i, k)] * inv;
444 }
445 }
446
447 if k + 1 == Ord::min(m, n) {
448 break;
449 }
450 if (m - k - 1) * (n - k - 1) < params.par_threshold {
451 par = Par::Seq;
452 }
453
454 let (_, A01, A10, mut A11) = A.rb_mut().split_at_mut(k + 1, k + 1);
455
456 let lhs = A10.col(k);
457 let rhs = A01.row(k);
458
459 match par {
460 Par::Seq => {
461 (max_row, max_col, max_score) = rank_one_update_and_best_in_matrix(A11.rb_mut(), lhs, rhs, simd_align(k + 1));
462 },
463 #[cfg(feature = "rayon")]
464 Par::Rayon(nthreads) => {
465 use rayon::prelude::*;
466 let nthreads = nthreads.get();
467
468 let mut best = core::iter::repeat_with(|| (0, 0, zero())).take(nthreads).collect::<alloc::vec::Vec<_>>();
469 let full_cols = A11.ncols();
470
471 best.par_iter_mut()
472 .zip_eq(A11.rb_mut().par_col_partition_mut(nthreads))
473 .zip_eq(rhs.par_partition(nthreads))
474 .enumerate()
475 .for_each(|(idx, (((max_row, max_col, max_score), A11), rhs))| {
476 (*max_row, *max_col, *max_score) = {
477 let (a, mut b, c) = rank_one_update_and_best_in_matrix(A11, lhs, rhs, simd_align(k + 1));
478 b += par_split_indices(full_cols, idx, nthreads).0;
479 (a, b, c)
480 };
481 });
482
483 max_row = 0;
484 max_col = 0;
485 max_score = zero();
486
487 for (row, col, val) in best {
488 if val > max_score {
489 max_row = row;
490 max_col = col;
491 max_score = val;
492 }
493 }
494 },
495 }
496
497 max_row += k + 1;
498 max_col += k + 1;
499 }
500
501 n_trans
502}
503
504#[derive(Copy, Clone, Debug)]
506pub struct FullPivLuParams {
507 pub par_threshold: usize,
509
510 #[doc(hidden)]
511 pub non_exhaustive: NonExhaustive,
512}
513
514impl<T: ComplexField> Auto<T> for FullPivLuParams {
515 #[inline]
516 fn auto() -> Self {
517 Self {
518 par_threshold: 256 * 512,
519 non_exhaustive: NonExhaustive(()),
520 }
521 }
522}
523
524#[inline]
525pub fn lu_in_place_scratch<I: Index, T: ComplexField>(nrows: usize, ncols: usize, par: Par, params: Spec<FullPivLuParams, T>) -> StackReq {
526 _ = par;
527 _ = params;
528 let size = Ord::min(nrows, ncols);
529 StackReq::new::<usize>(size).array(2)
530}
531
532#[derive(Copy, Clone, Debug)]
533pub struct FullPivLuInfo {
534 pub transposition_count: usize,
535}
536
537pub fn lu_in_place<'out, I: Index, T: ComplexField>(
538 mat: MatMut<'_, T>,
539 row_perm: &'out mut [I],
540 row_perm_inv: &'out mut [I],
541 col_perm: &'out mut [I],
542 col_perm_inv: &'out mut [I],
543 par: Par,
544 stack: &mut MemStack,
545 params: Spec<FullPivLuParams, T>,
546) -> (FullPivLuInfo, PermRef<'out, I>, PermRef<'out, I>) {
547 #[cfg(feature = "perf-warn")]
548 if (mat.col_stride().unsigned_abs() == 1 || mat.row_stride().unsigned_abs() != 1) && crate::__perf_warn!(LU_WARN) {
549 log::warn!(target: "faer_perf", "LU with full pivoting prefers column-major or row-major matrix. Found matrix with generic strides.");
550 }
551
552 let (M, N) = mat.shape();
553
554 let size = Ord::min(M, N);
555
556 let (mut row_transpositions, stack) = stack.make_with(size, |_| 0);
557 let row_transpositions = row_transpositions.as_mut();
558 let (mut col_transpositions, _) = stack.make_with(size, |_| 0);
559 let col_transpositions = col_transpositions.as_mut();
560
561 let n_transpositions = if mat.row_stride().abs() < mat.col_stride().abs() {
562 lu_in_place_unblocked(mat, row_transpositions, col_transpositions, par, false, params)
563 } else {
564 lu_in_place_unblocked(mat.transpose_mut(), col_transpositions, row_transpositions, par, true, params)
565 };
566
567 for i in 0..M {
568 row_perm[i] = I::truncate(i);
569 }
570 for (i, t) in row_transpositions.iter().copied().enumerate() {
571 row_perm.as_mut().swap(i, t);
572 }
573 for i in 0..M {
574 row_perm_inv[row_perm[i].zx()] = I::truncate(i);
575 }
576
577 for j in 0..N {
578 col_perm[j] = I::truncate(j);
579 }
580 for (i, t) in col_transpositions.iter().copied().enumerate() {
581 col_perm.as_mut().swap(i, t);
582 }
583 for j in 0..N {
584 col_perm_inv[col_perm[j].zx()] = I::truncate(j);
585 }
586
587 unsafe {
588 (
589 FullPivLuInfo {
590 transposition_count: n_transpositions,
591 },
592 PermRef::new_unchecked(row_perm, row_perm_inv, M),
593 PermRef::new_unchecked(col_perm, col_perm_inv, N),
594 )
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use crate::stats::prelude::*;
602 use crate::utils::approx::*;
603 use crate::{Mat, assert, c64};
604 use dyn_stack::MemBuffer;
605
606 #[test]
607 fn test_flu() {
608 let rng = &mut StdRng::seed_from_u64(0);
609
610 for par in [Par::Seq, Par::rayon(8)] {
611 for m in [8, 16, 24, 32, 128, 255, 256, 257] {
612 let n = 8;
613
614 let approx_eq = CwiseMat(ApproxEq {
615 abs_tol: 1e-10,
616 rel_tol: 1e-10,
617 });
618
619 let A = CwiseMatDistribution {
620 nrows: m,
621 ncols: n,
622 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
623 }
624 .rand::<Mat<c64>>(rng);
625 let A = A.as_ref();
626
627 let mut LU = A.cloned();
628 let row_perm = &mut *vec![0usize; m];
629 let row_perm_inv = &mut *vec![0usize; m];
630
631 let col_perm = &mut *vec![0usize; n];
632 let col_perm_inv = &mut *vec![0usize; n];
633
634 let (_, p, q) = lu_in_place(
635 LU.as_mut(),
636 row_perm,
637 row_perm_inv,
638 col_perm,
639 col_perm_inv,
640 par,
641 MemStack::new(&mut MemBuffer::new(lu_in_place_scratch::<usize, c64>(n, n, par, default()))),
642 default(),
643 );
644
645 let mut L = LU.as_ref().cloned();
646 let mut U = LU.as_ref().cloned();
647
648 for j in 0..n {
649 for i in 0..j {
650 L[(i, j)] = c64::ZERO;
651 }
652 L[(j, j)] = c64::ONE;
653 }
654 for j in 0..n {
655 for i in j + 1..m {
656 U[(i, j)] = c64::ZERO;
657 }
658 }
659 let L = L.as_ref();
660 let U = U.as_ref();
661
662 let U = U.subrows(0, n);
663
664 assert!(p.inverse() * L * U * q ~ A);
665 }
666
667 for n in [16, 24, 32, 128, 255, 256, 257] {
668 let approx_eq = CwiseMat(ApproxEq {
669 abs_tol: 1e-10,
670 rel_tol: 1e-10,
671 });
672
673 let A = CwiseMatDistribution {
674 nrows: n,
675 ncols: n,
676 dist: StandardNormal,
677 }
678 .rand::<Mat<f64>>(rng);
679 let A = A.as_ref();
680
681 let mut LU = A.cloned();
682 let row_perm = &mut *vec![0usize; n];
683 let row_perm_inv = &mut *vec![0usize; n];
684
685 let col_perm = &mut *vec![0usize; n];
686 let col_perm_inv = &mut *vec![0usize; n];
687
688 let (_, p, q) = lu_in_place(
689 LU.as_mut(),
690 row_perm,
691 row_perm_inv,
692 col_perm,
693 col_perm_inv,
694 par,
695 MemStack::new(&mut MemBuffer::new(lu_in_place_scratch::<usize, f64>(n, n, par, default()))),
696 default(),
697 );
698
699 let mut L = LU.as_ref().cloned();
700 let mut U = LU.as_ref().cloned();
701
702 for j in 0..n {
703 for i in 0..j {
704 L[(i, j)] = 0.0;
705 }
706 L[(j, j)] = 1.0;
707 }
708 for j in 0..n {
709 for i in j + 1..n {
710 U[(i, j)] = 0.0;
711 }
712 }
713 let L = L.as_ref();
714 let U = U.as_ref();
715
716 assert!(p.inverse() * L * U * q ~ A);
717 }
718 }
719 }
720}