A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/jax-ml/jax/issues/24843 below:

Wrong determinant results for large batch · Issue #24843 · jax-ml/jax · GitHub

Description

The bug can be reproduced using the following code:

import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

a = jr.normal(jr.key(0), (1500000, 20, 20))
d = jnp.linalg.det(a)
plt.plot(d)

The values at the end are obviously incorrect.

I have tested different batch and matrix sizes. The bug only happens when the full matrix size (in the example 1500000*20*20) exceeds 2**29. As also shown in the figure, the determinant values are wrong for batch index > 2**29 / (20*20).

I also tested on different devices. This bug happens on different types of GPUs while not on the CPU.

System info (python version, jaxlib version, accelerator, etc.)
jax:    0.4.35
jaxlib: 0.4.35
numpy:  2.0.2
python: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0]
device info: NVIDIA A100 80GB PCIe-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='alcc145', release='5.15.0-94-generic', version='#104-Ubuntu SMP Tue Jan 9 15:25:40 UTC 2024', machine='x86_64')

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4