I. Introduction (Short and Practical)
“The real value of transfer learning kicks in when you’ve got a solid base model and just the right amount of data to steer it where you want. That’s exactly where ResNet-50 still shines.”
If you’re reading this, I’ll assume you already know your way around PyTorch and pretrained models. This isn’t a beginner walkthrough or a theory-heavy article.
What I’m sharing here is what’s worked for me when fine-tuning ResNet-50 for real-world classification problems — from niche medical datasets to highly imbalanced product images.
This guide is practical, code-first, and skips the fluff. I’ll walk you through exactly how I approach fine-tuning ResNet-50, the decisions I make along the way, and the kind of results I’ve seen — with all the implementation details that actually matter.
So if you’re looking to plug ResNet-50 into your workflow, train it fast, and get strong performance on a domain-specific task, you’re in the right place.
II. When You Should Fine-Tune ResNet-50 (and When You Shouldn’t)
Let’s be honest — ResNet-50 isn’t always the best choice anymore. I’ve seen cases where switching to EfficientNet-B0 gave me a better speed-accuracy tradeoff.
And for anything that needs understanding spatial relationships deeply (like satellite imagery or fashion datasets), ViT or Swin Transformer usually edges out the classics.
But here’s the thing: ResNet-50 is still a workhorse — especially when:
- You’ve got limited data, but it’s quite different from ImageNet (think grayscale CT scans or industrial inspection images).
- You need something that trains fast, runs on moderate hardware, and has tons of community support.
- You want to prototype quickly before committing to a heavier architecture.
Now, about how much of it to fine-tune:
In some of my projects, I started with linear probing — just training a fresh classifier on top of the frozen backbone. It’s quick, and sometimes good enough.
But when the domain shift is serious (for example, ImageNet to histopathology slides), I’ve found linear probing barely moves the needle. That’s when I go all in and fine-tune the deeper layers.
A simple rule I follow:
If the pretrained model struggles to even “guess” your classes right when frozen — it’s time to wake up more layers.
You’ll see later how I do progressive unfreezing and which layers I usually tweak first.
Coming up next: the environment setup, exact libraries I use, and how I structure the dataset for training. No ambiguity — just straight into the code.
III. Setup: Tools, Libraries, Environment
Let’s get this out of the way — because broken environments waste hours.
Personally, I stick with PyTorch ≥ 2.1 and torchvision ≥ 0.16. That combo plays nicely together, supports newer transforms, and doesn’t break transforms.v2
if you ever want to explore those. CUDA 11.8 has been stable in my setups, but I’ve also tested this on CUDA 12.1 with no issues so far.
Here’s the deal: if your GPU has low VRAM (<= 8GB), you’ll want to keep batch sizes small and possibly enable mixed precision from the start — I’ll show you how to do that later.
I mostly rely on:
torch
torchvision
tqdm
(for live progress updates — trust me, you want this)scikit-learn
(for evaluation)albumentations
(when I need more control over augmentations)matplotlib
(just for sanity-check visualizations)
If I’m using timm
, it’s usually for model comparison or for grabbing more exotic pretrained weights — not needed for basic ResNet-50 fine-tuning via torchvision.
requirements.txt
(if you want to just get going)
torch>=2.1
torchvision>=0.16
scikit-learn
tqdm
albumentations
matplotlib
You can install it like this:
pip install -r requirements.txt
One tip from my experience: always verify that
torch.cuda.is_available()
returnsTrue
before training — I’ve accidentally run full training loops on CPU because of missing drivers or a badconda
environment.
Dataset Format I Work With
Most of my projects use either:
ImageFolder
-style directory (super easy and fast with torchvision)- Custom PyTorch
Dataset
when the data is coming from CSVs, parquet, or somewhere else
For this guide, I’ll use the ImageFolder
format — it’s clean and works well enough for most fine-tuning workflows.
Directory structure:
data/
train/
class1/
class2/
val/
class1/
class2/
You can use torchvision.datasets.ImageFolder
to load it in one line.
IV. Preparing the Dataset
This is one part I don’t take lightly — your augmentation strategy can make or break the model, especially when the dataset is small or noisy.
For example, when I was working with low-resolution industrial images, aggressive color jitter actually hurt the model — it was already struggling with subtle patterns, and more randomness just made it worse.
But when I was fine-tuning ResNet on aerial imagery, stronger augmentations helped a lot.
So: tailor your transforms to your data.
Training Transforms
Here’s what I typically start with:
from torchvision import transforms
transform_train = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
And for validation:
transform_val = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
You might be wondering: why use the same mean and std as ImageNet?
Because ResNet-50 was pretrained with these stats. If your data is drastically different (e.g., grayscale medical scans), I suggest computing your dataset’s mean/std using a small batch and normalizing accordingly.
Handling Class Imbalance
Here’s a trick I’ve used multiple times — WeightedRandomSampler. When your classes are imbalanced and accuracy looks fine but your model just ignores the minority classes, this helps balance the batches.
from torch.utils.data import WeightedRandomSampler
class_counts = [5000, 500, 50] # replace with your actual counts
weights = 1. / torch.tensor(class_counts, dtype=torch.float)
sample_weights = [weights[label] for _, label in train_dataset]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
Alternatively, I’ve used focal loss when the class imbalance is extreme and WeightedSampler wasn’t enough. I’ll cover that when we set up the training loop.
V. Loading Pretrained ResNet-50 and Modifying the Head
“You don’t need to reinvent the wheel — you just need to realign it.”
That’s pretty much how I think about using pretrained ResNet-50. You already have a backbone trained on millions of images. You just need to teach it how to speak the language of your dataset.
Load the Pretrained Model
Here’s how I start every time:
import torchvision.models as models
import torch.nn as nn
model = models.resnet50(pretrained=True)
Now, depending on the domain shift, I either:
- Freeze all layers initially, and just train the new head
- Or, if the dataset is wildly different, I unfreeze deeper layers selectively
When I’m not sure, I go with this hybrid: freeze everything except the last few blocks.
Freeze / Unfreeze Logic
Here’s what I do in most of my workflows:
for param in model.parameters():
param.requires_grad = False # freeze everything
# Unfreeze last ResNet block (layer4) for better adaptation
for param in model.layer4.parameters():
param.requires_grad = True
If you’re working with a small dataset and want to avoid overfitting, freeze more. If you’re working with a larger dataset or your domain is far from ImageNet, unfreeze deeper.
Replace the Classifier Head
This is where you adapt it to your task. In my last few projects, I found adding a hidden layer with dropout helps the model generalize better — especially with fewer samples.
num_classes = 5 # change based on your task
model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
You can experiment with BatchNorm1d
after the first Linear
too — I’ve used that in some cases where the data was noisy or label distributions were skewed.
One mistake I made early on: forgetting to move the model to
.to(device)
after modifying the head. Don’t skip that when you switch to training.
VI. Training Logic
This is where things either converge fast — or fall apart quietly.
After fine-tuning a dozen variants of ResNet (18, 34, 50, 101) across different domains, here’s how I approach training:
Optimizer — AdamW vs SGD
If I’m training from scratch or doing large-scale fine-tuning, SGD + momentum (0.9) still gives me the most stable performance. But when I’m only tweaking the head or a few layers, AdamW converges faster and is just less of a hassle.
Here’s my go-to:
import torch.optim as optim
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, weight_decay=1e-2)
If you’re unfreezing the whole model, I suggest starting with a lower LR (1e-5 to 3e-5). It’s slower, but safer.
Learning Rate Scheduler
I’ve had solid results with CosineAnnealingLR, especially for longer runs.
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
That said, for very short training runs (<=10 epochs), I prefer StepLR — it keeps things simple and predictable.
If I’m doing more aggressive training with large datasets, I sometimes go with OneCycleLR, but for most fine-tuning scenarios, CosineAnnealing just works.
Loss Function
You’re probably using CrossEntropyLoss
, and that’s fine — it works well 80% of the time. But if you’re dealing with class imbalance or poor recall on minority classes, Focal Loss can be a game-changer.
Here’s a basic focal loss I’ve used in production:
class FocalLoss(nn.Module):
def __init__(self, gamma=2.0, alpha=0.25, reduction='mean'):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
Mixed Precision Training
This might surprise you: enabling AMP (automatic mixed precision) often gives me 1.5–2x speedups — especially when fine-tuning on 16GB GPUs.
scaler = torch.cuda.amp.GradScaler()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
with torch.cuda.amp.autocast():
outputs = model(images)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Gradient Clipping
When training becomes unstable — especially with larger heads or noisy datasets — I use this:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
I don’t apply it unless I see loss spikes or exploding gradients.
Logging
You can go full-on WandB or TensorBoard, but sometimes I keep it minimal:
print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Acc = {val_acc:.2f}%")
Still, if you’re running long experiments or testing multiple variants, WandB is worth it. I’ve used it to track model configs, visualize confusion matrices, and even compare runs automatically.
VII. Fine-Tuning Strategy: Layer-Wise Unfreezing
“Don’t rip the whole engine out when all you need is a better exhaust.”
That’s the philosophy I follow when it comes to fine-tuning pretrained models like ResNet-50. If you’re dealing with a small but clean dataset, full fine-tuning from the start is often overkill — and it can backfire by leading to overfitting or destroying useful pretrained features.
Why Progressive Unfreezing?
In my own experiments (especially with medical imaging and satellite datasets), progressively unfreezing layers gave consistently better results than blindly training the full model from the get-go. It allows the model to adjust slowly to your domain, like easing into cold water instead of diving headfirst.
The usual routine I follow:
- Train only the new classifier head for a few epochs
- Unfreeze
layer4
- Optionally unfreeze
layer3
, depending on how well the model adapts
This gives you a balance between preserving general features and learning domain-specific patterns.
Implementation: Unfreeze Layer by Layer
Here’s a quick pattern I use to selectively unfreeze:
# Freeze everything
for param in model.parameters():
param.requires_grad = False
# Unfreeze classifier and last ResNet block
for name, param in model.named_parameters():
if "layer4" in name or "fc" in name:
param.requires_grad = True
If I see that performance is plateauing after training with layer4
unfrozen, I’ll go a step deeper:
if "layer3" in name or "layer4" in name or "fc" in name:
param.requires_grad = True
Here’s a tip: every time you unfreeze new layers, lower the learning rate—usually by a factor of 10—to prevent catastrophic forgetting.
Also, don’t forget to reinitialize your optimizer after changing the layers being trained. That part’s easy to miss.
VIII. Evaluation
“If training is the performance, evaluation is the honest review.”
Once the model trains, I don’t just look at validation loss and accuracy. Those are superficial. In my experience, deep evaluation often tells a story that overall metrics completely miss—especially with imbalanced or multi-class datasets.
Per-Class Accuracy with classification_report
One of the first things I do is drop this after validation:
from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred, target_names=class_names))
This has saved me so many times—I’ve had models with 90% overall accuracy that completely ignored minority classes. The per-class breakdown reveals these gaps fast.
📉 Confusion Matrix (Visualized)
I usually plot the confusion matrix right after the first evaluation cycle:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()
This might surprise you: I’ve caught mislabeled data just by staring at this heatmap. If one class is being predicted incorrectly 90% of the time, I double-check the labels first.
AUC-ROC Per Class (for Imbalanced Datasets)
When you’re working with imbalanced datasets (retail product detection, fraud classification, etc.), accuracy is almost useless. That’s when I switch to AUC-ROC per class:
from sklearn.metrics import roc_auc_score
import numpy as np
# Assuming y_true and y_probs are numpy arrays
for i in range(num_classes):
auc = roc_auc_score((y_true == i).astype(int), y_probs[:, i])
print(f"AUC for class {class_names[i]}: {auc:.4f}")
In some of my fine-tunes, especially for binary classification, AUC was the only metric that exposed subtle performance issues.
🧪 Overfitting: More Than Just Val Loss Going Up
You might be wondering: “Isn’t rising validation loss enough to catch overfitting?”
Not quite. In my runs, I’ve seen models where val loss creeps up slightly, but precision for key classes nosedives. That’s usually a sign the model is learning spurious correlations or shortcuts.
I also monitor:
- A sharp drop in F1 for minority classes
- High variance in predictions (measured across multiple validation sets)
- Unstable gradients or sudden spikes in validation loss
Rule of thumb I follow: if
val_acc
stays stable butval_f1
drops sharply — you’re overfitting to dominant classes.
IX. Inference Pipeline
“Training teaches, but inference delivers.”
Once the training’s done, the real test is how smoothly your model performs in the wild. Inference isn’t just a model(input)
call — when you’re running batches through GPU (or sometimes stuck on CPU in production), a few smart tweaks can make a big difference.
Efficient Batch Inference
I usually wrap inference in a DataLoader — not just for convenience, but to keep memory usage in check when working with large image sets. Here’s a skeleton I’ve used:
def run_inference(model, dataloader, device):
model.eval()
preds = []
with torch.no_grad():
for batch in dataloader:
images = batch.to(device)
outputs = model(images)
batch_preds = torch.argmax(outputs, dim=1)
preds.extend(batch_preds.cpu().numpy())
return preds
I’ve made the mistake before of running inference one image at a time. It’s slow, wastes memory, and doesn’t scale — don’t do it.
GPU or CPU — Keep It Flexible
I always make sure the inference logic respects the environment. Your model shouldn’t choke just because CUDA isn’t available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
This lets your code run seamlessly on edge devices or inside CI pipelines, where GPU isn’t always guaranteed.
TTA (Test-Time Augmentation): Optional, But Sometimes Worth It
There are times — especially when working on noisy datasets like aerial imagery or e-commerce photos — where Test-Time Augmentation (TTA) helped squeeze out that last 1–2% accuracy.
def tta_inference(model, image, transforms_list):
model.eval()
preds = []
with torch.no_grad():
for transform in transforms_list:
augmented = transform(image).unsqueeze(0).to(device)
output = model(augmented)
preds.append(output)
final_pred = torch.stack(preds).mean(dim=0)
return torch.argmax(final_pred)
Personally, I only apply TTA when I really care about squeezing out every ounce of precision — like model competitions or medical projects.
X. Exporting the Model
“If it’s not portable, it’s not production-ready.”
You might have the best model on your machine, but unless it’s easy to export, version, and serve, it’s not going anywhere. Here’s how I handle this part cleanly.
Saving & Loading Checkpoints (With Metadata)
I like to save not just model weights, but training context as well — things like epoch number, optimizer state, etc.
# Saving
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, 'model_checkpoint.pth')
# Loading
checkpoint = torch.load('model_checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
One thing I learned the hard way: always save checkpoints after reducing learning rate during schedulers. Otherwise, your resume will restart with a bad LR and spiral.
Exporting to TorchScript or ONNX
When I’ve needed to integrate models into a C++ pipeline or deploy to edge, I went with TorchScript or ONNX.
TorchScript:
scripted_model = torch.jit.script(model)
scripted_model.save("resnet50_scripted.pt")
ONNX:
dummy_input = torch.randn(1, 3, 224, 224).to(device)
torch.onnx.export(model, dummy_input, "resnet50.onnx",
input_names=['input'], output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}})
Watch out: ONNX often breaks with custom layers or non-standard ops. I’ve had issues with F.interpolate
, certain timm
backbones, and funky pooling layers. When it works, it’s great. When it doesn’t, expect to wrestle with version mismatches.
Versioning for Production
Personally, I use simple semantic versioning (v1.0.0
) and track model hashes along with the data split used. Whether you’re using MLflow, DVC, or just structured folders, the goal is to never wonder, “Which model did I deploy last month?”
At minimum, my saved models folder looks like this:
models/
├── resnet50_v1.2.0/
│ ├── model.pt
│ ├── config.json
│ └── class_index_map.json
Here’s the deal: nothing hurts more than debugging a production issue and realizing you can’t reproduce the exact model version. Lock it down from day one.
Final Thoughts
When ResNet-50 Still Holds Its Ground in 2025
You might be thinking: “Why even bother with ResNet-50 in 2025?”
Well, here’s the truth — I’ve tried newer backbones like ConvNeXt, Swin, and even ViT-based hybrids. They’re great, no doubt. But ResNet-50? It still pulls its weight, especially when:
- You don’t have access to huge compute.
- You’re deploying to environments where smaller models = lower latency.
- You need fast iterations with transfer learning, and your dataset isn’t massive.
In many real-world cases I’ve worked on — from industrial defect detection to quick-turnaround PoCs — ResNet-50 gave me just enough performance with way less pain.
Mistakes I’ve Personally Learned From (The Hard Way)
Let me save you some trouble. Here are things I’ve messed up or seen others trip over more than once:
- Unfreezing too early: If you unfreeze everything right after loading the pretrained model, especially on a small dataset, you’ll overfit fast and hard.
- Blindly trusting val_loss: I’ve had runs where the val loss looked great, but the confusion matrix was screaming “class imbalance.” Always look deeper.
- Forgetting to normalize during inference: This one’s easy to miss. Your model expects the same normalization used during training — skip it, and predictions turn to mush.
- Using wrong image size with pretrained backbones: ResNet was trained on 224×224. You can go bigger, sure, but resizing down to 128? That cost me days of debugging poor performance.
- Exporting without testing: I’ve exported to TorchScript/ONNX, only to find out some ops don’t work in the target environment. Always run a test inference after export.

I’m a Data Scientist.