|
|
@@ -143,9 +143,11 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
namespace { |
|
|
|
std::map<std::string, std::string> kOnnxOpMap = { |
|
|
|
const std::map<std::string, std::string> kOnnxOpMap = { |
|
|
|
{ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeConstant, ge::parser::CONSTANT}, |
|
|
|
}; |
|
|
|
const char* const MATMULV2 = "MatMulV2"; |
|
|
|
const std::vector<std::string> kNoNeedUpdateFormat = {MATMULV2}; |
|
|
|
} |
|
|
|
|
|
|
|
Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, |
|
|
@@ -419,6 +421,11 @@ void OnnxModelParser::UpdateFormat(ge::Graph &graph) { |
|
|
|
ge::Operator op; |
|
|
|
graph.FindOpByName(name, op); |
|
|
|
auto op_dsc = ge::OpDescUtils::GetOpDescFromOperator(op); |
|
|
|
if (std::find(kNoNeedUpdateFormat.begin(), kNoNeedUpdateFormat.end(), op_dsc->GetType()) |
|
|
|
!= kNoNeedUpdateFormat.end()) { |
|
|
|
GELOGW("Op %s:%s no need update format.", op_dsc->GetName().c_str(), op_dsc->GetType().c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto input_size = op_dsc->GetAllInputsSize(); |
|
|
|
for (size_t i = 0; i < input_size; i++) { |
|
|
|
auto input = op_dsc->MutableInputDesc(static_cast<uint32_t>(i)); |
|
|
|