-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Labels
questionQuestions about ONNXQuestions about ONNXtopic: spec clarificationClarification of the ONNX spec neededClarification of the ONNX spec needed
Milestone
Description
Ask a Question
Question
I write a onnx If op case, the condition input is np.array([1, 2, 0, 3]).astype(bool), the output is then_branch, if input is np.array([0, 2, 0, 3]).astype(bool), the output is else_branch, so from result, it looks like the first element decides True or False, but in onnx source code op_if.py:
if len(cond.shape) > 0:
try:
evaluated_condition = all(cond)
all means check if all items in a list are True, so it doesn't matter the experiment result, because [1, 2, 0, 3] has 0.
Further information
-
Relevant Area:
-
Is this issue related to a specific model?
Model name:
Model opset:
# Define input and output tensors
cond = onnx.helper.make_tensor_value_info('cond', onnx.TensorProto.BOOL, [])
res = onnx.helper.make_tensor_value_info('res', onnx.TensorProto.FLOAT, [5])
# Define then and else output tensors
then_out = onnx.helper.make_tensor_value_info('then_out', onnx.TensorProto.FLOAT, [5])
else_out = onnx.helper.make_tensor_value_info('else_out', onnx.TensorProto.FLOAT, [5])
# Define constant nodes for then and else branches
x = np.array([1, 2, 3, 4, 5]).astype(np.float32)
y = np.array([5, 4, 3, 2, 1]).astype(np.float32)
then_const_node = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=['then_out'],
value=onnx.numpy_helper.from_array(x),
)
else_const_node = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=['else_out'],
value=onnx.numpy_helper.from_array(y),
)
# Define then and else subgraphs
then_body = onnx.helper.make_graph(
[then_const_node],
'then_body',
[],
[then_out]
)
else_body = onnx.helper.make_graph(
[else_const_node],
'else_body',
[],
[else_out]
)
# Define If node
if_node = onnx.helper.make_node(
'If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body,
)
graph_def = helper.make_graph([if_node],
'if_model',
inputs=[cond],
outputs=[res])
# 将模型转换为ONNX格式
model = helper.make_model(graph_def, producer_name='onnx-example', opset_imports=[helper.make_opsetid("", 11)])
onnx.save(model, 'if_model.onnx')
# Print the model
# print(model)
# Run the model using onnxruntime
sess = onnxruntime.InferenceSession(model.SerializeToString())
cond_tensor = np.array([1, 2, 0, 3]).astype(bool)
print("cond_tensor: ", cond_tensor)
# print("all(cond_tensor): ", all(cond_tensor))
input_data = {'cond': cond_tensor}
output = sess.run([], input_data)
print("if_output: ", output)
Notes
Metadata
Metadata
Assignees
Labels
questionQuestions about ONNXQuestions about ONNXtopic: spec clarificationClarification of the ONNX spec neededClarification of the ONNX spec needed