In our last post, we introduced model distillation, and why you may want to distill your model for efficient task-specific deployment.
Today, we’ll dive into some of the technical details behind distillation, and how you can play around with the process yourself. This guide hopes to provide a detailed, practical walkthrough of the distillation process, considering both platform based solutions, and sample code if you choose to perform the distillation directly on your hardware.
Distillation Platforms
At Proxis, we’re building the most simple LLM distillation and fine-tuning platform in existence. Our customers use us to quickly build state-of-the-art models for their task, and seamlessly deploy those models into production with our serverless backend.
While we’re not quite ready to open to the public, we are currently running a closed beta for clients looking to distill and fine tune models on their data. If interested, please fill out our closed-beta invitation request form, and we’ll get back to you within a day if approved.
Distilling a model from scratch
If you’d prefer to set up your own distillation pipeline on your own compute infrastructure, we’ll provide some step by step instructions to
Step 1: Setting Up the Environment
Import Libraries
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
Check Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
Step 2: Preparing the Dataset
Choose a Task and Dataset
For this guide, we'll use the IMDb movie reviews dataset for sentiment analysis.
from datasets import load_dataset
dataset = load_dataset('imdb')
Preprocess the Data
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
def preprocess(example):
return tokenizer(example['text'], truncation=True, padding='max_length', max_length=128)
encoded_dataset = dataset.map(preprocess, batched=True)
Create DataLoaders
from torch.utils.data import DataLoader
train_dataset = encoded_dataset['train'].shuffle(seed=42).select(range(2000)) # Subset for quick training
test_dataset = encoded_dataset['test'].shuffle(seed=42).select(range(500))
train_loader = DataLoader(train_dataset, batch_size=16)
test_loader = DataLoader(test_dataset, batch_size=16)
Step 3: Selecting the Teacher Model
Choose a pre-trained model suitable for your task.
teacher_model_name = 'bert-base-uncased'
teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_model_name).to(device)
Step 4: Designing the Student Model
Customize the Student Model
Reduce the number of layers and hidden dimensions to create a smaller model.
from transformers import BertConfig, BertForSequenceClassification
student_config = BertConfig.from_pretrained(teacher_model_name)
student_config.num_hidden_layers = 4 # Reduce layers from 12 to 4
student_config.hidden_size = 256 # Reduce hidden size from 768 to 256
student_config.num_attention_heads = 4 # Reduce attention heads from 12 to 4
student_model = BertForSequenceClassification(student_config).to(device)
Initialize Student Model Weights
Optionally, initialize student model weights from the teacher model.
teacher_state_dict = teacher_model.state_dict()
student_state_dict = student_model.state_dict()
# Copy matching parameters
for name in student_state_dict.keys():
if name in teacher_state_dict and student_state_dict[name].shape == teacher_state_dict[name].shape:
student_state_dict[name] = teacher_state_dict[name]
student_model.load_state_dict(student_state_dict)
Step 5: Implementing the Distillation Process
Define Loss Functions
Hard Loss: Between student predictions and true labels.
Soft Loss: Between student and teacher predictions.
import torch.nn as nn
hard_loss_fn = nn.CrossEntropyLoss()
soft_loss_fn = nn.KLDivLoss(reduction='batchmean')
Set Hyperparameters
temperature = 2.0 # Softens probability distribution
alpha = 0.5 # Balances hard and soft losses
Prepare Optimizer and Scheduler
from transformers import AdamW, get_linear_schedule_with_warmup
optimizer = AdamW(student_model.parameters(), lr=5e-5)
num_epochs = 3
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
Step 6: Training the Student Model
Training Loop
for epoch in range(num_epochs):
student_model.train()
total_loss = 0
for batch in train_loader:
inputs = {k: v.to(device) for k, v in batch.items() if k in tokenizer.model_input_names}
labels = batch['label'].to(device)
# Teacher predictions
with torch.no_grad():
teacher_outputs = teacher_model(**inputs)
# Student predictions
student_outputs = student_model(**inputs)
# Calculate losses
hard_loss = hard_loss_fn(student_outputs.logits, labels)
# Soften logits
teacher_logits = teacher_outputs.logits / temperature
student_logits = student_outputs.logits / temperature
soft_loss = soft_loss_fn(
nn.functional.log_softmax(student_logits, dim=-1),
nn.functional.softmax(teacher_logits, dim=-1)
) * (temperature ** 2)
# Combined loss
loss = alpha * hard_loss + (1 - alpha) * soft_loss
total_loss += loss.item()
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
avg_loss = total_loss / len(train_loader)
print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}')
Step 7: Evaluating the Distilled Model
Evaluation Function
def evaluate(model, loader):
model.eval()
total_correct = 0
total_examples = 0
with torch.no_grad():
for batch in loader:
inputs = {k: v.to(device) for k, v in batch.items() if k in tokenizer.model_input_names}
labels = batch['label'].to(device)
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=-1)
total_correct += (predictions == labels).sum().item()
total_examples += labels.size(0)
accuracy = total_correct / total_examples
return accuracy
Evaluate Teacher and Student Models
teacher_accuracy = evaluate(teacher_model, test_loader)
student_accuracy = evaluate(student_model, test_loader)
print(f'Teacher Model Accuracy: {teacher_accuracy * 100:.2f}%')
print(f'Student Model Accuracy: {student_accuracy * 100:.2f}%')
Compare Model Sizes
teacher_params = sum(p.numel() for p in teacher_model.parameters())
student_params = sum(p.numel() for p in student_model.parameters())
print(f'Teacher Model Parameters: {teacher_params / 1e6:.2f}M')
print(f'Student Model Parameters: {student_params / 1e6:.2f}M')
Measure Inference Time
import time
def measure_inference_time(model, loader):
model.eval()
start_time = time.time()
with torch.no_grad():
for batch in loader:
inputs = {k: v.to(device) for k, v in batch.items() if k in tokenizer.model_input_names}
model(**inputs)
end_time = time.time()
return end_time - start_time
teacher_time = measure_inference_time(teacher_model, test_loader)
student_time = measure_inference_time(student_model, test_loader)
print(f'Teacher Inference Time: {teacher_time:.2f}s')
print(f'Student Inference Time: {student_time:.2f}s')
Step 8: Fine-Tuning and Deployment
Fine-Tuning (Optional)
If performance is not satisfactory, consider fine-tuning:
Adjust hyperparameters (e.g., alpha, temperature, learning rate).
Increase the size of the training dataset.
Use data augmentation techniques.
Save the Student Model
student_model.save_pretrained('distilled_student_model')
tokenizer.save_pretrained('distilled_student_model')
Deploy the Model
You can deploy the model using frameworks like FastAPI or Flask.
from fastapi import FastAPI, Request
import uvicorn
app = FastAPI()
# Load the model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained('distilled_student_model').to(device)
tokenizer = AutoTokenizer.from_pretrained('distilled_student_model')
@app.post("/predict")
async def predict(request: Request):
data = await request.json()
inputs = tokenizer(data['text'], return_tensors='pt', truncation=True, padding='max_length', max_length=128).to(device)
outputs = model(**inputs)
prediction = torch.argmax(outputs.logits, dim=-1).item()
return {'prediction': prediction}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
Go Forth
Distillation is a crucial step in productionizing your models. We hope this post helps you understand the steps you can take yourself to distill your models, or points you in the right direction towards an efficient distillation platform such as Proxis.