deeperwin.mcmc.MCMCState

class deeperwin.mcmc.MCMCState(r: jax._src.numpy.lax_numpy.array, R: jax._src.numpy.lax_numpy.array, Z: jax._src.numpy.lax_numpy.array, log_psi_sqr: jax._src.numpy.lax_numpy.array = None, walker_age: jax._src.numpy.lax_numpy.array = None, rng_state: jax._src.numpy.lax_numpy.array = None, stepsize: jax._src.numpy.lax_numpy.array = DeviceArray(0.01, dtype=float32, weak_type=True), step_nr: jax._src.numpy.lax_numpy.array = DeviceArray(0, dtype=int32), acc_rate: jax._src.numpy.lax_numpy.array = DeviceArray(0., dtype=float32, weak_type=True))[source]

Bases: collections.abc.Mapping

Dataclasss that holds an electronic configuration and metadata required for MCMC.

__init__(r: jax._src.numpy.lax_numpy.array, R: jax._src.numpy.lax_numpy.array, Z: jax._src.numpy.lax_numpy.array, log_psi_sqr: jax._src.numpy.lax_numpy.array = None, walker_age: jax._src.numpy.lax_numpy.array = None, rng_state: jax._src.numpy.lax_numpy.array = None, stepsize: jax._src.numpy.lax_numpy.array = DeviceArray(0.01, dtype=float32, weak_type=True), step_nr: jax._src.numpy.lax_numpy.array = DeviceArray(0, dtype=int32), acc_rate: jax._src.numpy.lax_numpy.array = DeviceArray(0., dtype=float32, weak_type=True)) None

Methods

__init__(r, R, Z[, log_psi_sqr, walker_age, ...])

build_batch(fixed_params)

from_tuple()

initialize_around_nuclei(n_walkers, ...)

merge_devices()

replace(**kwargs)

resize_or_init(mcmc_state, n_walkers, ...)

split_across_devices()

to_tuple()

Attributes

acc_rate

log_psi_sqr

rng_state

step_nr

stepsize

walker_age

r

R

Z