#!/usr/bin/env python3
"""
CLI to process a single molecule.
"""
import logging
import os
import sys
from deeperwin.configuration import Configuration
import ruamel.yaml as yaml
[docs]def process_molecule(config_file):
with open(config_file, "r") as f:
raw_config = yaml.YAML().load(f)
config: Configuration = Configuration.parse_obj(raw_config)
# Set environment variable to control jax behaviour before importing jax
if config.computation.disable_tensor_cores:
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
if config.computation.force_device_count and config.computation.n_local_devices:
os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={config.computation.n_local_devices}'
from jax.config import config as jax_config
jax_config.update("jax_enable_x64", config.computation.float_precision == "float64")
import jax
from deeperwin.utils import getCodeVersion, merge_params, init_multi_host_on_slurm, replicate_across_devices
if config.computation.n_nodes > 1:
init_multi_host_on_slurm()
# These imports can only take place after we have set the jax_config options
import chex
if config.computation.disable_jit:
chex.fake_pmap_and_jit().start()
from jax.lib import xla_bridge
from deeperwin.model import build_log_psi_squared
from deeperwin.optimization import optimize_wavefunction, evaluate_wavefunction, pretrain_orbitals
from deeperwin.checkpoints import load_data_for_reuse, delete_obsolete_checkpoints
from deeperwin.loggers import LoggerCollection, build_dpe_root_logger
import haiku as hk
import numpy as np
if not config.computation.n_local_devices:
config.computation.n_local_devices = jax.local_device_count()
else:
assert jax.local_device_count() == config.computation.n_local_devices
used_hardware = xla_bridge.get_backend().platform
if config.computation.require_gpu and (used_hardware == "cpu"):
raise ValueError("Required GPU, but no GPU available: Aborting.")
# Initialize loggers for logging debug/info messages
logger = build_dpe_root_logger(config.logging.basic)
if jax.process_index() > 0:
logger.debug(f"DeepErwin: Disabling logging for process {jax.process_index()}")
logging.disable()
if config.computation.rng_seed is None:
rng_seed = np.random.randint(2**31, size=())
config.computation.rng_seed = int(replicate_across_devices(rng_seed)[0])
rng_seed = config.computation.rng_seed
np.random.seed(rng_seed)
logger.debug(f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")
logger.debug(f"Used hardware: {used_hardware}; Local device count: {jax.local_device_count()}; Global device count: {jax.device_count()}")
# When reusing/restarting and old run: Merge configs and load data
if config.reuse is not None:
config, params_to_reuse, fixed_params, mcmc_state, opt_state, clipping_state = load_data_for_reuse(config, raw_config)
else:
params_to_reuse, fixed_params, mcmc_state, opt_state, clipping_state = None, None, None, None, None
# Build wavefunction and initialize parameters
log_psi_squared, orbital_func, params, fixed_params = build_log_psi_squared(config.model, config.physical, fixed_params, rng_seed)
if params_to_reuse:
params = merge_params(params, params_to_reuse)
# Log config and metadata of run
if jax.process_index() == 0:
loggers = LoggerCollection(config.logging, config.experiment_name)
loggers.on_run_begin()
if config.logging.wandb:
import wandb
config.logging.wandb.id = wandb.run.id
loggers.log_config(config)
loggers.log_tags(config.logging.tags)
loggers.log_param("code_version", getCodeVersion())
loggers.log_param("n_params", hk.data_structures.tree_size(params))
if "baseline_energies" in fixed_params:
loggers.log_metrics(fixed_params["baseline_energies"])
config.save("full_config.yml")
else:
loggers = None
# STEP 1: Supervised pre-training of wavefunction orbitals
if config.pre_training and config.pre_training.n_epochs > 0:
logger.info("Starting pre-training of orbitals...")
params, _, mcmc_state = pretrain_orbitals(
orbital_func, mcmc_state, params, fixed_params, config.pre_training, config.physical, config.model,
rng_seed, loggers,
)
# STEP 2: Unsupervised variational wavefunction optimization
if config.optimization.n_epochs > 0:
logger.info("Starting optimization...")
mcmc_state, params, opt_state, clipping_state = optimize_wavefunction(
log_psi_squared,
params,
fixed_params,
mcmc_state,
config.optimization,
config.physical,
rng_seed,
loggers,
opt_state,
clipping_state,
)
# STEP 3: Wavefunction evaluationA
if config.evaluation.n_epochs > 0:
logger.info("Starting evaluation...")
eval_history, mcmc_state = evaluate_wavefunction(
log_psi_squared,
params,
fixed_params,
mcmc_state,
config.evaluation,
config.physical,
rng_seed,
loggers,
config.optimization.n_epochs_total,
)
if jax.process_index() == 0:
loggers.log_checkpoint(config.optimization.n_epochs_total, params, fixed_params, mcmc_state, opt_state, clipping_state)
delete_obsolete_checkpoints(config.optimization.n_epochs_total, config.optimization.checkpoints)
loggers.on_run_end()
if __name__ == '__main__':
process_molecule(sys.argv[1])