Group normalization (arxiv.org/abs/1803.08494).
This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group.
>>> from flax import nnx >>> import jax >>> import numpy as np ... >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'bias': VariableState( # 6 (24 B) type=Param, value=Array([0., 0., 0., 0., 0., 0.], dtype=float32) ), 'scale': VariableState( # 6 (24 B) type=Param, value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x) ... >>> y = nnx.GroupNorm(num_features=6, num_groups=1, rngs=nnx.Rngs(0))(x) >>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x) >>> np.testing.assert_allclose(y, y2)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3