phase_rs/typed_syntax/
term.rs

1//! Term syntax terms.
2
3use std::{fmt::Display, iter::Sum};
4
5use crate::{
6    circuit_syntax::{TermC, pattern::PatternC, term::ClauseC},
7    normal_syntax::{Buildable, term::AtomN},
8    phase::Phase,
9    raw_syntax::{
10        TermR,
11        term::{AtomR, AtomRInner, TensorR, TensorRInner, TermRInner},
12    },
13    text::Name,
14    typed_syntax::{PatternT, PatternType},
15};
16
17/// A unitary type "qn <-> qn"
18#[derive(Clone, Copy, Debug, PartialEq)]
19pub struct TermType(pub usize);
20
21impl Display for TermType {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        write!(f, "q{}", self.0)
24    }
25}
26
27impl Sum for TermType {
28    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
29        TermType(iter.map(|x| x.0).sum())
30    }
31}
32
33impl TermType {
34    /// Convert a unitary type qn <-> qn to pattern type qn < qn
35    pub fn to_pattern_type(self) -> PatternType {
36        PatternType(self.0, self.0)
37    }
38}
39
40/// Syntax of typed terms
41#[derive(Clone, Debug, PartialEq)]
42pub enum TermT {
43    /// A non-empty composition "t_1 ; ... ; t_n"
44    Comp(Vec<TermT>),
45    /// A tensor "t_1 x ... x t_n"
46    Tensor(Vec<TermT>),
47    /// An identity "id(n)"
48    Id(TermType),
49    /// A (global) phase operator, e.g. "-1" or "ph(0.1pi)"
50    Phase(Phase),
51    /// An "if let" statement, "if let pattern then inner"
52    IfLet {
53        /// Pattern to match on in "if let"
54        pattern: PatternT,
55        /// Body of the "if let"
56        inner: Box<TermT>,
57    },
58    /// Top level symbol, a named gate
59    Gate {
60        /// Name of symbol/gate
61        name: Name,
62        /// Definition of symbol
63        def: Box<TermT>,
64    },
65    /// Inverse of a term "t ^ -1"
66    Inverse(Box<TermT>),
67    /// Square root of a term "sqrt(t)"
68    Sqrt(Box<TermT>),
69}
70
71impl TermT {
72    /// Returns the type of this term
73    pub fn get_type(&self) -> TermType {
74        match self {
75            TermT::Comp(terms) => terms.first().unwrap().get_type(),
76            TermT::Tensor(terms) => terms.iter().map(TermT::get_type).sum(),
77            TermT::Id(ty) => *ty,
78            TermT::Phase(_) => TermType(0),
79            TermT::IfLet { pattern, .. } => TermType(pattern.get_type().0),
80            TermT::Gate { def, .. } => def.get_type(),
81            TermT::Inverse(inner) => inner.get_type(),
82            TermT::Sqrt(inner) => inner.get_type(),
83        }
84    }
85
86    /// Evaluate a term to a given `Buildable` type, expanding top level definitions
87    /// and evaluating inverse and sqrt macros.
88    /// In particular this can be used to generate a `TermN` from a `TermT`.
89    pub fn eval<B: Buildable>(&self) -> B {
90        self.eval_with_phase_mul(1.0)
91    }
92
93    pub(super) fn eval_with_phase_mul<B: Buildable>(&self, phase_mul: f64) -> B {
94        match self {
95            TermT::Comp(terms) => {
96                let mut mapped_terms = terms.iter().map(|t| t.eval_with_phase_mul(phase_mul));
97                if terms.len() == 1 {
98                    mapped_terms.next().unwrap()
99                } else if phase_mul > 0.0 {
100                    B::comp(mapped_terms, &terms.first().unwrap().get_type())
101                } else {
102                    B::comp(mapped_terms.rev(), &terms.first().unwrap().get_type())
103                }
104            }
105            TermT::Tensor(terms) => {
106                if terms.len() == 1 {
107                    terms[0].eval_with_phase_mul(phase_mul)
108                } else {
109                    B::tensor(terms.iter().map(|t| t.eval_with_phase_mul(phase_mul)))
110                }
111            }
112            TermT::Id(ty) => B::comp(std::iter::empty(), ty),
113            TermT::Phase(phase) => B::atom(AtomN::Phase(phase_mul * phase.eval())),
114            TermT::IfLet { pattern, inner } => B::atom(AtomN::IfLet(
115                pattern.eval(),
116                Box::new(inner.eval_with_phase_mul(phase_mul)),
117                TermType(pattern.get_type().0),
118            )),
119            TermT::Gate { def, .. } => def.eval_with_phase_mul(phase_mul),
120            TermT::Inverse(inner) => inner.eval_with_phase_mul(-phase_mul),
121            TermT::Sqrt(inner) => inner.eval_with_phase_mul(phase_mul / 2.0),
122        }
123    }
124
125    /// Returns a `TermC` representing the "circuit-normal-form" of the term.
126    pub fn eval_circ(&self) -> TermC {
127        let mut clauses = vec![];
128        let size = self.get_type().0;
129        let inj = (0..size).collect::<Vec<_>>();
130        self.eval_circ_clause(&PatternC::id(size), &inj, 1.0, &mut clauses);
131        TermC {
132            clauses,
133            ty: self.get_type(),
134        }
135    }
136
137    pub(super) fn eval_circ_clause(
138        &self,
139        pattern: &PatternC,
140        inj: &[usize],
141        phase_mul: f64,
142        clauses: &mut Vec<ClauseC>,
143    ) {
144        match self {
145            TermT::Comp(terms) => {
146                if phase_mul < 0.0 {
147                    for t in terms.iter().rev() {
148                        t.eval_circ_clause(pattern, inj, phase_mul, clauses);
149                    }
150                } else {
151                    for t in terms {
152                        t.eval_circ_clause(pattern, inj, phase_mul, clauses);
153                    }
154                }
155            }
156            TermT::Tensor(terms) => {
157                let mut start = 0;
158                for t in terms {
159                    let size = t.get_type().0;
160                    let end = start + size;
161                    t.eval_circ_clause(pattern, &inj[start..end], phase_mul, clauses);
162                    start = end;
163                }
164            }
165            TermT::Id(_) => {
166                // Intentionally blank
167            }
168            TermT::Phase(phase) => {
169                clauses.push(ClauseC {
170                    pattern: pattern.clone(),
171                    phase: phase_mul * phase.eval(),
172                });
173            }
174            TermT::IfLet {
175                pattern: if_pattern,
176                inner,
177            } => {
178                let mut unitary_clauses = Vec::new();
179                let mut inner_pattern = pattern.clone();
180                let mut inner_inj = inj.to_vec();
181                if_pattern.eval_circ(&mut inner_pattern, &mut inner_inj, &mut unitary_clauses);
182                let temp: Vec<_> = unitary_clauses.iter().rev().map(ClauseC::invert).collect();
183                clauses.extend(unitary_clauses);
184
185                inner.eval_circ_clause(&inner_pattern, &inner_inj, phase_mul, clauses);
186
187                clauses.extend(temp)
188            }
189            TermT::Gate { def, .. } => {
190                def.eval_circ_clause(pattern, inj, phase_mul, clauses);
191            }
192            TermT::Inverse(inner) => {
193                inner.eval_circ_clause(pattern, inj, -phase_mul, clauses);
194            }
195            TermT::Sqrt(inner) => {
196                inner.eval_circ_clause(pattern, inj, phase_mul / 2.0, clauses);
197            }
198        }
199    }
200
201    /// Convert to a raw term.
202    pub fn to_raw(&self) -> TermR<()> {
203        let terms = if let TermT::Comp(terms) = self {
204            terms.iter().map(|t| t.to_raw_tensor()).collect()
205        } else {
206            vec![self.to_raw_tensor()]
207        };
208        TermRInner { terms }.into()
209    }
210
211    fn to_raw_tensor(&self) -> TensorR<()> {
212        let terms = if let TermT::Tensor(terms) = self {
213            terms.iter().map(|t| t.to_raw_atom()).collect()
214        } else {
215            vec![self.to_raw_atom()]
216        };
217        TensorRInner { terms }.into()
218    }
219
220    fn to_raw_atom(&self) -> AtomR<()> {
221        match self {
222            TermT::Id(ty) => AtomRInner::Id(ty.0),
223            TermT::Phase(phase) => AtomRInner::Phase(*phase),
224            TermT::IfLet { pattern, inner } => AtomRInner::IfLet {
225                pattern: pattern.to_raw(),
226                inner: Box::new(inner.to_raw_tensor()),
227            },
228            TermT::Gate { name, .. } => AtomRInner::Gate(name.to_owned()),
229            TermT::Inverse(inner) => AtomRInner::Inverse(Box::new(inner.to_raw_atom())),
230            TermT::Sqrt(inner) => AtomRInner::Sqrt(Box::new(inner.to_raw_atom())),
231            t => AtomRInner::Brackets(t.to_raw()),
232        }
233        .into()
234    }
235}