Return True if the input is a scalar.
JAX implementation of numpy.isscalar()
. JAX’s implementation differs from NumPy’s in that it considers zero-dimensional arrays to be scalars; see the Note below for more details.
element (Any) – input object to check; any type is valid input.
True if element
is a scalar value or an array-like object with zero dimensions, False otherwise.
Note
JAX and NumPy differ in their representation of scalar values. NumPy has special scalar objects (e.g. np.int32(0)
) which are distinct from zero-dimensional arrays (e.g. np.array(0)
), and numpy.isscalar()
returns True
for the former and False
for the latter.
JAX does not define special scalar objects, but rather represents scalars as zero-dimensional arrays. As such, jax.numpy.isscalar()
returns True
for both scalar objects (e.g. 0.0
or np.float32(0.0)
) and array-like objects with zero dimensions (e.g. jnp.array(0.0)
, np.array(0.0)
).
One reason for the different conventions in isscalar
is to maintain JIT-invariance: i.e. the property that the result of a function should not change when it is JIT-compiled. Because scalar inputs are cast to zero-dimensional JAX arrays at JIT boundaries, the semantics of numpy.isscalar()
are such that the result changes under JIT:
>>> np.isscalar(1.0) True >>> jax.jit(np.isscalar)(1.0) Array(False, dtype=bool)
By treating zero-dimensional arrays as scalars, jax.numpy.isscalar()
avoids this issue:
>>> jnp.isscalar(1.0) True >>> jax.jit(jnp.isscalar)(1.0) Array(True, dtype=bool)
Examples
In JAX, both scalars and zero-dimensional array-like objects are considered scalars:
>>> jnp.isscalar(1.0) True >>> jnp.isscalar(1 + 1j) True >>> jnp.isscalar(jnp.array(1)) # zero-dimensional JAX array True >>> jnp.isscalar(jnp.int32(1)) # JAX scalar constructor True >>> jnp.isscalar(np.array(1.0)) # zero-dimensional NumPy array True >>> jnp.isscalar(np.int32(1)) # NumPy scalar type True
Arrays with one or more dimension are not considered scalars:
>>> jnp.isscalar(jnp.array([1])) False >>> jnp.isscalar(np.array([1])) False
Compare this to numpy.isscalar()
, which returns True
for scalar-typed objects, and False
for all arrays, even those with zero dimensions:
>>> np.isscalar(np.int32(1)) # scalar object True >>> np.isscalar(np.array(1)) # zero-dimensional array False
In JAX, as in NumPy, objects which are not array-like are not considered scalars:
>>> jnp.isscalar(None) False >>> jnp.isscalar([1]) False >>> jnp.isscalar(tuple()) False >>> jnp.isscalar(slice(10)) False
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4