from support.parsing import parse from math import prod from support.expr_types import * from collections import Counter import copy from support.to_string import expr_to_string import fractions import numbers from functools import cache from decimal import * getcontext().prec = 20 EPSILON = 10**(-3) @cache def create_comm(expr, l, r, op): is_com_of_vals = is_add_of_values if op == '+' else is_mul_of_values Op = SumOperation if op == '+' else ProductOperation if is_com_of_vals(expr): expr = Op(tuple([l, r])) elif isinstance(l, BinaryOperation) and l.op == op: expr = Op(tuple([l.left, l.right, r])) elif isinstance(r, BinaryOperation) and r.op == op: expr = Op(tuple([l, r.left, r.right])) elif isinstance(l, Op) and isinstance(r, Op): expr = Op(tuple(l.values + r.values)) elif isinstance(l, Op) and (is_value(r) or is_com_of_vals(r)): expr = Op(tuple(l.values + tuple([r]))) elif isinstance(r, Op) and (is_value(l) or is_com_of_vals(l)): expr = Op(tuple([l]) + r.values) else: return expr return collapse_comm(expr.values, op, Op) @cache def is_monomial(m): return isinstance(m, Number) or isinstance(m, str) or (isinstance(m, Term) and is_monomial(m.right)) or\ (isinstance(m, BinaryOperation) and (\ (m.op == '^' and is_monomial(m.left) == 1 and isinstance(m.right, Number)) or\ (m.op == '*' and (\ is_monomial(m.right) and isinstance(m.left, Number) or is_monomial(m.left) and isinstance(m.right, Number)\ )) ) ) @cache def get_monomial_coefficient(m): if isinstance(m, str): return 1 if isinstance(m, Number): return m if isinstance(m, Term): return m.left if isinstance(m, BinaryOperation): if is_monomial(m.left) and m.op == '^' and isinstance(m.right, Number): return 1 if is_monomial(m.left) and m.op == '*' and isinstance(m.right, Number): return m.right if is_monomial(m.right) and m.op == '*' and isinstance(m.left, Number): return m.left raise Exception("not a monomial") @cache def get_monomial_power(m): if isinstance(m, str): return 1 if isinstance(m, Number): return 0 if isinstance(m, Term) and is_monomial(m.right): return get_monomial_power(m.right) if isinstance(m, BinaryOperation): if is_monomial(m.left) and m.op == '^' and isinstance(m.right, Number): return m.right if is_monomial(m.left) and m.op == '*' and isinstance(m.right, Number): return get_monomial_power(m.left) if is_monomial(m.right) and m.op == '*' and isinstance(m.left, Number): return get_monomial_power(m.right) raise Exception("not a monomial") @cache def get_monomial_base(m, op): if isinstance(m, str): return m if isinstance(m, Number): return 1 if isinstance(m, Term) and is_monomial(m.right): return get_monomial_base(m.right, op) if isinstance(m, BinaryOperation): if is_monomial(m.left) and m.op == '^' and isinstance(m.right, Number): return m.left if op == '*' else m if is_monomial(m.left) and m.op == '*' and isinstance(m.right, Number): return get_monomial_base(m.left, op) if is_monomial(m.right) and m.op == '*' and isinstance(m.left, Number): return get_monomial_base(m.right, op) raise Exception("not a monomial") @cache def get_base(a, op): if isinstance(a, Number): return 1 if is_monomial(a): return get_monomial_base(a, op) if isinstance(a, BinaryOperation) and a.op == '^' and isinstance(a.left, Number): return a.left return a def coerce_float(x): if x is None: return Decimal('Infinity') if is_int(x): x = Decimal(x) if isinstance(x, fractions.Fraction): x = Decimal(x.numerator)/Decimal(x.denominator) if isinstance(x, float): x = Decimal(x) return x @cache def collect_likes(terms, op): by_var = {} for t in terms: s = simplify(t) v = frozenset(collect_variables(s)) if v not in by_var: by_var[v] = [] by_var[v].append(s) result = [] for varset in by_var: exprs = by_var[varset] by_func = {} for a in exprs: f = get_base(a, op) if f not in by_func: by_func[f] = [] by_func[f].append(a) for func in by_func: funcset = by_func[func] if op == '*': coeff = 1 power = 0 col = [] rest = [] for f in funcset: if isinstance(f, Number): if int(f) == f: f = int(f) if int(coeff) == coeff: coeff = int(coeff) if isinstance(coeff, Decimal) or isinstance(f, Decimal) or isinstance(coeff, float) or isinstance(f, float): coeff = coerce_float(coeff) f = coerce_float(f) coeff *= f elif is_monomial(f): coeff *= get_monomial_coefficient(f) power += get_monomial_power(f) elif isinstance(f, BinaryOperation) and f.op == '^' and isinstance(f.left, Number) and len(collect_variables(f.right)) > 0: col.append(simplify(f.right)) else: rest.append(simplify(f)) if is_monomial(f) or isinstance(f, Number): result.append(simplify(Term(coeff, BinaryOperation('^', func, power)))) elif len(col) > 0: result.append(simplify(BinaryOperation('^', func, SumOperation(tuple(col))))) result += tuple(rest) elif op == '+': coeff = 0 for f in funcset: if isinstance(f, Number): if int(f) == f: f = int(f) if int(coeff) == coeff: coeff = int(coeff) if isinstance(coeff, Decimal) or isinstance(f, Decimal) or isinstance(coeff, float) or isinstance(f, float): coeff = coerce_float(coeff) f = coerce_float(f) coeff += f elif is_monomial(f): coeff += get_monomial_coefficient(f) else: result.append(simplify(f)) if coeff == 1: result.append(func) elif func is None: result.append(coeff) elif coeff != 0: result.append(simplify(Term(coeff, func))) return tuple(result) @cache def collapse_comm(vals, op, Op): recs = [tuple(x.values) for x in vals if isinstance(x, Op)] vals = [x for x in vals if not isinstance(x, Op)] for vs in recs: vals += vs vals = tuple([simplify(x) for x in vals]) vals = collect_likes(vals, op) vals = [simplify(x) for x in vals] if op == '*': t = None const = None expr = Op(tuple(vals)) if len(expr.values) == 1: return expr.values[0] if len(expr.values) == 0: return 0 if op == '+' else 1 return expr @cache def evaluate_function(f): if f.name == 'diff': from src.differentiation import evaluate_diff x = simplify(evaluate_diff(simplify(f.arguments[0]), f.arguments[1])) if is_constant(x, f.arguments[1]): return (True, x) return (False, x) elif f.name == 'arctan': return (False, Function(f.name, tuple([simplify(f.arguments[0])]))) elif f.name == 'e': if f.arguments[0] == 0: return (True, 1) else: return (False, f) elif f.name == 'sqrt': return (False, f) elif f.name == 'newton': from src.newton import newton return (True, newton(*f.arguments)) raise Exception("function not implemented") @cache def replace_in_expr(expr, x, y): if isinstance(expr, Number): return expr if isinstance(expr, Variable) and expr == x: return y if isinstance(expr, Variable) and expr != x: return expr if isinstance(expr, BinaryOperation): return BinaryOperation(expr.op, replace_in_expr(expr.left, x, y), replace_in_expr(expr.right, x, y)) if isinstance(expr, Function): f = Function(expr.name, tuple([replace_in_expr(a, x, y) for a in expr.arguments])) if f.name == 'sqrt': from src.newton import newton x = newton(Minus(Pow('x', 2), f.arguments[0]), 2, 'x') return x return f if isinstance(expr, Term): return Term(expr.left, replace_in_expr(expr.right, x, y)) if isinstance(expr, SumOperation): return SumOperation(tuple([replace_in_expr(z, x, y) for z in expr.values])) if isinstance(expr, ProductOperation): return ProductOperation(tuple([replace_in_expr(z, x, y) for z in expr.values])) raise Exception("replace not implemented for: " + str(expr)) def eval_expr(expr, args): return _eval_expr(expr, tuple(args.items())) @cache def _eval_expr(expr, args): try: for x, y in args: expr = simplify(replace_in_expr(expr, x, y)) return expr except (ZeroDivisionError, OverflowError): return None def _simplify(expr): expr = reduce_expr(expr) if isinstance(expr, SumOperation): expr = collapse_comm(tuple([simplify(x) for x in expr.values]), '+', SumOperation) if isinstance(expr, ProductOperation) and len(expr.values) == 2 and isinstance(expr.values[0], Number): expr = Term(simplify(expr.values[0]), simplify(expr.values[1])) if isinstance(expr, ProductOperation) and len(expr.values) == 2 and isinstance(expr.values[1], Number): expr = Term(simplify(expr.values[1]), simplify(expr.values[0])) if isinstance(expr, ProductOperation): expr = collapse_comm(tuple([simplify(x) for x in expr.values]), '*', ProductOperation) if isinstance(expr, Term): r = simplify(expr.right) if isinstance(r, Number): expr = simplify(BinaryOperation('*', expr.left, r)) elif isinstance(expr.right, Term): a = expr.left b = expr.right.left if is_int(a): a = int(a) if is_int(b): b = int(b) if isinstance(a, Decimal) or isinstance(b, Decimal): a = coerce_float(a) b = coerce_float(b) expr = Term(simplify(a*b), simplify(expr.right.right)) else: expr = Term(expr.left, simplify(expr.right)) if isinstance(expr, BinaryOperation): l_is_const, l = _simplify(expr.left) r_is_const, r = _simplify(expr.right) if l_is_const and r_is_const: if is_int(l): l = int(l) if is_int(r): r = int(r) if isinstance(l, Decimal) or isinstance(r, Decimal) or isinstance(l, float) or isinstance(r, float): l = coerce_float(l) r = coerce_float(r) if expr.op == '/': if isinstance(l, Decimal) or isinstance(r, Decimal) or isinstance(l, float) or isinstance(r, float): if abs(r) < EPSILON: raise ZeroDivisionError() return (True, l / r) return (True, fractions.Fraction(l, r)) elif expr.op == '+': return (True, l + r) elif expr.op == '-': return (True, l - r) elif expr.op == '*': return (True, l * r) elif expr.op == '^': return (True, l ** r) raise Exception("binary op not handled: " + str(expr.op)) old_expr = expr expr = BinaryOperation(old_expr.op, l, r) if old_expr.op == '+': expr = create_comm(expr, l, r, '+') elif old_expr.op == '*': expr = create_comm(expr, l, r, '*') if isinstance(expr, fractions.Fraction): if expr.denominator == 1: return (True, expr.numerator) return (True, expr) elif isinstance(expr, Number) and int(expr) == expr: return (True, int(expr)) elif isinstance(expr, Number) and isinstance(expr, Decimal): return (True, Decimal(expr)) elif isinstance(expr, Number): return (True, expr) if isinstance(expr, Function): return evaluate_function(expr) return (False, expr) def reduce_expr(expr): expr = distribute(expr) expr = addition_reductions(expr) expr = division_reductions(expr) expr = multiplication_reductions(expr) expr = power_reductions(expr) return expr @cache def simplify(expr): if isinstance(expr, fractions.Fraction): if expr.denominator == 1: return expr.numerator return expr elif isinstance(expr, Number) and int(expr) == expr: return int(expr) elif isinstance(expr, Number) and isinstance(expr, Decimal): return Decimal(expr) elif isinstance(expr, Number): return expr old = None while expr != old: old = copy.copy(expr) v, expr = _simplify(expr) if v is None: return expr return expr # a * (x + y) def distribute(expr): if isinstance(expr, BinaryOperation) and expr.op == '*': if isinstance(expr.left, BinaryOperation) and expr.left.op in set(['+', '-']): expr = BinaryOperation(expr.left.op, BinaryOperation('*', left=expr.left.left, right=expr.right), BinaryOperation('*', left=expr.left.right, right=expr.right)) elif isinstance(expr.right, BinaryOperation) and expr.right.op in set(['+', '-']): expr = BinaryOperation(expr.right.op, BinaryOperation('*', left=expr.left, right=expr.right.left), BinaryOperation('*', left=expr.left, right=expr.right.right)) elif isinstance(expr.left, SumOperation): expr = SumOperation(tuple([BinaryOperation('*', left=expr.right, right=x) for x in expr.left.values])) elif isinstance(expr.right, SumOperation): expr = SumOperation(tuple([BinaryOperation('*', left=expr.left, right=x) for x in expr.right.values])) if isinstance(expr, ProductOperation): sums = [] rest = [] for x in expr.values: if (isinstance(x, BinaryOperation) and x.op == '+') or isinstance(x, SumOperation): sums.append(x) else: rest.append(x) if len(sums) > 0: while len(rest) > 0: x = rest.pop() sums[0] = BinaryOperation('*', x, sums[0]) expr = ProductOperation(tuple(sums + rest)) return expr def power_reductions(expr): if isinstance(expr, BinaryOperation) and expr.op == '^': if isinstance(expr.left, Number): if expr.left == 0: return 0 if expr.left == 1: return 1 if isinstance(expr.right, Number): if expr.right == 0: return 1 if expr.right == 1: return expr.left if isinstance(expr, Term) and isinstance(expr.right, BinaryOperation) and expr.right.op == '^': if isinstance(expr.right.right, Number): if expr.right.right == 0: return expr.left if expr.right.right == 1: return simplify(Term(expr.left, expr.right.left)) return expr def addition_reductions(expr): if isinstance(expr, BinaryOperation) and expr.op == '+': if isinstance(expr.left, Number): if expr.left == 0: return expr.right if isinstance(expr.right, Number): if expr.right == 0: return expr.left if isinstance(expr, BinaryOperation) and expr.op == '-': if isinstance(expr.left, Number): if expr.right == 0: return expr.left return BinaryOperation('+', simplify(expr.left), simplify(BinaryOperation('*', -1, expr.right))) return expr def division_reductions(expr): if isinstance(expr, BinaryOperation) and expr.op == '/': if expr.right == 1: return expr.left if is_monomial(expr.left) and is_monomial(expr.right): l = get_monomial_base(expr.left, '*'), get_monomial_power(expr.left), get_monomial_coefficient(expr.left) r = get_monomial_base(expr.right, '*'), get_monomial_power(expr.right), get_monomial_coefficient(expr.right) if l[0] == r[0] and isinstance(l[2], numbers.Rational) and isinstance(r[2], numbers.Rational): x = simplify(BinaryOperation('*', fractions.Fraction(l[2], r[2]), BinaryOperation('^', l[0], BinaryOperation('-', l[1], r[1])))) return x if isinstance(expr.left, Term) and isinstance(expr.right, Term): return Term(simplify(fractions.Fraction(expr.left.left, expr.right.left)), simplify(BinaryOperation('/', expr.left.right, expr.right.right))) if len(collect_variables(expr.left)) > 0 and isinstance(expr.right, Number): return BinaryOperation('*', simplify(fractions.Fraction(1, expr.right)), expr.left) if len(collect_variables(expr.right)) > 0 and isinstance(expr.left, Number): result = BinaryOperation('*', expr.left, BinaryOperation('/', 1, expr.right)) if expr.right == 1: return expr.left return expr def multiplication_reductions(expr): if isinstance(expr, BinaryOperation) and expr.op == '*': if isinstance(expr.left, Number): if expr.left == 0: return 0 elif expr.left == 1: return expr.right if isinstance(expr.right, Number): if expr.right == 0: return 0 elif expr.right == 1: return expr.left if isinstance(expr, Term): if expr.left == 0: return 0 elif expr.left == 1: return expr.right if isinstance(expr, ProductOperation): expr = ProductOperation(tuple([simplify(x) for x in expr.values if x != 1])) if len(expr.values) == 1: expr = expr.values[0] return expr