When trying to run JAX with jaxlib==0.1.44
I run in to a segmentation fault on my machine with Python 3.8 and CUDA 10.2 if I run on GPU. This issue no longer occurs if I downgrade jaxlib
to 0.1.43
.
I installed jaxlib
using the installation instructions in the README for both versions, and I properly set the XLA CUDA directory in both cases to the same location. From what I gather, only jaxlib
is changing to generate the segfault.
I tried to do some digging and it seems like the segfault is coming from jaxlib/xla_extension.so
, particularly here is what gdb
produces:
0x00007fffd6f991e8 in absl::lts_2020_02_25::Mutex::ReaderLock() () from /home/ziyadedher/research/.venv/lib/python3.8/site-packages/jaxlib/xla_extension.so
Reverting to jaxlib==0.1.43
fixes the issue.
>>> jax.__version__ '0.1.63' >>> jaxlib.__version__ '0.1.44' >>> tensorflow.__version__ '2.2.0-rc3'
Some system information truncated to show the important bits:
$ nvcc --version Cuda compilation tools, release 10.2, V10.2.89 $ python --version Python 3.8.2 $ modinfo nvidia filename: /lib/modules/5.6.4-arch1-1/extramodules/nvidia.ko.xz version: 440.82
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3