phase_rs/normal_syntax/
pattern.rs

1//! Normal form patterns
2
3use faer::Mat;
4use num_complex::Complex;
5
6use crate::{
7    ket::{CompKetState, KetState},
8    normal_syntax::term::AtomN,
9    typed_syntax::{PatternT, PatternType, TermT, TermType},
10};
11
12/// A normal-form patterns
13#[derive(Clone, Debug, PartialEq)]
14pub enum PatternN {
15    /// A composition "p_1 . ... . p_n" with given type
16    Comp(Vec<PatternN>, PatternType),
17    /// A tensor "p_1 x ... x p_n"
18    Tensor(Vec<PatternN>),
19    /// A single ket state "|x>"
20    Ket(KetState),
21    /// An "atomic" term. Compound terms are evaluated to pattern compositions/tensors.
22    Unitary(Box<AtomN>),
23}
24
25impl PatternN {
26    /// Convert a normal-form pattern of type qm < qn to an m x n isometry matrix `i`
27    /// and an n x n projector `p` such that
28    /// p + ii^dagger = id
29    pub fn to_inj_and_proj(&self) -> (Mat<Complex<f64>>, Mat<Complex<f64>>) {
30        match self {
31            PatternN::Comp(patterns, ty) => {
32                let mut patterns_iter = patterns.iter().map(PatternN::to_inj_and_proj);
33                if let Some(i) = patterns_iter.next() {
34                    patterns_iter.fold(i, |(i1, p1), (i2, p2)| {
35                        (&i1 * i2, p1 + &i1 * p2 * i1.adjoint())
36                    })
37                } else {
38                    (
39                        Mat::identity(1 << ty.0, 1 << ty.0),
40                        Mat::zeros(1 << ty.0, 1 << ty.0),
41                    )
42                }
43            }
44            PatternN::Tensor(patterns) => {
45                let mut patterns_iter = patterns.iter().map(PatternN::to_inj_and_proj);
46                let i = patterns_iter.next().unwrap();
47                patterns_iter.fold(i, |(i1, p1), (i2, p2)| {
48                    (
49                        i1.kron(i2),
50                        p1.kron(Mat::<Complex<f64>>::identity(p2.nrows(), p2.nrows()))
51                            + (&i1 * i1.adjoint()).kron(p2),
52                    )
53                })
54            }
55            PatternN::Ket(state) => {
56                let m = state.to_state();
57                let cm = state.compl().to_state();
58                (m, cm.as_ref() * cm.adjoint())
59            }
60            PatternN::Unitary(inner) => {
61                let size = inner.get_type().0;
62                (inner.to_unitary(), Mat::zeros(1 << size, 1 << size))
63            }
64        }
65    }
66
67    /// Return a `PatternT` which is the "quotation" of this normal-form pattern.
68    /// Realises that all normal-form patterns are also patterns.
69    pub fn quote(&self) -> PatternT {
70        match self {
71            PatternN::Comp(patterns, ty) => {
72                if patterns.is_empty() {
73                    PatternT::Unitary(Box::new(TermT::Id(TermType(ty.0))))
74                } else {
75                    PatternT::Comp(patterns.iter().map(PatternN::quote).collect())
76                }
77            }
78            PatternN::Tensor(patterns) => {
79                PatternT::Tensor(patterns.iter().map(PatternN::quote).collect())
80            }
81            PatternN::Ket(state) => PatternT::Ket(CompKetState::single(*state)),
82            PatternN::Unitary(inner) => PatternT::Unitary(Box::new(inner.quote())),
83        }
84    }
85
86    fn squash_comp(mut self, acc: &mut Vec<PatternN>) {
87        if let PatternN::Comp(patterns, _) = self {
88            for p in patterns {
89                p.squash_comp(acc);
90            }
91        } else {
92            self.squash();
93            acc.push(self);
94        }
95    }
96
97    fn squash_tensor(mut self, acc: &mut Vec<PatternN>) {
98        if let PatternN::Tensor(patterns) = self {
99            for p in patterns {
100                p.squash_tensor(acc);
101            }
102        } else {
103            self.squash();
104            acc.push(self);
105        }
106    }
107
108    /// Simplifies compositions, tensors, and identities in the given normal-form pattern.
109    pub fn squash(&mut self) {
110        match self {
111            PatternN::Comp(patterns, _) => {
112                let old_patterns = std::mem::take(patterns);
113                for p in old_patterns {
114                    p.squash_comp(patterns);
115                }
116                if patterns.len() == 1 {
117                    *self = patterns.pop().unwrap();
118                }
119            }
120            PatternN::Tensor(patterns) => {
121                let old_patterns = std::mem::take(patterns);
122                for p in old_patterns {
123                    p.squash_tensor(patterns);
124                }
125                if patterns.len() == 1 {
126                    *self = patterns.pop().unwrap();
127                }
128            }
129            PatternN::Ket(_) => {}
130            PatternN::Unitary(inner) => inner.squash(),
131        }
132    }
133}