Calculate the normalized sinc function.
JAX implementation of numpy.sinc()
.
The normalized sinc function is given by
\[\mathrm{sinc}(x) = \frac{\sin({\pi x})}{\pi x}\]
where sinc(0)
returns the limit value of 1
. The sinc function is smooth and infinitely differentiable.
x (ArrayLike) – input array; will be promoted to an inexact type.
An array of the same shape as x
containing the result.
Examples
>>> x = jnp.array([-1, -0.5, 0, 0.5, 1]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sinc(x) Array([-0. , 0.637, 1. , 0.637, -0. ], dtype=float32)
Compare this to the naive approach to computing the function, which is undefined at zero:
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sin(jnp.pi * x) / (jnp.pi * x) Array([-0. , 0.637, nan, 0.637, -0. ], dtype=float32)
JAX defines a custom gradient rule for sinc to allow accurate evaluation of the gradient at zero even for higher-order derivatives:
>>> f = jnp.sinc >>> for i in range(1, 6): ... f = jax.grad(f) ... print(f"(d/dx)^{i} f(0.0) = {f(0.0):.2f}") ... (d/dx)^1 f(0.0) = 0.00 (d/dx)^2 f(0.0) = -3.29 (d/dx)^3 f(0.0) = 0.00 (d/dx)^4 f(0.0) = 19.48 (d/dx)^5 f(0.0) = 0.00
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4