Fine-Tuning Gemma for Custom NLP Tasks

1. Why Gemma?

“You don’t always need a 13B model to get 13B results.”

That’s something I’ve learned firsthand after spending weeks fine-tuning various open LLMs for lightweight, on-device use cases.

When I started experimenting with Gemma, I wasn’t chasing hype — I was just tired of hitting memory ceilings with LLaMA2 and constantly fighting with Mistral’s tokenization quirks.

For me, Gemma hit the sweet spot: it’s compact, modular, and fits neatly into a 4-bit LoRA pipeline without the usual headaches.

Plus, being natively aligned with Google’s stack makes it a lot easier to scale later — especially if you’re planning to serve models via Vertex AI or similar.

In short: I picked Gemma because it respects your time and hardware. And after fine-tuning it on multiple real-world tasks — from structured text generation to knowledge-intensive QA — it’s earned its place in my toolbox.


2. Environment Setup

“Before you talk to the model, make sure the model can talk to your GPU.”

Here’s how I personally set up my environment for Gemma fine-tuning. I’m not using Colab, Kaggle, or any sandboxed setup — this is my local development stack, tested on multiple RTX 4090 and A100 machines.

Base Environment (Python 3.10, CUDA 11.8)

I used a clean conda environment to isolate dependencies. You could also use Docker, but I prefer Conda for its speed during iteration.

conda create -n gemma-finetune python=3.10 -y
conda activate gemma-finetune

Core Libraries

These are the libraries I use across almost all my LoRA fine-tuning pipelines:

pip install transformers==4.39.3 peft==0.10.0 trl==0.7.10 bitsandbytes==0.43.1 accelerate==0.28.0

Note: trl is especially useful if you’re using SFTTrainer or planning to integrate RLHF later.

Optional — Check GPU Availability

I always confirm my environment is using the GPU before launching any training job. You’d be surprised how often model.to("cuda") silently fails on misconfigured systems.

import torch
print(torch.cuda.get_device_name(0))
print(torch.cuda.is_available())

Full requirements.txt (if you prefer installing everything in one go)

transformers==4.39.3
peft==0.10.0
trl==0.7.10
bitsandbytes==0.43.1
accelerate==0.28.0
datasets==2.18.0

If you’re using flash-attn or xformers for additional memory savings, I’d suggest installing them from source based on your CUDA version — but more on that when we get into memory optimization.


3. Model Selection: gemma-2b vs gemma-7b

“Don’t ask ‘how big is the model?’ Ask ‘how much of it can your GPU actually handle?’”

I’ve worked with both gemma-2b and gemma-7b under different memory constraints — and the reality is, you won’t always get to pick the bigger one, especially if you’re running on a single GPU setup.

Here’s how I approach this:

ModelVRAM (4-bit)VRAM (8-bit)Full PrecisionContext Limit
gemma-2b~3.5 GB~6.5 GB~12 GB8192 tokens
gemma-7b~7.5–8.5 GB~14 GB~27 GB8192 tokens

If I’m running LoRA fine-tuning on a single 3090 or 4080, gemma-2b is the safer bet. It leaves just enough headroom for a decent batch size, especially when you combine it with gradient accumulation and 4-bit quantization.

Pro tip: gemma-7b is trainable on a single 80GB A100, but you’ll need to reduce your batch size to 1 or 2 depending on sequence length.

Here’s the exact setup I used to load gemma-2b with 4-bit quantization using BitsAndBytes:

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b",
    quantization_config=bnb_config,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

Heads up: if you’re trying this with gemma-7b, just swap the model name and be ready to play Tetris with your batch size and gradient accumulation settings.

So which one should you choose? If you’re building for quick prototyping or deploying lightweight models, go with 2B. If you’re doing something heavy like multi-hop reasoning or document-grounded generation, and you’ve got the hardware — 7B pays off.


4. Dataset Preparation

“Garbage in, garbage out — but more like ‘half-cleaned JSON in, hallucinations out.’”

Let me be honest — this part took me the most time when I was setting up Gemma for fine-tuning. Tokenization is usually where things silently break, especially if you’re not respecting padding, truncation, or EOS tokens.

Dataset Format

I’ve used both instruction-following and conversational formats. Here’s an example of what works well with Gemma in instruction tuning scenarios:

{
  "instruction": "Translate this sentence to French.",
  "input": "The weather is nice today.",
  "output": "Il fait beau aujourd'hui."
}

Or you can flatten it like this:

{
  "text": "Instruction: Translate this sentence to French.\nInput: The weather is nice today.\nOutput: Il fait beau aujourd'hui."
}

If your use case is multi-turn chat, you’ll need to handle prompt formatting and EOS tokens differently — but for standard SFT, the above works fine.

Loading the Dataset

Here’s how I load a cleaned version of the Alpaca dataset:

from datasets import load_dataset

dataset = load_dataset("tatsu-lab/alpaca", split="train")

You might want to subset or filter the data — I usually drop samples with missing fields or overly long outputs.

Tokenization Strategy

This is what I use when preparing inputs:

def tokenize(example):
    text = f"Instruction: {example['instruction']}\nInput: {example['input']}\nOutput: {example['output']}"
    return tokenizer(
        text,
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    )

tokenized_dataset = dataset.map(tokenize)

Make sure you add the eos_token at the end of each sample, especially if you’re training with labels=input_ids.clone() — otherwise the model doesn’t learn where to stop.

Data Collator

For Gemma, I’ve had good results with the DataCollatorForLanguageModeling:

from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

5. Fine-Tuning Strategy (This is where the magic happens)

“You don’t need to fine-tune the whole castle. Just tweak the hinges on the right doors.”

When I first tried full fine-tuning on Gemma, it was… let’s just say not friendly on my GPU. I ended up switching to parameter-efficient fine-tuning (PEFT) with LoRA, and haven’t looked back since.

If you’re dealing with 2B or 7B models and want results without melting your hardware, LoRA with PEFT is the way to go.

a. LoRA Configuration with PEFT

I like keeping my LoraConfig explicit — I’ve seen weird results when defaults aren’t tuned for your model size.

Here’s what worked well for me on gemma-2b:

from peft import LoraConfig, get_peft_model, TaskType

peft_config = LoraConfig(
    r=16,                      # rank
    lora_alpha=32,             # scaling
    lora_dropout=0.1,          # dropout during training
    bias="none",               # keep bias untouched
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "v_proj"]  # works well with Gemma’s attention blocks
)

model = get_peft_model(model, peft_config)

“Why q_proj and v_proj?”
From my experience, those are often the most effective hooks when tuning transformer attention without destabilizing training. I’ve tested others (k_proj, o_proj) — they don’t always give ROI unless you’re training much longer.

LoRA drastically cuts down on memory usage, and honestly, it trains faster without losing much quality — especially on instruction-following tasks.

b. Trainer vs SFTTrainer vs Custom Loop

This might surprise you: you don’t need a custom training loop unless you’re doing something exotic.

I’ve used all three. Here’s what I’ve personally found:

OptionWhen I Use It
transformers.TrainerFor general tasks, especially classification / regression.
trl.SFTTrainerWhen doing SFT on instruction datasets (e.g., Alpaca-style).
Custom loopOnly if I need dynamic loss scaling, multi-loss objectives, or RLHF.

For Gemma SFT, I went with SFTTrainer from trl — it just works, supports LoRA, and integrates well with PEFT models.

Here’s a sample training config that worked well for me on a single 24GB GPU (like 3090 or 4090):

from transformers import TrainingArguments
from trl import SFTTrainer

training_args = TrainingArguments(
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-5,
    fp16=True,
    logging_steps=10,
    save_total_limit=1,
    save_strategy="epoch",
    output_dir="./gemma-finetuned"
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator
)

If you’re wondering whether fp16=True is stable — I’ve had no issues on Ampere and later. But if you’re on older GPUs, you may want to test bf16 or disable it.

One quick thing: SFTTrainer doesn’t automatically call model.gradient_checkpointing_enable() — so you might want to add that manually to save VRAM.


6. Memory Optimization (Advanced but Necessary)

“Fast is fine, but accuracy is everything. In a gunfight, you need to take your time in a hurry.” – Wyatt Earp

That quote sums up training large models on limited hardware: you want speed, but not at the cost of stability.

When I fine-tuned gemma-2b, I tested a few memory optimization techniques in isolation — but the real breakthrough came from stacking them intelligently.

Quantization with bitsandbytes

If you’re not using 4-bit quantization via bitsandbytes, you’re burning memory for no good reason. I’ve personally had great results with this configuration:

from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",      # nf4 usually gives better performance than fp4
    bnb_4bit_use_double_quant=True  # helps stabilize training in my experience
)

Then load your model like this:

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b",
    device_map="auto",
    quantization_config=bnb_config
)

Pro tip: With 4-bit + LoRA, I was able to fine-tune on a single 24GB GPU without OOM issues. That wasn’t the case with full precision.

Gradient Checkpointing

This is a must if you’re using LoRA and want to push batch sizes without running into memory walls. I’ve used it reliably with PEFT:

model.gradient_checkpointing_enable()
model.enable_input_require_grads()

I’ve seen people forget enable_input_require_grads() — without it, the model won’t compute gradients properly with LoRA layers.

Mixed Precision: fp16 vs bf16

Personally, I stick with fp16 unless I’m on a GPU that really shines with bf16 (e.g., A100). But here’s how you can enable either safely:

TrainingArguments(
    ...
    fp16=True,     # or use bf16=True
    ...
)

One heads-up: fp16 can sometimes overflow if your model isn’t stable. Use gradient_accumulation_steps > 1 to smooth it out.

Flash Attention?

You might be wondering: “Can I use FlashAttention with Gemma?”

As of now, Gemma doesn’t have native support for FlashAttention v2 (like Mistral or Mixtral do). But I’ve still seen some gains using xformers or Scaled Dot Product Attention in custom loops — though that requires modifying internals, which I wouldn’t recommend unless you’re building your own training stack.


7. Evaluation

“If you can’t measure it, you can’t improve it.” — Peter Drucker
I learned this the hard way. Early on, I’d just feel like my model was doing better — until I compared outputs side-by-side and realized it wasn’t improving at all.

Here’s how I run real-world evaluations on fine-tuned models — beyond just “it looks good.”

Inference Sanity Check (Always Do This)

Before running metrics, I always generate a few examples manually. It helps catch edge cases metrics won’t.

inputs = tokenizer("Instruction: What's the capital of France?", return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(output[0], skip_special_tokens=True))

I do this on both the base model and fine-tuned model to compare output style, factuality, and instruction compliance.

Structured Evaluation: BLEU / ROUGE / Custom

For instruction-tuned models, I lean on ROUGE more than BLEU — especially for summarization or question answering.

But for true instruction-following tasks, I prefer writing custom match functions that check if the model followed the structure (e.g., included rationale + answer, or steps in reasoning).

Here’s a simplified ROUGE example using evaluate:

import evaluate

rouge = evaluate.load("rouge")
results = rouge.compute(predictions=[gen], references=[reference])
print(results)

I’ve also used g-eval to plug GPT-4 into my eval loop for subjective assessments, especially for long-form outputs.

Before vs After: A Real Example

Here’s a comparison from my own fine-tuning:

Before fine-tuning (Gemma base):

“I don’t know what you’re asking. Please provide more information.”

After fine-tuning (Gemma SFT’d on Alpaca-style dataset):

“Sure! The capital of France is Paris. Let me know if you’d like more facts like this.”

Notice the shift in tone? That’s exactly what SFT with a clean instruction dataset does.


8. Saving, Loading, and Deployment

“Training a model is only half the job — the real challenge is using it like it was never fine-tuned at all.”

This might sound obvious, but trust me, how you save your model can make or break your ability to reload it correctly — especially if you’re using LoRA or quantization.

Saving Your Model (PEFT Adapter Included)

In my workflow, I always separate the base model from the LoRA adapter — this keeps things modular and avoids versioning issues when loading.

# Save the LoRA adapter (not the full base model)
model.save_pretrained("./gemma-lora-adapter")
tokenizer.save_pretrained("./gemma-lora-adapter")

If you want to push it to the Hub (and you should, for reproducibility), this is how I do it:

model.push_to_hub("your-username/gemma-lora-adapter")
tokenizer.push_to_hub("your-username/gemma-lora-adapter")

Personally, I’ve found this super helpful when running multiple fine-tuning experiments — each adapter version gets its own repo.

Loading the Fine-Tuned Adapter

I’ve seen folks try to load the full model directly and get weird results. You need to load the base and then attach the LoRA adapter.

Here’s the exact setup I use:

from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b",
    device_map="auto",
    quantization_config=bnb_config,
)

model = PeftModel.from_pretrained(base_model, "./gemma-lora-adapter")

Don’t forget to load the tokenizer from the adapter folder too. This avoids weird tokenization mismatches.

Quantized Inference with pipeline

When I’m doing quick testing (or prototyping a FastAPI endpoint), this is the fastest way to get started:

from transformers import pipeline

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto"
)

out = pipe("Instruction: Generate a haiku about data science.", max_new_tokens=60)
print(out[0]['generated_text'])

It just works — and yes, even with 4-bit models and LoRA adapters, as long as your underlying setup is clean.

Production Deployment (vLLM, text-generation-inference, FastAPI)

You might be wondering: “How do I actually serve this thing?”

Here’s what I’ve used in real projects:

  • FastAPI – Lightweight REST endpoint, good for custom pipelines or async workloads.
  • vLLM – If you want throughput. Hugely efficient for batched inference (esp. with longer contexts).
  • text-generation-inference (TGI) – Ideal for Hugging Face Hub deployment, plug-and-play.

When using vLLM, make sure your model and tokenizer are fully merged (i.e., not split across base and adapter). I usually do this merge manually using PEFT utilities when needed.


9. Troubleshooting (Things I’ve Actually Run Into)

“Everything works… until you hit ‘train’.”
— Me, every single time I test a new config.

Here’s a list of real-world issues I’ve hit — and how I fix them.

Model not learning?

This one’s sneaky. Your loss goes down a bit… then plateaus. Classic symptoms:

  • Learning rate too low → Try 2e-4 to 1e-4 for LoRA, especially on smaller datasets.
  • Too few trainable parameters → Unfreeze more target modules (e.g., add "o_proj", "gate_proj").
  • Bad data mix → Your dataset is too small or too noisy. Try adding synthetic instruction data.

CUDA OOM?

Yep, been there.

  • Batch size → Drop it. Use gradient_accumulation_steps to compensate.
  • Sequence length → Reduce from 1024 to 512 or 256 during experimentation.
  • Enable this combo:
model.gradient_checkpointing_enable()
bnb_config = BitsAndBytesConfig(load_in_4bit=True, ...)

With this combo, I was able to fine-tune a 7B model on a single 24GB GPU.

Tokenizer throwing shape mismatch errors?

Nine times out of ten, it’s this:

  • You’re loading the wrong tokenizer (e.g., from the base model instead of the adapter folder).
  • Or you’re not padding/truncating properly during tokenization.

Here’s how I usually tokenize:

tokenizer("Prompt text here", padding="max_length", truncation=True, max_length=512, return_tensors="pt")

If you’re using a custom collator, double-check how it handles special tokens and attention_mask.


10. Closing Thoughts: The Good, the Quirky, and What’s Next

“Every model has a personality. You just have to spend enough time with it to figure it out.”

After working with Gemma, here’s what stood out — both the wins and the trade-offs.

What Worked Well

Honestly, for a relatively compact open model, Gemma surprised me — especially the 2B variant with 4-bit quantization. Here’s what I found solid:

  • Fast prototyping: With quantized loading + PEFT, I was able to get a fine-tuned model up and running on a single A100 in just a couple of hours.
  • Instruction tuning compatibility: Models like Gemma-2B took to Alpaca-style prompts really well. I didn’t need heavy prompt engineering.
  • Modular deployment: Saving LoRA adapters separately and plugging them into quantized base models made it easy to ship different versions across projects.

What Didn’t Work So Smoothly

Now, let’s talk about where things got messy — because no model’s perfect.

  • Hallucinations: Especially under ambiguous prompts or under-trained adapters. I noticed this more in conversational fine-tunes.
  • Token mismatch errors: The tokenizer is sensitive — load the exact same tokenizer used during training or you’ll waste hours debugging weird output.
  • Context length limitations: While it’s marketed with decent context length, I found generation coherence drops off noticeably after ~1,500 tokens.
  • Inference speed: Even with 4-bit, larger versions (7B) weren’t snappy without Flash Attention or vLLM. You’ll definitely want to optimize your stack.

Where to Go Next

Here’s where I’d point you if you’re planning to push this further:

🔸 QLoRA for Extreme Memory Efficiency

If you’re working with tight memory budgets or planning to train on consumer GPUs, QLoRA is your friend. I’ve used it to run 7B models comfortably on 24GB VRAM.

🔸 Instruction Tuning on Domain Data

Generic instruction data is good for warmup — but fine-tuning on your actual use-case prompts (with edge cases) makes the difference between a flashy demo and a production model.

🔸 RAG on Top

Gemma alone isn’t enough for anything requiring factual grounding (e.g., medical, legal, enterprise data). Plugging it into a RAG setup with Haystack or LlamaIndex makes it actually useful.

🔸 Merge Adapters for Multi-Domain Use

If you’re training domain-specific LoRA adapters (e.g., finance + health), you can merge them using PEFT. Just keep an eye on conflicts in target modules.

Leave a Comment