test_simplify.py 1.79 KB
import pytest

from sympy import sympify, simplify, expand, nsimplify, Eq, solve
from fractions import Fraction
from termcolor import colored

from support.parsing import parse
    
from support.to_string import expr_to_string
from support.canonicalize import simplify as oursimplify
from support.canonicalize import collect_variables

from random import seed, uniform, shuffle
from tests.helpers.tests import cases, cases_extend
from tests.helpers.tests2 import cases2, cases2_extend

seed(1000)

def us_to_sympy(expr):
    return expr.replace("^", "**").replace("e(", "exp(").replace("arctan(", "atan(")

def sympy_to_us(expr):
    return expr.replace("**", "^").replace("exp(", "e(").replace("atan(", "arctan(")

def pprint(expr, pr=True):
    orig = expr_to_string(expr)
    result = expr_to_string(oursimplify(expr), use_parens=False)
    result_parens = expr_to_string(oursimplify(expr), use_parens=True)
    ref = nsimplify(expand(sympify(us_to_sympy(orig))))
    if pr:
        print(orig, end=" ")
        print(colored("=", 'yellow'), end=" ")
        print(result, end=" ")
        print(colored("?=", 'yellow'), end=" ")
        print(str(ref))
    assert(nsimplify(expand(sympify(us_to_sympy(result)))) == ref)


@pytest.mark.parametrize('expr', [
    "x + (1.5 + 2) + x+24",
    "x + 2*x*(1.5+2)+x+24",
    "x*2*x + 2*x*(1.5+2)+x+24",
    "(x + 1) * (x + - 1)",
    "x *3*( (x + 1) * (x + - 1))",
    "x *3*( (2*x + 1) * (x/4 + - 1))",
    "2^x *3*( (2*x + 1) * (x/4 + - 1))",
    "(3*x + 5)/(2 * x^2 - 5*x - 3)",
])
def test_simplification(expr):
    pprint(parse(expr), False)

@pytest.mark.parametrize('expr', cases + cases_extend)
def test_simplification2(expr):
    pprint(parse(expr), False)


@pytest.mark.parametrize('expr', cases2 + cases2_extend)
def test_simplification3(expr):
    pprint(parse(expr), False)