Hello,
I found a weird issue when calling irfft2 on a complex array on GPU (e.g. v100 but not restricted to it) with double precision. It looks like calling irfft2(x) changes the value of x for subsequent irfft2() or numpy.real() calls, while calling x on its own shows the initial x (as expected), see below. I'm running JAX 0.3.2 with CUDA/11.1.1-GCC-10.2.0 cuDNN/8.0.5.39-CUDA-11.1.1.
Thanks!
Python 3.10.2 | packaged by conda-forge | (main, Jan 28 2022, 16:48:22) [GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp
>>> from jax.config import config
>>> config.update("jax_enable_x64", True)
>>> x = jnp.array([[1.0,2.0],[3.0,4.0]],dtype=jnp.float64)*(1+1j)
>>> x
DeviceArray([[1.+1.j, 2.+2.j],
[3.+3.j, 4.+4.j]], dtype=complex128)
>>> jnp.fft.irfft2(x)
DeviceArray([[ 2.5, -0.5],
[-1. , 0. ]], dtype=float64)
>>> jnp.real(x)
DeviceArray([[ 4., 6.],
[-2., -2.]], dtype=float64)
>>> x
DeviceArray([[1.+1.j, 2.+2.j],
[3.+3.j, 4.+4.j]], dtype=complex128)
>>>
>>>
>>> jnp.fft.irfft2(x)
DeviceArray([[ 1.5, -0.5],
[ 3.5, -0.5]], dtype=float64)
>>>
>>> jnp.real(x)
DeviceArray([[2., 4.],
[6., 8.]], dtype=float64)
>>> x
DeviceArray([[1.+1.j, 2.+2.j],
[3.+3.j, 4.+4.j]], dtype=complex128)
>>>
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4