phase_rs/typed_syntax/
pattern.rs1use std::{fmt::Display, iter::Sum};
4
5use crate::{
6 circuit_syntax::{pattern::PatternC, term::ClauseC},
7 ket::CompKetState,
8 normal_syntax::PatternN,
9 raw_syntax::{
10 PatternR,
11 pattern::{PatAtomR, PatAtomRInner, PatTensorR, PatTensorRInner, PatternRInner},
12 },
13 typed_syntax::TermT,
14};
15
16#[derive(Clone, Copy, Debug, PartialEq)]
18pub struct PatternType(pub usize, pub usize);
19
20impl Sum for PatternType {
21 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
22 iter.fold(PatternType(0, 0), |PatternType(a, b), PatternType(c, d)| {
23 PatternType(a + c, b + d)
24 })
25 }
26}
27
28impl Display for PatternType {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 write!(f, "q{} -> q{}", self.1, self.0)
31 }
32}
33
34#[derive(Clone, Debug, PartialEq)]
36pub enum PatternT {
37 Comp(Vec<PatternT>),
39 Tensor(Vec<PatternT>),
41 Ket(CompKetState),
43 Unitary(Box<TermT>),
45}
46
47impl PatternT {
48 pub fn get_type(&self) -> PatternType {
50 match self {
51 PatternT::Comp(patterns) => PatternType(
52 patterns.first().unwrap().get_type().0,
53 patterns.last().unwrap().get_type().1,
54 ),
55 PatternT::Tensor(patterns) => patterns.iter().map(PatternT::get_type).sum(),
56 PatternT::Ket(states) => PatternType(states.qubits(), 0),
57 PatternT::Unitary(inner) => inner.get_type().to_pattern_type(),
58 }
59 }
60
61 pub(super) fn eval(&self) -> PatternN {
64 match self {
65 PatternT::Comp(patterns) => {
66 if patterns.len() == 1 {
67 patterns[0].eval()
68 } else {
69 PatternN::Comp(
70 patterns.iter().map(PatternT::eval).collect(),
71 self.get_type(),
72 )
73 }
74 }
75 PatternT::Tensor(patterns) => {
76 if patterns.len() == 1 {
77 patterns[0].eval()
78 } else {
79 PatternN::Tensor(patterns.iter().map(PatternT::eval).collect())
80 }
81 }
82 PatternT::Ket(states) => {
83 PatternN::Tensor(states.iter().map(|&state| PatternN::Ket(state)).collect())
84 }
85 PatternT::Unitary(inner) => inner.eval(),
86 }
87 }
88
89 pub(super) fn eval_circ(
90 &self,
91 pattern: &mut PatternC,
92 inj: &mut Vec<usize>,
93 clauses: &mut Vec<ClauseC>,
94 ) {
95 match self {
96 PatternT::Comp(patterns) => {
97 for p in patterns {
98 p.eval_circ(pattern, inj, clauses);
99 }
100 }
101 PatternT::Tensor(patterns) => {
102 let mut stack: Vec<Vec<usize>> = Vec::new();
103 for p in patterns.iter().rev() {
104 let size = p.get_type().0;
105 let mut i = inj.split_off(inj.len() - size);
106 p.eval_circ(pattern, &mut i, clauses);
107 stack.push(i);
108 }
109 while let Some(i) = stack.pop() {
110 inj.extend(i);
111 }
112 }
113 PatternT::Ket(states) => {
114 for (state, i) in states.iter().zip(inj.drain(0..states.qubits())) {
115 pattern.parts[i] = Some(*state)
116 }
117 }
118 PatternT::Unitary(inner) => {
119 inner.eval_circ_clause(pattern, inj, -1.0, clauses);
120 }
121 }
122 }
123
124 pub fn to_raw(&self) -> PatternR<()> {
126 let patterns = if let PatternT::Comp(patterns) = self {
127 patterns.iter().map(|t| t.to_raw_tensor()).collect()
128 } else {
129 vec![self.to_raw_tensor()]
130 };
131 PatternRInner { patterns }.into()
132 }
133
134 fn to_raw_tensor(&self) -> PatTensorR<()> {
135 let patterns = if let PatternT::Tensor(patterns) = self {
136 patterns.iter().map(|t| t.to_raw_atom()).collect()
137 } else {
138 vec![self.to_raw_atom()]
139 };
140 PatTensorRInner { patterns }.into()
141 }
142
143 fn to_raw_atom(&self) -> PatAtomR<()> {
144 match self {
145 PatternT::Ket(states) => PatAtomRInner::Ket(states.clone()),
146 PatternT::Unitary(inner) => PatAtomRInner::Unitary(Box::new(inner.to_raw())),
147 p => PatAtomRInner::Brackets(p.to_raw()),
148 }
149 .into()
150 }
151}