-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Description
需求描述 Feature Description
需求背景
在飞桨新一代IR的开发过程中,出于兼容旧IR体系的考虑,我们开发了ProgramTranslator(相关代码位于 paddle/fluid/ir_adaptor/translator/),用于将定义在旧IR下的计算图转化为新IR表示。
在翻译旧IR时,我们会遇到fill_constant op,它对应新IR下的full op。full 的定义如下:
- op : full
args : (IntArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor(out)
在新IR下,这应该是一个只有输出没有输入的op。然而,在旧IR下,其定义如下:
class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddAttr<int>("dtype",
"(int, default 5 (FP32)) "
"Output data type")
.SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int64_t>>("shape",
"(vector<int64_t>) The shape of the output")
.SetDefault({});
AddInput("ValueTensor",
"(Tensor, optional) If provided, fill_constant Op will use this "
"as value to set the output Tensor, this has a higher priority "
"than attr(str_value), the shape of this tensor MUST BE [1].")
.AsDispensable();
AddInput("ShapeTensor",
"(Tensor<int>), optional). The shape of the output."
"It has a higher priority than Attr(shape).")
.AsDispensable();
AddInput("ShapeTensorList",
"(vector<Tensor<int>>, optional). The shape of the output. "
"It has a higher priority than Attr(shape)."
"The shape of the element in vector must be [1].")
.AsDuplicable()
.AsDispensable();
AddAttr<float>("value", "(float, default 0.0f) The value to be filled")
.SetDefault(0.0f);
AddAttr<std::string>(
"str_value",
"(string, default empty) The str convert to value to be filled")
.SetDefault("");
AddAttr<bool>("force_cpu",
"(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
"device")
.SetDefault(false);
AddAttr<int>("place_type",
"(int, default -1) allow mamually setting place where the "
"variable should be hold. "
"-1: not set manually, determine the place by executor. "
"0: CPUPlace. "
"1: CUDAPlace. "
"2: CUDAPinnedPlace. "
"3: XPUPlace. ")
.SetDefault(-1);
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
}
};
可以看到,它是有输入的,但是属于可变输入。对旧IR而言,可变输入意味着,某个参数既有可能是Op的输入(Tensor,来自另一个Op的输出),又有可能是Op的属性(常量)。在新IR下,这种情况应该被统一为Op的输入。
为了将fill_constant翻译到新IR下,我们根据输入的不同,将它翻译到不同Op:
- 对于不包含可变输入的情形,将其翻译为
fullop。 - 如果包含可变输入,将其翻译为
full_with_inputop。full_with_input需要被注册到paddle_dialect中。
攻略方式
1.注册新OP
在paddle/fluid/ir/dialect/pd_op.yaml中注册一个新Op full_with_input,它的一个可能定义如下:
- name: full_with_input
inputs:
- typename: Tensor
name: value
optional: true
no_need_buffer: false
data_transform: {}
- typename: Tensor
name: shape
optional: true
no_need_buffer: false
data_transform: {}
attrs:
- {typename: DataType, name: dtype}
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
custom_verify: true
这是一个示例定义,你可能需要修改它。
另外注意到这里设置了custom_verify: true(这依赖PR #55428),这个字段表明不需要为该Op自动生成一份verify函数。 verify函数用于校验函数签名,在custom_verify: true为true时,其实现一般位于paddle/fluid/ir/dialect/pd_op_verify.cc中。你可以先观察下去掉该字段自动生成的verify函数,考虑下为什么你会需要一个自定义的verify函数。自动生成的op定义位于${PADDLE_BINARY_DIR}/paddle/fluid/ir/dialect/pd_op.cc中。
完善ProgramTranslator
1. 在paddle/fluid/ir_adaptor/translator/op_translator.cc中注册FillConstantTranscriber
OpTranslator::OpTranslator() {
general_handler = OpTranscriber();
special_handlers["feed"] = FeedOpTranscriber();
....
special_handlers["fill_constant"] = FillConstantTranscriber();
}这段代码表明当ProgramTranslator遇到fill_constant时,会调用 FillConstantTranscriber进行处理。
2. FillConstantTranscriber的可能实现
struct FillConstantTranscriber : public OpTranscriber {
ir::Operation* operator()(ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
ir::Program* program) override {
bool has_mutable_attribute = op_desc.HasInput("ShapeTensor", true) && op_desc.Input("ShapeTensor", true).size() > 0;
has_mutable_attribute |= op_desc.HasInput("ShapeTensorList", true) && op_desc.Input("ShapeTensorList", true).size() > 0;
has_mutable_attribute |= op_desc.HasInput("ValueTensor", true) && op_desc.Input("ValueTensor", true).size() > 0;
if (has_mutable_attribute) {
return OpTranscriber()(ctx, param_map, op_desc, program);
} else {
return FullWithInputOpTranscriber()(ctx, param_map, op_desc, program);
}
}
};这段代码按照op_desc中是否包含可变输入,将fill_constant的翻译分别转发给 已有的翻译函数OpTranscriber和尚未实现的FullWithInputOpTranscriber。后者是这项工作的主要部分。
3. 实现 FillConstantTranscriber的一些参考
目前在paddle/fluid/ir_adaptor/translator/op_translator.cc已有一些特殊的翻译函数,比较值得参考的包括AssignValueOpTranscriber、FetchOpTranscriber.
正确性验证
1. 基础目标
- 手写使用了
fill_constant的静态图网络,开启FLAGS_enable_new_ir_in_executor=1并验证执行的正确性
2. 进阶目标
- 在开启
FLAGS_NEW_IR_OPTEST=1 FLAGS_NEW_IR_OPTEST_WHITE_LIST=1的情形下通过单测test_fill_constant_op。如果你想实现这个目标,需要做的可能比上面列出的事项要多一些,但是逻辑是相似的,欢迎尝试。
奖励
一颗【夏玻利利葡萄】