Skip to content

stf Example not working #1693

@TATOAO

Description

@TATOAO

Current version:
commit 84156f1 (HEAD -> main, origin/main, origin/HEAD)
updated on Jun 3 2024

I tired to run the following script in example/scripts/stf.py:

# regular:
python examples/scripts/sft.py \
    --model_name_or_path="facebook/opt-350m" \
    --report_to="wandb" \
    --learning_rate=1.41e-5 \
    --per_device_train_batch_size=64 \
    --gradient_accumulation_steps=16 \
    --output_dir="sft_openassistant-guanaco" \
    --logging_steps=1 \
    --num_train_epochs=3 \
    --max_steps=-1 \
    --push_to_hub \
    --gradient_checkpointing

Error message:

Map:   0%|                                                                                                                                                                                                 | 0/9846 [00:00<?, ? examples/s]
Traceback (most recent call last):
  File "/Users/tatoaoliang/Downloads/Work/trl/examples/scripts/sft.py", line 137, in <module>
    trainer = SFTTrainer(
              ^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/trl/trl/trainer/sft_trainer.py", line 360, in __init__
    train_dataset = self._prepare_dataset(
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/trl/trl/trainer/sft_trainer.py", line 506, in _prepare_dataset
    return self._prepare_non_packed_dataloader(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/trl/trl/trainer/sft_trainer.py", line 574, in _prepare_non_packed_dataloader
    tokenized_dataset = dataset.map(
                        ^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 602, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 567, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3156, in map
    for rank, done, content in Dataset._map_single(**dataset_kwargs):
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3548, in _map_single
    batch = apply_function_on_filtered_inputs(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/arrow_dataset.py", line 3417, in apply_function_on_filtered_inputs
    processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/trl/trl/trainer/sft_trainer.py", line 545, in tokenize
    element[dataset_text_field] if not use_formatting_func else formatting_func(element),
    ~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/Users/tatoaoliang/Downloads/Work/virv/llama/lib/python3.11/site-packages/datasets/formatting/formatting.py", line 271, in __getitem__
    value = self.data[key]
            ~~~~~~~~~^^^^^
KeyError: None

I check the codes, here is the original snippet of _prepare_non_packed_dataloader function in "trl/trainer/sft_trainer.py" 529 line:

    def _prepare_non_packed_dataloader(
        self,
        tokenizer,
        dataset,
        dataset_text_field,
        max_seq_length,
        formatting_func=None,
        add_special_tokens=True,
        remove_unused_columns=True,
    ):
        #### debugger told me that formatting_func is None and dataset_text_field is None
        use_formatting_func = formatting_func is not None and dataset_text_field is None
        self._dataset_sanity_checked = False

       #### so use_formatting_func is False  
        # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
        def tokenize(element):
            outputs = tokenizer(
                element[dataset_text_field] if not use_formatting_func else formatting_func(element),
                add_special_tokens=add_special_tokens,
                truncation=True,
                padding=False,
                max_length=max_seq_length,
                return_overflowing_tokens=False,
                return_length=False,
            )

So it seems that formatting_func should not be None.

it is defined in sft_trainer.py , line 313

formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer)

and get_formatting_func_from_dataset is in trl/extras/dataset_formatting.py, line 60:

def get_formatting_func_from_dataset(
    dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer
) -> Optional[Callable]:
    r"""
    Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
    - `ChatML` with [{"role": str, "content": str}]
    - `instruction` with [{"prompt": str, "completion": str}]

    Args:
        dataset (Dataset): User dataset
        tokenizer (AutoTokenizer): Tokenizer used for formatting

    Returns:
        Callable: Formatting function if the dataset format is supported else None
    """
    if isinstance(dataset, Dataset):
        if "messages" in dataset.features:
            if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
                logging.info("Formatting dataset with chatml format")
                return conversations_formatting_function(tokenizer, "messages")
        if "conversations" in dataset.features:
            if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
                logging.info("Formatting dataset with chatml format")
                return conversations_formatting_function(tokenizer, "conversations")
        elif dataset.features == FORMAT_MAPPING["instruction"]:
            logging.info("Formatting dataset with instruction format")
            return instructions_formatting_function(tokenizer)

    return None

But openassistant-guanaco dataset only has the feature "text", so it is incompatible.

https://huggingface.co/datasets/timdettmers/openassistant-guanaco?row=0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions