Skip to content

KTO finetuning - float division by zero #1651

@jetlime

Description

@jetlime

I am attempting to finetune the LLama3-8B-Instruct model on the UNSW-NB15 dataset.

dataset = load_dataset("Jetlime/NF-UNSW-NB15-v2", streaming=False, split="train")

# Model
base_model = "meta-llama/Meta-Llama-3-8B-Instruct"

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, token=HUGGING_FACE_READ_TOKEN)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    attn_implementation=attn_implementation,
    token=HUGGING_FACE_READ_TOKEN
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

# Use only a small subset of the training set for a first finetuning trial
dataset = dataset.train_test_split(test_size=0.95, seed=123, stratify_by_column="Attack")
dataset_finetuning = dataset["train"]
dataset_finetuning

# Dataset({
#    features: ['input', 'output', 'Attack'],
#    num_rows: 113538
#})

# Creating the dataset columns required by the KTO finetuner
import random
def format_chat_template(row):
    row['prompt'] = row["input"]
    if random.randrange(0,1):
        row["label"] = False
        if row["output"] == 1:
            row["completion"] = '0'
        else:
            row["completion"] = '1'
    else:
        row["label"] = True
        row["completion"] = str(row["output"])
    return row

dataset_finetuning = dataset_finetuning.map(
    format_chat_template, num_proc=os.cpu_count()
)
dataset_finetuning

When I then perform the training,

training_args = KTOConfig(
    beta=0.1,
    desirable_weight=1.0,
    undesirable_weight=1.0,
    output_dir="./results-KTO/"
)

kto_trainer = KTOTrainer(
    model,
    args=training_args,
    train_dataset=dataset_finetuning,
    tokenizer=tokenizer,
)

kto_trainer.train()

# Tokenizing train dataset: 100%|██████████| 113538/113538 [01:26<00:00, 1313.22 examples/s]
# Extracting KL train dataset: 100%|██████████| 113538/113538 [00:08<00:00, 14077.15 examples/s]
# Processing tokenized train dataset: 100%|██████████| 113538/113538 [00:42<00:00, 2678.53 examples/s]
# Processing tokenized train KL dataset: 100%|██████████| 113538/113538 [00:40<00:00, 2805.04 examples/s]
# Filtering desirable examples: 100%|██████████| 113538/113538 [01:37<00:00, 1163.86 examples/s]
# Filtering undesirable examples: 100%|██████████| 113538/113538 [01:36<00:00, 1170.86 examples/s]

I obtain a Zero Division Error:

ZeroDivisionError                         Traceback (most recent call last)
Cell In[7], [line 8](vscode-notebook-cell:?execution_count=7&line=8)
      [1](vscode-notebook-cell:?execution_count=7&line=1) training_args = KTOConfig(
      [2](vscode-notebook-cell:?execution_count=7&line=2)     beta=0.1,
      [3](vscode-notebook-cell:?execution_count=7&line=3)     desirable_weight=1.0,
      [4](vscode-notebook-cell:?execution_count=7&line=4)     undesirable_weight=1.0,
      [5](vscode-notebook-cell:?execution_count=7&line=5)     output_dir="./results-KTO/"
      [6](vscode-notebook-cell:?execution_count=7&line=6) )
----> [8](vscode-notebook-cell:?execution_count=7&line=8) kto_trainer = KTOTrainer(
      [9](vscode-notebook-cell:?execution_count=7&line=9)     model,
     [10](vscode-notebook-cell:?execution_count=7&line=10)     args=training_args,
     [11](vscode-notebook-cell:?execution_count=7&line=11)     train_dataset=dataset_finetuning,
     [12](vscode-notebook-cell:?execution_count=7&line=12)     tokenizer=tokenizer,
     [13](vscode-notebook-cell:?execution_count=7&line=13) )
     [15](vscode-notebook-cell:?execution_count=7&line=15) kto_trainer.train()
     [16](vscode-notebook-cell:?execution_count=7&line=16) kto_trainer.save_model(new_model)

File ~/.local/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:599, in KTOTrainer.__init__(self, model, ref_model, args, train_dataset, eval_dataset, tokenizer, data_collator, model_init, callbacks, optimizers, preprocess_logits_for_metrics, peft_config, compute_metrics, model_adapter_name, ref_adapter_name)
    [597](https://vscode-remote+ssh-002dremote-002b10-002e35-002e14-002e201.vscode-resource.vscode-cdn.net/home/paul/Documents/llm-nids/~/.local/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:597) des_weight_lower_bound = round((len(undesirable) * self.undesirable_weight / len(desirable)) * 1, 2)
    [598](https://vscode-remote+ssh-002dremote-002b10-002e35-002e14-002e201.vscode-resource.vscode-cdn.net/home/paul/Documents/llm-nids/~/.local/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:598) des_weight_upper_bound = round((len(undesirable) * self.undesirable_weight / len(desirable)) * 1.33, 2)
--> [599](https://vscode-remote+ssh-002dremote-002b10-002e35-002e14-002e201.vscode-resource.vscode-cdn.net/home/paul/Documents/llm-nids/~/.local/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:599) und_weight_lower_bound = round((len(desirable) * self.desirable_weight / len(undesirable)) / 1.33, 2)
    [600](https://vscode-remote+ssh-002dremote-002b10-002e35-002e14-002e201.vscode-resource.vscode-cdn.net/home/paul/Documents/llm-nids/~/.local/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:600) und_weight_upper_bound = round((len(desirable) * self.desirable_weight / len(undesirable)) / 1, 2)
    [602](https://vscode-remote+ssh-002dremote-002b10-002e35-002e14-002e201.vscode-resource.vscode-cdn.net/home/paul/Documents/llm-nids/~/.local/lib/python3.10/site-packages/trl/trainer/kto_trainer.py:602) des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound

ZeroDivisionError: float division by zero

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions