Return number of elements along a given axis.
JAX implementation of numpy.size()
. Unlike np.size
, this function raises a TypeError
if the input is a collection such as a list or tuple.
a (ArrayLike | SupportsSize | SupportsShape) – array-like object, or any object with a size
attribute when axis
is not specified, or with a shape
attribute when axis
is specified.
axis (int | None) – optional integer along which to count elements. By default, return the total number of elements.
An integer specifying the number of elements in a
.
Examples
Size for arrays:
>>> x = jnp.arange(10) >>> jnp.size(x) 10 >>> y = jnp.ones((2, 3)) >>> jnp.size(y) 6 >>> jnp.size(y, axis=1) 3
This also works for scalars:
For arrays, this can also be accessed via the jax.Array.size
property:
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4