Calculate element-wise square of the input array.
JAX implementation of numpy.square
.
x (ArrayLike) – input array or scalar.
An array containing the square of the elements of x
.
Note
jnp.square
is equivalent to computing jnp.power(x, 2)
.
See also
jax.numpy.sqrt()
: Calculates the element-wise non-negative square root of the input array.
jax.numpy.power()
: Calculates the element-wise base x1
exponential of x2
.
jax.lax.integer_pow()
: Computes element-wise power \(x^y\), where \(y\) is a fixed integer.
jax.numpy.float_power()
: Computes the first array raised to the power of second array, element-wise, by promoting to the inexact dtype.
Examples
>>> x = jnp.array([3, -2, 5.3, 1]) >>> jnp.square(x) Array([ 9. , 4. , 28.090002, 1. ], dtype=float32) >>> jnp.power(x, 2) Array([ 9. , 4. , 28.090002, 1. ], dtype=float32)
For integer inputs:
>>> x1 = jnp.array([2, 4, 5, 6]) >>> jnp.square(x1) Array([ 4, 16, 25, 36], dtype=int32)
For complex-valued inputs:
>>> x2 = jnp.array([1-3j, -1j, 2]) >>> jnp.square(x2) Array([-8.-6.j, -1.+0.j, 4.+0.j], dtype=complex64)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4