phase_rs/normal_syntax/
term.rs1use std::f64::consts::PI;
4
5use faer::{Mat, mat};
6use num_complex::Complex;
7
8use crate::{
9 normal_syntax::PatternN,
10 phase::Phase,
11 typed_syntax::{TermT, TermType},
12};
13
14#[derive(Clone, Debug, PartialEq)]
16pub enum TermN {
17 Comp(Vec<TermN>, TermType),
19 Tensor(Vec<TermN>),
21 Atom(AtomN),
23}
24
25#[derive(Clone, Debug, PartialEq)]
27pub enum AtomN {
28 Phase(f64),
30 IfLet(PatternN, Box<TermN>, TermType),
32}
33
34impl TermN {
35 pub fn to_unitary(&self) -> Mat<Complex<f64>> {
37 match self {
38 TermN::Comp(terms, ty) => {
39 let mut terms_iter = terms.iter().map(TermN::to_unitary);
40 match terms_iter.next() {
41 None => Mat::identity(1 << ty.0, 1 << ty.0),
42 Some(u) => terms_iter.fold(u, |x, y| y * x),
43 }
44 }
45 TermN::Tensor(terms) => {
46 let mut terms_iter = terms.iter().map(TermN::to_unitary);
47 match terms_iter.next() {
48 None => Mat::identity(1, 1),
49 Some(u) => terms_iter.fold(u, |x, y| x.kron(y)),
50 }
51 }
52 TermN::Atom(atom) => atom.to_unitary(),
53 }
54 }
55
56 pub fn quote(&self) -> TermT {
59 match self {
60 TermN::Comp(terms, ty) => {
61 if terms.is_empty() {
62 TermT::Id(*ty)
63 } else {
64 TermT::Comp(terms.iter().map(TermN::quote).collect())
65 }
66 }
67 TermN::Tensor(terms) => TermT::Tensor(terms.iter().map(TermN::quote).collect()),
68 TermN::Atom(atom) => atom.quote(),
69 }
70 }
71
72 fn squash_comp(mut self, acc: &mut Vec<TermN>) {
73 if let TermN::Comp(terms, _) = self {
74 for t in terms {
75 t.squash_comp(acc);
76 }
77 } else {
78 self.squash();
79 acc.push(self);
80 }
81 }
82
83 fn squash_tensor(mut self, acc: &mut Vec<TermN>) {
84 if let TermN::Tensor(terms) = self {
85 for t in terms {
86 t.squash_tensor(acc);
87 }
88 } else {
89 self.squash();
90 acc.push(self);
91 }
92 }
93
94 pub fn squash(&mut self) {
96 match self {
97 TermN::Comp(terms, _) => {
98 let old_terms = std::mem::take(terms);
99 for t in old_terms {
100 t.squash_comp(terms);
101 }
102 if terms.len() == 1 {
103 *self = terms.pop().unwrap();
104 }
105 }
106 TermN::Tensor(terms) => {
107 let old_terms = std::mem::take(terms);
108 for t in old_terms {
109 t.squash_tensor(terms);
110 }
111 if terms.len() == 1 {
112 *self = terms.pop().unwrap();
113 }
114 }
115 TermN::Atom(atom) => atom.squash(),
116 }
117 }
118}
119
120impl AtomN {
121 pub(crate) fn get_type(&self) -> TermType {
122 match self {
123 AtomN::Phase(_) => TermType(0),
124 AtomN::IfLet(_, _, ty) => *ty,
125 }
126 }
127
128 pub fn to_unitary(&self) -> Mat<Complex<f64>> {
130 match self {
131 AtomN::Phase(angle) => mat![[Complex::cis(angle * PI)]],
132 AtomN::IfLet(pattern, inner, _) => {
133 let (inj, proj) = pattern.to_inj_and_proj();
134 let u = inner.to_unitary();
135 proj + &inj * u * inj.adjoint()
136 }
137 }
138 }
139
140 pub(super) fn quote(&self) -> TermT {
141 match self {
142 AtomN::Phase(angle) => TermT::Phase(Phase::from_angle(*angle)),
143 AtomN::IfLet(pattern, inner, _) => TermT::IfLet {
144 pattern: pattern.quote(),
145 inner: Box::new(inner.quote()),
146 },
147 }
148 }
149
150 pub(super) fn squash(&mut self) {
151 if let AtomN::IfLet(pattern, inner, _) = self {
152 pattern.squash();
153 inner.squash();
154 }
155 }
156}