pytorch 当运行统计数据为无时,在torch.onnx.export上将Batchnorms强制设置为训练模式

nvbavucw  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(202)

正如this git issue中所描述的(非常完整的描述),我试图在openvino后端加载一个.onnx模型。然而,在设置track_running_stats=False时,BatchNorm层被考虑在训练模式中。下面是我在将我的torch模型转换为onnx之前如何做到这一点:

model.eval()

for child in model.children():
    if type(child)==nn.BatchNorm2d:
        child.track_running_stats = False
        child.running_mean = None
        child.running_var = None

字符串
然后,我将模型导出到onnx:

dummy_input = torch.randn(1, 3, 200, 200, requires_grad=True)  
torch.onnx.export(model, dummy_input, model_path,  export_params=True, opset_version=16, training=torch.onnx.TrainingMode.PRESERVE)


最后,我在openvino中加载它时得到了这个错误:

Error: Check '(node.get_outputs_size() == 1)' failed at src/frontends/onnx/frontend/src/op/batch_norm.cpp:67:
While validating ONNX node '<Node(BatchNormalization): BatchNormalization_10>':
Training mode of BatchNormalization is not supported.


正如在git issue中提到的,我尝试查看BatchNorm输入/输出:

for node in onnx_model.graph.node:
    if any(("BatchNorm" in s or "bn" in s) for s in node.input) or any(("BatchNorm" in s or "bn" in s) for s in node.output):
        print('Node:',node.name)
        print(node)


所以你可以看到这些节点与BN相关:

Node: ReduceMean_5
input: "onnx::ReduceMean_22"
output: "onnx::BatchNormalization_23"
name: "ReduceMean_5"
op_type: "ReduceMean"
attribute {
  name: "axes"
  ints: 0
  ints: 1
  type: INTS
}
attribute {
  name: "keepdims"
  i: 0
  type: INT
}

Node: ReduceMean_9
input: "onnx::ReduceMean_26"
output: "onnx::BatchNormalization_27"
name: "ReduceMean_9"
op_type: "ReduceMean"
attribute {
  name: "axes"
  ints: 0
  ints: 1
  type: INTS
}
attribute {
  name: "keepdims"
  i: 0
  type: INT
}

Node: BatchNormalization_10
input: "input"
input: "bn1.weight"
input: "bn1.bias"
input: "onnx::BatchNormalization_23"
input: "onnx::BatchNormalization_27"
output: "input.4"
output: "29"
output: "30"
name: "BatchNormalization_10"
op_type: "BatchNormalization"
attribute {
  name: "epsilon"
  f: 9.999999747378752e-06
  type: FLOAT
}
attribute {
  name: "momentum"
  f: 0.8999999761581421
  type: FLOAT
}
attribute {
  name: "training_mode"
  i: 1
  type: INT
}

Node: BatchNormalization_13
input: "input.8"
input: "bn2.0.weight"
input: "bn2.0.bias"
input: "bn2.0.running_mean"
input: "bn2.0.running_var"
output: "input.12"
name: "BatchNormalization_13"
op_type: "BatchNormalization"
attribute {
  name: "epsilon"
  f: 9.999999747378752e-06
  type: FLOAT
}
attribute {
  name: "momentum"
  f: 0.8999999761581421
  type: FLOAT
}
attribute {
  name: "training_mode"
  i: 0
  type: INT
}

Node: ReduceMean_18
input: "onnx::ReduceMean_37"
output: "onnx::BatchNormalization_38"
name: "ReduceMean_18"
op_type: "ReduceMean"
attribute {
  name: "axes"
  ints: 0
  ints: 1
  type: INTS
}
attribute {
  name: "keepdims"
  i: 0
  type: INT
}

Node: ReduceMean_22
input: "onnx::ReduceMean_41"
output: "onnx::BatchNormalization_42"
name: "ReduceMean_22"
op_type: "ReduceMean"
attribute {
  name: "axes"
  ints: 0
  ints: 1
  type: INTS
}
attribute {
  name: "keepdims"
  i: 0
  type: INT
}

Node: BatchNormalization_23
input: "input.16"
input: "bn3.weight"
input: "bn3.bias"
input: "onnx::BatchNormalization_38"
input: "onnx::BatchNormalization_42"
output: "43"
output: "44"
output: "45"
name: "BatchNormalization_23"
op_type: "BatchNormalization"
attribute {
  name: "epsilon"
  f: 9.999999747378752e-06
  type: FLOAT
}
attribute {
  name: "momentum"
  f: 0.8999999761581421
  type: FLOAT
}
attribute {
  name: "training_mode"
  i: 1
  type: INT
}


你确实可以看到2/3 BN层处于训练模式= 1(-> True)。如何处理它,以便onnx在eval模式下考虑它们,同时保持track_running_stats=False
我不是很熟悉Onnx和更多的全球DL初学者,所以我喜欢任何建议!

pgky5nke

pgky5nke1#

我终于找到了一个解决方案,基于对问题中链接的GitHub问题提出的建议。
因此,将track_running_stats设置为False后,BatchNormalization层将被视为处于训练模式,如Onnx图所示。
我已经直接在图中删除了批处理规范化层中引用meanvar的未使用的输出,然后手动将层设置为eval模式(training_mode = 0)。您必须删除未使用的输出,而不仅仅是将training_mode属性设置为0,否则检查将无法通过。

for node in onnx_model.graph.node:
    if node.op_type == "BatchNormalization":
        for attribute in node.attribute:
            if attribute.name == 'training_mode':
                if attribute.i == 1:
                    node.output.remove(node.output[1])
                    node.output.remove(node.output[1])
                attribute.i = 0

字符串
在那之后,我能够正确地进行推理,并得到预期的结果。

相关问题