Skip to content

Differentiate MPS-type calculations #183

Answered by jcmgray
vincentmr asked this question in Q&A
Discussion options

You must be logged in to vote

I think your snippet is missing some imports to run directly, but to answer your questions:

  1. It will take the gradient of the tensors in the tensor network(s) supplied to it, in this case the computational state and parametrized tensors have already been contracted into intermediates (i.e the tensors in circ). For jax and its functional style, grad(f(theta)), where theta is any 'pytree', computes df/dtheta. So you need to write a function where params (or PTensor instances) are the inputs to get the gradient with respect to them.
  2. Yes sadly dynamic cutoffs will cause some problems for jax and other libraries with a 'static' computational graph, since the shapes are not predetermined. But t…

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@vincentmr
Comment options

Answer selected by vincentmr
Comment options

You must be logged in to vote
1 reply
@jcmgray
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants