deeperwin.optimization.optimize_wavefunction
- deeperwin.optimization.optimize_wavefunction(log_psi_squared, params, fixed_params, mcmc_state: deeperwin.mcmc.MCMCState, opt_config: deeperwin.configuration.OptimizationConfig, phys_config: deeperwin.configuration.PhysicalConfig, rng_seed: int, logger: Optional[deeperwin.loggers.DataLogger] = None, initial_opt_state=None, initial_clipping_state=None)[source]
Minimizes the energy of the wavefunction defined by the callable log_psi_squared by adjusting the trainable parameters.
- Parameters
log_psi_func (callable) – A function representing the wavefunction model
params (dict) – Trainable paramters of the model defined by log_psi_func
fixed_params (dict) – Fixed paramters of the model defined by log_psi_func
mcmc (MetropolisHastingsMonteCarlo) – Object that implements the MCMC algorithm
mcmc_state (MCMCState) – Initial state of the MCMC walkers
opt_config (OptimizationConfig) – Optimization hyperparameters
checkpoints (dict) – Dictionary with items of the form {n_epochs: path}. A checkpoint is saved for each item after optimization epoch n_epochs in the folder path.
logger (DataLogger) – A logger that is used to log information about the optimization process
log_config (LoggingConfig) – Logging configuration for checkpoints
- Returns
A tuple (mcmc_state, trainable_paramters, opt_state), where mcmc_state is the final MCMC state and trainable_parameters contains the optimized parameters.