LoRA 를 이용한 VLM 모델 Fine tuning 방법


LoRA 를 이용한 VLM 모델 Fine-Tuning 방법

개요

아래는 Hugging Face Ecosystem (TRL)을 활용한 Vision Language Model(Qwen2-VL-7B) 파인 튜닝 방법에 대해 설명합니다. 

1. 사전 준비

라이브러리 설치

최신 버전의 PyTorch는 현재 문제가 있어 특정 버전을 사용해야 합니다.

pip install -U -q transformers trl datasets bitsandbytes peft qwen-vl-utils wandb accelerate
pip install -q torch==2.4.1+cu121 torchvision==0.19.1+cu121 torchaudio==2.4.1+cu121 --extra-index-url https://download.pytorch.org/whl/cu121

2. 데이터셋 준비

본 튜토리얼에서는 한국 전통 음식 이미지 및 설명이 포함된 letgoofthepizza/traditional-korea-food-captioning 데이터셋을 사용합니다.

from datasets import load_dataset

dataset_id = "letgoofthepizza/traditional-korea-food-captioning"
train_dataset = load_dataset(dataset_id, split="train[:100]")
eval_dataset = load_dataset(dataset_id, split="train[:100]")
test_dataset = load_dataset(dataset_id, split="train[:100]")

3. 모델 로드 및 성능 확인

import torch
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor

model_id = "Qwen/Qwen2-VL-2B-Instruct"
model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
processor = Qwen2VLProcessor.from_pretrained(model_id)

4. LoRA 기반 Fine-Tuning 설정

4.1 Quantization 설정

from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_use_double_quant=True, 
    bnb_4bit_quant_type="nf4", 
    bnb_4bit_compute_dtype=torch.bfloat16
)

4.2 LoRA 설정

from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM",
)

peft_model = get_peft_model(model, peft_config)

5. Supervised Fine-Tuning (SFT) 설정

from trl import SFTConfig

training_args = SFTConfig(
    output_dir="qwen2-7b-instruct-trl-sft-K-Food",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    learning_rate=2e-4,
    lr_scheduler_type="constant",
    logging_steps=10,
    eval_steps=10,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    load_best_model_at_end=True,
    bf16=True,
    tf32=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    push_to_hub=False,
    report_to="wandb",
    remove_unused_columns=False,
)

6. 학습 진행

from trl import SFTTrainer
import wandb

wandb.init(
    project="qwen2-7b-instruct-trl-sft-K-Food",
    name="qwen2-7b-instruct-trl-sft-K-Food",
    config=training_args,
)

trainer = SFTTrainer(
    model=peft_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    tokenizer=processor.tokenizer,
)

trainer.train()
trainer.save_model(training_args.output_dir)

7. 마무리

LoRA를 활용한 Vision Language Model의 Fine-Tuning 과정을 설명하였습니다. 이를 통해 한국 음식 이미지 데이터셋을 학습하여 보다 정교한 시각 언어 모델을 구축할 수 있습니다. 그러나 15만개 정도의 데이터를 Qwen2-VL-7B 모델에 학습한 결과는 실망스러운 수준이었습니다.

'''

Fine-Tuning a Vision Language Model (Qwen2-VL-7B) with the Hugging Face Ecosystem (TRL)

https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl


Dataset

letgoofthepizza/traditional-korea-food-captioning

https://huggingface.co/datasets/letgoofthepizza/traditional-korea-food-captioning

ref. https://huggingface.co/datasets/HuggingFaceM4/ChartQA


We’ll also need to install an earlier version of PyTorch, as the latest version has an issue that currently prevents this notebook from running correctly. You can learn more about the issue here and consider updating to the latest version once it’s resolved.


pip install  -U -q transformers trl datasets bitsandbytes peft qwen-vl-utils wandb accelerate

pip install -q torch==2.4.1+cu121 torchvision==0.19.1+cu121 torchaudio==2.4.1+cu121 --extra-index-url https://download.pytorch.org/whl/cu121

'''


system_message = """You are a Vision Language Model specialized in interpreting visual data from Korean food images.

Your task is to analyze the provided Korean food image and respond to queries with detail descriptions with Koren language.

The images include a variety of Korean foods and text.

Focus on answers based on the detail information of Korean language. Avoid additional explanation unless absolutely necessary."""


def format_data(sample):

    return [

        {

            "role": "system",

            "content": [{"type": "text", "text": system_message}],

        },

        {

            "role": "user",

            "content": [

                {

                    "type": "image",

                    "image": sample["image"],

                },

                {

                    "type": "text",

                    "text": sample["text"],

                },

            ],

        },

        {

            "role": "assistant",

            "content": [

                {

                    "type": "text",

                    "text": sample["text"],

                }

            ],

        },

    ]


from datasets import load_dataset


dataset_id = "letgoofthepizza/traditional-korea-food-captioning"


train_dataset = load_dataset(dataset_id, split="train[:100]")

eval_dataset = load_dataset(dataset_id, split="train[:100]")

test_dataset = load_dataset(dataset_id, split="train[:100]")


train_dataset = [format_data(sample) for sample in train_dataset]

eval_dataset = [format_data(sample) for sample in eval_dataset]

test_dataset = [format_data(sample) for sample in test_dataset]


'''

# train_dataset[0]

[{'role': 'system', 'content': [{'type': 'text', 'text': 'You are a Vision Language Model specialized in interpreting visual data from Korean food  images.\nYour task is to analyze the provided Korean food image and respond to queries with detail descriptions with Koren language.\nThe images include a variety of Korean foods and text.\nFocus on answers based on the detail information of Korean language. Avoid additional explanation unless absolutely necessary.'}]}, 

 {'role': 'user', 'content': [{'type': 'image', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=773x512 at 0x7FF0E45EE860>}, {'type': 'text', 'text': '콩나물과 참깨로 장식된 막국수 한 그릇, 흰 쌀밥과 가위가 함께 제공됩니다.'}]}, 

 {'role': 'assistant', 'content': [{'type': 'text', 'text': '콩나물과 참깨로 장식된 막국수 한 그릇, 흰 쌀밥과 가위가 함께 제공됩니다.'}]}]

'''


'''

3. Load Model and Check Performance!

Now that we’ve loaded the dataset, let’s start by loading the model and evaluating its performance using a sample from the dataset. We’ll be using Qwen/Qwen2-VL-7B-Instruct, a Vision Language Model (VLM) capable of understanding both visual data and text.


If you’re exploring alternatives, consider these open-source options:


Meta AI’s Llama-3.2-11B-Vision

Mistral AI’s Pixtral-12B

Allen AI’s Molmo-7B-D-0924

Additionally, you can check the Leaderboards, such as the WildVision Arena or the OpenVLM Leaderboard, to find the best-performing VLMs.

'''

import torch

from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor


# model_id = "Qwen/Qwen2-VL-7B-Instruct"

model_id = "Qwen/Qwen2-VL-2B-Instruct"


model = Qwen2VLForConditionalGeneration.from_pretrained(

    model_id,

    device_map="auto",

    torch_dtype=torch.bfloat16,

)


processor = Qwen2VLProcessor.from_pretrained(model_id)


# To evaluate the model’s performance, we’ll use a sample from the dataset. First, let’s take a look at the internal structure of this sample.

# train_dataset[0]


# We’ll use the sample without the system message to assess the VLM’s raw understanding. Here’s the input we will use:

# train_dataset[0][1:2]


# Now, let’s take a look at the chart corresponding to the sample. Can you answer the query based on the visual information?

# train_dataset[0][1]["content"][0]["image"]

# 이미지가 나온다


# Let’s create a method that takes the model, processor, and sample as inputs to generate the model’s answer. 

# This will allow us to streamline the inference process and easily evaluate the VLM’s performance.

from qwen_vl_utils import process_vision_info


def generate_text_from_sample(model, processor, sample, max_new_tokens=1024, device="cuda"):

    # Prepare the text input by applying the chat template

    text_input = processor.apply_chat_template(

        sample[1:2], tokenize=False, add_generation_prompt=True  # Use the sample without the system message

    )


    # Process the visual input from the sample

    image_inputs, _ = process_vision_info(sample)


    # Prepare the inputs for the model

    model_inputs = processor(

        text=[text_input],

        images=image_inputs,

        return_tensors="pt",

    ).to(device)  # Move inputs to the specified device


    # Generate text with the model

    generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)


    # Trim the generated ids to remove the input ids

    trimmed_generated_ids = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)]


    # Decode the output text

    output_text = processor.batch_decode(

        trimmed_generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False

    )


    return output_text[0]  # Return the first decoded output text


# Example of how to call the method with sample:

# While the model successfully retrieves the correct visual information, it struggles to answer the question accurately. 

# This indicates that fine-tuning might be the key to enhancing its performance. Let’s proceed with the fine-tuning process!

# output = generate_text_from_sample(model, processor, train_dataset[0])

# print(output)


# Remove Model and Clean GPU

# Before we proceed with training the model in the next section, let’s clear the current variables and clean the GPU to free up resources.

import gc

import time


def clear_memory():

    # Delete variables if they exist in the current global scope

    if "inputs" in globals():

        del globals()["inputs"]

    if "model" in globals():

        del globals()["model"]

    if "processor" in globals():

        del globals()["processor"]

    if "trainer" in globals():

        del globals()["trainer"]

    if "peft_model" in globals():

        del globals()["peft_model"]

    if "bnb_config" in globals():

        del globals()["bnb_config"]

    time.sleep(2)


    # Garbage collection and clearing CUDA memory

    gc.collect()

    time.sleep(2)

    torch.cuda.empty_cache()

    torch.cuda.synchronize()

    time.sleep(2)

    gc.collect()

    time.sleep(2)


    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")


# clear_memory()


# 4. Fine-Tune the Model using TRL

# 4.1 Load the Quantized Model for Training 

# Next, we’ll load the quantized model using bitsandbytes.


from transformers import BitsAndBytesConfig


# BitsAndBytesConfig int-4 config

bnb_config = BitsAndBytesConfig(

    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16

)


# Load model and tokenizer

model = Qwen2VLForConditionalGeneration.from_pretrained(

    model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=bnb_config

)

processor = Qwen2VLProcessor.from_pretrained(model_id)


# 4.2 Set Up QLoRA and SFTConfig 

# Next, we will configure QLoRA for our training setup. 

# QLoRA enables efficient fine-tuning of large language models while significantly reducing the memory footprint compared to traditional methods. 

# Unlike standard LoRA, which reduces memory usage by applying a low-rank approximation, 

# QLoRA takes it a step further by quantizing the weights of the LoRA adapters. 

# This leads to even lower memory requirements and improved training efficiency, 

# making it an excellent choice for optimizing our model’s performance without sacrificing quality.

from peft import LoraConfig, get_peft_model


# Configure LoRA

peft_config = LoraConfig(

    lora_alpha=16,

    lora_dropout=0.05,

    r=8,

    bias="none",

    target_modules=["q_proj", "v_proj"],

    task_type="CAUSAL_LM",

)


# Apply PEFT model adaptation

peft_model = get_peft_model(model, peft_config)


# Print trainable parameters

# 학습전: trainable params: 2,523,136 || all params: 8,293,898,752 || trainable%: 0.0304

# peft_model.print_trainable_parameters()


# We will use Supervised Fine-Tuning (SFT) to refine our model’s performance on the task at hand. 

# To do this, we’ll define the training arguments using the SFTConfig class from the TRL library. 

# SFT allows us to provide labeled data, helping the model learn to generate more accurate responses based on the input it receives. 

# This approach ensures that the model is tailored to our specific use case, leading to better performance in understanding and responding to visual queries.

# trl은 "Transformers Reinforcement Learning"의 약자로, Hugging Face의 Transformers 라이브러리와 함께 사용되는 라이브러리입니다. 

# 이 라이브러리는 주로 강화 학습을 통해 언어 모델을 fine-tuning하는 데 사용됩니다. 

# trl은 특히 대화형 AI 모델을 훈련시키는 데 유용하며, Supervised Fine-Tuning (SFT) 및 Reinforcement Learning from Human Feedback (RLHF)와 같은 기법을 지원합니다.

from trl import SFTConfig


# Configure training arguments

training_args = SFTConfig(

    output_dir="qwen2-7b-instruct-trl-sft-K-Food",  # Directory to save the model

    num_train_epochs=3,  # Number of training epochs

    per_device_train_batch_size=4,  # Batch size for training

    per_device_eval_batch_size=4,  # Batch size for evaluation

    gradient_accumulation_steps=8,  # Steps to accumulate gradients

    gradient_checkpointing=True,  # Enable gradient checkpointing for memory efficiency

    # Optimizer and scheduler settings

    optim="adamw_torch_fused",  # Optimizer type

    learning_rate=2e-4,  # Learning rate for training

    lr_scheduler_type="constant",  # Type of learning rate scheduler

    # Logging and evaluation

    logging_steps=10,  # Steps interval for logging

    eval_steps=10,  # Steps interval for evaluation

    eval_strategy="steps",  # Strategy for evaluation

    save_strategy="steps",  # Strategy for saving the model

    save_steps=20,  # Steps interval for saving

    metric_for_best_model="eval_loss",  # Metric to evaluate the best model

    greater_is_better=False,  # Whether higher metric values are better

    load_best_model_at_end=True,  # Load the best model after training

    # Mixed precision and gradient settings

    bf16=True,  # Use bfloat16 precision

    tf32=True,  # Use TensorFloat-32 precision

    max_grad_norm=0.3,  # Maximum norm for gradient clipping

    warmup_ratio=0.03,  # Ratio of total steps for warmup

    # Hub and reporting

    push_to_hub=False,  # Whether to push model to Hugging Face Hub

    report_to="wandb",  # Reporting tool for tracking metrics

    # Gradient checkpointing settings

    gradient_checkpointing_kwargs={"use_reentrant": False},  # Options for gradient checkpointing

    # Dataset configuration

    dataset_text_field="",  # Text field in dataset

    dataset_kwargs={"skip_prepare_dataset": True},  # Additional dataset options

    # max_seq_length=1024  # Maximum sequence length for input

)


training_args.remove_unused_columns = False  # Keep unused columns in dataset


# 4.3 Training the Model

# We will log our training progress using Weights & Biases (W&B). Let’s connect our notebook to W&B to capture essential information during training.


import wandb


wandb.init(

    project="qwen2-7b-instruct-trl-sft-K-Food",  # change this

    name="qwen2-7b-instruct-trl-sft-K-Food",  # change this

    config=training_args,

)


# We need a collator function to properly retrieve and batch the data during the training procedure. 

# This function will handle the formatting of our dataset inputs, ensuring they are correctly structured for the model. Let’s define the collator function below.

# Create a data collator to encode text and image pairs

def collate_fn(examples):

    # Get the texts and images, and apply the chat template

    texts = [

        processor.apply_chat_template(example, tokenize=False) for example in examples

    ]  # Prepare texts for processing

    image_inputs = [process_vision_info(example)[0] for example in examples]  # Process the images to extract inputs


    # Tokenize the texts and process the images

    batch = processor(

        text=texts, images=image_inputs, return_tensors="pt", padding=True

    )  # Encode texts and images into tensors


    # The labels are the input_ids, and we mask the padding tokens in the loss computation

    labels = batch["input_ids"].clone()  # Clone input IDs for labels

    labels[labels == processor.tokenizer.pad_token_id] = -100  # Mask padding tokens in labels


    # Ignore the image token index in the loss computation (model specific)

    if isinstance(processor, Qwen2VLProcessor):  # Check if the processor is Qwen2VLProcessor

        image_tokens = [151652, 151653, 151655]  # Specific image token IDs for Qwen2VLProcessor

    else:

        image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]  # Convert image token to ID


    # Mask image token IDs in the labels

    for image_token_id in image_tokens:

        labels[labels == image_token_id] = -100  # Mask image token IDs in labels


    batch["labels"] = labels  # Add labels to the batch


    return batch  # Return the prepared batch


# we will define the SFTTrainer, which is a wrapper around the transformers.Trainer class and inherits its attributes and methods. 

# This class simplifies the fine-tuning process by properly initializing the PeftModel when a PeftConfig object is provided. 

# By using SFTTrainer, we can efficiently manage the training workflow and ensure a smooth fine-tuning experience for our Vision Language Model.

from trl import SFTTrainer


trainer = SFTTrainer(

    model=model,

    args=training_args,

    train_dataset=train_dataset,

    eval_dataset=eval_dataset,

    data_collator=collate_fn,

    peft_config=peft_config,

    tokenizer=processor.tokenizer,

)


trainer.train()

trainer.save_model(training_args.output_dir)


print("TILL HERE!!")




댓글 쓰기

0 댓글