Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SVD on jax backend and thus split_node cannot be jitted when max_truncation_err is set #953

Open
refraction-ray opened this issue Nov 18, 2021 · 0 comments

Comments

@refraction-ray
Copy link
Contributor

SVD and split_node are ok on tensorflow backend with tensorflow jit:

import tensorflow as tf
tn.set_default_backend("tensorflow")
@tf.function
def f(b):
    a = tn.Node(b)
    n1, n2, _ = tn.split_node(a, left_edges=a[:2], right_edges=a[2:], max_truncation_err=0.5)
    return n1.tensor
f(tf.ones([2,2,2,2]))

But it fails on jax backend as:

import jax
from jax import numpy as jnp
tn.set_default_backend("jax")
@jax.jit
def f(b):
    a = tn.Node(b)
    n1, n2, _ = tn.split_node(a, left_edges=a[:2], right_edges=a[2:], max_truncation_err=0.5)
    return n1.tensor
f(jnp.ones([2,2,2,2]))

The error is raised from svd operation in backends/numpy/decompositions.py: num_sing_vals_keep = min(max_singular_values, num_sing_vals_err) as ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:.

This error is actually as expected even before I tried this, since jax jitted function only accepts and returns tensors with fixed shape, which supports only a subset of functionalities of tf.function. Since split_node with max_truncation_err returns nodes of varying shape (final shape depends on the singular value), it seems to be incompatible with jax jit mechanism.

Any thoughts or workaround on this? As I believe it is very common to apply split_node with max_singular_values in tensornetwork related algorithms and it would be great such algorithms can be jitted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant