Remove one or more length-1 axes from array
JAX implementation of numpy.sqeeze()
, implemented via jax.lax.squeeze()
.
copy of a
with length-1 axes removed.
Notes
Unlike numpy.squeeze()
, jax.numpy.squeeze()
will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.
Examples
>>> x = jnp.array([[[0]], [[1]], [[2]]]) >>> x.shape (3, 1, 1)
Squeeze all length-1 dimensions:
>>> jnp.squeeze(x) Array([0, 1, 2], dtype=int32) >>> _.shape (3,)
Equivalent while specifying the axes explicitly:
>>> jnp.squeeze(x, axis=(1, 2)) Array([0, 1, 2], dtype=int32)
Attempting to squeeze a non-unit axis results in an error:
>>> jnp.squeeze(x, axis=0) Traceback (most recent call last): ... ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)
For convenience, this functionality is also available via the jax.Array.squeeze()
method:
>>> x.squeeze() Array([0, 1, 2], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4