deeperwin.optimizers.OptaxWrapper
- class deeperwin.optimizers.OptaxWrapper(value_and_grad_func: Callable[[...], Tuple[Union[jax._src.numpy.ndarray.ndarray, Tuple[jax._src.numpy.ndarray.ndarray, Any], Tuple[jax._src.numpy.ndarray.ndarray, Tuple[Any, Any]]], Any]], value_func_has_aux: bool, value_func_has_state: bool, value_func_has_rng: bool, optax_optimizer: optax._src.base.GradientTransformation, multi_device: bool = False, pmap_axis_name='devices', batch_process_func: Optional[Callable[[Any], Any]] = <function OptaxWrapper.<lambda>>)[source]
Bases:
object
Wrapper class for Optax optimizers to have the same interface as KFAC.
- __init__(value_and_grad_func: Callable[[...], Tuple[Union[jax._src.numpy.ndarray.ndarray, Tuple[jax._src.numpy.ndarray.ndarray, Any], Tuple[jax._src.numpy.ndarray.ndarray, Tuple[Any, Any]]], Any]], value_func_has_aux: bool, value_func_has_state: bool, value_func_has_rng: bool, optax_optimizer: optax._src.base.GradientTransformation, multi_device: bool = False, pmap_axis_name='devices', batch_process_func: Optional[Callable[[Any], Any]] = <function OptaxWrapper.<lambda>>)[source]
Initializes the Optax wrapper.
- Parameters
value_and_grad_func –
Python callable. The function should return the value of the loss to be optimized and its gradients. If the argument value_func_has_aux is False then the interface should be:
loss, loss_grads = value_and_grad_func(params, batch)
- If value_func_has_aux is True then the interface should be:
(loss, aux), loss_grads = value_and_grad_func(params, batch)
value_func_has_aux – Boolean. Specifies whether the provided callable value_and_grad_func returns the loss value only, or also some auxiliary data. (Default: False)
value_func_has_state – Boolean. Specifies whether the provided callable value_and_grad_func has a persistent state that is inputted and it also outputs an update version of it. (Default: False)
value_func_has_rng – Boolean. Specifies whether the provided callable value_and_grad_func additionally takes as input an rng key. (Default: False)
optax_optimizer – The optax optimizer to be wrapped.
batch_process_func – Callable. A function which to be called on each batch before feeding to the KFAC on device. This could be useful for specific device input optimizations. (Default: lambda x: x)
Methods
__init__
(value_and_grad_func, ...[, ...])Initializes the Optax wrapper.
init
(params, rng, batch[, func_state])Initializes the optimizer and returns the appropriate optimizer state.
step
(params, state, rng, batch[, func_state])A step with similar interface to KFAC.
- init(params: Any, rng: jax._src.numpy.ndarray.ndarray, batch: Any, func_state: Optional[Any] = None) Any [source]
Initializes the optimizer and returns the appropriate optimizer state.