1use crate::internal_prelude::*;
2
3#[track_caller]
39#[math]
40pub fn kron<T: ComplexField>(dst: MatMut<'_, T>, lhs: MatRef<'_, impl Conjugate<Canonical = T>>, rhs: MatRef<'_, impl Conjugate<Canonical = T>>) {
41 let mut dst = dst;
44 let mut lhs = lhs;
45 let mut rhs = rhs;
46 if dst.col_stride().unsigned_abs() < dst.row_stride().unsigned_abs() {
47 dst = dst.transpose_mut();
48 lhs = lhs.transpose();
49 rhs = rhs.transpose();
50 }
51
52 Assert!(Some(dst.nrows()) == lhs.nrows().checked_mul(rhs.nrows()));
53 Assert!(Some(dst.ncols()) == lhs.ncols().checked_mul(rhs.ncols()));
54
55 for lhs_j in 0..lhs.ncols() {
56 for lhs_i in 0..lhs.nrows() {
57 let lhs_val = Conj::apply(lhs.at(lhs_i, lhs_j));
58 let mut dst = dst
59 .rb_mut()
60 .submatrix_mut(lhs_i * rhs.nrows(), lhs_j * rhs.ncols(), rhs.nrows(), rhs.ncols());
61
62 for rhs_j in 0..rhs.ncols() {
63 for rhs_i in 0..rhs.nrows() {
64 unsafe {
66 let rhs_val = Conj::apply(rhs.at_unchecked(rhs_i, rhs_j));
67 *dst.rb_mut().at_mut_unchecked(rhs_i, rhs_j) = lhs_val * rhs_val;
68 }
69 }
70 }
71 }
72 }
73 }
75
76#[cfg(test)]
77mod tests {
78 use super::kron;
79 use crate::{Col, Mat, Row, assert};
80
81 #[test]
82 fn test_kron_ones() {
83 for (m, n, p, q) in [(2, 3, 4, 5), (3, 2, 5, 4), (1, 1, 1, 1)] {
84 let a = Mat::from_fn(m, n, |_, _| 1 as f64);
85 let b = Mat::from_fn(p, q, |_, _| 1 as f64);
86 let expected = Mat::from_fn(m * p, n * q, |_, _| 1 as f64);
87 let mut out = Mat::zeros(a.nrows() * b.nrows(), a.ncols() * b.ncols());
88 kron(out.as_mut(), a.as_ref(), b.as_ref());
89 assert!(out == expected);
90 }
91
92 for (m, n, p) in [(2, 3, 4), (3, 2, 5), (1, 1, 1)] {
93 let a = Mat::from_fn(m, n, |_, _| 1 as f64);
94 let b = Col::from_fn(p, |_| 1 as f64);
95 let expected = Mat::from_fn(m * p, n, |_, _| 1 as f64);
96 let mut out = Mat::zeros(a.nrows() * b.nrows(), a.ncols() * b.ncols());
97 kron(out.as_mut(), a.as_ref(), b.as_ref().as_mat());
98 assert!(out == expected);
99 let mut out = Mat::zeros(b.nrows() * a.nrows(), b.ncols() * a.ncols());
100 kron(out.as_mut(), b.as_ref().as_mat(), a.as_ref());
101 assert!(out == expected);
102
103 let a = Mat::from_fn(m, n, |_, _| 1 as f64);
104 let b = Row::from_fn(p, |_| 1 as f64);
105 let expected = Mat::from_fn(m, n * p, |_, _| 1 as f64);
106 let mut out = Mat::zeros(a.nrows() * b.nrows(), a.ncols() * b.ncols());
107 kron(out.as_mut(), a.as_ref(), b.as_ref().as_mat());
108 assert!(out == expected);
109 let mut out = Mat::zeros(b.nrows() * a.nrows(), b.ncols() * a.ncols());
110 kron(out.as_mut(), b.as_ref().as_mat(), a.as_ref());
111 assert!(out == expected);
112 }
113
114 for (m, n) in [(2, 3), (3, 2), (1, 1)] {
115 let a = Row::from_fn(m, |_| 1 as f64);
116 let b = Col::from_fn(n, |_| 1 as f64);
117 let expected = Mat::from_fn(n, m, |_, _| 1 as f64);
118 let mut out = Mat::zeros(a.nrows() * b.nrows(), a.ncols() * b.ncols());
119 kron(out.as_mut(), a.as_ref().as_mat(), b.as_ref().as_mat());
120 assert!(out == expected);
121 let mut out = Mat::zeros(b.nrows() * a.nrows(), b.ncols() * a.ncols());
122 kron(out.as_mut(), b.as_ref().as_mat(), a.as_ref().as_mat());
123 assert!(out == expected);
124
125 let c = Row::from_fn(n, |_| 1 as f64);
126 let expected = Mat::from_fn(1, m * n, |_, _| 1 as f64);
127 let mut out = Mat::zeros(a.nrows() * c.nrows(), a.ncols() * c.ncols());
128 kron(out.as_mut(), a.as_ref().as_mat(), c.as_ref().as_mat());
129 assert!(out == expected);
130
131 let d = Col::from_fn(m, |_| 1 as f64);
132 let expected = Mat::from_fn(m * n, 1, |_, _| 1 as f64);
133 let mut out = Mat::zeros(d.nrows() * b.nrows(), d.ncols() * b.ncols());
134 kron(out.as_mut(), d.as_ref().as_mat(), b.as_ref().as_mat());
135 assert!(out == expected);
136 }
137 }
138}