Loss Functions
Mean Squared Error: (1/n) Σ(ŷ − y)². Standard loss for regression. Penalizes large errors quadratically.
Mean Absolute Error (L1 loss): (1/n) Σ|ŷ − y|. Robust to outliers. Linear penalty.
Cross-entropy loss for multi-class classification. Expects raw logits (not softmax). Combines log-softmax and NLL loss for numerical stability. Returns a number when called with plain Tensors, or a GradTensor when called with GradTensors (for backpropagation).
Binary cross-entropy: −[y·log(ŷ) + (1−y)·log(1−ŷ)]. For binary classification. Expects probabilities in (0, 1).
Numerically stable BCE that accepts raw logits (before sigmoid). Combines sigmoid and BCE in one step. Returns a number when called with plain Tensors, or a GradTensor when called with GradTensors (for backpropagation).
Huber loss: quadratic for small errors, linear for large errors. Controlled by delta parameter. Combines benefits of MSE and MAE.
Root Mean Squared Error: √MSE. Same units as the target. More interpretable than MSE.
MSE
Where:
- ŷ = Prediction
Cross-Entropy
Where:
- ẑ = Raw logits
BCE
Where:
- ŷ = Predicted probability
Huber
Where:
- a = ŷ − y (error)
- δ = Transition threshold
import { mseLoss, crossEntropyLoss, binaryCrossEntropyLoss, huberLoss } from "deepbox/nn";import { tensor } from "deepbox/ndarray";// Regression lossconst pred = tensor([2.5, 0.0, 2.1]);const target = tensor([3.0, -0.5, 2.0]);mseLoss(pred, target); // Scalar tensorhuberLoss(pred, target, 1.0); // Robust to outliers// Multi-class classification (expects raw logits)const logits = tensor([[2.0, 1.0, 0.1]]);const labels = tensor([0], { dtype: "int32" }); // Class 0crossEntropyLoss(logits, labels);// Binary classification (expects probabilities)const probs = tensor([0.9, 0.1, 0.8]);const binLabels = tensor([1, 0, 1]);binaryCrossEntropyLoss(probs, binLabels);Choosing a Loss
- Regression → mseLoss (default), maeLoss (robust), huberLoss (balanced)
- Binary classification → binaryCrossEntropyLoss (with sigmoid output) or binaryCrossEntropyWithLogitsLoss (with raw logits)
- Multi-class classification → crossEntropyLoss (with raw logits, no softmax needed)