A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://docs.jax.dev/en/latest/_autosummary/jax.numpy.load.html below:

jax.numpy.load — JAX documentation

jax.numpy.load#
jax.numpy.load(file, *args, **kwargs)[source]#

Load JAX arrays from npy files.

JAX wrapper of numpy.load().

This function is a simple wrapper of numpy.load(), but in the case of .npy files created with numpy.save() or jax.numpy.save(), the output will be returned as a jax.Array, and bfloat16 data types will be restored. For .npz files, results will be returned as normal NumPy arrays.

This function requires concrete array inputs, and is not compatible with transformations like jax.jit() or jax.vmap().

Parameters:
  • file (IO[bytes] | str | os.PathLike[Any]) – string, bytes, or path-like object containing the array data.

  • args (Any) – for additional arguments, see numpy.load()

  • kwargs (Any) – for additional arguments, see numpy.load()

Returns:

the array stored in the file.

Return type:

Array

Examples

>>> import io
>>> f = io.BytesIO()  # use an in-memory file-like object.
>>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16')
>>> jnp.save(f, x)
>>> f.seek(0)
0
>>> jnp.load(f)
Array([2, 4, 6, 8], dtype=bfloat16)

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.3