A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://github.com/jax-ml/jax/issues/12550 below:

Jax crashes on TPU in version 0.3.19 · Issue #12550 · jax-ml/jax · GitHub

Description

Hi,

I installed Jax on a TPU V3-8:

pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

However, when running Jax, I get the following error.

(base) gerardoduran@t1v-n-7177f451-w-0:~$ python
Python 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
j>>> import jax.numpy as jnp
>>> jnp.sqrt(2)
tcmalloc: large alloc 377396076847104 bytes == (nil) @  0x7fefdec9d680 0x7fefdecbe824 0x7fedeee9d2da 0x7fedeee4ebae 0x7fedea76487a 0x7fedea763a05 0x7fedea765c62 0x7fede9a4a82e 0x7fee98171b66 0x7fee971e4541 0x7fee971d0912 0x7fee971ca61d 0x7fee95b3048c 0x7fee95b3f176 0x7fee93f64ded 0x7fee93d40ac8 0x7fee93d412d3 0x7fee93d1c916 0x55ced19903cc 0x55ced1989738 0x55ced199df80 0x55ced1981107 0x55ced199086f 0x55ced198299f 0x55ced199086f 0x55ced19800ff 0x55ced199086f 0x55ced19800ff 0x55ced199086f 0x55ced199e7f8 0x55ced198299f
Unhandled exception:
    @     0x7fedeedb9b62  (unknown)
    @     0x7fedeeeba4e6  (unknown)
    @     0x7fedeeeba03b  (unknown)
    @     0x7fedeeeb9fb4  (unknown)
    @     0x7fedeee9d32b  (unknown)
    @     0x7fedeee4ebae  (unknown)
    @     0x7fedea76487a  (unknown)
    @     0x7fedea763a05  (unknown)
    @     0x7fedea765c62  (unknown)
    @     0x7fede9a4a82e  TpuCompiler_RunHloPasses
    @     0x7fee98171b66  xla::(anonymous namespace)::TpuCompiler::RunHloPasses()
    @     0x7fee971e4541  xla::Service::BuildExecutable()
    @     0x7fee971d0912  xla::LocalService::CompileExecutables()
    @     0x7fee971ca61d  xla::LocalClient::Compile()
    @     0x7fee95b3048c  xla::PjRtStreamExecutorClient::Compile()
    @     0x7fee95b3f176  xla::PjRtStreamExecutorClient::Compile()
    @     0x7fee93f64ded  xla::PyClient::CompileMlir()
    @     0x7fee93d40ac8  pybind11::detail::argument_loader<>::call_impl<>()
    @     0x7fee93d412d3  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7fee93d1c916  pybind11::cpp_function::dispatcher()
    @     0x55ced19903cc  cfunction_call
https://symbolize.stripped_domain/r/?trace=7fedeedb9b62,7fedeeeba4e5,7fedeeeba03a,7fedeeeb9fb3,7fedeee9d32a,7fedeee4ebad,7fedea764879,7fedea763a04,7fedea765c61,7fede9a4a82d,7fee98171b65,7fee971e4540,7fee971d0911,7fee971ca61c,7fee95b3048b,7fee95b3f175,7fee93f64dec,7fee93d40ac7,7fee93d412d2,7fee93d1c915,55ced19903cb&map=ca08008df67fa564c14ead76d3f2385a:7feddef57000-7fedef062c00 
libc++abi: terminating due to uncaught exception of type std::bad_alloc: std::bad_alloc
https://symbolize.stripped_domain/r/?trace=7fefde94600b,7fefdec6f41f,7fedeeea17c8,7fedeeeba4e5,7fedeeeba03a,7fedeeeb9fb3,7fedeee9d32a,7fedeee4ebad,7fedea764879,7fedea763a04,7fedea765c61,7fede9a4a82d,7fee98171b65,7fee971e4540,7fee971d0911,7fee971ca61c,7fee95b3048b,7fee95b3f175,7fee93f64dec,7fee93d40ac7,7fee93d412d2,7fee93d1c915,55ced19903cb&map=ca08008df67fa564c14ead76d3f2385a:7feddef57000-7fedef062c00 
*** SIGABRT received by PID 12344 (TID 12344) on cpu 47 from PID 12344; ***
E0928 10:21:25.866065   12344 coredump_hook.cc:395] RAW: Remote crash data gathering hook invoked.
E0928 10:21:25.866084   12344 coredump_hook.cc:441] RAW: Skipping coredump since rlimit was 0 at process start.
E0928 10:21:25.866093   12344 client.cc:243] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0928 10:21:25.866101   12344 coredump_hook.cc:502] RAW: Sending fingerprint to remote end.
E0928 10:21:25.866109   12344 coredump_socket.cc:120] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0928 10:21:25.866121   12344 coredump_hook.cc:506] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0928 10:21:25.866130   12344 coredump_hook.cc:580] RAW: Discarding core.
E0928 10:21:26.115471   12344 process_state.cc:774] RAW: Raising signal 6 with default behavior
Aborted (core dumped)

I've tried both reinstalling Jax and create a new TPU V3-8, but I get the exact same error.

Running jax.devices() does show the TPUs I have on the VM

>>> jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

The problem seems to be related to the execution of Jax on the TPU. If I replicate @mattjj's code in this issue, I'm able to run Jax on the cpu-defined function, but not the tpu-defined one.

(base) gerardoduran@t1v-n-7177f451-w-0:~$ python
Python 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from jax import jit
>>> def f(x): return x**2
... 
>>> f_cpu = jit(f, backend='cpu')
>>> f_tpu = jit(f, backend='tpu')
>>> 
>>> f_cpu(2)
DeviceArray(4, dtype=int32, weak_type=True)
>>> f_tpu(2)
tcmalloc: large alloc 378179171590144 bytes == (nil) @  0x7f78decaf680 0x7f78decd0824 0x7f76eeeaf2da 0x7f76eee60bae 0x7f76ea77687a 0x7f76ea775a05 0x7f76ea777c62 0x7f76e9a5c82e 0x7f7798183b66 0x7f77971f6541 0x7f77971e2912 0x7f77971dc61d 0x7f7795b4248c 0x7f7795b51176 0x7f7793f76ded 0x7f7793d52ac8 0x7f7793d532d3 0x7f7793d2e916 0x55fc669c83cc 0x55fc669c1738 0x55fc669d5f80 0x55fc669b9107 0x55fc669c886f 0x55fc669ba99f 0x55fc669c886f 0x55fc669b80ff 0x55fc669c886f 0x55fc669b80ff 0x55fc669c886f 0x55fc669d67f8 0x55fc669ba99f
# ... more errors
What jax/jaxlib version are you using?

jax 0.3.19 / jaxlib 0.3.15

Which accelerator(s) are you using?

TPU

Additional system info

No response

NVIDIA GPU info

No response


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.3