Computes a convex version of the Kullback-Leibler divergence loss.
Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution. This version is jointly convex in p (targets) and q (log_predictions).
log_predictions – Probabilities of predicted distribution with shape […, dim]. Expected to be in the log-space to avoid underflow.
targets – Probabilities of target distribution with shape […, dim]. Expected to be strictly positive.
axis – Axis or axes along which to compute.
where – Elements to include in the computation.
Kullback-Leibler divergence of predicted distribution from target distribution with shape […].
References
Kullback and Leibler, On Information and Sufficiency, 1951
Changed in version 0.2.4: Added axis
and where
arguments.
Computes the cosine distance between targets and predictions.
The cosine distance, implemented here, measures the dissimilarity of two vectors as the opposite of cosine similarity: 1 - cos(theta).
predictions – The predicted vectors, with shape […, dim].
targets – Ground truth target vectors, with shape […, dim].
epsilon – minimum norm for terms in the denominator of the cosine similarity.
axis – Axis or axes along which to compute.
where – Elements to include in the computation.
cosine distances, with shape […].
References
Cosine distance, Wikipedia.
Changed in version 0.2.4: Added axis
and where
arguments.
Computes the cosine similarity between targets and predictions.
The cosine similarity is a measure of similarity between vectors defined as the cosine of the angle between them, which is also the inner product of those vectors normalized to have unit norm.
predictions – The predicted vectors, with shape […, dim].
targets – Ground truth target vectors, with shape […, dim].
epsilon – minimum norm for terms in the denominator of the cosine similarity.
axis – Axis or axes along which to compute.
where – Elements to include in the computation.
cosine similarity measures, with shape […].
References
Cosine similarity, Wikipedia.
Changed in version 0.2.4: Added axis
and where
arguments.
Computes CTC loss.
See docstring for ctc_loss_with_forward_probs
for details.
logits – (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in logits
, and K denotes the number of classes including a class for blanks.
logit_paddings – (B, T)-array. Padding indicators for logits
. Each element must be either 1.0 or 0.0, and logitpaddings[b, t] == 1.0
denotes that logits[b, t, :]
are padded values.
labels – (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence.
label_paddings – (B, N)-array. Padding indicators for labels
. Each element must be either 1.0 or 0.0, and labelpaddings[b, n] == 1.0
denotes that labels[b, n]
is a padded label. In the current implementation, labels
must be right-padded, i.e. each row labelpaddings[b, :]
must be repetition of zeroes, followed by repetition of ones.
blank_id – Id for blank token. logits[b, :, blank_id]
are used as probabilities of blank symbols.
log_epsilon – Numerically-stable approximation of log(+0).
(B,)-array containing loss values for each sequence in the batch.
Computes CTC loss and CTC forward-probabilities.
The CTC loss is a loss function based on log-likelihoods of the model that introduces a special blank symbol \(\phi\) to represent variable-length output sequences.
Forward probabilities returned by this function, as auxiliary results, are grouped into two part: blank alpha-probability and non-blank alpha probability. Those are defined as follows:
\[\alpha_{\mathrm{BLANK}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1}, \cdots), \\ \alpha_{\mathrm{LABEL}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1}, \cdots). \]
Here, \(\pi\) denotes the alignment sequence in the reference [Graves et al, 2006] that is blank-inserted representations of labels
. The return values are the logarithms of the above probabilities.
logits – (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in logits
, and K denotes the number of classes including a class for blanks.
logit_paddings – (B, T)-array. Padding indicators for logits
. Each element must be either 1.0 or 0.0, and logitpaddings[b, t] == 1.0
denotes that logits[b, t, :]
are padded values.
labels – (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence.
label_paddings – (B, N)-array. Padding indicators for labels
. Each element must be either 1.0 or 0.0, and labelpaddings[b, n] == 1.0
denotes that labels[b, n]
is a padded label. In the current implementation, labels
must be right-padded, i.e. each row labelpaddings[b, :]
must be repetition of zeroes, followed by repetition of ones.
blank_id – Id for blank token. logits[b, :, blank_id]
are used as probabilities of blank symbols.
log_epsilon – Numerically-stable approximation of log(+0).
A tuple (loss_value, logalpha_blank, logalpha_nonblank)
. Here, loss_value
is a (B,)-array containing the loss values for each sequence in the batch, logalpha_blank
and logalpha_nonblank
are (T, B, N+1)-arrays where the (t, b, n)-th element denotes log alpha_B(t, n) and log alpha_L(t, n), respectively, for b
-th sequence in the batch.
References
Graves et al, Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks, 2006
Creates a Fenchel-Young loss from a max function.
max_fun – the max function on which the Fenchel-Young loss is built.
A Fenchel-Young loss function with the same signature.
Examples
Given a max function, e.g., the log sum exp, you can construct a Fenchel-Young loss easily as follows:
>>> from jax.scipy.special import logsumexp >>> fy_loss = optax.losses.make_fenchel_young_loss(max_fun=logsumexp)
Blondel et al. Learning with Fenchel-Young Losses, 2020
Warning
The resulting loss accepts an arbitrary number of leading dimensions with the fy_loss operating over the last dimension. The jaxopt version of this function would instead flatten any vector in a single big 1D vector.
Computes the hinge loss for binary classification.
predictor_outputs – Outputs of the decision function.
targets – Target values. Target values should be strictly in the set {-1, 1}.
loss value.
Multiclass hinge loss.
scores – scores produced by the model (floats).
labels – ground-truth integer labels.
loss values
References
Hinge loss, Wikipedia
Added in version 0.2.3.
Huber loss, similar to L2 loss close to zero, L1 loss away from zero.
If gradient descent is applied to the huber loss, it is equivalent to clipping gradients of an l2_loss to [-delta, delta] in the backward pass.
predictions – a vector of arbitrary shape […].
targets – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.
delta – the bounds for the huber loss transformation, defaults at 1.
elementwise huber losses, with the same shape of predictions.
References
Huber loss, Wikipedia.
Computes the Kullback-Leibler divergence (relative entropy) loss.
Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution.
log_predictions – Probabilities of predicted distribution with shape […, dim]. Expected to be in the log-space to avoid underflow.
targets – Probabilities of target distribution with shape […, dim]. Expected to be strictly positive.
axis – Axis or axes along which to compute.
where – Elements to include in the computation.
Kullback-Leibler divergence of predicted distribution from target distribution with shape […].
References
Kullback and Leibler, On Information and Sufficiency, 1951
Changed in version 0.2.4: Added axis
and where
arguments.
Computes the Kullback-Leibler divergence (relative entropy) loss.
Version of kl_div_loss where targets are given in log-space.
log_predictions – Probabilities of predicted distribution with shape […, dim]. Expected to be in the log-space to avoid underflow.
log_targets – Probabilities of target distribution with shape […, dim]. Expected to be in the log-space.
axis – Axis or axes along which to compute.
where – Elements to include in the computation.
Kullback-Leibler divergence of predicted distribution from target distribution with shape […].
Changed in version 0.2.4: Added axis
and where
arguments.
Calculates the squared error for a set of predictions.
Mean Squared Error can be computed as squared_error(a, b).mean().
predictions – a vector of arbitrary shape […].
targets – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.
elementwise squared differences, with same shape as predictions.
Note
l2_loss = 0.5 * squared_error, where the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.
Calculates the L2 loss for a set of predictions.
predictions – a vector of arbitrary shape […].
targets – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.
elementwise squared differences, with same shape as predictions.
Note
the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.
Calculates the log-cosh loss for a set of predictions.
log(cosh(x)) is approximately (x**2) / 2 for small x and abs(x) - log(2) for large x. It is a twice differentiable alternative to the Huber loss.
predictions – a vector of arbitrary shape […].
targets – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.
the log-cosh loss, with same shape as predictions.
References
Chen et al, Log Hyperbolic Cosine Loss Improves Variational Auto-Encoder <https://openreview.net/pdf?id=rkglvsC9Ym>, 2019
Normalized temperature scaled cross entropy loss (NT-Xent).
Examples
>>> import jax >>> import optax >>> import jax.numpy as jnp >>> >>> key = jax.random.key(42) >>> key1, key2, key3 = jax.random.split(key, 3) >>> x = jax.random.normal(key1, shape=(4,2)) >>> labels = jnp.array([0, 0, 1, 1]) >>> >>> print("input:", x) input: [[ 0.07592554 -0.48634264] [ 1.2903206 0.5196119 ] [ 0.30040437 0.31034866] [ 0.5761609 -0.8074621 ]] >>> print("labels:", labels) labels: [0 0 1 1] >>> >>> w = jax.random.normal(key2, shape=(2,1)) # params >>> b = jax.random.normal(key3, shape=(1,)) # params >>> out = x @ w + b # model >>> >>> print("Embeddings:", out) Embeddings: [[0.08969027] [1.6291292 ] [0.8622629 ] [0.13612625]] >>> loss = optax.ntxent(out, labels) >>> print("loss:", loss) loss: 1.0986123
embeddings – batch of embeddings, with shape [batch, feature_length]
labels – labels for groups that are positive pairs. e.g. if you have a batch of 4 embeddings and the first two and last two were positive pairs your labels should look like [0, 0, 1, 1]. Shape [batch]
temperature – temperature scaling parameter.
A scalar loss value of NT-Xent values averaged over all positive pairs
References
T. Chen et al A Simple Framework for Contrastive Learning of Visual Representations, 2020
kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss
Added in version 0.2.3.
Computes PolyLoss between logits and labels.
The PolyLoss is a loss function that decomposes commonly used classification loss functions into a series of weighted polynomial bases. It is inspired by the Taylor expansion of cross-entropy loss and focal loss in the bases of \((1 - P_t)^j\).
\[L_{Poly} = \sum_1^\infty \alpha_j \cdot (1 - P_t)^j \\ L_{Poly-N} = (\epsilon_1 + 1) \cdot (1 - P_t) + \ldots + \\ (\epsilon_N + \frac{1}{N}) \cdot (1 - P_t)^N + \frac{1}{N + 1} \cdot (1 - P_t)^{N + 1} + \ldots = \\ - \log(P_t) + \sum_{j = 1}^N \epsilon_j \cdot (1 - P_t)^j \]
This function provides a simplified version of \(L_{Poly-N}\) with only the coefficient of the first polynomial term being changed.
logits – Unnormalized log probabilities, with shape […, num_classes].
labels – Valid probability distributions (non-negative, sum to 1), e.g. a one hot encoding specifying the correct class for each input; must have a shape broadcastable to […, num_classes].
epsilon – The coefficient of the first polynomial term. According to the paper, the following values are recommended: - For the ImageNet 2d image classification, epsilon = 2.0. - For the 2d Instance Segmentation and object detection, epsilon = -1.0. - It is also recommended to adjust this value based on the task, e.g. by using grid search.
axis – Axis or axes along which to compute.
where – Elements to include in the computation.
Poly loss between each prediction and the corresponding target distributions, with shape […].
References
Leng et al, PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions, 2022
Changed in version 0.2.4: Added axis
and where
arguments.
Binary perceptron loss.
predictor_outputs – score produced by the model (float).
targets – Target values. Target values should be strictly in the set {-1, 1}.
loss value.
References
Perceptron, Wikipedia
Multiclass perceptron loss.
scores – scores produced by the model.
labels – ground-truth integer labels.
loss values.
References
Michael Collins. Discriminative training methods for Hidden Markov Models: Theory and experiments with perceptron algorithms. EMNLP 2002
Added in version 0.2.2.
Ranking softmax loss.
Definition:
\[\ell(s, y) = -\sum_i y_i \log \frac{\exp(s_i)}{\sum_j \exp(s_j)} \]
logits – A [..., list_size]
-Array
, indicating the score of each item.
labels – A [..., list_size]
-Array
, indicating the relevance label for each item.
where – An optional [..., list_size]
-Array
, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.
weights – An optional [..., list_size]
-Array
, indicating the weight for each item.
reduce_fn – An optional function that reduces the loss values. Can be jax.numpy.sum()
or jax.numpy.mean()
. If None
, no reduction is performed.
The ranking softmax loss.
Computes element-wise sigmoid cross entropy given logits and labels.
This function can be used for binary or multiclass classification (where each class is an independent binary prediction and different classes are not mutually exclusive e.g. predicting that an image contains both a cat and a dog.)
Because this function is overloaded, please ensure your logits and labels are compatible with each other. If you’re passing in binary labels (values in {0, 1}), ensure your logits correspond to class 1 only. If you’re passing in per-class target probabilities or one-hot labels, please ensure your logits are also multiclass. Be particularly careful if you’re relying on implicit broadcasting to reshape logits or labels.
logits – Each element is the unnormalized log probability of a binary prediction. See note about compatibility with labels above.
labels – Binary labels whose values are {0,1} or multi-class target probabilities. See note about compatibility with logits above.
cross entropy for each binary prediction, same shape as logits.
References
Goodfellow et al, Deep Learning, 2016
Sigmoid focal loss.
The focal loss is a re-weighted cross entropy for unbalanced problems. Use this loss function if classes are not mutually exclusive. See sigmoid_binary_cross_entropy for more information.
logits – Array of floats. The predictions for each example. The predictions for each example.
labels – Array of floats. Labels and logits must have the same shape. The label array must contain the binary classification labels for each element in the data set (0 for the out-of-class and 1 for in-class).
alpha – (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default None (no weighting).
gamma – Exponent of the modulating factor (1 - p_t). Balances easy vs hard examples.
A loss value array with a shape identical to the logits and target arrays.
References
Lin et al. Focal Loss for Dense Object Detection, 2017
Apply label smoothing.
Label smoothing is often used in combination with a cross-entropy loss. Smoothed labels favor small logit gaps, and it has been shown that this can provide better model calibration by preventing overconfident predictions.
labels – One hot labels to be smoothed.
alpha – The smoothing factor.
a smoothed version of the one hot input labels.
References
Muller et al, When does label smoothing help?, 2019
Computes the softmax cross entropy between sets of logits and labels.
Contrarily to optax.softmax_cross_entropy()
this function handles labels*logsoftmax(logits)
as 0
when logits=-inf
and labels=0
, following the convention that 0 log 0 = 0
.
logits – Unnormalized log probabilities, with shape […, num_classes].
labels – Valid probability distributions (non-negative, sum to 1), e.g a one hot encoding specifying the correct class for each input; must have a shape broadcastable to […, num_classes].
cross entropy between each prediction and the corresponding target distributions, with shape […].
Computes the softmax cross entropy between sets of logits and labels.
This loss function is commonly used for multi-class classification tasks. It measures the dissimilarity between the predicted probability distribution (obtained by applying the softmax function to the logits) and the true probability distribution (represented by the one-hot encoded labels). This loss is also known as categorical cross entropy.
Let \(x\) denote the logits
array of size [batch_size, num_classes]
and \(y\) denote the labels
array of size [batch_size, num_classes]
. Then this function returns a vector \(\sigma\) of size [batch_size]
defined as:
\[\sigma_i = - \sum_j y_{i j} \log\left(\frac{\exp(x_{i j})}{\sum_k \exp(x_{i k})}\right) \,. \]
logits – Unnormalized log probabilities, with shape [batch_size, num_classes]
.
labels – One-hot encoded labels, with shape [batch_size, num_classes]. Each row represents the true class distribution for a single example.
axis – Axis or axes along which to compute.
where – Elements to include in the computation.
Cross-entropy between each prediction and the corresponding target distributions, with shape [batch_size]
.
Examples
>>> import optax >>> import jax.numpy as jnp >>> jnp.set_printoptions(precision=4) >>> # example: batch_size = 2, num_classes = 3 >>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]]) >>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) >>> print(optax.softmax_cross_entropy(logits, labels)) [0.2761 2.9518]
References
Cross-entropy Loss, Wikipedia
Multinomial Logistic Regression, Wikipedia
Changed in version 0.2.4: Added axis
and where
arguments.
Computes softmax cross entropy between the logits and integer labels.
This loss is useful for classification problems with integer labels that are not one-hot encoded. This loss is also known as categorical cross entropy.
Let \(x\) denote the logits
array of size [batch_size, num_classes]
and \(y\) denote the labels
array of size [batch_size]
. Then this function returns a vector \(\sigma\) of size [batch_size]
defined as:
\[\sigma_i = \log\left(\frac{\exp(x_{i y_i})}{\sum_j \exp(x_{i j})}\right)\,. \]
logits – Unnormalized log probabilities, with shape [batch_size, num_classes]
.
labels – Integers specifying the correct class for each input, with shape [batch_size]
. Class labels are assumed to be between 0 and num_classes - 1
inclusive.
axis – Axis or axes along which to compute. If a tuple of axes is passed then num_classes
must match the total number of elements in axis
dimensions and a label is interpreted as a flat index in a logits
slice of shape logits[axis]
.
where – Elements to include in the computation.
Cross-entropy between each prediction and the corresponding target distributions, with shape [batch_size]
.
Examples
>>> import optax >>> import jax.numpy as jnp >>> jnp.set_printoptions(precision=4) >>> # example: batch_size = 2, num_classes = 3 >>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]]) >>> labels = jnp.array([0, 1]) >>> print(optax.softmax_cross_entropy_with_integer_labels(logits, labels)) [0.2761 2.9518]
>>> import jax.numpy as jnp >>> import numpy as np >>> import optax >>> jnp.set_printoptions(precision=4) >>> # example: batch_size = (1, 2), num_classes = 12 (i.e. 3 * 4) >>> shape = (1, 2, 3, 4) >>> logits = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape) >>> # elements indices in slice of shape (3, 4) >>> ix = jnp.array([[1, 2]]) >>> jx = jnp.array([[1, 3]]) >>> labels = jnp.ravel_multi_index((ix, jx), shape[2:]) >>> cross_entropy = optax.softmax_cross_entropy_with_integer_labels( ... logits, labels, axis=(2, 3)) >>> print(cross_entropy) [[6.4587 0.4587]]
References
Cross-entropy Loss, Wikipedia
Multinomial Logistic Regression, Wikipedia
Changed in version 0.2.4: Added axis
and where
arguments.
Binary sparsemax loss.
This loss is zero if and only if jax.nn.sparse_sigmoid(logits) == labels.
logits – score produced by the model (float).
labels – ground-truth integer label (0 or 1).
loss value
References
Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins, Vlad Niculae. JMLR 2020. (Sec. 4.4)
Added in version 0.2.3.
Multiclass sparsemax loss.
scores – scores produced by the model.
labels – ground-truth integer labels.
loss values
References
Martins et al, From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification <https://arxiv.org/abs/1602.02068>, 2016.
Returns the triplet loss for a batch of embeddings.
Examples
>>> import jax.numpy as jnp, optax, chex >>> jnp.set_printoptions(precision=4) >>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]]) >>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]]) >>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]]) >>> output = optax.losses.triplet_margin_loss(anchors, positives, negatives, ... margin=1.0) >>> print(output) [0.1414 0.1414]
anchors – An array of anchor embeddings, with shape [batch, feature_dim].
positives – An array of positive embeddings (similar to anchors), with shape [batch, feature_dim].
negatives – An array of negative embeddings (dissimilar to anchors), with shape [batch, feature_dim].
axis – The axis along which to compute the distances (default is -1).
norm_degree – The norm degree for distance calculation (default is 2 for Euclidean distance).
margin – The minimum margin by which the positive distance should be smaller than the negative distance.
eps – A small epsilon value to ensure numerical stability in the distance calculation.
Returns the computed triplet loss as an array.
References
V. Balntas et al, Learning shallow convolutional feature descriptors with triplet losses <https://bmva-archive.org.uk/bmvc/2016/papers/paper119/abstract119.pdf> _, 2016
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.3