A RetroSearch Logo

Home - News ( United States | United Kingdom | Italy | Germany ) - Football scores

Search Query:

Showing content from https://optax.readthedocs.io/en/latest/api/losses.html below:

Losses — Optax documentation

Losses# Convex Kullback Leibler divergence#
optax.losses.convex_kl_divergence(log_predictions: chex.Array, targets: chex.Array, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[source]#

Computes a convex version of the Kullback-Leibler divergence loss.

Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution. This version is jointly convex in p (targets) and q (log_predictions).

Parameters:
  • log_predictions – Probabilities of predicted distribution with shape […, dim]. Expected to be in the log-space to avoid underflow.

  • targets – Probabilities of target distribution with shape […, dim]. Expected to be strictly positive.

  • axis – Axis or axes along which to compute.

  • where – Elements to include in the computation.

Returns:

Kullback-Leibler divergence of predicted distribution from target distribution with shape […].

References

Kullback and Leibler, On Information and Sufficiency, 1951

Changed in version 0.2.4: Added axis and where arguments.

Cosine distance#
optax.losses.cosine_distance(predictions: chex.Array, targets: chex.Array, epsilon: float = 0.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[source]#

Computes the cosine distance between targets and predictions.

The cosine distance, implemented here, measures the dissimilarity of two vectors as the opposite of cosine similarity: 1 - cos(theta).

Parameters:
  • predictions – The predicted vectors, with shape […, dim].

  • targets – Ground truth target vectors, with shape […, dim].

  • epsilon – minimum norm for terms in the denominator of the cosine similarity.

  • axis – Axis or axes along which to compute.

  • where – Elements to include in the computation.

Returns:

cosine distances, with shape […].

References

Cosine distance, Wikipedia.

Changed in version 0.2.4: Added axis and where arguments.

Cosine similarity#
optax.losses.cosine_similarity(predictions: chex.Array, targets: chex.Array, epsilon: float = 0.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[source]#

Computes the cosine similarity between targets and predictions.

The cosine similarity is a measure of similarity between vectors defined as the cosine of the angle between them, which is also the inner product of those vectors normalized to have unit norm.

Parameters:
  • predictions – The predicted vectors, with shape […, dim].

  • targets – Ground truth target vectors, with shape […, dim].

  • epsilon – minimum norm for terms in the denominator of the cosine similarity.

  • axis – Axis or axes along which to compute.

  • where – Elements to include in the computation.

Returns:

cosine similarity measures, with shape […].

References

Cosine similarity, Wikipedia.

Changed in version 0.2.4: Added axis and where arguments.

Connectionist temporal classification loss#
optax.losses.ctc_loss(logits: chex.Array, logit_paddings: chex.Array, labels: chex.Array, label_paddings: chex.Array, blank_id: int = 0, log_epsilon: float = -100000.0) chex.Array[source]#

Computes CTC loss.

See docstring for ctc_loss_with_forward_probs for details.

Parameters:
  • logits – (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in logits, and K denotes the number of classes including a class for blanks.

  • logit_paddings – (B, T)-array. Padding indicators for logits. Each element must be either 1.0 or 0.0, and logitpaddings[b, t] == 1.0 denotes that logits[b, t, :] are padded values.

  • labels – (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence.

  • label_paddings – (B, N)-array. Padding indicators for labels. Each element must be either 1.0 or 0.0, and labelpaddings[b, n] == 1.0 denotes that labels[b, n] is a padded label. In the current implementation, labels must be right-padded, i.e. each row labelpaddings[b, :] must be repetition of zeroes, followed by repetition of ones.

  • blank_id – Id for blank token. logits[b, :, blank_id] are used as probabilities of blank symbols.

  • log_epsilon – Numerically-stable approximation of log(+0).

Returns:

(B,)-array containing loss values for each sequence in the batch.

optax.losses.ctc_loss_with_forward_probs(logits: chex.Array, logit_paddings: chex.Array, labels: chex.Array, label_paddings: chex.Array, blank_id: int = 0, log_epsilon: float = -100000.0) tuple[TypeAliasForwardRef('chex.Array'), TypeAliasForwardRef('chex.Array'), TypeAliasForwardRef('chex.Array')][source]#

Computes CTC loss and CTC forward-probabilities.

The CTC loss is a loss function based on log-likelihoods of the model that introduces a special blank symbol \(\phi\) to represent variable-length output sequences.

Forward probabilities returned by this function, as auxiliary results, are grouped into two part: blank alpha-probability and non-blank alpha probability. Those are defined as follows:

\[\alpha_{\mathrm{BLANK}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1}, \cdots), \\ \alpha_{\mathrm{LABEL}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1}, \cdots). \]

Here, \(\pi\) denotes the alignment sequence in the reference [Graves et al, 2006] that is blank-inserted representations of labels. The return values are the logarithms of the above probabilities.

Parameters:
  • logits – (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in logits, and K denotes the number of classes including a class for blanks.

  • logit_paddings – (B, T)-array. Padding indicators for logits. Each element must be either 1.0 or 0.0, and logitpaddings[b, t] == 1.0 denotes that logits[b, t, :] are padded values.

  • labels – (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence.

  • label_paddings – (B, N)-array. Padding indicators for labels. Each element must be either 1.0 or 0.0, and labelpaddings[b, n] == 1.0 denotes that labels[b, n] is a padded label. In the current implementation, labels must be right-padded, i.e. each row labelpaddings[b, :] must be repetition of zeroes, followed by repetition of ones.

  • blank_id – Id for blank token. logits[b, :, blank_id] are used as probabilities of blank symbols.

  • log_epsilon – Numerically-stable approximation of log(+0).

Returns:

A tuple (loss_value, logalpha_blank, logalpha_nonblank). Here, loss_value is a (B,)-array containing the loss values for each sequence in the batch, logalpha_blank and logalpha_nonblank are (T, B, N+1)-arrays where the (t, b, n)-th element denotes log alpha_B(t, n) and log alpha_L(t, n), respectively, for b-th sequence in the batch.

References

Graves et al, Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks, 2006

Fenchel Young loss#
optax.losses.make_fenchel_young_loss(max_fun: MaxFun)[source]#

Creates a Fenchel-Young loss from a max function.

Parameters:

max_fun – the max function on which the Fenchel-Young loss is built.

Returns:

A Fenchel-Young loss function with the same signature.

Examples

Given a max function, e.g., the log sum exp, you can construct a Fenchel-Young loss easily as follows:

>>> from jax.scipy.special import logsumexp
>>> fy_loss = optax.losses.make_fenchel_young_loss(max_fun=logsumexp)
Reference:

Blondel et al. Learning with Fenchel-Young Losses, 2020

Warning

The resulting loss accepts an arbitrary number of leading dimensions with the fy_loss operating over the last dimension. The jaxopt version of this function would instead flatten any vector in a single big 1D vector.

Hinge loss#
optax.losses.hinge_loss(predictor_outputs: chex.Array, targets: chex.Array) chex.Array[source]#

Computes the hinge loss for binary classification.

Parameters:
  • predictor_outputs – Outputs of the decision function.

  • targets – Target values. Target values should be strictly in the set {-1, 1}.

Returns:

loss value.

optax.losses.multiclass_hinge_loss(scores: chex.Array, labels: chex.Array) chex.Array[source]#

Multiclass hinge loss.

Parameters:
  • scores – scores produced by the model (floats).

  • labels – ground-truth integer labels.

Returns:

loss values

References

Hinge loss, Wikipedia

Added in version 0.2.3.

Huber loss#
optax.losses.huber_loss(predictions: chex.Array, targets: TypeAliasForwardRef('chex.Array') | None = None, delta: float = 1.0) chex.Array[source]#

Huber loss, similar to L2 loss close to zero, L1 loss away from zero.

If gradient descent is applied to the huber loss, it is equivalent to clipping gradients of an l2_loss to [-delta, delta] in the backward pass.

Parameters:
  • predictions – a vector of arbitrary shape […].

  • targets – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

  • delta – the bounds for the huber loss transformation, defaults at 1.

Returns:

elementwise huber losses, with the same shape of predictions.

References

Huber loss, Wikipedia.

Kullback-Leibler divergence#
optax.losses.kl_divergence(log_predictions: chex.Array, targets: chex.Array, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[source]#

Computes the Kullback-Leibler divergence (relative entropy) loss.

Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution.

Parameters:
  • log_predictions – Probabilities of predicted distribution with shape […, dim]. Expected to be in the log-space to avoid underflow.

  • targets – Probabilities of target distribution with shape […, dim]. Expected to be strictly positive.

  • axis – Axis or axes along which to compute.

  • where – Elements to include in the computation.

Returns:

Kullback-Leibler divergence of predicted distribution from target distribution with shape […].

References

Kullback and Leibler, On Information and Sufficiency, 1951

Changed in version 0.2.4: Added axis and where arguments.

optax.losses.kl_divergence_with_log_targets(log_predictions: chex.Array, log_targets: chex.Array, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[source]#

Computes the Kullback-Leibler divergence (relative entropy) loss.

Version of kl_div_loss where targets are given in log-space.

Parameters:
  • log_predictions – Probabilities of predicted distribution with shape […, dim]. Expected to be in the log-space to avoid underflow.

  • log_targets – Probabilities of target distribution with shape […, dim]. Expected to be in the log-space.

  • axis – Axis or axes along which to compute.

  • where – Elements to include in the computation.

Returns:

Kullback-Leibler divergence of predicted distribution from target distribution with shape […].

Changed in version 0.2.4: Added axis and where arguments.

L2 Squared loss#
optax.losses.squared_error(predictions: chex.Array, targets: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[source]#

Calculates the squared error for a set of predictions.

Mean Squared Error can be computed as squared_error(a, b).mean().

Parameters:
  • predictions – a vector of arbitrary shape […].

  • targets – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

Returns:

elementwise squared differences, with same shape as predictions.

Note

l2_loss = 0.5 * squared_error, where the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.

optax.losses.l2_loss(predictions: chex.Array, targets: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[source]#

Calculates the L2 loss for a set of predictions.

Parameters:
  • predictions – a vector of arbitrary shape […].

  • targets – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

Returns:

elementwise squared differences, with same shape as predictions.

Note

the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.

Log hyperbolic cosine loss#
optax.losses.log_cosh(predictions: chex.Array, targets: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[source]#

Calculates the log-cosh loss for a set of predictions.

log(cosh(x)) is approximately (x**2) / 2 for small x and abs(x) - log(2) for large x. It is a twice differentiable alternative to the Huber loss.

Parameters:
  • predictions – a vector of arbitrary shape […].

  • targets – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

Returns:

the log-cosh loss, with same shape as predictions.

References

Chen et al, Log Hyperbolic Cosine Loss Improves Variational Auto-Encoder <https://openreview.net/pdf?id=rkglvsC9Ym>, 2019

Normalized temperature scaled cross-entropy (NT-Xent) loss#
optax.losses.ntxent(embeddings: chex.Array, labels: chex.Array, temperature: chex.Numeric = 0.07) chex.Numeric[source]#

Normalized temperature scaled cross entropy loss (NT-Xent).

Examples

>>> import jax
>>> import optax
>>> import jax.numpy as jnp
>>>
>>> key = jax.random.key(42)
>>> key1, key2, key3 = jax.random.split(key, 3)
>>> x = jax.random.normal(key1, shape=(4,2))
>>> labels = jnp.array([0, 0, 1, 1])
>>>
>>> print("input:", x)
input: [[ 0.07592554 -0.48634264]
 [ 1.2903206   0.5196119 ]
 [ 0.30040437  0.31034866]
 [ 0.5761609  -0.8074621 ]]
>>> print("labels:", labels)
labels: [0 0 1 1]
>>>
>>> w = jax.random.normal(key2, shape=(2,1)) # params
>>> b = jax.random.normal(key3, shape=(1,)) # params
>>> out = x @ w + b # model
>>>
>>> print("Embeddings:", out)
Embeddings: [[0.08969027]
 [1.6291292 ]
 [0.8622629 ]
 [0.13612625]]
>>> loss = optax.ntxent(out, labels)
>>> print("loss:", loss)
loss: 1.0986123
Parameters:
  • embeddings – batch of embeddings, with shape [batch, feature_length]

  • labels – labels for groups that are positive pairs. e.g. if you have a batch of 4 embeddings and the first two and last two were positive pairs your labels should look like [0, 0, 1, 1]. Shape [batch]

  • temperature – temperature scaling parameter.

Returns:

A scalar loss value of NT-Xent values averaged over all positive pairs

References

T. Chen et al A Simple Framework for Contrastive Learning of Visual Representations, 2020

kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss

Added in version 0.2.3.

Poly loss cross-entropy#
optax.losses.poly_loss_cross_entropy(logits: chex.Array, labels: chex.Array, epsilon: float = 2.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[source]#

Computes PolyLoss between logits and labels.

The PolyLoss is a loss function that decomposes commonly used classification loss functions into a series of weighted polynomial bases. It is inspired by the Taylor expansion of cross-entropy loss and focal loss in the bases of \((1 - P_t)^j\).

\[L_{Poly} = \sum_1^\infty \alpha_j \cdot (1 - P_t)^j \\ L_{Poly-N} = (\epsilon_1 + 1) \cdot (1 - P_t) + \ldots + \\ (\epsilon_N + \frac{1}{N}) \cdot (1 - P_t)^N + \frac{1}{N + 1} \cdot (1 - P_t)^{N + 1} + \ldots = \\ - \log(P_t) + \sum_{j = 1}^N \epsilon_j \cdot (1 - P_t)^j \]

This function provides a simplified version of \(L_{Poly-N}\) with only the coefficient of the first polynomial term being changed.

Parameters:
  • logits – Unnormalized log probabilities, with shape […, num_classes].

  • labels – Valid probability distributions (non-negative, sum to 1), e.g. a one hot encoding specifying the correct class for each input; must have a shape broadcastable to […, num_classes].

  • epsilon – The coefficient of the first polynomial term. According to the paper, the following values are recommended: - For the ImageNet 2d image classification, epsilon = 2.0. - For the 2d Instance Segmentation and object detection, epsilon = -1.0. - It is also recommended to adjust this value based on the task, e.g. by using grid search.

  • axis – Axis or axes along which to compute.

  • where – Elements to include in the computation.

Returns:

Poly loss between each prediction and the corresponding target distributions, with shape […].

References

Leng et al, PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions, 2022

Changed in version 0.2.4: Added axis and where arguments.

Perceptron#
optax.losses.perceptron_loss(predictor_outputs: chex.Numeric, targets: chex.Numeric) chex.Numeric[source]#

Binary perceptron loss.

Parameters:
  • predictor_outputs – score produced by the model (float).

  • targets – Target values. Target values should be strictly in the set {-1, 1}.

Returns:

loss value.

References

Perceptron, Wikipedia

optax.losses.multiclass_perceptron_loss(scores: chex.Array, labels: chex.Array) chex.Array[source]#

Multiclass perceptron loss.

Parameters:
  • scores – scores produced by the model.

  • labels – ground-truth integer labels.

Returns:

loss values.

References

Michael Collins. Discriminative training methods for Hidden Markov Models: Theory and experiments with perceptron algorithms. EMNLP 2002

Added in version 0.2.2.

Ranking softmax loss#
optax.losses.ranking_softmax_loss(logits: chex.Array, labels: chex.Array, *, where: TypeAliasForwardRef('chex.Array') | None = None, weights: TypeAliasForwardRef('chex.Array') | None = None, reduce_fn: ~collections.abc.Callable[[...], TypeAliasForwardRef('chex.Array')] | None = <function mean>) chex.Array[source]#

Ranking softmax loss.

Definition:

\[\ell(s, y) = -\sum_i y_i \log \frac{\exp(s_i)}{\sum_j \exp(s_j)} \]

Parameters:
  • logits – A [..., list_size]-Array, indicating the score of each item.

  • labels – A [..., list_size]-Array, indicating the relevance label for each item.

  • where – An optional [..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.

  • weights – An optional [..., list_size]-Array, indicating the weight for each item.

  • reduce_fn – An optional function that reduces the loss values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Returns:

The ranking softmax loss.

Sigmoid binary cross-entropy#
optax.losses.sigmoid_binary_cross_entropy(logits, labels)[source]#

Computes element-wise sigmoid cross entropy given logits and labels.

This function can be used for binary or multiclass classification (where each class is an independent binary prediction and different classes are not mutually exclusive e.g. predicting that an image contains both a cat and a dog.)

Because this function is overloaded, please ensure your logits and labels are compatible with each other. If you’re passing in binary labels (values in {0, 1}), ensure your logits correspond to class 1 only. If you’re passing in per-class target probabilities or one-hot labels, please ensure your logits are also multiclass. Be particularly careful if you’re relying on implicit broadcasting to reshape logits or labels.

Parameters:
  • logits – Each element is the unnormalized log probability of a binary prediction. See note about compatibility with labels above.

  • labels – Binary labels whose values are {0,1} or multi-class target probabilities. See note about compatibility with logits above.

Returns:

cross entropy for each binary prediction, same shape as logits.

References

Goodfellow et al, Deep Learning, 2016

Sigmoid focal loss#
optax.losses.sigmoid_focal_loss(logits: chex.Array, labels: chex.Array, alpha: float | None = None, gamma: float = 2.0) chex.Array[source]#

Sigmoid focal loss.

The focal loss is a re-weighted cross entropy for unbalanced problems. Use this loss function if classes are not mutually exclusive. See sigmoid_binary_cross_entropy for more information.

Parameters:
  • logits – Array of floats. The predictions for each example. The predictions for each example.

  • labels – Array of floats. Labels and logits must have the same shape. The label array must contain the binary classification labels for each element in the data set (0 for the out-of-class and 1 for in-class).

  • alpha – (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default None (no weighting).

  • gamma – Exponent of the modulating factor (1 - p_t). Balances easy vs hard examples.

Returns:

A loss value array with a shape identical to the logits and target arrays.

References

Lin et al. Focal Loss for Dense Object Detection, 2017

Smoothing labels#
optax.losses.smooth_labels(labels: chex.Array, alpha: float) Array[source]#

Apply label smoothing.

Label smoothing is often used in combination with a cross-entropy loss. Smoothed labels favor small logit gaps, and it has been shown that this can provide better model calibration by preventing overconfident predictions.

Parameters:
  • labels – One hot labels to be smoothed.

  • alpha – The smoothing factor.

Returns:

a smoothed version of the one hot input labels.

References

Muller et al, When does label smoothing help?, 2019

Soft-max cross-entropy#
optax.losses.safe_softmax_cross_entropy(logits: chex.Array, labels: chex.Array) chex.Array[source]#

Computes the softmax cross entropy between sets of logits and labels.

Contrarily to optax.softmax_cross_entropy() this function handles labels*logsoftmax(logits) as 0 when logits=-inf and labels=0, following the convention that 0 log 0 = 0.

Parameters:
  • logits – Unnormalized log probabilities, with shape […, num_classes].

  • labels – Valid probability distributions (non-negative, sum to 1), e.g a one hot encoding specifying the correct class for each input; must have a shape broadcastable to […, num_classes].

Returns:

cross entropy between each prediction and the corresponding target distributions, with shape […].

optax.losses.softmax_cross_entropy(logits: chex.Array, labels: chex.Array, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[source]#

Computes the softmax cross entropy between sets of logits and labels.

This loss function is commonly used for multi-class classification tasks. It measures the dissimilarity between the predicted probability distribution (obtained by applying the softmax function to the logits) and the true probability distribution (represented by the one-hot encoded labels). This loss is also known as categorical cross entropy.

Let \(x\) denote the logits array of size [batch_size, num_classes] and \(y\) denote the labels array of size [batch_size, num_classes]. Then this function returns a vector \(\sigma\) of size [batch_size] defined as:

\[\sigma_i = - \sum_j y_{i j} \log\left(\frac{\exp(x_{i j})}{\sum_k \exp(x_{i k})}\right) \,. \]

Parameters:
  • logits – Unnormalized log probabilities, with shape [batch_size, num_classes].

  • labels – One-hot encoded labels, with shape [batch_size, num_classes]. Each row represents the true class distribution for a single example.

  • axis – Axis or axes along which to compute.

  • where – Elements to include in the computation.

Returns:

Cross-entropy between each prediction and the corresponding target distributions, with shape [batch_size].

Examples

>>> import optax
>>> import jax.numpy as jnp
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = 2, num_classes = 3
>>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]])
>>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
>>> print(optax.softmax_cross_entropy(logits, labels))
[0.2761 2.9518]

References

Cross-entropy Loss, Wikipedia

Multinomial Logistic Regression, Wikipedia

Changed in version 0.2.4: Added axis and where arguments.

optax.losses.softmax_cross_entropy_with_integer_labels(logits: chex.Array, labels: chex.Array, axis: int | tuple[int, ...] = -1, where: TypeAliasForwardRef('chex.Array') | None = None) chex.Array[source]#

Computes softmax cross entropy between the logits and integer labels.

This loss is useful for classification problems with integer labels that are not one-hot encoded. This loss is also known as categorical cross entropy.

Let \(x\) denote the logits array of size [batch_size, num_classes] and \(y\) denote the labels array of size [batch_size]. Then this function returns a vector \(\sigma\) of size [batch_size] defined as:

\[\sigma_i = \log\left(\frac{\exp(x_{i y_i})}{\sum_j \exp(x_{i j})}\right)\,. \]

Parameters:
  • logits – Unnormalized log probabilities, with shape [batch_size, num_classes].

  • labels – Integers specifying the correct class for each input, with shape [batch_size]. Class labels are assumed to be between 0 and num_classes - 1 inclusive.

  • axis – Axis or axes along which to compute. If a tuple of axes is passed then num_classes must match the total number of elements in axis dimensions and a label is interpreted as a flat index in a logits slice of shape logits[axis].

  • where – Elements to include in the computation.

Returns:

Cross-entropy between each prediction and the corresponding target distributions, with shape [batch_size].

Examples

>>> import optax
>>> import jax.numpy as jnp
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = 2, num_classes = 3
>>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]])
>>> labels = jnp.array([0, 1])
>>> print(optax.softmax_cross_entropy_with_integer_labels(logits, labels))
[0.2761 2.9518]
>>> import jax.numpy as jnp
>>> import numpy as np
>>> import optax
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = (1, 2), num_classes = 12 (i.e. 3 * 4)
>>> shape = (1, 2, 3, 4)
>>> logits = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
>>> # elements indices in slice of shape (3, 4)
>>> ix = jnp.array([[1, 2]])
>>> jx = jnp.array([[1, 3]])
>>> labels = jnp.ravel_multi_index((ix, jx), shape[2:])
>>> cross_entropy = optax.softmax_cross_entropy_with_integer_labels(
...     logits, labels, axis=(2, 3))
>>> print(cross_entropy)
[[6.4587 0.4587]]

References

Cross-entropy Loss, Wikipedia

Multinomial Logistic Regression, Wikipedia

Changed in version 0.2.4: Added axis and where arguments.

Sparsemax#
optax.losses.sparsemax_loss(logits: chex.Array, labels: chex.Array) chex.Array[source]#

Binary sparsemax loss.

This loss is zero if and only if jax.nn.sparse_sigmoid(logits) == labels.

Parameters:
  • logits – score produced by the model (float).

  • labels – ground-truth integer label (0 or 1).

Returns:

loss value

References

Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins, Vlad Niculae. JMLR 2020. (Sec. 4.4)

Added in version 0.2.3.

optax.losses.multiclass_sparsemax_loss(scores: chex.Array, labels: chex.Array) chex.Array[source]#

Multiclass sparsemax loss.

Parameters:
  • scores – scores produced by the model.

  • labels – ground-truth integer labels.

Returns:

loss values

References

Martins et al, From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification <https://arxiv.org/abs/1602.02068>, 2016.

Triplet margin loss#
optax.losses.triplet_margin_loss(anchors: chex.Array, positives: chex.Array, negatives: chex.Array, axis: int = -1, norm_degree: chex.Numeric = 2, margin: chex.Numeric = 1.0, eps: chex.Numeric = 1e-06) chex.Array[source]#

Returns the triplet loss for a batch of embeddings.

Examples

>>> import jax.numpy as jnp, optax, chex
>>> jnp.set_printoptions(precision=4)
>>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]])
>>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]])
>>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]])
>>> output = optax.losses.triplet_margin_loss(anchors, positives, negatives,
...                                           margin=1.0)
>>> print(output)
[0.1414 0.1414]
Parameters:
  • anchors – An array of anchor embeddings, with shape [batch, feature_dim].

  • positives – An array of positive embeddings (similar to anchors), with shape [batch, feature_dim].

  • negatives – An array of negative embeddings (dissimilar to anchors), with shape [batch, feature_dim].

  • axis – The axis along which to compute the distances (default is -1).

  • norm_degree – The norm degree for distance calculation (default is 2 for Euclidean distance).

  • margin – The minimum margin by which the positive distance should be smaller than the negative distance.

  • eps – A small epsilon value to ensure numerical stability in the distance calculation.

Returns:

Returns the computed triplet loss as an array.

References

V. Balntas et al, Learning shallow convolutional feature descriptors with triplet losses <https://bmva-archive.org.uk/bmvc/2016/papers/paper119/abstract119.pdf> _, 2016


RetroSearch is an open source project built by @garambo | Open a GitHub Issue

Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo

HTML: 3.2 | Encoding: UTF-8 | Version: 0.7.3