Implementing Focal Loss in PyTorch for Class Imbalance

I. Introduction

“Not all data is created equal. And in machine learning, this imbalance can cost you—big time.”

If you’ve worked with real-world datasets, you already know the struggle. Most classification problems aren’t neatly balanced, where each class has an equal number of samples.

In reality, some categories are severely underrepresented. Think about fraud detection—99.9% of transactions are legit, while fraudulent ones are rare.

Standard loss functions like CrossEntropyLoss treat every class equally, which often leads to a frustrating issue: your model becomes biased toward the majority class, completely ignoring the minority class.

I’ve seen this happen myself. Early in my work with imbalanced datasets, I noticed that models would achieve a seemingly high accuracy, yet fail miserably on the underrepresented class.

The usual tricks—class weighting, oversampling—helped, but they weren’t always enough. That’s when I came across Focal Loss, and let me tell you—it changed the game.

Unlike CrossEntropyLoss, which assigns equal weight to all samples, Focal Loss dynamically adjusts the loss for each instance based on how confident the model is.

This means that easier examples contribute less to the gradient updates, while harder ones get more attention.

The result? A model that actually learns to recognize minority classes instead of brushing them off.

In this guide, I’ll walk you through a practical, no-fluff implementation of Focal Loss in PyTorch.

We’ll focus on writing clean, effective code—because at the end of the day, that’s what actually matters.


II. Understanding Focal Loss (Minimal Theory, Maximum Clarity)

Before we jump into implementation, let’s quickly break down the core idea. I promise—no unnecessary theory, just what you need to make it work.

Here’s the Focal Loss formula:

Let’s unpack this.

  • α (alpha) – This controls the weight assigned to each class. If your dataset is heavily imbalanced, setting a higher α for the minority class ensures that the model doesn’t ignore it.
  • γ (gamma) – This is where the magic happens. It reduces the loss for well-classified examples and amplifies it for hard ones. Think of it as a way to force your model to focus on difficult cases instead of getting lazy with easy ones.
  • ptp_tpt​ – The model’s predicted probability for the correct class. If the model is very confident (high ptp_tpt​), the loss contribution decreases. If it’s unsure (low ptp_tpt​), the loss is boosted.

Why Does This Matter?

Let’s say you’re training a model on a dataset where 95% of the samples belong to one class. A typical classifier will quickly learn to predict the majority class most of the time. If it gets 95% accuracy by always guessing the dominant class, that’s misleading, right?

With Focal Loss, we flip the script. Instead of treating every mistake the same, we tell the model:

“Hey, you’re already confident about these easy cases—let’s focus on the ones you keep getting wrong.”

Real-World Example: Where Focal Loss Wins

I first used Focal Loss when working on an anomaly detection project. Standard loss functions kept penalizing the majority class just as much as the rare events, leading to a model that simply ignored anomalies.

Once I switched to Focal Loss with a higher gamma (γ = 2.0) and class-weighted alpha, my model finally started detecting rare cases without sacrificing overall performance.

So now that you know why Focal Loss works, let’s move on to how to implement it in PyTorch—step by step.


III. Installing Dependencies and Environment Setup

Before we jump into implementation, let’s set up the environment properly. I’ve learned the hard way that dependency mismatches can cause unexpected errors, especially when working with PyTorch versions that change behavior slightly over time.

If you haven’t already installed PyTorch, you can do it using the official command from their website:

pip install torch torchvision

However, if you’re setting up a project from scratch, I highly recommend using a requirements.txt file to maintain consistency across environments. Here’s a minimal setup:

torch==2.0.1
torchvision==0.15.2
numpy==1.24.3

Save this in a file (requirements.txt) and install everything in one go:

pip install -r requirements.txt

If you’re working in a Jupyter Notebook, run this inside a cell:

!pip install torch torchvision numpy

That’s it—quick, clean, and future-proof. Now, let’s move to the fun part: writing our custom Focal Loss in PyTorch.


IV. Implementing Focal Loss in PyTorch

Now, let’s get our hands dirty.

A. Writing a Custom Focal Loss Class

The first time I implemented Focal Loss, I made the mistake of not subclassing nn.Module.

Instead, I wrote it as a standalone function, which made it less flexible when integrating it into complex models. Trust me—using PyTorch’s nn.Module is the way to go.

Here’s how you can implement it properly:

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        """
        Custom implementation of Focal Loss in PyTorch.
        
        Parameters:
        alpha (float): Weighting factor for the rare class (default 0.25).
        gamma (float): Modulating factor to down-weight easy examples (default 2.0).
        reduction (str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        Compute the Focal Loss between logits and targets.

        Parameters:
        inputs (Tensor): Model predictions (logits) of shape (batch_size, num_classes).
        targets (Tensor): Ground truth labels of shape (batch_size,).

        Returns:
        Tensor: Computed Focal Loss.
        """
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)  # Probability of the correct class
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

Breaking It Down

This implementation follows best practices and plays well with PyTorch’s autograd system. Here’s what’s happening:

  1. We inherit from nn.Module, making this loss function behave like any built-in PyTorch loss.
  2. We calculate standard CrossEntropyLoss but keep the individual loss values (reduction='none') so we can modify them.
  3. We apply the Focal Loss formula by weighting the loss for hard examples and down-weighting the easy ones.
  4. We provide different reduction methods (mean, sum, or none) to give you flexibility depending on your use case.

How to Use It in Your Training Loop

You might be wondering: How do I actually use this in my training pipeline?

Just like you would with any other loss function:

# Initialize the loss function
criterion = FocalLoss(alpha=0.25, gamma=2.0, reduction='mean')

# Sample input (logits) and target (class labels)
inputs = torch.tensor([[2.0, 0.5], [0.3, 1.5]], requires_grad=True)  # Example logits
targets = torch.tensor([0, 1])  # Example labels

# Compute the loss
loss = criterion(inputs, targets)

print(loss.item())  # Print the loss value

I’ve personally used this exact implementation in imbalanced classification problems like medical image analysis and fraud detection, and it consistently outperforms standard loss functions.


B. Integrating Focal Loss into a PyTorch Training Pipeline

At this point, we’ve covered what Focal Loss is and how to implement it as a custom PyTorch loss function.

But theory means nothing without practical application. So, let’s go step by step through integrating Focal Loss into a full PyTorch training pipeline.

“A model is only as good as its training pipeline.” – Something I learned the hard way.

I’ve trained multiple deep learning models on highly imbalanced datasets—think fraud detection, medical imaging, and rare object classification—and trust me, using Focal Loss without a well-structured pipeline is like putting race car tires on a broken-down car. It won’t fix the root problem.

Let’s build a fully functional training loop that properly leverages Focal Loss.

1. Model Creation

For demonstration, I’ll use ResNet-18, but you can replace it with any architecture of your choice. Since we’re focusing on a binary classification problem, we’ll modify the final layer.

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

# Load a pre-trained ResNet18 model
model = models.resnet18(pretrained=True)

# Modify the final fully connected layer for binary classification
model.fc = nn.Linear(model.fc.in_features, 2)  # 2 classes (adjust if needed)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

Why ResNet?
I often use ResNet-18 when testing pipelines because it’s lightweight, powerful, and pretrained on ImageNet, which gives it a good starting point.

2. Dataset Loading and Data Augmentation

Data augmentation can help mitigate class imbalance to some extent. If your minority class has fewer images, you can use transforms like random cropping, flipping, and color jittering to artificially increase its diversity.

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define transformations (resize, augment, normalize)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset
train_dataset = datasets.ImageFolder(root='path_to_data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

# Verify class distribution
class_counts = [len([label for _, label in train_dataset.samples if label == i]) for i in range(2)]
print(f"Class Distribution: {class_counts}")

Key Takeaways from My Experience:

  • Normalize your dataset to match the ImageNet stats if you’re using a pre-trained model.
  • Data augmentation helps but isn’t a magic fix—for extreme imbalance, consider oversampling the minority class.
  • Checking class distribution is critical—I’ve seen cases where people assumed their dataset was balanced but later found out it was heavily skewed.

3. Defining Focal Loss and Optimizer

We already implemented Focal Loss, so let’s initialize it alongside our optimizer.

# Initialize Focal Loss with tuned alpha and gamma
criterion = FocalLoss(alpha=0.25, gamma=2.0, reduction='mean')

# Choose Adam optimizer with a small learning rate
optimizer = optim.Adam(model.parameters(), lr=1e-4)

Why Adam?
For imbalanced datasets, I’ve found that Adam with a lower learning rate (1e-4 instead of the usual 1e-3) works best. It prevents the model from heavily favoring the dominant class too early.

4. Training Loop with Proper Logging

Here’s the full training loop with Focal Loss. This structure ensures proper loss logging, GPU acceleration, and gradient optimization.

num_epochs = 10  # You can increase this based on dataset size

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    total_loss = 0.0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()  # Reset gradients
        outputs = model(images)  # Forward pass
        loss = criterion(outputs, labels)  # Compute Focal Loss
        
        loss.backward()  # Backpropagation
        optimizer.step()  # Update model parameters
        
        total_loss += loss.item()  # Accumulate loss
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

What’s happening here?

  1. Move images & labels to GPU (if available).
  2. Reset gradients (optimizer.zero_grad()) before every forward pass.
  3. Forward pass through the model.
  4. Compute Focal Loss using our custom loss function.
  5. Backpropagate and update weights.
  6. Log loss at every epoch to track training progress.

Final Thoughts

This fully functional pipeline ensures that Focal Loss is properly utilized in training. If you’ve worked with imbalanced datasets before, you’ll know that:

  1. Loss functions alone don’t fix imbalance. Combine Focal Loss with data augmentation, proper sampling, and the right optimizer.
  2. Monitoring training loss is critical. I’ve seen cases where models start off learning both classes, but after a few epochs, they start predicting only the majority class—so keep an eye on your logs.
  3. Fine-tune alpha and gamma. The values I used (alpha=0.25, gamma=2.0) are common, but depending on your dataset, tweaking them can lead to better performance.

V. Tuning Focal Loss Parameters for Best Performance

Focal Loss isn’t a plug-and-play solution—you have to fine-tune its parameters to get the best results. I’ve personally seen cases where the wrong alpha (α) and gamma (γ) choices actually worsened performance instead of improving it.

If you’ve worked with imbalanced datasets before, you know how tricky it is to balance learning between the dominant and minority classes. This is where α and γ come into play.

Choosing Alpha (α): Balancing the Class Distribution

α is a weighting factor for each class. It helps balance the impact of the majority and minority classes during training.

  • If α is too low, the model still leans toward the majority class.
  • If α is too high, it may overcompensate and misclassify the majority class instead.

From my experience:

  • α = 0.25 works well for moderate imbalance (e.g., 1:10 class ratio).
  • α = 0.5 is better for severely imbalanced cases (e.g., rare disease detection).
  • If the dataset is extremely imbalanced, consider oversampling first before tweaking α.

Choosing Gamma (γ): Controlling Hard Example Emphasis

Gamma (γ) controls how much Focal Loss focuses on hard-to-classify examples. The higher the γ, the more the loss penalizes easy examples and forces the model to pay attention to the difficult ones.

  • γ = 1.0 → Closest to standard CrossEntropyLoss (very little hard example emphasis).
  • γ = 2.0 → The sweet spot for most imbalanced datasets.
  • γ = 3.0 → Even stronger focus on rare classes, useful in cases like medical imaging.

If γ is too high, the model may struggle to generalize because it over-prioritizes rare cases at the cost of common ones.

Impact of Tuning α and γ (Comparison Table)

αγResult
0.252.0Best balance for moderate imbalance
0.53.0Strong focus on rare classes, useful for extreme imbalance
0.251.0Closest behavior to CrossEntropyLoss (mild impact)
0.15.0Over-focuses on rare class → risk of overfitting

Pro Tip: If your model’s precision on the minority class is too low, try increasing γ. If your recall is bad, tweak α instead.


VI. Evaluating Model Performance

Let’s be honest—accuracy is one of the worst metrics for imbalanced classification. I’ve seen models hit 99% accuracy while completely failing to detect the minority class. That’s why I always rely on Precision, Recall, and F1-Score instead.

1. Precision, Recall, and F1-Score

Instead of just accuracy, we should analyze:

  • Precision: Of all predicted positives, how many were correct? (Important for fraud detection, where false positives are costly.)
  • Recall: Of all actual positives, how many were correctly detected? (Critical for medical diagnosis—missing a disease is worse than a false alarm.)
  • F1-Score: The harmonic mean of precision & recall, best for imbalanced datasets.

Let’s compute these metrics:

from sklearn.metrics import classification_report

# Function to evaluate model performance
def evaluate(model, data_loader):
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to('cuda'), labels.to('cuda')
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    # Print Precision, Recall, and F1-score
    print(classification_report(all_labels, all_preds, digits=4))

2. Confusion Matrix Visualization

A confusion matrix helps visualize how the model is making predictions. Here’s how to generate one using Seaborn and Matplotlib:

from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

def plot_confusion_matrix(true_labels, pred_labels):
    cm = confusion_matrix(true_labels, pred_labels)

    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=["Class 0", "Class 1"], yticklabels=["Class 0", "Class 1"])
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.show()

# Example usage (call after evaluation)
plot_confusion_matrix(all_labels, all_preds)

Final Thoughts

I’ve seen many models fail not because of the architecture or dataset—but because they were evaluated with the wrong metrics.

  1. Never trust accuracy alone. Always look at Precision, Recall, and F1-Score.
  2. A confusion matrix helps catch hidden biases. If your model always predicts the majority class, the confusion matrix will expose it.
  3. Focal Loss tuning matters. Small changes in α and γ can dramatically impact how well your model handles imbalance.

Conclusion

By now, you’ve seen how Focal Loss can be a game-changer for imbalanced classification problems. Instead of letting the majority class dominate training, it shifts focus toward hard-to-classify examples, ensuring better learning for minority classes.

Key Takeaways:

Improves Minority Class Detection – Unlike standard CrossEntropyLoss, Focal Loss helps models learn from rare classes without being overwhelmed by majority samples.

Tunable with α and γ – Fine-tuning these parameters lets you control the trade-off between recall and precision, optimizing for your specific dataset.

Works Best with Proper EvaluationAccuracy is misleading on imbalanced datasets. Always rely on Precision, Recall, F1-Score, and Confusion Matrices to truly understand model performance.

Final Thoughts: Experiment & Optimize

There’s no one-size-fits-all setting for Focal Loss. The best α and γ values depend on your dataset, so don’t be afraid to experiment.

  • Try different γ values and analyze the effect on minority class predictions.
  • Adjust α if your model struggles with class imbalance.
  • Monitor F1-Score instead of accuracy to gauge real improvements.

Machine learning isn’t about blindly applying techniques—it’s about iteration and refinement. With careful tuning and evaluation, Focal Loss can give your model the edge it needs to handle class imbalance effectively.

Leave a Comment