Evaluates the polynomial at specific values.
JAX implementations of numpy.polyval()
.
For the 1D-polynomial coefficients p
of length M
, the function returns the value:
\[p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}\]
p (ArrayLike) – An array of polynomial coefficients of shape (M,)
.
x (ArrayLike) – A number or an array of numbers.
unroll (int) – A number used to control the number of unrolled steps with lax.scan
. It must be specified statically.
An array of same shape as x
.
Note
The unroll
parameter is JAX specific. It does not affect correctness but can have a major impact on performance for evaluating high-order polynomials. The parameter controls the number of unrolled steps with lax.scan
inside the jnp.polyval
implementation. Consider setting unroll=128
(or even higher) to improve runtime performance on accelerators, at the cost of increased compilation time.
Examples
>>> p = jnp.array([2, 5, 1]) >>> jnp.polyval(p, 3) Array(34., dtype=float32)
If x
is a 2D array, polyval
returns 2D-array with same shape as that of x
:
>>> x = jnp.array([[2, 1, 5], ... [3, 4, 7], ... [1, 3, 5]]) >>> jnp.polyval(p, x) Array([[ 19., 8., 76.], [ 34., 53., 134.], [ 8., 34., 76.]], dtype=float32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4