1use std::ops::Range;
4
5use pretty::RcDoc;
6use winnow::{
7 LocatingSlice, ModalResult, Parser,
8 ascii::{dec_uint, multispace0, multispace1},
9 combinator::{alt, cut_err, delimited, opt, preceded, separated, seq},
10 error::{StrContext, StrContextValue},
11};
12
13use crate::{
14 phase::Phase,
15 raw_syntax::PatternR,
16 text::{HasParser, Name, Span, Spanned, ToDoc},
17 typecheck::{Env, TypeCheckError},
18 typed_syntax::{TermT, TermType},
19};
20
21pub type TermR<S> = Spanned<S, TermRInner<S>>;
24
25#[derive(Clone, Debug, PartialEq)]
28pub struct TermRInner<S> {
29 pub(crate) terms: Vec<TensorR<S>>,
30}
31
32impl<S> ToDoc for TermRInner<S> {
33 fn to_doc(&self) -> RcDoc<'_> {
34 RcDoc::intersperse(
35 self.terms.iter().map(TensorR::to_doc),
36 RcDoc::text(";").append(RcDoc::line()),
37 )
38 .group()
39 }
40}
41
42pub type TensorR<S> = Spanned<S, TensorRInner<S>>;
45
46#[derive(Clone, Debug, PartialEq)]
49pub struct TensorRInner<S> {
50 pub(crate) terms: Vec<AtomR<S>>,
51}
52
53impl<S> ToDoc for TensorRInner<S> {
54 fn to_doc(&self) -> RcDoc<'_> {
55 RcDoc::intersperse(
56 self.terms.iter().map(AtomR::to_doc),
57 RcDoc::line().append("x "),
58 )
59 .group()
60 }
61}
62
63pub type AtomR<S> = Spanned<S, AtomRInner<S>>;
66
67#[derive(Clone, Debug, PartialEq)]
70pub enum AtomRInner<S> {
71 Brackets(TermR<S>),
73 Id(usize),
75 Phase(Phase),
77 IfLet {
79 pattern: PatternR<S>,
81 inner: Box<TensorR<S>>,
83 },
84 Gate(Name),
86 Inverse(Box<AtomR<S>>),
88 Sqrt(Box<AtomR<S>>),
90}
91
92impl<S> ToDoc for AtomRInner<S> {
93 fn to_doc(&self) -> RcDoc<'_> {
94 match self {
95 AtomRInner::Brackets(term) => RcDoc::text("(")
96 .append(RcDoc::line().append(term.to_doc()).nest(2))
97 .append(RcDoc::line())
98 .append(")")
99 .group(),
100 AtomRInner::Id(qubits) => RcDoc::text("id").append(if *qubits == 1 {
101 RcDoc::nil()
102 } else {
103 RcDoc::as_string(qubits)
104 }),
105 AtomRInner::Phase(phase) => phase.to_doc(),
106 AtomRInner::IfLet { pattern, inner, .. } => RcDoc::text("if let")
107 .append(RcDoc::line().append(pattern.to_doc()).nest(2))
108 .append(RcDoc::line())
109 .append("then")
110 .group()
111 .append(RcDoc::line().append(inner.to_doc()).nest(2))
112 .group(),
113 AtomRInner::Gate(name) => name.to_doc(),
114 AtomRInner::Inverse(inner) => inner.to_doc().append(" ^ -1"),
115 AtomRInner::Sqrt(inner) => RcDoc::text("sqrt(")
116 .append(RcDoc::line().append(inner.to_doc()).nest(2))
117 .append(RcDoc::line())
118 .append(")")
119 .group(),
120 }
121 }
122}
123
124impl<S: Span> TermR<S> {
125 pub fn check(&self, env: &Env, check_sqrt: Option<&S>) -> Result<TermT, TypeCheckError<S>> {
128 if let Some(span) = check_sqrt
129 && self.inner.terms.len() != 1
130 {
131 return Err(TypeCheckError::TermNotRootable {
132 tm: self.clone(),
133 span_of_root: span.clone(),
134 });
135 }
136 let mut term_iter = self.inner.terms.iter();
137 let mut raw = term_iter.next().unwrap();
138 let t = raw.check(env, check_sqrt)?;
139 let ty1 = t.get_type();
140 let mut v = vec![t];
141 for r in term_iter {
142 let term = r.check(env, check_sqrt)?;
143 let ty2 = term.get_type();
144 if ty1 != ty2 {
145 return Err(TypeCheckError::TypeMismatch {
146 t1: raw.clone(),
147 ty1,
148 t2: r.clone(),
149 ty2,
150 });
151 }
152 raw = r;
153 v.push(term);
154 }
155 Ok(TermT::Comp(v))
156 }
157}
158
159impl<S: Span> TensorR<S> {
160 fn check(&self, env: &Env, check_sqrt: Option<&S>) -> Result<TermT, TypeCheckError<S>> {
161 Ok(TermT::Tensor(
162 self.inner
163 .terms
164 .iter()
165 .map(|t| t.check(env, check_sqrt))
166 .collect::<Result<_, _>>()?,
167 ))
168 }
169}
170
171impl<S: Span> AtomR<S> {
172 fn check(&self, env: &Env, check_sqrt: Option<&S>) -> Result<TermT, TypeCheckError<S>> {
173 match &self.inner {
174 AtomRInner::Brackets(term) => term.check(env, check_sqrt),
175 AtomRInner::Id(qubits) => Ok(TermT::Id(TermType(*qubits))),
176 AtomRInner::Phase(phase) => Ok(TermT::Phase(*phase)),
177 AtomRInner::IfLet { pattern, inner, .. } => {
178 let p = pattern.check(env)?;
179 let t = inner.check(env, check_sqrt)?;
180 let pty = p.get_type();
181 let tty = t.get_type();
182 if pty.1 != tty.0 {
183 Err(TypeCheckError::IfTypeMismatch {
184 p: pattern.clone(),
185 pty,
186 t: inner.as_ref().clone(),
187 tty,
188 })
189 } else {
190 Ok(TermT::IfLet {
191 pattern: p,
192 inner: Box::new(t),
193 })
194 }
195 }
196 AtomRInner::Gate(name) => {
197 if let Some(def) = env.0.get(name) {
198 Ok(TermT::Gate {
199 name: name.clone(),
200 def: Box::new(def.clone()),
201 })
202 } else {
203 Err(TypeCheckError::UnknownSymbol {
204 name: name.clone(),
205 span: self.span.clone(),
206 })
207 }
208 }
209 AtomRInner::Inverse(inner) => {
210 let inner_t = inner.check(env, check_sqrt)?;
211 Ok(TermT::Inverse(Box::new(inner_t)))
212 }
213 AtomRInner::Sqrt(inner) => {
214 let inner_t = if check_sqrt.is_some() {
215 inner.check(env, None)?
216 } else {
217 inner.check(env, Some(&self.span))?
218 };
219
220 Ok(TermT::Sqrt(Box::new(inner_t)))
221 }
222 }
223 }
224}
225
226impl HasParser for TermRInner<Range<usize>> {
227 fn parser(input: &mut LocatingSlice<&str>) -> ModalResult<Self> {
229 separated(1.., TensorR::parser, (multispace0, ';', multispace0))
230 .context(StrContext::Label("term"))
231 .map(|terms| TermRInner { terms })
232 .parse_next(input)
233 }
234}
235
236impl HasParser for TensorRInner<Range<usize>> {
237 fn parser(input: &mut LocatingSlice<&str>) -> ModalResult<Self> {
238 separated(1.., AtomR::parser, (multispace0, 'x', multispace0))
239 .context(StrContext::Label("term"))
240 .map(|terms| TensorRInner { terms })
241 .parse_next(input)
242 }
243}
244
245impl HasParser for AtomRInner<Range<usize>> {
246 fn parser(input: &mut LocatingSlice<&str>) -> ModalResult<Self> {
247 let without_inverse = alt((
248 delimited(
249 ("(", multispace0),
250 cut_err(TermR::parser),
251 cut_err(
252 (multispace0, ")")
253 .context(StrContext::Expected(StrContextValue::CharLiteral(')'))),
254 ),
255 )
256 .map(AtomRInner::Brackets),
257 preceded(("sqrt", multispace0), cut_err(AtomR::parser))
258 .map(|inner| AtomRInner::Sqrt(Box::new(inner))),
259 preceded("id", opt(dec_uint)).map(|qubits| AtomRInner::Id(qubits.unwrap_or(1))),
260 preceded(
261 "if",
262 cut_err(seq!(
263 _: multispace1,
264 _: "let".context(StrContext::Expected(StrContextValue::StringLiteral("let"))),
265 _: multispace1,
266 PatternR::parser,
267 _: multispace1,
268 _: "then".context(StrContext::Expected(StrContextValue::StringLiteral("then"))),
269 _: multispace1,
270 TensorR::parser)),
271 )
272 .map(|(pattern, inner)| AtomRInner::IfLet {
273 pattern,
274 inner: Box::new(inner),
275 }),
276 Phase::parser.map(AtomRInner::Phase),
277 Name::parser.map(AtomRInner::Gate),
278 ))
279 .context(StrContext::Expected(StrContextValue::CharLiteral('(')))
280 .context(StrContext::Expected(StrContextValue::StringLiteral("sqrt")))
281 .context(StrContext::Expected(StrContextValue::StringLiteral("id")))
282 .context(StrContext::Expected(StrContextValue::StringLiteral("if")))
283 .context(StrContext::Expected(StrContextValue::CharLiteral('H')))
284 .context(StrContext::Expected(StrContextValue::Description(
285 "identifier",
286 )));
287
288 (
289 without_inverse,
290 opt((
291 multispace0,
292 "^",
293 multispace0,
294 cut_err("-1").context(StrContext::Expected(StrContextValue::StringLiteral("-1"))),
295 ))
296 .context(StrContext::Label("term")),
297 )
298 .with_span()
299 .map(|((inner, invert), span)| {
300 if invert.is_some() {
301 AtomRInner::Inverse(Box::new(Spanned { inner, span }))
302 } else {
303 inner
304 }
305 })
306 .parse_next(input)
307 }
308}