Skip to content

【快乐开源】CINN编译器符号推导扩量 #66444

@gongshaotian

Description

@gongshaotian

任务划分🧾

快乐开源2024

符号推导按照接口实现难度划分为“简单”、“中等”和“复杂”三个等级,前期只开放简单任务,后续会逐步开放更具有挑战性的任务。任务列表见:InferSymbolicShape接口实现任务列表🧾

⚠️由于"复杂"等级的任务容易引入隐蔽bug,这里做一些特殊要求:

  • "复杂" 难度问题一次最多领取两个任务,做完再领
  • 为了方便定位bug,每个任务单独提交PR
  • 设置难度等级的目的是为了引导大家循序渐进的了解符号推导任务的目标和实现方法,因此要求【启航计划⛵️】的同学至少完成两个"中等"难度的任务,再去尝试解决"复杂"任务。有特殊情况的联系导师解决。

认领方式
请大家直接在👆的excel表中认领任务,如:

  • 任务认领要求:为了方便排查 Bug,一次最多同时认领5个任务,一个PR最多包含5个 Op 的 InferSymbolicShape 接口实现
  • PR提交格式:【Infer Symbolic Shape No.xxx】开头,注明任务编号
  • 认领后,超过2周没有提交PR,将重新释放

一、需求背景

深度学习模型常需要处理各种形状和尺寸的数据,支持动态 Shape 特性的深度学习框架允许模型在训练和推理过程中适应不同尺寸的输入,从而提高模型的灵活性和通用性。动态 Shape 功能允许模型在训练和推理时延迟计算张量的部分或全部维度,直到运行时再确定,因此可以根据实际的输入尺寸选择合适的优化策略,以达到最佳性能。
paddle.reshape()为例,现在的动态维度用 "-1" 表示,信息量少表示能力较弱,很多约束信息没法表示出来。而在CINN编译器引入了Shape Dialect之后,CINN能够直接基于动态Shape语义进行编译与优化。在CINN内部能够直接使用“S0”、“S1”这样的符号表示张量的维度信息,并能够通过添加一些维度约束限制为简化符号推导过程和后续编译优化提供指导信息

二、参考文档

在实现具体算子的符号推导接口时,需要了解符号的表示推导相关的基础概念
7954e1ad8f9b33a42f3f4d88c1d604b2

2.1 动态Shape符号表示

在CINN中,符号通过DimExpr、ShapeOrData、ShapeOrDataDimExprs三个不同的抽象层次进行表示
image

2.1.1 DimExpr

DimExpr是Shape Dialect最底层的数据结构,用于表示单个维度对应的符号信息。目前符号表示的语法支持int64_t的整数,string(一般为符号推导产生的从"S0","S1","S2"...一系列的新符号)以及加减乘除等复合语法,符号的操作也支持加减乘除运算以及相等判断。

// paddle/pir/include/dialect/shape/utils/dim_expr.h

using DimExprBase = std::variant<std::int64_t,
                                 std::string,
                                 Negative<DimExpr>,
                                 Reciprocal<DimExpr>,
                                 Add<DimExpr>,
                                 Mul<DimExpr>,
                                 Max<DimExpr>,
                                 Min<DimExpr>,
                                 Broadcast<DimExpr>>;

2.1.2 ShapeOrData

基于单个维度的抽象表示DimExpr,Shape Dialect使用ShapeOrData表示Tensor对应的符号维度信息,如[1, S0, S1]。

// paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h

std::vector<T> shape_;
std::optional<std::vector<T>> data_;

解释说明:一般来说,Tensor对应的符号维度信息用vector数据结构表示就够了,但对于一些特殊情况,vector的表示能力会有所不足,例如 :

y = paddle.shape(x) // shape of x: [S0, S1],  shape of y: [2] 
z = y.reshape(y)    // z[S0, S1]

在这个例子中,Tensor y 是存储了Tensor x 的 shape 信息,y 本身的shape为[2]。单算子的符号推导是根据操作数的shape信息和attribute信息实现的,此时如果要实现reshape op的符号推导接口就会发现我们真正需要的是Tensor y本身存储的信息[S0, S1],而不是 shape of y: [2] ,因此CINN中设计了data区以提升张量形状信息的表示能力。

2.1.3 ShapeOrDataDimExprs

ShapeOrDataDimExprs用于表示value的符号信息。由于value不仅可以是DenseTensorType还可能是VectorType,相应地ShapeOrDataDimExprs也需要是TensorShapeOrDataDimExprs或TensorListShapeOrDataDimExprs,实现中使用std::variant支持多种类型控制。

// paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h

using NullShapeOrDataDimExpr = std::monostate;
using TensorShapeOrDataDimExprs = ShapeOrData<DimExpr>;
using TensorListShapeOrDataDimExprs = std::vector<TensorShapeOrDataDimExprs>;

class RankedTensorArrayShapeOrDataDimExprs {
  //...
};

using ShapeOrDataDimExprsBase =
    std::variant<NullShapeOrDataDimExpr,
                 TensorShapeOrDataDimExprs,
                 TensorListShapeOrDataDimExprs,
                 RankedTensorArrayShapeOrDataDimExprs>;

class ShapeOrDataDimExprs : public ShapeOrDataDimExprsBase

2.2 符号推导重要组件

在实现单算子符号推导接口时,除了基础的符号表示,还需要掌握 CINN 编译器的 Shape Dialect 中提供的两个重要组件:符号推导的上下文管理器 InferSymbolicShapeContext 和 符号间的约束信息管理 ConstraintsManager
image

2.2.1 InferSymbolicShapeContext

InferSymbolicShapeContext是符号推导的一个上下文环境类,单算子符号推导接口开发主要会用到以下接口:

接口 功能描述
GetNextSymName 获取下一个新符号
GetShapeOrDataForValue 获取value对应的符号信息
SetStaticShapeForValue 作为保底机制给value设置静态符号,如果含有-1则赋予一个新符号
SetShapeOrDataForValue 给value设置符号信息
AddEqualCstr 给符号增加相等约束
AddGreatThanOneCstr 给符号增加大于一的约束
AddBroadcastableCstr 给符号增加Broadcastable约束

2.2.2 符号约束

建立约束的核心设计理念是,减少新符号,提升性能。目前约束包括Equal、GTOne(大于 1)、Broadcastable。实现具体Op的符号推导接口时无需关注ConstraintsManager的具体实现,只需了解上诉三个约束的添加方法:

// InferSymbolicShapeContext *infer_context;

infer_context->AddEqualCstr(const symbol::DimExpr& lhs, const symbol::DimExpr& rhs);
infer_context->AddGreatThanOneCstr(const symbol::DimExpr& dim_expr);
infer_context->AddBroadcastableCstr(const symbol::DimExpr& lhs, const symbol::DimExpr& rhs);

其中的Broadcastable约束,将从后向前比较两个 Tensor 的形状,需要满足如下至少一个条件才能进行广播:

  • 两个 Tensor 的维度大小相等
  • 其中一个 Tensor 的维度等于 1
  • 其中一个 Tensor 的维度不存在
    因此如果能知道某个维度的值大于1,那么它在参与Broadcast时候一定与最终的Broadcast结果相同;如果两个参与Broadcast的维度都大于1,那么这两个维度一定相等且于最终的Broadcast结果相同。

2.2.3 单算子符号推导

算子继承InferSymbolicShapeInterface接口来实现符号推导,该接口传入符号推导上下文并对齐进行修改。具体来说就是从符号推导上下文中获取所需输入value的符号信息然后通过符号计算得到并在符号推导上下文中设置输入value的符号信息。除此之外,需要根据算子的计算特点在符号推导上下文中加入符号的约束关系。以Matmul算子的符号推导接口为例:

#define OP_DECLARE_INFER_SYMBOLIC_SHAPE(name) \
  bool name##OpInferSymbolicShape(            \
      pir::Operation* op, pir::InferSymbolicShapeContext* infer_context);

// paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Matmul)

// paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc
bool MatmulOpInferSymbolicShape(pir::Operation *op,
                                pir::InferSymbolicShapeContext *infer_context) {
  // ......
  infer_context->SetShapeOrDataForValue(op->result(0),
                                        ShapeOrData{TensorExprs(out_dims)});

  if ((ndims_x == ndims_y) && ndims_x >= 2) {
    if (transpose_x_attr == false && transpose_y_attr == false) {
      infer_context->AddEqualCstr(x_dims[ndims_x - 1], y_dims[ndims_x - 2]);
    } else if (transpose_x_attr == false && transpose_y_attr == true) {
      infer_context->AddEqualCstr(x_dims[ndims_x - 1], y_dims[ndims_x - 1]);
    } else if (transpose_x_attr == true && transpose_y_attr == false) {
      infer_context->AddEqualCstr(x_dims[ndims_x - 2], y_dims[ndims_x - 2]);
    } else {
      infer_context->AddEqualCstr(x_dims[ndims_x - 2], y_dims[ndims_x - 1]);
    }

    for (size_t i = 0; i < ndims_x - 2; ++i) {
      infer_context->AddEqualCstr(x_dims[i], y_dims[i]);
    }
  }
  return true;
}

三、开发流程和示例

3.1 开发流程

  1. Op声明中添加符号推导接口 主要分为以下几个类别
    • 在 paddle/phi/ops/yaml/ops.yaml 文件或paddle/phi/ops/yaml/inconsistent/static_ops.yaml 文件中添加 InferSymbolicShapeInterface 接口
  2. 接口实现
    • 通过yaml文件中添加的Op,需要根据 Op的参数类型在 paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape 目录下对应文件中添加接口实现(参考InferMeta的放置)
    • 实现方式:
  3. 本地Op级别测试
    • 在 test/legacy_test/目录和 test/deprecated/legacy_test目录下搜索相关单测并执行
      • 由 OpTest 实现的单测,在 test_check_output() 中同时打开 check_pir 和 check_symbol_infer 两个flag时会触发符号推导单测检查机制,检查添加对应Op的符号推导结果是否符合预期
      • 部分单测给定信息不足,会导致该单测无法推导出具体的维度,需要视情况关闭(谨慎⚠️)该单测的 check_symbol_infer 符号推导单测检查 flag
      • 单测实现方式有基于 OpTest 和基于 unittest 的两种实现方式,缺少基于OpTest的单测会导致 Coverage 流水线执行失败,需要手动构造至少一个对应单测⚠️
  4. 提交PR跑CI验证
  5. 合入PR,修改任务进度

编译和测试命令
编译命令(以python3.9为例):

cd build

rm -rf CMakeCache.txt

cmake .. -DWITH_GPU=ON \
        -DWITH_PROFILER=OFF \
        -DPY_VERSION=3.9 \
        -DWITH_DISTRIBUTE=OFF  \
        -DON_INFER=ON \
        -DWITH_TESTING=ON \
        -DWITH_CINN=ON

make -j 32

python -m pip uninstall paddlepaddle-gpu -y
python -m pip install python/dist/paddlepaddle_gpu-0.0.0-cp39-cp39-linux_x86_64.whl

测试命令(以test_reshape_op.py为例):

FLAGS_print_ir=1 GLOG_v=3 ctest -R test_reshape_op -VV > reshape_op.txt 2>&1

3.2 添加接口示例PR

#65880
#65889

3.3 常见Debug问题

在本地Op级别测试和CI验证的Debug阶段可能会遇到如下类型的问题:

  1. 误用data fix ExpandOp::InferSymbolicShape #63755
  2. 未将所有输出value的符号都设置,漏设置 fix FlashAttnOpInferSymbolicShape and FlashAttnInferMeta #63816
  3. 没有data却误调用带data的构造函数 [CINN]fix InferSymbolicShape for builtin.slice and pd_op.stack #64173
  4. 对vector tensor的错误设置 [CINN]fix InferSymbolicShape for builtin.slice and pd_op.stack #64173
  5. 忘记考虑0D情况 [CINN]fix 0-D AssignValueOpInferSymbolicShape #64615

Metadata

Metadata

Labels

PFCCPaddle Framework Contributor Club,https://github.com/PaddlePaddle/community/tree/master/pfccstatus/close已关闭type/others其他问题

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions