matmul

Function matmul 

Source
pub fn matmul<T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>, M: Shape, N: Shape, K: Shape>(
    dst: impl AsMatMut<T = T, Rows = M, Cols = N>,
    dst_structure: BlockStructure,
    beta: Accum,
    lhs: impl AsMatRef<T = LhsT, Rows = M, Cols = K>,
    lhs_structure: BlockStructure,
    rhs: impl AsMatRef<T = RhsT, Rows = K, Cols = N>,
    rhs_structure: BlockStructure,
    alpha: T,
    par: Par,
)
Expand description

computes the matrix product [beta * acc] + alpha * lhs * rhs (implicitly conjugating the operands if needed) and stores the result in acc

performs the operation:

  • acc = alpha * lhs * rhs if beta is accum::replace (in this case, the preexisting values in acc are not read)
  • acc = acc + alpha * lhs * rhs if beta is accum::add

the left hand side and right hand side may be interpreted as triangular depending on the given corresponding matrix structure.

for the destination matrix, the result is:

  • fully computed if the structure is rectangular,
  • only the triangular half (including the diagonal) is computed if the structure is triangular
  • only the strict triangular half (excluding the diagonal) is computed if the structure is strictly triangular or unit triangular

§panics

panics if the matrix dimensions are not compatible for matrix multiplication. i.e.

  • acc.nrows() == lhs.nrows()
  • acc.ncols() == rhs.ncols()
  • lhs.ncols() == rhs.nrows()

additionally, matrices that are marked as triangular must be square, i.e., they must have the same number of rows and columns.

§example

use faer::linalg::matmul::triangular::{BlockStructure, matmul};
use faer::{Accum, Conj, Mat, Par, mat, unzip, zip};

let lhs = mat![[0.0, 2.0], [1.0, 3.0]];
let rhs = mat![[4.0, 6.0], [5.0, 7.0]];

let mut acc = Mat::<f64>::zeros(2, 2);
let target = mat![
	[
		2.5 * (lhs[(0, 0)] * rhs[(0, 0)] + lhs[(0, 1)] * rhs[(1, 0)]),
		0.0,
	],
	[
		2.5 * (lhs[(1, 0)] * rhs[(0, 0)] + lhs[(1, 1)] * rhs[(1, 0)]),
		2.5 * (lhs[(1, 0)] * rhs[(0, 1)] + lhs[(1, 1)] * rhs[(1, 1)]),
	],
];

matmul(
	&mut acc,
	BlockStructure::TriangularLower,
	Accum::Replace,
	&lhs,
	BlockStructure::Rectangular,
	&rhs,
	BlockStructure::Rectangular,
	2.5,
	Par::Seq,
);

zip!(&acc, &target).for_each(|unzip!(acc, target)| assert!((acc - target).abs() < 1e-10));