-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Labels
Description
Bug Report
Describe the bug
Inputs used in a subgraph disappear from their graph when two models are merged through onnx.compose.merge_models
. It only happens when the subgraph is in the second graph when merging.
If the two models have the same inputs names, the error does not appear as the second graph relies on the input node of the first one.
See Reproduction instructions
.
onnx.onnx_cpp2py_export.checker.ValidationError: Nodes in a graph must be topologically sorted, however input 'base_two' of node:
name: Slice_22 OpType: Slice
is not output of any previous nodes.
System information
- OS Platform and Distribution: Linux Ubuntu 20.04
- ONNX version: '1.15.0'
- Python version: 3.9
Reproduction instructions
Full reproduction script. Generated ONNX files are attached.
onnx_models.zip
import onnx
import torch
from torch import nn
from torch import Tensor
class ModelGraphOne(nn.Module):
def __init__(self):
super().__init__()
def forward(self, inputs: Tensor) -> Tensor:
return 2 * inputs
@torch.jit._script_if_tracing
def extract_from_coords(base: Tensor, coords: Tensor):
n_extractions = coords.size(0)
extractions = torch.zeros(n_extractions, 10, 10)
for extract_id, extract_coords in enumerate(coords):
x1 = extract_coords[0]
x2 = extract_coords[1]
y1 = extract_coords[2]
y2 = extract_coords[3]
extractions[extract_id] = base[x1:x2, y1:y2]
return extractions
class ModelGraphTwo(nn.Module):
def __init__(self):
super().__init__()
def forward(self, base: Tensor, coords: Tensor) -> Tensor:
extractions = extract_from_coords(base, coords)
return extractions
base = torch.randn(100, 100)
coords = torch.tensor([[0, 10, 0, 10], [10, 20, 30, 40]])
model_one = ModelGraphOne()
model_two = ModelGraphTwo()
# Everything goes as planned as inputs names of model_one and model_two are the same
torch.onnx.export(model_one, args=(base, ), f="model_one.onnx", input_names=["base"], output_names=["base_mult"])
torch.onnx.export(model_two, args=(base, coords), f="model_two.onnx", input_names=["base", "coords"]) # has a subgraph
onnx_model_one = onnx.load("model_one.onnx")
onnx_model_two = onnx.load("model_two.onnx")
merged_model = onnx.compose.merge_models(onnx_model_one, onnx_model_two, io_map=[("base_mult", "base")])
print("Models merged successfully")
# Inputs names of model one and model two are now different
# If the model containing the subgraph is the first one during the merge, everything goes as planned
torch.onnx.export(model_one, args=(base, ), f="model_one.onnx", input_names=["base"], output_names=["base_mult"])
torch.onnx.export(model_two, args=(base, coords), f="model_two.onnx", input_names=["base_two", "coords"], output_names=["extract"]) # has a subgraph
onnx_model_one = onnx.load("model_one.onnx")
onnx_model_two = onnx.load("model_two.onnx")
onnx.compose.merge_models(onnx_model_two, onnx_model_one, io_map=[("extract", "base")])
print("Models merged successfully")
# If the model containing the subgraph is the second one during the merge -> KO
onnx.compose.merge_models(onnx_model_one, onnx_model_two, io_map=[("base_mult", "base_two")])
# onnx.onnx_cpp2py_export.checker.ValidationError: Nodes in a graph must be topologically sorted, however input 'base_two' of node:
# name: Slice_22 OpType: Slice
# is not output of any previous nodes.
#
# ==> Context: Bad node spec for node. Name: Loop_8 OpType: Loop
Expected behavior
onnx.compose.merge_models(onnx_model_one, onnx_model_two, io_map=[("base_mult", "base_two")])
should work.