jax.grad(jax.nn.softplus)(0.0)
evaluates to 0.0, which is definitely wrong -- the right answer is 0.5.
This is easy to visualize:
import matplotlib.pyplot as plt x = jax.numpy.linspace(-2, 2, num=101) plt.plot(x, jax.vmap(jax.grad(jax.nn.softplus))(x))
You can’t perform that action at this time.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3