deeperwin.optimizers.OptaxWrapper

class deeperwin.optimizers.OptaxWrapper(value_and_grad_func: Callable[[...], Tuple[Union[jax._src.numpy.ndarray.ndarray, Tuple[jax._src.numpy.ndarray.ndarray, Any], Tuple[jax._src.numpy.ndarray.ndarray, Tuple[Any, Any]]], Any]], value_func_has_aux: bool, value_func_has_state: bool, value_func_has_rng: bool, optax_optimizer: optax._src.base.GradientTransformation, multi_device: bool = False, pmap_axis_name='devices', batch_process_func: Optional[Callable[[Any], Any]] = <function OptaxWrapper.<lambda>>)[source]

Bases: object

Wrapper class for Optax optimizers to have the same interface as KFAC.

__init__(value_and_grad_func: Callable[[...], Tuple[Union[jax._src.numpy.ndarray.ndarray, Tuple[jax._src.numpy.ndarray.ndarray, Any], Tuple[jax._src.numpy.ndarray.ndarray, Tuple[Any, Any]]], Any]], value_func_has_aux: bool, value_func_has_state: bool, value_func_has_rng: bool, optax_optimizer: optax._src.base.GradientTransformation, multi_device: bool = False, pmap_axis_name='devices', batch_process_func: Optional[Callable[[Any], Any]] = <function OptaxWrapper.<lambda>>)[source]

Initializes the Optax wrapper.

Parameters
  • value_and_grad_func

    Python callable. The function should return the value of the loss to be optimized and its gradients. If the argument value_func_has_aux is False then the interface should be:

    loss, loss_grads = value_and_grad_func(params, batch)

    If value_func_has_aux is True then the interface should be:

    (loss, aux), loss_grads = value_and_grad_func(params, batch)

  • value_func_has_aux – Boolean. Specifies whether the provided callable value_and_grad_func returns the loss value only, or also some auxiliary data. (Default: False)

  • value_func_has_state – Boolean. Specifies whether the provided callable value_and_grad_func has a persistent state that is inputted and it also outputs an update version of it. (Default: False)

  • value_func_has_rng – Boolean. Specifies whether the provided callable value_and_grad_func additionally takes as input an rng key. (Default: False)

  • optax_optimizer – The optax optimizer to be wrapped.

  • batch_process_func – Callable. A function which to be called on each batch before feeding to the KFAC on device. This could be useful for specific device input optimizations. (Default: lambda x: x)

Methods

__init__(value_and_grad_func, ...[, ...])

Initializes the Optax wrapper.

init(params, rng, batch[, func_state])

Initializes the optimizer and returns the appropriate optimizer state.

step(params, state, rng, batch[, func_state])

A step with similar interface to KFAC.

init(params: Any, rng: jax._src.numpy.ndarray.ndarray, batch: Any, func_state: Optional[Any] = None) Any[source]

Initializes the optimizer and returns the appropriate optimizer state.

_step(params: Any, state: Any, rng, batch: Any, func_state: Optional[Any] = None) Union[jax._src.numpy.ndarray.ndarray, Tuple[jax._src.numpy.ndarray.ndarray, Any], Tuple[jax._src.numpy.ndarray.ndarray, Tuple[Any, Any]]][source]

A single step of optax.

step(params: Any, state: Any, rng: jax._src.numpy.ndarray.ndarray, batch: Any, func_state: Optional[Any] = None) Union[Tuple[Any, Any, Any, Mapping[str, jax._src.numpy.ndarray.ndarray]], Tuple[Any, Any, Mapping[str, jax._src.numpy.ndarray.ndarray]]][source]

A step with similar interface to KFAC.