Hi,
I installed Jax on a TPU V3-8:
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
However, when running Jax, I get the following error.
(base) gerardoduran@t1v-n-7177f451-w-0:~$ python Python 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import jax j>>> import jax.numpy as jnp >>> jnp.sqrt(2) tcmalloc: large alloc 377396076847104 bytes == (nil) @ 0x7fefdec9d680 0x7fefdecbe824 0x7fedeee9d2da 0x7fedeee4ebae 0x7fedea76487a 0x7fedea763a05 0x7fedea765c62 0x7fede9a4a82e 0x7fee98171b66 0x7fee971e4541 0x7fee971d0912 0x7fee971ca61d 0x7fee95b3048c 0x7fee95b3f176 0x7fee93f64ded 0x7fee93d40ac8 0x7fee93d412d3 0x7fee93d1c916 0x55ced19903cc 0x55ced1989738 0x55ced199df80 0x55ced1981107 0x55ced199086f 0x55ced198299f 0x55ced199086f 0x55ced19800ff 0x55ced199086f 0x55ced19800ff 0x55ced199086f 0x55ced199e7f8 0x55ced198299f Unhandled exception: @ 0x7fedeedb9b62 (unknown) @ 0x7fedeeeba4e6 (unknown) @ 0x7fedeeeba03b (unknown) @ 0x7fedeeeb9fb4 (unknown) @ 0x7fedeee9d32b (unknown) @ 0x7fedeee4ebae (unknown) @ 0x7fedea76487a (unknown) @ 0x7fedea763a05 (unknown) @ 0x7fedea765c62 (unknown) @ 0x7fede9a4a82e TpuCompiler_RunHloPasses @ 0x7fee98171b66 xla::(anonymous namespace)::TpuCompiler::RunHloPasses() @ 0x7fee971e4541 xla::Service::BuildExecutable() @ 0x7fee971d0912 xla::LocalService::CompileExecutables() @ 0x7fee971ca61d xla::LocalClient::Compile() @ 0x7fee95b3048c xla::PjRtStreamExecutorClient::Compile() @ 0x7fee95b3f176 xla::PjRtStreamExecutorClient::Compile() @ 0x7fee93f64ded xla::PyClient::CompileMlir() @ 0x7fee93d40ac8 pybind11::detail::argument_loader<>::call_impl<>() @ 0x7fee93d412d3 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN() @ 0x7fee93d1c916 pybind11::cpp_function::dispatcher() @ 0x55ced19903cc cfunction_call https://symbolize.stripped_domain/r/?trace=7fedeedb9b62,7fedeeeba4e5,7fedeeeba03a,7fedeeeb9fb3,7fedeee9d32a,7fedeee4ebad,7fedea764879,7fedea763a04,7fedea765c61,7fede9a4a82d,7fee98171b65,7fee971e4540,7fee971d0911,7fee971ca61c,7fee95b3048b,7fee95b3f175,7fee93f64dec,7fee93d40ac7,7fee93d412d2,7fee93d1c915,55ced19903cb&map=ca08008df67fa564c14ead76d3f2385a:7feddef57000-7fedef062c00 libc++abi: terminating due to uncaught exception of type std::bad_alloc: std::bad_alloc https://symbolize.stripped_domain/r/?trace=7fefde94600b,7fefdec6f41f,7fedeeea17c8,7fedeeeba4e5,7fedeeeba03a,7fedeeeb9fb3,7fedeee9d32a,7fedeee4ebad,7fedea764879,7fedea763a04,7fedea765c61,7fede9a4a82d,7fee98171b65,7fee971e4540,7fee971d0911,7fee971ca61c,7fee95b3048b,7fee95b3f175,7fee93f64dec,7fee93d40ac7,7fee93d412d2,7fee93d1c915,55ced19903cb&map=ca08008df67fa564c14ead76d3f2385a:7feddef57000-7fedef062c00 *** SIGABRT received by PID 12344 (TID 12344) on cpu 47 from PID 12344; *** E0928 10:21:25.866065 12344 coredump_hook.cc:395] RAW: Remote crash data gathering hook invoked. E0928 10:21:25.866084 12344 coredump_hook.cc:441] RAW: Skipping coredump since rlimit was 0 at process start. E0928 10:21:25.866093 12344 client.cc:243] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec. E0928 10:21:25.866101 12344 coredump_hook.cc:502] RAW: Sending fingerprint to remote end. E0928 10:21:25.866109 12344 coredump_socket.cc:120] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket E0928 10:21:25.866121 12344 coredump_hook.cc:506] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running? E0928 10:21:25.866130 12344 coredump_hook.cc:580] RAW: Discarding core. E0928 10:21:26.115471 12344 process_state.cc:774] RAW: Raising signal 6 with default behavior Aborted (core dumped)
I've tried both reinstalling Jax and create a new TPU V3-8, but I get the exact same error.
Running jax.devices()
does show the TPUs I have on the VM
>>> jax.devices() [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
The problem seems to be related to the execution of Jax on the TPU. If I replicate @mattjj's code in this issue, I'm able to run Jax on the cpu-defined function, but not the tpu-defined one.
(base) gerardoduran@t1v-n-7177f451-w-0:~$ python Python 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> from jax import jit >>> def f(x): return x**2 ... >>> f_cpu = jit(f, backend='cpu') >>> f_tpu = jit(f, backend='tpu') >>> >>> f_cpu(2) DeviceArray(4, dtype=int32, weak_type=True) >>> f_tpu(2) tcmalloc: large alloc 378179171590144 bytes == (nil) @ 0x7f78decaf680 0x7f78decd0824 0x7f76eeeaf2da 0x7f76eee60bae 0x7f76ea77687a 0x7f76ea775a05 0x7f76ea777c62 0x7f76e9a5c82e 0x7f7798183b66 0x7f77971f6541 0x7f77971e2912 0x7f77971dc61d 0x7f7795b4248c 0x7f7795b51176 0x7f7793f76ded 0x7f7793d52ac8 0x7f7793d532d3 0x7f7793d2e916 0x55fc669c83cc 0x55fc669c1738 0x55fc669d5f80 0x55fc669b9107 0x55fc669c886f 0x55fc669ba99f 0x55fc669c886f 0x55fc669b80ff 0x55fc669c886f 0x55fc669b80ff 0x55fc669c886f 0x55fc669d67f8 0x55fc669ba99f # ... more errorsWhat jax/jaxlib version are you using?
jax 0.3.19 / jaxlib 0.3.15
Which accelerator(s) are you using?TPU
Additional system infoNo response
NVIDIA GPU infoNo response
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3