Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax autodiff incompatible with SVD? #147

Open
nikitn2 opened this issue Oct 17, 2022 · 9 comments
Open

Jax autodiff incompatible with SVD? #147

nikitn2 opened this issue Oct 17, 2022 · 9 comments
Labels

Comments

@nikitn2
Copy link

nikitn2 commented Oct 17, 2022

What is your issue?

Hello,

Lately I've been trying to use jax-autodifferencing to minimise a certain loss-function of mine, albeit I keep running into errors like “unhashable type: 'DeviceArray'” or “ConcretizationTypeError”. They happen when I try to use compression or tensor.split(), which makes me think that JAX has some sort of incompatibility with SVDs, since SVDs are typically used to compress bonds and split tensors.

Below is a reproduction of the issue based on Chapter 9 in your user guide. By setting chi = None, this code will run, but when chi = 8 and compression is performed, the code crashes with the following error message:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the bool function.
The error occurred while tracing the function for jit. This concrete value was not available in Python because it depends on the values of the argument 'arrays'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

%config InlineBackend.figure_formats = ['svg']
import quimb as qu
import quimb.tensor as qtn
from quimb.tensor.optimize import TNOptimizer

chi = 8
#chi = None
L = 16
H = qu.ham_heis(L, sparse=True, cyclic=True)
gs = qu.groundstate(H)
target = qtn.Dense1D(gs)

bond_dim = 32
mps = qtn.MPS_rand_state(L, bond_dim, cyclic=True)

def normalize_state(psi):
    return psi / (psi.H @ psi) ** 0.5

def negative_overlap(psi, target):
    tn = psi.H & target
    innerProd = tn.contract(tags=..., max_bond = chi) # <----Note the compression here
    return - innerProd**2

optmzr = TNOptimizer(
    mps,                               
    loss_fn=negative_overlap,
    norm_fn=normalize_state,
    loss_constants={'target': target},
    autodiff_backend='jax',
    optimizer='L-BFGS-B',
)

mps_opt = optmzr.optimize(100)

I find this particularly strange given that the code in Chapter 4.8 works just fine for me despite it also making use of jax-autodifferencing along with compression (max_bond is limited to 32 in compute_local_expectation() ).

Do you have any idea what might be wrong?

And thanks very much for this excellent numerical library – I LOVE quimb!

Cheers

@jcmgray
Copy link
Owner

jcmgray commented Oct 17, 2022

Hi @nikitn2, turning on compressed contraction through the usual (exact) contract interface is still quite experimental. I believe the issue is that when you supply max_bond it switches to the compressed contraction method, which has a default cutoff which is non-zero. When using such a cutoff the generated shapes can be dynamic which tracing libraries like jax don't like, so you just need to supply cutoff=0.0.

The general problem here is the difficulty of choosing minimal default arguments/options for these advanced algorithms which don't hide the details. E.g. compressed contraction shouldn't be used with the current default exact optimize='greedy' etc. args / path optimizers, but it would eventually be nice for it to be a simple switch..

Thanks for the kind words about quimb!

@nikitn2
Copy link
Author

nikitn2 commented Oct 18, 2022

Hi @jcmgray,

Thanks for the reply – your explanation makes perfect sense. It's probably too much to expect an autodifferencing framework to easily handle dynamically-sized arrays, and I can indeed see how difficult it is to integrate these considerations into a numerical library as advanced as yours, while still keeping quimb simple to use. Dilemmas...!

Your cutoff=0.0 trick does indeed fix the issue in the example I provided, however, what about the other compression methods? I need to use tensor_network_apply_op_vec(), or MatrixProductState.compress(), in my code, and in this case your solution doesn't work for me. For example, running the below code snippet will still result in the same error as before:

%config InlineBackend.figure_formats = ['svg']
import quimb as qu
import quimb.tensor as qtn
from quimb.tensor.optimize import TNOptimizer

chi = 8
L = 16

builder = qtn.SpinHam1D(S=1)
builder += 1/2, '+', '-'
builder += 1/2, '-', '+'
builder += 1, 'Z', 'Z'
H = builder.build_mpo(L)

bond_dim = 16
mps = qtn.MPS_rand_state(L, bond_dim, phys_dim=3, cyclic=False)

def normalize_state(psi):
    return psi / (psi.H @ psi) ** 0.5

def energy(psi, H):
    
    # Alternative 1
    Hpsi = qtn.tensor_arbgeom.tensor_network_apply_op_vec(tn_op=H,tn_vec = psi,compress=True, max_bond = chi, cutoff=0.0 )
    
    # Alternative 2
    #Hpsi = qtn.tensor_arbgeom.tensor_network_apply_op_vec(tn_op=H,tn_vec = psi,compress=False)
    #Hpsi.compress( max_bond=chi,cutoff=0.0)
    
    tn = psi.H & Hpsi
    return tn^...

optmzr = TNOptimizer(
    mps,                               
    loss_fn=energy,
    norm_fn=normalize_state,
    loss_constants={'H': H},
    autodiff_backend='jax',
    optimizer='L-BFGS-B',
)

mps_opt = optmzr.optimize(100)

Is there perhaps a way for me to just globally turn off dynamic shaping of tensors when using autodiff?

@nikitn2
Copy link
Author

nikitn2 commented Oct 20, 2022

Hi @jcmgray,

I've been playing around a bit with this issue in my code and I find that Jax basically doesn't like it when the bond-dimensions are in any way dynamically allocated.

I've found two fixes so far. In my example above, if you pass renorm=0 along with cutoff=0.0, then the autodiff will run. However, sadly it will also become numerically unstable... Furthermore, in my code in general I also observe that opts['max_bond'] = _MAX_BOND_LOOKUP.get(max_bond, max_bond) in tensor_core.py causes problems. Changing the line to opts['max_bond'] = -1 if(max_bond==None) else max_bond seems to mesh better with Jax.

Do you think there is anything I can do to get the Jax autodifferencing to work with compression in a more stable manner?

@jcmgray
Copy link
Owner

jcmgray commented Oct 20, 2022

Just a couple of things:

  • Is the 1D setup above your main use case? Generally compression is not needed for such cases (e.g. in the above case, contracting an MPO into the MPS and compressing is more expensive than contracting the MPO expectation directly). That is why renorm is indeed set to 0 for non-1d tensor networks in quimb as the standard way to contract these is usually fixed bond dimension (and since there is no canonical form you can't rely on the singular values as much for truncation)

  • Autodiff and compressed contraction is inherently unstable in many cases! That's because unless the approximate contraction is numerically close to the exact contraction, the optimizer learns to exploit the different to produce unphysical results: for example above you are really computing <psi|H'|psi'> where H' might not be hermitian and psi' might differ from psi etc.

@nikitn2
Copy link
Author

nikitn2 commented Oct 21, 2022

Hi @jcmgray,

Thanks very much for the reply! You're right, it does indeed make perfect sense that the above example should be numerically unstable.

Though the above example is not my main use case. My main use case involves applying operators O = O(|psi>) with very large bond-dimensions on wavefunctions |psi> to compute a loss function of the type L = | O|psi> - |phi> |^2 = <O psi | O psi> + <phi | phi> - <O psi| phi> - <phi | O psi>. In my code, O is a tensor network operator (but I could contract it into MPO form) and |psi> and |phi> are either in MPS or tree tensor network form.

My first idea was to contract O into a MPO, then use the zip-up algorithm to compute the |O psi> terms in a reasonably cheap manner, after which L can be calculated by just adding up overlaps of various TNs. However, Jax didn't like my zipup implementation... It gave the error:

File "/Users/ngourianov/opt/anaconda3/lib/python3.9/site-packages/quimb/tensor/tensor_core.py", line 317, in _parse_split_opts
    opts['max_bond'] = _MAX_BOND_LOOKUP.get(max_bond, max_bond)

TypeError: unhashable type: 'DynamicJaxprTracer'

After applying the fix I mentioned in my previous post, I now get a Jax "ConcretizationTypeError", originating at the elif max_bond > 0: line in the _trim_and_renorm_SVD function. Not really sure how to proceed with that.

Giving up on this, I tried just computing the four terms of L using TensorNetwork.contract() with compression switched on, in the hope that it'd be able to compute L with the same complexity as if I used the zip-up algorithm. However, this also doesn't work. Even when I set cutoff=0.0, I get the error:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
The error occurred while tracing the function <unknown> for jit. This value became a tracer due to JAX operations on these lines:

  operation a:bool[] = lt b c
    from line /Users/nik/opt/anaconda3/lib/python3.9/site-packages/quimb/tensor/tensor_core.py:6149 (_compress_neighbors)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

The only way I can compute and minimise L using Jax is by using TensorNetwork.contract() without any compression. But this isn't really scalable due to the large bond-dimension of O...

Do you have any idea what I could do to proceed? I tried using tensorflow instead of Jax, and it actually worked with compression turned on, albeit it's almost two orders of magnitude slower than Jax...

EDIT:

I'm thinking of just restricted O to be a MPO and |psi> and |phi> to be MPSs. Presumably this should be simple enough for the compressed contraction to work. Would the compressed contraction algorithm in this case just automatically implement the zipup algorithm?

@jcmgray
Copy link
Owner

jcmgray commented Oct 21, 2022

I see, so its a fitting task (quimb does actually have TN fit functionality, but probably for full control you might want to handle things manually).

After applying the fix I mentioned in my previous post, I now get a Jax "ConcretizationTypeError", originating at the elif max_bond > 0: line in the _trim_and_renorm_SVD function. Not really sure how to proceed with that.

This is saying that max_bond is being dynamically set somewhere so the concrete value is not available when tracing (which is also why its not hashable, changing from the dict lookup does probably makes sense too however). I suppose you need to check the algorithm to make sure max_bond is always constant.

The only way I can compute and minimise L using Jax is by using TensorNetwork.contract() without any compression. But this isn't really scalable due to the large bond-dimension of O...

Are you sure the complexity of even the zip up algrorithm is not similar as exact contraction <phi|O|psi>? Since I guess you are not optimizing O itself (?) you don't need the constant O^2 overlap term. Or is it the memory overhead from back-propagation that is too much?

I'm thinking of just restricted O to be a MPO and |psi> and |phi> to be MPSs. Presumably this should be simple enough for the compressed contraction to work. Would the compressed contraction algorithm in this case just automatically implement the zipup algorithm?

What the compressed contraction algorithm does depends entirely on the contraction path (optimize kwarg). For 1D and tree like TNs there is no compression required as you can contract exactly without increasing the intermediates size. My understanding for the 'apply mpo to mps and compress' algorithms is that these are useful when you are applying several or many MPOs, i.e. you really have a 2D geometry effectively.

That being said, if the MPO has a really large bond dimension, maybe that's different enough from '1D and tree like' to apply some compression somewhere, you could run a cotengra HyperCompressedOptimizer to search. But I would stress that this stuff is not 'officially' supported in quimb yet.

@nikitn2
Copy link
Author

nikitn2 commented Oct 22, 2022

HI @jcmgray,

Thanks so much for your reply!

Are you sure the complexity of even the zip up algrorithm is not similar as exact contraction <phi|O|psi>? Since I guess you are not optimizing O itself (?) you don't need the constant O^2 overlap term. Or is it the memory overhead from back-propagation that is too much?

It's a bit worse than that I'm afraid, as I'm basically dealing with a highly nonlinear problem. In my case O is itself a sum of operators O_1, O_2, ... with one of the operators, let's say O_N, having very high bond-dimension and depending on the variational function itself, O_N = O_N (|psi>). So I can't really use your fit function, unfortunately. And I also do think I need to calculate the overlap term < psi | O' O | psi > when calculating the loss function L, unless there's something I've missed?

Therefore, even if I represent |psi> as a MPS and O as an MPO, the resulting loss function L = < psi | O' O psi> + <phi|phi> - < psi | O' phi > - <phi| O psi> will still be a “2D geometry”, as you put it, which is why I'd like to use compressed contraction.

So to caculate L, it would be nice to first pre-comute |O psi > = O |psi> = O_1 |psi> ... O_N |psi> using the zipup algorithm such that the bond-dimension is kept in check. Do you have any ideas I could try to fix the ConcretizationTypeError when I try to use zipup ?

That being said, if the MPO has a really large bond dimension, maybe that's different enough from '1D and tree like' to apply some compression somewhere, you could run a cotengra HyperCompressedOptimizer to search. But I would stress that this stuff is not 'officially' supported in quimb yet.

I'll try to investigate this contengra HyperCompressedOptimizer concept. Thank you for telling me about it.

Thanks again for your replies – you've no idea how much time they save for me, and for that I'm incredibly grateful :)

@jcmgray
Copy link
Owner

jcmgray commented Oct 23, 2022

... And I also do think I need to calculate the overlap term < psi | O' O | psi > when calculating the loss function L, unless there's something I've missed?

I see, yes I just meant if you were only interested in finding $\min_{\phi} ||\phi\rangle - O | \psi \rangle|$, then the $\langle \psi | O^{\dagger} O | \psi \rangle$ term is constant & doesn't figure in the optimization, though you might still want it for the actual value of $L$.

... Do you have any ideas I could try to fix the ConcretizationTypeError when I try to use zipup ?

Only that you need to find where your implementation of the algorithm calls tensor_split with cutoff != 0.0 or max_bond is None, possibly there is a call to MatrixProductState.compress or some such function, personally I'd just set a breakpoint with the above condition to find it. You could also try torch, though eventual performance might not be as good.

@nikitn2
Copy link
Author

nikitn2 commented Oct 25, 2022

Ah, I see now the confusion re the Loss function: my variational function is Psi, not Phi! Sorry, I forgot to mention that.

So I need to always hard–set the bond dimension, or it might perform compression even if cutoff=0.0? That makes sense, actually.

Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants