Skip to content

Parser fails to parse a float attribute when it is -inf #5102

@justinchuby

Description

@justinchuby

Bug Report

Is the issue related to model conversion?

No

Describe the bug

Parser fails to parse tmp_11 = Constant <value_float = -inf> () (microsoft/onnxscript#567 (comment))

System information

  • ONNX version (e.g. 1.13): main
  • Python version: 3.10

Reproduction instructions

import onnx

model_text = """
<
ir_version: 8,
opset_import: ["" : 18, "onnxscript.atenlib" : 1, "this" : 1],
producer_name: "pytorch",
producer_version: "2.1.0"
>
torch_jit (float[4,3,8] input_0, float[4,4,6,8] input_1, float[4,4,6,8] input_2) => (float[4,4,3,8] output) {
intermediate_3 = this._attention_scale (input_0)
intermediate_4 = this._causal_attention_mask (input_0, input_1)
output = onnxscript.atenlib._aten_scaled_dot_product_attention_float_mask_onnx <dropout_p = 0> (input_0, input_1, input_2, intermediate_4, intermediate_3)
}
<
domain: "this",
opset_import: ["" : 18]
>
_attention_scale (query) => (scale)
{
tmp = Shape (query)
subscript_index = Constant <value = int64[1] subscript_index {-1}> ()
tmp_gather = Gather (tmp, subscript_index)
subscript_axis = Constant <value = int64[1] subscript_axis {0}> ()
tmp_0 = Squeeze (tmp_gather, subscript_axis)
embedding_size = CastLike (tmp_0, query)
tmp_1 = Constant <value = float tmp_1 {1}> ()
tmp_2 = Sqrt (embedding_size)
tmp_1_cast = CastLike (tmp_1, tmp_2)
scale = Div (tmp_1_cast, tmp_2)
}
<
domain: "this",
opset_import: ["" : 18]
>
_causal_attention_mask (query, key) => (attn_mask_13)
{
tmp = Shape (query)
one = Constant <value = int64[1] one {1}> ()
subscript_axis = Constant <value = int64[1] subscript_axis {0}> ()
tmp_0 = Constant <value = int64 tmp_0 {-2}> ()
tmp_0_reshaped = Reshape (tmp_0, one)
tmp_1 = Constant <value = int64 tmp_1 {-1}> ()
tmp_1_reshaped = Reshape (tmp_1, one)
target_length = Slice (tmp, tmp_0_reshaped, tmp_1_reshaped, subscript_axis)
tmp_2 = Shape (key)
one_3 = Constant <value = int64[1] one_3 {1}> ()
subscript_axis_4 = Constant <value = int64[1] subscript_axis_4 {0}> ()
tmp_5 = Constant <value = int64 tmp_5 {-2}> ()
tmp_5_reshaped = Reshape (tmp_5, one_3)
tmp_6 = Constant <value = int64 tmp_6 {-1}> ()
tmp_6_reshaped = Reshape (tmp_6, one_3)
source_length = Slice (tmp_2, tmp_5_reshaped, tmp_6_reshaped, subscript_axis_4)
size = Concat <axis = 0> (target_length, source_length)
tmp_7 = Constant <value = float tmp_7 {1}> ()
attn_mask = Expand (tmp_7, size)
attn_mask_8 = Trilu <upper = 0> (attn_mask)
tmp_9 = Constant <value = float tmp_9 {0}> ()
tmp_9_cast = CastLike (tmp_9, attn_mask_8)
tmp_10 = Equal (attn_mask_8, tmp_9_cast)
tmp_11 = Constant <value_float = -inf> ()
tmp_12 = Constant <value = float tmp_12 {0}> ()
tmp_12_cast = CastLike (tmp_12, tmp_11)
attn_mask_13 = Where (tmp_10, tmp_11, tmp_12_cast)
}
<
domain: "onnxscript.atenlib",
opset_import: ["" : 18]
>
_aten_scaled_dot_product_attention_float_mask_onnx <dropout_p>(query, key, value, attn_mask, scale) => (return_val)
{
key_shape = Shape (key)
one = Constant <value = int64[1] one {1}> ()
subscript_axis = Constant <value = int64[1] subscript_axis {0}> ()
tmp = Constant <value = int64 tmp {-1}> ()
tmp_reshaped = Reshape (tmp, one)
key_shape_shape = Shape (key_shape)
key_shape_shape_dim = Gather (key_shape_shape, subscript_axis)
key_last_dim = Slice (key_shape, tmp_reshaped, key_shape_shape_dim, subscript_axis)
one_0 = Constant <value = int64[1] one_0 {1}> ()
subscript_axis_1 = Constant <value = int64[1] subscript_axis_1 {0}> ()
tmp_2 = Constant <value = int64 tmp_2 {-2}> ()
tmp_2_reshaped = Reshape (tmp_2, one_0)
tmp_3 = Constant <value = int64 tmp_3 {-1}> ()
tmp_3_reshaped = Reshape (tmp_3, one_0)
key_second_last_dim = Slice (key_shape, tmp_2_reshaped, tmp_3_reshaped, subscript_axis_1)
one_4 = Constant <value = int64[1] one_4 {1}> ()
subscript_axis_5 = Constant <value = int64[1] subscript_axis_5 {0}> ()
tmp_6 = Constant <value = int64 tmp_6 {-2}> ()
tmp_6_reshaped = Reshape (tmp_6, one_4)
key_first_dims = Slice (key_shape, subscript_axis_5, tmp_6_reshaped, subscript_axis_5)
tmp_7 = Constant <value_ints = [-1]> ()
key_squeezed_shape = Concat <axis = 0> (tmp_7, key_second_last_dim, key_last_dim)
key_squeezed = Reshape (key, key_squeezed_shape)
key_squeezed_transposed = Transpose <perm = [0, 2, 1]> (key_squeezed)
key_transposed_shape = Concat <axis = 0> (key_first_dims, key_last_dim, key_second_last_dim)
key_transposed = Reshape (key_squeezed_transposed, key_transposed_shape)
tmp_8 = Sqrt (scale)
query_scaled = Mul (query, tmp_8)
tmp_9 = Sqrt (scale)
key_transposed_scaled = Mul (key_transposed, tmp_9)
tmp_10 = MatMul (query_scaled, key_transposed_scaled)
tmp_11 = Add (tmp_10, attn_mask)
attn_weight = Softmax <axis = -1> (tmp_11)
dropout_p = Constant <value_float: float = @dropout_p> ()
attn_weight_12, _ = Dropout (attn_weight, dropout_p)
return_val = MatMul (attn_weight_12, value)
}
"""

model = onnx.parser.parse_model(model_text)

Expected behavior

Notes

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions