Skip to content

Conversation

enkilee
Copy link
Contributor

@enkilee enkilee commented Jul 1, 2024

PR Category

User Experience

PR Types

Improvements

Description

Copy link

paddle-bot bot commented Jul 1, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Jul 1, 2024
@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label Jul 1, 2024
@@ -382,25 +397,37 @@ class DataLoader:
please see :code:`paddle.io.IterableDataset`
"""

return_list: bool
collate_fn: default_collate_fn | None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要是个 Callable ~ 可以参考以下 default_collate_fn 的输入输出类型,貌似还挺复杂的 ... ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

恩,我试着改改。

collate_fn: default_collate_fn | None
use_buffer_reader: bool
prefetch_factor: int
worker_init_fn: Callable[..., Any] | None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考了一下测试用例:

class TestDynamicDataLoaderIterInitFuncSplit(unittest.TestCase):
    def test_main(self):
        place = base.CPUPlace()
        with base.dygraph.guard(place):
            dataset = RangeIterableDataset(0, 10)

            def worker_spliter(worker_id):
                worker_info = get_worker_info()

                dataset = worker_info.dataset
                start = dataset.start
                end = dataset.end
                num_per_worker = int(
                    math.ceil((end - start) / float(worker_info.num_workers))
                )

                worker_id = worker_info.id
                dataset.start = start + worker_id * num_per_worker
                dataset.end = min(dataset.start + num_per_worker, end)

            dataloader = DataLoader(
                dataset,
                places=place,
                num_workers=1,
                batch_size=1,
                drop_last=True,
                worker_init_fn=worker_spliter,
            )

            rets = []
            for d in dataloader:
                rets.append(d.numpy()[0][0])

            assert tuple(sorted(rets)) == tuple(range(0, 10))

这里貌似可以是 Callable[[int], None]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到

prefetch_factor: int
worker_init_fn: Callable[..., Any] | None
dataset: Dataset
feed_list: list[Tensor] | tuple[Tensor]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用 Sequence 可否?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,谢谢

worker_init_fn: Callable[..., Any] | None
dataset: Dataset
feed_list: list[Tensor] | tuple[Tensor]
places: list[Place] | tuple[Place] | list[str]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用 paddle._typing 的 PlaceLike ? Sequence[PlaceLike]

Paddle 的 PlaceCPUPlace 没有继承关系,所以不能直接用 Place

p.s. 我也是最近才发现的 🤣

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到,谢谢

import numpy.typing as npt

from paddle import Tensor
from python.paddle._typing.device_like import PlaceLike
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from python.paddle._typing.device_like import PlaceLike
from python.paddle._typing import PlaceLike

@@ -382,25 +404,37 @@ class DataLoader:
please see :code:`paddle.io.IterableDataset`
"""

return_list: bool
collate_fn: Callable[[_Collate_Fn_State], None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个确实有点麻烦,我试了一下:

from __future__ import annotations

import numbers
import numpy as np
import numpy.typing as npt
from typing import Protocol, overload, Sequence, Mapping, Any, TYPE_CHECKING, AnyStr, TypeVar

import paddle

if TYPE_CHECKING:
    from paddle import Tensor

    KT = TypeVar('KT')
    VT = TypeVar('VT')

@overload
def default_collate_fn(batch: Sequence[npt.NDArray[Any]] | Sequence[numbers.Number]) -> npt.NDArray[Any]: ...
@overload
def default_collate_fn(batch: Sequence[Tensor]) -> Tensor: ...
@overload
def default_collate_fn(batch: Sequence[AnyStr]) -> AnyStr: ...
@overload
def default_collate_fn(batch: Sequence[Mapping[KT, VT]]) -> Mapping[KT, VT]: ...
@overload
def default_collate_fn(batch: Sequence[Sequence[VT]]) -> Sequence[VT]: ...
def default_collate_fn(batch):
    sample = batch[0]
    if isinstance(sample, np.ndarray):
        batch = np.stack(batch, axis=0)
        return batch
    elif isinstance(sample, paddle.Tensor):
        return paddle.stack(batch, axis=0)
    elif isinstance(sample, numbers.Number):
        batch = np.array(batch)
        return batch
    elif isinstance(sample, (str, bytes)):
        return batch
    elif isinstance(sample, Mapping):
        return {
            key: default_collate_fn([d[key] for d in batch]) for key in sample
        }
    elif isinstance(sample, Sequence):
        sample_fields_num = len(sample)
        if not all(len(sample) == sample_fields_num for sample in iter(batch)):
            raise RuntimeError(
                "fields number not same among samples in a batch"
            )
        return [default_collate_fn(fields) for fields in zip(*batch)]

    raise TypeError(
        "batch data con only contains: tensor, numpy.ndarray, "
        f"dict, list, number, but got {type(sample)}"
    )


class _Collate_Fn(Protocol):
    @overload
    def __call__(self, batch: Sequence[npt.NDArray[Any]] | Sequence[numbers.Number]) -> npt.NDArray[Any]: ...
    @overload
    def __call__(self, batch: Sequence[Tensor]) -> Tensor: ...
    @overload
    def __call__(self, batch: Sequence[AnyStr]) -> AnyStr: ...
    @overload
    def __call__(self, batch: Sequence[Mapping[KT, VT]]) -> Mapping[KT, VT]: ...
    @overload
    def __call__(self, batch: Sequence[Sequence[VT]]) -> Sequence[VT]: ...

fn: _Collate_Fn = default_collate_fn

@SigureMo 帮忙看看这个 _Collate_Fn 的 Protocol 看看是否可行?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是把default_collate_fn在这里又重载了?得这么重复么。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目测没啥问题,检查如果没问题就没问题,只有一些风格的小问题

  • _Collate_Fn -> _CollateFn
  • KT -> _KVT -> _V,K 和 V 和 T 一样是常用的泛型变量

关于 _CollateFn 类型要用 Protocol 重写一遍的问题,没有解决方案,前段时间刚吐槽过一次,但就是得写两遍重复的代码(paddle.jit.save 那边就是这样的)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是把default_collate_fn在这里又重载了?得这么重复么。

抱歉,刚看到这里有留言 ... ...

我这里是举例,根据 default_collate_fn 可以推导出 _CollateFn 的写法 ~

只保留 _CollateFn 就可以了 ~ 不需要把 default_collate_fn 再写一边 ~~~ 🤣🤣🤣

worker_init_fn: Callable[[int], None]
dataset: Dataset
feed_list: Sequence[Tensor] | None
places: Sequence[PlaceLike] | list[str] | None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
places: Sequence[PlaceLike] | list[str] | None
places: Sequence[PlaceLike] | None

应该可以不需要 list[str] 了吧 ~ PlaceLike 包括 str 了 ~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外,DataLoader

    def __len__(self):

    def __iter__(self):

    def __call__(self):

需要也标注一下 ~ 不然在 for 循环等地方可能无法确定类型 ~


from paddle import Tensor
from python.paddle._typing.device_like import PlaceLike
from python.paddle.io.dataloader.dataset import Dataset
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哪来的 python.

@@ -382,25 +404,37 @@ class DataLoader:
please see :code:`paddle.io.IterableDataset`
"""

return_list: bool
collate_fn: Callable[[_Collate_Fn_State], None]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目测没啥问题,检查如果没问题就没问题,只有一些风格的小问题

  • _Collate_Fn -> _CollateFn
  • KT -> _KVT -> _V,K 和 V 和 T 一样是常用的泛型变量

关于 _CollateFn 类型要用 Protocol 重写一遍的问题,没有解决方案,前段时间刚吐槽过一次,但就是得写两遍重复的代码(paddle.jit.save 那边就是这样的)

batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
collate_fn: Callable[[_CollateFn], None] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(这里是不是 _CollateFn | None

Copy link
Contributor

@megemini megemini Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

collate_fn: Callable[[_CollateFn], None] | None = None,

collate_fn: _CollateFn | None = None

Comment on lines 43 to 97
@overload
def default_collate_fn(
batch: Sequence[npt.NDArray[Any]] | Sequence[numbers.Number],
) -> npt.NDArray[Any]:
...


@overload
def default_collate_fn(batch: Sequence[Tensor]) -> Tensor:
...


@overload
def default_collate_fn(batch: Sequence[AnyStr]) -> AnyStr:
...


@overload
def default_collate_fn(batch: Sequence[Mapping[_K, _V]]) -> Mapping[_K, _V]:
...


@overload
def default_collate_fn(batch: Sequence[Sequence[_V]]) -> Sequence[_V]:
...


def default_collate_fn(batch):
sample = batch[0]
if isinstance(sample, np.ndarray):
batch = np.stack(batch, axis=0)
return batch
elif isinstance(sample, paddle.Tensor):
return paddle.stack(batch, axis=0)
elif isinstance(sample, numbers.Number):
batch = np.array(batch)
return batch
elif isinstance(sample, (str, bytes)):
return batch
elif isinstance(sample, Mapping):
return {
key: default_collate_fn([d[key] for d in batch]) for key in sample
}
elif isinstance(sample, Sequence):
sample_fields_num = len(sample)
if not all(len(sample) == sample_fields_num for sample in iter(batch)):
raise RuntimeError(
"fields number not same among samples in a batch"
)
return [default_collate_fn(fields) for fields in zip(*batch)]

raise TypeError(
"batch data con only contains: tensor, numpy.ndarray, "
f"dict, list, number, but got {type(sample)}"
)
Copy link
Contributor

@megemini megemini Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用这里再写一边,如果有需要,在这个函数的文件那里写就行 ~ 🤣

Update

这个函数在 python/paddle/io/dataloader/collate.py ~ 不是公开 API ,先不标注了 ~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到

...


fn: _CollateFn = default_collate_fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个也不用 ~ 保留上面的 _CollateFn 就行 ~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到

Copy link
Contributor

@megemini megemini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其他的暂时没发现啥问题了 ~ 辛苦 ~ 🤟🤟🤟

Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

@SigureMo SigureMo merged commit 475deec into PaddlePaddle:develop Jul 13, 2024
lixcli pushed a commit to lixcli/Paddle that referenced this pull request Jul 22, 2024
@enkilee enkilee deleted the typing-b01 branch January 8, 2025 08:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants