phase_rs/normal_syntax/
term.rs

1//! Normal-form terms.
2
3use std::f64::consts::PI;
4
5use faer::{Mat, mat};
6use num_complex::Complex;
7
8use crate::{
9    normal_syntax::PatternN,
10    phase::Phase,
11    typed_syntax::{TermT, TermType},
12};
13
14/// A normal-form term
15#[derive(Clone, Debug, PartialEq)]
16pub enum TermN {
17    /// A composition "t_1 ; ... ; t_n" with given type
18    Comp(Vec<TermN>, TermType),
19    /// A tensor "t_1 x ... x t_n"
20    Tensor(Vec<TermN>),
21    /// An "atomic" term
22    Atom(AtomN),
23}
24
25/// "Atomic" terms. Terms which are not compositions or tensors.
26#[derive(Clone, Debug, PartialEq)]
27pub enum AtomN {
28    /// A (global) phase operator, e.g. "-1" or "ph(0.1pi)"
29    Phase(f64),
30    /// An "if let" statement with given pattern, body term, and type
31    IfLet(PatternN, Box<TermN>, TermType),
32}
33
34impl TermN {
35    /// Convert a normal-form term of type qn <-> qn to an n x n unitary matrix.
36    pub fn to_unitary(&self) -> Mat<Complex<f64>> {
37        match self {
38            TermN::Comp(terms, ty) => {
39                let mut terms_iter = terms.iter().map(TermN::to_unitary);
40                match terms_iter.next() {
41                    None => Mat::identity(1 << ty.0, 1 << ty.0),
42                    Some(u) => terms_iter.fold(u, |x, y| y * x),
43                }
44            }
45            TermN::Tensor(terms) => {
46                let mut terms_iter = terms.iter().map(TermN::to_unitary);
47                match terms_iter.next() {
48                    None => Mat::identity(1, 1),
49                    Some(u) => terms_iter.fold(u, |x, y| x.kron(y)),
50                }
51            }
52            TermN::Atom(atom) => atom.to_unitary(),
53        }
54    }
55
56    /// Return a `TermT` which is the "quotation" of this normal-form term.
57    /// Realises that all normal-form terms are also terms.
58    pub fn quote(&self) -> TermT {
59        match self {
60            TermN::Comp(terms, ty) => {
61                if terms.is_empty() {
62                    TermT::Id(*ty)
63                } else {
64                    TermT::Comp(terms.iter().map(TermN::quote).collect())
65                }
66            }
67            TermN::Tensor(terms) => TermT::Tensor(terms.iter().map(TermN::quote).collect()),
68            TermN::Atom(atom) => atom.quote(),
69        }
70    }
71
72    fn squash_comp(mut self, acc: &mut Vec<TermN>) {
73        if let TermN::Comp(terms, _) = self {
74            for t in terms {
75                t.squash_comp(acc);
76            }
77        } else {
78            self.squash();
79            acc.push(self);
80        }
81    }
82
83    fn squash_tensor(mut self, acc: &mut Vec<TermN>) {
84        if let TermN::Tensor(terms) = self {
85            for t in terms {
86                t.squash_tensor(acc);
87            }
88        } else {
89            self.squash();
90            acc.push(self);
91        }
92    }
93
94    /// Simplifies compositions, tensors, and identities in the given normal-form term.
95    pub fn squash(&mut self) {
96        match self {
97            TermN::Comp(terms, _) => {
98                let old_terms = std::mem::take(terms);
99                for t in old_terms {
100                    t.squash_comp(terms);
101                }
102                if terms.len() == 1 {
103                    *self = terms.pop().unwrap();
104                }
105            }
106            TermN::Tensor(terms) => {
107                let old_terms = std::mem::take(terms);
108                for t in old_terms {
109                    t.squash_tensor(terms);
110                }
111                if terms.len() == 1 {
112                    *self = terms.pop().unwrap();
113                }
114            }
115            TermN::Atom(atom) => atom.squash(),
116        }
117    }
118}
119
120impl AtomN {
121    pub(crate) fn get_type(&self) -> TermType {
122        match self {
123            AtomN::Phase(_) => TermType(0),
124            AtomN::IfLet(_, _, ty) => *ty,
125        }
126    }
127
128    /// Convert a normal-form atom of type qn <-> qn to an n x n unitary matrix.
129    pub fn to_unitary(&self) -> Mat<Complex<f64>> {
130        match self {
131            AtomN::Phase(angle) => mat![[Complex::cis(angle * PI)]],
132            AtomN::IfLet(pattern, inner, _) => {
133                let (inj, proj) = pattern.to_inj_and_proj();
134                let u = inner.to_unitary();
135                proj + &inj * u * inj.adjoint()
136            }
137        }
138    }
139
140    pub(super) fn quote(&self) -> TermT {
141        match self {
142            AtomN::Phase(angle) => TermT::Phase(Phase::from_angle(*angle)),
143            AtomN::IfLet(pattern, inner, _) => TermT::IfLet {
144                pattern: pattern.quote(),
145                inner: Box::new(inner.quote()),
146            },
147        }
148    }
149
150    pub(super) fn squash(&mut self) {
151        if let AtomN::IfLet(pattern, inner, _) = self {
152            pattern.squash();
153            inner.squash();
154        }
155    }
156}