JAX includes a numpy compatible jax.numpy
module which has a bunch of nice features (automatic differentiation, jit compilation, vectorized mapping, GPU runtime, js export). They've taken great pains to make sure it's usually as simple as swapping import numpy as np
for import jax.numpy as np
. LIkewise (but less extensively) for the jax.scipy
module.
I'd like to do some optimization for which it would be really convenient to automatically differentiate some of the great stuff you've implemented and export it to js. It should be as simple as changing import numpy as np
around the library:
if HOWEVER_WE_SET_THE_CONFIG: import jax.numpy as np else: import numpy as np
Changing the type signatures probably has more degrees of freedom we can choose, but is basically the same.
I'd be happy to implement it, but don't want to make a PR that you don't want.
I expect that the added maintenance burden would be pretty minimal.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4