Skip to content

IterableDataset's state_dict shard_example_idx is always equal to the number of samples in a shard #7475

@bruno-hays

Description

@bruno-hays

Describe the bug

I've noticed a strange behaviour with Iterable state_dict: the value of shard_example_idx is always equal to the amount of samples in a shard.

Steps to reproduce the bug

I am reusing the example from the doc

from datasets import Dataset

ds = Dataset.from_dict({"a": range(6)}).to_iterable_dataset(num_shards=1)
state_dict = None
# Iterate through the dataset and print examples
for idx, example in enumerate(ds):
    print(example)
    if idx == 2:
        state_dict = ds.state_dict()
        print("checkpoint")
        break
print(state_dict)

Returns:

{'a': 0}
{'a': 1}
checkpoint
{'examples_iterable': {'shard_idx': 0, 'shard_example_idx': 6, 'type': 'ArrowExamplesIterable'}, 'epoch': 0}

Expected behavior

shard_example_idx should be 2 instead of 6
If we run with num_shards=2, then shard_example_idx is 3 instead of 2 and so on.

Environment info

  • datasets version: 3.4.1
  • Platform: macOS-14.6.1-arm64-arm-64bit
  • Python version: 3.12.9
  • huggingface_hub version: 0.29.3
  • PyArrow version: 19.0.1
  • Pandas version: 2.2.3
  • fsspec version: 2024.12.0

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions