differentiation.py 1.74 KB
from support.expr_types import *
from functools import cache
from support.canonicalize import simplify

@cache
def evaluate_diff(expr, var):
    """
    Relevant Expression Types (and fields):
        - BinaryOperation b: b.left, b.right
        - Function f: f.name, f.arguments
        - SummationOperation s: s.values
        - ProductOperation p: p.values
    Variable types for below:
        - OP_CHAR in ['+', '-', '*', '/', '^']
        - FUNC_NAME in ['e', 'arctan']
        - expr : Expression
        - x : Expression
        - y : Expression
        - var : str
        - xs : List[expr]
    Available/Useful functions:
        - is_constant(expr, var)
        - is_op(expr, OP_CHAR)
        - is_function(expr, FUNC_NAME)
        - is_product(expr)
        - is_summation(expr)

        - Plus(x, y)
        - Minus(x, y)
        - Mul(x, y)
        - Pow(x, y)
        - Div(x, y)
        - Summation(xs)
        - Product(xs)
        - Func(expr, xs)

        - evaluate_diff(expr, var)

    Differentation Rules to Handle:
        - var
        - constant w.r.t. var
        - sum rule
        - subtraction rule
        - product rule
        - division rule: DO THIS AS A RE-WRITE to multiplication, do not use division rule directly
        - power rule for numerical exponent
        - e^x rule (e^x is represented as a function e(x))
        - arctan(x) rule
        - sum rule for SummationOperation([x1, x2, x3, ...])
        - product rule for ProductOperation([x1, x2, x3, ...])
    """
    if is_constant(expr, var):
        return None
    if expr == var:
        return None
    if is_op(expr, '+'):
        return None
    
    # Uh oh...lots of cases are missing and the ones above seem wrong. :(

    raise Exception("not implemented: " + str(expr))