from support.parsing import parse
from math import prod
from support.expr_types import *
from collections import Counter
import copy
from support.to_string import expr_to_string
import fractions
import numbers
from functools import cache

from decimal import *
getcontext().prec = 20

EPSILON = 10**(-3)


@cache
def create_comm(expr, l, r, op):
    is_com_of_vals = is_add_of_values if op == '+' else is_mul_of_values
    Op = SumOperation if op == '+' else ProductOperation 

    if is_com_of_vals(expr):
        expr = Op(tuple([l, r]))
    elif isinstance(l, BinaryOperation) and l.op == op:
        expr = Op(tuple([l.left, l.right, r]))
    elif isinstance(r, BinaryOperation) and r.op == op:
        expr = Op(tuple([l, r.left, r.right]))
    elif isinstance(l, Op) and isinstance(r, Op):
        expr = Op(tuple(l.values + r.values))
    elif isinstance(l, Op) and (is_value(r) or is_com_of_vals(r)):
        expr = Op(tuple(l.values + tuple([r])))
    elif isinstance(r, Op) and (is_value(l) or is_com_of_vals(l)):
        expr = Op(tuple([l]) + r.values)
    else:
        return expr
    return collapse_comm(expr.values, op, Op)

@cache
def is_monomial(m):
    return isinstance(m, Number) or isinstance(m, str) or (isinstance(m, Term) and is_monomial(m.right)) or\
        (isinstance(m, BinaryOperation) and (\
            (m.op == '^' and is_monomial(m.left) == 1 and isinstance(m.right, Number)) or\
            (m.op == '*' and (\
                is_monomial(m.right) and isinstance(m.left, Number) or
                is_monomial(m.left) and isinstance(m.right, Number)\
            ))
        )
    )

@cache
def get_monomial_coefficient(m):
    if isinstance(m, str):
        return 1
    if isinstance(m, Number):
        return m
    if isinstance(m, Term):
        return m.left
    if isinstance(m, BinaryOperation):
        if is_monomial(m.left) and m.op == '^' and isinstance(m.right, Number):
            return 1
        if is_monomial(m.left) and m.op == '*' and isinstance(m.right, Number):
            return m.right
        if is_monomial(m.right) and m.op == '*' and isinstance(m.left, Number):
            return m.left
    raise Exception("not a monomial")

@cache
def get_monomial_power(m):
    if isinstance(m, str):
        return 1
    if isinstance(m, Number):
        return 0
    if isinstance(m, Term) and is_monomial(m.right):
        return get_monomial_power(m.right)
    if isinstance(m, BinaryOperation):
        if is_monomial(m.left) and m.op == '^' and isinstance(m.right, Number):
            return m.right
        if is_monomial(m.left) and m.op == '*' and isinstance(m.right, Number):
            return get_monomial_power(m.left)
        if is_monomial(m.right) and m.op == '*' and isinstance(m.left, Number):
            return get_monomial_power(m.right)
    raise Exception("not a monomial")

@cache
def get_monomial_base(m, op):
    if isinstance(m, str):
        return m
    if isinstance(m, Number):
        return 1
    if isinstance(m, Term) and is_monomial(m.right):
        return get_monomial_base(m.right, op)
    if isinstance(m, BinaryOperation):
        if is_monomial(m.left) and m.op == '^' and isinstance(m.right, Number):
            return m.left if op == '*' else m
        if is_monomial(m.left) and m.op == '*' and isinstance(m.right, Number):
            return get_monomial_base(m.left, op)
        if is_monomial(m.right) and m.op == '*' and isinstance(m.left, Number):
            return get_monomial_base(m.right, op)
    raise Exception("not a monomial")

@cache
def get_base(a, op):
    if isinstance(a, Number):
        return 1
    if is_monomial(a):
        return get_monomial_base(a, op)
    if isinstance(a, BinaryOperation) and a.op == '^' and isinstance(a.left, Number):
        return a.left
    return a

def coerce_float(x):
    if x is None:
        return Decimal('Infinity')
    if is_int(x): x = Decimal(x)
    if isinstance(x, fractions.Fraction): x = Decimal(x.numerator)/Decimal(x.denominator)
    if isinstance(x, float): x = Decimal(x)
    return x


@cache
def collect_likes(terms, op):
    by_var = {}
    for t in terms:
        s = simplify(t)
        v = frozenset(collect_variables(s))
        if v not in by_var:
            by_var[v] = []
        by_var[v].append(s)
    result = []
    for varset in by_var:
        exprs = by_var[varset]
        by_func = {}
        for a in exprs:
            f = get_base(a, op)
            if f not in by_func:
                by_func[f] = []
            by_func[f].append(a)
        for func in by_func:
            funcset = by_func[func]
            if op == '*':
                coeff = 1
                power = 0
                col = []
                rest = []
                for f in funcset:
                    if isinstance(f, Number):
                        if int(f) == f:
                            f = int(f)
                        if int(coeff) == coeff:
                            coeff = int(coeff)
                        if isinstance(coeff, Decimal) or isinstance(f, Decimal) or isinstance(coeff, float) or isinstance(f, float):
                            coeff = coerce_float(coeff)
                            f = coerce_float(f)
                        coeff *= f
                    elif is_monomial(f):
                        coeff *= get_monomial_coefficient(f)
                        power += get_monomial_power(f)
                    elif isinstance(f, BinaryOperation) and f.op == '^' and isinstance(f.left, Number) and len(collect_variables(f.right)) > 0:
                        col.append(simplify(f.right)) 
                    else:
                        rest.append(simplify(f))
                if is_monomial(f) or isinstance(f, Number):
                    result.append(simplify(Term(coeff, BinaryOperation('^', func, power))))
                elif len(col) > 0:
                    result.append(simplify(BinaryOperation('^', func, SumOperation(tuple(col)))))
                result += tuple(rest)
            elif op == '+':
                coeff = 0
                for f in funcset:
                    if isinstance(f, Number):
                        if int(f) == f:
                            f = int(f)
                        if int(coeff) == coeff:
                            coeff = int(coeff)
                        if isinstance(coeff, Decimal) or isinstance(f, Decimal) or isinstance(coeff, float) or isinstance(f, float):
                            coeff = coerce_float(coeff)
                            f = coerce_float(f)
                        coeff += f
                    elif is_monomial(f):
                        coeff += get_monomial_coefficient(f)
                    else:
                        result.append(simplify(f))
                if coeff == 1:
                    result.append(func)
                elif func is None:
                    result.append(coeff)
                elif coeff != 0:
                    result.append(simplify(Term(coeff, func)))
    return tuple(result)
    
@cache
def collapse_comm(vals, op, Op):
    recs = [tuple(x.values) for x in vals if isinstance(x, Op)]
    vals = [x for x in vals if not isinstance(x, Op)]
    for vs in recs:
        vals += vs

    vals = tuple([simplify(x) for x in vals])
    vals = collect_likes(vals, op)
    vals = [simplify(x) for x in vals]

    if op == '*':
        t = None
        const = None

    expr = Op(tuple(vals))

    if len(expr.values) == 1:
        return expr.values[0]
    if len(expr.values) == 0:
        return 0 if op == '+' else 1
    return expr

@cache
def evaluate_function(f):
    if f.name == 'diff': 
        from src.differentiation import evaluate_diff
        x = simplify(evaluate_diff(simplify(f.arguments[0]), f.arguments[1])) 
        if is_constant(x, f.arguments[1]):
            return (True, x)
        return (False, x)
    elif f.name == 'arctan':
        return (False, Function(f.name, tuple([simplify(f.arguments[0])])))
    elif f.name == 'e':
        if f.arguments[0] == 0:
            return (True, 1)
        else:
            return (False, f)
    elif f.name == 'sqrt':
        return (False, f)
    elif f.name == 'newton':
        from src.newton import newton
        return (True, newton(*f.arguments))

    raise Exception("function not implemented")

@cache
def replace_in_expr(expr, x, y):
    if isinstance(expr, Number):
        return expr
    if isinstance(expr, Variable) and expr == x:
        return y
    if isinstance(expr, Variable) and expr != x:
        return expr 
    if isinstance(expr, BinaryOperation):
        return BinaryOperation(expr.op, replace_in_expr(expr.left, x, y), replace_in_expr(expr.right, x, y))
    if isinstance(expr, Function):
        f = Function(expr.name, tuple([replace_in_expr(a, x, y) for a in expr.arguments]))
        if f.name == 'sqrt':
            from src.newton import newton
            x = newton(Minus(Pow('x', 2), f.arguments[0]), 2, 'x') 
            return x
        return f
    if isinstance(expr, Term):
        return Term(expr.left, replace_in_expr(expr.right, x, y))
    if isinstance(expr, SumOperation):
        return SumOperation(tuple([replace_in_expr(z, x, y) for z in expr.values]))
    if isinstance(expr, ProductOperation):
        return ProductOperation(tuple([replace_in_expr(z, x, y) for z in expr.values]))
    raise Exception("replace not implemented for: " + str(expr))

def eval_expr(expr, args):
    return _eval_expr(expr, tuple(args.items()))

@cache
def _eval_expr(expr, args):
    try:
        for x, y in args:
            expr = simplify(replace_in_expr(expr, x, y))
        return expr
    except (ZeroDivisionError, OverflowError):
        return None
 
def _simplify(expr):
    expr = reduce_expr(expr)
    if isinstance(expr, SumOperation):
        expr = collapse_comm(tuple([simplify(x) for x in expr.values]), '+', SumOperation)
    if isinstance(expr, ProductOperation) and len(expr.values) == 2 and isinstance(expr.values[0], Number):
        expr = Term(simplify(expr.values[0]), simplify(expr.values[1]))
    if isinstance(expr, ProductOperation) and len(expr.values) == 2 and isinstance(expr.values[1], Number):
        expr = Term(simplify(expr.values[1]), simplify(expr.values[0]))
    if isinstance(expr, ProductOperation):
        expr = collapse_comm(tuple([simplify(x) for x in expr.values]), '*', ProductOperation)
    if isinstance(expr, Term):
        r = simplify(expr.right)
        if isinstance(r, Number):
            expr = simplify(BinaryOperation('*', expr.left, r))
        elif isinstance(expr.right, Term):
            a = expr.left
            b = expr.right.left
            if is_int(a): a = int(a)
            if is_int(b): b = int(b)
            if isinstance(a, Decimal) or isinstance(b, Decimal):
                a = coerce_float(a)
                b = coerce_float(b)

            expr = Term(simplify(a*b), simplify(expr.right.right))
        else:
            expr = Term(expr.left, simplify(expr.right))
    if isinstance(expr, BinaryOperation):
        l_is_const, l = _simplify(expr.left)
        r_is_const, r = _simplify(expr.right)
        if l_is_const and r_is_const:
            if is_int(l): l = int(l)
            if is_int(r): r = int(r)

            if isinstance(l, Decimal) or isinstance(r, Decimal) or isinstance(l, float) or isinstance(r, float):
                l = coerce_float(l)
                r = coerce_float(r)
            if expr.op == '/': 
                if isinstance(l, Decimal) or isinstance(r, Decimal) or isinstance(l, float) or isinstance(r, float):
                    if abs(r) < EPSILON:
                        raise ZeroDivisionError()
                    return (True, l / r)
                return (True, fractions.Fraction(l, r))
            elif expr.op == '+': return (True, l + r)
            elif expr.op == '-': return (True, l - r)
            elif expr.op == '*': return (True, l * r)
            elif expr.op == '^': return (True, l ** r)
            raise Exception("binary op not handled: " + str(expr.op))

        old_expr = expr
        expr = BinaryOperation(old_expr.op, l, r) 
        if old_expr.op == '+': expr = create_comm(expr, l, r, '+')
        elif old_expr.op == '*': expr = create_comm(expr, l, r, '*')

    if isinstance(expr, fractions.Fraction):
        if expr.denominator == 1:
            return (True, expr.numerator)
        return (True, expr)
    elif isinstance(expr, Number) and int(expr) == expr:
        return (True, int(expr))
    elif isinstance(expr, Number) and isinstance(expr, Decimal):
        return (True, Decimal(expr))
    elif isinstance(expr, Number):
        return (True, expr)

    if isinstance(expr, Function):
        return evaluate_function(expr)

    return (False, expr)

def reduce_expr(expr):
    expr = distribute(expr)
    expr = addition_reductions(expr)
    expr = division_reductions(expr)
    expr = multiplication_reductions(expr)
    expr = power_reductions(expr)
    return expr

@cache
def simplify(expr):
    if isinstance(expr, fractions.Fraction):
        if expr.denominator == 1:
            return expr.numerator
        return expr
    elif isinstance(expr, Number) and int(expr) == expr:
        return int(expr)
    elif isinstance(expr, Number) and isinstance(expr, Decimal):
        return Decimal(expr)
    elif isinstance(expr, Number):
        return expr


    old = None
    while expr != old:
        old = copy.copy(expr)
        v, expr = _simplify(expr)
        if v is None:
            return expr
    return expr

# a * (x + y)
def distribute(expr):
    if isinstance(expr, BinaryOperation) and expr.op == '*':
        if isinstance(expr.left, BinaryOperation) and expr.left.op in set(['+', '-']):
            expr = BinaryOperation(expr.left.op, BinaryOperation('*', left=expr.left.left, right=expr.right), BinaryOperation('*', left=expr.left.right, right=expr.right))
        elif isinstance(expr.right, BinaryOperation) and expr.right.op in set(['+', '-']):
            expr = BinaryOperation(expr.right.op, BinaryOperation('*', left=expr.left, right=expr.right.left), BinaryOperation('*', left=expr.left, right=expr.right.right))
        elif isinstance(expr.left, SumOperation):
            expr = SumOperation(tuple([BinaryOperation('*', left=expr.right, right=x) for x in expr.left.values]))
        elif isinstance(expr.right, SumOperation):
            expr = SumOperation(tuple([BinaryOperation('*', left=expr.left, right=x) for x in expr.right.values]))
    if isinstance(expr, ProductOperation):
        sums = []
        rest = []
        for x in expr.values:
            if (isinstance(x, BinaryOperation) and x.op == '+') or isinstance(x, SumOperation):
                sums.append(x)
            else:
                rest.append(x)
        if len(sums) > 0:
            while len(rest) > 0:
                x = rest.pop()
                sums[0] = BinaryOperation('*', x, sums[0])
        expr = ProductOperation(tuple(sums + rest))
    return expr

def power_reductions(expr):
    if isinstance(expr, BinaryOperation) and expr.op == '^':
        if isinstance(expr.left, Number):
            if expr.left == 0:
                return 0
            if expr.left == 1:
                return 1
        if isinstance(expr.right, Number):
            if expr.right == 0:
                return 1
            if expr.right == 1:
                return expr.left
    if isinstance(expr, Term) and isinstance(expr.right, BinaryOperation) and expr.right.op == '^':
        if isinstance(expr.right.right, Number):
            if expr.right.right == 0:
                return expr.left
            if expr.right.right == 1:
                return simplify(Term(expr.left, expr.right.left))
    return expr

def addition_reductions(expr):
    if isinstance(expr, BinaryOperation) and expr.op == '+':
        if isinstance(expr.left, Number):
            if expr.left == 0:
                return expr.right
        if isinstance(expr.right, Number):
            if expr.right == 0:
                return expr.left
    if isinstance(expr, BinaryOperation) and expr.op == '-':
        if isinstance(expr.left, Number):
            if expr.right == 0:
                return expr.left
        return BinaryOperation('+', simplify(expr.left), simplify(BinaryOperation('*', -1, expr.right))) 
    return expr

def division_reductions(expr):
    if isinstance(expr, BinaryOperation) and expr.op == '/':
        if expr.right == 1:
            return expr.left
        if is_monomial(expr.left) and is_monomial(expr.right):
            l = get_monomial_base(expr.left, '*'), get_monomial_power(expr.left), get_monomial_coefficient(expr.left)
            r = get_monomial_base(expr.right, '*'), get_monomial_power(expr.right), get_monomial_coefficient(expr.right)
            if l[0] == r[0] and isinstance(l[2], numbers.Rational) and isinstance(r[2], numbers.Rational):
                x = simplify(BinaryOperation('*', fractions.Fraction(l[2], r[2]), BinaryOperation('^', l[0], BinaryOperation('-', l[1], r[1]))))
                return x
        if isinstance(expr.left, Term) and isinstance(expr.right, Term):
            return Term(simplify(fractions.Fraction(expr.left.left, expr.right.left)), simplify(BinaryOperation('/', expr.left.right, expr.right.right)))
        if len(collect_variables(expr.left)) > 0 and isinstance(expr.right, Number):
            return BinaryOperation('*', simplify(fractions.Fraction(1, expr.right)), expr.left)
        if len(collect_variables(expr.right)) > 0 and isinstance(expr.left, Number):
            result = BinaryOperation('*', expr.left, BinaryOperation('/', 1, expr.right))
        if expr.right == 1:
            return expr.left

    return expr

def multiplication_reductions(expr):
    if isinstance(expr, BinaryOperation) and expr.op == '*':
        if isinstance(expr.left, Number):
            if expr.left == 0:
                return 0
            elif expr.left == 1:
                return expr.right 
        if isinstance(expr.right, Number):
            if expr.right == 0:
                return 0
            elif expr.right == 1:
                return expr.left
    if isinstance(expr, Term):
        if expr.left == 0:
            return 0
        elif expr.left == 1:
            return expr.right
    if isinstance(expr, ProductOperation):
        expr = ProductOperation(tuple([simplify(x) for x in expr.values if x != 1])) 
        if len(expr.values) == 1:
            expr = expr.values[0]
    return expr