parsing.py 5.43 KB
import enum
import re

from support.expr_types import *

class TokenType(enum.Enum):
    T_NUM = 0
    T_PLUS = 1
    T_MINUS = 2
    T_MULT = 3
    T_DIV = 4
    T_LPAR = 5
    T_RPAR = 6
    T_END = 7
    T_VAR = 8
    T_POW = 9
    T_FUN = 10
    T_COMMA = 11
    T_GT = 12
    T_ARROW = 13


class Node:
    def __init__(self, token_type, value=None, charno=-1):
        self.token_type = token_type
        self.value = value
        self.charno = charno
        self.children = []

    def __str__(self):
        if self.token_type in [TokenType.T_NUM, TokenType.T_VAR]:
            return str(self.value)
        elif self.token_type == TokenType.T_ARROW:
            return str(self.children[0]) + " -> " + str(self.children[1])
        elif self.token_type == TokenType.T_FUN:
            return self.value[0] + "(" + str(self.value[1]) + ")"
        else:
            return str(self.value) + (("(" + ", ".join([str(x) for x in self.children]) + ")") if len(self.children) > 0 else "")

    def __repr__(self):
        return str(self)

mappings = {
    '+': TokenType.T_PLUS,
    '-': TokenType.T_MINUS,
    '*': TokenType.T_MULT,
    '/': TokenType.T_DIV,
    '^': TokenType.T_POW,
    '(': TokenType.T_LPAR,
    ')': TokenType.T_RPAR
}

rev_map = {}
for k in mappings:
    rev_map[mappings[k]] = k

def lexical_analysis(s): 
    tokens = []
    for i, c in enumerate(s):
        if re.match(r'>', c):
            token = Node(TokenType.T_GT)
        elif c in mappings:
            token_type = mappings[c]
            token = Node(token_type, value=c, charno=i)
        elif re.match(r',', c):
            token = Node(TokenType.T_COMMA)
        elif re.match(r'[0-9.]', c):
            token = Node(TokenType.T_NUM, value=c, charno=i)
        elif re.match(r'[a-z]', c):
            token = Node(TokenType.T_VAR, value=c, charno=i)
        elif re.match(r'\s', c):
            continue
        else:
            raise Exception('Invalid token: {}'.format(c))
        if len(tokens) > 0 and token.token_type == tokens[-1].token_type and token.token_type in [TokenType.T_NUM, TokenType.T_VAR]:
            tokens[-1].value += token.value
        else:
            tokens.append(token)
    tokens.append(Node(TokenType.T_END))
    return tokens


def match(tokens, token):
    if tokens[0].token_type == token:
        return tokens.pop(0)
    else:
        print(tokens)
        raise Exception('Invalid syntax on token {}'.format(tokens[0].token_type))


def parse_e(tokens):
    left_node = parse_e2(tokens)

    while tokens[0].token_type in [TokenType.T_PLUS, TokenType.T_MINUS]:
        node = tokens.pop(0)
        if node.token_type == TokenType.T_MINUS and tokens[0].token_type == TokenType.T_GT:
            _ = tokens.pop(0)
            node = Node(TokenType.T_ARROW) 
        node.children.append(left_node)
        node.children.append(parse_e2(tokens))
        left_node = node
    return left_node


def parse_e2(tokens):
    left_node = parse_e3(tokens)

    while tokens[0].token_type in [TokenType.T_MULT, TokenType.T_DIV]:
        node = tokens.pop(0)
        node.children.append(left_node)
        node.children.append(parse_e3(tokens))
        left_node = node
    return left_node

def parse_e3(tokens):
    left_node = parse_e4(tokens)

    while tokens[0].token_type in [TokenType.T_POW]:
        node = tokens.pop(0)
        node.children.append(left_node)
        node.children.append(parse_e4(tokens))
        left_node = node
    return left_node

def parse_e4(tokens):
    if tokens[0].token_type in [TokenType.T_NUM]:
        return tokens.pop(0)
    elif tokens[0].token_type in [TokenType.T_VAR]:
        if len(tokens) == 1 or tokens[1].token_type != TokenType.T_LPAR:
            return tokens.pop(0)
        else:
            f = tokens.pop(0)
            match(tokens, TokenType.T_LPAR)
            if tokens[0].token_type == TokenType.T_RPAR:
                expressions = []
            else:
                expressions = [parse_e(tokens)]
            while tokens[0].token_type == TokenType.T_COMMA:
                match(tokens, TokenType.T_COMMA)
                expressions.append(parse_e(tokens))
            match(tokens, TokenType.T_RPAR)
            return Node(TokenType.T_FUN, value=(f.value, expressions), charno=f.charno)
    elif tokens[0].token_type == TokenType.T_MINUS:
        tokens.pop(0)
        node = Node(TokenType.T_MINUS)
        node.children = [parse_e3(tokens)]
        return node
    match(tokens, TokenType.T_LPAR)
    expression = parse_e(tokens)
    match(tokens, TokenType.T_RPAR)

    return expression

def _convert_ast(ast):
    if ast.token_type in rev_map:
        # Assuming this is unary negation for now....
        if len(ast.children) == 1:
            return BinaryOperation('*', -1, _convert_ast(ast.children[0]))
        return BinaryOperation(rev_map[ast.token_type], _convert_ast(ast.children[0]), _convert_ast(ast.children[1]))
    elif ast.token_type == TokenType.T_ARROW:
        return BinaryOperation('->', _convert_ast(ast.children[0]), _convert_ast(ast.children[1]))
    elif ast.token_type == TokenType.T_FUN:
        return Function(ast.value[0], tuple([_convert_ast(x) for x in ast.value[1]]))
    elif ast.token_type == TokenType.T_NUM:
        try:
            return int(ast.value)
        except:
            return float(ast.value)
    else:
        return ast.value


def parse(inputstring):
    tokens = lexical_analysis(inputstring)
    ast = parse_e(tokens)
    match(tokens, TokenType.T_END)
    ast = _convert_ast(ast)
    return ast