A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/jax-ml/jax/issues/9946 below:

Different outputs on subsequent calls of irfft2 with the same input on GPU with double precision · Issue #9946 · jax-ml/jax · GitHub

Hello,

I found a weird issue when calling irfft2 on a complex array on GPU (e.g. v100 but not restricted to it) with double precision. It looks like calling irfft2(x) changes the value of x for subsequent irfft2() or numpy.real() calls, while calling x on its own shows the initial x (as expected), see below. I'm running JAX 0.3.2 with CUDA/11.1.1-GCC-10.2.0 cuDNN/8.0.5.39-CUDA-11.1.1.

Thanks!

Python 3.10.2 | packaged by conda-forge | (main, Jan 28 2022, 16:48:22) [GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp
>>> from jax.config import config
>>> config.update("jax_enable_x64", True)
>>> x = jnp.array([[1.0,2.0],[3.0,4.0]],dtype=jnp.float64)*(1+1j)
>>> x
DeviceArray([[1.+1.j, 2.+2.j],
             [3.+3.j, 4.+4.j]], dtype=complex128)
>>> jnp.fft.irfft2(x)
DeviceArray([[ 2.5, -0.5],
             [-1. ,  0. ]], dtype=float64)
>>> jnp.real(x)
DeviceArray([[ 4.,  6.],
             [-2., -2.]], dtype=float64)
>>> x
DeviceArray([[1.+1.j, 2.+2.j],
             [3.+3.j, 4.+4.j]], dtype=complex128)
>>>
>>>
>>> jnp.fft.irfft2(x)
DeviceArray([[ 1.5, -0.5],
             [ 3.5, -0.5]], dtype=float64)
>>>
>>> jnp.real(x)
DeviceArray([[2., 4.],
             [6., 8.]], dtype=float64)
>>> x
DeviceArray([[1.+1.j, 2.+2.j],
             [3.+3.j, 4.+4.j]], dtype=complex128)
>>>

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.4