phase_rs/typed_syntax/
pattern.rs

1//! Term syntax patterns.
2
3use std::{fmt::Display, iter::Sum};
4
5use crate::{
6    circuit_syntax::{pattern::PatternC, term::ClauseC},
7    ket::CompKetState,
8    normal_syntax::PatternN,
9    raw_syntax::{
10        PatternR,
11        pattern::{PatAtomR, PatAtomRInner, PatTensorR, PatTensorRInner, PatternRInner},
12    },
13    typed_syntax::TermT,
14};
15
16/// A pattern type "qn < qm"
17#[derive(Clone, Copy, Debug, PartialEq)]
18pub struct PatternType(pub usize, pub usize);
19
20impl Sum for PatternType {
21    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
22        iter.fold(PatternType(0, 0), |PatternType(a, b), PatternType(c, d)| {
23            PatternType(a + c, b + d)
24        })
25    }
26}
27
28impl Display for PatternType {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        write!(f, "q{} -> q{}", self.1, self.0)
31    }
32}
33
34/// Syntax of typed patterns
35#[derive(Clone, Debug, PartialEq)]
36pub enum PatternT {
37    /// A non-empty composition "p_1 . ... . p_n"
38    Comp(Vec<PatternT>),
39    /// A tensor "p_1 x ... x p_n"
40    Tensor(Vec<PatternT>),
41    /// A sequence of ket states "|xyz>", equivalent to "|x> x |y> x |z>"
42    Ket(CompKetState),
43    /// A unitary pattern
44    Unitary(Box<TermT>),
45}
46
47impl PatternT {
48    /// Returns the type of this pattern
49    pub fn get_type(&self) -> PatternType {
50        match self {
51            PatternT::Comp(patterns) => PatternType(
52                patterns.first().unwrap().get_type().0,
53                patterns.last().unwrap().get_type().1,
54            ),
55            PatternT::Tensor(patterns) => patterns.iter().map(PatternT::get_type).sum(),
56            PatternT::Ket(states) => PatternType(states.qubits(), 0),
57            PatternT::Unitary(inner) => inner.get_type().to_pattern_type(),
58        }
59    }
60
61    /// Evaluate a term to a `PatternN`, expanding top level definitions
62    /// and evaluating inverse and sqrt macros.
63    pub(super) fn eval(&self) -> PatternN {
64        match self {
65            PatternT::Comp(patterns) => {
66                if patterns.len() == 1 {
67                    patterns[0].eval()
68                } else {
69                    PatternN::Comp(
70                        patterns.iter().map(PatternT::eval).collect(),
71                        self.get_type(),
72                    )
73                }
74            }
75            PatternT::Tensor(patterns) => {
76                if patterns.len() == 1 {
77                    patterns[0].eval()
78                } else {
79                    PatternN::Tensor(patterns.iter().map(PatternT::eval).collect())
80                }
81            }
82            PatternT::Ket(states) => {
83                PatternN::Tensor(states.iter().map(|&state| PatternN::Ket(state)).collect())
84            }
85            PatternT::Unitary(inner) => inner.eval(),
86        }
87    }
88
89    pub(super) fn eval_circ(
90        &self,
91        pattern: &mut PatternC,
92        inj: &mut Vec<usize>,
93        clauses: &mut Vec<ClauseC>,
94    ) {
95        match self {
96            PatternT::Comp(patterns) => {
97                for p in patterns {
98                    p.eval_circ(pattern, inj, clauses);
99                }
100            }
101            PatternT::Tensor(patterns) => {
102                let mut stack: Vec<Vec<usize>> = Vec::new();
103                for p in patterns.iter().rev() {
104                    let size = p.get_type().0;
105                    let mut i = inj.split_off(inj.len() - size);
106                    p.eval_circ(pattern, &mut i, clauses);
107                    stack.push(i);
108                }
109                while let Some(i) = stack.pop() {
110                    inj.extend(i);
111                }
112            }
113            PatternT::Ket(states) => {
114                for (state, i) in states.iter().zip(inj.drain(0..states.qubits())) {
115                    pattern.parts[i] = Some(*state)
116                }
117            }
118            PatternT::Unitary(inner) => {
119                inner.eval_circ_clause(pattern, inj, -1.0, clauses);
120            }
121        }
122    }
123
124    /// Convert to a raw pattern.
125    pub fn to_raw(&self) -> PatternR<()> {
126        let patterns = if let PatternT::Comp(patterns) = self {
127            patterns.iter().map(|t| t.to_raw_tensor()).collect()
128        } else {
129            vec![self.to_raw_tensor()]
130        };
131        PatternRInner { patterns }.into()
132    }
133
134    fn to_raw_tensor(&self) -> PatTensorR<()> {
135        let patterns = if let PatternT::Tensor(patterns) = self {
136            patterns.iter().map(|t| t.to_raw_atom()).collect()
137        } else {
138            vec![self.to_raw_atom()]
139        };
140        PatTensorRInner { patterns }.into()
141    }
142
143    fn to_raw_atom(&self) -> PatAtomR<()> {
144        match self {
145            PatternT::Ket(states) => PatAtomRInner::Ket(states.clone()),
146            PatternT::Unitary(inner) => PatAtomRInner::Unitary(Box::new(inner.to_raw())),
147            p => PatAtomRInner::Brackets(p.to_raw()),
148        }
149        .into()
150    }
151}