test_differentiation.py 1.96 KB
import pytest

from sympy import sympify, simplify, expand, nsimplify, Eq, solve
from fractions import Fraction
from termcolor import colored

from support.parsing import parse
    
from support.to_string import expr_to_string
from support.canonicalize import simplify as oursimplify
from support.canonicalize import collect_variables

from random import seed, uniform, shuffle
from tests.helpers.tests import cases, cases_extend
from tests.helpers.tests2 import cases2, cases2_extend

seed(1000)

def us_to_sympy(expr):
    return expr.replace("^", "**").replace("e(", "exp(").replace("arctan(", "atan(")

def sympy_to_us(expr):
    return expr.replace("**", "^").replace("exp(", "e(").replace("atan(", "arctan(")

def pprint(expr, pr=True):
    print(expr)
    orig = expr
    result = expr_to_string(oursimplify(parse(expr)), use_parens=False)
    result_parens = expr_to_string(oursimplify(parse(expr)), use_parens=True)
    ref = nsimplify(expand(sympify(us_to_sympy(orig))))
    print(result, ref, orig)
    if pr:
        print(orig, end=" ")
        print(colored("=", 'yellow'), end=" ")
        print(result, end=" ")
        print(colored("?=", 'yellow'), end=" ")
        print(str(ref))
    assert(nsimplify(expand(sympify(us_to_sympy(result)))) == ref)

@pytest.mark.parametrize('expr', [
    "diff(x + 1, x)",
    "3*diff((4*x)^2 + 1, x)",
])
def test_basic_differentiation(expr):
    pprint(expr, False)

@pytest.mark.parametrize('expr', cases + cases_extend)
def test_more_differentiation(expr):
    pprint(('diff(' + expr + ', x)'), False)

@pytest.mark.parametrize('expr', cases2 + cases2_extend)
def test_moremore_differentiation(expr):
    pprint(('diff(' + expr + ', x)'), False)


@pytest.mark.parametrize('expr', cases + cases_extend)
def test_moremoremore_differentiation(expr):
    pprint(('diff(diff(' + expr + ', x), x)'), False)

@pytest.mark.parametrize('expr', cases2 + cases2_extend)
def test_moremoremore_differentiation(expr):
    pprint('diff(diff(' + expr + ', x), x)', False)