AlgorithmCodeConcept

Deep dive: Knowledge distillation

In this deep dive seri, we are going over Knowledge distillation (KD), Partial Function Application, Learning rate Scheduler for LLM Finetuning, and Image transformation with functional module.

Knowledge distillation

Knowledge Distillation (KD) is a technique to transfer knowledge from a large, complex model (teacher) to a smaller, more efficient one (student). The essence of KD lies in its ability to capture the rich representations learned by the teacher model and impart them to the student, enabling the latter to perform tasks with comparable accuracy while being more resource-efficient.

Implementing KD begins with training the teacher model to its full capacity. Next, the student model is trained using a specific loss function.

This loss function is based not only the hard labels of the training data but also on the soft outputs (probabilities) generated by the teacher model. These soft outputs convey the teacher’s confidence across various classes, offering a more nuanced understanding than hard labels alone.

The process typically utilizes a temperature parameter to soften the probabilities, making the distribution more informative and easier for the student to learn from. This parameter is key; it softens the output distributions, making them richer in information and easier for the student to assimilate.

Some popular teacher-student pairs include BERT (teacher) and DistilBERT (student) for NLP tasks, or ResNet50 and MobileNet for image classification. Below is an example of implementation of the KD process.

import torch

import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, targets, temperature):

    soft_loss = F.kl_div(F.log_softmax(student_logits / temperature, dim=1),

                         F.softmax(teacher_logits / temperature, dim=1),

                         reduction='batchmean') * (temperature ** 2)

    hard_loss = F.cross_entropy(student_logits, targets)

    return soft_loss + hard_loss



# Assuming teacher_model and student_model are defined and loaded

# Define temperature and alpha for balancing the loss components

temperature = 5.0

alpha = 0.5

# Compute teacher and student outputs

teacher_logits = teacher_model(input_data)

student_logits = student_model(input_data)


#
Compute distillation loss

loss = distillation_loss(student_logits, teacher_logits, 
                        targets, temperature)

# Backpropagate and update student model

loss.backward()

optimizer.step()

Partial Function Application

The partial function from the functools module in Python allows you to create a new function by partially applying arguments to an existing function. This can be particularly useful when you need to reuse a function with certain arguments fixed, or when you want to create specialized versions of a function without rewriting it entirely.

from functools import partial

def my_function(a, b, c):
    return a + b * c

# Create a new function with 'a' fixed to 10
new_function = partial(my_function, 10)

# Call the new
function with only 'b' and 'c' arguments
result = new_function(2, 3)
print(result)  # Output: 16

In this example, new_function is a partially applied version of my_function where the first argument a is fixed to 10. When you call new_function with 2 and 3, it effectively calls my_function(10, 2, 3), resulting in the output 16.

Partial function application can be particularly useful in machine learning tasks, where you often need to apply the same function with different sets of arguments or parameters. For example, you could create specialized versions of a model evaluation function with different evaluation metrics or data preprocessing steps partially applied.

Learning rate Scheduler for LLM Finetuning

This tutorial provides a from-scratch implementation that explains how to implement the learning rate schedule with linear warmup and half-cycle cosine decay used for fine-tuning and pre-training LLMs.

Image transformation with functional module

The torchvision.transforms.functional module in PyTorch offers a hands-on approach for image transformations, allowing for precise control over each operation, crucial for tasks requiring synchronized changes to paired data, like images and masks in segmentation. It enables custom, conditional transformations beyond the transforms.

import torch
import torchvision.transforms.functional as TF
import random

def synchronized_transform(image, mask): 

    # Randomly choose the start point for the crop 
    i, j, h, w = TF.RandomCrop.get_params(image,

                                          output_size=(256, 256)) 

    image_cropped = TF.crop(image, i, j, h, w) 
    mask_cropped = TF.crop(mask, i, j, h, w) 

    # Apply any other synchronized transformations
here 

    if random.random() > 0.5: 

        image_cropped = TF.hflip(image_cropped) 

        mask_cropped = TF.hflip(mask_cropped) 
    return image_cropped, mask_cropped

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses User Verification plugin to reduce spam. See how your comment data is processed.