Skip to content

Conversation

megemini
Copy link
Contributor

PR Category

User Experience

PR Types

New features

Description

NO.13 为 Paddle 新增 RAdam / NAdam API

关联 RFC:

本地测试通过,并且,使用以下代码,与 PyTorch 的结果进行比对,结果一致:

import numpy as np

import torch
from torch.autograd import Variable

import paddle

np.random.seed(2024)
data = np.random.rand(300, 400)

for opt_name, opt_torch, opt_paddle in [['NAdam', torch.optim.NAdam, paddle.optimizer.NAdam], ['RAdam', torch.optim.RAdam, paddle.optimizer.RAdam]]:
    print(f'------ optimizer is : {opt_name} ------')

    tensor = torch.FloatTensor(data)
    x = Variable(tensor, requires_grad=True)

    optimizer = opt_torch([x], lr=0.001)
    for i in range(5):
        optimizer.zero_grad()
        y = torch.mean((x - 5) * (x - 5))
        y.backward()
        optimizer.step()

    x_torch = x.detach().numpy()

    for device in ['cpu', 'gpu']:
        print(f'------ compare {device} ------')
        paddle.set_device(device)
        x = paddle.to_tensor(data)
        x.stop_gradient = False

        optimizer = opt_paddle(parameters=[x], learning_rate=0.001)
        for i in range(5):
            optimizer.clear_grad()
            y = paddle.mean((x - 5) * (x - 5))
            y.backward()
            optimizer.step()

        x_paddle = x.numpy()

        np.testing.assert_allclose(x_torch, x_paddle, atol=1e-06, rtol=1e-06)
        print(f'------- compare finish ---------')

这里同时比对了 NAdam 和 RAdam 在 CPUGPU 上的优化后结果,结果一致。

期间发现一个共性的问题:当优化步数较多时,CPU 的精度与 GPU 的精度会有大于 1e-06 的情况。GPU 下的精度与 PyTorch 一致性较好 ~ 其他 Paddle 的优化器也存在此类问题。

特此指出 ~

另外,由于 Paddle 的优化器算法使用 accumulator 的方式,因此,虽然优化结果一样,但算法的具体实现逻辑上与原算法不一样~

具体实现方式,可以参考 test_xxx_op.py 测试用例里面的 radam_stepnadam_step 函数~

@cxxly 请评审 ~

Copy link

paddle-bot bot commented Apr 18, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 18, 2024
@megemini
Copy link
Contributor Author

megemini commented Apr 18, 2024

@cxxly 昨晚我提交的时候,本地测试是通过的,结果中间 merge 了一下,单测就出问题了,精度的问题 ... ... 您那边有了解到 op_test 有什么变化吗?

但是上面这段对比 PyTorch 的代码没啥问题 ... ...

------ optimizer is : NAdam ------
------ compare cpu ------
------- compare finish ---------
------ compare gpu ------
W0419 07:04:03.844563 221339 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 12.2, Runtime API Version: 11.7
W0419 07:04:03.845705 221339 gpu_resources.cc:164] device: 0, cuDNN Version: 8.5.
------- compare finish ---------
------ optimizer is : RAdam ------
------ compare cpu ------
------- compare finish ---------
------ compare gpu ------
------- compare finish ---------

我也再看一下吧 ... ... 中间 merge 了几十个 pr ... ...

@luotao1
Copy link
Contributor

luotao1 commented Apr 19, 2024

我看 merge 前的commit 2752082 CI也木有过。昨天晚上只合入了迁移单测目录的PR,应该木有影响

@megemini
Copy link
Contributor Author

megemini commented Apr 19, 2024

我看 merge 前的commit 2752082 CI也木有过。昨天晚上只合入了迁移单测目录的PR,应该木有影响

那个 CI 没有跑完,因为提示有 conflict,所以我中间又 merge 了一下 ~

merge 之前我本地是通过的,merge 之后就出问题了 ~

刚才定位了一下,好像是 op_test 的逻辑有问题,比如,测试用例里面 nadam_step 是使用 numpy 模拟算子,这里应该只走了 1 步,但是,merge 之后,op_test 好像是走了 2 步,所以出问题了 ~

下面是我 op_test 里面 print 的一些输出:

!!!!! inputs
{'param': array([[ 0.17602904,  0.3982175 , -0.6236961 , ...,  0.6821805 ,
         0.9441111 ,  0.5645144 ],
       [-0.605939  ,  0.22125213, -0.04228899, ..., -0.7997418 ,
         0.02449642, -0.02278226],
       [-0.01884528,  0.50991696,  0.8460069 , ..., -0.05510011,
        -0.30808878, -0.4063227 ],
       ...,
       [-0.9146876 , -0.0642081 ,  0.8627311 , ..., -0.33185562,
         0.5504968 ,  0.14223264],
       [ 0.15039325, -0.45783746,  0.10544593, ...,  0.6430891 ,
         0.8044486 ,  0.12357192],
       [-0.23231815,  0.17546286,  0.13293621, ..., -0.4734008 ,
         0.47730142, -0.41493982]], dtype=float32), 'grad': array([[ 0.59897286, -0.96861345, -0.8794145 , ..., -0.9679654 ,
         0.35871917, -0.5837766 ],
       [ 0.32519364, -0.622245  , -0.33584148, ...,  0.10742767,
        -0.6007129 ,  0.2832118 ],
       [ 0.37178418,  0.80451465,  0.23622747, ..., -0.65858865,
         0.39903817,  0.9290114 ],
       ...,
       [-0.53591865,  0.65896475, -0.92882305, ..., -0.14718425,
        -0.56659144,  0.7059012 ],
       [ 0.8947429 ,  0.6457763 ,  0.9349599 , ...,  0.24845287,
        -0.41513   , -0.24681845],
       [ 0.96900505, -0.4984258 , -0.21041307, ...,  0.25712886,
         0.59644985, -0.53189933]], dtype=float32), 

'momentum_decay_pow': array(0.884736, dtype=float32), 

'beta2_pow': array(0.7660609, dtype=float32), 

'mu_product': array(0.474552, dtype=float32), 

'moment1': array([[ 0.56653005,  0.61273986, -0.00243789, ...,  0.80402356,
        -0.61513036,  0.7725169 ],
       [-0.76927406,  0.9316274 ,  0.20868428, ..., -0.53259283,
         0.52867246, -0.5667158 ],
       [-0.07326724,  0.33402517, -0.1505982 , ..., -0.47788525,
        -0.3946228 ,  0.58341897],
       ...,
       [-0.5391075 ,  0.94511294, -0.9080138 , ...,  0.9197616 ,
        -0.84012556, -0.73314774],
       [ 0.65273803,  0.5870666 ,  0.18124168, ...,  0.38077796,
        -0.7243844 , -0.00442055],
       [-0.75882095,  0.18046115,  0.09214061, ...,  0.4812588 ,
        -0.8366428 , -0.67993087]], dtype=float32), 'moment2': array([[0.46832946, 0.7352264 , 0.9196088 , ..., 0.04938603, 0.8743996 ,
        0.21438688],
       [0.39720523, 0.2538434 , 0.5387814 , ..., 0.96552783, 0.09327286,
        0.8940937 ],
       [0.04364423, 0.79084235, 0.9966176 , ..., 0.19338697, 0.90396744,
        0.7609614 ],
       ...,
       [0.09379539, 0.4152249 , 0.4628522 , ..., 0.23871422, 0.80177   ,
        0.47308218],
       [0.86667293, 0.1730922 , 0.49166903, ..., 0.60770994, 0.8060526 ,
        0.44773874],
       [0.1273331 , 0.2625222 , 0.22494832, ..., 0.80136675, 0.05598379,
        0.01957218]], dtype=float32), 'learning_rate': array(0.003, dtype=float32)}
!!!!! attrs
{'epsilon': 1e-08, 'beta1': 0.78, 'beta2': 0.915, 'momentum_decay': 0.004, 'momentum_decay_base': 0.96}

--------------------
param_out actual_np
[[ 0.1743741   0.39931554 -0.6224579  ...  0.6845692   0.9439621
   0.56524587]
 [-0.60594314  0.22185223 -0.04182015 ... -0.799576    0.02598498
  -0.02284868]
 [-0.02085438  0.50848264  0.84576494 ... -0.05261055 -0.30842945
  -0.4081155 ]
 ...
 [-0.91161144 -0.06639929  0.8652213  ... -0.33253878  0.551909
   0.14145973]
 [ 0.14870569 -0.46055287  0.10357524 ...  0.6423588   0.80555993
   0.12408908]
 [-0.23428343  0.17657013  0.13343854 ... -0.47411665  0.47615242
  -0.4095929 ]]
param_out expect_np
[[ 0.1743592   0.3993776  -0.62242097 ...  0.6847349   0.9439293
   0.56533515]
 [-0.6059954   0.22194509 -0.04179393 ... -0.7995945   0.02609378
  -0.0228765 ]
 [-0.02092784  0.5084554   0.84575117 ... -0.05257884 -0.30845746
  -0.4081416 ]
 ...
 [-0.9115866  -0.06640426  0.86524254 ... -0.3324781   0.5519114
   0.14139257]
 [ 0.14868431 -0.46057892  0.10352963 ...  0.642358    0.8055586
   0.12410427]
 [-0.23441313  0.17661788  0.13346191 ... -0.4741149   0.4759967
  -0.40957034]]
--------------------
momentum_decay_pow_out actual_np
0.8153727
momentum_decay_pow_out expect_np
0.8493466
--------------------
beta2_pow_out actual_np
0.64136535
beta2_pow_out expect_np
0.70094573
--------------------
mu_product_out actual_np
0.072285436
mu_product_out expect_np
0.18519613
--------------------
moment1_out actual_np
[[ 0.57366747  0.2648421  -0.19537276 ...  0.41418594 -0.40088344
   0.47413227]
 [-0.52849114  0.5897754   0.08888859 ... -0.3917883   0.28020766
  -0.37973168]
 [ 0.02464409  0.43753287 -0.06549654 ... -0.51764    -0.22001737
   0.65944934]
 ...
 [-0.53840595  0.8821603  -0.9125919  ...  0.6850335  -0.77994806
  -0.4165569 ]
 [ 0.7059791   0.59998274  0.34705973 ...  0.35166642 -0.65634847
  -0.0577481 ]
 [-0.37869918  0.031106    0.0255788  ...  0.4319502  -0.52136236
  -0.6473639 ]]
moment1_out expect_np
[[ 0.57366747  0.26484212 -0.19537275 ...  0.414186   -0.40088344
   0.4741323 ]
 [-0.52849114  0.58977544  0.0888886  ... -0.3917883   0.2802077
  -0.3797317 ]
 [ 0.02464408  0.43753284 -0.06549655 ... -0.51764    -0.22001737
   0.6594493 ]
 ...
 [-0.5384059   0.8821603  -0.9125918  ...  0.6850335  -0.77994806
  -0.41655692]
 [ 0.7059791   0.59998274  0.34705967 ...  0.35166642 -0.65634847
  -0.05774809]
 [-0.3786992   0.03110601  0.0255788  ...  0.4319502  -0.52136236
  -0.6473639 ]]
--------------------
moment2_out actual_np
[[0.45901677 0.75248015 0.90717846 ... 0.12482955 0.8110134  0.22513157]
 [0.37243164 0.26517776 0.5025721  ... 0.88443893 0.11601742 0.8249135 ]
 [0.05168346 0.77863646 0.9166484  ... 0.21381688 0.84066486 0.76963997]
 ...
 [0.11023553 0.41684073 0.4968403  ... 0.22026488 0.76090676 0.4752254 ]
 [0.86105376 0.19382665 0.52417994 ... 0.5613015  0.7521865  0.41485912]
 [0.19632229 0.2613242  0.20959097 ... 0.73887044 0.08146411 0.04195648]]
moment2_out expect_np
[[0.45901677 0.7524802  0.90717846 ... 0.12482958 0.8110134  0.22513159]
 [0.37243164 0.26517776 0.5025721  ... 0.88443893 0.11601742 0.8249135 ]
 [0.05168346 0.7786365  0.9166484  ... 0.2138169  0.84066486 0.76963997]
 ...
 [0.11023553 0.41684073 0.49684033 ... 0.22026488 0.76090676 0.47522542]
 [0.86105376 0.19382666 0.52417994 ... 0.5613015  0.7521865  0.41485912]
 [0.19632232 0.26132423 0.20959097 ... 0.73887044 0.08146413 0.04195648]]

可以看到,以 momentum_decay_pow 为例,公式为 momentum_decay_pow *= momentum_decay_base,按照上面的输入值,其中 momentum_decay_pow = 0.884736,momentum_decay_base = 0.96,如果只走 1 步,则 momentum_decay_pow = 0.884736 * 0.96 = 0.84934656,这里与 expect_np 一致。但是,op_test 走了两步,momentum_decay_pow = 0.884736 * 0.96 * 0.96 = 0.815372698,这个与 actual_np 一样!expect_np 与 actual_np 不同,因此报错!

受影响的 beta2_pow 还有 mu_product,因此 param 也就不同了!

beta2_pow = 0.7660609 * 0.915 = 0.700945724 -> 1
beta2_pow = 0.7660609 * 0.915 * 0.915 = 0.641365337 -> 2

merge 之前是没问题的,我也有在 op_test 做类似的比对,确实是走了 1 步!

@megemini
Copy link
Contributor Author

另外,PyTorch 的结果与我这里的算子一致,目前看,可能是 op_test 在 append op 的时候可能有点问题 ~ 这里只涉及静态图 ~

@megemini
Copy link
Contributor Author

megemini commented Apr 19, 2024

我看 merge 前的commit 2752082 CI也木有过。昨天晚上只合入了迁移单测目录的PR,应该木有影响

我 merge 之前的 branch 好像是两三个 星期 月 之前的了 ... ... 因为每次 merge 重新编译太耗时,偷懒了 ... ...

@luotao1
Copy link
Contributor

luotao1 commented Apr 19, 2024

我 merge 之前的 branch 好像是两三个月 之前的了

要不要2分定位下?

@megemini
Copy link
Contributor Author

我 merge 之前的 branch 好像是两三个月 之前的了

要不要2分定位下?

感谢!是我的问题,op_test 我在本地修改的 ... ...

抱歉浪费大家的时间了,等过了 CI 再 @ 大家帮忙评审 ~ 抱歉 ~

@megemini
Copy link
Contributor Author

@cxxly CI 基本都已通过 ~ 之前是我这边本地单元测试的一些问题,已解决 ~ 请评审,谢谢 ~

@luotao1 luotao1 added the API label Apr 24, 2024
Copy link

paddle-ci-bot bot commented Apr 27, 2024

Sorry to inform you that 60f8ebe's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@luotao1
Copy link
Contributor

luotao1 commented May 7, 2024

PR-CI-Coverage 超时了

可以用类似 set_tests_properties(test_autograd_dynamic PROPERTIES TIMEOUT 100)来增大时间

@megemini
Copy link
Contributor Author

Update 20240510

  • 延长 radam / nadam 的单测时间
  • 不使用 uci_housing 数据,使用 random 数据进行单测 (uci_housing 可能出现超时等情况,特使用随机数据代替)。

另外,PR-CI-Coverage 中 radam / nadam 是通过的了,fail 的 api 应该不涉及这两个地方 ~

@cxxly 请评审 ~

cxxly
cxxly previously approved these changes May 11, 2024
Copy link
Contributor

@cxxly cxxly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


def __init__(
self,
learning_rate=0.001,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value of learning_rate is 0.002 in other deep learning frameworks, Which is better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

learning_rate of NAdam is:

  • PyTorch: 0.002
  • TensorFlow/Keras: 0.001

I can not tell which is better 😅 It is better for users to set by themselves?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does paper have any suggestions? Which default value is commonly used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The empirical value is 0.002, it is also used for Adam:

The best learning rate found for SGD was .2, for momentum/NAG was .5, for RMSProp was .001, and for Adam/Nadam was .002. - INCORPORATING NESTEROV MOMENTUM INTO A DAM

but the default value is 0.001 of paddle.optimizer.Adam(learning_rate=0.001 ...).

Let me change it to 0.002 in NAdam?

Copy link
Contributor

@jeff41404 jeff41404 May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but the default value is 0.001 of paddle.optimizer.Adam(learning_rate=0.001 ...).

Modifying the default values of existing APIs Adam requires more testing and can be omitted from this PR

Let me change it to 0.002 in NAdam?

OK, if there is authoritative evidence, this can be done

jeff41404
jeff41404 previously approved these changes May 14, 2024
Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 82 to 83
**Notes**:
**Currently, RAdam doesn't support sparse parameter optimization.**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**Notes**:
**Currently, RAdam doesn't support sparse parameter optimization.**
Note:
Currently, RAdam doesn't support sparse parameter optimization.

参考 https://www.paddlepaddle.org.cn/documentation/docs/zh/dev_guides/api_contributing_guides/api_docs_guidelines_cn.html#zhujie

beta2 (float|Tensor): The exponential decay rate for the 2nd moment estimates.
It should be a float number or a 0-D Tensor with shape [] and data type as float32.
The default value is 0.999.
epsilon (float): A small float value for numerical stability.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
epsilon (float): A small float value for numerical stability.
epsilon (float, optional): A small float value for numerical stability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有点疑问 ~ 因为最近在做 type hint,在 python 里面,一般 x=None 的才视作 Optional,而我们这里好像是有默认值的就视作 Optional

The default value is 0.999.
epsilon (float): A small float value for numerical stability.
The default value is 1e-08.
weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__init__ 里写的是 weight_decay=None,和文档描述有冲突,确认一下

Comment on lines 79 to 80
**Notes**:
**Currently, NAdam doesn't support sparse parameter optimization.**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**Notes**:
**Currently, NAdam doesn't support sparse parameter optimization.**
Note:
Currently, NAdam doesn't support sparse parameter optimization.

The default value is 0.999.
epsilon (float): A small float value for numerical stability.
The default value is 1e-08.
weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__init__ 里写的是 weight_decay=None,和文档描述有冲突,确认一下

then the parameters are list of dict. Note that the learning_rate in parameter groups
represents the scale of base learning_rate.
The default value is None in static graph mode, at this time all parameters will be updated.
beta1 (float|Tensor): The exponential decay rate for the 1st moment estimates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
beta1 (float|Tensor): The exponential decay rate for the 1st moment estimates.
beta1 (float|Tensor, optional): The exponential decay rate for the 1st moment estimates.

beta1 (float|Tensor): The exponential decay rate for the 1st moment estimates.
It should be a float number or a 0-D Tensor with shape [] and data type as float32.
The default value is 0.9.
beta2 (float|Tensor): The exponential decay rate for the 2nd moment estimates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
beta2 (float|Tensor): The exponential decay rate for the 2nd moment estimates.
beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.

then the parameters are list of dict. Note that the learning_rate in parameter groups
represents the scale of base learning_rate.
The default value is None in static graph mode, at this time all parameters will be updated.
beta1 (float|Tensor): The exponential decay rate for the 1st moment estimates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
beta1 (float|Tensor): The exponential decay rate for the 1st moment estimates.
beta1 (float|Tensor, optional): The exponential decay rate for the 1st moment estimates.

beta1 (float|Tensor): The exponential decay rate for the 1st moment estimates.
It should be a float number or a 0-D Tensor with shape [] and data type as float32.
The default value is 0.9.
beta2 (float|Tensor): The exponential decay rate for the 2nd moment estimates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
beta2 (float|Tensor): The exponential decay rate for the 2nd moment estimates.
beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.

&\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
&\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
\big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
&\hspace{0mm} \text{ 其中 } \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)}, \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

英文文档突然冒出了一个中文的 其中 会不会有点器官(狗头

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯 去掉 ~~~ 😆

@megemini
Copy link
Contributor Author

Update 20240514

  • 修改 docstring

@sunzhongkai588 请评审 ~

sunzhongkai588
sunzhongkai588 previously approved these changes May 15, 2024
Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for docs

@luotao1 luotao1 changed the title 【Hackathon 6th No.13】为 Paddle 新增 RAdam / NAdam API 【Hackathon 6th No.13】为 Paddle 新增 RAdam / NAdam API -part May 15, 2024
@luotao1
Copy link
Contributor

luotao1 commented May 15, 2024

顺师傅,可以提交下中文文档。另外,PIR 流水线失败:

2024-05-14 22:17:04 The following tests FAILED: 
2024-05-14 22:17:04 	1036 - test_paddlescience (Failed)
2024-05-14 22:17:04 	1067 - test_radam_op (Failed)
2024-05-14 22:17:04 	1114 - test_paddlescience (Failed)
2024-05-14 22:17:04 	1147 - test_radam_op (Failed)
2024-05-14 22:17:04             	1066 - test_nadam_op (Failed)
2024-05-14 22:17:04             	995 - test_nadam_op (Failed)

@megemini
Copy link
Contributor Author

顺师傅,可以提交下中文文档。另外,PIR 流水线失败:

2024-05-14 22:17:04 The following tests FAILED: 
2024-05-14 22:17:04 	1036 - test_paddlescience (Failed)
2024-05-14 22:17:04 	1067 - test_radam_op (Failed)
2024-05-14 22:17:04 	1114 - test_paddlescience (Failed)
2024-05-14 22:17:04 	1147 - test_radam_op (Failed)
2024-05-14 22:17:04             	1066 - test_nadam_op (Failed)
2024-05-14 22:17:04             	995 - test_nadam_op (Failed)

好像 PR-CI-Py3-PIR 之前不是 Required 就没太关注 ... ...

看了一下日志

2024-05-15 14:23:33 Traceback (most recent call last):
2024-05-15 14:23:33   File "/workspace/Paddle/build/test/legacy_test/test_nadam_op.py", line 292, in test_nadam_static
2024-05-15 14:23:33     conv = paddle.static.nn.conv2d(data, 8, 3)
2024-05-15 14:23:33   File "/workspace/Paddle/build/python/paddle/static/nn/common.py", line 1061, in conv2d
2024-05-15 14:23:33     helper.append_op(
2024-05-15 14:23:33   File "/workspace/Paddle/build/python/paddle/base/layer_helper.py", line 50, in append_op
2024-05-15 14:23:33     return self.main_program.current_block().append_op(*args, **kwargs)
2024-05-15 14:23:33   File "/workspace/Paddle/build/python/paddle/base/framework.py", line 4609, in append_op
2024-05-15 14:23:33     op = Operator(
2024-05-15 14:23:33   File "/workspace/Paddle/build/python/paddle/base/framework.py", line 3234, in __init__
2024-05-15 14:23:33     raise TypeError(
2024-05-15 14:23:33 TypeError: The type of '%Input' in operator conv2d should be one of [str, bytes, Variable]. but received : Value(define_op_name=pd_op.data, index=0, dtype=builtin.tensor<2x3x8x8xf32>, stop_gradient=True)

是测试用例里面 conv2d 在 pir 下调用有问题 ~

这几个测试用例是参考之前 ADAM 的一些用例,应该是之后 pir 的调用有变动 ~

那我改一下吧 ~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants