phase_rs/normal_syntax/
pattern.rs1use faer::Mat;
4use num_complex::Complex;
5
6use crate::{
7 ket::{CompKetState, KetState},
8 normal_syntax::term::AtomN,
9 typed_syntax::{PatternT, PatternType, TermT, TermType},
10};
11
12#[derive(Clone, Debug, PartialEq)]
14pub enum PatternN {
15 Comp(Vec<PatternN>, PatternType),
17 Tensor(Vec<PatternN>),
19 Ket(KetState),
21 Unitary(Box<AtomN>),
23}
24
25impl PatternN {
26 pub fn to_inj_and_proj(&self) -> (Mat<Complex<f64>>, Mat<Complex<f64>>) {
30 match self {
31 PatternN::Comp(patterns, ty) => {
32 let mut patterns_iter = patterns.iter().map(PatternN::to_inj_and_proj);
33 if let Some(i) = patterns_iter.next() {
34 patterns_iter.fold(i, |(i1, p1), (i2, p2)| {
35 (&i1 * i2, p1 + &i1 * p2 * i1.adjoint())
36 })
37 } else {
38 (
39 Mat::identity(1 << ty.0, 1 << ty.0),
40 Mat::zeros(1 << ty.0, 1 << ty.0),
41 )
42 }
43 }
44 PatternN::Tensor(patterns) => {
45 let mut patterns_iter = patterns.iter().map(PatternN::to_inj_and_proj);
46 let i = patterns_iter.next().unwrap();
47 patterns_iter.fold(i, |(i1, p1), (i2, p2)| {
48 (
49 i1.kron(i2),
50 p1.kron(Mat::<Complex<f64>>::identity(p2.nrows(), p2.nrows()))
51 + (&i1 * i1.adjoint()).kron(p2),
52 )
53 })
54 }
55 PatternN::Ket(state) => {
56 let m = state.to_state();
57 let cm = state.compl().to_state();
58 (m, cm.as_ref() * cm.adjoint())
59 }
60 PatternN::Unitary(inner) => {
61 let size = inner.get_type().0;
62 (inner.to_unitary(), Mat::zeros(1 << size, 1 << size))
63 }
64 }
65 }
66
67 pub fn quote(&self) -> PatternT {
70 match self {
71 PatternN::Comp(patterns, ty) => {
72 if patterns.is_empty() {
73 PatternT::Unitary(Box::new(TermT::Id(TermType(ty.0))))
74 } else {
75 PatternT::Comp(patterns.iter().map(PatternN::quote).collect())
76 }
77 }
78 PatternN::Tensor(patterns) => {
79 PatternT::Tensor(patterns.iter().map(PatternN::quote).collect())
80 }
81 PatternN::Ket(state) => PatternT::Ket(CompKetState::single(*state)),
82 PatternN::Unitary(inner) => PatternT::Unitary(Box::new(inner.quote())),
83 }
84 }
85
86 fn squash_comp(mut self, acc: &mut Vec<PatternN>) {
87 if let PatternN::Comp(patterns, _) = self {
88 for p in patterns {
89 p.squash_comp(acc);
90 }
91 } else {
92 self.squash();
93 acc.push(self);
94 }
95 }
96
97 fn squash_tensor(mut self, acc: &mut Vec<PatternN>) {
98 if let PatternN::Tensor(patterns) = self {
99 for p in patterns {
100 p.squash_tensor(acc);
101 }
102 } else {
103 self.squash();
104 acc.push(self);
105 }
106 }
107
108 pub fn squash(&mut self) {
110 match self {
111 PatternN::Comp(patterns, _) => {
112 let old_patterns = std::mem::take(patterns);
113 for p in old_patterns {
114 p.squash_comp(patterns);
115 }
116 if patterns.len() == 1 {
117 *self = patterns.pop().unwrap();
118 }
119 }
120 PatternN::Tensor(patterns) => {
121 let old_patterns = std::mem::take(patterns);
122 for p in old_patterns {
123 p.squash_tensor(patterns);
124 }
125 if patterns.len() == 1 {
126 *self = patterns.pop().unwrap();
127 }
128 }
129 PatternN::Ket(_) => {}
130 PatternN::Unitary(inner) => inner.squash(),
131 }
132 }
133}