A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/deepmind/optax/issues/571 below:

`optax` and `tensorflow`'s Adam optimizer's setting. · Issue #571 · google-deepmind/optax · GitHub

Currently, optax.scale_by_adam should be equivalent to torch.optim.Adam. However, Tensorflow has a different implementation.

In short, if we change https://github.com/deepmind/optax/blob/cebdeff4a1922113a96c520e7a81b5bf79825b77/optax/_src/transform.py#L345-L348 to the following, then the adam optimizer would be the same as tensorflow's imlementation.

updates = jax.tree_util.tree_map(
    lambda m, v: (jnp.sqrt(1- b2**count_inc) / (1-b1**count_inc)) *  m / (jnp.sqrt(v + eps_root) + eps), mu, nu)
More context

Basically, PyTorch and optax's adam follow Algorithm 1 of the Kingma and Ba’s Adam paper (arxiv/1412.6980), but TensorFlow uses the formulation just before Section 2.1 of the paper and its epsilon referred to here is epsilon hat in the paper.

This was a relevant issue in my recent reproduction of openai's work in https://github.com/openai/lm-human-preferences. Long story short, below is an end-to-end experiment with torch's adam adam_pt and tensorlfow-style adam adam_tf. While the final performance (objective/scores) look the same, the learning curves are different in a non-trivial way. E.g., the torch adam version had a much higher clipfrac initially, causing a more initial significant update.

The "initial aggressive update" issue gets aggravated in larger models (e.g., gpt2-large). You can see that objective/kl had a spike with adam_tf, so this could be a reproducibility issue.

Desired solution

include a

import jax
import jax.numpy as jnp
from optax import ScaleByAdamState, update_moment, update_moment_per_elem_norm
from optax._src.alias import _scale_by_learning_rate
from optax._src import base, utils, combine, numerics


def scale_by_adam_tf_style(
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    mu_dtype = None,
) -> base.GradientTransformation:
  """Rescale updates according to the Adam algorithm.
  References:
    [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
  WARNING: This is a TensorFlow-style Adam optimizer that uses the
    formulation just before Section 2.1 of the Kingma and Ba paper
    rather than the formulation in Algorithm 1, the "epsilon" referred 
    to here is "epsilon hat" in the paper.
  Args:
    b1: Decay rate for the exponentially weighted average of grads.
    b2: Decay rate for the exponentially weighted average of squared grads.
    eps: Term added to the denominator to improve numerical stability. (epsilon hat)
    eps_root: Term added to the denominator inside the square-root to improve
      numerical stability when backpropagating gradients through the rescaling.
    mu_dtype: Optional `dtype` to be used for the first order accumulator; if
      `None` then the `dtype` is inferred from `params` and `updates`.
  Returns:
    A `GradientTransformation` object.
  """

  mu_dtype = utils.canonicalize_dtype(mu_dtype)

  def init_fn(params):
    mu = jax.tree_util.tree_map(  # First moment
        lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
    nu = jax.tree_util.tree_map(jnp.zeros_like, params)  # Second moment
    return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

  def update_fn(updates, state, params=None):
    del params
    mu = update_moment(updates, state.mu, b1, 1)
    nu = update_moment_per_elem_norm(updates, state.nu, b2, 2)
    count_inc = numerics.safe_int32_increment(state.count)

    ### `optax` default adam implementation
    # mu_hat = bias_correction(mu, b1, count_inc)
    # nu_hat = bias_correction(nu, b2, count_inc)
    # updates = jax.tree_util.tree_map(
    #     lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
    ### Tensorflow adam implementation
    updates = jax.tree_util.tree_map(
        lambda m, v: (jnp.sqrt(1- b2**count_inc) / (1-b1**count_inc)) *  m / (jnp.sqrt(v + eps_root) + eps), mu, nu) # 
    mu = utils.cast_tree(mu, mu_dtype)
    return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)

  return base.GradientTransformation(init_fn, update_fn)


def adam_tf_style(
    learning_rate,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    mu_dtype = None,
):
  return combine.chain(
      scale_by_adam_tf_style(
          b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype),
      _scale_by_learning_rate(learning_rate),
  )

obviously this is bad naming, but I figure you'd have much better ideas :)


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4