import enum import re from support.expr_types import * class TokenType(enum.Enum): T_NUM = 0 T_PLUS = 1 T_MINUS = 2 T_MULT = 3 T_DIV = 4 T_LPAR = 5 T_RPAR = 6 T_END = 7 T_VAR = 8 T_POW = 9 T_FUN = 10 T_COMMA = 11 T_GT = 12 T_ARROW = 13 class Node: def __init__(self, token_type, value=None, charno=-1): self.token_type = token_type self.value = value self.charno = charno self.children = [] def __str__(self): if self.token_type in [TokenType.T_NUM, TokenType.T_VAR]: return str(self.value) elif self.token_type == TokenType.T_ARROW: return str(self.children[0]) + " -> " + str(self.children[1]) elif self.token_type == TokenType.T_FUN: return self.value[0] + "(" + str(self.value[1]) + ")" else: return str(self.value) + (("(" + ", ".join([str(x) for x in self.children]) + ")") if len(self.children) > 0 else "") def __repr__(self): return str(self) mappings = { '+': TokenType.T_PLUS, '-': TokenType.T_MINUS, '*': TokenType.T_MULT, '/': TokenType.T_DIV, '^': TokenType.T_POW, '(': TokenType.T_LPAR, ')': TokenType.T_RPAR } rev_map = {} for k in mappings: rev_map[mappings[k]] = k def lexical_analysis(s): tokens = [] for i, c in enumerate(s): if re.match(r'>', c): token = Node(TokenType.T_GT) elif c in mappings: token_type = mappings[c] token = Node(token_type, value=c, charno=i) elif re.match(r',', c): token = Node(TokenType.T_COMMA) elif re.match(r'[0-9.]', c): token = Node(TokenType.T_NUM, value=c, charno=i) elif re.match(r'[a-z]', c): token = Node(TokenType.T_VAR, value=c, charno=i) elif re.match(r'\s', c): continue else: raise Exception('Invalid token: {}'.format(c)) if len(tokens) > 0 and token.token_type == tokens[-1].token_type and token.token_type in [TokenType.T_NUM, TokenType.T_VAR]: tokens[-1].value += token.value else: tokens.append(token) tokens.append(Node(TokenType.T_END)) return tokens def match(tokens, token): if tokens[0].token_type == token: return tokens.pop(0) else: print(tokens) raise Exception('Invalid syntax on token {}'.format(tokens[0].token_type)) def parse_e(tokens): left_node = parse_e2(tokens) while tokens[0].token_type in [TokenType.T_PLUS, TokenType.T_MINUS]: node = tokens.pop(0) if node.token_type == TokenType.T_MINUS and tokens[0].token_type == TokenType.T_GT: _ = tokens.pop(0) node = Node(TokenType.T_ARROW) node.children.append(left_node) node.children.append(parse_e2(tokens)) left_node = node return left_node def parse_e2(tokens): left_node = parse_e3(tokens) while tokens[0].token_type in [TokenType.T_MULT, TokenType.T_DIV]: node = tokens.pop(0) node.children.append(left_node) node.children.append(parse_e3(tokens)) left_node = node return left_node def parse_e3(tokens): left_node = parse_e4(tokens) while tokens[0].token_type in [TokenType.T_POW]: node = tokens.pop(0) node.children.append(left_node) node.children.append(parse_e4(tokens)) left_node = node return left_node def parse_e4(tokens): if tokens[0].token_type in [TokenType.T_NUM]: return tokens.pop(0) elif tokens[0].token_type in [TokenType.T_VAR]: if len(tokens) == 1 or tokens[1].token_type != TokenType.T_LPAR: return tokens.pop(0) else: f = tokens.pop(0) match(tokens, TokenType.T_LPAR) if tokens[0].token_type == TokenType.T_RPAR: expressions = [] else: expressions = [parse_e(tokens)] while tokens[0].token_type == TokenType.T_COMMA: match(tokens, TokenType.T_COMMA) expressions.append(parse_e(tokens)) match(tokens, TokenType.T_RPAR) return Node(TokenType.T_FUN, value=(f.value, expressions), charno=f.charno) elif tokens[0].token_type == TokenType.T_MINUS: tokens.pop(0) node = Node(TokenType.T_MINUS) node.children = [parse_e3(tokens)] return node match(tokens, TokenType.T_LPAR) expression = parse_e(tokens) match(tokens, TokenType.T_RPAR) return expression def _convert_ast(ast): if ast.token_type in rev_map: # Assuming this is unary negation for now.... if len(ast.children) == 1: return BinaryOperation('*', -1, _convert_ast(ast.children[0])) return BinaryOperation(rev_map[ast.token_type], _convert_ast(ast.children[0]), _convert_ast(ast.children[1])) elif ast.token_type == TokenType.T_ARROW: return BinaryOperation('->', _convert_ast(ast.children[0]), _convert_ast(ast.children[1])) elif ast.token_type == TokenType.T_FUN: return Function(ast.value[0], tuple([_convert_ast(x) for x in ast.value[1]])) elif ast.token_type == TokenType.T_NUM: try: return int(ast.value) except: return float(ast.value) else: return ast.value def parse(inputstring): tokens = lexical_analysis(inputstring) ast = parse_e(tokens) match(tokens, TokenType.T_END) ast = _convert_ast(ast) return ast