Skip to content

Non scalar as condition to If op #5520

@Mexyy

Description

@Mexyy

Ask a Question

Question

I write a onnx If op case, the condition input is np.array([1, 2, 0, 3]).astype(bool), the output is then_branch, if input is np.array([0, 2, 0, 3]).astype(bool), the output is else_branch, so from result, it looks like the first element decides True or False, but in onnx source code op_if.py:

if len(cond.shape) > 0:
    try:
         evaluated_condition = all(cond)

all means check if all items in a list are True, so it doesn't matter the experiment result, because [1, 2, 0, 3] has 0.

Further information

  • Relevant Area:

  • Is this issue related to a specific model?
    Model name:
    Model opset:

    # Define input and output tensors
    cond = onnx.helper.make_tensor_value_info('cond', onnx.TensorProto.BOOL, [])
    res = onnx.helper.make_tensor_value_info('res', onnx.TensorProto.FLOAT, [5])

    # Define then and else output tensors
    then_out = onnx.helper.make_tensor_value_info('then_out', onnx.TensorProto.FLOAT, [5])
    else_out = onnx.helper.make_tensor_value_info('else_out', onnx.TensorProto.FLOAT, [5])

    # Define constant nodes for then and else branches
    x = np.array([1, 2, 3, 4, 5]).astype(np.float32)
    y = np.array([5, 4, 3, 2, 1]).astype(np.float32)

    then_const_node = onnx.helper.make_node(
        'Constant',
        inputs=[],
        outputs=['then_out'],
        value=onnx.numpy_helper.from_array(x),
    )

    else_const_node = onnx.helper.make_node(
        'Constant',
        inputs=[],
        outputs=['else_out'],
        value=onnx.numpy_helper.from_array(y),
    )

    # Define then and else subgraphs
    then_body = onnx.helper.make_graph(
        [then_const_node],
        'then_body',
        [],
        [then_out]
    )

    else_body = onnx.helper.make_graph(
        [else_const_node],
        'else_body',
        [],
        [else_out]
    )

    # Define If node
    if_node = onnx.helper.make_node(
        'If',
        inputs=['cond'],
        outputs=['res'],
        then_branch=then_body,
        else_branch=else_body,
    )

    graph_def = helper.make_graph([if_node],
                                  'if_model',
                                  inputs=[cond],
                                  outputs=[res])

    # 将模型转换为ONNX格式
    model = helper.make_model(graph_def, producer_name='onnx-example', opset_imports=[helper.make_opsetid("", 11)])
    onnx.save(model, 'if_model.onnx')

    # Print the model
    # print(model)

    # Run the model using onnxruntime
    sess = onnxruntime.InferenceSession(model.SerializeToString())
    cond_tensor = np.array([1, 2, 0, 3]).astype(bool)
    print("cond_tensor: ", cond_tensor)
    # print("all(cond_tensor): ", all(cond_tensor))
    input_data = {'cond': cond_tensor}
    output = sess.run([], input_data)
    print("if_output: ", output)

Notes

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions