Skip to content

Conversation

megemini
Copy link
Contributor

PR Category

User Experience

PR Types

Improvements

Description

类型标注:

  • python/paddle/incubate/autograd/functional.py

Related links

@SigureMo @megemini

Copy link

paddle-bot bot commented Jul 13, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@overload
def jvp(
func: _Func, xs: Sequence[Tensor], v: TensorOrTensors | None = None
) -> tuple[Tensor]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> tuple[Tensor]:
) -> tuple[Tensor, ...]:

多个?

以及这里返回值是不是两部分组成的?

@overload
def vjp(
func: _Func, xs: Sequence[Tensor], v: TensorOrTensors | None = None
) -> tuple[Tensor]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同下

from paddle import Tensor
from paddle._typing import TensorOrTensors

_T_TensorOrTensors = TypeVar("_T_TensorOrTensors", Tensor, Sequence[Tensor])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_T_TensorOrTensors = TypeVar("_T_TensorOrTensors", Tensor, Sequence[Tensor])
_TensorOrTensorsT = TypeVar("_T_TensorOrTensors", Tensor, Sequence[Tensor])

泛型参数统一用 T 后缀吧,单个字符泛型参数依次顺延,如 UV 等(字典除外,用 KV

_T_TensorOrTensors = TypeVar("_T_TensorOrTensors", Tensor, Sequence[Tensor])

class _Func(Protocol):
def __call__(self, _T_TensorOrTensors) -> _T_TensorOrTensors:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据 171 行,func 好像是支持 *args 的?是不是可以这样?

from __future__ import annotations

from typing_extensions import TypeVarTuple, Unpack

TensorTs = TypeVarTuple("TensorTs")  # 这里是复数泛型参数,按照惯例使用 Ts


def foo(*args: Unpack[TensorTs]) -> tuple[Unpack[TensorTs]]: ...


x = foo(1, "")
reveal_type(x)  # tuple[int, str]

虽然美中不足就是 TypeVarTuple 不支持 type bound

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个还挺的麻烦,以 vjp 为例:

    class _Func(Protocol):
        @overload
        def __call__(self, arg: Tensor, /) -> Tensor: ...
        @overload
        def __call__(self, *args: Tensor) -> tuple[Tensor, ...]: ...


@overload
def vjp(
    func: Callable[[Tensor], Tensor],
    xs: Tensor,
    v: TensorOrTensors | None = None,
) -> tuple[Tensor, Tensor]: ...
@overload
def vjp(
    func: Callable[[Unpack[tuple[Tensor, ...]]], tuple[Tensor, ...]],
    xs: Tensor,
    v: TensorOrTensors | None = None,
) -> tuple[tuple[Tensor, ...], Tensor]: ...
@overload
def vjp(
    func: Callable[[Tensor], Tensor],
    xs: Sequence[Tensor],
    v: TensorOrTensors | None = None,
) -> tuple[Tensor, tuple[Tensor, ...]]: ...
@overload
def vjp(
    func: Callable[[Unpack[tuple[Tensor, ...]]], tuple[Tensor, ...]],
    xs: Sequence[Tensor],
    v: TensorOrTensors | None = None,
) -> tuple[tuple[Tensor, ...], tuple[Tensor, ...]]: ...

返回的第一个参数,与 func 的输出相关,第二个参数与 xs 相关 ~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外,这个 func 到底是输出一个 Tensor 还是多个,也没限定 ... ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉应该按照一个 Tensor 来处理 ~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考示例:

            >>> import paddle

            >>> def func(x, y):
            ...     return paddle.matmul(x, y)
            ...
            >>> x = paddle.to_tensor([[1., 2.], [3., 4.]])
            >>> J = paddle.incubate.autograd.Jacobian(func, [x, x])

看来还是得用 TypeVarTuple

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还是不行:

def bar(a: ???) -> None:
    return

def t1(a:Tensor, b:Tensor) -> Tensor: # pass
    return Tensor()

def t2(a:Tensor) -> Tensor: # pass
    return Tensor()

def t3(a:Tensor, b:str) -> Tensor: # fail or pass
    return Tensor()

def t4(a:str)->Tensor: # fail
    return Tensor()

def t5(a:str)->str: # fail
    return '1'

看看有啥办法满足以上标注?至少满足 t1 t2 的 pass ~

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

返回的第一个参数,与 func 的输出相关,第二个参数与 xs 相关 ~

如果这样的话,多用几个泛型参数呢?

_FuncOutT = TypeVar("_FuncOutT", Tensor, Sequence[Tensor])
_InputT = TypeVar("_InputT", Tensor, Sequence[Tensor])
def vjp(
    func: Callable[[Tensor], _FuncOutT],  # 这里需要 overload 一个 Sequence 的或者用 Protocol
    xs: _InputT,
    v: TensorOrTensors | None = None,
) -> tuple[_FuncOutT, _InputT]: ...

只是大概思路,看看是否有用?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码和示例都没说清楚,func 到底只返回 Tensor 还是 tuple[Tensor] ,但是从示例来看,应该是:输入可以是多个 Tensor,输出一个 Tensor ~

那么,现在问题是,前面例子里面写的,如何满足:

def func(a: ???) -> None:
    return

def t1(a:Tensor, b:Tensor) -> Tensor: # pass
    return Tensor()

def t2(a:Tensor) -> Tensor: # pass
    return Tensor()

func(t1)
func(t2)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用 ParamSpec 呢?

_FuncInputT = ParamSpec("_InputT")
def func(a: Callable[_FuncInputT, _FuncOutT]) -> None: ...

就是约束会丢,这个泛型并没起到泛型约束的作用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,没啥用,其实跟 Callable[..., Tensor] 没啥区别 ... ...

@megemini
Copy link
Contributor Author

Update 20240714

  • 修改 vjp jvp 为:
@overload
def jvp(
    func: Callable[..., Tensor],
    xs: Tensor,
    v: TensorOrTensors | None = None,
) -> tuple[Tensor, Tensor]:
    ...


@overload
def jvp(
    func: Callable[..., tuple[Tensor, ...]],
    xs: Tensor,
    v: TensorOrTensors | None = None,
) -> tuple[tuple[Tensor, ...], Tensor]:
    ...


@overload
def jvp(
    func: Callable[..., Tensor],
    xs: Sequence[Tensor],
    v: TensorOrTensors | None = None,
) -> tuple[Tensor, tuple[Tensor, ...]]:
    ...


@overload
def jvp(
    func: Callable[..., tuple[Tensor, ...]],
    xs: Sequence[Tensor],
    v: TensorOrTensors | None = None,
) -> tuple[tuple[Tensor, ...], tuple[Tensor, ...]]:
    ...

即,func 可以接收一个或多个输入,输出 Tensortuple[Tensor, ...] ,整体输出的第一个参数与其相同

  • 修改 Jacobian HessianfuncCallable[..., TensorOrTensors]

@SigureMo
Copy link
Member

即,func 可以接收一个或多个输入,输出 Tensor 或 tuple[Tensor, ...] ,整体输出的第一个参数与其相同

这里的约束应该做不到 func 的输出和整体第一个输出一致?比如 tuple[Tensor, Tensor]tuple[Tensor, Tensor, Tensor],这里可以用泛型么?

@megemini
Copy link
Contributor Author

即,func 可以接收一个或多个输入,输出 Tensor 或 tuple[Tensor, ...] ,整体输出的第一个参数与其相同

这里的约束应该做不到 func 的输出和整体第一个输出一致?比如 tuple[Tensor, Tensor]tuple[Tensor, Tensor, Tensor],这里可以用泛型么?

嗯,用泛型更好一点 ~

Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

@SigureMo SigureMo merged commit e813d6c into PaddlePaddle:develop Jul 14, 2024
lixcli pushed a commit to lixcli/Paddle that referenced this pull request Jul 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants