Skip to content

Slow iteration for iterable dataset with numpy formatting for array data #7206

@alex-hh

Description

@alex-hh

Describe the bug

When working with large arrays, setting with_format to e.g. numpy then applying map causes a significant slowdown for iterable datasets.

Steps to reproduce the bug

import numpy as np
import time
from datasets import Dataset, Features, Array3D

features=Features(**{"array0": Array3D((None, 10, 10), dtype="float32"), "array1": Array3D((None,10,10), dtype="float32")})
dataset = Dataset.from_dict({f"array{i}": [np.zeros((x,10,10), dtype=np.float32) for x in [2000,1000]*25] for i in range(2)}, features=features)

Then

ds = dataset.to_iterable_dataset()
ds = ds.with_format("numpy").map(lambda x: x)
t0 = time.time()
for ex in ds:
    pass
t1 = time.time()
print(t1-t0)

takes 27 s, whereas

ds = dataset.to_iterable_dataset()
ds = ds.with_format("numpy")
ds = dataset.to_iterable_dataset()
t0 = time.time()
for ex in ds:
    pass
t1 = time.time()
print(t1 - t0)

takes ~1s

Expected behavior

Map should not introduce a slowdown when formatting is enabled.

Environment info

3.0.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions