Load JAX arrays from npy files.
JAX wrapper of numpy.load()
.
This function is a simple wrapper of numpy.load()
, but in the case of .npy
files created with numpy.save()
or jax.numpy.save()
, the output will be returned as a jax.Array
, and bfloat16
data types will be restored. For .npz
files, results will be returned as normal NumPy arrays.
This function requires concrete array inputs, and is not compatible with transformations like jax.jit()
or jax.vmap()
.
file (IO[bytes] | str | os.PathLike[Any]) – string, bytes, or path-like object containing the array data.
args (Any) – for additional arguments, see numpy.load()
kwargs (Any) – for additional arguments, see numpy.load()
the array stored in the file.
Examples
>>> import io >>> f = io.BytesIO() # use an in-memory file-like object. >>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16') >>> jnp.save(f, x) >>> f.seek(0) 0 >>> jnp.load(f) Array([2, 4, 6, 8], dtype=bfloat16)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3