Convert an array to a specified dtype.
JAX implementation of numpy.astype()
.
This is implemented via jax.lax.convert_element_type()
, which may have slightly different behavior than numpy.astype()
in some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent.
An array with the same shape as x
, containing values of the specified dtype.
Examples
>>> x = jnp.array([0, 1, 2, 3]) >>> x Array([0, 1, 2, 3], dtype=int32) >>> x.astype('float32') Array([0.0, 1.0, 2.0, 3.0], dtype=float32)
>>> y = jnp.array([0.0, 0.5, 1.0]) >>> y.astype(int) # truncates fractional values Array([0, 0, 1], dtype=int32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4