Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Recover up-to-date expression from exported PyTorch model (SingleSymPyModule) #574

Open
fburic opened this issue Mar 20, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@fburic
Copy link

fburic commented Mar 20, 2024

Feature Request

Hi!

Very cool project!

I just wanted to suggest the following feature, which I've implemented for myself:

The export_torch.SingleSymPyModule object only holds the initial expression string at the time of export from a PySR model (Sympy expression).

I'm considering the use case in which the PyTorch model would be further trained to tweak parameters, and it would be very helpful to then inspect the resulting expression, with updated parameter values. However, there's currently no method to update the string.

If this is interesting (either here or in the base sympytorch project @patrick-kidger ), I'm pasting my implementation below, which returns a Sympy expression that represents the current SingleSymPyModule (as an independent function, but could be a proper class __repr__).

It's based on reverse engineering the recursive Node structure so I'm not entirely sure of my parsing (specifically, the conditions for deciding the type of node). The way I'm mapping back function names to sympy.core also feels a bit too broad (using all public content of that module), but working under a time budget :(

I did however test with property-based testing (using hypothesis), making sure the round-trip PySR -> SingleSympyModule -> SymPy expression always agreed on output with the initial PySR model, given random choices of input operations for PySRRegressor. Seems to work fine. Pasting that below as well.

The only thing that's not guaranteed to match is the order of terms. That's fine for my work, plus I'm not sure how much time that would take, but would definitely help to have that as well.

Suggestion

def sympytorch_expr(model: export_torch.SingleSymPyModule) -> sympy.Expr:
    """
    Retrieve the Sympy expression of a SingleSymPyModule.

    Relies on mapping of Sympy operations in sympy.core.__dict__,
    e.g. {'Mul': sympy.core.mul.Mul}

    :param model: SingleSymPyModule instance
    :return: Sympy expression
    """
    str_repr = _sympytorch_node_repr(model._node)
    sympy_op_mapping = sympy.core.__dict__
    # A modicum of sanitizing
    sympy_op_mapping = {op_name: op for op_name, op in sympy_op_mapping.items()
                        if not op_name.startswith('_')}
    return parse_expr(str_repr, local_dict=sympy_op_mapping)


def _sympytorch_node_repr(node) -> str:
    if _sympytorch_node_is_variable(node):
        return node._name

    if _sympytorch_node_is_function(node):
        return str(node)

    if _sympytorch_node_is_parameter(node):
        if isinstance(node._value, torch.nn.Parameter):
            return str(node._value.data.item())
        return str(node._value)

    else:
        # Remove the qualifier from class name for later parsing from sympy.core
        # e.g. sympy.core.mul.Mul -> Mul
        func_repr = str(node._sympy_func).split('.')[-1].split("'")[0]
        args_repr = [_sympytorch_node_repr(arg) for arg in node._args]
        args_repr = '(' + ', '.join(args_repr) + ')'
        args_repr = func_repr + args_repr
        return args_repr


def _sympytorch_node_is_variable(node) -> bool:
    return hasattr(node, '_name')


def _sympytorch_node_is_function(node) -> bool:
    return issubclass(type(node), sympy.core.function.FunctionClass)


def _sympytorch_node_is_parameter(node) -> bool:
    return not hasattr(node, '_args') or not node._args

Property-based test

import numpy as np
from pysr import PySRRegressor
from sympy import lambdify

from hypothesis import given, seed, settings, HealthCheck
import hypothesis.strategies as strat
import pytest


@pytest.fixture
def data_for_test_sympytorch_repr():
    """Cache the data generation to save a little time per example."""
    rng = np.random.default_rng(42)
    X = rng.uniform(low=0, high=2, size=(10, 1))
    X_test = rng.uniform(low=2, high=4, size=(10, 1))
    y = 2 * np.cos(X) + X ** 2 - 2
    return X, X_test, y


@given(
    binary_operators=strat.sets(strat.sampled_from(['+', '*', '/']),
                                min_size=1, max_size=2),
    unary_operators=strat.sets(strat.sampled_from(['sin', 'log', 'sqrt', 'square']),
                               min_size=1, max_size=2),
)
@settings(max_examples=100,
          deadline=10000,   # Account for variation between example times
          suppress_health_check=[HealthCheck.function_scoped_fixture])
@seed(42)
def test_sympytorch_repr(binary_operators, unary_operators, data_for_test_sympytorch_repr):
    """
    Test invariant = outputs match after round-trip
        PySR model (expression) -> SingleSympyModule -> SymPy expression

    under random PySR input operators.
    """
    X, X_test, y = data_for_test_sympytorch_repr

    model = PySRRegressor(binary_operators=list(binary_operators),
                          unary_operators=list(unary_operators),
                          niterations=2,
                          deterministic=True,
                          procs=0,
                          random_state=42,
                          temp_equation_file=True,
                          verbosity=0)
    model.fit(X, y)
    torch_model = model.pytorch()
    torch_model_expr = sympytorch_expr(torch_model)
    torch_model_expr_func = lambdify('x0', torch_model_expr, 'numpy')

    output_expr = torch_model_expr_func(X_test).ravel()
    outupt_pysr = model.predict(X_test).ravel()
    assert np.allclose(output_expr, outupt_pysr)
@fburic fburic added the enhancement New feature or request label Mar 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant