-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[PIR]Gen check DataType #59354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PIR]Gen check DataType #59354
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
好的,我去加一下 |
对于第四种情况,从CI看编译还有些问题,以ci中编译报错的bilinear这个API为例。pd_api.cc中这个api的签名如下
最后一个输入类型是paddle::optionalpir::Value的,paddle::optional的意思是说bias这个参数可能是null。对于bilinear会根据最后一个输入的dtype选择kernel,如果bias为null那就会根据weight的dtype来选择kernel。那么对于bilinear这个api来说,生成的检查逻辑可能类似下面:
另外:我不确定是不是存在这样的API,比如bilinear的weight参数也是optional的,那么生成的逻辑是不是需要多个ifelse呢?可以先看下是否会有这样的API |
或者如果optional相关的api比较少这里,支持起来又比较麻烦的话,可以先加个黑名单之类的,跳过这类api。等下一个pr在支持这类情况~ |
test_conj_op.py中Testfp16ConjOp.testfp16 这个单测修改一下,只在gpu环境下跑吧,可以用is_compiled_with_cuda这个api判断。因为这个cpu kernel没有注册fp16类型~ |
好的,已修改 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
PR types
Others
PR changes
Others
Description
[PIR]Gen check DataType
根据yaml中的配置情况,有下面4中情况需要支持check dtype(已经全部完成):
这里input是指的Value类型的输入参数,其他类型的参数被称为attr
参考 #58954