Skip to content

Conversation

86kkd
Copy link
Contributor

@86kkd 86kkd commented Jun 23, 2024

PR Category

User Experience

PR Types

Improvements

Description

类型标注:

  • python/paddle/vision/datasets/cifar.py

Related links

@SigureMo @megemini

Copy link

paddle-bot bot commented Jun 23, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@megemini megemini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@86kkd 这几个文件,可以先参考一下 https://github.com/cattidea/paddlepaddle-stubs/tree/main/paddle-stubs/vision/datasets

主要是需要增加类属性,如:

class Cifar10(Dataset):
    mode: Any = ...
    backend: Any = ...
    data_file: Any = ...
    transform: Any = ...
    dtype: Any = ...

不过,这里不能写 Any ,具体怎么写还需要看一下 ~

另外,@SigureMopython/paddle/vision/transforms/transforms.py 是否可以加点东西:

from __future__ import annotations

import numpy as np
import numpy.typing as npt

from typing import overload, TypeVar, Any, Sequence, Callable, Tuple, Union


class Tensor: ...


class PILImage: ...


_DataT = TypeVar("_DataT", Tensor, PILImage, npt.NDArray[Any])


# ---------- transform ----------
_TransformTensor = Callable[[Tensor], Tensor]
_TransformTensors = Callable[[Tuple[Tensor, ...]], Tuple[Tensor, ...]]
_TransformPILImage = Callable[[PILImage], PILImage]
_TransformPILImages = Callable[[Tuple[PILImage, ...]], Tuple[PILImage, ...]]
_TransformNDArray = Callable[[npt.NDArray[Any]], npt.NDArray[Any]]
_TransformNDArrays = Callable[
    [Tuple[npt.NDArray[Any], ...]], Tuple[npt.NDArray[Any], ...]
]

_Transform = Union[
    _TransformTensor,
    _TransformTensors,
    _TransformPILImage,
    _TransformPILImages,
    _TransformNDArray,
    _TransformNDArrays,
]
# ---------- transform end ----------


class BaseTransform:
    def __init__(self, keys=None):
        self.keys = keys

    @overload
    def __call__(self, inputs: _DataT) -> _DataT: ...

    @overload
    def __call__(self, inputs: tuple[_DataT, ...]) -> tuple[_DataT, ...]: ...

    def __call__(self, inputs):
        if isinstance(inputs, (Tensor, PILImage, np.ndarray)):
            return inputs
        return tuple(inputs)


class Resize(BaseTransform):
    def __init__(self, size, interpolation="bilinear", keys=None): ...

    @overload
    def __call__(self, inputs: _DataT) -> _DataT: ...

    @overload
    def __call__(self, inputs: tuple[_DataT, ...]) -> tuple[_DataT, ...]: ...

    def __call__(self, inputs):
        if isinstance(inputs, (Tensor, PILImage, np.ndarray)):
            return inputs
        return tuple(inputs)


class Compose:
    transforms: Sequence[BaseTransform]

    def __init__(self, transforms: Sequence[BaseTransform]) -> None:
        self.transforms = transforms

    @overload
    def __call__(self, data: _DataT) -> _DataT: ...

    @overload
    def __call__(self, data: tuple[_DataT, ...]) -> tuple[_DataT, ...]: ...

    def __call__(self, data):
        if isinstance(data, (Tensor, PILImage, np.ndarray)):
            return data
        return tuple(data)


class Cifar10:
    transform: _Transform

    def __init__(self, transform: _Transform) -> None:
        self.transform = transform


t1 = Resize(1)
cifar = Cifar10(transform=t1) # ok

t2 = Compose([Resize(1), Resize(2)])
cifar = Cifar10(transform=t2) # ok


def trans_pass(a: Tensor) -> Tensor:
    return Tensor()


def trans_fail(a: Tensor) -> tuple[Tensor, Tensor]:
    return Tensor(), Tensor()


cifar = Cifar10(transform=trans_pass) # ok
cifar = Cifar10(transform=trans_fail) # This should fail

添加以上 transform 内的东西 ~ 这样 datasets 里面的这几个地方,都可以用 _Transform

实际上,个人感觉,python/paddle/vision/transforms/transforms.py 在设计的时候,就应该有一个 Protocol ,而由于缺少这个 Protocol ,导致 BaseTransformCompose__call__ 写法(输入参数名)都不一致 ~

backend=None,
data_file: str | None = None,
mode: str = 'train',
transform: Callable[[(list | tuple)], BaseTransform] | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是个什么写法?

@SigureMo SigureMo changed the title [Typing][A-86] Add type annotations for python/paddle/vision/datasets/cifar.py [Typing][A-86, A-92] Add type annotations for python/paddle/vision/{datasets/cifar.py, image.py} Jun 23, 2024
@SigureMo SigureMo changed the title [Typing][A-86, A-92] Add type annotations for python/paddle/vision/{datasets/cifar.py, image.py} [Typing][A-86, A-92] Add type annotations for python/paddle/vision/{datasets/cifar.py,image.py} Jun 23, 2024
SigureMo
SigureMo previously approved these changes Jun 23, 2024
Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

@SigureMo SigureMo merged commit e26bf13 into PaddlePaddle:develop Jun 24, 2024
@86kkd 86kkd deleted the a-86 branch June 24, 2024 04:01
@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label Jun 24, 2024
co63oc pushed a commit to co63oc/Paddle that referenced this pull request Jun 25, 2024
…datasets/cifar.py,image.py}` (PaddlePaddle#65386)


---------

Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com>
co63oc pushed a commit to co63oc/Paddle that referenced this pull request Jun 25, 2024
…datasets/cifar.py,image.py}` (PaddlePaddle#65386)


---------

Co-authored-by: Nyakku Shigure <sigure.qaq@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants