phase_rs/raw_syntax/
term.rs

1//! Raw syntax terms.
2
3use std::ops::Range;
4
5use pretty::RcDoc;
6use winnow::{
7    LocatingSlice, ModalResult, Parser,
8    ascii::{dec_uint, multispace0, multispace1},
9    combinator::{alt, cut_err, delimited, opt, preceded, separated, seq},
10    error::{StrContext, StrContextValue},
11};
12
13use crate::{
14    phase::Phase,
15    raw_syntax::PatternR,
16    text::{HasParser, Name, Span, Spanned, ToDoc},
17    typecheck::{Env, TypeCheckError},
18    typed_syntax::{TermT, TermType},
19};
20
21/// Raw syntax term with text span.
22/// Represents a list of tensored terms composed together.
23pub type TermR<S> = Spanned<S, TermRInner<S>>;
24
25/// Raw syntax term without text span.
26/// Represents a list of tensored terms composed together.
27#[derive(Clone, Debug, PartialEq)]
28pub struct TermRInner<S> {
29    pub(crate) terms: Vec<TensorR<S>>,
30}
31
32impl<S> ToDoc for TermRInner<S> {
33    fn to_doc(&self) -> RcDoc<'_> {
34        RcDoc::intersperse(
35            self.terms.iter().map(TensorR::to_doc),
36            RcDoc::text(";").append(RcDoc::line()),
37        )
38        .group()
39    }
40}
41
42/// Raw syntax tensored term with text span.
43/// Represents a list of atoms tensored together.
44pub type TensorR<S> = Spanned<S, TensorRInner<S>>;
45
46/// Raw syntax tensored term without text span.
47/// Represents a list of atoms tensored together.
48#[derive(Clone, Debug, PartialEq)]
49pub struct TensorRInner<S> {
50    pub(crate) terms: Vec<AtomR<S>>,
51}
52
53impl<S> ToDoc for TensorRInner<S> {
54    fn to_doc(&self) -> RcDoc<'_> {
55        RcDoc::intersperse(
56            self.terms.iter().map(AtomR::to_doc),
57            RcDoc::line().append("x "),
58        )
59        .group()
60    }
61}
62
63/// Raw syntax atom with text span.
64/// Represents a term other than a tensor or composition (or a composition/tensor in brackets)
65pub type AtomR<S> = Spanned<S, AtomRInner<S>>;
66
67/// Raw syntax atom without text span.
68/// Represents a term other than a tensor or composition (or a composition/tensor in brackets)
69#[derive(Clone, Debug, PartialEq)]
70pub enum AtomRInner<S> {
71    /// A term enclosed in parentheses
72    Brackets(TermR<S>),
73    /// An identity term "id(n)"
74    Id(usize),
75    /// A (global) phase operator, e.g. "-1" or "ph(0.1pi)"
76    Phase(Phase),
77    /// An "if let" statement, "if let pattern then inner"
78    IfLet {
79        /// Pattern to match on in "if let"
80        pattern: PatternR<S>,
81        /// Body of the "if let"
82        inner: Box<TensorR<S>>,
83    },
84    /// Top level symbol, a named gate
85    Gate(Name),
86    /// Inverse of a term "t ^ -1"
87    Inverse(Box<AtomR<S>>),
88    /// Square root of a term "sqrt(t)"
89    Sqrt(Box<AtomR<S>>),
90}
91
92impl<S> ToDoc for AtomRInner<S> {
93    fn to_doc(&self) -> RcDoc<'_> {
94        match self {
95            AtomRInner::Brackets(term) => RcDoc::text("(")
96                .append(RcDoc::line().append(term.to_doc()).nest(2))
97                .append(RcDoc::line())
98                .append(")")
99                .group(),
100            AtomRInner::Id(qubits) => RcDoc::text("id").append(if *qubits == 1 {
101                RcDoc::nil()
102            } else {
103                RcDoc::as_string(qubits)
104            }),
105            AtomRInner::Phase(phase) => phase.to_doc(),
106            AtomRInner::IfLet { pattern, inner, .. } => RcDoc::text("if let")
107                .append(RcDoc::line().append(pattern.to_doc()).nest(2))
108                .append(RcDoc::line())
109                .append("then")
110                .group()
111                .append(RcDoc::line().append(inner.to_doc()).nest(2))
112                .group(),
113            AtomRInner::Gate(name) => name.to_doc(),
114            AtomRInner::Inverse(inner) => inner.to_doc().append(" ^ -1"),
115            AtomRInner::Sqrt(inner) => RcDoc::text("sqrt(")
116                .append(RcDoc::line().append(inner.to_doc()).nest(2))
117                .append(RcDoc::line())
118                .append(")")
119                .group(),
120        }
121    }
122}
123
124impl<S: Span> TermR<S> {
125    /// Typecheck a raw term in given environment
126    /// If `check_sqrt` is not `None`, then checks that the term is "composition free"
127    pub fn check(&self, env: &Env, check_sqrt: Option<&S>) -> Result<TermT, TypeCheckError<S>> {
128        if let Some(span) = check_sqrt
129            && self.inner.terms.len() != 1
130        {
131            return Err(TypeCheckError::TermNotRootable {
132                tm: self.clone(),
133                span_of_root: span.clone(),
134            });
135        }
136        let mut term_iter = self.inner.terms.iter();
137        let mut raw = term_iter.next().unwrap();
138        let t = raw.check(env, check_sqrt)?;
139        let ty1 = t.get_type();
140        let mut v = vec![t];
141        for r in term_iter {
142            let term = r.check(env, check_sqrt)?;
143            let ty2 = term.get_type();
144            if ty1 != ty2 {
145                return Err(TypeCheckError::TypeMismatch {
146                    t1: raw.clone(),
147                    ty1,
148                    t2: r.clone(),
149                    ty2,
150                });
151            }
152            raw = r;
153            v.push(term);
154        }
155        Ok(TermT::Comp(v))
156    }
157}
158
159impl<S: Span> TensorR<S> {
160    fn check(&self, env: &Env, check_sqrt: Option<&S>) -> Result<TermT, TypeCheckError<S>> {
161        Ok(TermT::Tensor(
162            self.inner
163                .terms
164                .iter()
165                .map(|t| t.check(env, check_sqrt))
166                .collect::<Result<_, _>>()?,
167        ))
168    }
169}
170
171impl<S: Span> AtomR<S> {
172    fn check(&self, env: &Env, check_sqrt: Option<&S>) -> Result<TermT, TypeCheckError<S>> {
173        match &self.inner {
174            AtomRInner::Brackets(term) => term.check(env, check_sqrt),
175            AtomRInner::Id(qubits) => Ok(TermT::Id(TermType(*qubits))),
176            AtomRInner::Phase(phase) => Ok(TermT::Phase(*phase)),
177            AtomRInner::IfLet { pattern, inner, .. } => {
178                let p = pattern.check(env)?;
179                let t = inner.check(env, check_sqrt)?;
180                let pty = p.get_type();
181                let tty = t.get_type();
182                if pty.1 != tty.0 {
183                    Err(TypeCheckError::IfTypeMismatch {
184                        p: pattern.clone(),
185                        pty,
186                        t: inner.as_ref().clone(),
187                        tty,
188                    })
189                } else {
190                    Ok(TermT::IfLet {
191                        pattern: p,
192                        inner: Box::new(t),
193                    })
194                }
195            }
196            AtomRInner::Gate(name) => {
197                if let Some(def) = env.0.get(name) {
198                    Ok(TermT::Gate {
199                        name: name.clone(),
200                        def: Box::new(def.clone()),
201                    })
202                } else {
203                    Err(TypeCheckError::UnknownSymbol {
204                        name: name.clone(),
205                        span: self.span.clone(),
206                    })
207                }
208            }
209            AtomRInner::Inverse(inner) => {
210                let inner_t = inner.check(env, check_sqrt)?;
211                Ok(TermT::Inverse(Box::new(inner_t)))
212            }
213            AtomRInner::Sqrt(inner) => {
214                let inner_t = if check_sqrt.is_some() {
215                    inner.check(env, None)?
216                } else {
217                    inner.check(env, Some(&self.span))?
218                };
219
220                Ok(TermT::Sqrt(Box::new(inner_t)))
221            }
222        }
223    }
224}
225
226impl HasParser for TermRInner<Range<usize>> {
227    /// Parser for terms.
228    fn parser(input: &mut LocatingSlice<&str>) -> ModalResult<Self> {
229        separated(1.., TensorR::parser, (multispace0, ';', multispace0))
230            .context(StrContext::Label("term"))
231            .map(|terms| TermRInner { terms })
232            .parse_next(input)
233    }
234}
235
236impl HasParser for TensorRInner<Range<usize>> {
237    fn parser(input: &mut LocatingSlice<&str>) -> ModalResult<Self> {
238        separated(1.., AtomR::parser, (multispace0, 'x', multispace0))
239            .context(StrContext::Label("term"))
240            .map(|terms| TensorRInner { terms })
241            .parse_next(input)
242    }
243}
244
245impl HasParser for AtomRInner<Range<usize>> {
246    fn parser(input: &mut LocatingSlice<&str>) -> ModalResult<Self> {
247        let without_inverse = alt((
248            delimited(
249                ("(", multispace0),
250                cut_err(TermR::parser),
251                cut_err(
252                    (multispace0, ")")
253                        .context(StrContext::Expected(StrContextValue::CharLiteral(')'))),
254                ),
255            )
256            .map(AtomRInner::Brackets),
257            preceded(("sqrt", multispace0), cut_err(AtomR::parser))
258                .map(|inner| AtomRInner::Sqrt(Box::new(inner))),
259            preceded("id", opt(dec_uint)).map(|qubits| AtomRInner::Id(qubits.unwrap_or(1))),
260            preceded(
261                "if",
262                cut_err(seq!(
263		_: multispace1,
264		_: "let".context(StrContext::Expected(StrContextValue::StringLiteral("let"))),
265		_: multispace1,
266		PatternR::parser,
267		_: multispace1,
268		_: "then".context(StrContext::Expected(StrContextValue::StringLiteral("then"))),
269		_: multispace1,
270		TensorR::parser)),
271            )
272            .map(|(pattern, inner)| AtomRInner::IfLet {
273                pattern,
274                inner: Box::new(inner),
275            }),
276            Phase::parser.map(AtomRInner::Phase),
277            Name::parser.map(AtomRInner::Gate),
278        ))
279        .context(StrContext::Expected(StrContextValue::CharLiteral('(')))
280        .context(StrContext::Expected(StrContextValue::StringLiteral("sqrt")))
281        .context(StrContext::Expected(StrContextValue::StringLiteral("id")))
282        .context(StrContext::Expected(StrContextValue::StringLiteral("if")))
283        .context(StrContext::Expected(StrContextValue::CharLiteral('H')))
284        .context(StrContext::Expected(StrContextValue::Description(
285            "identifier",
286        )));
287
288        (
289            without_inverse,
290            opt((
291                multispace0,
292                "^",
293                multispace0,
294                cut_err("-1").context(StrContext::Expected(StrContextValue::StringLiteral("-1"))),
295            ))
296            .context(StrContext::Label("term")),
297        )
298            .with_span()
299            .map(|((inner, invert), span)| {
300                if invert.is_some() {
301                    AtomRInner::Inverse(Box::new(Spanned { inner, span }))
302                } else {
303                    inner
304                }
305            })
306            .parse_next(input)
307    }
308}