Best viewed here. For the changes specific to the experimental Pallas APIs, see Pallas Changelog.
JAX follows Effort-based versioning; for a discussion of this and JAX’s API compatibility policy, refer to API compatibility. For the Python and NumPy version support policy, refer to Python and NumPy version support policy.
Unreleased#New features:
Added jax.lax.axis_size()
which returns the size of the mapped axis given its name.
Changes
JAX nightly packages are now published to artifact registry. To install these packages, see the JAX installation guide.
jax.sharding.PartitionSpec
no longer inherits from a tuple.
jax.ShapeDtypeStruct
is immutable now. Please use .update
method to update your ShapeDtypeStruct
instead of doing in-place updates.
Breaking changes
jax.numpy.array()
no longer accepts None
. This behavior was deprecated since November 2023 and is now removed.
Removed the config.jax_data_dependent_tracing_fallback
config option, which was added temporarily in v0.4.36 to allow users to opt out of the new “stackless” tracing machinery.
Removed the config.jax_eager_pmap
config option.
Disallow the calling of lower
and trace
AOT APIs on the result of jax.jit
if there have been subsequent wrappers applied. Previously this worked, but silently ignored the wrappers. The workaround is to apply jax.jit
last among the wrappers, and similarly for jax.pmap
. See #27873.
The cuda12_pip
extra for jax
has been removed; use pip install jax[cuda12]
instead.
Changes
The minimum CuDNN version is v9.8.
JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain supported.
JAX package extras are now updated to use dash instead of underscore to align with PEP 685. For instance, if you were previously using pip install jax[cuda12_local]
to install JAX, run pip install jax[cuda12-local]
instead.
jax.jit()
now requires fun
to be passed by position, and additional arguments to be passed by keyword. Doing otherwise will result in a DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
Deprecations
jax.tree_util.build_tree()
is deprecated. Use jax.tree.unflatten()
instead.
Implemented host callback handlers for CPU and GPU devices using XLA’s FFI and removed existing CPU/GPU handlers using XLA’s custom call.
All APIs in jax.lib.xla_extension
are now deprecated.
jax.interpreters.mlir.hlo
and jax.interpreters.mlir.func_dialect
, which were accidental exports, have been removed. If needed, they are available from jax.extend.mlir
.
jax.interpreters.mlir.custom_call
is deprecated. The APIs provided by jax.ffi
should be used instead.
The deprecated use of jax.ffi.ffi_call()
with inline arguments is no longer supported. ffi_call()
now unconditionally returns a callable.
The following exports in jax.lib.xla_client
are deprecated: get_topology_for_devices
, heap_profile
, mlir_api_version
, Client
, CompileOptions
, DeviceAssignment
, Frame
, HloSharding
, OpSharding
, Traceback
.
The following internal APIs in jax.util
are deprecated: HashableFunction
, as_hashable_function
, cache
, safe_map
, safe_zip
, split_dict
, split_list
, split_list_checked
, split_merge
, subvals
, toposort
, unzip2
, wrap_name
, and wraps
.
jax.dlpack.to_dlpack
has been deprecated. You can usually pass a JAX Array
directly to the from_dlpack
function of another framework. If you need the functionality of to_dlpack
, use the __dlpack__
attribute of an array.
jax.lax.infeed
, jax.lax.infeed_p
, jax.lax.outfeed
, and jax.lax.outfeed_p
are deprecated and will be removed in JAX v0.7.0.
Several previously-deprecated APIs have been removed, including:
From jax.lib.xla_client
: ArrayImpl
, FftType
, PaddingType
, PrimitiveType
, XlaBuilder
, dtype_to_etype
, ops
, register_custom_call_target
, shape_from_pyval
, Shape
, XlaComputation
.
From jax.lib.xla_extension
: ArrayImpl
, XlaRuntimeError
.
From jax
: jax.treedef_is_leaf
, jax.tree_flatten
, jax.tree_map
, jax.tree_leaves
, jax.tree_structure
, jax.tree_transpose
, and jax.tree_unflatten
. Replacements can be found in jax.tree
or jax.tree_util
.
From jax.core
: AxisSize
, ClosedJaxpr
, EvalTrace
, InDBIdx
, InputType
, Jaxpr
, JaxprEqn
, Literal
, MapPrimitive
, OpaqueTraceState
, OutDBIdx
, Primitive
, Token
, TRACER_LEAK_DEBUGGER_WARNING
, Var
, concrete_aval
, dedup_referents
, escaped_tracer_error
, extend_axis_env_nd
, full_lower
, get_referent
, jaxpr_as_fun
, join_effects
, lattice_join
, leaked_tracer_error
, maybe_find_leaked_tracers
, raise_to_shaped
, raise_to_shaped_mappings
, reset_trace_state
, str_eqn_compact
, substitute_vars_in_output_ty
, typecompat
, and used_axis_names_jaxpr
. Most have no public replacement, though a few are available at jax.extend.core
.
The vectorized
argument to pure_callback()
and ffi_call()
. Use the vmap_method
parameter instead.
New Features
Added a allow_negative_indices
option to jax.lax.dynamic_slice()
, jax.lax.dynamic_update_slice()
and related functions. The default is true, matching the current behavior. If set to false, JAX does not need to emit code clamping negative indices, which improves code size.
Added a replace
option to jax.random.categorical()
to enable sampling without replacement.
Patch release of 0.5.1
Bug fixes
Fixes TPU metric logging and tpu-info
, which was broken in 0.5.1
Breaking changes
The jit tracing cache now keys on input NamedShardings. Previously, the tracing cache did not include sharding information at all (although subsequent jit caches did like lowering and compilation caches), so two equivalent shardings of different types would not retrace, but now they do. For example:
@jax.jit def f(x): return x # inp1.sharding is of type SingleDeviceSharding inp1 = jnp.arange(8) f(inp1) mesh = jax.make_mesh((1,), ('x',)) # inp2.sharding is of type NamedSharding inp2 = jax.device_put(jnp.arange(8), NamedSharding(mesh, P('x'))) f(inp2) # tracing cache miss
In the above example, calling f(inp1)
and then f(inp2)
will lead to a tracing cache miss because the shardings have changed on the abstract values while tracing.
New Features
Added an experimental jax.experimental.custom_dce.custom_dce()
decorator to support customizing the behavior of opaque functions under JAX-level dead code elimination (DCE). See #25956 for more details.
Added low-level reduction APIs in jax.lax
: jax.lax.reduce_sum()
, jax.lax.reduce_prod()
, jax.lax.reduce_max()
, jax.lax.reduce_min()
, jax.lax.reduce_and()
, jax.lax.reduce_or()
, and jax.lax.reduce_xor()
.
jax.lax.linalg.qr()
, and jax.scipy.linalg.qr()
, now support column-pivoting on CPU and GPU. See #20282 and
Added jax.random.multinomial()
. #25955 for more details.
Changes
JAX_CPU_COLLECTIVES_IMPLEMENTATION
and JAX_NUM_CPU_DEVICES
now work as env vars. Before they could only be specified via jax.config or flags.
JAX_CPU_COLLECTIVES_IMPLEMENTATION
now defaults to 'gloo'
, meaning multi-process CPU communication works out-of-the-box.
The jax[tpu]
TPU extra no longer depends on the libtpu-nightly
package. This package may safely be removed if it is present on your machine; JAX now uses libtpu
instead.
Deprecations
The internal function linear_util.wrap_init
and the constructor core.Jaxpr
now must take a non-empty core.DebugInfo
kwarg. For a limited time, a DeprecationWarning
is printed if jax.extend.linear_util.wrap_init
is used without debugging info. A downstream effect of this several other internal functions need debug info. This change does not affect public APIs. See https://github.com/jax-ml/jax/issues/26480 for more detail.
In jax.numpy.ndim()
, jax.numpy.shape()
, and jax.numpy.size()
, non-arraylike inputs (such as lists, tuples, etc.) are now deprecated.
Bug fixes
TPU runtime startup and shutdown time should be significantly improved on TPU v5e and newer (from around 17s to around 8s). If not already set, you may need to enable transparent hugepages in your VM image (sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled'
). We hope to improve this further in future releases.
Persistent compilation cache no longer writes access time file if JAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRU eviction policy isn’t enabled. This should improve performance when using the cache with large-scale network storage.
As of this release, JAX now uses effort-based versioning. Since this release makes a breaking change to PRNG key semantics that may require users to update their code, we are bumping the “meso” version of JAX to signify this.
Breaking changes
Enable jax_threefry_partitionable
by default (see the update note).
This release drops support for Mac x86 wheels. Mac ARM of course remains supported. For a recent discussion, see https://github.com/jax-ml/jax/discussions/22936.
Two key factors motivated this decision:
The Mac x86 build (only) has a number of test failures and crashes. We would prefer to ship no release than a broken release.
Mac x86 hardware is end-of-life and cannot be easily obtained for developers at this point. So it is difficult for us to fix this kind of problem even if we wanted to.
We are open to readding support for Mac x86 if the community is willing to help support that platform: in particular, we would need the JAX test suite to pass cleanly on Mac x86 before we could ship releases again.
Changes:
The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum supported version until June 2025.
The minimum SciPy version is now 1.11. SciPy 1.11 will remain the minimum supported version until June 2025.
jax.numpy.einsum()
now defaults to optimize='auto'
rather than optimize='optimal'
. This avoids exponentially-scaling trace-time in the case of many arguments (#25214).
jax.numpy.linalg.solve()
no longer supports batched 1D arguments on the right hand side. To recover the previous behavior in these cases, use solve(a, b[..., None]).squeeze(-1)
.
New Features
jax.numpy.fft.fftn()
, jax.numpy.fft.rfftn()
, jax.numpy.fft.ifftn()
, and jax.numpy.fft.irfftn()
now support transforms in more than 3 dimensions, which was previously the limit. See #25606 for more details.
Support added for user defined state in the FFI via the new jax.ffi.register_ffi_type_id()
function.
The AOT lowering .as_text()
method now supports the debug_info
option to include debugging information, e.g., source location, in the output.
Deprecations
From jax.interpreters.xla
, abstractify
and pytype_aval_mappings
are now deprecated, having been replaced by symbols of the same name in jax.core
.
jax.scipy.special.lpmn()
and jax.scipy.special.lpmn_values()
are deprecated, following their deprecation in SciPy v1.15.0. There are no plans to replace these deprecated functions with new APIs.
The jax.extend.ffi
submodule was moved to jax.ffi
, and the previous import path is deprecated.
Deletions
jax_enable_memories
flag has been deleted and the behavior of that flag is on by default.
From jax.lib.xla_client
, the previously-deprecated Device
and XlaRuntimeError
symbols have been removed; instead use jax.Device
and jax.errors.JaxRuntimeError
respectively.
The jax.experimental.array_api
module has been removed after being deprecated in JAX v0.4.32. Since that release, jax.numpy
supports the array API directly.
Breaking Changes
XlaExecutable.cost_analysis
now returns a dict[str, float]
(instead of a single-element list[dict[str, float]]
).
Changes:
jax.tree.flatten_with_path
and jax.tree.map_with_path
are added as shortcuts of the corresponding tree_util
functions.
Deprecations
a number of APIs in the internal jax.core
namespace have been deprecated. Most were no-ops, were little-used, or can be replaced by APIs of the same name in jax.extend.core
; see the documentation for jax.extend
for information on the compatibility guarantees of these semi-public extensions.
Several previously-deprecated APIs have been removed, including:
from jax.core
: check_eqn
, check_type
, check_valid_jaxtype
, and non_negative_dim
.
from jax.lib.xla_bridge
: xla_client
and default_backend
.
from jax.lib.xla_client
: _xla
and bfloat16
.
from jax.numpy
: round_
.
New Features
jax.export.export()
can be used for device-polymorphic export with shardings constructed with jax.sharding.AbstractMesh()
. See the jax.export documentation.
Added jax.lax.split()
. This is a primitive version of jax.numpy.split()
, added because it yields a more compact transpose during automatic differentiation.
This is a patch release of jax 0.4.36. Only “jax” was released at this version.
Bug fixes
Fixed a bug where jit
would error if an argument was named f
(#25329).
Fix a bug that will throw index out of range
error in jax.lax.while_loop()
if the user register pytree node class with different aux data for the flatten and flatten_with_path.
Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.
Breaking Changes
This release lands “stackless”, an internal change to JAX’s tracing machinery. We made trace dispatch purely a function of context rather than a function of both context and data. This let us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call
, new_base_main
, custom_bind
, and so on. The change should only affect users that use JAX internals.
If you do use JAX internals then you may need to update your code (see https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f for clues about how to do this). There might also be version skew issues with JAX libraries that do this. If you find this change breaks your non-JAX-internals-using code then try the config.jax_data_dependent_tracing_fallback
flag as a workaround, and if you need help updating your code then please file a bug.
jax.experimental.jax2tf.convert()
with native_serialization=False
or with enable_xla=False
have been deprecated since July 2024, with JAX version 0.4.31. Now we removed support for these use cases. jax2tf
with native serialization will still be supported.
In jax.interpreters.xla
, the xb
, xc
, and xe
symbols have been removed after being deprecated in JAX v0.4.31. Instead use xb = jax.lib.xla_bridge
, xc = jax.lib.xla_client
, and xe = jax.lib.xla_extension
.
The deprecated module jax.experimental.export
has been removed. It was replaced by jax.export
in JAX v0.4.30. See the migration guide for information on migrating to the new API.
The initial
argument to jax.nn.softmax()
and jax.nn.log_softmax()
has been removed, after being deprecated in v0.4.27.
Calling np.asarray
on typed PRNG keys (i.e. keys produced by :func:jax.random.key
) now raises an error. Previously, this returned a scalar object array.
The following deprecated methods and functions in jax.export
have been removed:
jax.export.DisabledSafetyCheck.shape_assertions
: it had no effect already.
jax.export.Exported.lowering_platforms
: use platforms
.
jax.export.Exported.mlir_module_serialization_version
: use calling_convention_version
.
jax.export.Exported.uses_shape_polymorphism
: use uses_global_constants
.
the lowering_platforms
kwarg for jax.export.export()
: use platforms
instead.
The kwargs symbolic_scope
and symbolic_constraints
from jax.export.symbolic_args_specs()
have been removed. They were deprecated in June 2024. Use scope
and constraints
instead.
Hashing of tracers, which has been deprecated since version 0.4.30, now results in a TypeError
.
Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and replaces previous build.py usage. Run python build/build.py --help
for more details. Brief overview of the new subcommand options:
build
: Builds JAX wheel packages. For e.g., python build/build.py build --wheels=jaxlib,jax-cuda-pjrt
requirements_update
: Updates requirements_lock.txt files.
jax.scipy.linalg.toeplitz()
now does implicit batching on multi-dimensional inputs. To recover the previous behavior, you can call jax.numpy.ravel()
on the function inputs.
jax.scipy.special.gamma()
and jax.scipy.special.gammasgn()
now return NaN for negative integer inputs, to match the behavior of SciPy from https://github.com/scipy/scipy/pull/21827.
jax.clear_backends
was removed after being deprecated in v0.4.26.
We removed the custom call “__gpu$xla.gpu.triton” from the list of custom call that we guarantee export stability. This is because this custom call relies on Triton IR, which is not guaranteed to be stable. If you need to export code that uses this custom call, you can use the disabled_checks
parameter. See more details in the documentation.
New Features
jax.jit()
got a new compiler_options: dict[str, Any]
argument, for passing compilation options to XLA. For the moment it’s undocumented and may be in flux.
jax.tree_util.register_dataclass()
now allows metadata fields to be declared inline via dataclasses.field()
. See the function documentation for examples.
Added jax.numpy.put_along_axis()
.
jax.lax.linalg.eig()
and the related jax.numpy
functions (jax.numpy.linalg.eig()
and jax.numpy.linalg.eigvals()
) are now supported on GPU. See #24663 for more details.
Added two new configuration flags, jax_exec_time_optimization_effort
and jax_memory_fitting_effort
, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
Bug fixes
Fixed a bug where the GPU implementations of LU and QR decomposition would result in an indexing overflow for batch sizes close to int32 max. See #24843 for more details.
Deprecations
jax.lib.xla_extension.ArrayImpl
and jax.lib.xla_client.ArrayImpl
are deprecated; use jax.Array
instead.
jax.lib.xla_extension.XlaRuntimeError
is deprecated; use jax.errors.JaxRuntimeError
instead.
Breaking Changes
jax.numpy.isscalar()
now returns True for any array-like object with zero dimensions. Previously it only returned True for zero-dimensional array-like objects with a weak dtype.
jax.experimental.host_callback
has been deprecated since March 2024, with JAX version 0.4.26. Now we removed it. See #20385 for a discussion of alternatives.
Changes:
jax.lax.FftType
was introduced as a public name for the enum of FFT operations. The semi-public API jax.lib.xla_client.FftType
has been deprecated.
TPU: JAX now installs TPU support from the libtpu
package rather than libtpu-nightly
. For the next few releases JAX will pin an empty version of libtpu-nightly
as well as libtpu
to ease the transition; that dependency will be removed in Q1 2025.
Deprecations:
The semi-public API jax.lib.xla_client.PaddingType
has been deprecated. No JAX APIs consume this type, so there is no replacement.
The default behavior of jax.pure_callback()
and jax.extend.ffi.ffi_call()
under vmap
has been deprecated and so has the vectorized
parameter to those functions. The vmap_method
parameter should be used instead for better defined behavior. See the discussion in #23881 for more details.
The semi-public API jax.lib.xla_client.register_custom_call_target
has been deprecated. Use the JAX FFI instead.
The semi-public APIs jax.lib.xla_client.dtype_to_etype
, jax.lib.xla_client.ops
, jax.lib.xla_client.shape_from_pyval
, jax.lib.xla_client.PrimitiveType
, jax.lib.xla_client.Shape
, jax.lib.xla_client.XlaBuilder
, and jax.lib.xla_client.XlaComputation
have been deprecated. Use StableHLO instead.
New Functionality
This release includes wheels for Python 3.13. Free-threading mode is not yet supported.
jax.errors.JaxRuntimeError
has been added as a public alias for the formerly private XlaRuntimeError
type.
Breaking changes
jax_pmap_no_rank_reduction
flag is set to True
by default.
array[0] on a pmap result now introduces a reshape (use array[0:1] instead).
The per-shard shape (accessable via jax_array.addressable_shards or jax_array.addressable_data(0)) now has a leading (1, …). Update code that directly accesses shards accordingly. The rank of the per-shard-shape now matches that of the global shape which is the same behavior as jit. This avoids costly reshapes when passing results from pmap into jit.
jax.experimental.host_callback
has been deprecated since March 2024, with JAX version 0.4.26. Now we set the default value of the --jax_host_callback_legacy
configuration value to True
, which means that if your code uses jax.experimental.host_callback
APIs, those API calls will be implemented in terms of the new jax.experimental.io_callback
API. If this breaks your code, for a very limited time, you can set the --jax_host_callback_legacy
to True
. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs. See #20385 for a discussion.
Deprecations
In jax.numpy.trim_zeros()
, non-arraylike arguments or arraylike arguments with ndim != 1
are now deprecated, and in the future will result in an error.
Internal pretty-printing tools jax.core.pp_*
have been removed, after being deprecated in JAX v0.4.30.
jax.lib.xla_client.Device
is deprecated; use jax.Device
instead.
jax.lib.xla_client.XlaRuntimeError
has been deprecated. Use jax.errors.JaxRuntimeError
instead.
Deletion:
jax.xla_computation
is deleted. It’s been 3 months since it’s deprecation in 0.4.30 JAX release. Please use the AOT APIs to get the same functionality as jax.xla_computation
.
jax.xla_computation(fn)(*args, **kwargs)
can be replaced with jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
.
You can also use .out_info
property of jax.stages.Lowered
to get the output information (like tree structure, shape and dtype).
For cross-backend lowering, you can replace jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
with jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
.
jax.ShapeDtypeStruct
no longer accepts the named_shape
argument. The argument was only used by xmap
which was removed in 0.4.31.
jax.tree.map(f, None, non-None)
, which previously emitted a DeprecationWarning
, now raises an error in a future version of jax. None
is only a tree-prefix of itself. To preserve the current behavior, you can ask jax.tree.map
to treat None
as a leaf value by writing: jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.
jax.sharding.XLACompatibleSharding
has been removed. Please use jax.sharding.Sharding
.
Bug fixes
Fixed a bug where jax.numpy.cumsum()
would produce incorrect outputs if a non-boolean input was provided and dtype=bool
was specified.
Edit implementation of jax.numpy.ldexp()
to get correct gradient.
This is a patch release on top of jax 0.4.32, that fixes two bugs found in that release.
A TPU-only data corruption bug was found in the version of libtpu pinned by JAX 0.4.32, which manifested only if multiple TPU slices were present in the same job, for example, if training on multiple v5e slices. This release fixes that issue by pinning a fixed version of libtpu
.
This release fixes an inaccurate result for F64 tanh on CPU (#23590).
jax 0.4.32 (September 11, 2024)#Note: This release was yanked from PyPi because of a data corruption bug on TPU. See the 0.4.33 release notes for more details.
New Functionality
Added jax.extend.ffi.ffi_call()
and jax.extend.ffi.ffi_lowering()
to support the use of the new Foreign function interface (FFI) to interface with custom C++ and CUDA code from JAX.
Changes
jax_enable_memories
flag is set to True
by default.
jax.numpy
now supports v2023.12 of the Python Array API Standard. See Python Array API standard for more information.
Computations on the CPU backend may now be dispatched asynchronously in more cases. Previously non-parallel computations were always dispatched synchronously. You can recover the old behavior by setting jax.config.update('jax_cpu_enable_async_dispatch', False)
.
Added new jax.process_indices()
function to replace the jax.host_ids()
function that was deprecated in JAX v0.2.13.
To align with the behavior of numpy.fabs
, jax.numpy.fabs
has been modified to no longer support complex dtypes
.
jax.tree_util.register_dataclass
now checks that data_fields
and meta_fields
includes all dataclass fields with init=True
and only them, if nodetype
is a dataclass.
Several jax.numpy
functions now have full ufunc
interfaces, including add
, multiply
, bitwise_and
, bitwise_or
, bitwise_xor
, logical_and
, logical_and
, and logical_and
. In future releases we plan to expand these to other ufuncs.
Added jax.lax.optimization_barrier()
, which allows users to prevent compiler optimizations such as common-subexpression elimination and to control scheduling.
Breaking changes
The MHLO MLIR dialect (jax.extend.mlir.mhlo
) has been removed. Use the stablehlo
dialect instead.
Deprecations
Complex inputs to jax.numpy.clip()
and jax.numpy.hypot()
are no longer allowed, after being deprecated since JAX v0.4.27.
Deprecated the following APIs:
jax.lib.xla_bridge.xla_client
: use jax.lib.xla_client
directly.
jax.lib.xla_bridge.get_backend
: use jax.extend.backend.get_backend()
.
jax.lib.xla_bridge.default_backend
: use jax.extend.backend.default_backend()
.
The jax.experimental.array_api
module is deprecated, and importing it is no longer required to use the Array API. jax.numpy
supports the array API directly; see Python Array API standard for more information.
The internal utilities jax.core.check_eqn
, jax.core.check_type
, and jax.core.check_valid_jaxtype
are now deprecated, and will be removed in the future.
jax.numpy.round_
has been deprecated, following removal of the corresponding API in NumPy 2.0. Use jax.numpy.round()
instead.
Passing a DLPack capsule to jax.dlpack.from_dlpack()
is deprecated. The argument to jax.dlpack.from_dlpack()
should be an array from another framework that implements the __dlpack__
protocol.
Note: This release was yanked from PyPi because of a data corruption bug on TPU. See the 0.4.33 release notes for more details.
Breaking changes
This release of jaxlib switched to a new version of the CPU backend, which should compile faster and leverage parallelism better. If you experience any problems due to this change, you can temporarily enable the old CPU backend by setting the environment variable XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
. If you need to do this, please file a JAX bug with instructions to reproduce.
Hermetic CUDA support is added. Hermetic CUDA uses a specific downloadable version of CUDA instead of the user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL distributions, and then use CUDA libraries and tools as dependencies in various Bazel targets. This enables more reproducible builds for JAX and its supported CUDA versions.
Changes
SparseCore profiling is added.
JAX now supports profiling SparseCore on TPUv5p chips. These traces will be viewable in Tensorboard Profiler’s TraceViewer.
Deletion
xmap has been deleted. Please use shard_map()
as the replacement.
Changes
The minimum CuDNN version is v9.1. This was true in previous releases also, but we now declare this version constraint formally.
The minimum Python version is now 3.10. 3.10 will remain the minimum supported version until July 2025.
The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum supported version until December 2024.
The minimum SciPy version is now 1.10. SciPy 1.10 will remain the minimum supported version until January 2025.
jax.numpy.ceil()
, jax.numpy.floor()
and jax.numpy.trunc()
now return the output of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
libdevice.10.bc
is no longer bundled with CUDA wheels. It must be installed either as a part of local CUDA installation, or via NVIDIA’s CUDA pip wheels.
jax.experimental.pallas.BlockSpec
now expects block_shape
to be passed before index_map
. The old argument order is deprecated and will be removed in a future release.
Updated the repr of gpu devices to be more consistent with TPUs/CPUs. For example, cuda(id=0)
will now be CudaDevice(id=0)
.
Added the device
property and to_device
method to jax.Array
, as part of JAX’s Array API support.
Deprecations
Removed a number of previously-deprecated internal APIs related to polymorphic shapes. From jax.core
: removed canonicalize_shape
, dimension_as_value
, definitely_equal
, and symbolic_equal_dim
.
HLO lowering rules should no longer wrap singleton ir.Values in tuples. Instead, return singleton ir.Values unwrapped. Support for wrapped values will be removed in a future version of JAX.
jax.experimental.jax2tf.convert()
with native_serialization=False
or enable_xla=False
is now deprecated and this support will be removed in a future version. Native serialization has been the default since JAX 0.4.16 (September 2023).
The previously-deprecated function jax.random.shuffle
has been removed; instead use jax.random.permutation
with independent=True
.
Bug fixes
Fixed a bug that meant that negative static_argnums to a jit were mishandled by the jit dispatch fast path.
Fixed a bug that meant triangular solves of batches of singular matrices produce nonsensical finite values, instead of inf or nan (#3589, #15429).
Changes
JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was bumped to 0.4.0 but this has been rolled back in this release to give users of both TensorFlow and JAX more time to migrate to a newer TensorFlow release.
jax.experimental.mesh_utils
can now create an efficient mesh for TPU v5e.
jax now depends on jaxlib directly. This change was enabled by the CUDA plugin switch: there are no longer multiple jaxlib variants. You can install a CPU-only jax with pip install jax
, no extras required.
Added an API for exporting and serializing JAX functions. This used to exist in jax.experimental.export
(which is being deprecated), and will now live in jax.export
. See the documentation.
Deprecations
Internal pretty-printing tools jax.core.pp_*
are deprecated, and will be removed in a future release.
Hashing of tracers is deprecated, and will lead to a TypeError
in a future JAX release. This previously was the case, but there was an inadvertent regression in the last several JAX releases.
jax.experimental.export
is deprecated. Use jax.export
instead. See the migration guide.
Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays x
and y
, x.astype(y)
will raise a warning. To silence it use x.astype(y.dtype)
.
jax.xla_computation
is deprecated and will be removed in a future release. Please use the AOT APIs to get the same functionality as jax.xla_computation
.
jax.xla_computation(fn)(*args, **kwargs)
can be replaced with jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
.
You can also use .out_info
property of jax.stages.Lowered
to get the output information (like tree structure, shape and dtype).
For cross-backend lowering, you can replace jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
with jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
.
Support for monolithic CUDA jaxlibs has been dropped. You must use the plugin-based installation (pip install jax[cuda12]
or pip install jax[cuda12_local]
).
Changes
We anticipate that this will be the last release of JAX and jaxlib supporting a monolithic CUDA jaxlib. Future releases will use the CUDA plugin jaxlib (e.g. pip install jax[cuda12]
).
JAX now requires ml_dtypes version 0.4.0 or newer.
Removed backwards-compatibility support for old usage of the jax.experimental.export
API. It is not possible anymore to use from jax.experimental.export import export
, and instead you should use from jax.experimental import export
. The removed functionality has been deprecated since 0.4.24.
Added is_leaf
argument to jax.tree.all()
& jax.tree_util.tree_all()
.
Deprecations
jax.sharding.XLACompatibleSharding
is deprecated. Please use jax.sharding.Sharding
.
jax.experimental.Exported.in_shardings
has been renamed as jax.experimental.Exported.in_shardings_hlo
. Same for out_shardings
. The old names will be removed after 3 months.
Removed a number of previously-deprecated APIs:
The tol
argument of jax.numpy.linalg.matrix_rank()
is being deprecated and will soon be removed. Use rtol
instead.
The rcond
argument of jax.numpy.linalg.pinv()
is being deprecated and will soon be removed. Use rtol
instead.
The deprecated jax.config
submodule has been removed. To configure JAX use import jax
and then reference the config object via jax.config
.
jax.random
APIs no longer accept batched keys, where previously some did unintentionally. Going forward, we recommend explicit use of jax.vmap()
in such cases.
In jax.scipy.special.beta()
, the x
and y
parameters have been renamed to a
and b
for consistency with other beta
APIs.
New Functionality
Added jax.experimental.Exported.in_shardings_jax()
to construct shardings that can be used with the JAX APIs from the HloShardings that are stored in the Exported
objects.
Bug fixes
Fixed a bug where XLA sharded some concatenation operations incorrectly, which manifested as an incorrect output for cumulative reductions (#21403).
Fixed a bug where XLA:CPU miscompiled certain matmul fusions (https://github.com/openxla/xla/pull/13301).
Fixes a compiler crash on GPU (https://github.com/jax-ml/jax/issues/21396).
Deprecations
jax.tree.map(f, None, non-None)
now emits a DeprecationWarning
, and will raise an error in a future version of jax. None
is only a tree-prefix of itself. To preserve the current behavior, you can ask jax.tree.map
to treat None
as a leaf value by writing: jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.
Bug fixes
Reverted a change to make_jaxpr
that was breaking Equinox (#21116).
Deprecations & removals
The kind
argument to jax.numpy.sort()
and jax.numpy.argsort()
is now removed. Use stable=True
or stable=False
instead.
Removed get_compute_capability
from the jax.experimental.pallas.gpu
module. Use the compute_capability
attribute of a GPU device, returned by jax.devices()
or jax.local_devices()
, instead.
The newshape
argument to jax.numpy.reshape()
is being deprecated and will soon be removed. Use shape
instead.
Changes
The minimum jaxlib version of this release is 0.4.27.
Bug fixes
Fixes a memory corruption bug in the type name of Array and JIT Python objects in Python 3.10 or earlier.
Fixed a warning '+ptx84' is not a recognized feature for this target
under CUDA 12.4.
Fixed a slow compilation problem on CPU.
Changes
The Windows build is now built with Clang instead of MSVC.
New Functionality
Added jax.numpy.unstack()
and jax.numpy.cumulative_sum()
, following their addition in the array API 2023 standard, soon to be adopted by NumPy.
Added a new config option jax_cpu_collectives_implementation
to select the implementation of cross-process collective operations used by the CPU backend. Choices available are 'none'
(default), 'gloo'
and 'mpi'
(requires jaxlib 0.4.26). If set to 'none'
, cross-process collective operations are disabled.
Changes
jax.pure_callback()
, jax.experimental.io_callback()
and jax.debug.callback()
now use jax.Array
instead of np.ndarray
. You can recover the old behavior by transforming the arguments via jax.tree.map(np.asarray, args)
before passing them to the callback.
complex_arr.astype(bool)
now follows the same semantics as NumPy, returning False where complex_arr
is equal to 0 + 0j
, and True otherwise.
core.Token
now is a non-trivial class which wraps a jax.Array
. It could be created and threaded in and out of computations to build up dependency. The singleton object core.token
has been removed, users now should create and use fresh core.Token
objects instead.
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with jax.config.update('jax_threefry_gpu_kernel_lowering', True)
. If the new default causes issues, please file a bug. Otherwise, we intend to remove this flag in a future release.
Deprecations & Removals
Pallas now exclusively uses XLA for compiling kernels on GPU. The old lowering pass via Triton Python APIs has been removed and the JAX_TRITON_COMPILE_VIA_XLA
environment variable no longer has any effect.
jax.numpy.clip()
has a new argument signature: a
, a_min
, and a_max
are deprecated in favor of x
(positional only), min
, and max
(#20550).
The device()
method of JAX arrays has been removed, after being deprecated since JAX v0.4.21. Use arr.devices()
instead.
The initial
argument to jax.nn.softmax()
and jax.nn.log_softmax()
is deprecated; empty inputs to softmax are now supported without setting this.
In jax.jit()
, passing invalid static_argnums
or static_argnames
now leads to an error rather than a warning.
The minimum jaxlib version is now 0.4.23.
The jax.numpy.hypot()
function now issues a deprecation warning when passing complex-valued inputs to it. This will raise an error when the deprecation is completed.
Scalar arguments to jax.numpy.nonzero()
, jax.numpy.where()
, and related functions now raise an error, following a similar change in NumPy.
The config option jax_cpu_enable_gloo_collectives
is deprecated. Use jax.config.update('jax_cpu_collectives_implementation', 'gloo')
instead.
The jax.Array.device_buffer
and jax.Array.device_buffers
methods have been removed after being deprecated in JAX v0.4.22. Instead use jax.Array.addressable_shards
and jax.Array.addressable_data()
.
The condition
, x
, and y
parameters of jax.numpy.where
are now positional-only, following deprecation of the keywords in JAX v0.4.21.
Non-array arguments to functions in jax.lax.linalg
now must be specified by keyword. Previously, this raised a DeprecationWarning.
Array-like arguments are now required in several :func:jax.numpy
APIs, including apply_along_axis()
, apply_over_axes()
, inner()
, outer()
, cross()
, kron()
, and lexsort()
.
Bug fixes
jax.numpy.astype()
will now always return a copy when copy=True
. Previously, no copy would be made when the output array would have the same dtype as the input array. This may result in some increased memory usage. The default value is set to copy=False
to preserve backwards compatibility.
New Functionality
Added jax.numpy.trapezoid()
, following the addition of this function in NumPy 2.0.
Changes
Complex-valued jax.numpy.geomspace()
now chooses the logarithmic spiral branch consistent with that of NumPy 2.0.
The behavior of lax.rng_bit_generator
, and in turn the 'rbg'
and 'unsafe_rbg'
PRNG implementations, under jax.vmap
has changed so that mapping over keys results in random generation only from the first key in the batch.
Docs now use jax.random.key
for construction of PRNG key arrays rather than jax.random.PRNGKey
.
Deprecations & Removals
jax.tree_map()
is deprecated; use jax.tree.map
instead, or for backward compatibility with older JAX versions, use jax.tree_util.tree_map()
.
jax.clear_backends()
is deprecated as it does not necessarily do what its name suggests and can lead to unexpected consequences, e.g., it will not destroy existing backends and release corresponding owned resources. Use jax.clear_caches()
if you only want to clean up compilation caches. For backward compatibility or you really need to switch/reinitialize the default backend, use jax.extend.backend.clear_backends()
.
The jax.experimental.maps
module and jax.experimental.maps.xmap
are deprecated. Use jax.experimental.shard_map
or jax.vmap
with the spmd_axis_name
argument for expressing SPMD device-parallel computations.
The jax.experimental.host_callback
module is deprecated. Use instead the new JAX external callbacks. Added JAX_HOST_CALLBACK_LEGACY
flag to assist in the transition to the new callbacks. See #20385 for a discussion.
Passing arguments to jax.numpy.array_equal()
and jax.numpy.array_equiv()
that cannot be converted to a JAX array now results in an exception.
The deprecated flag jax_parallel_functions_output_gda
has been removed. This flag was long deprecated and did nothing; its use was a no-op.
The previously-deprecated imports jax.interpreters.ad.config
and jax.interpreters.ad.source_info_util
have now been removed. Use jax.config
and jax.extend.source_info_util
instead.
JAX export does not support older serialization versions anymore. Version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024. See a description of the versions. This change could break clients that set a specific JAX serialization version lower than 9.
Changes
JAX now supports CUDA 12.1 or newer only. Support for CUDA 11.8 has been dropped.
JAX now supports NumPy 2.0.
New Features
Added CUDA Array Interface import support (requires jaxlib 0.4.24).
JAX arrays now support NumPy-style scalar boolean indexing, e.g. x[True]
or x[False]
.
Added jax.tree
module, with a more convenient interface for referencing functions in jax.tree_util
.
jax.tree.transpose()
(i.e. jax.tree_util.tree_transpose()
) now accepts inner_treedef=None
, in which case the inner treedef will be automatically inferred.
Changes
Pallas now uses XLA instead of the Triton Python APIs to compile Triton kernels. You can revert to the old behavior by setting the JAX_TRITON_COMPILE_VIA_XLA
environment variable to "0"
.
Several deprecated APIs in jax.interpreters.xla
that were removed in v0.4.24 have been re-added in v0.4.25, including backend_specific_translations
, translations
, register_translation
, xla_destructure
, TranslationRule
, TranslationContext
, and XLAOp
. These are still considered deprecated, and will be removed again in the future when better replacements are available. Refer to #19816 for discussion.
Deprecations & Removals
jax.numpy.linalg.solve()
now shows a deprecation warning for batched 1D solves with b.ndim > 1
. In the future these will be treated as batched 2D solves.
Conversion of a non-scalar array to a Python scalar now raises an error, regardless of the size of the array. Previously a deprecation warning was raised in the case of non-scalar arrays of size 1. This follows a similar deprecation in NumPy.
The previously deprecated configuration APIs have been removed following a standard 3 months deprecation cycle (see API compatibility). These include
the jax.config.config
object and
the define_*_state
and DEFINE_*
methods of jax.config
.
Importing the jax.config
submodule via import jax.config
is deprecated. To configure JAX use import jax
and then reference the config object via jax.config
.
The minimum jaxlib version is now 0.4.20.
Changes
JAX lowering to StableHLO does not depend on physical devices anymore. If your primitive wraps custom_partitioning or JAX callbacks in the lowering rule i.e. function passed to rule
parameter of mlir.register_lowering
then add your primitive to jax._src.dispatch.prim_requires_devices_during_lowering
set. This is needed because custom_partitioning and JAX callbacks need physical devices to create Sharding
s during lowering. This is a temporary state until we can create Sharding
s without physical devices.
jax.numpy.argsort()
and jax.numpy.sort()
now support the stable
and descending
arguments.
Several changes to the handling of shape polymorphism (used in jax.experimental.jax2tf
and jax.experimental.export
):
cleaner pretty-printing of symbolic expressions (#19227)
added the ability to specify symbolic constraints on the dimension variables. This makes shape polymorphism more expressive, and gives a way to workaround limitations in the reasoning about inequalities. See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
with the addition of symbolic constraints (#19235) we now consider dimension variables from different scopes to be different, even if they have the same name. Symbolic expressions from different scopes cannot interact, e.g., in arithmetic operations. Scopes are introduced by jax.experimental.jax2tf.convert()
, jax.experimental.export.symbolic_shape()
, jax.experimental.export.symbolic_args_specs()
. The scope of a symbolic expression e
can be read with e.scope
and passed into the above functions to direct them to construct symbolic expressions in a given scope. See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
simplified and faster equality comparisons, where we consider two symbolic dimensions to be equal if the normalized form of their difference reduces to 0 (#19231; note that this may result in user-visible behavior changes)
improved the error messages for inconclusive inequality comparisons (#19235).
the core.non_negative_dim
API (introduced recently) was deprecated and core.max_dim
and core.min_dim
were introduced (#18953) to express max
and min
for symbolic dimensions. You can use core.max_dim(d, 0)
instead of core.non_negative_dim(d)
.
the shape_poly.is_poly_dim
is deprecated in favor of export.is_symbolic_dim
(#19282).
the export.args_specs
is deprecated in favor of export.symbolic_args_specs ({jax-issue}
#19283`).
the shape_poly.PolyShape
and jax2tf.PolyShape
are deprecated, use strings for polymorphic shapes specifications (#19284).
JAX default native serialization version is now 9. This is relevant for jax.experimental.jax2tf
and jax.experimental.export
. See description of version numbers.
Refactored the API for jax.experimental.export
. Instead of from jax.experimental.export import export
you should use now from jax.experimental import export
. The old way of importing will continue to work for a deprecation period of 3 months.
Added jax.scipy.stats.sem()
.
jax.numpy.unique()
with return_inverse = True
returns inverse indices reshaped to the dimension of the input, following a similar change to numpy.unique()
in NumPy 2.0.
jax.numpy.sign()
now returns x / abs(x)
for nonzero complex inputs. This is consistent with the behavior of numpy.sign()
in NumPy version 2.0.
jax.scipy.special.logsumexp()
with return_sign=True
now uses the NumPy 2.0 convention for the complex sign, x / abs(x)
. This is consistent with the behavior of scipy.special.logsumexp()
in SciPy v1.13.
JAX now supports the bool DLPack type for both import and export. Previously bool values could not be imported and were exported as integers.
Deprecations & Removals
A number of previously deprecated functions have been removed, following a standard 3+ month deprecation cycle (see API compatibility). This includes:
From jax.core
: TracerArrayConversionError
, TracerIntegerConversionError
, UnexpectedTracerError
, as_hashable_function
, collections
, dtypes
, lu
, map
, namedtuple
, partial
, pp
, ref
, safe_zip
, safe_map
, source_info_util
, total_ordering
, traceback_util
, tuple_delete
, tuple_insert
, and zip
.
From jax.lax
: dtypes
, itertools
, naryop
, naryop_dtype_rule
, standard_abstract_eval
, standard_naryop
, standard_primitive
, standard_unop
, unop
, and unop_dtype_rule
.
The jax.linear_util
submodule and all its contents.
The jax.prng
submodule and all its contents.
From jax.random
: PRNGKeyArray
, KeyArray
, default_prng_impl
, threefry_2x32
, threefry2x32_key
, threefry2x32_p
, rbg_key
, and unsafe_rbg_key
.
From jax.tree_util
: register_keypaths
, AttributeKeyPathEntry
, and GetItemKeyPathEntry
.
from jax.interpreters.xla
: backend_specific_translations
, translations
, register_translation
, xla_destructure
, TranslationRule
, TranslationContext
, axis_groups
, ShapedArray
, ConcreteArray
, AxisEnv
, backend_compile
, and XLAOp
.
from jax.numpy
: NINF
, NZERO
, PZERO
, row_stack
, issubsctype
, trapz
, and in1d
.
from jax.scipy.linalg
: tril
and triu
.
The previously-deprecated method PRNGKeyArray.unsafe_raw_array
has been removed. Use jax.random.key_data()
instead.
bool(empty_array)
now raises an error rather than returning False
. This previously raised a deprecation warning, and follows a similar change in NumPy.
Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses the mhlo dialect, in favor of stablehlo. APIs that refer to “mhlo” will be removed in the future. Use the “stablehlo” dialect instead.
jax.random
: passing batched keys directly to random number generation functions, such as bits()
, gamma()
, and others, is deprecated and will emit a FutureWarning
. Use jax.vmap
for explicit batching.
jax.lax.tie_in()
is deprecated: it has been a no-op since JAX v0.2.0.
Changes
JAX now supports CUDA 12.3 and CUDA 11.8. Support for CUDA 12.2 has been dropped.
cost_analysis
now works with cross-compiled Compiled
objects (i.e. when using .lower().compile()
with a topology object, e.g., to compile for Cloud TPU from a non-TPU computer).
Added CUDA Array Interface import support (requires jax 0.4.25).
Fixed a bug that caused verbose logging from the GPU compiler during compilation.
Deprecations
The device_buffer
and device_buffers
properties of JAX arrays are deprecated. Explicit buffers have been replaced by the more flexible array sharding interface, but the previous outputs can be recovered this way:
arr.device_buffer
becomes arr.addressable_data(0)
arr.device_buffers
becomes [x.data for x in arr.addressable_shards]
New Features
Added jax.nn.squareplus
.
Changes
The minimum jaxlib version is now 0.4.19.
Released wheels are built now with clang instead of gcc.
Enforce that the device backend has not been initialized prior to calling jax.distributed.initialize()
.
Automate arguments to jax.distributed.initialize()
in cloud TPU environments.
Deprecations
The previously-deprecated sym_pos
argument has been removed from jax.scipy.linalg.solve()
. Use assume_a='pos'
instead.
Passing None
to jax.array()
or jax.asarray()
, either directly or within a list or tuple, is deprecated and now raises a FutureWarning
. It currently is converted to NaN, and in the future will raise a TypeError
.
Passing the condition
, x
, and y
parameters to jax.numpy.where
by keyword arguments has been deprecated, to match numpy.where
.
Passing arguments to jax.numpy.array_equal()
and jax.numpy.array_equiv()
that cannot be converted to a JAX array is deprecated and now raises a DeprecationWaning
. Currently the functions return False, in the future this will raise an exception.
The device()
method of JAX arrays is deprecated. Depending on the context, it may be replaced with one of the following:
jax.Array.devices()
returns the set of all devices used by the array.
jax.Array.sharding
gives the sharding configuration used by the array.
Changes
In preparation for adding distributed CPU support, JAX now treats CPU devices identically to GPU and TPU devices, that is:
jax.devices()
includes all devices present in a distributed job, even those not local to the current process. jax.local_devices()
still only includes devices local to the current process, so if the change to jax.devices()
breaks you, you most likely want to use jax.local_devices()
instead.
CPU devices now receive a globally unique ID number within a distributed job; previously CPU devices would receive a process-local ID number.
The process_index
of each CPU device will now match any GPU or TPU devices within the same process; previously the process_index
of a CPU device was always 0.
On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to 1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
Bug fixes
Fixed error/hang when an array with non-finite values is passed to a non-symmetric eigendecomposition (#18226). Arrays with non-finite values now produce arrays full of NaNs as outputs.
Bug fixes
Fixed some type confusion between E4M3 and E5M2 float8 types.
New Features
Added jax.typing.DTypeLike
, which can be used to annotate objects that are convertible to JAX dtypes.
Added jax.numpy.fill_diagonal
.
Changes
JAX now requires SciPy 1.9 or newer.
Bug fixes
Only process 0 in a multicontroller distributed JAX program will write persistent compilation cache entries. This fixes write contention if the cache is placed on a network file system such as GCS.
The version check for cusolver and cufft no longer considers the patch versions when determining if the installed version of these libraries is at least as new as the versions against which JAX was built.
Changes
jaxlib will now always prefer pip-installed NVIDIA CUDA libraries (nvidia-… packages) over any other CUDA installation if they are installed, including installations named in LD_LIBRARY_PATH
. If this causes problems and the intent is to use a system-installed CUDA, the fix is to remove the pip installed CUDA library packages.
Changes
CUDA jaxlibs now depend on the user to install a compatible NCCL version. If using the recommended cuda12_pip
installation, NCCL should be installed automatically. Currently, NCCL 2.16 or newer is required.
We now provide Linux aarch64 wheels, both with and without NVIDIA GPU support.
jax.Array.item()
now supports optional index arguments.
Deprecations
A number of internal utilities and inadvertent exports in jax.lax
have been deprecated, and will be removed in a future release.
jax.lax.dtypes
: use jax.dtypes
instead.
jax.lax.itertools
: use itertools
instead.
naryop
, naryop_dtype_rule
, standard_abstract_eval
, standard_naryop
, standard_primitive
, standard_unop
, unop
, and unop_dtype_rule
are internal utilities, now deprecated without replacement.
Bug fixes
Fixed Cloud TPU regression where compilation would OOM due to smem.
New features
Added new jax.numpy.bitwise_count()
function, matching the API of the similar function recently added to NumPy.
Deprecations
Removed the deprecated module jax.abstract_arrays
and all its contents.
Named key constructors in jax.random
are deprecated. Pass the impl
argument to jax.random.PRNGKey()
or jax.random.key()
instead:
random.threefry2x32_key(seed)
becomes random.PRNGKey(seed, impl='threefry2x32')
random.rbg_key(seed)
becomes random.PRNGKey(seed, impl='rbg')
random.unsafe_rbg_key(seed)
becomes random.PRNGKey(seed, impl='unsafe_rbg')
Changes:
CUDA: JAX now verifies that the CUDA libraries it finds are at least as new as the CUDA libraries that JAX was built against. If older libraries are found, JAX raises an exception since that is preferable to mysterious failures and crashes.
Removed the “No GPU/TPU” found warning. Instead warn if, on Linux, an NVIDIA GPU or a Google TPU are found but not used and --jax_platforms
was not specified.
jax.scipy.stats.mode()
now returns a 0 count if the mode is taken across a size-0 axis, matching the behavior of scipy.stats.mode
in SciPy 1.11.
Most jax.numpy
functions and attributes now have fully-defined type stubs. Previously many of these were treated as Any
by static type checkers like mypy
and pytype
.
Changes:
Python 3.12 wheels were added in this release.
The CUDA 12 wheels now require CUDA 12.2 or newer and cuDNN 8.9.4 or newer.
Bug fixes:
Fixed log spam from ABSL when the JAX CPU backend was initialized.
Changes
Added jax.numpy.ufunc
, as well as jax.numpy.frompyfunc()
, which can convert any scalar-valued function into a numpy.ufunc()
-like object, with methods such as outer()
, reduce()
, accumulate()
, at()
, and reduceat()
(#17054).
When not running under IPython: when an exception is raised, JAX now filters out the entirety of its internal frames from tracebacks. (Without the “unfiltered stack trace” that previously appeared.) This should produce much friendlier-looking tracebacks. See here for an example. This behavior can be changed by setting JAX_TRACEBACK_FILTERING=remove_frames
(for two separate unfiltered/filtered tracebacks, which was the old behavior) or JAX_TRACEBACK_FILTERING=off
(for one unfiltered traceback).
jax2tf default serialization version is now 7, which introduces new shape safety assertions.
Devices passed to jax.sharding.Mesh
should be hashable. This specifically applies to mock devices or user created devices. jax.devices()
are already hashable.
Breaking changes:
jax2tf now uses native serialization by default. See the jax2tf documentation for details and for mechanisms to override the default.
The option --jax_coordination_service
has been removed. It is now always True
.
jax.jaxpr_util
has been removed from the public JAX namespace.
JAX_USE_PJRT_C_API_ON_TPU
no longer has an effect (i.e. it always defaults to true).
The backwards compatibility flag --jax_host_callback_ad_transforms
introduced in December 2021, has been removed.
Deprecations:
Several jax.numpy
APIs have been deprecated following NumPy NEP-52:
jax.numpy.NINF
has been deprecated. Use -jax.numpy.inf
instead.
jax.numpy.PZERO
has been deprecated. Use 0.0
instead.
jax.numpy.NZERO
has been deprecated. Use -0.0
instead.
jax.numpy.issubsctype(x, t)
has been deprecated. Use jax.numpy.issubdtype(x.dtype, t)
.
jax.numpy.row_stack
has been deprecated. Use jax.numpy.vstack
instead.
jax.numpy.in1d
has been deprecated. Use jax.numpy.isin
instead.
jax.numpy.trapz
has been deprecated. Use jax.scipy.integrate.trapezoid
instead.
jax.scipy.linalg.tril
and jax.scipy.linalg.triu
have been deprecated, following SciPy. Use jax.numpy.tril
and jax.numpy.triu
instead.
jax.lax.prod
has been removed after being deprecated in JAX v0.4.11. Use the built-in math.prod
instead.
A number of exports from jax.interpreters.xla
related to defining HLO lowering rules for custom JAX primitives have been deprecated. Custom primitives should be defined using the StableHLO lowering utilities in jax.interpreters.mlir
instead.
The following previously-deprecated functions have been removed after a three-month deprecation period:
jax.abstract_arrays.ShapedArray
: use jax.core.ShapedArray
.
jax.abstract_arrays.raise_to_shaped
: use jax.core.raise_to_shaped
.
jax.numpy.alltrue
: use jax.numpy.all
.
jax.numpy.sometrue
: use jax.numpy.any
.
jax.numpy.product
: use jax.numpy.prod
.
jax.numpy.cumproduct
: use jax.numpy.cumprod
.
Deprecations/removals:
The internal submodule jax.prng
is now deprecated. Its contents are available at jax.extend.random
.
The internal submodule path jax.linear_util
has been deprecated. Use jax.extend.linear_util
instead (Part of jax.extend: a module for extensions)
jax.random.PRNGKeyArray
and jax.random.KeyArray
are deprecated. Use jax.Array
for type annotations, and jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)
for runtime detection of typed prng keys.
The method PRNGKeyArray.unsafe_raw_array
is deprecated. Use jax.random.key_data()
instead.
jax.experimental.pjit.with_sharding_constraint
is deprecated. Use jax.lax.with_sharding_constraint
instead.
The internal utilities jax.core.is_opaque_dtype
and jax.core.has_opaque_dtype
have been removed. Opaque dtypes have been renamed to Extended dtypes; use jnp.issubdtype(dtype, jax.dtypes.extended)
instead (available since jax v0.4.14).
The utility jax.interpreters.xla.register_collective_primitive
has been removed. This utility did nothing useful in recent JAX releases and calls to it can be safely removed.
The internal submodule path jax.linear_util
has been deprecated. Use jax.extend.linear_util
instead (Part of jax.extend: a module for extensions)
Changes:
Sparse CSR matrix multiplications via the experimental jax sparse APIs no longer uses a deterministic algorithm on NVIDIA GPUs. This change was made to improve compatibility with CUDA 12.2.1.
Bug fixes:
Fixed a crash on Windows due to a fatal LLVM error related to out-of-order sections and IMAGE_REL_AMD64_ADDR32NB relocations (https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4).
Changes
jax.jit
takes donate_argnames
as an argument. It’s semantics are similar to static_argnames
. If neither donate_argnums nor donate_argnames is provided, no arguments are donated. If donate_argnums is not provided but donate_argnames is, or vice versa, JAX uses inspect.signature(fun)
to find any positional arguments that correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual parameters listed in either donate_argnums or donate_argnames will be donated.
jax.random.gamma()
has been re-factored to a more efficient algorithm with more robust endpoint behavior (#16779). This means that the sequence of values returned for a given key
will change between JAX v0.4.13 and v0.4.14 for gamma
and related samplers (including jax.random.ball()
, jax.random.beta()
, jax.random.chisquare()
, jax.random.dirichlet()
, jax.random.generalized_normal()
, jax.random.loggamma()
, jax.random.t()
).
Deletions
in_axis_resources
and out_axis_resources
have been deleted from pjit since it has been more than 3 months since their deprecation. Please use in_shardings
and out_shardings
as the replacement. This is a safe and trivial name replacement. It does not change any of the current pjit semantics and doesn’t break any code. You can still pass in PartitionSpecs
to in_shardings and out_shardings.
Deprecations
Python 3.8 support has been dropped as per https://docs.jax.dev/en/latest/deprecation.html
JAX now requires NumPy 1.22 or newer as per https://docs.jax.dev/en/latest/deprecation.html
Passing optional arguments to jax.numpy.ndarray.at()
by position is no longer supported, after being deprecated in JAX version 0.4.7. For example, instead of x.at[i].get(True)
, use x.at[i].get(indices_are_sorted=True)
The following jax.Array
methods have been removed, after being deprecated in JAX v0.4.5:
jax.Array.broadcast
: use jax.lax.broadcast()
instead.
jax.Array.broadcast_in_dim
: use jax.lax.broadcast_in_dim()
instead.
jax.Array.split
: use jax.numpy.split()
instead.
The following APIs have been removed after previous deprecation:
jax.ad
: use jax.interpreters.ad
.
jax.curry
: use curry = lambda f: partial(partial, f)
.
jax.partial_eval
: use jax.interpreters.partial_eval
.
jax.pxla
: use jax.interpreters.pxla
.
jax.xla
: use jax.interpreters.xla
.
jax.ShapedArray
: use jax.core.ShapedArray
.
jax.interpreters.pxla.device_put
: use jax.device_put()
.
jax.interpreters.pxla.make_sharded_device_array
: use jax.make_array_from_single_device_arrays()
.
jax.interpreters.pxla.ShardedDeviceArray
: use jax.Array
.
jax.numpy.DeviceArray
: use jax.Array
.
jax.stages.Compiled.compiler_ir
: use jax.stages.Compiled.as_text()
.
Breaking changes
JAX now requires ml_dtypes version 0.2.0 or newer.
To fix a corner case, calls to jax.lax.cond()
with five arguments will always resolve to the “common operands” cond
behavior (as documented) if the second and third arguments are callable, even if other operands are callable as well. See #16413.
The deprecated config options jax_array
and jax_jit_pjit_api_merge
, which did nothing, have been removed. These options have been true by default for many releases.
New features
JAX now supports a configuration flag –jax_serialization_version and a JAX_SERIALIZATION_VERSION environment variable to control the serialization version (#16746).
jax2tf in presence of shape polymorphism now generates code that checks certain shape constraints, if the serialization version is at least 7. See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
Deprecations
Python 3.8 support has been dropped as per https://docs.jax.dev/en/latest/deprecation.html
Changes
jax.jit
now allows None
to be passed to in_shardings
and out_shardings
. The semantics are as follows:
For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
jax.experimental.pjit.pjit
also allows None
to be passed to in_shardings
and out_shardings
. The semantics are as follows:
If the mesh context manager is not provided, JAX has the freedom to choose whatever sharding it wants.
For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh.
Executable.cost_analysis() works on Cloud TPU
Added a warning if a non-allowlisted jaxlib
plugin is in use.
Added jax.tree_util.tree_leaves_with_path
.
None
is not a valid input to jax.experimental.multihost_utils.host_local_array_to_global_array
or jax.experimental.multihost_utils.global_array_to_host_local_array
. Please use jax.sharding.PartitionSpec()
if you wanted to replicate your input.
Bug fixes
Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel is named cudnn89
instead of cudnn88
.
Deprecations
The native_serialization_strict_checks
parameter to jax.experimental.jax2tf.convert()
is deprecated in favor of the new native_serializaation_disabled_checks
(#16347).
Changes
Added Windows CPU-only wheels to the jaxlib
Pypi release.
Bug fixes
__cuda_array_interface__
was broken in previous jaxlib versions and is now fixed (#16440).
Concurrent CUDA kernel tracing is now enabled by default on NVIDIA GPUs.
Changes
Deprecations
jax.abstract_arrays
and its contents are now deprecated. See related functionality in :mod:jax.core
.
jax.numpy.alltrue
: use jax.numpy.all
. This follows the deprecation of numpy.alltrue
in NumPy version 1.25.0.
jax.numpy.sometrue
: use jax.numpy.any
. This follows the deprecation of numpy.sometrue
in NumPy version 1.25.0.
jax.numpy.product
: use jax.numpy.prod
. This follows the deprecation of numpy.product
in NumPy version 1.25.0.
jax.numpy.cumproduct
: use jax.numpy.cumprod
. This follows the deprecation of numpy.cumproduct
in NumPy version 1.25.0.
jax.sharding.OpShardingSharding
has been removed since it has been 3 months since it was deprecated.
Changes
Includes PTX/SASS for Hopper (SM version 9.0+) GPUs. Previous versions of jaxlib should work on Hopper but would have a long JIT-compilation delay the first time a JAX operation was executed.
Bug fixes
Fixes incorrect source line information in JAX-generated Python tracebacks under Python 3.11.
Fixes crash when printing local variables of frames in JAX-generated Python tracebacks (#16027).
Deprecations
The following APIs have been removed after a 3 month deprecation period, in accordance with the API compatibility policy:
jax.experimental.PartitionSpec
: use jax.sharding.PartitionSpec
.
jax.experimental.maps.Mesh
: use jax.sharding.Mesh
jax.experimental.pjit.NamedSharding
: use jax.sharding.NamedSharding
.
jax.experimental.pjit.PartitionSpec
: use jax.sharding.PartitionSpec
.
jax.experimental.pjit.FROM_GDA
. Instead pass sharded jax.Array
objects as input and remove the optional in_shardings
argument to pjit
.
jax.interpreters.pxla.PartitionSpec
: use jax.sharding.PartitionSpec
.
jax.interpreters.pxla.Mesh
: use jax.sharding.Mesh
jax.interpreters.xla.Buffer
: use jax.Array
.
jax.interpreters.xla.Device
: use jax.Device
.
jax.interpreters.xla.DeviceArray
: use jax.Array
.
jax.interpreters.xla.device_put
: use jax.device_put
.
jax.interpreters.xla.xla_call_p
: use jax.experimental.pjit.pjit_p
.
axis_resources
argument of with_sharding_constraint
is removed. Please use shardings
instead.
Changes
Added memory_stats()
method to Device
s. If supported, this returns a dict of string stat names with int values, e.g. "bytes_in_use"
, or None if the platform doesn’t support memory statistics. The exact stats returned may vary across platforms. Currently only implemented on Cloud TPU.
Readded support for the Python buffer protocol (memoryview
) on CPU devices.
Changes
Fixed 'apple-m1' is not a recognized processor for this target (ignoring processor)
issue that prevented previous release from running on Mac M1.
Changes
The flags experimental_cpp_jit, experimental_cpp_pjit and experimental_cpp_pmap have been removed. They are now always on.
Accuracy of singular value decomposition (SVD) on TPU has been improved (requires jaxlib 0.4.9).
Deprecations
jax.experimental.gda_serialization
is deprecated and has been renamed to jax.experimental.array_serialization
. Please change your imports to use jax.experimental.array_serialization
.
The in_axis_resources
and out_axis_resources
arguments of pjit have been deprecated. Please use in_shardings
and out_shardings
respectively.
The function jax.numpy.msort
has been removed. It has been deprecated since JAX v0.4.1. Use jnp.sort(a, axis=0)
instead.
in_parts
and out_parts
arguments have been removed from jax.xla_computation
since they were only used with sharded_jit and sharded_jit is long gone.
instantiate_const_outputs
argument has been removed from jax.xla_computation
since it has been unused for a very long time.
Breaking changes
A major component of the Cloud TPU runtime has been upgraded. This enables the following new features on Cloud TPU:
jax.debug.print()
, jax.debug.callback()
, and jax.debug.breakpoint()
now work on Cloud TPU
Automatic TPU memory defragmentation
jax.experimental.host_callback()
is no longer supported on Cloud TPU with the new runtime component. Please file an issue on the JAX issue tracker if the new jax.debug
APIs are insufficient for your use case.
The old runtime component will be available for at least the next three months by setting the environment variable JAX_USE_PJRT_C_API_ON_TPU=false
. If you find you need to disable the new runtime for any reason, please let us know on the JAX issue tracker.
Changes
The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
Deprecations
CUDA 11.4 support has been dropped. JAX GPU wheels only support CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built from source.
global_arg_shapes
argument of pmap only worked with sharded_jit and has been removed from pmap. Please migrate to pjit and remove global_arg_shapes from pmap.
Changes
As per https://docs.jax.dev/en/latest/jax_array_migration.html#jax-array-migration jax.config.jax_array
cannot be disabled anymore.
jax.config.jax_jit_pjit_api_merge
cannot be disabled anymore.
jax.experimental.jax2tf.convert()
now supports the native_serialization
parameter to use JAX’s native lowering to StableHLO to obtain a StableHLO module for the entire JAX function instead of lowering each JAX primitive to a TensorFlow op. This simplifies the internals and increases the confidence that what you serialize matches the JAX native semantics. See documentation. As part of this change the config flag --jax2tf_default_experimental_native_lowering
has been renamed to --jax2tf_native_serialization
.
JAX now depends on ml_dtypes
, which contains definitions of NumPy types like bfloat16. These definitions were previously internal to JAX, but have been split into a separate package to facilitate sharing them with other projects.
JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.
Deprecations
The type jax.numpy.DeviceArray
is deprecated. Use jax.Array
instead, for which it is an alias.
The type jax.interpreters.pxla.ShardedDeviceArray
is deprecated. Use jax.Array
instead.
Passing additional arguments to jax.numpy.ndarray.at()
by position is deprecated. For example, instead of x.at[i].get(True)
, use x.at[i].get(indices_are_sorted=True)
jax.interpreters.xla.device_put
is deprecated. Please use jax.device_put
.
jax.interpreters.pxla.device_put
is deprecated. Please use jax.device_put
.
jax.experimental.pjit.FROM_GDA
is deprecated. Please pass in sharded jax.Arrays as input and remove the in_shardings
argument to pjit since it is optional.
Changes:
jaxlib now depends on ml_dtypes
, which contains definitions of NumPy types like bfloat16. These definitions were previously internal to JAX, but have been split into a separate package to facilitate sharing them with other projects.
Changes
jax.tree_util
now contain a set of APIs that allow user to define keys for their custom pytree node. This includes:
tree_flatten_with_path
that flattens a tree and return not only each leaf but also their key paths.
tree_map_with_path
that can map a function that takes the key path as an argument.
register_pytree_with_keys
to register how the key path and leaves should looks like in a custom pytree node.
keystr
that pretty-prints a key path.
jax2tf.call_tf()
has a new parameter output_shape_dtype
(default None
) that can be used to declare the output shape and type of the result. This enables jax2tf.call_tf()
to work in the presence of shape polymorphism. (#14734).
Deprecations
The old key-path APIs in jax.tree_util
are deprecated and will be removed 3 months from Mar 10 2023:
register_keypaths
: use jax.tree_util.register_pytree_with_keys()
instead.
AttributeKeyPathEntry
: use GetAttrKey
instead.
GetitemKeyPathEntry
: use SequenceKey
or DictKey
instead.
Deprecations
jax.sharding.OpShardingSharding
has been renamed to jax.sharding.GSPMDSharding
. jax.sharding.OpShardingSharding
will be removed in 3 months from Feb 17, 2023.
The following jax.Array
methods are deprecated and will be removed 3 months from Feb 23 2023:
jax.Array.broadcast
: use jax.lax.broadcast()
instead.
jax.Array.broadcast_in_dim
: use jax.lax.broadcast_in_dim()
instead.
jax.Array.split
: use jax.numpy.split()
instead.
Changes
The implementation of jit
and pjit
has been merged. Merging jit and pjit changes the internals of JAX without affecting the public API of JAX. Before, jit
was a final style primitive. Final style means that the creation of jaxpr was delayed as much as possible and transformations were stacked on top of each other. With the jit
-pjit
implementation merge, jit
becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see this section in autodidax. Moving to initial style should simplify JAX’s internals and make development of features like dynamic shapes, etc easier. You can disable it only via the environment variable i.e. os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'
. The merge must be disabled via an environment variable since it affects JAX at import time so it needs to be disabled before jax is imported.
axis_resources
argument of with_sharding_constraint
is deprecated. Please use shardings
instead. There is no change needed if you were using axis_resources
as an arg. If you were using it as a kwarg, then please use shardings
instead. axis_resources
will be removed after 3 months from Feb 13, 2023.
added the jax.typing
module, with tools for type annotations of JAX functions.
The following names have been deprecated:
jax.xla.Device
and jax.interpreters.xla.Device
: use jax.Device
.
jax.experimental.maps.Mesh
. Use jax.sharding.Mesh
instead.
jax.experimental.pjit.NamedSharding
: use jax.sharding.NamedSharding
.
jax.experimental.pjit.PartitionSpec
: use jax.sharding.PartitionSpec
.
jax.interpreters.pxla.Mesh
: use jax.sharding.Mesh
.
jax.interpreters.pxla.PartitionSpec
: use jax.sharding.PartitionSpec
.
Breaking Changes
the initial
argument to reduction functions like :func:jax.numpy.sum
is now required to be a scalar, consistent with the corresponding NumPy API. The previous behavior of broadcasting the output against non-scalar initial
values was an unintentional implementation detail (#14446).
Breaking changes
Support for NVIDIA Kepler series GPUs has been removed from the default jaxlib
builds. If Kepler support is needed, it is still possible to build jaxlib
from source with Kepler support (via the --cuda_compute_capabilities=sm_35
option to build.py
), however note that CUDA 12 has completely dropped support for Kepler GPUs.
Breaking changes
Deleted jax.scipy.linalg.polar_unitary()
, which was a deprecated JAX extension to the scipy API. Use jax.scipy.linalg.polar()
instead.
Changes
Added jax.scipy.stats.rankdata()
.
jax.Array
now has the non-blocking is_ready()
method, which returns True
if the array is ready (see also jax.block_until_ready()
).
Breaking changes
Deleted jax.experimental.callback
Operations with dimensions in presence of jax2tf shape polymorphism have been generalized to work in more scenarios, by converting the symbolic dimension to JAX arrays. Operations involving symbolic dimensions and np.ndarray
now can raise errors when the result is used as a shape value (#14106).
jaxpr objects now raise an error on attribute setting in order to avoid problematic mutations (#14102)
Changes
jax2tf.call_tf()
has a new parameter has_side_effects
(default True
) that can be used to declare whether an instance can be removed or replicated by JAX optimizations such as dead-code elimination (#13980).
Added more support for floordiv and mod for jax2tf shape polymorphism. Previously, certain division operations resulted in errors in presence of symbolic dimensions (#14108).
Changes
Set JAX_USE_PJRT_C_API_ON_TPU=1 to enable new Cloud TPU runtime, featuring automatic device memory defragmentation.
Changes
Support for Python 3.7 has been dropped, in accordance with JAX’s Python and NumPy version support policy.
We introduce jax.Array
which is a unified array type that subsumes DeviceArray
, ShardedDeviceArray
, and GlobalDeviceArray
types in JAX. The jax.Array
type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unify jit
and pjit
. jax.Array
has been enabled by default in JAX 0.4 and makes some breaking change to the pjit
API. The jax.Array migration guide can help you migrate your codebase to jax.Array
. You can also look at the Distributed arrays and automatic parallelization tutorial to understand the new concepts.
PartitionSpec
and Mesh
are now out of experimental. The new API endpoints are jax.sharding.PartitionSpec
and jax.sharding.Mesh
. jax.experimental.maps.Mesh
and jax.experimental.PartitionSpec
are deprecated and will be removed in 3 months.
with_sharding_constraint
s new public endpoint is jax.lax.with_sharding_constraint
.
If using ABSL flags together with jax.config
, the ABSL flag values are no longer read or written after the JAX configuration options are initially populated from the ABSL flags. This change improves performance of reading jax.config
options, which are used pervasively in JAX.
The jax2tf.call_tf function now uses for TF lowering the first TF device of the same platform as used by the embedding JAX computation. Before, it was using the 0th device for the JAX-default backend.
A number of jax.numpy
functions now have their arguments marked as positional-only, matching NumPy.
jnp.msort
is now deprecated, following the deprecation of np.msort
in numpy 1.24. It will be removed in a future release, in accordance with the API compatibility policy. It can be replaced with jnp.sort(a, axis=0)
.
Changes
Support for Python 3.7 has been dropped, in accordance with JAX’s Python and NumPy version support policy.
The behavior of XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to GPU memory allocation for more details.
The deprecated method .block_host_until_ready()
has been removed. Use .block_until_ready()
instead.
The release was yanked.
The release was yanked.
Changes
jax.numpy.linalg.pinv()
now supports the hermitian
option.
jax.scipy.linalg.hessenberg()
is now supported on CPU only. Requires jaxlib > 0.3.24.
New functions jax.lax.linalg.hessenberg()
, jax.lax.linalg.tridiagonal()
, and jax.lax.linalg.householder_product()
were added. Householder reduction is currently CPU-only and tridiagonal reductions are supported on CPU and GPU only.
The gradients of svd
and jax.numpy.linalg.pinv
are now computed more economically for non-square matrices.
Breaking Changes
Deleted the jax_experimental_name_stack
config option.
Convert a string axis_names
arguments to the jax.experimental.maps.Mesh
constructor into a singleton tuple instead of unpacking the string into a sequence of character axis names.
Changes
Added support for tridiagonal reductions on CPU and GPU.
Added support for upper Hessenberg reductions on CPU.
Bugs
Fixed a bug that meant that frames in tracebacks captured by JAX were incorrectly mapped to source lines under Python 3.10+
Changes
JAX should be faster to import. We now import scipy lazily, which accounted for a significant fraction of JAX’s import time.
Setting the env var JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N
can be used to limit the number of cache entries written to the persistent cache. By default, computations that take 1 second or more to compile will be cached.
Added jax.scipy.stats.mode()
.
The default device order used by pmap
on TPU if no order is specified now matches jax.devices()
for single-process jobs. Previously the two orderings differed, which could lead to unnecessary copies or out-of-memory errors. Requiring the orderings to agree simplifies matters.
Breaking Changes
jax.numpy.gradient()
now behaves like most other functions in jax.numpy
, and forbids passing lists or tuples in place of arrays (#12958)
Functions in jax.numpy.linalg
and jax.numpy.fft
now uniformly require inputs to be array-like: i.e. lists and tuples cannot be used in place of arrays. Part of #7737.
Deprecations
jax.sharding.MeshPspecSharding
has been renamed to jax.sharding.NamedSharding
. jax.sharding.MeshPspecSharding
name will be removed in 3 months.
Changes
Buffer donation now works on CPU. This may break code that marked buffers for donation on CPU but relied on donation not being implemented.
Changes
Update Colab TPU driver version for new jaxlib release.
Changes
Add JAX_PLATFORMS=tpu,cpu
as default setting in TPU initialization, so JAX will raise an error if TPU cannot be initialized instead of falling back to CPU. Set JAX_PLATFORMS=''
to override this behavior and automatically choose an available backend (the original default), or set JAX_PLATFORMS=cpu
to always use CPU regardless of if the TPU is available.
Deprecations
Several test utilities deprecated in JAX v0.3.8 are now removed from jax.test_util
.
Changes
The persistent compilation cache will now warn instead of raising an exception on error (#12582), so program execution can continue if something goes wrong with the cache. Set JAX_RAISE_PERSISTENT_CACHE_ERRORS=true
to revert this behavior.
Bug fixes:
Bug fixes
Fixes support for limiting the visible CUDA devices via jax_cuda_visible_devices
in distributed jobs. This functionality is needed for the JAX/SLURM integration on GPU (#12533).
Fixes required jaxlib version.
Changes
Ahead-of-time lowering and compilation functionality (tracked in #7733) is stable and public. See the overview and the API docs for jax.stages
.
Introduced jax.Array
, intended to be used for both isinstance
checks and type annotations for array types in JAX. Notice that this included some subtle changes to how isinstance
works for jax.numpy.ndarray
for jax-internal objects, as jax.numpy.ndarray
is now a simple alias of jax.Array
.
Breaking changes
jax._src
is no longer imported into the public jax
namespace. This may break users that were using JAX internals.
jax.soft_pmap
has been deleted. Please use pjit
or xmap
instead. jax.soft_pmap
is undocumented. If it were documented, a deprecation period would have been provided.
Bugs
Fix corner case issue in gradient of lax.pow
with an exponent of zero (#12041)
Breaking changes
jax.checkpoint()
, also known as jax.remat()
, no longer supports the concrete
option, following the previous version’s deprecation; see JEP 11830.
Changes
Added jax.pure_callback()
that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with jax.jit
or jax.pmap
).
Deprecations:
The deprecated DeviceArray.tile()
method has been removed. Use jax.numpy.tile()
(#11944).
DeviceArray.to_py()
has been deprecated. Use np.asarray(x)
instead.
Breaking changes
Support for NumPy 1.19 has been dropped, per the deprecation policy. Please upgrade to NumPy 1.20 or newer.
Changes
Added jax.debug
that includes utilities for runtime value debugging such at jax.debug.print()
and jax.debug.breakpoint()
.
Added new documentation for runtime value debugging
Deprecations
jax.mask()
jax.shapecheck()
APIs have been removed. See #11557.
jax.experimental.loops
has been removed. See #10278 for an alternative API.
jax.tree_util.tree_multimap()
has been removed. It has been deprecated since JAX release 0.3.5, and jax.tree_util.tree_map()
is a direct replacement.
Removed jax.experimental.stax
; it has long been a deprecated alias of jax.example_libraries.stax
.
Removed jax.experimental.optimizers
; it has long been a deprecated alias of jax.example_libraries.optimizers
.
jax.checkpoint()
, also known as jax.remat()
, has a new implementation switched on by default, meaning the old implementation is deprecated; see JEP 11830.
Changes
JaxTestCase
and JaxTestLoader
have been removed from jax.test_util
. These classes have been deprecated since v0.3.1 (#11248).
Added jax.scipy.gaussian_kde
(#11237).
Binary operations between JAX arrays and built-in collections (dict
, list
, set
, tuple
) now raise a TypeError
in all cases. Previously some cases (particularly equality and inequality) would return boolean scalars inconsistent with similar operations in NumPy (#11234).
Several jax.tree_util
routines accessed as top-level JAX package imports are now deprecated, and will be removed in a future JAX release in accordance with the API compatibility policy:
jax.treedef_is_leaf()
is deprecated in favor of jax.tree_util.treedef_is_leaf()
jax.tree_flatten()
is deprecated in favor of jax.tree_util.tree_flatten()
jax.tree_leaves()
is deprecated in favor of jax.tree_util.tree_leaves()
jax.tree_structure()
is deprecated in favor of jax.tree_util.tree_structure()
jax.tree_transpose()
is deprecated in favor of jax.tree_util.tree_transpose()
jax.tree_unflatten()
is deprecated in favor of jax.tree_util.tree_unflatten()
The sym_pos
argument of jax.scipy.linalg.solve()
is deprecated in favor of assume_a='pos'
, following a similar deprecation in scipy.linalg.solve()
.
Breaking changes
jax.experimental.compilation_cache.initialize_cache()
does not support max_cache_size_ bytes
anymore and will not get that as an input.
JAX_PLATFORMS
now raises an exception when platform initialization fails.
Changes
Fixed compatibility problems with NumPy 1.23.
jax.numpy.linalg.slogdet()
now accepts an optional method
argument that allows selection between an LU-decomposition based implementation and an implementation based on QR decomposition.
jax.numpy.linalg.qr()
now supports mode="raw"
.
pickle
, copy.copy
, and copy.deepcopy
now have more complete support when used on jax arrays (#10659). In particular:
pickle
and deepcopy
previously returned np.ndarray
objects when used on a DeviceArray
; now DeviceArray
objects are returned. For deepcopy
, the copied array is on the same device as the original. For pickle
the deserialized array will be on the default device.
Within function transformations (i.e. traced code), deepcopy
and copy
previously were no-ops. Now they use the same mechanism as DeviceArray.copy()
.
Calling pickle
on a traced array now results in an explicit ConcretizationTypeError
.
The implementation of singular value decomposition (SVD) and symmetric/Hermitian eigendecomposition should be significantly faster on TPU, especially for matrices above 1000x1000 or so. Both now use a spectral divide-and-conquer algorithm for eigendecomposition (QDWH-eig).
jax.numpy.ldexp()
no longer silently promotes all inputs to float64, instead it promotes to float32 for integer inputs of size int32 or smaller (#10921).
Add a create_perfetto_link
option to jax.profiler.start_trace()
and jax.profiler.start_trace()
. When used, the profiler will generate a link to the Perfetto UI to view the trace.
Changed the semantics of jax.profiler.start_server(...)()
to store the keepalive globally, rather than requiring the user to keep a reference to it.
Added jax.random.ball()
.
Added jax.default_device()
.
Added a python -m jax.collect_profile
script to manually capture program traces as an alternative to the TensorBoard UI.
Added a jax.named_scope
context manager that adds profiler metadata to Python programs (similar to jax.named_call
).
In scatter-update operations (i.e. :attr:jax.numpy.ndarray.at
), unsafe implicit dtype casts are deprecated, and now result in a FutureWarning
. In a future release, this will become an error. An example of an unsafe implicit cast is jnp.zeros(4, dtype=int).at[0].set(1.5)
, in which 1.5
previously was silently truncated to 1
.
jax.experimental.compilation_cache.initialize_cache()
now supports gcs bucket path as input.
Added jax.scipy.stats.gennorm()
.
jax.numpy.roots()
is now better behaved when strip_zeros=False
when coefficients have leading zeros (#11215).
x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14 was released in 2018, so this should not be a very onerous requirement.
The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks.
The Python flatbuffers package is no longer a dependency of jaxlib.
Changes
Fixes #10717.
Changes
jax.lax.eigh()
now accepts an optional sort_eigenvalues
argument that allows users to opt out of eigenvalue sorting on TPU.
Deprecations
Non-array arguments to functions in jax.lax.linalg
are now marked keyword-only. As a backward-compatibility step passing keyword-only arguments positionally yields a warning, but in a future JAX release passing keyword-only arguments positionally will fail. However, most users should prefer to use jax.numpy.linalg
instead.
jax.scipy.linalg.polar_unitary()
, which was a JAX extension to the scipy API, is deprecated. Use jax.scipy.linalg.polar()
instead.
Changes
TF commit fixes an issue in the MHLO canonicalizer that caused constant folding to take a long time or crash for certain programs.
Changes
Added support for fully asynchronous checkpointing for GlobalDeviceArray.
Changes
jax.numpy.linalg.svd()
on TPUs uses a qdwh-svd solver.
jax.numpy.linalg.cond()
on TPUs now accepts complex input.
jax.numpy.linalg.pinv()
on TPUs now accepts complex input.
jax.numpy.linalg.matrix_rank()
on TPUs now accepts complex input.
jax.scipy.cluster.vq.vq()
has been added.
jax.experimental.maps.mesh
has been deleted. Please use jax.experimental.maps.Mesh
. Please see https://docs.jax.dev/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.
jax.scipy.linalg.qr()
now returns a length-1 tuple rather than the raw array when mode='r'
, in order to match the behavior of scipy.linalg.qr
(#10452)
jax.numpy.take_along_axis()
now takes an optional mode
parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passing mode="clip"
.
jax.numpy.take()
now defaults to mode="fill"
, which returns invalid values (e.g., NaN) for out-of-bounds indices.
Scatter operations, such as x.at[...].set(...)
, now have "drop"
semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.
jax.numpy.take_along_axis()
now raises a TypeError
if its indices are not of an integer type, matching the behavior of numpy.take_along_axis()
. Previously non-integer indices were silently cast to integers.
jax.numpy.ravel_multi_index()
now raises a TypeError
if its dims
argument is not of an integer type, matching the behavior of numpy.ravel_multi_index()
. Previously non-integer dims
was silently cast to integers.
jax.numpy.split()
now raises a TypeError
if its axis
argument is not of an integer type, matching the behavior of numpy.split()
. Previously non-integer axis
was silently cast to integers.
jax.numpy.indices()
now raises a TypeError
if its dimensions are not of an integer type, matching the behavior of numpy.indices()
. Previously non-integer dimensions were silently cast to integers.
jax.numpy.diag()
now raises a TypeError
if its k
argument is not of an integer type, matching the behavior of numpy.diag()
. Previously non-integer k
was silently cast to integers.
Added jax.random.orthogonal()
.
Deprecations
Many functions and objects available in jax.test_util
are now deprecated and will raise a warning on import. This includes cases_from_list
, check_close
, check_eq
, device_under_test
, format_shape_dtype_string
, rand_uniform
, skip_on_devices
, with_config
, xla_bridge
, and _default_tolerance
(#10389). These, along with previously-deprecated JaxTestCase
, JaxTestLoader
, and BufferDonationTestCase
, will be removed in a future JAX release. Most of these utilities can be replaced by calls to standard python & numpy testing utilities found in e.g. unittest
, absl.testing
, numpy.testing
, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as jax.devices()
. Many of the deprecated utilities will still exist in jax._src.test_util
, but these are not public APIs and as such may be changed or removed without notice in future releases.
Changes:
Fixed a performance problem if the indices passed to jax.numpy.take_along_axis()
were broadcasted (#10281).
jax.scipy.special.expit()
and jax.scipy.special.logit()
now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.
The DeviceArray.tile()
method is deprecated, because numpy arrays do not have a tile()
method. As a replacement for this, use jax.numpy.tile()
(#10266).
Changes:
Linux wheels are now built conforming to the manylinux2014
standard, instead of manylinux2010
.
Changes:
Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU pod. Fixes #10218.
Deprecations:
jax.experimental.loops
is being deprecated. See #10278 for an alternative API.
Changes:
added jax.random.loggamma()
& improved behavior of jax.random.beta()
and jax.random.dirichlet()
for small parameter values (#9906).
the private lax_numpy
submodule is no longer exposed in the jax.numpy
namespace (#10029).
added array creation routines jax.numpy.frombuffer()
, jax.numpy.fromfunction()
, and jax.numpy.fromstring()
(#10049).
DeviceArray.copy()
now returns a DeviceArray
rather than a np.ndarray
(#10069)
jax.experimental.sharded_jit
has been deprecated and will be removed soon.
Deprecations:
jax.nn.normalize()
is being deprecated. Use jax.nn.standardize()
instead (#9899).
jax.tree_util.tree_multimap()
is deprecated. Use jax.tree_util.tree_map()
instead (#5746).
jax.experimental.sharded_jit
is deprecated. Use pjit
instead.
Bug fixes
Changes:
The functions jax.ops.index_update
, jax.ops.index_add
, which were deprecated in 0.2.22, have been removed. Please use the .at
property on JAX arrays instead, e.g., x.at[idx].set(y)
.
Moved jax.experimental.ann.approx_*_k
into jax.lax
. These functions are optimized alternatives to jax.lax.top_k
.
jax.numpy.broadcast_arrays()
and jax.numpy.broadcast_to()
now require scalar or array-like inputs, and will fail if they are passed lists (part of #7737).
The standard jax[tpu] install can now be used with Cloud TPU v4 VMs.
pjit
now works on CPU (in addition to previous TPU and GPU support).
Changes
XlaComputation.as_hlo_text()
now supports printing large constants by passing boolean flag print_large_constants=True
.
Deprecations:
The .block_host_until_ready()
method on JAX arrays has been deprecated. Use .block_until_ready()
instead.
Changes:
jax.test_util.JaxTestCase
and jax.test_util.JaxTestLoader
are now deprecated. The suggested replacement is to use parametrized.TestCase
directly. For tests that rely on custom asserts such as JaxTestCase.assertAllClose()
, the suggested replacement is to use standard numpy testing utilities such as numpy.testing.assert_allclose()
, which work directly with JAX arrays (#9620).
jax.test_util.JaxTestCase
now sets jax_numpy_rank_promotion='raise'
by default (#9562). To recover the previous behavior, use the new jax.test_util.with_config
decorator:
@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
Added jax.scipy.linalg.schur()
, jax.scipy.linalg.sqrtm()
, jax.scipy.signal.csd()
, jax.scipy.signal.stft()
, jax.scipy.signal.welch()
.
Changes
jax version has been bumped to 0.3.0. Please see the design doc for the explanation.
Changes
Bazel 5.0.0 is now required to build jaxlib.
jaxlib version has been bumped to 0.3.0. Please see the design doc for the explanation.
jax.jit(f).lower(...).compiler_ir()
now defaults to the MHLO dialect if no dialect=
is passed.
The jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
now returns an MLIR ir.Module
object instead of its string representation.
New features
Includes precompiled SASS for NVidia compute capability 8.0 GPUS (e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not to increase the number of compute capabilities: GPUs with compute capability 6.1 can use the 6.0 SASS.
With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR by default.
Breaking changes
Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
Bug fixes
Fixed a bug where apparently identical pytreedef objects constructed by different routes do not compare as equal (#9066).
The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).
Breaking changes:
Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS
environment variable, or the --jax_host_callback_ad_transforms
flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs (#8678).
Sorting now matches the behavior of NumPy for 0.0
and NaN
regardless of the bit representation. In particular, 0.0
and -0.0
are now treated as equivalent, where previously -0.0
was treated as less than 0.0
. Additionally all NaN
representations are now treated as equivalent and sorted to the end of the array. Previously negative NaN
values were sorted to the front of the array, and NaN
values with different internal bit representations were not treated as equivalent, and were sorted according to those bit patterns (#9178).
jax.numpy.unique()
now treats NaN
values in the same way as np.unique
in NumPy versions 1.21 and newer: at most one NaN
value will appear in the uniquified output (#9184).
Bug fixes:
host_callback now supports ad_checkpoint.checkpoint (#8907).
New features:
add jax.block_until_ready
({jax-issue}`#8941)
Added a new debugging flag/environment variable JAX_DUMP_IR_TO=/path
. If set, JAX dumps the MHLO/HLO IR it generates for each computation to a file under the given path.
Added jax.ensure_compile_time_eval
to the public api (#7987).
jax2tf now supports a flag jax2tf_associative_scan_reductions to change the lowering for associative reductions, e.g., jnp.cumsum, to behave like JAX on CPU and GPU (to use an associative scan). See the jax2tf README for more details (#9189).
New features:
Support for python 3.10.
Bug fixes:
Out-of-bounds indices to jax.ops.segment_sum
will now be handled with FILL_OR_DROP
semantics, as documented. This primarily affects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634).
jax2tf will force the converted code to use XLA for the code fragments under jax.jit, e.g., most jax.numpy functions (#7839).
Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via the host, which is usually slower.
Added experimental MLIR Python bindings for use by JAX.
New features:
(Experimental) jax.distributed.initialize
exposes multi-host GPU backend.
jax.random.permutation
supports new independent
keyword argument (#8430)
Breaking changes
Moved jax.experimental.stax
to jax.example_libraries.stax
Moved jax.experimental.optimizers
to jax.example_libraries.optimizers
New features:
Added jax.lax.linalg.qdwh
.
New features:
jax.random.choice
and jax.random.permutation
now support multidimensional arrays and an optional axis
argument (#8158)
Breaking changes:
jax.numpy.take
and jax.numpy.take_along_axis
now require array-like inputs (see #7737)
Multiple cuDNN versions are now supported for jaxlib GPU cuda11
wheels.
cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN installation is new enough, since it supports additional functionality.
cuDNN 8.0.5 or newer.
Breaking changes:
The install commands for GPU jaxlib are as follows:
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
Breaking Changes
Static arguments to jax.pmap
must now be hashable.
Unhashable static arguments have long been disallowed on jax.jit
, but they were still permitted on jax.pmap
; jax.pmap
compared unhashable static arguments using object identity.
This behavior is a footgun, since comparing arguments using object identity leads to recompilation each time the object identity changes. Instead, we now ban unhashable arguments: if a user of jax.pmap
wants to compare static arguments by object identity, they can define __hash__
and __eq__
methods on their objects that do that, or wrap their objects in an object that has those operations with object identity semantics. Another option is to use functools.partial
to encapsulate the unhashable static arguments into the function object.
jax.util.partial
was an accidental export that has now been removed. Use functools.partial
from the Python standard library instead.
Deprecations
The functions jax.ops.index_update
, jax.ops.index_add
etc. are deprecated and will be removed in a future JAX release. Please use the .at
property on JAX arrays instead, e.g., x.at[idx].set(y)
. For now, these functions produce a DeprecationWarning
.
New features:
An optimized C++ code-path improving the dispatch time for pmap
is now the default when using jaxlib 0.1.72 or newer. The feature can be disabled using the --experimental_cpp_pmap
flag (or JAX_CPP_PMAP
environment variable).
jax.numpy.unique
now supports an optional fill_value
argument (#8121)
Breaking changes:
Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 11.1+.
Bug fixes:
Fixes https://github.com/jax-ml/jax/issues/7461, which caused wrong outputs on all platforms due to incorrect buffer aliasing inside the XLA compiler.
Breaking Changes
jax.api
has been removed. Functions that were available as jax.api.*
were aliases for functions in jax.*
; please use the functions in jax.*
instead.
jax.partial
, and jax.lax.partial
were accidental exports that have now been removed. Use functools.partial
from the Python standard library instead.
Boolean scalar indices now raise a TypeError
; previously this silently returned wrong results (#7925).
Many more jax.numpy
functions now require array-like inputs, and will error if passed a list (#7747 #7802 #7907). See #7737 for a discussion of the rationale behind this change.
When inside a transformation such as jax.jit
, jax.numpy.array
always stages the array it produces into the traced computation. Previously jax.numpy.array
would sometimes produce a on-device array, even under a jax.jit
decorator. This change may break code that used JAX arrays to perform shape or index computations that must be known statically; the workaround is to perform such computations using classic NumPy arrays instead.
jnp.ndarray
is now a true base-class for JAX arrays. In particular, this means that for a standard numpy array x
, isinstance(x, jnp.ndarray)
will now return False
(#7927).
New features:
Added jax.numpy.insert()
implementation (#7936).
Breaking Changes
Breaking changes:
Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 10.2 and CUDA 11.1+.
Breaking changes:
Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
The jit
decorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common operators such as +
.
This change should largely be transparent to most users. However, there is one known behavioral change, which is that large integer constants may now produce an error when passed directly to a JAX operator (e.g., x + 2**40
). The workaround is to cast the constant to an explicit type (e.g., np.float64(2**40)
).
New features:
Improved the support for shape polymorphism in jax2tf for operations that need to use a dimension size in array computation, e.g., jnp.mean
. (#7317)
Bug fixes:
Some leaked trace errors from the previous release (#7613)
Breaking changes:
Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.
Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
The host_callback mechanism now uses one thread per local device for making the calls to the Python callbacks. Previously there was a single thread for all devices. This means that the callbacks may now be called interleaved. The callbacks corresponding to one device will still be called in sequence.
Breaking changes:
Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.
The minimum jaxlib version is now 0.1.69.
The backend
argument to jax.dlpack.from_dlpack()
has been removed.
New features:
Added a polar decomposition (jax.scipy.linalg.polar()
).
Bug fixes:
Tightened the checks for lax.argmin and lax.argmax to ensure they are not used with an invalid axis
value, or with an empty reduction dimension. (#7196)
Fix bugs in TFRT CPU backend that results in incorrect results.
Bug fixes:
Default to the older “stream_executor” CPU runtime for jaxlib <= 0.1.68 to work around #7229, which caused wrong outputs on CPU due to a concurrency problem.
New features:
New SciPy function jax.scipy.special.sph_harm()
.
Reverse-mode autodiff functions (jax.grad()
, jax.value_and_grad()
, jax.vjp()
, and jax.linear_transpose()
) support a parameter that indicates which named axes should be summed over in the backward pass if they were broadcasted over in the forward pass. This enables use of these APIs in a non-per-example way inside maps (initially only jax.experimental.maps.xmap()
) (#6950).
New features:
#7042 Turned on TFRT CPU backend with significant dispatch performance improvements on CPU.
The jax2tf.convert()
supports inequalities and min/max for booleans (#6956).
New SciPy function jax.scipy.special.lpmn_values()
.
Breaking changes:
Support for NumPy 1.16 has been dropped, per the deprecation policy.
Bug fixes:
Fixed bug that prevented round-tripping from JAX to TF and back: jax2tf.call_tf(jax2tf.convert)
(#6947).
Bug fixes:
Fixed bug in TFRT CPU backend that gets nans when transfer TPU buffer to CPU.
New features:
The jax2tf.convert()
now has support for pjit
and sharded_jit
.
A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters tracebacks.
A new traceback filtering mode using __tracebackhide__
is now enabled by default in sufficiently recent versions of IPython.
The jax2tf.convert()
supports shape polymorphism even when the unknown dimensions are used in arithmetic operations, e.g., jnp.reshape(-1)
(#6827).
The jax2tf.convert()
generates custom attributes with location information in TF ops. The code that XLA generates after jax2tf has the same location information as JAX/XLA.
New SciPy function jax.scipy.special.lpmn()
.
Bug fixes:
The jax2tf.convert()
now ensures that it uses the same typing rules for Python scalars and for choosing 32-bit vs. 64-bit computations as JAX (#6883).
The jax2tf.convert()
now scopes the enable_xla
conversion parameter properly to apply only during the just-in-time conversion (#6720).
The jax2tf.convert()
now converts lax.dot_general
using the XlaDot
TensorFlow op, for better fidelity w.r.t. JAX numerical precision (#6717).
The jax2tf.convert()
now has support for inequality comparisons and min/max for complex numbers (#6892).
New features:
CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.
NVidia now promises compatibility between CUDA minor releases starting with CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that is compatible with CUDA 11.2 and 11.3.
There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use the CUDA 11.1 wheel for those versions (cuda111).
Jaxlib now bundles libdevice.10.bc
in CUDA wheels. There should be no need to point JAX to a CUDA installation to find this file.
Added automatic support for static keyword arguments to the jit()
implementation.
Added support for pretransformation exception traces.
Initial support for pruning unused arguments from jit()
-transformed computations. Pruning is still a work in progress.
Improved the string representation of PyTreeDef
objects.
Added support for XLA’s variadic ReduceWindow.
Bug fixes:
Fixed a bug in the remote cloud TPU support when large numbers of arguments are passed to a computation.
Fix a bug that meant that JAX garbage collection was not triggered by jit()
transformed functions.
New features:
When combined with jaxlib 0.1.66, jax.jit()
now supports static keyword arguments. A new static_argnames
option has been added to specify keyword arguments as static.
jax.nonzero()
has a new optional size
argument that allows it to be used within jit
(#6501)
jax.numpy.unique()
now supports the axis
argument (#6532).
jax.experimental.host_callback.call()
now supports pjit.pjit
(#6569).
Added jax.scipy.linalg.eigh_tridiagonal()
that computes the eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at present.
The order of the filtered and unfiltered stack traces in exceptions has been changed. The traceback attached to an exception thrown from JAX-transformed code is now filtered, with an UnfilteredStackTrace
exception containing the original trace as the __cause__
of the filtered exception. Filtered stack traces now also work with Python 3.6.
If an exception is thrown by code that has been transformed by reverse-mode automatic differentiation, JAX now attempts to attach as a __cause__
of the exception a JaxStackTraceBeforeTransformation
object that contains the stack trace that created the original operation in the forward pass. Requires jaxlib 0.1.66.
Breaking changes:
The following function names have changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed so please change your code.
host_id
–> process_index()
host_count
–> process_count()
host_ids
–> range(jax.process_count())
Similarly, the argument to local_devices()
has been renamed from host_id
to process_index
.
Arguments to jax.jit()
other than the function are now marked as keyword-only. This change is to prevent accidental breakage when arguments are added to jit
.
Bug fixes:
New features
New profiling APIs: jax.profiler.start_trace()
, jax.profiler.stop_trace()
, and jax.profiler.trace()
jax.lax.reduce()
is now differentiable.
Breaking changes:
The minimum jaxlib version is now 0.1.64.
Some profiler APIs names have been changed. There are still aliases, so this should not break existing code, but the aliases will eventually be removed so please change your code.
TraceContext
–> TraceAnnotation()
StepTraceContext
–> StepTraceAnnotation()
trace_function
–> annotate_function()
Omnistaging can no longer be disabled. See omnistaging for more information.
Python integers larger than the maximum int64
value will now lead to an overflow in all cases, rather than being silently converted to uint64
in some cases (#6047).
Outside X64 mode, Python integers outside the range representable by int32
will now lead to an OverflowError
rather than having their value silently truncated.
Bug fixes:
host_callback
now supports empty arrays in arguments and results (#6262).
jax.random.randint()
clips rather than wraps of out-of-bounds limits, and can now generate integers in the full range of the specified dtype (#5868)
New features:
Bug fixes:
#6136 generalized jax.flatten_util.ravel_pytree
to handle integer dtypes.
#6129 fixed a bug with handling some constants like enum.IntEnums
#6145 fixed batching issues with incomplete beta functions
#6014 fixed H2D transfers during tracing
#6165 avoids OverflowErrors when converting some large Python integers to floats
Breaking changes:
The minimum jaxlib version is now 0.1.62.
New features:
jax.scipy.stats.chi2()
is now available as a distribution with logpdf and pdf methods.
jax.scipy.stats.betabinom()
is now available as a distribution with logpmf and pmf methods.
Added jax.experimental.jax2tf.call_tf()
to call TensorFlow functions from JAX (#5627) and README).
Extended the batching rule for lax.pad
to support batching of the padding values.
Bug fixes:
jax.numpy.take()
properly handles negative indices (#5768)
Breaking changes:
JAX’s promotion rules were adjusted to make promotion more consistent and invariant to JIT. In particular, binary operations can now result in weakly-typed values when appropriate. The main user-visible effect of the change is that some operations result in outputs of different precision than before; for example the expression jnp.bfloat16(1) + 0.1 * jnp.arange(10)
previously returned a float64
array, and now returns a bfloat16
array. JAX’s type promotion behavior is described at Type promotion semantics.
jax.numpy.linspace()
now computes the floor of integer values, i.e., rounding towards -inf rather than 0. This change was made to match NumPy 1.20.0.
jax.numpy.i0()
no longer accepts complex numbers. Previously the function computed the absolute value of complex arguments. This change was made to match the semantics of NumPy 1.20.0.
Several jax.numpy
functions no longer accept tuples or lists in place of array arguments: jax.numpy.pad()
, :funcjax.numpy.ravel
, jax.numpy.repeat()
, jax.numpy.reshape()
. In general, jax.numpy
functions should be used with scalars or array arguments.
New features:
jaxlib wheels are now built to require AVX instructions on x86-64 machines by default. If you want to use JAX on a machine that doesn’t support AVX, you can build a jaxlib from source using the --target_cpu_features
flag to build.py
. --target_cpu_features
also replaces --enable_march_native
.
Bug fixes:
Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
bool
, int8
, and uint8
are now considered safe to cast to bfloat16
NumPy extension type.
New features:
Extend the jax.experimental.loops
module with support for pytrees. Improved error checking and error messages.
Add jax.experimental.enable_x64()
and jax.experimental.disable_x64()
. These are context managers which allow X64 mode to be temporarily enabled/disabled within a session.
Breaking changes:
jax.ops.segment_sum()
now drops segment IDs that are out of range rather than wrapping them into the segment ID space. This was done for performance reasons.
New features:
Add jax.closure_convert()
for use with higher-order custom derivative functions. (#5244)
Add jax.experimental.host_callback.call()
to call a custom Python function on the host and return a result to the device computation. (#5243)
Bug fixes:
jax.numpy.arccosh
now returns the same branch as numpy.arccosh
for complex inputs (#5156)
host_callback.id_tap
now works for jax.pmap
also. There is an optional parameter for id_tap
and id_print
to request that the device from which the value is tapped be passed as a keyword argument to the tap function (#5182).
Breaking changes:
jax.numpy.pad
now takes keyword arguments. Positional argument constant_values
has been removed. In addition, passing unsupported keyword arguments raises an error.
Changes for jax.experimental.host_callback.id_tap()
(#5243):
Removed support for kwargs
for jax.experimental.host_callback.id_tap()
. (This support has been deprecated for a few months.)
Changed the printing of tuples for jax.experimental.host_callback.id_print()
to use ‘(’ instead of ‘[‘.
Changed the jax.experimental.host_callback.id_print()
in presence of JVP to print a pair of primal and tangent. Previously, there were two separate print operations for the primals and the tangent.
host_callback.outfeed_receiver
has been removed (it is not necessary, and was deprecated a few months ago).
New features:
New flag for debugging inf
, analogous to that for NaN
(#5224).
New features:
Add jax.device_put_replicated
Add multi-host support to jax.experimental.sharded_jit
Add support for differentiating eigenvalues computed by jax.numpy.linalg.eig
Add support for building on Windows platforms
Add support for general in_axes and out_axes in jax.pmap
Add complex support for jax.numpy.linalg.slogdet
Bug fixes:
Fix higher-than-second order derivatives of jax.numpy.sinc
at zero
Fix some hard-to-hit bugs around symbolic zeros in transpose rules
Breaking changes:
jax.experimental.optix
has been deleted, in favor of the standalone optax
Python package.
indexing of JAX arrays with non-tuple sequences now raises a TypeError
. This type of indexing has been deprecated in Numpy since v1.16, and in JAX since v0.2.4. See #4564.
New Features:
Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. See README.md.
Breaking change cleanup
Raise an error on non-hashable static arguments for jax.jit and xla_computation. See cb48f42.
Improve consistency of type promotion behavior (#4744):
Adding a complex Python scalar to a JAX floating point number respects the precision of the JAX float. For example, jnp.float32(1) + 1j
now returns complex64
, where previously it returned complex128
.
Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type are now independent of the order of arguments. For example: jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
and jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
both return float16
, where previously the first returned float64
and the second returned float16
.
The contents of the (undocumented) jax.lax_linalg
linear algebra module are now exposed publicly as jax.lax.linalg
.
jax.random.PRNGKey
now produces the same results in and out of JIT compilation (#4877). This required changing the result for a given seed in a few particular cases:
With jax_enable_x64=False
, negative seeds passed as Python integers now return a different result outside JIT mode. For example, jax.random.PRNGKey(-1)
previously returned [4294967295, 4294967295]
, and now returns [0, 4294967295]
. This matches the behavior in JIT.
Seeds outside the range representable by int64
outside JIT now result in an OverflowError
rather than a TypeError
. This matches the behavior in JIT.
To recover the keys returned previously for negative integers with jax_enable_x64=False
outside JIT, you can use:
key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
DeviceArray now raises RuntimeError
instead of ValueError
when trying to access its value while it has been deleted.
Fixed a bug that meant JAX sometimes return platform-specific types (e.g., np.cint
) instead of standard types (e.g., np.int32
). (#4903)
Fixed a crash when constant-folding certain int16 operations. (#4971)
Added an is_leaf
predicate to pytree.flatten()
.
Fixed manylinux2010 compliance issues in GPU wheels.
Switched the CPU FFT implementation from Eigen to PocketFFT.
Fixed a bug where the hash of bfloat16 values was not correctly initialized and could change (#4651).
Add support for retaining ownership when passing arrays to DLPack (#4636).
Fixed a bug for batched triangular solves with sizes greater than 128 but not a multiple of 128.
Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
Fixed a bug in profiler where tools are missing (#4427).
Dropped support for CUDA 10.0.
Improvements:
Ensure that check_jaxpr
does not perform FLOPS. See #4650.
Expanded the set of JAX primitives converted by jax2tf. See primitives_with_limited_support.md.
Improvements:
Add support for remat
to jax.experimental.host_callback. See #4608.
Deprecations
Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy. In a future release, this will result in a TypeError. See #4564.
The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation
Improvements:
As a benefit of omnistaging, the host_callback functions are executed (in program order) even if the result of the jax.experimental.host_callback.id_print()
/ jax.experimental.host_callback.id_tap()
is not used in the computation.
Improvements:
Omnistaging on by default. See #3370 and omnistaging
Breaking changes:
New simplified interface for jax.experimental.host_callback.id_tap()
(#4101)
Update XLA:
Fix bug in DLPackManagedTensorToBuffer (#4196)
Bug Fixes:
make jnp.abs() work for unsigned inputs (#3914)
Improvements:
“Omnistaging” behavior added behind a flag, disabled by default (#3370)
New Features:
BFGS (#3101)
TPU support for half-precision arithmetic (#3878)
Bug Fixes:
Prevent some accidental dtype warnings (#3874)
Fix a multi-threading bug in custom derivatives (#3845, #3869)
Improvements:
Faster searchsorted implementation (#3873)
Better test coverage for jax.numpy sorting algorithms (#3836)
Update XLA.
The minimum jaxlib version is now 0.1.51.
New Features:
jax.image.resize. (#3703)
hfft and ihfft (#3664)
jax.numpy.intersect1d (#3726)
jax.numpy.lexsort (#3812)
lax.scan
and the scan
primitive support an unroll
parameter for loop unrolling when lowering to XLA (#3738).
Bug Fixes:
Fix reduction repeated axis error (#3618)
Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
make psum transpose handle zero cotangents (#3653)
Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
Support differentiation through jax.lax.all_to_all (#3733)
address nan issue in jax.scipy.special.zeta (#3777)
Improvements:
Many improvements to jax2tf
Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
Enable XLA SPMD partitioning by default. (#3151)
Add support for 0d transpose convolution (#3643)
Make LU gradient work for low-rank matrices (#3610)
support multiple_results and custom JVPs in jet (#3657)
Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
Implement complex convolutions on CPU and GPU. (#3735)
Make jnp.take work for empty slices of empty arrays. (#3751)
Relax dimension ordering rules for dot_general. (#3778)
Enable buffer donation for GPU. (#3800)
Add support for base dilation and window dilation to reduce window op… (#3803)
Update XLA.
Add new runtime support for host_callback.
Bug fixes:
Fix an odeint bug introduced in the previous release, see #3587.
The minimum jaxlib version is now 0.1.48.
Bug fixes:
Allow jax.experimental.ode.odeint
dynamics functions to close over values with respect to which we’re differentiating #3562.
Add support for CUDA 11.0.
Drop support for CUDA 9.2 (we only maintain support for the last four CUDA versions.)
Update XLA.
Bug fixes:
Fix build issue that could result in slow compiles (tensorflow/tensorflow)
New features:
Adds support for fast traceback collection.
Adds preliminary support for on-device heap profiling.
Implements np.nextafter
for bfloat16
types.
Complex128 support for FFTs on CPU and GPU.
Bug fixes:
Improved float64 tanh
accuracy on GPU.
float64 scatters on GPU are much faster.
Complex matrix multiplication on CPU should be much faster.
Stable sorts on CPU should actually be stable now.
Concurrency bug fix in CPU backend.
New features:
lax.switch
introduces indexed conditionals with multiple branches, together with a generalization of the cond
primitive #3318.
New features:
lax.cond()
supports a single-operand form, taken as the argument to both branches #2993.
Notable changes:
The format of the transforms
keyword for the jax.experimental.host_callback.id_tap()
primitive has changed #3132.
New features:
Support for reduction over subsets of a pmapped axis using axis_index_groups
#2382.
Experimental support for printing and calling host-side Python function from compiled code. See id_print and id_tap (#3006).
Notable changes:
The visibility of names exported from jax.numpy
has been tightened. This may break code that was making use of names that were previously exported accidentally.
Fixes crash for outfeed.
New features:
Support for in_axes=None
on pmap()
#2896.
Fixes crash for linear algebra functions on Mac OS X (#432).
Fixes an illegal instruction crash caused by using AVX512 instructions when an operating system or hypervisor disabled them (#2906).
New features:
Differentiation of determinants of singular matrices #2809.
Bug fixes:
Fixes segfault: #2755
Plumb is_stable option on Sort HLO through to Python.
New features:
Add syntactic sugar for functional indexed updates #2684.
Add jax.numpy.unique()
#2760.
Add jax.numpy.rint()
#2724.
Add jax.numpy.rint()
#2724.
Add more primitive rules for jax.experimental.jet()
.
Bug fixes:
Better errors:
Improves error message for reverse-mode differentiation of lax.while_loop()
#2129.
Fixes a bug where if multiple GPUs of different models were present, JAX would only compile programs suitable for the first GPU.
Bugfix for batch_group_count
convolutions.
Added precompiled SASS for more GPU versions to avoid startup PTX compilation hang.
Added jax.custom_jvp
and jax.custom_vjp
from #2026, see the tutorial notebook. Deprecated jax.custom_transforms
and removed it from the docs (though it still works).
Add scipy.sparse.linalg.cg
#2566.
Changed how Tracers are printed to show more useful information for debugging #2591.
Made jax.numpy.isclose
handle nan
and inf
correctly #2501.
Added several new rules for jax.experimental.jet
#2537.
Fixed jax.experimental.stax.BatchNorm
when scale
/center
isn’t provided.
Fix some missing cases of broadcasting in jax.numpy.einsum
#2512.
Implement jax.numpy.cumsum
and jax.numpy.cumprod
in terms of a parallel prefix scan #2596 and make reduce_prod
differentiable to arbitrary order #2597.
Add batch_group_count
to conv_general_dilated
#2635.
Add docstring for test_util.check_grads
#2656.
Add callback_transform
#2665.
Implement rollaxis
, convolve
/correlate
1d & 2d, copysign
, trunc
, roots
, and quantile
/percentile
interpolation options.
Fixed a performance regression for Resnet-50 on GPU.
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
Removed the internal function lax._safe_mul
, which implemented the convention 0. * nan == 0.
. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details.
Added an all_gather
parallel convenience function.
More type annotations in core code.
jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.
JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.
New features:
jax.pmap()
has static_broadcast_argnums
argument which allows the user to specify arguments that should be treated as compile-time constants and should be broadcasted to all devices. It works analogously to static_argnums
in jax.jit()
.
Improved error messages for when tracers are mistakenly saved in global state.
Added jax.nn.one_hot()
utility function.
Added jax.experimental.jet
for exponentially faster higher-order automatic differentiation.
Added more correctness checking to arguments of jax.lax.broadcast_in_dim()
.
The minimum jaxlib version is now 0.1.41.
Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.
Includes prototype support for multihost GPU computations that communicate via NCCL.
Improves performance of NCCL collectives on GPU.
Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.
Supports device assignments known at XLA compilation time.
Breaking changes
The minimum jaxlib version is now 0.1.38.
Simplified Jaxpr
by removing the Jaxpr.freevars
and Jaxpr.bound_subjaxprs
. The call primitives (xla_call
, xla_pmap
, sharded_call
, and remat_call
) get a new parameter call_jaxpr
with a fully-closed (no constvars
) jaxpr. Also, added a new field call_primitive
to primitives.
New features:
Reverse-mode automatic differentiation (e.g. grad
) of lax.cond
, making it now differentiable in both modes (#2091)
JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zero-copy way with other libraries, such as PyTorch.
JAX GPU DeviceArrays now support __cuda_array_interface__
, which is another zero-copy protocol for sharing GPU arrays with other libraries such as CuPy and Numba.
JAX CPU device buffers now implement the Python buffer protocol, which allows zero-copy buffer sharing between JAX and NumPy.
Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
Updates XLA.
CUDA 9.0 is no longer supported.
CUDA 10.2 wheels are now built by default.
Breaking changes
JAX has dropped Python 2 support, because Python 2 reached its end of life on January 1, 2020. Please update to Python 3.5 or newer.
New features
With the Python 3 upgrade, JAX no longer depends on fastcache
, which should help with installation.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3