From ecd5ef56e5310f31ebee4deb5c795ca61f26134c Mon Sep 17 00:00:00 2001 From: wangzhengjun Date: Wed, 23 Mar 2022 17:17:35 +0800 Subject: [PATCH] batch train for bert 2.x --- parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc | 3 +++ tests/st/testcase/test_tensorflow_parser.cc | 1 + .../testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc | 1 + 3 files changed, 5 insertions(+) diff --git a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc index bc50850..76fc5ef 100644 --- a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc +++ b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc @@ -109,6 +109,9 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge return FAILED; } } + const auto out_desc = op_dest->MutableOutputDesc(0); + GE_CHECK_NOTNULL(out_desc); + out_desc->SetDataType(out_type); std::shared_ptr pkg_node = ge::parser::MakeShared(); GE_CHECK_NOTNULL(pkg_node); diff --git a/tests/st/testcase/test_tensorflow_parser.cc b/tests/st/testcase/test_tensorflow_parser.cc index 07346f9..7ed5ef8 100644 --- a/tests/st/testcase/test_tensorflow_parser.cc +++ b/tests/st/testcase/test_tensorflow_parser.cc @@ -1913,6 +1913,7 @@ TEST_F(STestTensorflowParser, tensorflow_auto_mapping_parser_adapter_test) EXPECT_EQ(ret, SUCCESS); op_dest->SetType(ge::parser::SHAPE); + op_dest->AddOutputDesc(GeTensorDesc()); ret = autoMappingParser.ParseParams(node_def, op_dest); EXPECT_EQ(ret, SUCCESS); } diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc index 0afb636..710609e 100644 --- a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc @@ -2082,6 +2082,7 @@ TEST_F(UtestTensorflowParser, tensorflow_auto_mapping_parser_adapter_test) EXPECT_EQ(ret, SUCCESS); op_dest->SetType(ge::parser::SHAPE); + op_dest->AddOutputDesc(GeTensorDesc()); ret = autoMappingParser.ParseParams(node_def, op_dest); EXPECT_EQ(ret, SUCCESS); }