-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[NewIR]Support train model in new ir using Python API #57010
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! LGTM overall
core.DataType.INT16: 'int16', | ||
core.DataType.INT32: 'int32', | ||
core.DataType.INT64: 'int64', | ||
core.DataType.UINT8: 'uint8', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要考入bf16类型吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要,下个PR修改
if bias is not None: | ||
return paddle._ir_ops.add(out, bias) | ||
else: | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分可以为什么不是linear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为没有linear的api
var, var.shape, float(self._value), var.dtype, place | ||
) | ||
return None | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
动静的full_接口不一致吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里主要是动静态图场景语义不一致,这里在新ir静态图场景下需要创建一个OpResult,无法使用full_
) | ||
for index, grad in enumerate(grads): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grad 有内部接口 caculate_gradients_helper 返回的就是前向和反向的对应dict 可考虑直接用内部接口的返回值
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
谢谢,下个Pr修改
@@ -300,6 +303,23 @@ def GenBuildOutputs( | |||
VLOG(4) << "Builder construction meta_{name}"; | |||
phi::MetaTensor meta_{name}(&ir_meta_tensor_{name}); | |||
""" | |||
|
|||
CREATE_OPTIONAL_INPUT_METATENSOR_TEMPLATE = """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否仅保留 CREATE_OPTIONAL_INPUT_METATENSOR_TEMPLATE即可?CREATE_INPUT_METATENSOR_TEMPLATE 可以删除了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了生成代码简洁性更好,这里建议保留CREATE_OPTIONAL_INPUT_METATENSOR_TEMPLATE
@@ -39,5 +40,28 @@ pir::OpResult split_grad(std::vector<pir::OpResult> out_grads, int axis) { | |||
|
|||
return split_grad_op.x_grad(); | |||
} | |||
pir::OpResult get_parameter(const std::string& name, | |||
phi::DataType dtype, | |||
const std::vector<int64_t>& shape) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO:get_parameter的输入建议仅包括 name,type 的获取可从Program::parameters_获取
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
谢谢,下个PR修改
if (defining_op->HasAttribute(attr_name)) { | ||
auto attrs = defining_op->attribute(attr_name) | ||
.dyn_cast<pir::ArrayAttribute>() | ||
.AsVector(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议这里对 attr 属性添加判断,判断是否为 BoolAttribtue 后再 dyn_castpir::BoolAttribute
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,下个pr修改
phi::IntArray shape = CastPyArg2IntArray(shape_obj, "shape", 2); | ||
// Call ir static api | ||
auto static_api_out = | ||
paddle::dialect::get_parameter(name, dtype, shape.GetData()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
@@ -1142,6 +1164,79 @@ def _create_optimization_pass( | |||
end = len(target_block.ops) | |||
return target_block._slice_ops(start, end) | |||
|
|||
def _new_ir_create_optimization_pass( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _new_ir_create_optimization_pass( | |
def _pir_create_optimization_pass( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python端命名后续会专门统一修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, no docs change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Description
Pcard-67164
实现在python端使用新IR进行模型训练功能: