Skip to content

Conversation

YuanRisheng
Copy link
Contributor

@YuanRisheng YuanRisheng commented Sep 6, 2023

PR types

Others

PR changes

Others

Description

Pcard-67164
实现在python端使用新IR进行模型训练功能:

  • 实现Python端在新IR下的create_parameter功能,Set/GetParameterOp抽象出c++ api并绑定到Python
  • pybind层数据结构扩展,支持Program添加自定义属性功能,进一步简化Program,Block等数据获取方式,降低新IR python端适配成本,支持获取Block下所有可训练OpResult功能
  • xavier,constant初始化器适配新IR,支持在新IR下初始化parameter参数
  • Optimizer优化器适配新IR,在新IR下获取learning rate,接入新IR backward反向逻辑,实现在新IR下通过优化器优化参数机制
  • IR op自动生成支持optional inplace输入输出场景
  • uniform/sgd/subtract/square/linear等C++ api适配python
  • Fluid dtype,新IR dtype及numpy type转换功能加强,新增IR dtype到字符串的转换,添加对IR type check功能

@paddle-bot
Copy link

paddle-bot bot commented Sep 6, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! LGTM overall

core.DataType.INT16: 'int16',
core.DataType.INT32: 'int32',
core.DataType.INT64: 'int64',
core.DataType.UINT8: 'uint8',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要考入bf16类型吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要,下个PR修改

if bias is not None:
return paddle._ir_ops.add(out, bias)
else:
return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分可以为什么不是linear

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为没有linear的api

var, var.shape, float(self._value), var.dtype, place
)
return None
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

动静的full_接口不一致吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里主要是动静态图场景语义不一致,这里在新ir静态图场景下需要创建一个OpResult,无法使用full_

)
for index, grad in enumerate(grads):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad 有内部接口 caculate_gradients_helper 返回的就是前向和反向的对应dict 可考虑直接用内部接口的返回值

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢,下个Pr修改

@@ -300,6 +303,23 @@ def GenBuildOutputs(
VLOG(4) << "Builder construction meta_{name}";
phi::MetaTensor meta_{name}(&ir_meta_tensor_{name});
"""

CREATE_OPTIONAL_INPUT_METATENSOR_TEMPLATE = """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否仅保留 CREATE_OPTIONAL_INPUT_METATENSOR_TEMPLATE即可?CREATE_INPUT_METATENSOR_TEMPLATE 可以删除了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了生成代码简洁性更好,这里建议保留CREATE_OPTIONAL_INPUT_METATENSOR_TEMPLATE

@@ -39,5 +40,28 @@ pir::OpResult split_grad(std::vector<pir::OpResult> out_grads, int axis) {

return split_grad_op.x_grad();
}
pir::OpResult get_parameter(const std::string& name,
phi::DataType dtype,
const std::vector<int64_t>& shape) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO:get_parameter的输入建议仅包括 name,type 的获取可从Program::parameters_获取

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢,下个PR修改

if (defining_op->HasAttribute(attr_name)) {
auto attrs = defining_op->attribute(attr_name)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议这里对 attr 属性添加判断,判断是否为 BoolAttribtue 后再 dyn_castpir::BoolAttribute

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下个pr修改

phi::IntArray shape = CastPyArg2IntArray(shape_obj, "shape", 2);
// Call ir static api
auto static_api_out =
paddle::dialect::get_parameter(name, dtype, shape.GetData());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -1142,6 +1164,79 @@ def _create_optimization_pass(
end = len(target_block.ops)
return target_block._slice_ops(start, end)

def _new_ir_create_optimization_pass(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _new_ir_create_optimization_pass(
def _pir_create_optimization_pass(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python端命名后续会专门统一修改

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, no docs change

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants