generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Closed
Description
I might be misunderstanding something here, but it looks like PPOTrainer
always sets self.data_collator
to DataCollatorForLanguageModeling
, see this permalink and might be inconsistently using a data collator if it is provided.
def __init__(
self,
config: Optional[PPOConfig] = None,
model: Optional[PreTrainedModelWrapper] = None,
ref_model: Optional[PreTrainedModelWrapper] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
data_collator: Optional[typing.Callable] = None, # <--- takes data collator as a kwarg here
num_shared_layers: Optional[int] = None,
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
):
It uses the provided data_collator
if a dataset
is provided...
if self.dataset is not None:
self.dataloader = self.prepare_dataloader(self.dataset, data_collator)
But ignores it when assigning self.data_collator
:
# Step 3: Initialize optimizer and data collator
self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
if optimizer is None:
self.optimizer = Adam(
filter(lambda p: p.requires_grad, self.model.parameters()),
lr=self.config.learning_rate,
)
else:
self.optimizer = optimizer
The self.data_collator
collator is the one is used in things like prepare_model_inputs
:
def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
if self.is_encoder_decoder:
input_data = self.data_collator(
[{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries]
).to(self.current_device)
decoder_inputs = self.data_collator(
[{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses]
).to(self.current_device)
input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
else:
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = self.data_collator(
[{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids]
).to(self.current_device)
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data
Metadata
Metadata
Assignees
Labels
No labels