Skip to content

Conversation

megemini
Copy link
Contributor

@megemini megemini commented Jul 6, 2024

PR Category

User Experience

PR Types

Improvements

Description

类型标注:

  • python/paddle/distribution/distribution.py

Related links

@SigureMo @megemini

Copy link

paddle-bot bot commented Jul 6, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -18,7 +18,10 @@
# 'Normal',
# 'Uniform']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块 TODO 和 __all__ 的注释删掉吧,看着也没啥用

def _to_tensor(self, *args):
def _to_tensor(
self, *args: TensorLike | NestedNumbericSequence
) -> tuple[Tensor, ...]:
Copy link
Member

@SigureMo SigureMo Jul 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是 Tensor?

from typing import TypeGuard
from typing_extensions import TypeIs


def is_str_seq(*args: str | int) -> TypeGuard[tuple[str, ...]]:
    return all(isinstance(arg, str) for arg in args)


def foo(x: str | int):
    if is_str_seq(x, x):
        reveal_type(x) # tuple[str, ...]
    else:
        reveal_type(x)

这里好像 guard 住的是每个元素的类型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是 Tensor?

?输出?

return tuple(variable_args)
from typing import TypeGuard
from typing_extensions import TypeIs


def is_str_seq(self, *args: str | int) -> TypeGuard[tuple[str, ...]]:
    return all(isinstance(arg, str) for arg in args)


def foo(x: str | int):
    if is_str_seq(x, x):
        reveal_type(x) # tuple[str, ...]
    else:
        reveal_type(x)

这里好像 guard 住的是每个元素的类型

这是指这里?

    def _validate_args(
        self, *args: TensorLike | NestedNumbericSequence
    ) -> TypeGuard[tuple[Tensor, ...]]:

Copy link
Member

@SigureMo SigureMo Jul 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是指这里?

是这里

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    def _validate_args(
        self, *args: TensorLike | NestedNumbericSequence
    ) -> TypeGuard[tuple[Tensor, ...]]:
        """
        Argument validation for distribution args
        Args:
            value (float, list, numpy.ndarray, Tensor)
        Raises
            ValueError: if one argument is Tensor, all arguments should be Tensor
        """
        is_variable = False
        is_number = False
        for arg in args:
            if isinstance(arg, (Variable, paddle.pir.Value)):
                is_variable = True
            else:
                is_number = True

        if is_variable and is_number:
            raise ValueError(
                'if one argument is Tensor, all arguments should be Tensor'
            )

        return is_variable

是 guard 每个元素 ~ 有啥问题?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

每个元素都是 Tensor?那为啥是 TypeGuard[tuple[Tensor, ...]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from typing import TypeGuard
from typing_extensions import TypeIs


def is_str_seq(*args: str | int) -> TypeGuard[str]:
    return all(isinstance(arg, str) for arg in args)


def foo(x: str | int, y: str | int):
    if is_str_seq(x, y):
        reveal_type(x)  # str
        reveal_type(y)  # str | int
    else:
        reveal_type(x)
        reveal_type(y)

我试了下 guard 不住,pyright 只能保证第一个参数 guard 住,这是 TypeGuard 规范中写的

mypy 甚至连第一个参数都 guard 不住

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

每个元素都是 Tensor?那为啥是 TypeGuard[tuple[Tensor, ...]]

这个方法的作用不就是保证每一个元素都是 Tensor ? 输入是 *args ,所以用的 tuple[Tensor, ...]

我试了下 guard 不住,pyright 只能保证第一个参数 guard 住,这是 TypeGuard 规范中写的

那应该是实现的还不够完善吧 ~ 那咋办?写 TypeGuard[Tensor] ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeGuard[Tensor] 吧,大多数用例是没问题的,而且是非公开 API,也还好

Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

@SigureMo SigureMo merged commit 3477632 into PaddlePaddle:develop Jul 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants