deeperwin.mcmc.MCMCState
- class deeperwin.mcmc.MCMCState(r: jax._src.numpy.lax_numpy.array, R: jax._src.numpy.lax_numpy.array, Z: jax._src.numpy.lax_numpy.array, log_psi_sqr: jax._src.numpy.lax_numpy.array = None, walker_age: jax._src.numpy.lax_numpy.array = None, rng_state: jax._src.numpy.lax_numpy.array = None, stepsize: jax._src.numpy.lax_numpy.array = DeviceArray(0.01, dtype=float32, weak_type=True), step_nr: jax._src.numpy.lax_numpy.array = DeviceArray(0, dtype=int32), acc_rate: jax._src.numpy.lax_numpy.array = DeviceArray(0., dtype=float32, weak_type=True))[source]
Bases:
collections.abc.Mapping
Dataclasss that holds an electronic configuration and metadata required for MCMC.
- __init__(r: jax._src.numpy.lax_numpy.array, R: jax._src.numpy.lax_numpy.array, Z: jax._src.numpy.lax_numpy.array, log_psi_sqr: jax._src.numpy.lax_numpy.array = None, walker_age: jax._src.numpy.lax_numpy.array = None, rng_state: jax._src.numpy.lax_numpy.array = None, stepsize: jax._src.numpy.lax_numpy.array = DeviceArray(0.01, dtype=float32, weak_type=True), step_nr: jax._src.numpy.lax_numpy.array = DeviceArray(0, dtype=int32), acc_rate: jax._src.numpy.lax_numpy.array = DeviceArray(0., dtype=float32, weak_type=True)) None
Methods
__init__
(r, R, Z[, log_psi_sqr, walker_age, ...])build_batch
(fixed_params)from_tuple
()initialize_around_nuclei
(n_walkers, ...)merge_devices
()replace
(**kwargs)resize_or_init
(mcmc_state, n_walkers, ...)split_across_devices
()to_tuple
()Attributes
acc_rate
log_psi_sqr
rng_state
step_nr
stepsize
walker_age
r
R
Z