1use std::{fmt::Display, iter::Sum};
4
5use crate::{
6 circuit_syntax::{TermC, pattern::PatternC, term::ClauseC},
7 normal_syntax::{Buildable, term::AtomN},
8 phase::Phase,
9 raw_syntax::{
10 TermR,
11 term::{AtomR, AtomRInner, TensorR, TensorRInner, TermRInner},
12 },
13 text::Name,
14 typed_syntax::{PatternT, PatternType},
15};
16
17#[derive(Clone, Copy, Debug, PartialEq)]
19pub struct TermType(pub usize);
20
21impl Display for TermType {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 write!(f, "q{}", self.0)
24 }
25}
26
27impl Sum for TermType {
28 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
29 TermType(iter.map(|x| x.0).sum())
30 }
31}
32
33impl TermType {
34 pub fn to_pattern_type(self) -> PatternType {
36 PatternType(self.0, self.0)
37 }
38}
39
40#[derive(Clone, Debug, PartialEq)]
42pub enum TermT {
43 Comp(Vec<TermT>),
45 Tensor(Vec<TermT>),
47 Id(TermType),
49 Phase(Phase),
51 IfLet {
53 pattern: PatternT,
55 inner: Box<TermT>,
57 },
58 Gate {
60 name: Name,
62 def: Box<TermT>,
64 },
65 Inverse(Box<TermT>),
67 Sqrt(Box<TermT>),
69}
70
71impl TermT {
72 pub fn get_type(&self) -> TermType {
74 match self {
75 TermT::Comp(terms) => terms.first().unwrap().get_type(),
76 TermT::Tensor(terms) => terms.iter().map(TermT::get_type).sum(),
77 TermT::Id(ty) => *ty,
78 TermT::Phase(_) => TermType(0),
79 TermT::IfLet { pattern, .. } => TermType(pattern.get_type().0),
80 TermT::Gate { def, .. } => def.get_type(),
81 TermT::Inverse(inner) => inner.get_type(),
82 TermT::Sqrt(inner) => inner.get_type(),
83 }
84 }
85
86 pub fn eval<B: Buildable>(&self) -> B {
90 self.eval_with_phase_mul(1.0)
91 }
92
93 pub(super) fn eval_with_phase_mul<B: Buildable>(&self, phase_mul: f64) -> B {
94 match self {
95 TermT::Comp(terms) => {
96 let mut mapped_terms = terms.iter().map(|t| t.eval_with_phase_mul(phase_mul));
97 if terms.len() == 1 {
98 mapped_terms.next().unwrap()
99 } else if phase_mul > 0.0 {
100 B::comp(mapped_terms, &terms.first().unwrap().get_type())
101 } else {
102 B::comp(mapped_terms.rev(), &terms.first().unwrap().get_type())
103 }
104 }
105 TermT::Tensor(terms) => {
106 if terms.len() == 1 {
107 terms[0].eval_with_phase_mul(phase_mul)
108 } else {
109 B::tensor(terms.iter().map(|t| t.eval_with_phase_mul(phase_mul)))
110 }
111 }
112 TermT::Id(ty) => B::comp(std::iter::empty(), ty),
113 TermT::Phase(phase) => B::atom(AtomN::Phase(phase_mul * phase.eval())),
114 TermT::IfLet { pattern, inner } => B::atom(AtomN::IfLet(
115 pattern.eval(),
116 Box::new(inner.eval_with_phase_mul(phase_mul)),
117 TermType(pattern.get_type().0),
118 )),
119 TermT::Gate { def, .. } => def.eval_with_phase_mul(phase_mul),
120 TermT::Inverse(inner) => inner.eval_with_phase_mul(-phase_mul),
121 TermT::Sqrt(inner) => inner.eval_with_phase_mul(phase_mul / 2.0),
122 }
123 }
124
125 pub fn eval_circ(&self) -> TermC {
127 let mut clauses = vec![];
128 let size = self.get_type().0;
129 let inj = (0..size).collect::<Vec<_>>();
130 self.eval_circ_clause(&PatternC::id(size), &inj, 1.0, &mut clauses);
131 TermC {
132 clauses,
133 ty: self.get_type(),
134 }
135 }
136
137 pub(super) fn eval_circ_clause(
138 &self,
139 pattern: &PatternC,
140 inj: &[usize],
141 phase_mul: f64,
142 clauses: &mut Vec<ClauseC>,
143 ) {
144 match self {
145 TermT::Comp(terms) => {
146 if phase_mul < 0.0 {
147 for t in terms.iter().rev() {
148 t.eval_circ_clause(pattern, inj, phase_mul, clauses);
149 }
150 } else {
151 for t in terms {
152 t.eval_circ_clause(pattern, inj, phase_mul, clauses);
153 }
154 }
155 }
156 TermT::Tensor(terms) => {
157 let mut start = 0;
158 for t in terms {
159 let size = t.get_type().0;
160 let end = start + size;
161 t.eval_circ_clause(pattern, &inj[start..end], phase_mul, clauses);
162 start = end;
163 }
164 }
165 TermT::Id(_) => {
166 }
168 TermT::Phase(phase) => {
169 clauses.push(ClauseC {
170 pattern: pattern.clone(),
171 phase: phase_mul * phase.eval(),
172 });
173 }
174 TermT::IfLet {
175 pattern: if_pattern,
176 inner,
177 } => {
178 let mut unitary_clauses = Vec::new();
179 let mut inner_pattern = pattern.clone();
180 let mut inner_inj = inj.to_vec();
181 if_pattern.eval_circ(&mut inner_pattern, &mut inner_inj, &mut unitary_clauses);
182 let temp: Vec<_> = unitary_clauses.iter().rev().map(ClauseC::invert).collect();
183 clauses.extend(unitary_clauses);
184
185 inner.eval_circ_clause(&inner_pattern, &inner_inj, phase_mul, clauses);
186
187 clauses.extend(temp)
188 }
189 TermT::Gate { def, .. } => {
190 def.eval_circ_clause(pattern, inj, phase_mul, clauses);
191 }
192 TermT::Inverse(inner) => {
193 inner.eval_circ_clause(pattern, inj, -phase_mul, clauses);
194 }
195 TermT::Sqrt(inner) => {
196 inner.eval_circ_clause(pattern, inj, phase_mul / 2.0, clauses);
197 }
198 }
199 }
200
201 pub fn to_raw(&self) -> TermR<()> {
203 let terms = if let TermT::Comp(terms) = self {
204 terms.iter().map(|t| t.to_raw_tensor()).collect()
205 } else {
206 vec![self.to_raw_tensor()]
207 };
208 TermRInner { terms }.into()
209 }
210
211 fn to_raw_tensor(&self) -> TensorR<()> {
212 let terms = if let TermT::Tensor(terms) = self {
213 terms.iter().map(|t| t.to_raw_atom()).collect()
214 } else {
215 vec![self.to_raw_atom()]
216 };
217 TensorRInner { terms }.into()
218 }
219
220 fn to_raw_atom(&self) -> AtomR<()> {
221 match self {
222 TermT::Id(ty) => AtomRInner::Id(ty.0),
223 TermT::Phase(phase) => AtomRInner::Phase(*phase),
224 TermT::IfLet { pattern, inner } => AtomRInner::IfLet {
225 pattern: pattern.to_raw(),
226 inner: Box::new(inner.to_raw_tensor()),
227 },
228 TermT::Gate { name, .. } => AtomRInner::Gate(name.to_owned()),
229 TermT::Inverse(inner) => AtomRInner::Inverse(Box::new(inner.to_raw_atom())),
230 TermT::Sqrt(inner) => AtomRInner::Sqrt(Box::new(inner.to_raw_atom())),
231 t => AtomRInner::Brackets(t.to_raw()),
232 }
233 .into()
234 }
235}