-
Notifications
You must be signed in to change notification settings - Fork 668
Closed
Description
Here it is probably typo:
from torchtune.models.llama3_2_vision import llama3_2_vision_transform
from torchtune.datasets.multimodal import multimodal_chat_dataset
transform = Llama3VisionTransform(
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
prompt_template="torchtune.data.QuestionAnswerTemplate",
max_seq_len=8192,
image_size=560,
)
ds = multimodal_chat_dataset(
model_transform=model_transform,
source="json",
data_files="data/my_data.json",
column_map={
"dialogue": "conversations",
"image_path": "image",
},
image_dir="/home/user/dataset/", # /home/user/dataset/images/clock.jpg
image_tag="<image>",
split="train",
)
tokenized_dict = ds[0]
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nQuestion:<|image|>What time is it on the clock?Answer:<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nIt is 10:00AM.<|eot_id|>'
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width)
# torch.Size([4, 3, 224, 224])
Shouldn't it be just transform, not model_transform?
Metadata
Metadata
Assignees
Labels
No labels