A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html below:

Normalization

Group normalization (arxiv.org/abs/1803.08494).

This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group.

>>> from flax import nnx
>>> import jax
>>> import numpy as np
...
>>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6))
>>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
  'bias': VariableState( # 6 (24 B)
    type=Param,
    value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
  ),
  'scale': VariableState( # 6 (24 B)
    type=Param,
    value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
  )
})
>>> y = layer(x)
...
>>> y = nnx.GroupNorm(num_features=6, num_groups=1, rngs=nnx.Rngs(0))(x)
>>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x)
>>> np.testing.assert_allclose(y, y2)

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.3