-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Current handling of response_masks
inside batch_forward_pass
function does not take padding into consideration which results with shape unmatch during masking. I think response tokens should not be concatenated with a torch.zeros(query_length)
and masking operation should be done without slicing.
An example with batch size of 2:
- First sample in the batch has a query size of 10 and response size of 9 (response mask has also size of 9).
- Second sample in the batch has a query size of 10 and response size of 5(response mask has also size of 5).
- With the concatenation,
response_mask_batch[1]
has the size of 15. start
will be 14 for second sample(due to the padding) andend
will be 19.
Hence,
The operation
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
will yield a RuntimeError: The size of tensor a (5) must match the size of tensor b (1) at non-singleton dimension 0
as response_mask_batch[1][14:19]
is same as response_mask_batch[1][14:15]
which has length of 1.
Removing the concatenation of the response mask and removing the slicing from the response mask since response mask already has the length of end - start + 1
, which is equal to length of masks[j, start:end]
should be a better approach.