faer/linalg/
kron.rs

1use crate::internal_prelude::*;
2
3/// kronecker product of two matrices
4///
5/// the kronecker product of two matrices $A$ and $B$ is a block matrix
6/// $B$ with the following structure:
7///
8/// ```text
9/// C = [ a[(0, 0)] * B    , a[(0, 1)] * B    , ... , a[(0, n-1)] * B    ]
10///     [ a[(1, 0)] * B    , a[(1, 1)] * B    , ... , a[(1, n-1)] * B    ]
11///     [ ...              , ...              , ... , ...              ]
12///     [ a[(m-1, 0)] * B  , a[(m-1, 1)] * B  , ... , a[(m-1, n-1)] * B  ]
13/// ```
14///
15/// # panics
16///
17/// panics if `dst` does not have the correct dimensions. the dimensions
18/// of `dst` must be `A.nrows() * B.nrows()` by `A.ncols() * B.ncols()`.
19///
20/// # example
21///
22/// ```
23/// use faer::linalg::kron::kron;
24/// use faer::{Mat, mat};
25///
26/// let a = mat![[1.0, 2.0], [3.0, 4.0]];
27/// let b = mat![[0.0, 5.0], [6.0, 7.0]];
28/// let c = mat![
29/// 	[0.0, 5.0, 0.0, 10.0],
30/// 	[6.0, 7.0, 12.0, 14.0],
31/// 	[0.0, 15.0, 0.0, 20.0],
32/// 	[18.0, 21.0, 24.0, 28.0],
33/// ];
34/// let mut dst = Mat::zeros(4, 4);
35/// kron(dst.as_mut(), a.as_ref(), b.as_ref());
36/// assert_eq!(dst, c);
37/// ```
38#[track_caller]
39#[math]
40pub fn kron<T: ComplexField>(dst: MatMut<'_, T>, lhs: MatRef<'_, impl Conjugate<Canonical = T>>, rhs: MatRef<'_, impl Conjugate<Canonical = T>>) {
41	// pull the lever kron
42
43	let mut dst = dst;
44	let mut lhs = lhs;
45	let mut rhs = rhs;
46	if dst.col_stride().unsigned_abs() < dst.row_stride().unsigned_abs() {
47		dst = dst.transpose_mut();
48		lhs = lhs.transpose();
49		rhs = rhs.transpose();
50	}
51
52	Assert!(Some(dst.nrows()) == lhs.nrows().checked_mul(rhs.nrows()));
53	Assert!(Some(dst.ncols()) == lhs.ncols().checked_mul(rhs.ncols()));
54
55	for lhs_j in 0..lhs.ncols() {
56		for lhs_i in 0..lhs.nrows() {
57			let lhs_val = Conj::apply(lhs.at(lhs_i, lhs_j));
58			let mut dst = dst
59				.rb_mut()
60				.submatrix_mut(lhs_i * rhs.nrows(), lhs_j * rhs.ncols(), rhs.nrows(), rhs.ncols());
61
62			for rhs_j in 0..rhs.ncols() {
63				for rhs_i in 0..rhs.nrows() {
64					// SAFETY: Bounds have been checked.
65					unsafe {
66						let rhs_val = Conj::apply(rhs.at_unchecked(rhs_i, rhs_j));
67						*dst.rb_mut().at_mut_unchecked(rhs_i, rhs_j) = lhs_val * rhs_val;
68					}
69				}
70			}
71		}
72	}
73	// the other lever
74}
75
76#[cfg(test)]
77mod tests {
78	use super::kron;
79	use crate::{Col, Mat, Row, assert};
80
81	#[test]
82	fn test_kron_ones() {
83		for (m, n, p, q) in [(2, 3, 4, 5), (3, 2, 5, 4), (1, 1, 1, 1)] {
84			let a = Mat::from_fn(m, n, |_, _| 1 as f64);
85			let b = Mat::from_fn(p, q, |_, _| 1 as f64);
86			let expected = Mat::from_fn(m * p, n * q, |_, _| 1 as f64);
87			let mut out = Mat::zeros(a.nrows() * b.nrows(), a.ncols() * b.ncols());
88			kron(out.as_mut(), a.as_ref(), b.as_ref());
89			assert!(out == expected);
90		}
91
92		for (m, n, p) in [(2, 3, 4), (3, 2, 5), (1, 1, 1)] {
93			let a = Mat::from_fn(m, n, |_, _| 1 as f64);
94			let b = Col::from_fn(p, |_| 1 as f64);
95			let expected = Mat::from_fn(m * p, n, |_, _| 1 as f64);
96			let mut out = Mat::zeros(a.nrows() * b.nrows(), a.ncols() * b.ncols());
97			kron(out.as_mut(), a.as_ref(), b.as_ref().as_mat());
98			assert!(out == expected);
99			let mut out = Mat::zeros(b.nrows() * a.nrows(), b.ncols() * a.ncols());
100			kron(out.as_mut(), b.as_ref().as_mat(), a.as_ref());
101			assert!(out == expected);
102
103			let a = Mat::from_fn(m, n, |_, _| 1 as f64);
104			let b = Row::from_fn(p, |_| 1 as f64);
105			let expected = Mat::from_fn(m, n * p, |_, _| 1 as f64);
106			let mut out = Mat::zeros(a.nrows() * b.nrows(), a.ncols() * b.ncols());
107			kron(out.as_mut(), a.as_ref(), b.as_ref().as_mat());
108			assert!(out == expected);
109			let mut out = Mat::zeros(b.nrows() * a.nrows(), b.ncols() * a.ncols());
110			kron(out.as_mut(), b.as_ref().as_mat(), a.as_ref());
111			assert!(out == expected);
112		}
113
114		for (m, n) in [(2, 3), (3, 2), (1, 1)] {
115			let a = Row::from_fn(m, |_| 1 as f64);
116			let b = Col::from_fn(n, |_| 1 as f64);
117			let expected = Mat::from_fn(n, m, |_, _| 1 as f64);
118			let mut out = Mat::zeros(a.nrows() * b.nrows(), a.ncols() * b.ncols());
119			kron(out.as_mut(), a.as_ref().as_mat(), b.as_ref().as_mat());
120			assert!(out == expected);
121			let mut out = Mat::zeros(b.nrows() * a.nrows(), b.ncols() * a.ncols());
122			kron(out.as_mut(), b.as_ref().as_mat(), a.as_ref().as_mat());
123			assert!(out == expected);
124
125			let c = Row::from_fn(n, |_| 1 as f64);
126			let expected = Mat::from_fn(1, m * n, |_, _| 1 as f64);
127			let mut out = Mat::zeros(a.nrows() * c.nrows(), a.ncols() * c.ncols());
128			kron(out.as_mut(), a.as_ref().as_mat(), c.as_ref().as_mat());
129			assert!(out == expected);
130
131			let d = Col::from_fn(m, |_| 1 as f64);
132			let expected = Mat::from_fn(m * n, 1, |_, _| 1 as f64);
133			let mut out = Mat::zeros(d.nrows() * b.nrows(), d.ncols() * b.ncols());
134			kron(out.as_mut(), d.as_ref().as_mat(), b.as_ref().as_mat());
135			assert!(out == expected);
136		}
137	}
138}