test_newton.py 2.2 KB
import pytest

from sympy import sympify, simplify, expand, nsimplify, Eq, solve
from fractions import Fraction
from termcolor import colored

from support.parsing import parse
    
from support.to_string import expr_to_string
from support.canonicalize import simplify as oursimplify
from support.canonicalize import collect_variables

from random import seed, uniform, shuffle
from tests.helpers.tests import cases, cases_extend
from tests.helpers.tests2 import cases2, cases2_extend

seed(1000)

def us_to_sympy(expr):
    return expr.replace("^", "**").replace("e(", "exp(").replace("arctan(", "atan(")

def sympy_to_us(expr):
    return expr.replace("**", "^").replace("exp(", "e(").replace("atan(", "arctan(")

def pprint(expr, pr=True):
    orig = expr_to_string(expr)
    result = expr_to_string(oursimplify(expr), use_parens=False)
    result_parens = expr_to_string(oursimplify(expr), use_parens=True)
    ref = nsimplify(expand(sympify(us_to_sympy(orig))))
    if pr:
        print(orig, end=" ")
        print(colored("=", 'yellow'), end=" ")
        print(result, end=" ")
        print(colored("?=", 'yellow'), end=" ")
        print(str(ref))
    assert(nsimplify(expand(sympify(us_to_sympy(result)))) == ref)

from src.newton import newton
x0s = [x/10 for x in range(-50, 50)] + [x/50 for x in range(-10, 10)] + [3.35, 1200]

@pytest.mark.parametrize('expr', cases + cases2)
def test_newton(expr):
    orig = expr_to_string(oursimplify(parse(expr)))
    ref = solve(Eq(expand(sympify(us_to_sympy(orig))), 0))
    ref = [float(x) for x in ref if x.is_real]
    if not len(ref):
        ref = None
    expr = oursimplify(parse(expr))
    var = list(collect_variables(expr))
    if not var:
        us = None
    else:
        us = newton(expr, 1, var[0])
    if ref is None or us is not None:
        assert ref == us if ref is None else any([pytest.approx(float(us), 0.001) == x for x in ref])
    else:
        for x in x0s:
            try:
                us = newton(expr, x, var[0])
                if us is not None:
                    assert any([pytest.approx(float(us), 0.001) == x for x in ref])
                    break
            except ZeroDivisionError:
                continue
        else:
            assert False