Source code for deeperwin.process_molecule

#!/usr/bin/env python3
"""
CLI to process a single molecule.
"""
import logging
import os
import sys
from deeperwin.configuration import Configuration
import ruamel.yaml as yaml

[docs]def process_molecule(config_file): with open(config_file, "r") as f: raw_config = yaml.YAML().load(f) config: Configuration = Configuration.parse_obj(raw_config) # Set environment variable to control jax behaviour before importing jax if config.computation.disable_tensor_cores: os.environ["NVIDIA_TF32_OVERRIDE"] = "0" if config.computation.force_device_count and config.computation.n_local_devices: os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={config.computation.n_local_devices}' from jax.config import config as jax_config jax_config.update("jax_enable_x64", config.computation.float_precision == "float64") import jax from deeperwin.utils import getCodeVersion, merge_params, init_multi_host_on_slurm, replicate_across_devices if config.computation.n_nodes > 1: init_multi_host_on_slurm() # These imports can only take place after we have set the jax_config options import chex if config.computation.disable_jit: chex.fake_pmap_and_jit().start() from jax.lib import xla_bridge from deeperwin.model import build_log_psi_squared from deeperwin.optimization import optimize_wavefunction, evaluate_wavefunction, pretrain_orbitals from deeperwin.checkpoints import load_data_for_reuse, delete_obsolete_checkpoints from deeperwin.loggers import LoggerCollection, build_dpe_root_logger import haiku as hk import numpy as np if not config.computation.n_local_devices: config.computation.n_local_devices = jax.local_device_count() else: assert jax.local_device_count() == config.computation.n_local_devices used_hardware = xla_bridge.get_backend().platform if config.computation.require_gpu and (used_hardware == "cpu"): raise ValueError("Required GPU, but no GPU available: Aborting.") # Initialize loggers for logging debug/info messages logger = build_dpe_root_logger(config.logging.basic) if jax.process_index() > 0: logger.debug(f"DeepErwin: Disabling logging for process {jax.process_index()}") logging.disable() if config.computation.rng_seed is None: rng_seed = np.random.randint(2**31, size=()) config.computation.rng_seed = int(replicate_across_devices(rng_seed)[0]) rng_seed = config.computation.rng_seed np.random.seed(rng_seed) logger.debug(f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}") logger.debug(f"Used hardware: {used_hardware}; Local device count: {jax.local_device_count()}; Global device count: {jax.device_count()}") # When reusing/restarting and old run: Merge configs and load data if config.reuse is not None: config, params_to_reuse, fixed_params, mcmc_state, opt_state, clipping_state = load_data_for_reuse(config, raw_config) else: params_to_reuse, fixed_params, mcmc_state, opt_state, clipping_state = None, None, None, None, None # Build wavefunction and initialize parameters log_psi_squared, orbital_func, params, fixed_params = build_log_psi_squared(config.model, config.physical, fixed_params, rng_seed) if params_to_reuse: params = merge_params(params, params_to_reuse) # Log config and metadata of run if jax.process_index() == 0: loggers = LoggerCollection(config.logging, config.experiment_name) loggers.on_run_begin() if config.logging.wandb: import wandb config.logging.wandb.id = wandb.run.id loggers.log_config(config) loggers.log_tags(config.logging.tags) loggers.log_param("code_version", getCodeVersion()) loggers.log_param("n_params", hk.data_structures.tree_size(params)) if "baseline_energies" in fixed_params: loggers.log_metrics(fixed_params["baseline_energies"]) config.save("full_config.yml") else: loggers = None # STEP 1: Supervised pre-training of wavefunction orbitals if config.pre_training and config.pre_training.n_epochs > 0: logger.info("Starting pre-training of orbitals...") params, _, mcmc_state = pretrain_orbitals( orbital_func, mcmc_state, params, fixed_params, config.pre_training, config.physical, config.model, rng_seed, loggers, ) # STEP 2: Unsupervised variational wavefunction optimization if config.optimization.n_epochs > 0: logger.info("Starting optimization...") mcmc_state, params, opt_state, clipping_state = optimize_wavefunction( log_psi_squared, params, fixed_params, mcmc_state, config.optimization, config.physical, rng_seed, loggers, opt_state, clipping_state, ) # STEP 3: Wavefunction evaluationA if config.evaluation.n_epochs > 0: logger.info("Starting evaluation...") eval_history, mcmc_state = evaluate_wavefunction( log_psi_squared, params, fixed_params, mcmc_state, config.evaluation, config.physical, rng_seed, loggers, config.optimization.n_epochs_total, ) if jax.process_index() == 0: loggers.log_checkpoint(config.optimization.n_epochs_total, params, fixed_params, mcmc_state, opt_state, clipping_state) delete_obsolete_checkpoints(config.optimization.n_epochs_total, config.optimization.checkpoints) loggers.on_run_end()
if __name__ == '__main__': process_molecule(sys.argv[1])