1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import enum
import re
from support.expr_types import *
class TokenType(enum.Enum):
T_NUM = 0
T_PLUS = 1
T_MINUS = 2
T_MULT = 3
T_DIV = 4
T_LPAR = 5
T_RPAR = 6
T_END = 7
T_VAR = 8
T_POW = 9
T_FUN = 10
T_COMMA = 11
T_GT = 12
T_ARROW = 13
class Node:
def __init__(self, token_type, value=None, charno=-1):
self.token_type = token_type
self.value = value
self.charno = charno
self.children = []
def __str__(self):
if self.token_type in [TokenType.T_NUM, TokenType.T_VAR]:
return str(self.value)
elif self.token_type == TokenType.T_ARROW:
return str(self.children[0]) + " -> " + str(self.children[1])
elif self.token_type == TokenType.T_FUN:
return self.value[0] + "(" + str(self.value[1]) + ")"
else:
return str(self.value) + (("(" + ", ".join([str(x) for x in self.children]) + ")") if len(self.children) > 0 else "")
def __repr__(self):
return str(self)
mappings = {
'+': TokenType.T_PLUS,
'-': TokenType.T_MINUS,
'*': TokenType.T_MULT,
'/': TokenType.T_DIV,
'^': TokenType.T_POW,
'(': TokenType.T_LPAR,
')': TokenType.T_RPAR
}
rev_map = {}
for k in mappings:
rev_map[mappings[k]] = k
def lexical_analysis(s):
tokens = []
for i, c in enumerate(s):
if re.match(r'>', c):
token = Node(TokenType.T_GT)
elif c in mappings:
token_type = mappings[c]
token = Node(token_type, value=c, charno=i)
elif re.match(r',', c):
token = Node(TokenType.T_COMMA)
elif re.match(r'[0-9.]', c):
token = Node(TokenType.T_NUM, value=c, charno=i)
elif re.match(r'[a-z]', c):
token = Node(TokenType.T_VAR, value=c, charno=i)
elif re.match(r'\s', c):
continue
else:
raise Exception('Invalid token: {}'.format(c))
if len(tokens) > 0 and token.token_type == tokens[-1].token_type and token.token_type in [TokenType.T_NUM, TokenType.T_VAR]:
tokens[-1].value += token.value
else:
tokens.append(token)
tokens.append(Node(TokenType.T_END))
return tokens
def match(tokens, token):
if tokens[0].token_type == token:
return tokens.pop(0)
else:
print(tokens)
raise Exception('Invalid syntax on token {}'.format(tokens[0].token_type))
def parse_e(tokens):
left_node = parse_e2(tokens)
while tokens[0].token_type in [TokenType.T_PLUS, TokenType.T_MINUS]:
node = tokens.pop(0)
if node.token_type == TokenType.T_MINUS and tokens[0].token_type == TokenType.T_GT:
_ = tokens.pop(0)
node = Node(TokenType.T_ARROW)
node.children.append(left_node)
node.children.append(parse_e2(tokens))
left_node = node
return left_node
def parse_e2(tokens):
left_node = parse_e3(tokens)
while tokens[0].token_type in [TokenType.T_MULT, TokenType.T_DIV]:
node = tokens.pop(0)
node.children.append(left_node)
node.children.append(parse_e3(tokens))
left_node = node
return left_node
def parse_e3(tokens):
left_node = parse_e4(tokens)
while tokens[0].token_type in [TokenType.T_POW]:
node = tokens.pop(0)
node.children.append(left_node)
node.children.append(parse_e4(tokens))
left_node = node
return left_node
def parse_e4(tokens):
if tokens[0].token_type in [TokenType.T_NUM]:
return tokens.pop(0)
elif tokens[0].token_type in [TokenType.T_VAR]:
if len(tokens) == 1 or tokens[1].token_type != TokenType.T_LPAR:
return tokens.pop(0)
else:
f = tokens.pop(0)
match(tokens, TokenType.T_LPAR)
if tokens[0].token_type == TokenType.T_RPAR:
expressions = []
else:
expressions = [parse_e(tokens)]
while tokens[0].token_type == TokenType.T_COMMA:
match(tokens, TokenType.T_COMMA)
expressions.append(parse_e(tokens))
match(tokens, TokenType.T_RPAR)
return Node(TokenType.T_FUN, value=(f.value, expressions), charno=f.charno)
elif tokens[0].token_type == TokenType.T_MINUS:
tokens.pop(0)
node = Node(TokenType.T_MINUS)
node.children = [parse_e3(tokens)]
return node
match(tokens, TokenType.T_LPAR)
expression = parse_e(tokens)
match(tokens, TokenType.T_RPAR)
return expression
def _convert_ast(ast):
if ast.token_type in rev_map:
# Assuming this is unary negation for now....
if len(ast.children) == 1:
return BinaryOperation('*', -1, _convert_ast(ast.children[0]))
return BinaryOperation(rev_map[ast.token_type], _convert_ast(ast.children[0]), _convert_ast(ast.children[1]))
elif ast.token_type == TokenType.T_ARROW:
return BinaryOperation('->', _convert_ast(ast.children[0]), _convert_ast(ast.children[1]))
elif ast.token_type == TokenType.T_FUN:
return Function(ast.value[0], tuple([_convert_ast(x) for x in ast.value[1]]))
elif ast.token_type == TokenType.T_NUM:
try:
return int(ast.value)
except:
return float(ast.value)
else:
return ast.value
def parse(inputstring):
tokens = lexical_analysis(inputstring)
ast = parse_e(tokens)
match(tokens, TokenType.T_END)
ast = _convert_ast(ast)
return ast