A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://docs.jax.dev/en/latest/installation.html below:

Installation — JAX documentation

Installation#

Using JAX requires installing two packages: jax, which is pure Python and cross-platform, and jaxlib which contains compiled binaries, and requires different builds for different operating systems and accelerators.

Summary: For most users, a typical JAX installation may look something like this:

Supported platforms#

The table below shows all supported platforms and installation options. Check if your setup is supported; and if it says “yes” or “experimental”, then click on the corresponding link to learn how to install JAX in greater detail.

CPU# pip installation: CPU#

Currently, the JAX team releases jaxlib wheels for the following operating systems and architectures:

To install a CPU-only version of JAX, which might be useful for doing local development on a laptop, you can run:

pip install --upgrade pip
pip install --upgrade jax

On Windows, you may also need to install the Microsoft Visual Studio 2019 Redistributable if it is not already installed on your machine.

Other operating systems and architectures require building from source. Trying to pip install on other operating systems and architectures may lead to jaxlib not being installed alongside jax, although jax may successfully install (but fail at runtime).

NVIDIA GPU#

On CUDA 12, JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer. Note that Kepler-series GPUs are no longer supported by JAX since NVIDIA has dropped support for Kepler GPUs in its software. On CUDA 13, JAX supports NVIDIA GPUs that have SM version 7.5 or newer. NVIDIA dropped support for previous GPUs in CUDA 13.

You must first install the NVIDIA driver. You’re recommended to install the newest driver available from NVIDIA, but the driver version must be >= 525 for CUDA 12 on Linux, and >= 580 for CUDA 13 on Linux.

If you need to use a newer CUDA toolkit with an older driver, for example on a cluster where you cannot update the NVIDIA driver easily, you may be able to use the CUDA forward compatibility packages that NVIDIA provides for this purpose.

pip installation: NVIDIA GPU (CUDA, installed via pip, easier)#

There are two ways to install JAX with NVIDIA GPU support:

The JAX team strongly recommends installing CUDA and cuDNN using the pip wheels, since it is much easier!

NVIDIA has released CUDA packages only for x86_64 and aarch64.

pip install --upgrade pip

# NVIDIA CUDA 13 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda13]"

# Alternatively, for CUDA 12, use
# pip install --upgrade "jax[cuda12]"

We recommend migrating to the CUDA 13 wheels; at some point in the future we will drop CUDA 12 support.

If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things you need to check:

pip installation: NVIDIA GPU (CUDA, installed locally, harder)#

If you prefer to use a preinstalled copy of NVIDIA CUDA, you must first install NVIDIA CUDA and cuDNN.

JAX provides pre-built CUDA-compatible wheels for Linux x86_64 and Linux aarch64 only. Other combinations of operating system and architecture are possible, but require building from source (refer to Building from source to learn more}.

You should use an NVIDIA driver version that is at least as new as your NVIDIA CUDA toolkit’s corresponding driver version. If you need to use a newer CUDA toolkit with an older driver, for example on a cluster where you cannot update the NVIDIA driver easily, you may be able to use the CUDA forward compatibility packages that NVIDIA provides for this purpose.

JAX currently ships two CUDA wheel variants: CUDA 12 and CUDA 13:

The CUDA 12 wheel is:

Built with

Compatible with

CUDA 12.3

CUDA >=12.1

CUDNN 9.8

CUDNN >=9.8, <10.0

NCCL 2.19

NCCL >=2.18

The CUDA 13 wheel is:

Built with

Compatible with

CUDA 13.0

CUDA >=13.0

CUDNN 9.8

CUDNN >=9.8, <10.0

NCCL 2.19

NCCL >=2.18

JAX checks the versions of your libraries, and will report an error if they are not sufficiently new. Setting the JAX_SKIP_CUDA_CONSTRAINTS_CHECK environment variable will disable the check, but using older versions of CUDA may lead to errors, or incorrect results.

NCCL is an optional dependency, required only if you are performing multi-GPU computations.

To install, run:

pip install --upgrade pip


# Installs the wheel compatible with NVIDIA CUDA 13 and cuDNN 9.8 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda13-local]"

# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.8 or newer.
# Note: wheels only available on linux.
# pip install --upgrade "jax[cuda12-local]"

These pip installations do not work with Windows, and may fail silently; refer to the table above.

You can find your CUDA version with the command:

JAX uses LD_LIBRARY_PATH to find CUDA libraries and PATH to find binaries (ptxas, nvlink). Please make sure that these paths point to the correct CUDA installation.

JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package. Make sure that it is present in your CUDA installation.

Please let the JAX team know on the GitHub issue tracker if you run into any errors or problems with the pre-built wheels.

NVIDIA GPU Docker containers#

NVIDIA provides the JAX Toolbox containers, which are bleeding edge containers containing nightly releases of jax and some models/frameworks.

Google Cloud TPU# pip installation: Google Cloud TPU#

JAX provides pre-built wheels for Google Cloud TPU. To install JAX along with appropriate versions of jaxlib and libtpu, you can run the following in your cloud TPU VM:

For users of Colab (https://colab.research.google.com/), be sure you are using TPU v2 and not the older, deprecated TPU runtime.

Mac GPU# pip installation#

Apple provides an experimental Metal plugin. For details, refer to Apple’s JAX on Metal documentation.

Note: There are several caveats with the Metal plugin:

AMD GPU (Linux)#

AMD GPU support is provided by a ROCm JAX plugin supported by AMD.

There are several ways to use JAX on AMDGPU devices. Please see AMD’s instructions for details.

Note: ROCm support on Windows WSL2 is experimental. For WSL installation, you may need to:

  1. Install ROCm for WSL following AMD’s official guide

  2. Follow the standard Linux ROCm JAX installation steps within your WSL environment

  3. Be aware that performance and stability may differ from native Linux installations

Intel GPU#

Intel provides an experimental OneAPI plugin: intel-extension-for-openxla for Intel GPU hardware. For more details and installation instructions, refer to one of the following two methods:

  1. Pip installation: JAX acceleration on Intel GPU.

  2. Using Intel’s XLA Docker container.

Please report any issues related to:

JAX nightly installation#

Nightly releases reflect the state of the main JAX repository at the time they are built, and may not pass the full test suite.

Unlike the instructions for installing a JAX release, here we name all of JAX’s packages explicitly on the command line, so pip will upgrade them if a newer version is available.

JAX publishes nightlies, release candidates(RCs), and releases to several non-pypi PEP 503 indexes.

All JAX packages can be reached from the index https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ as well as PyPI mirrored packages. This additional mirroring enables nightly installation to use –index (-i) as the install method with pip.

Note: The unified index could return an RC or release as the newest version even with --pre immediately after a release before the newest nightly is rebuilt. If automation or testing must be done against nightlies or you cannot use our full index, use the extra index https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ which only contains nightly artifacts.

pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
pip install -U --pre jax jaxlib libtpu requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -U --pre jax jaxlib "jax-cuda13-plugin[with-cuda]" jax-cuda13-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
pip install -U --pre jax jaxlib "jax-cuda12-plugin[with-cuda]" jax-cuda12-pjrt -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Building JAX from source#

Refer to Building from source.

Installing older jaxlib wheels#

Due to storage limitations on the Python package index, the JAX team periodically removes older jaxlib wheels from the releases on http://pypi.org/project/jax. These can still be installed directly via the URLs here. For example:

# Install jaxlib on CPU via the wheel archive
pip install "jax[cpu]==0.3.25" -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

# Install the jaxlib 0.3.25 CPU wheel directly
pip install jaxlib==0.3.25 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

For specific older GPU wheels, be sure to use the jax_cuda_releases.html URL; for example

pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.5