Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend.item in MPS calculation is incompatible with autograd in jax #959

Open
SUSYUSTC opened this issue Jan 21, 2022 · 2 comments
Open

Comments

@SUSYUSTC
Copy link

In file https://github.com/google/TensorNetwork/blob/master/tensornetwork/matrixproductstates/base_mps.py,
line 319: res.append(self.backend.item(result.tensor))
and line 479 return [self.backend.item(o) for o in c],
the using of self.backend.item is incompatible with autograd in jax (and maybe also other backends).
I haven't checked with other files so those files might have similar issues.
Here's a simple example:

import tensornetwork as tn
import numpy as np
import jax
tn.set_default_backend('jax')
Z = jax.numpy.asarray(np.array([[1.0, 0.0], [0.0, -1.0]], dtype=np.complex64))


def func(x):
    mps = tn.FiniteMPS.random([2, 2, 2, 2], [4, 4, 4], dtype=np.complex64)
    gate = jax.scipy.linalg.expm(Z * x)
    e = mps.measure_local_operator([gate], [0])
    return e[0]


print(func(1.0))                 # output: (1.2248424291610718-2.9802322387695312e-08j)
vg = jax.value_and_grad(func)
print(vg(1.0))                   # error: AttributeError: 'ConcreteArray' object has no attribute 'item'
@mganahl
Copy link
Collaborator

mganahl commented Jan 21, 2022

hi, and thanks for the message!
Can you post the full error message as well? thanks!

@SUSYUSTC
Copy link
Author

hi, and thanks for the message! Can you post the full error message as well? thanks!

The full output is

(0.7063742876052856-1.4842953532934189e-08j)
Traceback (most recent call last):
  File "a.py", line 17, in <module>
    print(vg(1.0))                   # error: AttributeError: 'ConcreteArray' object has no attribute 'item'
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/_src/api.py", line 993, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/_src/api.py", line 2313, in _vjp
    flat_fun, primals_flat, reduce_axes=reduce_axes)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/interpreters/ad.py", line 116, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/interpreters/ad.py", line 103, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 513, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "a.py", line 11, in func
    e = mps.measure_local_operator([gate], [0])
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/matrixproductstates/base_mps.py", line 319, in measure_local_operator
    res.append(self.backend.item(result.tensor))
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/backends/jax/jax_backend.py", line 878, in item
    return tensor.item()
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/jax/core.py", line 568, in __getattr__
    attr = getattr(self.aval, name)
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ConcreteArray' object has no attribute 'item'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "a.py", line 17, in <module>
    print(vg(1.0))                   # error: AttributeError: 'ConcreteArray' object has no attribute 'item'
  File "a.py", line 11, in func
    e = mps.measure_local_operator([gate], [0])
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/matrixproductstates/base_mps.py", line 319, in measure_local_operator
    res.append(self.backend.item(result.tensor))
  File "/home/jiace/anaconda3/lib/python3.7/site-packages/tensornetwork/backends/jax/jax_backend.py", line 878, in item
    return tensor.item()
AttributeError: 'ConcreteArray' object has no attribute 'item'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants