From: @t00456437 Reviewed-by: @ji_chen,@xchu42 Signed-off-by: @ji_chentags/v1.1.0
@@ -37,6 +37,7 @@ if (ENABLE_OPEN_SRC) | |||||
include(cmake/external_libs/protobuf_static.cmake) | include(cmake/external_libs/protobuf_static.cmake) | ||||
include(cmake/external_libs/protoc.cmake) | include(cmake/external_libs/protoc.cmake) | ||||
include(cmake/external_libs/gflags.cmake) | include(cmake/external_libs/gflags.cmake) | ||||
include(cmake/external_libs/gtest.cmake) | |||||
include(cmake/external_libs/securec.cmake) | include(cmake/external_libs/securec.cmake) | ||||
include(cmake/external_libs/json.cmake) | include(cmake/external_libs/json.cmake) | ||||
include(cmake/FindModule.cmake) | include(cmake/FindModule.cmake) | ||||
@@ -78,6 +79,7 @@ if (ENABLE_OPEN_SRC) | |||||
else() | else() | ||||
find_module(slog libslog.so ${ASCEND_ATC_DIR}) | find_module(slog libslog.so ${ASCEND_ATC_DIR}) | ||||
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | ||||
find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | |||||
if(PLATFORM STREQUAL "train") | if(PLATFORM STREQUAL "train") | ||||
find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) | find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) | ||||
find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) | find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) | ||||
@@ -123,8 +125,13 @@ if (ENABLE_OPEN_SRC) | |||||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | ||||
#find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) | #find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) | ||||
else() | else() | ||||
message(FATAL_ERROR "PLATFORM param is invalid, should be train or inference, build terminated") | |||||
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||||
endif() | endif() | ||||
if (ENABLE_GE_COV OR ENABLE_GE_UT) | |||||
add_subdirectory(tests) | |||||
endif() | |||||
endif() | endif() | ||||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | ||||
@@ -59,7 +59,7 @@ checkopts() | |||||
ENABLE_GE_ST="off" | ENABLE_GE_ST="off" | ||||
ENABLE_GE_COV="off" | ENABLE_GE_COV="off" | ||||
GE_ONLY="on" | GE_ONLY="on" | ||||
PLATFORM="inference" | |||||
PLATFORM="" | |||||
PRODUCT="normal" | PRODUCT="normal" | ||||
ENABLE_GITEE="off" | ENABLE_GITEE="off" | ||||
# Process the options | # Process the options | ||||
@@ -166,6 +166,9 @@ build_graphengine() | |||||
elif [ "x${PLATFORM}" = "xinference" ] | elif [ "x${PLATFORM}" = "xinference" ] | ||||
then | then | ||||
TARGET="ge_compiler atc_ge_local_engine atc_ge_local_opskernel_builder atc_host_cpu_engine atc_host_cpu_opskernel_builder atc opensrc_ascendcl ${TARGET}" | TARGET="ge_compiler atc_ge_local_engine atc_ge_local_opskernel_builder atc_host_cpu_engine atc_host_cpu_opskernel_builder atc opensrc_ascendcl ${TARGET}" | ||||
elif [ "X$ENABLE_GE_UT" = "Xon" ] | |||||
then | |||||
TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest" | |||||
elif [ "x${PLATFORM}" = "xall" ] | elif [ "x${PLATFORM}" = "xall" ] | ||||
then | then | ||||
# build all the target | # build all the target | ||||
@@ -0,0 +1,60 @@ | |||||
if (HAVE_GTEST) | |||||
return() | |||||
endif() | |||||
include(ExternalProject) | |||||
if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
(${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) | |||||
set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) | |||||
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | |||||
endif() | |||||
if (ENABLE_GITEE) | |||||
set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.0.tar.gz") | |||||
set(MD5 "") | |||||
else() | |||||
set(REQ_URL "https://github.com/google/googletest/archive/release-1.8.0.tar.gz") | |||||
set(MD5 "") | |||||
endif () | |||||
set (gtest_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") | |||||
set (gtest_CFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") | |||||
ExternalProject_Add(gtest_build | |||||
URL ${REQ_URL} | |||||
CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gtest_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gtest <SOURCE_DIR> | |||||
-DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON | |||||
BUILD_COMMAND $(MAKE) | |||||
INSTALL_COMMAND $(MAKE) install | |||||
EXCLUDE_FROM_ALL TRUE | |||||
) | |||||
set(GTEST_PKG_DIR ${CMAKE_INSTALL_PREFIX}/gtest) | |||||
file(MAKE_DIRECTORY ${GTEST_PKG_DIR}/include) | |||||
add_library(gtest SHARED IMPORTED) | |||||
set_target_properties(gtest PROPERTIES | |||||
IMPORTED_LOCATION ${GTEST_PKG_DIR}/lib/libgtest.so | |||||
) | |||||
add_library(gtest_main SHARED IMPORTED) | |||||
set_target_properties(gtest_main PROPERTIES | |||||
IMPORTED_LOCATION ${GTEST_PKG_DIR}/lib/libgtest_main.so | |||||
) | |||||
target_include_directories(gtest INTERFACE ${GTEST_PKG_DIR}/include) | |||||
target_include_directories(gtest_main INTERFACE ${GTEST_PKG_DIR}/include) | |||||
set(INSTALL_BASE_DIR "") | |||||
set(INSTALL_LIBRARY_DIR lib) | |||||
install(FILES ${GTEST_PKG_DIR}/lib/libgtest.so ${GTEST_PKG_DIR}/lib/libgtest_main.so OPTIONAL | |||||
DESTINATION ${INSTALL_LIBRARY_DIR}) | |||||
add_dependencies(gtest gtest_build) | |||||
#set(HAVE_GFLAGS TRUE CACHE BOOL "gflags build add") | |||||
set(HAVE_GTEST TRUE) |
@@ -22,7 +22,7 @@ endif () | |||||
set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2 -Dgoogle=ascend_private") | set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2 -Dgoogle=ascend_private") | ||||
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | ||||
ExternalProject_Add(protobuf_build | ExternalProject_Add(protobuf_build | ||||
URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz | |||||
URL ${REQ_URL} | |||||
CONFIGURE_COMMAND ${CMAKE_COMMAND} | CONFIGURE_COMMAND ${CMAKE_COMMAND} | ||||
-Dprotobuf_WITH_ZLIB=OFF | -Dprotobuf_WITH_ZLIB=OFF | ||||
-DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} | -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} | ||||
@@ -20,7 +20,7 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst | |||||
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | ||||
set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) | set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) | ||||
ExternalProject_Add(protobuf_static_build | ExternalProject_Add(protobuf_static_build | ||||
URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz | |||||
URL ${REQ_URL} | |||||
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | ||||
#SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 | #SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 | ||||
CONFIGURE_COMMAND ${CMAKE_COMMAND} | CONFIGURE_COMMAND ${CMAKE_COMMAND} | ||||
@@ -22,6 +22,7 @@ add_subdirectory(depends/runtime) | |||||
add_subdirectory(depends/omg) | add_subdirectory(depends/omg) | ||||
add_subdirectory(depends/hccl) | add_subdirectory(depends/hccl) | ||||
add_subdirectory(depends/profiler) | add_subdirectory(depends/profiler) | ||||
add_subdirectory(depends/error_manager) | |||||
if (ENABLE_GE_COV OR ENABLE_GE_UT) | if (ENABLE_GE_COV OR ENABLE_GE_UT) | ||||
add_subdirectory(ut) | add_subdirectory(ut) | ||||
@@ -13,60 +13,84 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
# ============================================================================ | # ============================================================================ | ||||
cmake_minimum_required(VERSION 2.8) | |||||
#cmake_minimum_required(VERSION 2.8) | |||||
project(STUB_CCE) | project(STUB_CCE) | ||||
set(CMAKE_CXX_STANDARD 11) | set(CMAKE_CXX_STANDARD 11) | ||||
include_directories(${GE_SOURCE_DIR}/inc) | |||||
include_directories(${GE_SOURCE_DIR}/inc/framework) | |||||
include_directories(${GE_SOURCE_DIR}/inc/graph) | |||||
include_directories(${GE_SOURCE_DIR}/inc/external) | |||||
include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||||
include_directories(${GE_SOURCE_DIR}/src/common) | |||||
include_directories(${GE_SOURCE_DIR}/src/common/graph) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) | |||||
include_directories(${GE_CODE_DIR}/inc) | |||||
include_directories(${GE_CODE_DIR}/inc/framework) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc/graph) | |||||
include_directories(${GE_CODE_DIR}/inc/external) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc/external) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc/external/graph) | |||||
include_directories(${GE_CODE_DIR}/metadef) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc) | |||||
include_directories(${GE_CODE_DIR}/metadef/graph) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"${GE_SOURCE_DIR}/src/proto/om.proto" | |||||
"${GE_SOURCE_DIR}/src/proto/ge_ir.proto" | |||||
"${GE_SOURCE_DIR}/src/proto/task.proto" | |||||
set(PROTO_LIST | |||||
"${GE_CODE_DIR}/metadef/proto/om.proto" | |||||
"${GE_CODE_DIR}/metadef/proto/ge_ir.proto" | |||||
"${GE_CODE_DIR}/metadef/proto/task.proto" | |||||
) | ) | ||||
ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"${GE_SOURCE_DIR}/src/common/graph/ge_attr_define.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/anchor.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/ge_attr_value.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/buffer.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/compute_graph.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/graph.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/model.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/model_serialize.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/node.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/op_desc.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/operator.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/operator_factory.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/operator_factory_impl.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/tensor.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/detail/attributes_holder.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/utils/anchor_utils.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/utils/graph_utils.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/utils/node_utils.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/utils/op_desc_utils.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/utils/type_utils.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/op_imp.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/shape_refiner.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/ge_tensor.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/opsproto/opsproto_manager.cc" | |||||
set(SRCS | |||||
"${GE_CODE_DIR}/metadef/graph/ge_attr_define.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/anchor.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/ge_attr_value.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/buffer.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/compute_graph.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/graph.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/model.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/model_serialize.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/node.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/op_desc.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/operator.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/operator_factory.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/tensor.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | |||||
"${GE_CODE_DIR}/metadef/ops/op_imp.cpp" | |||||
"${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc" | |||||
) | ) | ||||
add_library(cce_ge_stub SHARED src/cce_stub.cc ${PROTO_SRCS} ${PROTO_HDRS}) | add_library(cce_ge_stub SHARED src/cce_stub.cc ${PROTO_SRCS} ${PROTO_HDRS}) | ||||
target_link_libraries(cce_ge_stub protobuf::protobuf) | |||||
target_compile_definitions(cce_ge_stub PRIVATE | |||||
google=ascend_private | |||||
) | |||||
target_link_libraries(cce_ge_stub | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
-Wl,--no-as-needed | |||||
ascend_protobuf | |||||
-Wl,--as-needed | |||||
c_sec | |||||
) | |||||
add_library(cce_stub SHARED ${SRCS} ${PROTO_SRCS} ${PROTO_HDRS}) | add_library(cce_stub SHARED ${SRCS} ${PROTO_SRCS} ${PROTO_HDRS}) | ||||
target_link_libraries(cce_stub protobuf::protobuf) | |||||
target_compile_definitions(cce_stub PRIVATE | |||||
google=ascend_private | |||||
) | |||||
target_link_libraries(cce_stub PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
-Wl,--no-as-needed | |||||
ascend_protobuf | |||||
-Wl,--as-needed | |||||
c_sec | |||||
) |
@@ -0,0 +1,34 @@ | |||||
# Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
#cmake_minimum_required(VERSION 2.8) | |||||
project(STUB_ERROR_MANAGER) | |||||
file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"src/error_manager_stub.cc" | |||||
) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_CODE_DIR}/inc) | |||||
include_directories(${GE_CODE_DIR}/inc/external) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc) | |||||
include_directories(${GE_CODE_DIR}/inc/framework) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc/external) | |||||
add_library(error_manager_stub SHARED ${SRCS}) | |||||
target_link_libraries(error_manager_stub PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
) |
@@ -0,0 +1,85 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "common/util/error_manager/error_manager.h" | |||||
ErrorManager &ErrorManager::GetInstance() { | |||||
static ErrorManager instance; | |||||
return instance; | |||||
} | |||||
/// | |||||
/// @brief init | |||||
/// @param [in] path: current so path | |||||
/// @return int 0(success) -1(fail) | |||||
/// | |||||
int ErrorManager::Init(std::string path) { return 0; } | |||||
/// | |||||
/// @brief Report error message | |||||
/// @param [in] error_code: error code | |||||
/// @param [in] args_map: parameter map | |||||
/// @return int 0(success) -1(fail) | |||||
/// | |||||
int ErrorManager::ReportErrMessage(std::string error_code, const std::map<std::string, std::string> &args_map) { | |||||
return 0; | |||||
} | |||||
/// | |||||
/// @brief output error message | |||||
/// @param [in] handle: print handle | |||||
/// @return int 0(success) -1(fail) | |||||
/// | |||||
int ErrorManager::OutputErrMessage(int handle) { return 0; } | |||||
/// | |||||
/// @brief output message | |||||
/// @param [in] handle: print handle | |||||
/// @return int 0(success) -1(fail) | |||||
/// | |||||
int ErrorManager::OutputMessage(int handle) { return 0; } | |||||
/// | |||||
/// @brief Report error message | |||||
/// @param [in] key: vector parameter key | |||||
/// @param [in] value: vector parameter value | |||||
/// | |||||
void ErrorManager::ATCReportErrMessage(std::string error_code, const std::vector<std::string> &key, | |||||
const std::vector<std::string> &value) { | |||||
} | |||||
/// | |||||
/// @brief report graph compile failed message such as error code and op_name in mstune case | |||||
/// @param [in] msg: failed message map, key is error code, value is op_name | |||||
/// @return int 0(success) -1(fail) | |||||
/// | |||||
int ErrorManager::ReportMstuneCompileFailedMsg(const std::map<std::string, std::string> &msg) { return 0; } | |||||
/// | |||||
/// @brief save graph compile failed message from thread local map to global map | |||||
/// @param [in] graph_name: graph name | |||||
/// | |||||
void ErrorManager::SaveMstuneCompileFailedMsg(const std::string &graph_name) {} | |||||
/// | |||||
/// @brief get graph compile failed message in mstune case | |||||
/// @param [in] graph_name: graph name | |||||
/// @param [out] msg_map: failed message map, key is error code, value is op_name list | |||||
/// @return int 0(success) -1(fail) | |||||
/// | |||||
int ErrorManager::GetMstuneCompileFailedMsg(const std::string &graph_name, std::map<std::string, std::vector<std::string>> &msg_map) { return 0; } | |||||
@@ -13,14 +13,18 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
# ============================================================================ | # ============================================================================ | ||||
cmake_minimum_required(VERSION 2.8) | |||||
#cmake_minimum_required(VERSION 2.8) | |||||
project(hccl_stub) | project(hccl_stub) | ||||
file(GLOB_RECURSE SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | file(GLOB_RECURSE SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | ||||
"src/hccl_stub.cc" | "src/hccl_stub.cc" | ||||
) | ) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_SOURCE_DIR}/inc) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_CODE_DIR}/inc) | |||||
add_library(hccl_stub SHARED ${SRC_FILES}) | |||||
add_library(hccl_stub SHARED ${SRC_FILES}) | |||||
target_link_libraries(hccl_stub PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
) |
@@ -18,27 +18,27 @@ | |||||
#include "hccl/hcom.h" | #include "hccl/hcom.h" | ||||
hcclResult_t hcom_all_gather(const char *tag, void *input_count_ptr, void *output_ptr, u64 input_count, | |||||
hcclDataType_t data_type, const char *group, rtStream_t stream) { | |||||
HcclResult hcom_all_gather(const char *tag, void *input_count_ptr, void *output_ptr, u64 input_count, | |||||
HcclDataType data_type, const char *group, rtStream_t stream) { | |||||
return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
} | } | ||||
hcclResult_t hcom_broadcast(const char *tag, void *ptr, u64 count, hcclDataType_t data_type, u32 root, | |||||
HcclResult hcom_broadcast(const char *tag, void *ptr, u64 count, HcclDataType data_type, u32 root, | |||||
const char *group, rtStream_t stream) { | const char *group, rtStream_t stream) { | ||||
return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
} | } | ||||
hcclResult_t hcom_all_reduce(const char *tag, void *input_ptr, void *output_ptr, u64 count, hcclDataType_t data_type, | |||||
hcclRedOp_t op, const char *group, rtStream_t stream) { | |||||
HcclResult hcom_all_reduce(const char *tag, void *input_ptr, void *output_ptr, u64 count, HcclDataType data_type, | |||||
HcclReduceOp op, const char *group, rtStream_t stream) { | |||||
return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
} | } | ||||
hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 max_segment_num, | |||||
HcclResult hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 max_segment_num, | |||||
u32 *segment_num, u32 *segment_idx) { | u32 *segment_num, u32 *segment_idx) { | ||||
return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
} | } | ||||
hcclResult_t hcom_reduce_scatter(const char *tag, void *input_ptr, void *output_ptr, u64 count, | |||||
hcclDataType_t data_type, hcclRedOp_t op, const char *group, rtStream_t stream) { | |||||
HcclResult hcom_reduce_scatter(const char *tag, void *input_ptr, void *output_ptr, u64 count, | |||||
HcclDataType data_type, HcclReduceOp op, const char *group, rtStream_t stream) { | |||||
return HCCL_SUCCESS; | return HCCL_SUCCESS; | ||||
} | |||||
} |
@@ -13,7 +13,7 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
# ============================================================================ | # ============================================================================ | ||||
cmake_minimum_required(VERSION 2.8) | |||||
#cmake_minimum_required(VERSION 2.8) | |||||
project(STUB_MMPA) | project(STUB_MMPA) | ||||
@@ -21,10 +21,18 @@ file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"src/mmpa_stub.cc" | "src/mmpa_stub.cc" | ||||
) | ) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_SOURCE_DIR}/inc) | |||||
include_directories(${GE_SOURCE_DIR}/inc/framework) | |||||
include_directories(${GE_SOURCE_DIR}/inc/external) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_CODE_DIR}/inc) | |||||
include_directories(${GE_CODE_DIR}/inc/external) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc) | |||||
include_directories(${GE_CODE_DIR}/inc/framework) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc/external) | |||||
add_library(mmpa_stub SHARED ${SRCS}) | add_library(mmpa_stub SHARED ${SRCS}) | ||||
target_link_libraries(mmpa_stub protobuf::protobuf) | |||||
target_link_libraries(mmpa_stub PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
-Wl,--no-as-needed | |||||
ascend_protobuf | |||||
-Wl,--as-needed | |||||
c_sec | |||||
) |
@@ -217,3 +217,58 @@ INT32 mmScandir(const CHAR *path, mmDirent ***entryList, mmFilter filterFunc, m | |||||
VOID mmScandirFree(mmDirent **entryList, INT32 count) | VOID mmScandirFree(mmDirent **entryList, INT32 count) | ||||
{ | { | ||||
} | } | ||||
INT32 mmAccess2(const CHAR *pathName, INT32 mode) | |||||
{ | |||||
return 0; | |||||
} | |||||
INT32 mmGetTimeOfDay(mmTimeval *timeVal, mmTimezone *timeZone) | |||||
{ | |||||
return 0; | |||||
} | |||||
INT32 mmRealPath(const CHAR *path, CHAR *realPath, INT32 realPathLen) | |||||
{ | |||||
return 0; | |||||
} | |||||
INT32 mmGetErrorCode() | |||||
{ | |||||
return 0; | |||||
} | |||||
INT32 mmIsDir(const CHAR *fileName) | |||||
{ | |||||
return 0; | |||||
} | |||||
INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) | |||||
{ | |||||
return 0; | |||||
} | |||||
INT32 mmDlclose(VOID *handle) | |||||
{ | |||||
return 0; | |||||
} | |||||
CHAR *mmDlerror() | |||||
{ | |||||
return ""; | |||||
} | |||||
INT32 mmDladdr(VOID *addr, mmDlInfo *info) | |||||
{ | |||||
return 0; | |||||
} | |||||
VOID *mmDlopen(const CHAR *fileName, INT32 mode) | |||||
{ | |||||
return NULL; | |||||
} | |||||
VOID *mmDlsym(VOID *handle, const CHAR *funcName) | |||||
{ | |||||
return NULL; | |||||
} |
@@ -13,33 +13,47 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
# ============================================================================ | # ============================================================================ | ||||
cmake_minimum_required(VERSION 2.8) | |||||
#cmake_minimum_required(VERSION 2.8) | |||||
project(OMG_CCE) | project(OMG_CCE) | ||||
set(CMAKE_CXX_STANDARD 11) | set(CMAKE_CXX_STANDARD 11) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | |||||
include_directories(${GE_SOURCE_DIR}/inc) | |||||
include_directories(${GE_SOURCE_DIR}/inc/framework) | |||||
include_directories(${GE_SOURCE_DIR}/inc/graph) | |||||
include_directories(${GE_SOURCE_DIR}/inc/external) | |||||
include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||||
include_directories(${GE_SOURCE_DIR}/src/ge) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce) | |||||
include_directories(${GE_CODE_DIR}/inc) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc) | |||||
include_directories(${GE_CODE_DIR}/inc/framework) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc/graph) | |||||
include_directories(${GE_CODE_DIR}/inc/external) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc/external) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc/external/graph) | |||||
include_directories(${GE_CODE_DIR}/ge) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"${GE_SOURCE_DIR}/src/proto/om.proto" | |||||
"${GE_SOURCE_DIR}/src/proto/task.proto" | |||||
set(PROTO_LIST | |||||
"${GE_CODE_DIR}/metadef/proto/om.proto" | |||||
"${GE_CODE_DIR}/metadef/proto/task.proto" | |||||
) | ) | ||||
ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
# "${GE_SOURCE_DIR}/src/ge/common/util.cc" | |||||
"src/omg_stub.cc" | |||||
set(SRCS | |||||
# "${GE_CODE_DIR}/src/ge/common/util.cc" | |||||
"src/omg_stub.cc" | |||||
) | ) | ||||
add_library(omg_stub SHARED ${SRCS} ${PROTO_SRCS} ${PROTO_HDRS}) | add_library(omg_stub SHARED ${SRCS} ${PROTO_SRCS} ${PROTO_HDRS}) | ||||
target_link_libraries(omg_stub protobuf::protobuf) | |||||
target_compile_definitions(omg_stub PRIVATE | |||||
google=ascend_private | |||||
) | |||||
target_link_libraries(omg_stub PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
-Wl,--no-as-needed | |||||
ascend_protobuf | |||||
-Wl,--as-needed | |||||
c_sec | |||||
json | |||||
) |
@@ -643,7 +643,7 @@ Status GetInputOutputDescInfo(uint32_t model_id, vector<InputOutputDescInfo> &in | |||||
} | } | ||||
Status DataInput(const InputData *input_data, OutputData *output_data) { return SUCCESS; } | Status DataInput(const InputData *input_data, OutputData *output_data) { return SUCCESS; } | ||||
/* | |||||
class ModelManager { | class ModelManager { | ||||
public: | public: | ||||
static std::shared_ptr<ModelManager> GetInstance(); | static std::shared_ptr<ModelManager> GetInstance(); | ||||
@@ -741,6 +741,8 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
*/ | |||||
} // namespace ge | } // namespace ge | ||||
namespace ge { | namespace ge { | ||||
@@ -13,12 +13,16 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
# ============================================================================ | # ============================================================================ | ||||
cmake_minimum_required(VERSION 2.8) | |||||
#cmake_minimum_required(VERSION 2.8) | |||||
project(profiler_stub) | project(profiler_stub) | ||||
file(GLOB_RECURSE SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | file(GLOB_RECURSE SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | ||||
"src/profiler_stub.cc" | "src/profiler_stub.cc" | ||||
) | ) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||||
add_library(profiler_stub SHARED ${SRC_FILES}) | |||||
add_library(profiler_stub SHARED ${SRC_FILES}) | |||||
target_link_libraries(profiler_stub PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
) |
@@ -13,7 +13,7 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
# ============================================================================ | # ============================================================================ | ||||
cmake_minimum_required(VERSION 2.8) | |||||
#cmake_minimum_required(VERSION 2.8) | |||||
project(STUB_MMPA) | project(STUB_MMPA) | ||||
@@ -21,7 +21,12 @@ file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"src/runtime_stub.cc" | "src/runtime_stub.cc" | ||||
) | ) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_SOURCE_DIR}/inc/framework) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_CODE_DIR}/inc/framework) | |||||
add_library(runtime_stub SHARED ${SRCS}) | add_library(runtime_stub SHARED ${SRCS}) | ||||
target_link_libraries(runtime_stub PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
c_sec | |||||
) |
@@ -221,8 +221,9 @@ rtError_t rtCpuKernelLaunch(const void *so_name, const void *kernel_name, uint32 | |||||
return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
} | } | ||||
rtError_t rtModelGetTaskId(void *handle, uint32_t *task_id) { | |||||
rtError_t rtModelGetTaskId(void *handle, uint32_t *task_id, uint32_t *stream_id) { | |||||
*task_id = 0; | *task_id = 0; | ||||
*stream_id = 0; | |||||
return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
} | } | ||||
rtError_t rtEndGraph(rtModel_t model, rtStream_t stream) { return RT_ERROR_NONE; } | rtError_t rtEndGraph(rtModel_t model, rtStream_t stream) { return RT_ERROR_NONE; } | ||||
@@ -307,3 +308,79 @@ rtError_t rtModelBindQueue(rtModel_t model, uint32_t queueId, rtModelQueueFlag_t | |||||
{ | { | ||||
return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
} | } | ||||
rtError_t rtSetSocVersion(const char *version) | |||||
{ | |||||
return RT_ERROR_NONE; | |||||
} | |||||
rtError_t rtGetSocVersion(char *version, const uint32_t maxLen) | |||||
{ | |||||
return RT_ERROR_NONE; | |||||
} | |||||
rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) | |||||
{ | |||||
return RT_ERROR_NONE; | |||||
} | |||||
rtError_t rtMallocHostSharedMemory(rtMallocHostSharedMemoryIn *in, | |||||
rtMallocHostSharedMemoryOut *out) | |||||
{ | |||||
out->ptr = new uint8_t[in->size]; | |||||
out->devPtr = new uint8_t[in->size]; | |||||
return RT_ERROR_NONE; | |||||
} | |||||
rtError_t rtFreeHostSharedMemory(rtFreeHostSharedMemoryIn *in) | |||||
{ | |||||
delete[] (uint8_t*)in->ptr; | |||||
delete[] (uint8_t*)in->devPtr; | |||||
return RT_ERROR_NONE; | |||||
} | |||||
rtError_t rtGetAicpuDeploy(rtAicpuDeployType_t *deplyType) | |||||
{ | |||||
return RT_ERROR_NONE; | |||||
} | |||||
rtError_t rtDebugRegister(rtModel_t model, uint32_t flag, const void *addr, uint32_t *streamId, uint32_t *taskId) | |||||
{ | |||||
return RT_ERROR_NONE; | |||||
} | |||||
rtError_t rtDebugUnRegister(rtModel_t model) | |||||
{ | |||||
return RT_ERROR_NONE; | |||||
} | |||||
rtError_t rtDumpAddrSet(rtModel_t model, void *addr, uint32_t dumpSize, uint32_t flag) | |||||
{ | |||||
return RT_ERROR_NONE; | |||||
} | |||||
rtError_t rtSetCtxINFMode(bool mode) | |||||
{ | |||||
return RT_ERROR_NONE; | |||||
} | |||||
rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream) | |||||
{ | |||||
*label = new uint32_t; | |||||
return RT_ERROR_NONE; | |||||
} | |||||
rtError_t rtGetRtCapability(rtFeatureType_t featureType, int32_t featureInfo, int64_t *value) | |||||
{ | |||||
return RT_ERROR_NONE; | |||||
} | |||||
rtError_t rtGetMaxStreamAndTask(uint32_t streamType, uint32_t *maxStrCount, uint32_t *maxTaskCount) | |||||
{ | |||||
return RT_ERROR_NONE; | |||||
} | |||||
rtError_t rtModelExit(rtModel_t model, rtStream_t stream) | |||||
{ | |||||
return RT_ERROR_NONE; | |||||
} |
@@ -13,11 +13,14 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
# ============================================================================ | # ============================================================================ | ||||
cmake_minimum_required(VERSION 2.8) | |||||
#cmake_minimum_required(VERSION 2.8) | |||||
project(slog_stub) | project(slog_stub) | ||||
file(GLOB_RECURSE SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | file(GLOB_RECURSE SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | ||||
"src/*.cc" | "src/*.cc" | ||||
) | ) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||||
add_library(slog_stub SHARED ${SRC_FILES}) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||||
add_library(slog_stub SHARED ${SRC_FILES}) | |||||
target_link_libraries(slog_stub PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
) |
@@ -38,6 +38,8 @@ void DlogWithKVInner(int module_id, int level, KeyValue *pst_kv_array, int kv_nu | |||||
dav_log(module_id, fmt); | dav_log(module_id, fmt); | ||||
} | } | ||||
int dlog_setlevel(int module_id, int level, int enable_event) { return DLOG_DEBUG; } | |||||
int dlog_getlevel(int module_id, int *enable_event) { return DLOG_DEBUG; } | int dlog_getlevel(int module_id, int *enable_event) { return DLOG_DEBUG; } | ||||
int CheckLogLevel(int moduleId, int logLevel) | int CheckLogLevel(int moduleId, int logLevel) | ||||
@@ -17,30 +17,34 @@ project(ut_libgraph) | |||||
set(CMAKE_CXX_STANDARD 11) | set(CMAKE_CXX_STANDARD 11) | ||||
file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"${GE_SOURCE_DIR}/src/proto/om.proto" | |||||
"${GE_SOURCE_DIR}/src/proto/ge_ir.proto" | |||||
"${onnx_INC}/onnx/onnx.proto" | |||||
set(PROTO_LIST | |||||
"${GE_CODE_DIR}/metadef/proto/om.proto" | |||||
"${GE_CODE_DIR}/metadef/proto/ge_ir.proto" | |||||
"${GE_CODE_DIR}/metadef/proto/proto_inner/ge_onnx.proto" | |||||
) | ) | ||||
ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
# include directories | # include directories | ||||
include_directories(${CMAKE_CURRENT_LIST_DIR}) | include_directories(${CMAKE_CURRENT_LIST_DIR}) | ||||
include_directories(${GE_SOURCE_DIR}/src) | |||||
include_directories(${GE_SOURCE_DIR}/src/common) | |||||
include_directories(${GE_SOURCE_DIR}/src/common/graph) | |||||
include_directories(${GE_SOURCE_DIR}/inc) | |||||
include_directories(${GE_SOURCE_DIR}/inc/external) | |||||
include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||||
include_directories(${GE_SOURCE_DIR}/inc/graph) | |||||
include_directories(${GE_SOURCE_DIR}/inc/common) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) | |||||
include_directories(${GE_CODE_DIR}) | |||||
include_directories(${GE_CODE_DIR}/metadef) | |||||
include_directories(${GE_CODE_DIR}/metadef/graph) | |||||
include_directories(${GE_CODE_DIR}/inc) | |||||
include_directories(${GE_CODE_DIR}/inc/external) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc/external) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc/external/graph) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc/graph) | |||||
include_directories(${GE_CODE_DIR}/metadef/inc/common) | |||||
include_directories(${GE_CODE_DIR}/metadef/third_party) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) | |||||
file(GLOB_RECURSE UT_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
set(UT_FILES | |||||
"testcase/ge_graph/ge_anchor_utils_unittest.cc" | "testcase/ge_graph/ge_anchor_utils_unittest.cc" | ||||
"testcase/ge_graph/ge_def_type_unittest.cc" | "testcase/ge_graph/ge_def_type_unittest.cc" | ||||
"testcase/ge_graph/ge_graph_anchor_unittest.cc" | "testcase/ge_graph/ge_graph_anchor_unittest.cc" | ||||
@@ -56,41 +60,59 @@ file(GLOB_RECURSE UT_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
"testcase/ge_graph/ge_model_unittest.cc" | "testcase/ge_graph/ge_model_unittest.cc" | ||||
) | ) | ||||
file(GLOB_RECURSE SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
"${GE_SOURCE_DIR}/src/common/graph/option/ge_local_context.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/option/ge_context.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/anchor.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/ge_attr_value.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/attr_value.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/buffer.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/compute_graph.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/ge_attr_define.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/graph.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/model.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/model_serialize.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/node.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/op_desc.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/operator.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/operator_reg.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/operator_factory.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/operator_factory_impl.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/range_vistor.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/tensor.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/ge_tensor.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/shape_refiner.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/format_refiner.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/inference_context.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/detail/attributes_holder.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/utils/anchor_utils.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/utils/graph_utils.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/utils/node_utils.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/utils/op_desc_utils.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/utils/type_utils.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/utils/ge_ir_utils.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/utils/tensor_utils.cc" | |||||
"${GE_SOURCE_DIR}/src/common/ops/op_imp.cc" | |||||
"${GE_SOURCE_DIR}/src/common/graph/opsproto/opsproto_manager.cc" | |||||
set(SRC_FILES | |||||
#"${GE_CODE_DIR}/metadef/graph/option/ge_local_context.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/option/ge_context.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/anchor.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/ge_attr_value.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/attr_value.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/buffer.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/compute_graph.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/ge_attr_define.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/graph.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/gnode.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/ascend_string.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/model.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/model_serialize.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/node.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/op_desc.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/operator.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/operator_reg.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/operator_factory.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/range_vistor.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/tensor.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/format_refiner.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/inference_context.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" | |||||
#"${GE_CODE_DIR}/metadef/graph/utils/tensor_utils.cc" | |||||
"${GE_CODE_DIR}/metadef/ops/op_imp.cpp" | |||||
#"${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc" | |||||
) | ) | ||||
#add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | |||||
add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | ||||
target_link_libraries(ut_libgraph graphengine::gtest graphengine::gtest_main slog_stub protobuf::protobuf graphengine::securec rt dl) | |||||
target_compile_definitions(ut_libgraph PRIVATE | |||||
google=ascend_private | |||||
) | |||||
target_link_libraries(ut_libgraph | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
graph | |||||
gtest | |||||
gtest_main | |||||
slog_stub | |||||
ascend_protobuf | |||||
c_sec | |||||
-lrt | |||||
-ldl | |||||
) |
@@ -85,7 +85,7 @@ ut::GraphBuilder BuildGraph1() { | |||||
builder.AddDataEdge(var2, 0, conv1, 1); | builder.AddDataEdge(var2, 0, conv1, 1); | ||||
builder.AddDataEdge(conv1, 0, relu1, 0); | builder.AddDataEdge(conv1, 0, relu1, 0); | ||||
builder.AddDataEdge(relu1, 0, netoutput1, 0); | builder.AddDataEdge(relu1, 0, netoutput1, 0); | ||||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
return builder; | return builder; | ||||
} | } | ||||
@@ -134,7 +134,7 @@ ut::GraphBuilder BuildGraph2() { | |||||
builder.AddDataEdge(var6, 0, bn1, 4); | builder.AddDataEdge(var6, 0, bn1, 4); | ||||
builder.AddDataEdge(bn1, 0, relu1, 0); | builder.AddDataEdge(bn1, 0, relu1, 0); | ||||
builder.AddDataEdge(relu1, 0, netoutput1, 0); | builder.AddDataEdge(relu1, 0, netoutput1, 0); | ||||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
return builder; | return builder; | ||||
} | } | ||||
@@ -189,7 +189,7 @@ ut::GraphBuilder BuildGraph3() { | |||||
builder.AddDataEdge(relu1, 0, conv2, 0); | builder.AddDataEdge(relu1, 0, conv2, 0); | ||||
builder.AddDataEdge(var3, 0, conv2, 1); | builder.AddDataEdge(var3, 0, conv2, 1); | ||||
builder.AddDataEdge(conv2, 0, netoutput1, 0); | builder.AddDataEdge(conv2, 0, netoutput1, 0); | ||||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
return builder; | return builder; | ||||
} | } | ||||
@@ -248,7 +248,7 @@ ut::GraphBuilder BuildGraph4() { | |||||
builder.AddDataEdge(relu1, 0, conv2, 0); | builder.AddDataEdge(relu1, 0, conv2, 0); | ||||
builder.AddDataEdge(var3, 0, conv2, 1); | builder.AddDataEdge(var3, 0, conv2, 1); | ||||
builder.AddDataEdge(conv2, 0, netoutput1, 0); | builder.AddDataEdge(conv2, 0, netoutput1, 0); | ||||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
return builder; | return builder; | ||||
} | } | ||||
@@ -305,7 +305,7 @@ ut::GraphBuilder BuilderGraph5() { | |||||
builder.AddDataEdge(relug1, 0, bng1, 0); | builder.AddDataEdge(relug1, 0, bng1, 0); | ||||
builder.AddDataEdge(bng1, 0, apply1, 0); | builder.AddDataEdge(bng1, 0, apply1, 0); | ||||
builder.AddDataEdge(apply1, 0, netoutput1, 0); | builder.AddDataEdge(apply1, 0, netoutput1, 0); | ||||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
return builder; | return builder; | ||||
} | } | ||||
@@ -353,7 +353,7 @@ ut::GraphBuilder BuildGraph6() { | |||||
builder.AddDataEdge(constant, 0, addn, 2); | builder.AddDataEdge(constant, 0, addn, 2); | ||||
builder.AddDataEdge(addn, 0, netoutput, 0); | builder.AddDataEdge(addn, 0, netoutput, 0); | ||||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
return builder; | return builder; | ||||
} | } | ||||
@@ -397,7 +397,7 @@ ut::GraphBuilder BuildGraph7() { | |||||
builder.AddDataEdge(constant, 0, addn, 2); | builder.AddDataEdge(constant, 0, addn, 2); | ||||
builder.AddDataEdge(addn, 0, netoutput, 0); | builder.AddDataEdge(addn, 0, netoutput, 0); | ||||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
return builder; | return builder; | ||||
} | } | ||||
@@ -449,7 +449,7 @@ ut::GraphBuilder BuildGraph8() { | |||||
builder.AddDataEdge(relu, 0, reshape, 0); | builder.AddDataEdge(relu, 0, reshape, 0); | ||||
builder.AddDataEdge(reshape, 0, conv, 1); | builder.AddDataEdge(reshape, 0, conv, 1); | ||||
builder.AddDataEdge(conv, 0, netoutput, 0); | builder.AddDataEdge(conv, 0, netoutput, 0); | ||||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
return builder; | return builder; | ||||
} | } | ||||
} // namespace | } // namespace | ||||
@@ -457,7 +457,7 @@ ut::GraphBuilder BuildGraph8() { | |||||
TEST_F(UtestFormatRefiner, data_format) { | TEST_F(UtestFormatRefiner, data_format) { | ||||
auto builder = BuildGraph8(); | auto builder = BuildGraph8(); | ||||
auto graph = builder.GetGraph(); | auto graph = builder.GetGraph(); | ||||
FormatRefiner::SetInferOrigineFormatFlag(false); | |||||
//FormatRefiner::SetInferOrigineFormatFlag(false); | |||||
graph->SaveDataFormat(FORMAT_NCHW); | graph->SaveDataFormat(FORMAT_NCHW); | ||||
EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); | EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); | ||||
auto data2 = graph->FindNode("data2"); | auto data2 = graph->FindNode("data2"); | ||||
@@ -466,18 +466,18 @@ TEST_F(UtestFormatRefiner, data_format) { | |||||
EXPECT_EQ(data2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); | EXPECT_EQ(data2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); | ||||
EXPECT_EQ(relu->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); | EXPECT_EQ(relu->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); | ||||
EXPECT_EQ(relu->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); | EXPECT_EQ(relu->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); | ||||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
} | } | ||||
TEST_F(UtestFormatRefiner, constant_fail) { | TEST_F(UtestFormatRefiner, constant_fail) { | ||||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
auto builder = BuildGraph6(); | auto builder = BuildGraph6(); | ||||
auto graph = builder.GetGraph(); | auto graph = builder.GetGraph(); | ||||
EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_FAILED); | EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_FAILED); | ||||
} | } | ||||
TEST_F(UtestFormatRefiner, scalar_nodes_infer) { | TEST_F(UtestFormatRefiner, scalar_nodes_infer) { | ||||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
auto builder = BuildGraph6(); | auto builder = BuildGraph6(); | ||||
auto graph = builder.GetGraph(); | auto graph = builder.GetGraph(); | ||||
auto constant = graph->FindNode("constant"); | auto constant = graph->FindNode("constant"); | ||||
@@ -650,7 +650,7 @@ TEST_F(UtestFormatRefiner, infer_origine_format_failed) { | |||||
} | } | ||||
TEST_F(UtestFormatRefiner, save_format) { | TEST_F(UtestFormatRefiner, save_format) { | ||||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||||
auto builder = BuildGraph6(); | auto builder = BuildGraph6(); | ||||
auto graph = builder.GetGraph(); | auto graph = builder.GetGraph(); | ||||
graph->SaveDataFormat(FORMAT_NHWC); | graph->SaveDataFormat(FORMAT_NHWC); | ||||
@@ -658,4 +658,4 @@ TEST_F(UtestFormatRefiner, save_format) { | |||||
EXPECT_EQ(save_format, FORMAT_NHWC); | EXPECT_EQ(save_format, FORMAT_NHWC); | ||||
graph->SaveDataFormat(FORMAT_ND); | graph->SaveDataFormat(FORMAT_ND); | ||||
} | } | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -1060,7 +1060,7 @@ TEST(UtestGeModelSerialize, test_model_serialize_imp_invalid_param) { | |||||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | auto graph = std::make_shared<ComputeGraph>("test_graph"); | ||||
auto node = graph->AddNode(std::make_shared<OpDesc>()); | auto node = graph->AddNode(std::make_shared<OpDesc>()); | ||||
node->op_ = nullptr; | node->op_ = nullptr; | ||||
proto::ModelDef model_def; | |||||
ge::proto::ModelDef model_def; | |||||
Model model; | Model model; | ||||
model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); | model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); | ||||
EXPECT_FALSE(imp.SerializeModel(model, &model_def)); | EXPECT_FALSE(imp.SerializeModel(model, &model_def)); | ||||
@@ -1101,26 +1101,26 @@ TEST(UTEST_ge_model_unserialize, test_invalid_tensor) { | |||||
TEST(UTEST_ge_model_unserialize, test_invalid_TensorDesc) { | TEST(UTEST_ge_model_unserialize, test_invalid_TensorDesc) { | ||||
{ // valid | { // valid | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.mutable_attr(); | auto attrs = mode_def.mutable_attr(); | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
auto tensor_desc_attr = attr_def->mutable_td(); | auto tensor_desc_attr = attr_def->mutable_td(); | ||||
tensor_desc_attr->set_layout("NCHW"); | tensor_desc_attr->set_layout("NCHW"); | ||||
tensor_desc_attr->set_dtype(proto::DataType::DT_INT8); | |||||
tensor_desc_attr->set_dtype(ge::proto::DataType::DT_INT8); | |||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
Model model; | Model model; | ||||
EXPECT_TRUE(imp.UnserializeModel(model, mode_def)); | EXPECT_TRUE(imp.UnserializeModel(model, mode_def)); | ||||
} | } | ||||
{ // invalid layout | { // invalid layout | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.mutable_attr(); | auto attrs = mode_def.mutable_attr(); | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
auto tensor_desc_attr = attr_def->mutable_td(); | auto tensor_desc_attr = attr_def->mutable_td(); | ||||
tensor_desc_attr->set_layout("InvalidLayout"); | tensor_desc_attr->set_layout("InvalidLayout"); | ||||
tensor_desc_attr->set_dtype(proto::DataType::DT_INT8); | |||||
tensor_desc_attr->set_dtype(ge::proto::DataType::DT_INT8); | |||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
Model model; | Model model; | ||||
@@ -1131,13 +1131,13 @@ TEST(UTEST_ge_model_unserialize, test_invalid_TensorDesc) { | |||||
EXPECT_EQ(tensor_desc.GetDataType(), DT_INT8); | EXPECT_EQ(tensor_desc.GetDataType(), DT_INT8); | ||||
} | } | ||||
{ // invalid datatype | { // invalid datatype | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.mutable_attr(); | auto attrs = mode_def.mutable_attr(); | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
auto tensor_desc_attr = attr_def->mutable_td(); // tensor desc | auto tensor_desc_attr = attr_def->mutable_td(); // tensor desc | ||||
tensor_desc_attr->set_layout("NHWC"); | tensor_desc_attr->set_layout("NHWC"); | ||||
tensor_desc_attr->set_dtype((proto::DataType)100); | |||||
tensor_desc_attr->set_dtype((ge::proto::DataType)100); | |||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
Model model; | Model model; | ||||
@@ -1148,13 +1148,13 @@ TEST(UTEST_ge_model_unserialize, test_invalid_TensorDesc) { | |||||
EXPECT_EQ(tensor_desc.GetDataType(), DT_UNDEFINED); | EXPECT_EQ(tensor_desc.GetDataType(), DT_UNDEFINED); | ||||
} | } | ||||
{ // invalid datatype | { // invalid datatype | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.mutable_attr(); | auto attrs = mode_def.mutable_attr(); | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
auto tensor_desc_attr = attr_def->mutable_t()->mutable_desc(); // tensor | auto tensor_desc_attr = attr_def->mutable_t()->mutable_desc(); // tensor | ||||
tensor_desc_attr->set_layout("NHWC"); | tensor_desc_attr->set_layout("NHWC"); | ||||
tensor_desc_attr->set_dtype((proto::DataType)100); | |||||
tensor_desc_attr->set_dtype((ge::proto::DataType)100); | |||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
Model model; | Model model; | ||||
@@ -1167,13 +1167,13 @@ TEST(UTEST_ge_model_unserialize, test_invalid_TensorDesc) { | |||||
EXPECT_EQ(tensor_desc.GetDataType(), DT_UNDEFINED); | EXPECT_EQ(tensor_desc.GetDataType(), DT_UNDEFINED); | ||||
} | } | ||||
{ // invalid attrmap | { // invalid attrmap | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.add_graph()->mutable_attr(); // graph attr | auto attrs = mode_def.add_graph()->mutable_attr(); // graph attr | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
auto tensor_desc_attr = attr_def->mutable_t()->mutable_desc(); // tensor | auto tensor_desc_attr = attr_def->mutable_t()->mutable_desc(); // tensor | ||||
tensor_desc_attr->set_layout("NCHW"); | tensor_desc_attr->set_layout("NCHW"); | ||||
tensor_desc_attr->set_dtype(proto::DataType::DT_INT8); | |||||
tensor_desc_attr->set_dtype(ge::proto::DataType::DT_INT8); | |||||
auto attrs1 = tensor_desc_attr->mutable_attr(); | auto attrs1 = tensor_desc_attr->mutable_attr(); | ||||
auto attr1 = (*attrs1)["key2"]; // empty attr | auto attr1 = (*attrs1)["key2"]; // empty attr | ||||
@@ -1191,13 +1191,13 @@ TEST(UTEST_ge_model_unserialize, test_invalid_TensorDesc) { | |||||
EXPECT_EQ(attr_value.GetValueType(), GeAttrValue::VT_NONE); | EXPECT_EQ(attr_value.GetValueType(), GeAttrValue::VT_NONE); | ||||
} | } | ||||
{ // invalid attrmap2 | { // invalid attrmap2 | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
auto tensor_desc_attr = attr_def->mutable_t()->mutable_desc(); // tensor | auto tensor_desc_attr = attr_def->mutable_t()->mutable_desc(); // tensor | ||||
tensor_desc_attr->set_layout("NCHW"); | tensor_desc_attr->set_layout("NCHW"); | ||||
tensor_desc_attr->set_dtype(proto::DataType::DT_INT8); | |||||
tensor_desc_attr->set_dtype(ge::proto::DataType::DT_INT8); | |||||
auto attrs1 = tensor_desc_attr->mutable_attr(); | auto attrs1 = tensor_desc_attr->mutable_attr(); | ||||
auto attr1 = (*attrs1)["key2"].mutable_list(); // empty list attr | auto attr1 = (*attrs1)["key2"].mutable_list(); // empty list attr | ||||
@@ -1219,14 +1219,14 @@ TEST(UTEST_ge_model_unserialize, test_invalid_TensorDesc) { | |||||
} | } | ||||
TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | ||||
{ // invalid graph | { // invalid graph | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
auto graph_attr = attr_def->mutable_g(); | auto graph_attr = attr_def->mutable_g(); | ||||
auto attrs_of_graph = graph_attr->mutable_attr(); | auto attrs_of_graph = graph_attr->mutable_attr(); | ||||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | ||||
tensor_val->set_dtype(proto::DT_INT8); | |||||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||||
tensor_val->set_layout("invalidLayout"); | tensor_val->set_layout("invalidLayout"); | ||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
@@ -1245,15 +1245,15 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | ||||
} | } | ||||
{ // invalid list graph | { // invalid list graph | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH); | attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH); | ||||
auto graph_attr = attr_def->mutable_list()->add_g(); | auto graph_attr = attr_def->mutable_list()->add_g(); | ||||
auto attrs_of_graph = graph_attr->mutable_attr(); | auto attrs_of_graph = graph_attr->mutable_attr(); | ||||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | ||||
tensor_val->set_dtype(proto::DT_INT8); | |||||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||||
tensor_val->set_layout("invalidLayout"); | tensor_val->set_layout("invalidLayout"); | ||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
@@ -1273,14 +1273,14 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | ||||
} | } | ||||
{ // invalid named_attrs | { // invalid named_attrs | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
auto graph_attr = attr_def->mutable_func(); | auto graph_attr = attr_def->mutable_func(); | ||||
auto attrs_of_graph = graph_attr->mutable_attr(); | auto attrs_of_graph = graph_attr->mutable_attr(); | ||||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | ||||
tensor_val->set_dtype(proto::DT_INT8); | |||||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||||
tensor_val->set_layout("invalidLayout"); | tensor_val->set_layout("invalidLayout"); | ||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
@@ -1298,15 +1298,15 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | ||||
} | } | ||||
{ // invalid list named_attrs | { // invalid list named_attrs | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS); | attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS); | ||||
auto graph_attr = attr_def->mutable_list()->add_na(); | auto graph_attr = attr_def->mutable_list()->add_na(); | ||||
auto attrs_of_graph = graph_attr->mutable_attr(); | auto attrs_of_graph = graph_attr->mutable_attr(); | ||||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | ||||
tensor_val->set_dtype(proto::DT_INT8); | |||||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||||
tensor_val->set_layout("invalidLayout"); | tensor_val->set_layout("invalidLayout"); | ||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
@@ -1325,14 +1325,14 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | ||||
} | } | ||||
{ // invalid tensor_desc | { // invalid tensor_desc | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
auto graph_attr = attr_def->mutable_td(); | auto graph_attr = attr_def->mutable_td(); | ||||
auto attrs_of_graph = graph_attr->mutable_attr(); | auto attrs_of_graph = graph_attr->mutable_attr(); | ||||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | ||||
tensor_val->set_dtype(proto::DT_INT8); | |||||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||||
tensor_val->set_layout("invalidLayout"); | tensor_val->set_layout("invalidLayout"); | ||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
@@ -1350,15 +1350,15 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | ||||
} | } | ||||
{ // invalid list tensor_desc | { // invalid list tensor_desc | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC); | attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC); | ||||
auto graph_attr = attr_def->mutable_list()->add_td(); | auto graph_attr = attr_def->mutable_list()->add_td(); | ||||
auto attrs_of_graph = graph_attr->mutable_attr(); | auto attrs_of_graph = graph_attr->mutable_attr(); | ||||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | ||||
tensor_val->set_dtype(proto::DT_INT8); | |||||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||||
tensor_val->set_layout("invalidLayout"); | tensor_val->set_layout("invalidLayout"); | ||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
@@ -1377,14 +1377,14 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | ||||
} | } | ||||
{ // invalid tensor | { // invalid tensor | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
auto graph_attr = attr_def->mutable_t()->mutable_desc(); | auto graph_attr = attr_def->mutable_t()->mutable_desc(); | ||||
auto attrs_of_graph = graph_attr->mutable_attr(); | auto attrs_of_graph = graph_attr->mutable_attr(); | ||||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | ||||
tensor_val->set_dtype(proto::DT_INT8); | |||||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||||
tensor_val->set_layout("invalidLayout"); | tensor_val->set_layout("invalidLayout"); | ||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
@@ -1402,15 +1402,15 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | ||||
} | } | ||||
{ // invalid list tensor | { // invalid list tensor | ||||
proto::ModelDef mode_def; | |||||
ge::proto::ModelDef mode_def; | |||||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR); | attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR); | ||||
auto graph_attr = attr_def->mutable_list()->add_t()->mutable_desc(); | auto graph_attr = attr_def->mutable_list()->add_t()->mutable_desc(); | ||||
auto attrs_of_graph = graph_attr->mutable_attr(); | auto attrs_of_graph = graph_attr->mutable_attr(); | ||||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | ||||
tensor_val->set_dtype(proto::DT_INT8); | |||||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||||
tensor_val->set_layout("invalidLayout"); | tensor_val->set_layout("invalidLayout"); | ||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
@@ -1429,15 +1429,15 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | ||||
} | } | ||||
{ // invalid list tensor | { // invalid list tensor | ||||
proto::GraphDef graph_def; | |||||
ge::proto::GraphDef graph_def; | |||||
auto attrs = graph_def.add_op()->mutable_attr(); // node attr | auto attrs = graph_def.add_op()->mutable_attr(); // node attr | ||||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||||
attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR); | attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR); | ||||
auto graph_attr = attr_def->mutable_list()->add_t()->mutable_desc(); | auto graph_attr = attr_def->mutable_list()->add_t()->mutable_desc(); | ||||
auto attrs_of_graph = graph_attr->mutable_attr(); | auto attrs_of_graph = graph_attr->mutable_attr(); | ||||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | ||||
tensor_val->set_dtype(proto::DT_INT8); | |||||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||||
tensor_val->set_layout("invalidLayout"); | tensor_val->set_layout("invalidLayout"); | ||||
ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
@@ -1462,7 +1462,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||||
TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | ||||
// model invalid node input | // model invalid node input | ||||
{ | { | ||||
proto::ModelDef model_def; | |||||
ge::proto::ModelDef model_def; | |||||
auto op_def = model_def.add_graph()->add_op(); // node attr | auto op_def = model_def.add_graph()->add_op(); // node attr | ||||
op_def->add_input("invalidNodeName:0"); | op_def->add_input("invalidNodeName:0"); | ||||
@@ -1475,7 +1475,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||||
} | } | ||||
// model invalid node control input | // model invalid node control input | ||||
{ | { | ||||
proto::ModelDef model_def; | |||||
ge::proto::ModelDef model_def; | |||||
auto op_def = model_def.add_graph()->add_op(); // node attr | auto op_def = model_def.add_graph()->add_op(); // node attr | ||||
op_def->add_input("invalidNodeName:-1"); | op_def->add_input("invalidNodeName:-1"); | ||||
@@ -1488,7 +1488,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||||
} | } | ||||
// model invalid graph input | // model invalid graph input | ||||
{ | { | ||||
proto::ModelDef model_def; | |||||
ge::proto::ModelDef model_def; | |||||
model_def.add_graph()->add_input("invalidNodeName:0"); | model_def.add_graph()->add_input("invalidNodeName:0"); | ||||
Buffer buffer(model_def.ByteSizeLong()); | Buffer buffer(model_def.ByteSizeLong()); | ||||
@@ -1500,7 +1500,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||||
} | } | ||||
// model invalid graph input | // model invalid graph input | ||||
{ | { | ||||
proto::ModelDef model_def; | |||||
ge::proto::ModelDef model_def; | |||||
model_def.add_graph()->add_output("invalidNodeName:0"); | model_def.add_graph()->add_output("invalidNodeName:0"); | ||||
Buffer buffer(model_def.ByteSizeLong()); | Buffer buffer(model_def.ByteSizeLong()); | ||||
@@ -1512,7 +1512,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||||
} | } | ||||
// graph invalid node input | // graph invalid node input | ||||
{ | { | ||||
proto::GraphDef graph_def; | |||||
ge::proto::GraphDef graph_def; | |||||
auto op_def = graph_def.add_op(); // node attr | auto op_def = graph_def.add_op(); // node attr | ||||
op_def->add_input("invalidNodeName:0"); | op_def->add_input("invalidNodeName:0"); | ||||
@@ -1525,7 +1525,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||||
} | } | ||||
// graph invalid node control input | // graph invalid node control input | ||||
{ | { | ||||
proto::GraphDef graph_def; | |||||
ge::proto::GraphDef graph_def; | |||||
auto op_def = graph_def.add_op(); // node attr | auto op_def = graph_def.add_op(); // node attr | ||||
op_def->add_input("invalidNodeName:-1"); | op_def->add_input("invalidNodeName:-1"); | ||||
@@ -1538,7 +1538,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||||
} | } | ||||
// graph invalid graph input | // graph invalid graph input | ||||
{ | { | ||||
proto::GraphDef graph_def; | |||||
ge::proto::GraphDef graph_def; | |||||
graph_def.add_input("invalidNodeName:0"); | graph_def.add_input("invalidNodeName:0"); | ||||
Buffer buffer(graph_def.ByteSizeLong()); | Buffer buffer(graph_def.ByteSizeLong()); | ||||
@@ -1550,7 +1550,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||||
} | } | ||||
// graph invalid graph output | // graph invalid graph output | ||||
{ | { | ||||
proto::GraphDef graph_def; | |||||
ge::proto::GraphDef graph_def; | |||||
graph_def.add_output("invalidNodeName:0"); | graph_def.add_output("invalidNodeName:0"); | ||||
Buffer buffer(graph_def.ByteSizeLong()); | Buffer buffer(graph_def.ByteSizeLong()); | ||||
@@ -1562,7 +1562,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||||
} | } | ||||
// model invalid node input anchor | // model invalid node input anchor | ||||
{ | { | ||||
proto::ModelDef model_def; | |||||
ge::proto::ModelDef model_def; | |||||
auto graph_def = model_def.add_graph(); | auto graph_def = model_def.add_graph(); | ||||
auto node_def1 = graph_def->add_op(); // node attr | auto node_def1 = graph_def->add_op(); // node attr | ||||
node_def1->set_name("node1"); | node_def1->set_name("node1"); | ||||
@@ -151,7 +151,7 @@ TEST_F(UtestGeNode, update_opdesc) { | |||||
EXPECT_EQ(n1->UpdateOpDesc(desc_ptr2), GRAPH_SUCCESS); | EXPECT_EQ(n1->UpdateOpDesc(desc_ptr2), GRAPH_SUCCESS); | ||||
} | } | ||||
/* | |||||
TEST_F(UtestGeNode, add_link_from) { | TEST_F(UtestGeNode, add_link_from) { | ||||
OpDescPtr desc_ptr = std::make_shared<OpDesc>("name", "type"); | OpDescPtr desc_ptr = std::make_shared<OpDesc>("name", "type"); | ||||
EXPECT_EQ(desc_ptr->AddInputDesc("x", GeTensorDesc(GeShape({1, 16, 16, 16}), FORMAT_NCHW)), GRAPH_SUCCESS); | EXPECT_EQ(desc_ptr->AddInputDesc("x", GeTensorDesc(GeShape({1, 16, 16, 16}), FORMAT_NCHW)), GRAPH_SUCCESS); | ||||
@@ -179,6 +179,7 @@ TEST_F(UtestGeNode, add_link_from) { | |||||
NodePtr n8 = graph_ptr1->AddNode(desc_ptr1); | NodePtr n8 = graph_ptr1->AddNode(desc_ptr1); | ||||
EXPECT_EQ(n8->AddLinkFromForParse(n7), GRAPH_PARAM_INVALID); | EXPECT_EQ(n8->AddLinkFromForParse(n7), GRAPH_PARAM_INVALID); | ||||
} | } | ||||
*/ | |||||
TEST_F(UtestGeNode, add_link_from_fail) { | TEST_F(UtestGeNode, add_link_from_fail) { | ||||
OpDescPtr desc_ptr = std::make_shared<OpDesc>("name1", "type1"); | OpDescPtr desc_ptr = std::make_shared<OpDesc>("name1", "type1"); | ||||
@@ -18,7 +18,7 @@ | |||||
#include "common/formats/format_transfers/datatype_transfer.h" | #include "common/formats/format_transfers/datatype_transfer.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "common/formats/formats.h" | #include "common/formats/formats.h" | ||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
@@ -17,9 +17,10 @@ | |||||
#include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
#include "common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h" | #include "common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -644,4 +645,4 @@ TEST_F(UTEST_FormatTransferNc1hwc0ToNchw, invalid_src_data_type) { | |||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | ||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -18,9 +18,12 @@ | |||||
#include "common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h" | #include "common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
class UtestFormatTransfer5dNhwc : public testing::Test { | class UtestFormatTransfer5dNhwc : public testing::Test { | ||||
@@ -759,4 +762,4 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_src_dst_shape_relation) { | |||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | ||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -18,9 +18,12 @@ | |||||
#include "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h" | #include "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
class UtestFormatTransferC1hwncoc0Hwcn : public testing::Test { | class UtestFormatTransferC1hwncoc0Hwcn : public testing::Test { | ||||
@@ -13710,4 +13713,4 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_fp32_success_gt_cube) { | |||||
} | } | ||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -19,11 +19,14 @@ | |||||
#include "common/formats/format_transfers/format_transfer_fractal_nz.h" | #include "common/formats/format_transfers/format_transfer_fractal_nz.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "common/formats/formats.h" | #include "common/formats/formats.h" | ||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "time.h" | #include "time.h" | ||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
class UtestFormatTransferNdFractNz : public testing::Test { | class UtestFormatTransferNdFractNz : public testing::Test { | ||||
@@ -9164,4 +9167,4 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_dst_shape_relation) { | |||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | ||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -19,11 +19,14 @@ | |||||
#include "common/formats/format_transfers/format_transfer_fractal_zz.h" | #include "common/formats/format_transfers/format_transfer_fractal_zz.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "common/formats/formats.h" | #include "common/formats/formats.h" | ||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "time.h" | #include "time.h" | ||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
class UtestFormatTransferNdFractZz : public testing::Test { | class UtestFormatTransferNdFractZz : public testing::Test { | ||||
@@ -7988,4 +7991,4 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_dst_shape_relation) { | |||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | ||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -18,9 +18,12 @@ | |||||
#include "common/formats/format_transfers/format_transfer_fracz_hwcn.h" | #include "common/formats/format_transfers/format_transfer_fracz_hwcn.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
class UtestFormatTransferFracZHwcn : public testing::Test { | class UtestFormatTransferFracZHwcn : public testing::Test { | ||||
@@ -18,9 +18,12 @@ | |||||
#include "common/formats/format_transfers/format_transfer_fracz_nchw.h" | #include "common/formats/format_transfers/format_transfer_fracz_nchw.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
class UtestFormatTransferFraczNchw : public testing::Test { | class UtestFormatTransferFraczNchw : public testing::Test { | ||||
@@ -10486,4 +10489,4 @@ TEST_F(UtestFormatTransferFraczNchw, fp32_1) { | |||||
} | } | ||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -18,9 +18,12 @@ | |||||
#include "common/formats/format_transfers/format_transfer_fracz_nhwc.h" | #include "common/formats/format_transfers/format_transfer_fracz_nhwc.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
class UtestFormatTransferFraczNhwc : public testing::Test { | class UtestFormatTransferFraczNhwc : public testing::Test { | ||||
@@ -5422,4 +5425,4 @@ TEST_F(UtestFormatTransferFraczNhwc, fracz_to_nhwc_fp32_success_gt_cube) { | |||||
} | } | ||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -18,9 +18,12 @@ | |||||
#include "common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h" | #include "common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
class UtestFormatTransferHwcnC1hwncoc0 : public testing::Test { | class UtestFormatTransferHwcnC1hwncoc0 : public testing::Test { | ||||
@@ -13745,4 +13748,4 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_fp32_success_gt_cube) { | |||||
} | } | ||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -18,7 +18,10 @@ | |||||
#include "common/formats/format_transfers/format_transfer_fractal_z.h" | #include "common/formats/format_transfers/format_transfer_fractal_z.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -34460,4 +34463,4 @@ TEST_F(UtestFormatTransferHwcnFz, build_transfer_not_support) { | |||||
EXPECT_EQ(transfer, nullptr); | EXPECT_EQ(transfer, nullptr); | ||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -18,7 +18,10 @@ | |||||
#include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" | #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -18,7 +18,10 @@ | |||||
#include "common/formats/format_transfers/format_transfer_fractal_z.h" | #include "common/formats/format_transfers/format_transfer_fractal_z.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -16873,4 +16876,4 @@ TEST_F(UtestFormatTransferNchwFz, build_transfer_uint8) { | |||||
EXPECT_NE(transfer, nullptr); | EXPECT_NE(transfer, nullptr); | ||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -18,9 +18,12 @@ | |||||
#include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" | #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
class UtestFormatTransferNhwc5d : public testing::Test { | class UtestFormatTransferNhwc5d : public testing::Test { | ||||
@@ -747,4 +750,4 @@ TEST_F(UtestFormatTransferNhwc5d, unsupport_dst_format) { | |||||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | ||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -18,7 +18,10 @@ | |||||
#include "common/formats/format_transfers/format_transfer_fractal_z.h" | #include "common/formats/format_transfers/format_transfer_fractal_z.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -5351,4 +5354,4 @@ TEST_F(UtestFormatTransferNhwcFz, build_transfer_uint8) { | |||||
EXPECT_NE(transfer, nullptr); | EXPECT_NE(transfer, nullptr); | ||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -18,9 +18,13 @@ | |||||
#include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" | #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
//#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "register/register_format_transfer.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -78,4 +82,4 @@ TEST_F(UtestFormatTransfer, get_size_by_data_type) { | |||||
EXPECT_EQ(DT_UNDEFINED, 26); | EXPECT_EQ(DT_UNDEFINED, 26); | ||||
} | } | ||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -189,18 +189,20 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||||
bool ExpectStreamEq(SubGraphInfoPtr subgraph, int64_t expect) { return GetStream(subgraph) == expect; } | bool ExpectStreamEq(SubGraphInfoPtr subgraph, int64_t expect) { return GetStream(subgraph) == expect; } | ||||
bool ExpectStreamNe(SubGraphInfoPtr subgraph, int64_t expect) { return GetStream(subgraph) != expect; } | bool ExpectStreamNe(SubGraphInfoPtr subgraph, int64_t expect) { return GetStream(subgraph) != expect; } | ||||
Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, vector<EngineConfPtr> &confs, | |||||
Status AssignLogicalStreams(Graph2SubGraphInfoList &subgraph_map, vector<EngineConfPtr> &confs, | |||||
std::map<std::string, int> &max_parallel_num, ComputeGraphPtr &whole_graph) { | std::map<std::string, int> &max_parallel_num, ComputeGraphPtr &whole_graph) { | ||||
SchedulerConf scheduler_conf; | SchedulerConf scheduler_conf; | ||||
if (confs.empty()) { | if (confs.empty()) { | ||||
for (const auto &subgraph : subgraphs) { | |||||
EngineConfPtr conf = make_shared<EngineConf>(); | |||||
conf->id = subgraph->GetEngineName(); | |||||
if (conf->id == "ge_local") { | |||||
conf->skip_assign_stream = true; | |||||
conf->attach = true; | |||||
} | |||||
scheduler_conf.cal_engines[conf->id] = conf; | |||||
for (const auto &subgraph_pair : subgraph_map) { | |||||
for (const auto &sub_graph : subgraph_pair.second) { | |||||
EngineConfPtr conf = make_shared<EngineConf>(); | |||||
conf->id = sub_graph->GetEngineName(); | |||||
if (conf->id == "ge_local") { | |||||
conf->skip_assign_stream = true; | |||||
conf->attach = true; | |||||
} | |||||
scheduler_conf.cal_engines[conf->id] = conf; | |||||
} | |||||
} | } | ||||
} else { | } else { | ||||
for (auto &conf : confs) { | for (auto &conf : confs) { | ||||
@@ -217,11 +219,21 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||||
scheduler_confs["scheduler"] = scheduler_conf; | scheduler_confs["scheduler"] = scheduler_conf; | ||||
LogicalStreamAllocator allocator(scheduler_confs, max_parallel_num); | LogicalStreamAllocator allocator(scheduler_confs, max_parallel_num); | ||||
int64_t stream_num = 0; | int64_t stream_num = 0; | ||||
return allocator.Assign(whole_graph, subgraphs, stream_num); | |||||
return allocator.Assign(whole_graph, subgraph_map, stream_num); | |||||
} | } | ||||
Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, std::map<std::string, int> &max_parallel_num, | |||||
vector<EngineConfPtr> &confs) { | |||||
Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, | |||||
vector<EngineConfPtr> &confs, | |||||
std::map<std::string, int> &max_parallel_num, | |||||
ComputeGraphPtr &whole_graph) { | |||||
Graph2SubGraphInfoList subgraph_map; | |||||
subgraph_map[whole_graph] = subgraphs; | |||||
return AssignLogicalStreams(subgraph_map, confs, max_parallel_num, whole_graph); | |||||
} | |||||
Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, | |||||
vector<EngineConfPtr>& confs, | |||||
std::map<std::string, int> &max_parallel_num) { | |||||
ComputeGraphPtr whole_graph = make_shared<ComputeGraph>("whole_graph"); | ComputeGraphPtr whole_graph = make_shared<ComputeGraph>("whole_graph"); | ||||
return AssignLogicalStreams(subgraphs, confs, max_parallel_num, whole_graph); | return AssignLogicalStreams(subgraphs, confs, max_parallel_num, whole_graph); | ||||
} | } | ||||
@@ -229,12 +241,12 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||||
Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, | Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, | ||||
vector<EngineConfPtr> confs = vector<EngineConfPtr>()) { | vector<EngineConfPtr> confs = vector<EngineConfPtr>()) { | ||||
std::map<std::string, int> max_parallel_num; | std::map<std::string, int> max_parallel_num; | ||||
return AssignLogicalStreams(subgraphs, max_parallel_num, confs); | |||||
return AssignLogicalStreams(subgraphs, confs, max_parallel_num); | |||||
} | } | ||||
Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, std::map<std::string, int> &max_parallel_num) { | Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, std::map<std::string, int> &max_parallel_num) { | ||||
vector<EngineConfPtr> confs; | vector<EngineConfPtr> confs; | ||||
return AssignLogicalStreams(subgraphs, max_parallel_num, confs); | |||||
return AssignLogicalStreams(subgraphs, confs, max_parallel_num); | |||||
} | } | ||||
/// typical case | /// typical case | ||||
@@ -295,7 +307,7 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||||
Status status = AssignLogicalStreams({const1, const2, get_next, genmask1, genmask2, domask, subgraph4, subgraph5, | Status status = AssignLogicalStreams({const1, const2, get_next, genmask1, genmask2, domask, subgraph4, subgraph5, | ||||
subgraph6, allreduce1, allreduce2, apply1, apply2}, | subgraph6, allreduce1, allreduce2, apply1, apply2}, | ||||
max_parallel_num, confs); | |||||
confs, max_parallel_num); | |||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
EXPECT_EQ(GetStream(get_next), 0); | EXPECT_EQ(GetStream(get_next), 0); | ||||
@@ -652,7 +664,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent) { | |||||
vector<EngineConfPtr> confs = {conf1, conf2}; | vector<EngineConfPtr> confs = {conf1, conf2}; | ||||
Status status = | Status status = | ||||
AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4, subgraph5}, max_parallel_num, confs); | |||||
AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4, subgraph5}, confs, max_parallel_num); | |||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
EXPECT_EQ(GetStream(subgraph1), 0); | EXPECT_EQ(GetStream(subgraph1), 0); | ||||
EXPECT_EQ(GetStream(subgraph2), 0); | EXPECT_EQ(GetStream(subgraph2), 0); | ||||
@@ -695,7 +707,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { | |||||
vector<EngineConfPtr> confs = {conf1, conf2, conf3}; | vector<EngineConfPtr> confs = {conf1, conf2, conf3}; | ||||
Status status = | Status status = | ||||
AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4, subgraph5}, max_parallel_num, confs); | |||||
AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4, subgraph5},confs, max_parallel_num); | |||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
EXPECT_EQ(GetStream(subgraph1), 4); | EXPECT_EQ(GetStream(subgraph1), 4); | ||||
EXPECT_EQ(GetStream(subgraph2), 0); | EXPECT_EQ(GetStream(subgraph2), 0); | ||||
@@ -858,7 +870,7 @@ TEST_F(UtestLogicalStreamAllocator, test_all_reduce_parallel_pass) { | |||||
std::map<std::string, int> max_parallel_num; | std::map<std::string, int> max_parallel_num; | ||||
LogicalStreamPass::Context context; | LogicalStreamPass::Context context; | ||||
context.next_stream = 5; | context.next_stream = 5; | ||||
context.hcom_parallel = true; | |||||
context.enable_hcom_parallel = true; | |||||
vector<LogicalStreamPass::SubgraphPtr> subgraphs; | vector<LogicalStreamPass::SubgraphPtr> subgraphs; | ||||
LogicalStreamPassPtr allreduce_pass = std::make_shared<AllReduceParallelPass>(); | LogicalStreamPassPtr allreduce_pass = std::make_shared<AllReduceParallelPass>(); | ||||
ret = allreduce_pass->Run(graph, subgraphs, context); | ret = allreduce_pass->Run(graph, subgraphs, context); | ||||
@@ -152,7 +152,7 @@ TEST_F(UtestMemoryAssignerTest, MemoryBlock_Resize_RealSizeList_is_empty) { | |||||
ge::OpDescPtr op_def_a = createOpWithWsSize("A", 6000); | ge::OpDescPtr op_def_a = createOpWithWsSize("A", 6000); | ||||
ge::NodePtr node_a = graph->AddNode(op_def_a); | ge::NodePtr node_a = graph->AddNode(op_def_a); | ||||
MemoryBlock* memory_block = new MemoryBlock(0); | MemoryBlock* memory_block = new MemoryBlock(0); | ||||
memory_block->Init(1, kOutput, node_a, 0); | |||||
memory_block->Init(1, kOutput, node_a, 0, 1); | |||||
memory_block->real_size_list_.clear(); | memory_block->real_size_list_.clear(); | ||||
memory_block->Resize(); | memory_block->Resize(); | ||||
@@ -165,7 +165,7 @@ namespace ge { | |||||
class MockBlockMemAssigner : public BlockMemAssigner { | class MockBlockMemAssigner : public BlockMemAssigner { | ||||
public: | public: | ||||
explicit MockBlockMemAssigner(ge::ComputeGraphPtr compute_graph) : BlockMemAssigner(compute_graph){}; | |||||
explicit MockBlockMemAssigner(ge::ComputeGraphPtr compute_graph, const std::map<std::string, std::string> &anchor_to_symbol, const std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors) : BlockMemAssigner(compute_graph, anchor_to_symbol, symbol_to_anchors) {}; | |||||
virtual ~MockBlockMemAssigner(){}; | virtual ~MockBlockMemAssigner(){}; | ||||
@@ -177,7 +177,10 @@ class MockBlockMemAssigner : public BlockMemAssigner { | |||||
TEST_F(UtestMemoryAssignerTest, Mock_block_mem_assigner_failed) { | TEST_F(UtestMemoryAssignerTest, Mock_block_mem_assigner_failed) { | ||||
ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>(""); | ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>(""); | ||||
make_graph(graph); | make_graph(graph); | ||||
MockBlockMemAssigner mock_assigner(graph); | |||||
std::map<std::string, std::string> anchor_to_symbol; | |||||
std::map<std::string, std::list<NodeIndexIO>> symbol_to_anchors; | |||||
EXPECT_EQ(GraphUtils::GetRefMapping(graph, symbol_to_anchors, anchor_to_symbol), GRAPH_SUCCESS); | |||||
MockBlockMemAssigner mock_assigner(graph, anchor_to_symbol, symbol_to_anchors); | |||||
EXPECT_EQ(mock_assigner.Assign(), FAILED); | EXPECT_EQ(mock_assigner.Assign(), FAILED); | ||||
} | } |
@@ -40,20 +40,22 @@ std::vector<void *> stub_get_output_addrs(const RuntimeParam &model_param, Const | |||||
} | } | ||||
TEST_F(UtestDataDumper, LoadDumpInfo_no_output_addrs_fail) { | TEST_F(UtestDataDumper, LoadDumpInfo_no_output_addrs_fail) { | ||||
DataDumper data_dumper; | |||||
RuntimeParam rts_param; | |||||
DataDumper data_dumper(rts_param); | |||||
data_dumper.SetModelName("test"); | data_dumper.SetModelName("test"); | ||||
data_dumper.SetModelId(2333); | data_dumper.SetModelId(2333); | ||||
data_dumper.SetMemory(std::move(RuntimeParam{})); | |||||
std::shared_ptr<OpDesc> op_desc_1(new OpDesc()); | std::shared_ptr<OpDesc> op_desc_1(new OpDesc()); | ||||
op_desc_1->AddOutputDesc("test", GeTensorDesc()); | op_desc_1->AddOutputDesc("test", GeTensorDesc()); | ||||
data_dumper.SaveDumpTask(0, op_desc_1, 0); | |||||
data_dumper.SaveDumpTask(0, 0, op_desc_1, 0); | |||||
string dump_mode = "output"; | |||||
data_dumper.dump_properties_.SetDumpMode(dump_mode); | |||||
Status ret = data_dumper.LoadDumpInfo(); | Status ret = data_dumper.LoadDumpInfo(); | ||||
EXPECT_EQ(ret, PARAM_INVALID); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | } | ||||
TEST_F(UtestDataDumper, UnloadDumpInfo_success) { | TEST_F(UtestDataDumper, UnloadDumpInfo_success) { | ||||
DataDumper data_dumper; | |||||
RuntimeParam rts_param; | |||||
DataDumper data_dumper(rts_param); | |||||
data_dumper.SetModelName("test"); | data_dumper.SetModelName("test"); | ||||
data_dumper.SetModelId(2333); | data_dumper.SetModelId(2333); | ||||
@@ -25,7 +25,6 @@ | |||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/model_serialize.h" | #include "graph/model_serialize.h" | ||||
#include "graph/load/new_model_manager/davinci_model.h" | #include "graph/load/new_model_manager/davinci_model.h" | ||||
#include "graph/load/new_model_manager/model_output.h" | |||||
#include "common/properties_manager.h" | #include "common/properties_manager.h" | ||||
#include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
#include <cce/taskdown_api.h> | #include <cce/taskdown_api.h> | ||||
@@ -38,7 +37,6 @@ | |||||
#include "graph/load/new_model_manager/task_info/stream_switch_task_info.h" | #include "graph/load/new_model_manager/task_info/stream_switch_task_info.h" | ||||
#include "graph/load/new_model_manager/task_info/profiler_trace_task_info.h" | #include "graph/load/new_model_manager/task_info/profiler_trace_task_info.h" | ||||
#include "graph/load/new_model_manager/task_info/memcpy_async_task_info.h" | #include "graph/load/new_model_manager/task_info/memcpy_async_task_info.h" | ||||
#include "graph/load/new_model_manager/task_info/label_goto_task_info.h" | |||||
#include "graph/load/new_model_manager/task_info/label_set_task_info.h" | #include "graph/load/new_model_manager/task_info/label_set_task_info.h" | ||||
#include "graph/load/new_model_manager/task_info/kernel_ex_task_info.h" | #include "graph/load/new_model_manager/task_info/kernel_ex_task_info.h" | ||||
#include "graph/load/new_model_manager/task_info/kernel_task_info.h" | #include "graph/load/new_model_manager/task_info/kernel_task_info.h" | ||||
@@ -113,7 +113,7 @@ class DModelListener : public ge::ModelListener { | |||||
uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { return 0; } | uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { return 0; } | ||||
}; | }; | ||||
shared_ptr<ge::ModelListener> UTEST_CALL_BACK_FUN(new DModelListener()); | |||||
shared_ptr<ModelListener> UTEST_CALL_BACK_FUN(new DModelListener()); | |||||
TEST_F(UtestModelManagerModelManager, case_load_incorrect_param) { | TEST_F(UtestModelManagerModelManager, case_load_incorrect_param) { | ||||
ModelManager mm; | ModelManager mm; | ||||
@@ -164,7 +164,7 @@ TEST_F(UtestModelManagerModelManager, case_load_model_encypt_type_unsupported) { | |||||
delete[](uint8_t *) data.model_data; | delete[](uint8_t *) data.model_data; | ||||
} | } | ||||
shared_ptr<ge::ModelListener> LabelCallBack(new DModelListener()); | |||||
shared_ptr<ModelListener> LabelCallBack(new DModelListener()); | |||||
// test HandleCommand | // test HandleCommand | ||||
TEST_F(UtestModelManagerModelManager, command_success1) { | TEST_F(UtestModelManagerModelManager, command_success1) { | ||||
@@ -306,6 +306,8 @@ TEST_F(UtestModelManagerModelManager, get_input_output_desc_info_fail) { | |||||
EXPECT_EQ(ge::PARAM_INVALID, manager.GetInputOutputDescInfo(2, input_shape, output_shape)); | EXPECT_EQ(ge::PARAM_INVALID, manager.GetInputOutputDescInfo(2, input_shape, output_shape)); | ||||
} | } | ||||
/* | |||||
// test GetInputOutputDescInfo fail | // test GetInputOutputDescInfo fail | ||||
TEST_F(UtestModelManagerModelManager, get_input_output_desc_info_zero_copy_fail) { | TEST_F(UtestModelManagerModelManager, get_input_output_desc_info_zero_copy_fail) { | ||||
ModelManager manager; | ModelManager manager; | ||||
@@ -314,6 +316,7 @@ TEST_F(UtestModelManagerModelManager, get_input_output_desc_info_zero_copy_fail) | |||||
vector<InputOutputDescInfo> output_shape; | vector<InputOutputDescInfo> output_shape; | ||||
EXPECT_EQ(ge::PARAM_INVALID, manager.GetInputOutputDescInfoForZeroCopy(2, input_shape, output_shape)); | EXPECT_EQ(ge::PARAM_INVALID, manager.GetInputOutputDescInfoForZeroCopy(2, input_shape, output_shape)); | ||||
} | } | ||||
*/ | |||||
// test Stop | // test Stop | ||||
TEST_F(UtestModelManagerModelManager, stop_fail) { | TEST_F(UtestModelManagerModelManager, stop_fail) { | ||||
@@ -324,7 +327,7 @@ TEST_F(UtestModelManagerModelManager, stop_fail) { | |||||
// build input_data | // build input_data | ||||
TEST_F(UtestModelManagerModelManager, check_data_len_success) { | TEST_F(UtestModelManagerModelManager, check_data_len_success) { | ||||
shared_ptr<ge::ModelListener> g_label_call_back(new DModelListener()); | |||||
shared_ptr<ModelListener> g_label_call_back(new DModelListener()); | |||||
DavinciModel model(0, g_label_call_back); | DavinciModel model(0, g_label_call_back); | ||||
ModelManager model_manager; | ModelManager model_manager; | ||||
ge::InputData input_data; | ge::InputData input_data; | ||||
@@ -134,6 +134,55 @@ class OmeTestOpUtils { | |||||
} | } | ||||
} | } | ||||
static Status TransModelToGeModel(const ModelPtr &model, GeModelPtr &ge_model) { | |||||
if (model == nullptr) { | |||||
GELOGE(FAILED, "Model is null"); | |||||
return FAILED; | |||||
} | |||||
ge_model = ge::MakeShared<ge::GeModel>(); | |||||
GE_CHECK_NOTNULL(ge_model); | |||||
ge_model->SetGraph(model->GetGraph()); | |||||
ge_model->SetName(model->GetName()); | |||||
ge_model->SetVersion(model->GetVersion()); | |||||
ge_model->SetPlatformVersion(model->GetPlatformVersion()); | |||||
ge_model->SetAttr(model->MutableAttrMap()); | |||||
auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph()); | |||||
ge::Buffer weight; | |||||
(void)ge::AttrUtils::GetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, weight); | |||||
ge_model->SetWeight(weight); | |||||
if (model->HasAttr(MODEL_ATTR_TASKS)) { | |||||
ge::Buffer task_buffer; | |||||
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetZeroCopyBytes(model, MODEL_ATTR_TASKS, task_buffer), FAILED, | |||||
"Get bytes failed."); | |||||
std::shared_ptr<ModelTaskDef> task = ge::MakeShared<ModelTaskDef>(); | |||||
GE_CHECK_NOTNULL(task); | |||||
GE_IF_BOOL_EXEC(task_buffer.GetData() == nullptr, GELOGE(FAILED, "Get data fail"); return FAILED); | |||||
GE_IF_BOOL_EXEC(task_buffer.GetSize() == 0, GELOGE(FAILED, "Get size fail"); return FAILED); | |||||
GE_CHK_BOOL_EXEC(ReadProtoFromArray(task_buffer.GetData(), static_cast<int>(task_buffer.GetSize()), task.get()), | |||||
return INTERNAL_ERROR, "ReadProtoFromArray failed."); | |||||
ge_model->SetModelTaskDef(task); | |||||
} | |||||
TBEKernelStore kernel_store; | |||||
if (compute_graph != nullptr && compute_graph->GetDirectNodesSize() != 0) { | |||||
for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { | |||||
auto node_op_desc = n->GetOpDesc(); | |||||
GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | |||||
TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | |||||
GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); | |||||
kernel_store.AddTBEKernel(tbe_kernel); | |||||
GELOGI("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); | |||||
} | |||||
} | |||||
if (!kernel_store.Build()) { | |||||
GELOGE(FAILED, "TBE Kernels store build failed!"); | |||||
return FAILED; | |||||
} | |||||
ge_model->SetTBEKernelStore(kernel_store); | |||||
return SUCCESS; | |||||
} | |||||
static void LoadStandardModelDataLocal(ge::ModelData &data) { | static void LoadStandardModelDataLocal(ge::ModelData &data) { | ||||
static const std::string STANDARD_MODEL_DATA_PATH = | static const std::string STANDARD_MODEL_DATA_PATH = | ||||
"llt/framework/domi/ut/ome/test/data/standard_partition_model.txt"; | "llt/framework/domi/ut/ome/test/data/standard_partition_model.txt"; | ||||
@@ -151,7 +200,7 @@ class OmeTestOpUtils { | |||||
ge::Model::Load((uint8_t *)data.model_data, data.model_len, *model_); | ge::Model::Load((uint8_t *)data.model_data, data.model_len, *model_); | ||||
GeModelPtr ge_model; | GeModelPtr ge_model; | ||||
ModelHelper::TransModelToGeModel(model_, ge_model); | |||||
TransModelToGeModel(model_, ge_model); | |||||
davinciModel.Assign(ge_model); | davinciModel.Assign(ge_model); | ||||
if (data.model_data != nullptr) { | if (data.model_data != nullptr) { | ||||
@@ -178,7 +227,7 @@ class OmeTestOpUtils { | |||||
model->SetGraph(graph); | model->SetGraph(graph); | ||||
GeModelPtr ge_model; | GeModelPtr ge_model; | ||||
ModelHelper::TransModelToGeModel(model, ge_model); | |||||
TransModelToGeModel(model, ge_model); | |||||
davinciModel.Assign(ge_model); | davinciModel.Assign(ge_model); | ||||
} | } | ||||
@@ -24,9 +24,7 @@ | |||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
#include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
#include "graph/load/new_model_manager/davinci_model.h" | #include "graph/load/new_model_manager/davinci_model.h" | ||||
#include "graph/load/new_model_manager/model_output.h" | |||||
#include "graph/load/new_model_manager/model_utils.h" | #include "graph/load/new_model_manager/model_utils.h" | ||||
#include "graph/load/output/output.h" | |||||
#include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
#include "new_op_test_utils.h" | #include "new_op_test_utils.h" | ||||
#include "proto/om.pb.h" | #include "proto/om.pb.h" | ||||
@@ -49,7 +49,7 @@ class UtestTestPass : public BaseNodePass { | |||||
for (const auto &node_name : iter->second) { | for (const auto &node_name : iter->second) { | ||||
auto del_node = node->GetOwnerComputeGraph()->FindNode(node_name); | auto del_node = node->GetOwnerComputeGraph()->FindNode(node_name); | ||||
GraphUtils::IsolateNode(del_node, {0}); | GraphUtils::IsolateNode(del_node, {0}); | ||||
AddNodeDeleted(del_node.get()); | |||||
AddNodeDeleted(del_node); | |||||
} | } | ||||
} | } | ||||
iter = names_to_add_repass_.find(node->GetName()); | iter = names_to_add_repass_.find(node->GetName()); | ||||
@@ -38,7 +38,7 @@ class UtestGraphPassesFlowCtrlPass : public testing::Test { | |||||
EXPECT_EQ(SUCCESS, ge::VarManager::Instance(0)->Init(session_version, session_id, device_id, job_id)); | EXPECT_EQ(SUCCESS, ge::VarManager::Instance(0)->Init(session_version, session_id, device_id, job_id)); | ||||
} | } | ||||
void TearDown() { VarManagerPool::Instance().Destroy(); } | |||||
void TearDown() { VarManagerPool::Instance().Destory(); } | |||||
public: | public: | ||||
/// Set up a graph with the following network structure | /// Set up a graph with the following network structure | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/add_kernel.h" | |||||
#include "host_kernels/add_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -31,7 +31,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/broadcast_args_kernel.h" | |||||
#include "host_kernels/broadcast_args_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -31,7 +31,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/broadcast_gradient_args_kernel.h" | |||||
#include "host_kernels/broadcast_gradient_args_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/cast_kernel.h" | |||||
#include "host_kernels/cast_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -28,7 +28,7 @@ | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/passes/dimension_compute_pass.h" | #include "graph/passes/dimension_compute_pass.h" | ||||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||||
#include "host_kernels/kernel_utils.h" | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/concat_offset_kernel.h" | |||||
#include "host_kernels/concat_offset_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -28,7 +28,7 @@ | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/passes/dimension_compute_pass.h" | #include "graph/passes/dimension_compute_pass.h" | ||||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||||
#include "host_kernels/kernel_utils.h" | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/concat_v2_kernel.h" | |||||
#include "host_kernels/concat_v2_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -28,7 +28,7 @@ | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/passes/dimension_compute_pass.h" | #include "graph/passes/dimension_compute_pass.h" | ||||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||||
#include "host_kernels/kernel_utils.h" | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/dynamic_stitch_kernel.h" | |||||
#include "host_kernels/dynamic_stitch_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -27,7 +27,7 @@ | |||||
#include "common/op/attr_value_util.h" | #include "common/op/attr_value_util.h" | ||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||||
#include "host_kernels/kernel_utils.h" | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -31,7 +31,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/empty_kernel.h" | |||||
#include "host_kernels/empty_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/expanddims_kernel.h" | |||||
#include "host_kernels/expanddims_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/fill_kernel.h" | |||||
#include "host_kernels/fill_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
@@ -18,13 +18,13 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/floordiv_kernel.h" | |||||
#include "host_kernels/floordiv_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
#include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||||
#include "host_kernels/kernel_utils.h" | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -20,7 +20,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/floormod_kernel.h" | |||||
#include "host_kernels/floormod_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -27,7 +27,7 @@ | |||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/operator.h" | #include "graph/operator.h" | ||||
#include "graph/passes/constant_folding_pass.h" | #include "graph/passes/constant_folding_pass.h" | ||||
#include "graph/passes/folding_kernel/broadcast_args_kernel.h" | |||||
#include "host_kernels/broadcast_args_kernel.h" | |||||
#include "inc/kernel_factory.h" | #include "inc/kernel_factory.h" | ||||
#include "shape_refiner.h" | #include "shape_refiner.h" | ||||
@@ -19,7 +19,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/gather_v2_kernel.h" | |||||
#include "host_kernels/gather_v2_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -29,7 +29,7 @@ | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/passes/dimension_compute_pass.h" | #include "graph/passes/dimension_compute_pass.h" | ||||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||||
#include "host_kernels/kernel_utils.h" | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -984,4 +984,4 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, AbnormalTest) { | |||||
status = kernel->Compute(op_desc_ptr, input_7, outputs); | status = kernel->Compute(op_desc_ptr, input_7, outputs); | ||||
EXPECT_NE(ge::SUCCESS, status); | EXPECT_NE(ge::SUCCESS, status); | ||||
} | } | ||||
} | |||||
} |
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/greater_kernel.h" | |||||
#include "host_kernels/greater_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -20,7 +20,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/maximum_kernel.h" | |||||
#include "host_kernels/maximum_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -20,7 +20,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/mul_kernel.h" | |||||
#include "host_kernels/mul_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/pack_kernel.h" | |||||
#include "host_kernels/pack_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -28,7 +28,7 @@ | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/passes/dimension_compute_pass.h" | #include "graph/passes/dimension_compute_pass.h" | ||||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||||
#include "host_kernels/kernel_utils.h" | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/permute_kernel.h" | |||||
#include "host_kernels/permute_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -28,7 +28,7 @@ | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/passes/dimension_compute_pass.h" | #include "graph/passes/dimension_compute_pass.h" | ||||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||||
#include "host_kernels/kernel_utils.h" | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -20,7 +20,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/range_kernel.h" | |||||
#include "host_kernels/range_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/rank_kernel.h" | |||||
#include "host_kernels/rank_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -18,14 +18,14 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/reduce_prod_kernel.h" | |||||
#include "host_kernels/reduce_prod_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
#include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/passes/folding_kernel/concat_v2_kernel.h" | |||||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||||
#include "host_kernels/concat_v2_kernel.h" | |||||
#include "host_kernels/kernel_utils.h" | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -16,12 +16,12 @@ | |||||
#include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
#include "graph/passes/folding_kernel/reformat_kernel.h" | |||||
#include "host_kernels/reformat_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||||
#include "host_kernels/kernel_utils.h" | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/reshape_kernel.h" | |||||
#include "host_kernels/reshape_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/rsqrt_kernel.h" | |||||
#include "host_kernels/rsqrt_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/shape_kernel.h" | |||||
#include "host_kernels/shape_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/shape_n_kernel.h" | |||||
#include "host_kernels/shape_n_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -19,7 +19,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/size_kernel.h" | |||||
#include "host_kernels/size_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/slice_kernel.h" | |||||
#include "host_kernels/slice_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/squeeze_kernel.h" | |||||
#include "host_kernels/squeeze_kernel.h" | |||||
#include "../graph_builder_utils.h" | #include "../graph_builder_utils.h" | ||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/ssd_prior_box_kernel.h" | |||||
#include "host_kernels/ssd_prior_box_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/strided_slice_kernel.h" | |||||
#include "host_kernels/strided_slice_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -28,7 +28,7 @@ | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/passes/dimension_compute_pass.h" | #include "graph/passes/dimension_compute_pass.h" | ||||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||||
#include "host_kernels/kernel_utils.h" | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -18,7 +18,7 @@ | |||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/folding_kernel/sub_kernel.h" | |||||
#include "host_kernels/sub_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -16,7 +16,7 @@ | |||||
#include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
#include "graph/passes/folding_kernel/transdata_kernel.h" | |||||
#include "host_kernels/transdata_kernel.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -26,7 +26,7 @@ | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/passes/dimension_compute_pass.h" | #include "graph/passes/dimension_compute_pass.h" | ||||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||||
#include "host_kernels/kernel_utils.h" | |||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
@@ -155,7 +155,7 @@ TEST_F(UtestGraphPassesNetOutputPass, add_ctrl_edge_for_netout_from_leaf_success | |||||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{relu3, 0}}; | std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{relu3, 0}}; | ||||
compute_graph->SetGraphOutNodesInfo(output_nodes); | compute_graph->SetGraphOutNodesInfo(output_nodes); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
// check contain netoutput | // check contain netoutput | ||||
@@ -200,7 +200,7 @@ TEST_F(UtestGraphPassesNetOutputPass, only_target_node_success) { | |||||
std::vector<ge::NodePtr> target_nodes = {mul1, mul2}; | std::vector<ge::NodePtr> target_nodes = {mul1, mul2}; | ||||
compute_graph->SetGraphTargetNodesInfo(target_nodes); | compute_graph->SetGraphTargetNodesInfo(target_nodes); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
// check contain netoutput | // check contain netoutput | ||||
@@ -256,7 +256,7 @@ TEST_F(UtestGraphPassesNetOutputPass, targets_with_retval_success) { | |||||
} | } | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
// check contain netoutput | // check contain netoutput | ||||
@@ -300,7 +300,7 @@ TEST_F(UtestGraphPassesNetOutputPass, output_node_and_target_node_no_duplicate_s | |||||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{relu3, 0}}; | std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{relu3, 0}}; | ||||
compute_graph->SetGraphOutNodesInfo(output_nodes); | compute_graph->SetGraphOutNodesInfo(output_nodes); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
// check contain netoutput | // check contain netoutput | ||||
@@ -348,7 +348,7 @@ TEST_F(UtestGraphPassesNetOutputPass, output_node_and_target_node_duplicate_succ | |||||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}}; | std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}}; | ||||
compute_graph->SetGraphOutNodesInfo(output_nodes); | compute_graph->SetGraphOutNodesInfo(output_nodes); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
// check contain netoutput | // check contain netoutput | ||||
@@ -398,7 +398,7 @@ TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_target_node_success) { | |||||
compute_graph->SetGraphTargetNodesInfo(target_nodes); | compute_graph->SetGraphTargetNodesInfo(target_nodes); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
// check contain netoutput | // check contain netoutput | ||||
@@ -462,7 +462,7 @@ TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_output_nodes_and_targe | |||||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}}; | std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}}; | ||||
compute_graph->SetGraphOutNodesInfo(output_nodes); | compute_graph->SetGraphOutNodesInfo(output_nodes); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
// check contain netoutput | // check contain netoutput | ||||
@@ -518,7 +518,7 @@ TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_output_nodes_and_targe | |||||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}}; | std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}}; | ||||
compute_graph->SetGraphOutNodesInfo(output_nodes); | compute_graph->SetGraphOutNodesInfo(output_nodes); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
// check contain netoutput | // check contain netoutput | ||||
@@ -582,7 +582,7 @@ TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_output_nodes_and_targe | |||||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}}; | std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}}; | ||||
compute_graph->SetGraphOutNodesInfo(output_nodes); | compute_graph->SetGraphOutNodesInfo(output_nodes); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
// check contain netoutput | // check contain netoutput | ||||
@@ -626,7 +626,7 @@ TEST_F(UtestGraphPassesNetOutputPass, no_output_no_target_no_retval_success) { | |||||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}}; | std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}}; | ||||
compute_graph->SetGraphOutNodesInfo(output_nodes); | compute_graph->SetGraphOutNodesInfo(output_nodes); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
} | } | ||||
@@ -641,7 +641,7 @@ TEST_F(UtestGraphPassesNetOutputPass, user_out_node_success) { | |||||
compute_graph->SetGraphOutNodesInfo(output_nodes); | compute_graph->SetGraphOutNodesInfo(output_nodes); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | ||||
@@ -687,7 +687,7 @@ TEST_F(UtestGraphPassesNetOutputPass, retval_node_for_out_success) { | |||||
} | } | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | ||||
@@ -737,7 +737,7 @@ TEST_F(UtestGraphPassesNetOutputPass, check_order_and_const_flag_success) { | |||||
GraphUtils::AddEdge(mul2->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0)); | GraphUtils::AddEdge(mul2->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0)); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | ||||
@@ -775,7 +775,7 @@ TEST_F(UtestGraphPassesNetOutputPass, out_node_check_fail) { | |||||
compute_graph->SetGraphOutNodesInfo(output_nodes_invalid_name); | compute_graph->SetGraphOutNodesInfo(output_nodes_invalid_name); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::INTERNAL_ERROR); | EXPECT_EQ(status, ge::INTERNAL_ERROR); | ||||
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | ||||
@@ -817,7 +817,7 @@ TEST_F(UtestGraphPassesNetOutputPass, retval_node_check_fail) { | |||||
} | } | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::INTERNAL_ERROR); | EXPECT_EQ(status, ge::INTERNAL_ERROR); | ||||
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | ||||
@@ -832,7 +832,7 @@ TEST_F(UtestGraphPassesNetOutputPass, out_node_update_desc_check_fail) { | |||||
EXPECT_NE(netout_node, nullptr); | EXPECT_NE(netout_node, nullptr); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::INTERNAL_ERROR); | EXPECT_EQ(status, ge::INTERNAL_ERROR); | ||||
} | } | ||||
@@ -852,7 +852,7 @@ TEST_F(UtestGraphPassesNetOutputPass, out_node_remove_check_fail) { | |||||
EXPECT_EQ(mul1, nullptr); | EXPECT_EQ(mul1, nullptr); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
} | } | ||||
@@ -72,7 +72,7 @@ ComputeGraphPtr CreatePadGraph() { | |||||
TEST_F(UtestGraphPassesPassManagerPass, all_pass_success) { | TEST_F(UtestGraphPassesPassManagerPass, all_pass_success) { | ||||
PassManager manager; | PassManager manager; | ||||
manager.AddPass(new SuccessGraphPass); | |||||
manager.AddPass("", new SuccessGraphPass); | |||||
EXPECT_EQ(manager.GraphPasses().size(), 1); | EXPECT_EQ(manager.GraphPasses().size(), 1); | ||||
ComputeGraphPtr graph = CreatePadGraph(); | ComputeGraphPtr graph = CreatePadGraph(); | ||||
@@ -83,7 +83,7 @@ TEST_F(UtestGraphPassesPassManagerPass, all_pass_success) { | |||||
TEST_F(UtestGraphPassesPassManagerPass, graph_pass_success) { | TEST_F(UtestGraphPassesPassManagerPass, graph_pass_success) { | ||||
ComputeGraphPtr graph = CreatePadGraph(); | ComputeGraphPtr graph = CreatePadGraph(); | ||||
SuccessGraphPass pass; | SuccessGraphPass pass; | ||||
vector<GraphPass *> passes = {&pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
EXPECT_EQ(SUCCESS, status); | EXPECT_EQ(SUCCESS, status); | ||||
} | } | ||||
@@ -91,7 +91,7 @@ TEST_F(UtestGraphPassesPassManagerPass, graph_pass_success) { | |||||
TEST_F(UtestGraphPassesPassManagerPass, graph_pass_not_changed) { | TEST_F(UtestGraphPassesPassManagerPass, graph_pass_not_changed) { | ||||
ComputeGraphPtr graph = CreatePadGraph(); | ComputeGraphPtr graph = CreatePadGraph(); | ||||
NotChangedGraphPass pass; | NotChangedGraphPass pass; | ||||
vector<GraphPass *> passes = {&pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
EXPECT_EQ(NOT_CHANGED, status); | EXPECT_EQ(NOT_CHANGED, status); | ||||
} | } | ||||
@@ -99,7 +99,7 @@ TEST_F(UtestGraphPassesPassManagerPass, graph_pass_not_changed) { | |||||
TEST_F(UtestGraphPassesPassManagerPass, graph_pass_error) { | TEST_F(UtestGraphPassesPassManagerPass, graph_pass_error) { | ||||
ComputeGraphPtr graph = CreatePadGraph(); | ComputeGraphPtr graph = CreatePadGraph(); | ||||
ErrorGraphPass pass; | ErrorGraphPass pass; | ||||
vector<GraphPass *> passes = {&pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
EXPECT_EQ(FAILED, status); | EXPECT_EQ(FAILED, status); | ||||
} | } |
@@ -67,7 +67,7 @@ TEST_F(UtestGraphPassesPrunePass, no_net_out_put_node) { | |||||
uint64_t size_ori = graph->GetDirectNode().size(); | uint64_t size_ori = graph->GetDirectNode().size(); | ||||
PrunePass prune_pass; | PrunePass prune_pass; | ||||
vector<GraphPass *> passes = {&prune_pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} }; | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
EXPECT_EQ(ge::SUCCESS, status); | EXPECT_EQ(ge::SUCCESS, status); | ||||
@@ -109,7 +109,7 @@ TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_only_one_path) { | |||||
uint64_t size_ori = graph->GetDirectNode().size(); | uint64_t size_ori = graph->GetDirectNode().size(); | ||||
PrunePass prune_pass; | PrunePass prune_pass; | ||||
vector<GraphPass *> passes = {&prune_pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} }; | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
uint64_t size = graph->GetDirectNode().size(); | uint64_t size = graph->GetDirectNode().size(); | ||||
@@ -250,7 +250,7 @@ TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_multi_path) { | |||||
uint64_t size_ori = graph->GetDirectNode().size(); | uint64_t size_ori = graph->GetDirectNode().size(); | ||||
PrunePass prune_pass; | PrunePass prune_pass; | ||||
vector<GraphPass *> passes = {&prune_pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} }; | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
uint64_t size_after_proc = graph->GetDirectNode().size(); | uint64_t size_after_proc = graph->GetDirectNode().size(); | ||||
@@ -323,7 +323,7 @@ TEST_F(UtestGraphPassesPrunePass, multi_net_out_put_node_with_circle_net) { | |||||
uint64_t size_ori = graph->GetDirectNode().size(); | uint64_t size_ori = graph->GetDirectNode().size(); | ||||
PrunePass prune_pass; | PrunePass prune_pass; | ||||
vector<GraphPass *> passes = {&prune_pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} }; | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
EXPECT_EQ(ge::SUCCESS, status); | EXPECT_EQ(ge::SUCCESS, status); | ||||
uint64_t size_after_proc = graph->GetDirectNode().size(); | uint64_t size_after_proc = graph->GetDirectNode().size(); | ||||
@@ -464,7 +464,7 @@ TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_two_isolate_data_nod | |||||
uint64_t size_ori = graph->GetDirectNode().size(); | uint64_t size_ori = graph->GetDirectNode().size(); | ||||
PrunePass prune_pass; | PrunePass prune_pass; | ||||
vector<GraphPass *> passes = {&prune_pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} }; | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
uint64_t size = graph->GetDirectNode().size(); | uint64_t size = graph->GetDirectNode().size(); | ||||
@@ -68,7 +68,7 @@ TEST_F(UtestResourcePairControlPass, resource_pair_control) { | |||||
EXPECT_EQ(stackpop0->GetInControlNodes().size(), 0); | EXPECT_EQ(stackpop0->GetInControlNodes().size(), 0); | ||||
ResourcePairAddControlPass add_pass; | ResourcePairAddControlPass add_pass; | ||||
vector<GraphPass*> passes = {&add_pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes = { {"", &add_pass} }; | |||||
EXPECT_EQ(PassManager::Run(graph, passes), SUCCESS); | EXPECT_EQ(PassManager::Run(graph, passes), SUCCESS); | ||||
auto stackpush1 = graph->FindNode("stackpush1"); | auto stackpush1 = graph->FindNode("stackpush1"); | ||||
@@ -80,7 +80,7 @@ TEST_F(UtestResourcePairControlPass, resource_pair_control) { | |||||
EXPECT_EQ(stackpop1->GetInControlNodes().at(0)->GetName(), "stackpush1"); | EXPECT_EQ(stackpop1->GetInControlNodes().at(0)->GetName(), "stackpush1"); | ||||
ResourcePairRemoveControlPass remove_pass; | ResourcePairRemoveControlPass remove_pass; | ||||
passes = {&remove_pass}; | |||||
passes = { {"", &remove_pass} }; | |||||
EXPECT_EQ(PassManager::Run(graph, passes), SUCCESS); | EXPECT_EQ(PassManager::Run(graph, passes), SUCCESS); | ||||
auto stackpush2 = graph->FindNode("stackpush1"); | auto stackpush2 = graph->FindNode("stackpush1"); | ||||
@@ -70,7 +70,7 @@ ge::ComputeGraphPtr CreateSaveGraph() { | |||||
TEST_F(UtestGraphPassesSavePass, cover_run_success) { | TEST_F(UtestGraphPassesSavePass, cover_run_success) { | ||||
ge::ComputeGraphPtr compute_graph = CreateSaveGraph(); | ge::ComputeGraphPtr compute_graph = CreateSaveGraph(); | ||||
ge::PassManager pass_managers; | ge::PassManager pass_managers; | ||||
pass_managers.AddPass(new (std::nothrow) SavePass); | |||||
pass_managers.AddPass("", new (std::nothrow) SavePass); | |||||
Status status = pass_managers.Run(compute_graph); | Status status = pass_managers.Run(compute_graph); | ||||
EXPECT_EQ(status, ge::SUCCESS); | EXPECT_EQ(status, ge::SUCCESS); | ||||
} | } |
@@ -19,7 +19,6 @@ | |||||
#include "omg/omg_inner_types.h" | #include "omg/omg_inner_types.h" | ||||
#define protected public | #define protected public | ||||
#define private public | #define private public | ||||
#include "graph/passes/switch_op_pass.h" | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/debug/memory_dumper.h" | #include "common/debug/memory_dumper.h" | ||||
@@ -27,7 +26,6 @@ | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
#include "graph/passes/control_op_attr_pass.h" | |||||
#include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
#undef protected | #undef protected | ||||
#undef private | #undef private | ||||
@@ -19,7 +19,6 @@ | |||||
#include <string> | #include <string> | ||||
#define private public | #define private public | ||||
#include "graph/passes/switch_pass.h" | |||||
#include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
#include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
@@ -54,9 +54,11 @@ TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_reshape) { | |||||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), reshape_node->GetInDataAnchor(0)); | GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), reshape_node->GetInDataAnchor(0)); | ||||
ge::UnusedOpRemovePass unused_pass(FMK_TYPE_T); | |||||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||||
ge::IsolatedOpRemovePass isolate_pass; | ge::IsolatedOpRemovePass isolate_pass; | ||||
vector<GraphPass *> passes = {&unused_pass, &isolate_pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
passes.emplace_back("", &isolate_pass); | |||||
passes.emplace_back("", &unused_pass); | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
EXPECT_EQ(SUCCESS, status); | EXPECT_EQ(SUCCESS, status); | ||||
NodePtr found_node = graph->FindNode("transpose1"); | NodePtr found_node = graph->FindNode("transpose1"); | ||||
@@ -73,9 +75,11 @@ TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_squeeze) { | |||||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), squeeze_node->GetInDataAnchor(0)); | GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), squeeze_node->GetInDataAnchor(0)); | ||||
ge::UnusedOpRemovePass unused_pass(FMK_TYPE_T); | |||||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||||
ge::IsolatedOpRemovePass isolate_pass; | ge::IsolatedOpRemovePass isolate_pass; | ||||
vector<GraphPass *> passes = {&unused_pass, &isolate_pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
passes.emplace_back("", &isolate_pass); | |||||
passes.emplace_back("", &unused_pass); | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
EXPECT_EQ(SUCCESS, status); | EXPECT_EQ(SUCCESS, status); | ||||
NodePtr found_node = graph->FindNode("transpose1"); | NodePtr found_node = graph->FindNode("transpose1"); | ||||
@@ -100,9 +104,11 @@ TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv) { | |||||
NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); | NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); | ||||
GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); | GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); | ||||
ge::UnusedOpRemovePass unused_pass(FMK_TYPE_T); | |||||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||||
ge::IsolatedOpRemovePass isolate_pass; | ge::IsolatedOpRemovePass isolate_pass; | ||||
vector<GraphPass *> passes = {&unused_pass, &isolate_pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
passes.emplace_back("", &isolate_pass); | |||||
passes.emplace_back("", &unused_pass); | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
EXPECT_EQ(SUCCESS, status); | EXPECT_EQ(SUCCESS, status); | ||||
NodePtr found_node0 = graph->FindNode("transpose1"); | NodePtr found_node0 = graph->FindNode("transpose1"); | ||||
@@ -128,9 +134,11 @@ TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv3) { | |||||
NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); | NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); | ||||
GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); | GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); | ||||
ge::UnusedOpRemovePass unused_pass(FMK_TYPE_T); | |||||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||||
ge::IsolatedOpRemovePass isolate_pass; | ge::IsolatedOpRemovePass isolate_pass; | ||||
vector<GraphPass *> passes = {&unused_pass, &isolate_pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
passes.emplace_back("", &isolate_pass); | |||||
passes.emplace_back("", &unused_pass); | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
EXPECT_EQ(SUCCESS, status); | EXPECT_EQ(SUCCESS, status); | ||||
NodePtr found_node0 = graph->FindNode("transpose1"); | NodePtr found_node0 = graph->FindNode("transpose1"); | ||||
@@ -151,9 +159,11 @@ TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, cast_and_cast) { | |||||
GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); | GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); | ||||
ge::UnusedOpRemovePass unused_pass(FMK_TYPE_T); | |||||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||||
ge::IsolatedOpRemovePass isolate_pass; | ge::IsolatedOpRemovePass isolate_pass; | ||||
vector<GraphPass *> passes = {&unused_pass, &isolate_pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
passes.emplace_back("", &isolate_pass); | |||||
passes.emplace_back("", &unused_pass); | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
EXPECT_EQ(SUCCESS, status); | EXPECT_EQ(SUCCESS, status); | ||||
} | } | ||||
@@ -171,9 +181,11 @@ TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, remove_parent_node) { | |||||
GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | ||||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); | GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); | ||||
ge::UnusedOpRemovePass unused_pass(FMK_TYPE_T); | |||||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||||
ge::IsolatedOpRemovePass isolate_pass; | ge::IsolatedOpRemovePass isolate_pass; | ||||
vector<GraphPass *> passes = {&unused_pass, &isolate_pass}; | |||||
std::vector<std::pair<string, GraphPass*>> passes; | |||||
passes.emplace_back("", &isolate_pass); | |||||
passes.emplace_back("", &unused_pass); | |||||
Status status = PassManager::Run(graph, passes); | Status status = PassManager::Run(graph, passes); | ||||
EXPECT_EQ(SUCCESS, status); | EXPECT_EQ(SUCCESS, status); | ||||
} | } |
@@ -39,7 +39,6 @@ | |||||
#include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
#include "graph_builder_utils.h" | #include "graph_builder_utils.h" | ||||
#include "cce/dnn_struct_base.hpp" | #include "cce/dnn_struct_base.hpp" | ||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
#include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" | #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" | ||||
#include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" | #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" | ||||
#include "common/formats/format_transfers/datatype_transfer.h" | #include "common/formats/format_transfers/datatype_transfer.h" | ||||
@@ -38,10 +38,9 @@ class UtestSingleOpManager : public testing::Test { | |||||
}; | }; | ||||
TEST_F(UtestSingleOpManager, test_get_resource) { | TEST_F(UtestSingleOpManager, test_get_resource) { | ||||
uintptr_t resource_id = 0x1; | |||||
rtStream_t stream = (rtStream_t)0x01; | |||||
auto &instance = SingleOpManager::GetInstance(); | auto &instance = SingleOpManager::GetInstance(); | ||||
ASSERT_EQ(instance.TryGetResource(resource_id), nullptr); | |||||
ASSERT_NE(instance.GetResource(resource_id), nullptr); | |||||
ASSERT_NE(instance.GetResource(0x01, stream), nullptr); | |||||
} | } | ||||
TEST_F(UtestSingleOpManager, test_get_op_from_model) { | TEST_F(UtestSingleOpManager, test_get_op_from_model) { | ||||
@@ -56,7 +55,7 @@ TEST_F(UtestSingleOpManager, test_get_op_from_model) { | |||||
model_data.model_len = model_str.size(); | model_data.model_len = model_str.size(); | ||||
ASSERT_EQ(instance.GetOpFromModel("model", model_data, stream, &single_op), FAILED); | ASSERT_EQ(instance.GetOpFromModel("model", model_data, stream, &single_op), FAILED); | ||||
ASSERT_EQ(instance.GetResource(resource_id)->GetOperator(model_data.model_data), nullptr); | |||||
ASSERT_EQ(instance.GetResource(resource_id, stream)->GetOperator(model_data.model_data), nullptr); | |||||
} | } | ||||
TEST_F(UtestSingleOpManager, test_relesase_resource) { | TEST_F(UtestSingleOpManager, test_relesase_resource) { | ||||
@@ -64,7 +63,7 @@ TEST_F(UtestSingleOpManager, test_relesase_resource) { | |||||
auto &instance = SingleOpManager::GetInstance(); | auto &instance = SingleOpManager::GetInstance(); | ||||
ASSERT_EQ(instance.ReleaseResource(stream), SUCCESS); | ASSERT_EQ(instance.ReleaseResource(stream), SUCCESS); | ||||
instance.GetResource(0x99); | |||||
instance.GetResource(0x99, stream); | |||||
ASSERT_EQ(instance.ReleaseResource(stream), SUCCESS); | ASSERT_EQ(instance.ReleaseResource(stream), SUCCESS); | ||||
} | } | ||||
@@ -92,4 +91,4 @@ TEST_F(UtestSingleOpManager, get_resource_failed) { | |||||
auto &instance = SingleOpManager::GetInstance(); | auto &instance = SingleOpManager::GetInstance(); | ||||
ASSERT_EQ(instance.GetOpFromModel("model", model_data, stream, &single_op), FAILED); | ASSERT_EQ(instance.GetOpFromModel("model", model_data, stream, &single_op), FAILED); | ||||
} | |||||
} |
@@ -97,7 +97,9 @@ TEST_F(UtestSingleOpModel, test_set_inputs_and_outputs) { | |||||
model.output_offset_list_.push_back(0); | model.output_offset_list_.push_back(0); | ||||
model.output_sizes_.push_back(16); | model.output_sizes_.push_back(16); | ||||
SingleOp single_op; | |||||
std::mutex stream_mu_; | |||||
rtStream_t stream_ = nullptr; | |||||
SingleOp single_op(&stream_mu_, stream_); | |||||
ASSERT_EQ(model.SetInputsAndOutputs(single_op), SUCCESS); | ASSERT_EQ(model.SetInputsAndOutputs(single_op), SUCCESS); | ||||
} | } | ||||
@@ -111,25 +113,29 @@ TEST_F(UtestSingleOpModel, test_build_kernel_task) { | |||||
model.output_offset_list_.push_back(0); | model.output_offset_list_.push_back(0); | ||||
model.output_sizes_.push_back(16); | model.output_sizes_.push_back(16); | ||||
auto graph = make_shared<ComputeGraph>("graph"); | |||||
auto op_desc = make_shared<OpDesc>("AddN", "AddN"); | auto op_desc = make_shared<OpDesc>("AddN", "AddN"); | ||||
vector<int64_t> shape{16, 16}; | vector<int64_t> shape{16, 16}; | ||||
GeShape ge_shape(shape); | GeShape ge_shape(shape); | ||||
GeTensorDesc desc(ge_shape); | GeTensorDesc desc(ge_shape); | ||||
op_desc->AddInputDesc(desc); | op_desc->AddInputDesc(desc); | ||||
op_desc->AddOutputDesc(desc); | op_desc->AddOutputDesc(desc); | ||||
auto node = graph->AddNode(op_desc); | |||||
std::mutex stream_mu_; | |||||
rtStream_t stream_ = nullptr; | |||||
SingleOp single_op(&stream_mu_, stream_); | |||||
SingleOp single_op; | |||||
domi::KernelDef kernel_def; | domi::KernelDef kernel_def; | ||||
kernel_def.mutable_context()->set_kernel_type(cce::ccKernelType::CCE_AI_CORE); | |||||
OpTask *task = nullptr; | |||||
ASSERT_EQ(model.BuildKernelTask(kernel_def, single_op, &task), UNSUPPORTED); | |||||
kernel_def.mutable_context()->set_kernel_type(cce::ccKernelType::TE); | |||||
TbeOpTask *task = nullptr; | |||||
ASSERT_EQ(model.BuildKernelTask(kernel_def, &task), UNSUPPORTED); | |||||
kernel_def.mutable_context()->set_kernel_type(cce::ccKernelType::TE); | kernel_def.mutable_context()->set_kernel_type(cce::ccKernelType::TE); | ||||
ASSERT_EQ(model.BuildKernelTask(kernel_def, single_op, &task), INTERNAL_ERROR); | |||||
ASSERT_EQ(model.BuildKernelTask(kernel_def, &task), INTERNAL_ERROR); | |||||
model.op_list_[0] = op_desc; | |||||
model.op_list_[0] = node; | |||||
ASSERT_EQ(model.BuildKernelTask(kernel_def, single_op, &task), PARAM_INVALID); | |||||
ASSERT_EQ(model.BuildKernelTask(kernel_def, &task), PARAM_INVALID); | |||||
ASSERT_EQ(task, nullptr); | ASSERT_EQ(task, nullptr); | ||||
delete task; | delete task; | ||||
} | } | ||||
@@ -145,18 +151,22 @@ TEST_F(UtestSingleOpModel, test_parse_arg_table) { | |||||
SingleOpModel op_model("model", model_data_str.c_str(), model_data_str.size()); | SingleOpModel op_model("model", model_data_str.c_str(), model_data_str.size()); | ||||
TbeOpTask task; | TbeOpTask task; | ||||
SingleOp op; | |||||
OpDescPtr op_desc; | |||||
std::mutex stream_mu_; | |||||
rtStream_t stream_ = nullptr; | |||||
SingleOp op(&stream_mu_, stream_); | |||||
op.arg_table_.resize(2); | op.arg_table_.resize(2); | ||||
auto *args = new uintptr_t[2]; | |||||
args[0] = 0x100000; | |||||
args[1] = 0x200000; | |||||
task.SetKernelArgs(args, 16, 1); | |||||
auto args = std::unique_ptr<uint8_t[]>(new uint8_t[sizeof(uintptr_t) * 2]); | |||||
auto *arg_base = (uintptr_t*)args.get(); | |||||
arg_base[0] = 0x100000; | |||||
arg_base[1] = 0x200000; | |||||
task.SetKernelArgs(std::move(args), 16, 1, op_desc); | |||||
op_model.model_params_.addr_mapping_[0x100000] = 1; | op_model.model_params_.addr_mapping_[0x100000] = 1; | ||||
op_model.ParseArgTable(&task, op); | op_model.ParseArgTable(&task, op); | ||||
ASSERT_EQ(op.arg_table_[0].size(), 0); | ASSERT_EQ(op.arg_table_[0].size(), 0); | ||||
ASSERT_EQ(op.arg_table_[1].size(), 1); | ASSERT_EQ(op.arg_table_[1].size(), 1); | ||||
ASSERT_EQ(op.arg_table_[1].front(), &args[0]); | |||||
ASSERT_EQ(op.arg_table_[1].front(), &arg_base[0]); | |||||
} | } |
@@ -38,8 +38,9 @@ class UtestStreamResource : public testing::Test { | |||||
rtStream_t stream; | rtStream_t stream; | ||||
}; | }; | ||||
/* | |||||
TEST_F(UtestStreamResource, test_cache_op) { | TEST_F(UtestStreamResource, test_cache_op) { | ||||
StreamResource res; | |||||
StreamResource res((uintptr_t)1); | |||||
auto *op = new SingleOp(); | auto *op = new SingleOp(); | ||||
string stub_name = "stubFunc"; | string stub_name = "stubFunc"; | ||||
const void *key = stub_name.c_str(); | const void *key = stub_name.c_str(); | ||||
@@ -47,31 +48,34 @@ TEST_F(UtestStreamResource, test_cache_op) { | |||||
res.CacheOperator(key, op); | res.CacheOperator(key, op); | ||||
ASSERT_NE(res.GetOperator(key), nullptr); | ASSERT_NE(res.GetOperator(key), nullptr); | ||||
} | } | ||||
*/ | |||||
TEST_F(UtestStreamResource, test_malloc_memory) { | TEST_F(UtestStreamResource, test_malloc_memory) { | ||||
StreamResource res; | |||||
ASSERT_NE(res.MallocMemory(100), nullptr); | |||||
ASSERT_NE(res.MallocMemory(100), nullptr); | |||||
ASSERT_NE(res.MallocMemory(100), nullptr); | |||||
StreamResource res((uintptr_t)1); | |||||
string purpose("test"); | |||||
ASSERT_NE(res.MallocMemory(purpose, 100), nullptr); | |||||
ASSERT_NE(res.MallocMemory(purpose, 100), nullptr); | |||||
ASSERT_NE(res.MallocMemory(purpose, 100), nullptr); | |||||
} | } | ||||
TEST_F(UtestStreamResource, test_do_malloc_memory) { | TEST_F(UtestStreamResource, test_do_malloc_memory) { | ||||
size_t max_allocated = 0; | size_t max_allocated = 0; | ||||
vector<uint8_t *> allocated; | vector<uint8_t *> allocated; | ||||
string purpose("test"); | |||||
uint8_t *ret = StreamResource::DoMallocMemory(100, max_allocated, allocated); | |||||
StreamResource res((uintptr_t)1); | |||||
uint8_t *ret = res.DoMallocMemory(purpose, 100, max_allocated, allocated); | |||||
ASSERT_EQ(allocated.size(), 1); | ASSERT_EQ(allocated.size(), 1); | ||||
ASSERT_NE(allocated.back(), nullptr); | ASSERT_NE(allocated.back(), nullptr); | ||||
ASSERT_EQ(max_allocated, 100); | ASSERT_EQ(max_allocated, 100); | ||||
StreamResource::DoMallocMemory(50, max_allocated, allocated); | |||||
StreamResource::DoMallocMemory(99, max_allocated, allocated); | |||||
StreamResource::DoMallocMemory(100, max_allocated, allocated); | |||||
res.DoMallocMemory(purpose, 50, max_allocated, allocated); | |||||
res.DoMallocMemory(purpose, 99, max_allocated, allocated); | |||||
res.DoMallocMemory(purpose, 100, max_allocated, allocated); | |||||
ASSERT_EQ(allocated.size(), 1); | ASSERT_EQ(allocated.size(), 1); | ||||
ASSERT_EQ(max_allocated, 100); | ASSERT_EQ(max_allocated, 100); | ||||
StreamResource::DoMallocMemory(101, max_allocated, allocated); | |||||
res.DoMallocMemory(purpose, 101, max_allocated, allocated); | |||||
ASSERT_EQ(allocated.size(), 2); | ASSERT_EQ(allocated.size(), 2); | ||||
ASSERT_EQ(max_allocated, 101); | ASSERT_EQ(max_allocated, 101); | ||||