A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/jax-ml/jax/issues/16440 below:

jax `__cuda_array_interface__` not working · Issue #16440 · jax-ml/jax · GitHub

Description
import numpy as	np
import jax.numpy as jnp
from numba import cuda
import cupy


@cuda.jit
def _sum_nomask(w, res, ind):
    start = cuda.grid(1)
    stride = cuda.gridsize(1)
    n = w.shape[0]
    tot = 0.0
    for i in range(start, n, stride):
        if w[i] > 0:
            tot += w[i]

    cuda.atomic.add(res, ind, tot)



if __name__ == "__main__":
    arr = cupy.arange(10000, dtype=np.float32)
    res = cupy.zeros(2, dtype=np.float32)
    _sum_nomask[500, 32](arr, res, 0)
    print("cupy:", res)

    arr = jnp.arange(10000, dtype=jnp.float32)
    res = jnp.zeros(2, dtype=jnp.float32)
    _sum_nomask[500, 32](arr, res, 0)
    print("jax:", res)

I get

$ python test_jax_numba.py 
cupy: [49995040.        0.]
Traceback (most recent call last):
  File "/home/mrbecker/test_jax_numba.py", line 29, in <module>
    _sum_nomask[500, 32](arr, res, 0)
  File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 542, in __call__
    return self.dispatcher.call(args, self.griddim, self.blockdim,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 676, in call
    kernel = _dispatcher.Dispatcher._cuda_call(self, *args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 683, in _compile_for_args
    argtypes = [self.typeof_pyval(a) for a in args]
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 683, in <listcomp>
    argtypes = [self.typeof_pyval(a) for a in args]
                ^^^^^^^^^^^^^^^^^^^^
  File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/cuda/dispatcher.py", line 690, in typeof_pyval
    return typeof(val, Purpose.argument)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/core/typing/typeof.py", line 33, in typeof
    ty = typeof_impl(val, c)
         ^^^^^^^^^^^^^^^^^^^
  File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/functools.py", line 909, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/core/typing/typeof.py", line 46, in typeof_impl
    tp = _typeof_buffer(val, c)
         ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mrbecker/miniforge3/envs/jax-numba/lib/python3.11/site-packages/numba/core/typing/typeof.py", line 69, in _typeof_buffer
    m = memoryview(val)
        ^^^^^^^^^^^^^^^
BufferError: INVALID_ARGUMENT: Python buffer protocol is only defined for CPU buffers.

Futher if you try to access the cuda array interface attribute another error comes up. (This may be a red herring.)

In [1]: import jax

In [2]: a = jax.numpy.arange(10)

In [3]: a
Out[3]: Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [4]: a.__cuda_array_interface__
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
TypeError: Unregistered type : absl::lts_20230125::StatusOr<dict>

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[4], line 1
----> 1 a.__cuda_array_interface__

TypeError: Unable to convert function return value to a Python type! The signature was
	(arg0: xla::PyArray) -> absl::lts_20230125::StatusOr<dict>
What jax/jaxlib version are you using?

jax and jaxlib 0.4.12

Which accelerator(s) are you using?

GPU

Additional system info

linux w/ nvidia

NVIDIA GPU info
$ nvidia-smi
Thu Jun 15 14:35:41 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03   Driver Version: 470.182.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Quadro GV100        Off  | 00000000:2D:00.0 Off |                  Off |
| 49%   57C    P0    39W / 250W |      0MiB / 32508MiB |      3%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.3