Create a square or rectangular identity matrix
JAX implementation of numpy.eye()
.
N (DimSize) – integer specifying the first dimension of the array.
M (DimSize | None) – optional integer specifying the second dimension of the array; defaults to the same value as N
.
k (int | ArrayLike) – optional integer specifying the offset of the diagonal. Use positive values for upper diagonals, and negative values for lower diagonals. Default is zero.
dtype (DTypeLike | None) – optional dtype; defaults to floating point.
device (xc.Device | Sharding | None) – optional Device
or Sharding
to which the created array will be committed.
Identity array of shape (N, M)
, or (N, N)
if M
is not specified.
Examples
A simple 3x3 identity matrix:
>>> jnp.eye(3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
Integer identity matrices with offset diagonals:
>>> jnp.eye(3, k=1, dtype=int) Array([[0, 1, 0], [0, 0, 1], [0, 0, 0]], dtype=int32) >>> jnp.eye(3, k=-1, dtype=int) Array([[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=int32)
Non-square identity matrix:
>>> jnp.eye(3, 5, k=1) Array([[0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.]], dtype=float32)
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4