Browse Source

batch train for bert 2.x

pull/491/head
wangzhengjun 3 years ago
parent
commit
ecd5ef56e5
3 changed files with 5 additions and 0 deletions
  1. +3
    -0
      parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc
  2. +1
    -0
      tests/st/testcase/test_tensorflow_parser.cc
  3. +1
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 3
- 0
parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc View File

@@ -109,6 +109,9 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge
return FAILED; return FAILED;
} }
} }
const auto out_desc = op_dest->MutableOutputDesc(0);
GE_CHECK_NOTNULL(out_desc);
out_desc->SetDataType(out_type);


std::shared_ptr<NodeDef> pkg_node = ge::parser::MakeShared<NodeDef>(); std::shared_ptr<NodeDef> pkg_node = ge::parser::MakeShared<NodeDef>();
GE_CHECK_NOTNULL(pkg_node); GE_CHECK_NOTNULL(pkg_node);


+ 1
- 0
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -1913,6 +1913,7 @@ TEST_F(STestTensorflowParser, tensorflow_auto_mapping_parser_adapter_test)
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);


op_dest->SetType(ge::parser::SHAPE); op_dest->SetType(ge::parser::SHAPE);
op_dest->AddOutputDesc(GeTensorDesc());
ret = autoMappingParser.ParseParams(node_def, op_dest); ret = autoMappingParser.ParseParams(node_def, op_dest);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
} }


+ 1
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -2082,6 +2082,7 @@ TEST_F(UtestTensorflowParser, tensorflow_auto_mapping_parser_adapter_test)
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);


op_dest->SetType(ge::parser::SHAPE); op_dest->SetType(ge::parser::SHAPE);
op_dest->AddOutputDesc(GeTensorDesc());
ret = autoMappingParser.ParseParams(node_def, op_dest); ret = autoMappingParser.ParseParams(node_def, op_dest);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
} }


Loading…
Cancel
Save