Error with VMC_SRt #1642
-
Hello! First of all, let me say thank you to all of the NetKet developers. What an amazing tool! I'm trying to get VMC_SRt working (amazing that it has been added!). Unfortunately, it is currently giving me an error when I run the driver, while using standard SR works just fine. The model is built on a discrete spin Hilbert space with complex parameters and complex output. Using gs = nk.driver.VMC(H, optimizer, variational_state=vstate, preconditioner=nk.optimizer.SR(diag_shift=0.1, holomorphic=False), holomorphic=False) the model trains and converges with no errors (and reproduces known results from the literature). Changing this to import netket.experimental as nkx
gs = nkx.driver.VMC_SRt(H, optimizer, diag_shift=0.01, variational_state=vstate, jacobian_mode="complex") leads to the error ---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-33-d2c8dffbbffd>](https://localhost:8080/#) in <cell line: 3>()
1 log = nk.logging.RuntimeLog()
2
----> 3 gs.run(n_iter=100, out=log)
4
5 ffn_energy = vstate.expect(H)
8 frames
[/usr/local/lib/python3.10/dist-packages/netket/driver/abstract_variational_driver.py](https://localhost:8080/#) in run(self, n_iter, out, obs, show_progress, save_params_every, write_every, step_size, callback)
257 first_step = True
258
--> 259 for step in self.iter(n_iter, step_size):
260 log_data = self.estimate(obs)
261 self._log_additional_data(log_data, step)
[/usr/local/lib/python3.10/dist-packages/netket/driver/abstract_variational_driver.py](https://localhost:8080/#) in iter(self, n_steps, step)
167 for _ in range(0, n_steps, step):
168 for i in range(0, step):
--> 169 dp = self._forward_and_backward()
170 if i == 0:
171 yield self.step_count
[/usr/local/lib/python3.10/dist-packages/netket/experimental/driver/vmc_srt.py](https://localhost:8080/#) in _forward_and_backward(self)
244 )
245
--> 246 self._dp = self._unravel_params_fn(updates)
247
248 return self._dp
[... skipping hidden 13 frame]
[/usr/local/lib/python3.10/dist-packages/jax/_src/flatten_util.py](https://localhost:8080/#) in unravel_pytree(treedef, unravel_list, flat)
51
52 def unravel_pytree(treedef, unravel_list, flat):
---> 53 return tree_unflatten(treedef, unravel_list(flat))
54
55 def _ravel_list(lst):
[... skipping hidden 1 frame]
[/usr/local/lib/python3.10/dist-packages/jax/_src/flatten_util.py](https://localhost:8080/#) in _unravel_list_single_dtype(indices, shapes, arr)
76 def _unravel_list_single_dtype(indices, shapes, arr):
77 chunks = jnp.split(arr, indices[:-1])
---> 78 return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]
79
80 def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr):
[/usr/local/lib/python3.10/dist-packages/jax/_src/flatten_util.py](https://localhost:8080/#) in <listcomp>(.0)
76 def _unravel_list_single_dtype(indices, shapes, arr):
77 chunks = jnp.split(arr, indices[:-1])
---> 78 return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]
79
80 def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr):
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py](https://localhost:8080/#) in meth(self, *args, **kwargs)
731 def _forward_method_to_aval(name):
732 def meth(self, *args, **kwargs):
--> 733 return getattr(self.aval, name).fun(self, *args, **kwargs)
734 return meth
735
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py](https://localhost:8080/#) in _reshape(a, order, *args)
141 newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
142 if order == "C":
--> 143 return lax.reshape(a, newshape, None)
144 elif order == "F":
145 dims = list(range(a.ndim)[::-1])
[... skipping hidden 8 frame]
[/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py](https://localhost:8080/#) in _reshape_shape_rule(operand, new_sizes, dimensions)
3376 not math.prod(np.shape(operand)) == math.prod(new_sizes)):
3377 msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
-> 3378 raise TypeError(msg.format(new_sizes, np.shape(operand)))
3379 if dimensions is not None:
3380 if set(dimensions) != set(range(np.ndim(operand))):
TypeError: reshape total size must be unchanged, got new_sizes (2, 3, 7) for shape (84,). From my (amateur) reading of the trace, it looks like something is going wrong with chunking? The model itself is rather complicated, so if it isn't a problem with VMC_SRt itself, I'll put together a minimal working example and edit my question. I'm just a little confused as it why it would work with standard VMC but not VMC_SRt, since my understand is that the difference is just in how one computes a certain matrix. Thank you in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
Ah, thanks for the report! Indeed we were only testing SRt for real parameters. The complex parameters is breaking it. Let me see if I can drop an hot fix very quickly... |
Beta Was this translation helpful? Give feedback.
@aashmore thanks again for the report. I just tagged netket v3.10.1 and it should fix the issue. Let us know if it's not the case.