deeperwin.optimization.optimize_wavefunction

deeperwin.optimization.optimize_wavefunction(log_psi_squared, params, fixed_params, mcmc_state: deeperwin.mcmc.MCMCState, opt_config: deeperwin.configuration.OptimizationConfig, phys_config: deeperwin.configuration.PhysicalConfig, rng_seed: int, logger: Optional[deeperwin.loggers.DataLogger] = None, initial_opt_state=None, initial_clipping_state=None)[source]

Minimizes the energy of the wavefunction defined by the callable log_psi_squared by adjusting the trainable parameters.

Parameters
  • log_psi_func (callable) – A function representing the wavefunction model

  • params (dict) – Trainable paramters of the model defined by log_psi_func

  • fixed_params (dict) – Fixed paramters of the model defined by log_psi_func

  • mcmc (MetropolisHastingsMonteCarlo) – Object that implements the MCMC algorithm

  • mcmc_state (MCMCState) – Initial state of the MCMC walkers

  • opt_config (OptimizationConfig) – Optimization hyperparameters

  • checkpoints (dict) – Dictionary with items of the form {n_epochs: path}. A checkpoint is saved for each item after optimization epoch n_epochs in the folder path.

  • logger (DataLogger) – A logger that is used to log information about the optimization process

  • log_config (LoggingConfig) – Logging configuration for checkpoints

Returns

A tuple (mcmc_state, trainable_paramters, opt_state), where mcmc_state is the final MCMC state and trainable_parameters contains the optimized parameters.