Browse Source

size op parse process out_dtype attr

pull/343/head
wangxiaotian22 3 years ago
parent
commit
d825ba37cf
7 changed files with 1655 additions and 2 deletions
  1. +7
    -0
      inc/external/OWNERS
  2. +12
    -0
      parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc
  3. +1
    -0
      tests/CMakeLists.txt
  4. +81
    -0
      tests/depends/graph/CMakeLists.txt
  5. +1489
    -0
      tests/depends/graph/src/attr_util_stub.cc
  6. +2
    -2
      tests/ut/parser/CMakeLists.txt
  7. +63
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_auto_mapping_parser_adapter_unittest.cc

+ 7
- 0
inc/external/OWNERS View File

@@ -0,0 +1,7 @@
approvers:
- gegenhua
reviewers:
- xchu42
- sheng-nan
- ji_chen
- wqtshg

+ 12
- 0
parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc View File

@@ -84,6 +84,18 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge
op_dest->GetType().c_str(), dynamic_tensor_num);
}

if (op_dest->GetType() == SIZE) {
ge::DataType out_type = DT_INT32;
if (AttrUtils::GetDataType(op_dest, kShapeAttrOutType, out_type)) {
if (!AttrUtils::SetInt(op_dest, kShapeAttrDtype, static_cast<int64_t>(out_type))) {
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", kShapeAttrDtype,
op_dest->GetName().c_str(), op_dest->GetType().c_str());
GELOGE(FAILED, "Set attr dtype for op:%s failed.", op_dest->GetName().c_str());
return FAILED;
}
}
}

// add nodedef for shape insert by adapter when online_infer_dynamic
if (op_dest->GetType() == SHAPE) {
ge::DataType out_type = DT_INT32;


+ 1
- 0
tests/CMakeLists.txt View File

@@ -18,6 +18,7 @@ add_subdirectory(depends/slog)
add_subdirectory(depends/mmpa)
add_subdirectory(depends/profiler)
add_subdirectory(depends/error_manager)
add_subdirectory(depends/graph)

if (ENABLE_PARSER_COV OR ENABLE_PARSER_UT)
add_subdirectory(ut)


+ 81
- 0
tests/depends/graph/CMakeLists.txt View File

@@ -0,0 +1,81 @@
# Copyright 2019-2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

#cmake_minimum_required(VERSION 2.8)

project(STUB_ATTR_UTIL)

################################################################################
set(PARSER_PROTO_LIST
"${PARSER_DIR}/metadef/proto/om.proto"
"${PARSER_DIR}/metadef/proto/ge_ir.proto"
"${PARSER_DIR}/metadef/proto/task.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/attr_value.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/function.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/graph.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/graph_library.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/node_def.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/op_def.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/resource_handle.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/tensor.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/tensor_shape.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/types.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/versions.proto"
"${PARSER_DIR}/metadef/proto/caffe/caffe.proto"
"${PARSER_DIR}/metadef/proto/onnx/ge_onnx.proto"
#"${PARSER_DIR}/metadef/proto/proto_inner/ge_onnx.proto"
)

protobuf_generate(ge PARSER_PROTO_SRCS PARSER_PROTO_HDRS ${PARSER_PROTO_LIST})


file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"src/attr_util_stub.cc"
)

include_directories(${CMAKE_CURRENT_LIST_DIR})
include_directories(${PARSER_DIR}/metadef/inc)
include_directories(${PARSER_DIR}/metadef/inc/graph)
include_directories(${PARSER_DIR}/metadef/inc/external)
include_directories(${PARSER_DIR}/metadef/inc/external/graph)
include_directories(${PARSER_DIR}/metadef/graph)
include_directories(${PARSER_DIR}/metadef/third_party)
include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc)
include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external)
include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external/ge)
include_directories(${PARSER_DIR}/metadef/third_party/fwkacllib/inc)
include_directories(${PARSER_DIR}/metadef/third_party/transformer/inc)
include_directories(${PARSER_DIR}/metadef)
include_directories(${CMAKE_BINARY_DIR}/proto/ge)
include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto)


add_library(attr_util_stub STATIC
${SRCS} ${PARSER_PROTO_HDRS}
)

target_compile_definitions(attr_util_stub PRIVATE
google=ascend_private
)

target_compile_options(attr_util_stub PRIVATE
-O2 -g -fno-common
)

target_link_libraries(attr_util_stub PRIVATE
$<BUILD_INTERFACE:intf_pub>
ascend_protobuf
c_sec
)

+ 1489
- 0
tests/depends/graph/src/attr_util_stub.cc
File diff suppressed because it is too large
View File


+ 2
- 2
tests/ut/parser/CMakeLists.txt View File

@@ -79,7 +79,6 @@ set(MATEDEF_SRC_FILES
"${PARSER_DIR}/metadef/graph/detail/attributes_holder.cc"
"${PARSER_DIR}/metadef/graph/format_refiner.cc"
"${PARSER_DIR}/metadef/graph/ge_attr_define.cc"
"${PARSER_DIR}/metadef/graph/ge_attr_value.cc"
"${PARSER_DIR}/metadef/graph/ge_tensor.cc"
"${PARSER_DIR}/metadef/graph/gnode.cc"
"${PARSER_DIR}/metadef/graph/graph.cc"
@@ -308,6 +307,7 @@ set(PARSER_UT_FILES
"testcase/onnx_parser_testcase/onnx_parser_unittest.cc"
"testcase/onnx_parser_testcase/message2operator_unittest.cc"
"testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc"
"testcase/tensorflow_parser_testcase/tensorflow_auto_mapping_parser_adapter_unittest.cc"
)

############ libut_parser_common.a ############
@@ -349,6 +349,6 @@ target_link_libraries(ut_parser
$<BUILD_INTERFACE:intf_pub>
ut_parser_proto
-Wl,--whole-archive ut_parser_common -Wl,--no-whole-archive
ut_parser_graph ut_parser_register error_manager_stub mmpa_stub
ut_parser_graph ut_parser_register error_manager_stub mmpa_stub attr_util_stub
gtest gtest_main slog_stub ascend_protobuf c_sec -lrt -ldl -lgcov
)

+ 63
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_auto_mapping_parser_adapter_unittest.cc View File

@@ -0,0 +1,63 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <iostream>
#include "parser/common/op_parser_factory.h"
#include "parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h"
#include "framework/omg/parser/parser_factory.h"
#include "graph/operator_reg.h"
#include "external/graph/types.h"
#include "register/op_registry.h"
#include "parser/common/register_tbe.h"


namespace ge {
class UtestTensorflowAutoMappingParserAdapter : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}

};


TEST_F(UtestTensorflowAutoMappingParserAdapter, success) {
auto parser = TensorFlowAutoMappingParserAdapter();

domi::tensorflow::NodeDef arg_node;
arg_node.set_name("size");
arg_node.set_op("Size");
auto attr = arg_node.mutable_attr();
domi::tensorflow::AttrValue value;
value.set_type(domi::tensorflow::DataType::DT_HALF);
(*attr)["out_type"] = value;

auto op_desc = ge::parser::MakeShared<ge::OpDesc>("size", "Size");
auto ret = parser.ParseParams(reinterpret_cast<Message *>(&arg_node), op_desc);
EXPECT_EQ(ret, ge::SUCCESS);


auto ret2 = ge::AttrUtils::SetBool(op_desc, "test_fail", true);
EXPECT_EQ(ret2, true);
EXPECT_EQ(ge::AttrUtils::HasAttr(op_desc, "test_fail"), true);

ret = parser.ParseParams(reinterpret_cast<Message *>(&arg_node), op_desc);
EXPECT_EQ(ret, ge::FAILED);
}


} // namespace ge

Loading…
Cancel
Save