From: @t00456437 Reviewed-by: @ji_chen,@xchu42 Signed-off-by: @ji_chentags/v1.1.0
@@ -37,6 +37,7 @@ if (ENABLE_OPEN_SRC) | |||
include(cmake/external_libs/protobuf_static.cmake) | |||
include(cmake/external_libs/protoc.cmake) | |||
include(cmake/external_libs/gflags.cmake) | |||
include(cmake/external_libs/gtest.cmake) | |||
include(cmake/external_libs/securec.cmake) | |||
include(cmake/external_libs/json.cmake) | |||
include(cmake/FindModule.cmake) | |||
@@ -78,6 +79,7 @@ if (ENABLE_OPEN_SRC) | |||
else() | |||
find_module(slog libslog.so ${ASCEND_ATC_DIR}) | |||
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | |||
find_module(error_manager liberror_manager.so ${ASCEND_ATC_DIR}) | |||
if(PLATFORM STREQUAL "train") | |||
find_module(msprof libmsprof.so ${ASCEND_DRIVER_COMMON_DIR}) | |||
find_module(hccl libhccl.so ${ASCEND_RUNTIME_DIR}) | |||
@@ -123,8 +125,13 @@ if (ENABLE_OPEN_SRC) | |||
find_module(ascend_hal_stub libascend_hal.so ${ASCEND_DRIVER_DIR}/driver) | |||
#find_module(ascendcl_static libascendcl.a ${ASCEND_ACL_DIR}) | |||
else() | |||
message(FATAL_ERROR "PLATFORM param is invalid, should be train or inference, build terminated") | |||
message(STATUS "PLATFORM param is invalid, should be train or inference, you choose nothing!") | |||
endif() | |||
if (ENABLE_GE_COV OR ENABLE_GE_UT) | |||
add_subdirectory(tests) | |||
endif() | |||
endif() | |||
set(METADEF_DIR ${CMAKE_CURRENT_LIST_DIR}/metadef) | |||
@@ -59,7 +59,7 @@ checkopts() | |||
ENABLE_GE_ST="off" | |||
ENABLE_GE_COV="off" | |||
GE_ONLY="on" | |||
PLATFORM="inference" | |||
PLATFORM="" | |||
PRODUCT="normal" | |||
ENABLE_GITEE="off" | |||
# Process the options | |||
@@ -166,6 +166,9 @@ build_graphengine() | |||
elif [ "x${PLATFORM}" = "xinference" ] | |||
then | |||
TARGET="ge_compiler atc_ge_local_engine atc_ge_local_opskernel_builder atc_host_cpu_engine atc_host_cpu_opskernel_builder atc opensrc_ascendcl ${TARGET}" | |||
elif [ "X$ENABLE_GE_UT" = "Xon" ] | |||
then | |||
TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest" | |||
elif [ "x${PLATFORM}" = "xall" ] | |||
then | |||
# build all the target | |||
@@ -0,0 +1,60 @@ | |||
if (HAVE_GTEST) | |||
return() | |||
endif() | |||
include(ExternalProject) | |||
if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||
(${CMAKE_INSTALL_PREFIX} STREQUAL "C:/Program Files (x86)/ascend")) | |||
set(CMAKE_INSTALL_PREFIX ${GE_CODE_DIR}/output CACHE STRING "path for install()" FORCE) | |||
message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | |||
endif() | |||
if (ENABLE_GITEE) | |||
set(REQ_URL "https://gitee.com/mirrors/googletest/repository/archive/release-1.8.0.tar.gz") | |||
set(MD5 "") | |||
else() | |||
set(REQ_URL "https://github.com/google/googletest/archive/release-1.8.0.tar.gz") | |||
set(MD5 "") | |||
endif () | |||
set (gtest_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") | |||
set (gtest_CFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack") | |||
ExternalProject_Add(gtest_build | |||
URL ${REQ_URL} | |||
CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS=${gtest_CXXFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gtest <SOURCE_DIR> | |||
-DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_MACOSX_RPATH=TRUE -Dgtest_disable_pthreads=ON | |||
BUILD_COMMAND $(MAKE) | |||
INSTALL_COMMAND $(MAKE) install | |||
EXCLUDE_FROM_ALL TRUE | |||
) | |||
set(GTEST_PKG_DIR ${CMAKE_INSTALL_PREFIX}/gtest) | |||
file(MAKE_DIRECTORY ${GTEST_PKG_DIR}/include) | |||
add_library(gtest SHARED IMPORTED) | |||
set_target_properties(gtest PROPERTIES | |||
IMPORTED_LOCATION ${GTEST_PKG_DIR}/lib/libgtest.so | |||
) | |||
add_library(gtest_main SHARED IMPORTED) | |||
set_target_properties(gtest_main PROPERTIES | |||
IMPORTED_LOCATION ${GTEST_PKG_DIR}/lib/libgtest_main.so | |||
) | |||
target_include_directories(gtest INTERFACE ${GTEST_PKG_DIR}/include) | |||
target_include_directories(gtest_main INTERFACE ${GTEST_PKG_DIR}/include) | |||
set(INSTALL_BASE_DIR "") | |||
set(INSTALL_LIBRARY_DIR lib) | |||
install(FILES ${GTEST_PKG_DIR}/lib/libgtest.so ${GTEST_PKG_DIR}/lib/libgtest_main.so OPTIONAL | |||
DESTINATION ${INSTALL_LIBRARY_DIR}) | |||
add_dependencies(gtest gtest_build) | |||
#set(HAVE_GFLAGS TRUE CACHE BOOL "gflags build add") | |||
set(HAVE_GTEST TRUE) |
@@ -22,7 +22,7 @@ endif () | |||
set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2 -Dgoogle=ascend_private") | |||
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | |||
ExternalProject_Add(protobuf_build | |||
URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz | |||
URL ${REQ_URL} | |||
CONFIGURE_COMMAND ${CMAKE_COMMAND} | |||
-Dprotobuf_WITH_ZLIB=OFF | |||
-DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR} | |||
@@ -20,7 +20,7 @@ set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fst | |||
set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | |||
set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) | |||
ExternalProject_Add(protobuf_static_build | |||
URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz | |||
URL ${REQ_URL} | |||
#URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | |||
#SOURCE_DIR ${METADEF_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 | |||
CONFIGURE_COMMAND ${CMAKE_COMMAND} | |||
@@ -22,6 +22,7 @@ add_subdirectory(depends/runtime) | |||
add_subdirectory(depends/omg) | |||
add_subdirectory(depends/hccl) | |||
add_subdirectory(depends/profiler) | |||
add_subdirectory(depends/error_manager) | |||
if (ENABLE_GE_COV OR ENABLE_GE_UT) | |||
add_subdirectory(ut) | |||
@@ -13,60 +13,84 @@ | |||
# limitations under the License. | |||
# ============================================================================ | |||
cmake_minimum_required(VERSION 2.8) | |||
#cmake_minimum_required(VERSION 2.8) | |||
project(STUB_CCE) | |||
set(CMAKE_CXX_STANDARD 11) | |||
include_directories(${GE_SOURCE_DIR}/inc) | |||
include_directories(${GE_SOURCE_DIR}/inc/framework) | |||
include_directories(${GE_SOURCE_DIR}/inc/graph) | |||
include_directories(${GE_SOURCE_DIR}/inc/external) | |||
include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||
include_directories(${GE_SOURCE_DIR}/src/common) | |||
include_directories(${GE_SOURCE_DIR}/src/common/graph) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) | |||
include_directories(${GE_CODE_DIR}/inc) | |||
include_directories(${GE_CODE_DIR}/inc/framework) | |||
include_directories(${GE_CODE_DIR}/metadef/inc/graph) | |||
include_directories(${GE_CODE_DIR}/inc/external) | |||
include_directories(${GE_CODE_DIR}/metadef/inc/external) | |||
include_directories(${GE_CODE_DIR}/metadef/inc/external/graph) | |||
include_directories(${GE_CODE_DIR}/metadef) | |||
include_directories(${GE_CODE_DIR}/metadef/inc) | |||
include_directories(${GE_CODE_DIR}/metadef/graph) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) | |||
include_directories(${CMAKE_BINARY_DIR}) | |||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"${GE_SOURCE_DIR}/src/proto/om.proto" | |||
"${GE_SOURCE_DIR}/src/proto/ge_ir.proto" | |||
"${GE_SOURCE_DIR}/src/proto/task.proto" | |||
set(PROTO_LIST | |||
"${GE_CODE_DIR}/metadef/proto/om.proto" | |||
"${GE_CODE_DIR}/metadef/proto/ge_ir.proto" | |||
"${GE_CODE_DIR}/metadef/proto/task.proto" | |||
) | |||
ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"${GE_SOURCE_DIR}/src/common/graph/ge_attr_define.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/anchor.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/ge_attr_value.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/buffer.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/compute_graph.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/graph.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/model.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/model_serialize.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/node.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/op_desc.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/operator.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/operator_factory.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/operator_factory_impl.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/tensor.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/detail/attributes_holder.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/utils/anchor_utils.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/utils/graph_utils.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/utils/node_utils.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/utils/op_desc_utils.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/utils/type_utils.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/op_imp.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/shape_refiner.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/ge_tensor.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/opsproto/opsproto_manager.cc" | |||
set(SRCS | |||
"${GE_CODE_DIR}/metadef/graph/ge_attr_define.cc" | |||
"${GE_CODE_DIR}/metadef/graph/anchor.cc" | |||
"${GE_CODE_DIR}/metadef/graph/ge_attr_value.cc" | |||
"${GE_CODE_DIR}/metadef/graph/buffer.cc" | |||
"${GE_CODE_DIR}/metadef/graph/compute_graph.cc" | |||
"${GE_CODE_DIR}/metadef/graph/graph.cc" | |||
"${GE_CODE_DIR}/metadef/graph/model.cc" | |||
"${GE_CODE_DIR}/metadef/graph/model_serialize.cc" | |||
"${GE_CODE_DIR}/metadef/graph/node.cc" | |||
"${GE_CODE_DIR}/metadef/graph/op_desc.cc" | |||
"${GE_CODE_DIR}/metadef/graph/operator.cc" | |||
"${GE_CODE_DIR}/metadef/graph/operator_factory.cc" | |||
"${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" | |||
"${GE_CODE_DIR}/metadef/graph/tensor.cc" | |||
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | |||
"${GE_CODE_DIR}/metadef/ops/op_imp.cpp" | |||
"${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" | |||
"${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" | |||
"${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc" | |||
) | |||
add_library(cce_ge_stub SHARED src/cce_stub.cc ${PROTO_SRCS} ${PROTO_HDRS}) | |||
target_link_libraries(cce_ge_stub protobuf::protobuf) | |||
target_compile_definitions(cce_ge_stub PRIVATE | |||
google=ascend_private | |||
) | |||
target_link_libraries(cce_ge_stub | |||
$<BUILD_INTERFACE:intf_pub> | |||
-Wl,--no-as-needed | |||
ascend_protobuf | |||
-Wl,--as-needed | |||
c_sec | |||
) | |||
add_library(cce_stub SHARED ${SRCS} ${PROTO_SRCS} ${PROTO_HDRS}) | |||
target_link_libraries(cce_stub protobuf::protobuf) | |||
target_compile_definitions(cce_stub PRIVATE | |||
google=ascend_private | |||
) | |||
target_link_libraries(cce_stub PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
-Wl,--no-as-needed | |||
ascend_protobuf | |||
-Wl,--as-needed | |||
c_sec | |||
) |
@@ -0,0 +1,34 @@ | |||
# Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
#cmake_minimum_required(VERSION 2.8) | |||
project(STUB_ERROR_MANAGER) | |||
file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"src/error_manager_stub.cc" | |||
) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_CODE_DIR}/inc) | |||
include_directories(${GE_CODE_DIR}/inc/external) | |||
include_directories(${GE_CODE_DIR}/metadef/inc) | |||
include_directories(${GE_CODE_DIR}/inc/framework) | |||
include_directories(${GE_CODE_DIR}/metadef/inc/external) | |||
add_library(error_manager_stub SHARED ${SRCS}) | |||
target_link_libraries(error_manager_stub PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
) |
@@ -0,0 +1,85 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "common/util/error_manager/error_manager.h" | |||
ErrorManager &ErrorManager::GetInstance() { | |||
static ErrorManager instance; | |||
return instance; | |||
} | |||
/// | |||
/// @brief init | |||
/// @param [in] path: current so path | |||
/// @return int 0(success) -1(fail) | |||
/// | |||
int ErrorManager::Init(std::string path) { return 0; } | |||
/// | |||
/// @brief Report error message | |||
/// @param [in] error_code: error code | |||
/// @param [in] args_map: parameter map | |||
/// @return int 0(success) -1(fail) | |||
/// | |||
int ErrorManager::ReportErrMessage(std::string error_code, const std::map<std::string, std::string> &args_map) { | |||
return 0; | |||
} | |||
/// | |||
/// @brief output error message | |||
/// @param [in] handle: print handle | |||
/// @return int 0(success) -1(fail) | |||
/// | |||
int ErrorManager::OutputErrMessage(int handle) { return 0; } | |||
/// | |||
/// @brief output message | |||
/// @param [in] handle: print handle | |||
/// @return int 0(success) -1(fail) | |||
/// | |||
int ErrorManager::OutputMessage(int handle) { return 0; } | |||
/// | |||
/// @brief Report error message | |||
/// @param [in] key: vector parameter key | |||
/// @param [in] value: vector parameter value | |||
/// | |||
void ErrorManager::ATCReportErrMessage(std::string error_code, const std::vector<std::string> &key, | |||
const std::vector<std::string> &value) { | |||
} | |||
/// | |||
/// @brief report graph compile failed message such as error code and op_name in mstune case | |||
/// @param [in] msg: failed message map, key is error code, value is op_name | |||
/// @return int 0(success) -1(fail) | |||
/// | |||
int ErrorManager::ReportMstuneCompileFailedMsg(const std::map<std::string, std::string> &msg) { return 0; } | |||
/// | |||
/// @brief save graph compile failed message from thread local map to global map | |||
/// @param [in] graph_name: graph name | |||
/// | |||
void ErrorManager::SaveMstuneCompileFailedMsg(const std::string &graph_name) {} | |||
/// | |||
/// @brief get graph compile failed message in mstune case | |||
/// @param [in] graph_name: graph name | |||
/// @param [out] msg_map: failed message map, key is error code, value is op_name list | |||
/// @return int 0(success) -1(fail) | |||
/// | |||
int ErrorManager::GetMstuneCompileFailedMsg(const std::string &graph_name, std::map<std::string, std::vector<std::string>> &msg_map) { return 0; } | |||
@@ -13,14 +13,18 @@ | |||
# limitations under the License. | |||
# ============================================================================ | |||
cmake_minimum_required(VERSION 2.8) | |||
#cmake_minimum_required(VERSION 2.8) | |||
project(hccl_stub) | |||
file(GLOB_RECURSE SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
"src/hccl_stub.cc" | |||
) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_SOURCE_DIR}/inc) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_CODE_DIR}/inc) | |||
add_library(hccl_stub SHARED ${SRC_FILES}) | |||
add_library(hccl_stub SHARED ${SRC_FILES}) | |||
target_link_libraries(hccl_stub PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
) |
@@ -18,27 +18,27 @@ | |||
#include "hccl/hcom.h" | |||
hcclResult_t hcom_all_gather(const char *tag, void *input_count_ptr, void *output_ptr, u64 input_count, | |||
hcclDataType_t data_type, const char *group, rtStream_t stream) { | |||
HcclResult hcom_all_gather(const char *tag, void *input_count_ptr, void *output_ptr, u64 input_count, | |||
HcclDataType data_type, const char *group, rtStream_t stream) { | |||
return HCCL_SUCCESS; | |||
} | |||
hcclResult_t hcom_broadcast(const char *tag, void *ptr, u64 count, hcclDataType_t data_type, u32 root, | |||
HcclResult hcom_broadcast(const char *tag, void *ptr, u64 count, HcclDataType data_type, u32 root, | |||
const char *group, rtStream_t stream) { | |||
return HCCL_SUCCESS; | |||
} | |||
hcclResult_t hcom_all_reduce(const char *tag, void *input_ptr, void *output_ptr, u64 count, hcclDataType_t data_type, | |||
hcclRedOp_t op, const char *group, rtStream_t stream) { | |||
HcclResult hcom_all_reduce(const char *tag, void *input_ptr, void *output_ptr, u64 count, HcclDataType data_type, | |||
HcclReduceOp op, const char *group, rtStream_t stream) { | |||
return HCCL_SUCCESS; | |||
} | |||
hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 max_segment_num, | |||
HcclResult hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 max_segment_num, | |||
u32 *segment_num, u32 *segment_idx) { | |||
return HCCL_SUCCESS; | |||
} | |||
hcclResult_t hcom_reduce_scatter(const char *tag, void *input_ptr, void *output_ptr, u64 count, | |||
hcclDataType_t data_type, hcclRedOp_t op, const char *group, rtStream_t stream) { | |||
HcclResult hcom_reduce_scatter(const char *tag, void *input_ptr, void *output_ptr, u64 count, | |||
HcclDataType data_type, HcclReduceOp op, const char *group, rtStream_t stream) { | |||
return HCCL_SUCCESS; | |||
} | |||
} |
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
# ============================================================================ | |||
cmake_minimum_required(VERSION 2.8) | |||
#cmake_minimum_required(VERSION 2.8) | |||
project(STUB_MMPA) | |||
@@ -21,10 +21,18 @@ file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"src/mmpa_stub.cc" | |||
) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_SOURCE_DIR}/inc) | |||
include_directories(${GE_SOURCE_DIR}/inc/framework) | |||
include_directories(${GE_SOURCE_DIR}/inc/external) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_CODE_DIR}/inc) | |||
include_directories(${GE_CODE_DIR}/inc/external) | |||
include_directories(${GE_CODE_DIR}/metadef/inc) | |||
include_directories(${GE_CODE_DIR}/inc/framework) | |||
include_directories(${GE_CODE_DIR}/metadef/inc/external) | |||
add_library(mmpa_stub SHARED ${SRCS}) | |||
target_link_libraries(mmpa_stub protobuf::protobuf) | |||
target_link_libraries(mmpa_stub PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
-Wl,--no-as-needed | |||
ascend_protobuf | |||
-Wl,--as-needed | |||
c_sec | |||
) |
@@ -217,3 +217,58 @@ INT32 mmScandir(const CHAR *path, mmDirent ***entryList, mmFilter filterFunc, m | |||
VOID mmScandirFree(mmDirent **entryList, INT32 count) | |||
{ | |||
} | |||
INT32 mmAccess2(const CHAR *pathName, INT32 mode) | |||
{ | |||
return 0; | |||
} | |||
INT32 mmGetTimeOfDay(mmTimeval *timeVal, mmTimezone *timeZone) | |||
{ | |||
return 0; | |||
} | |||
INT32 mmRealPath(const CHAR *path, CHAR *realPath, INT32 realPathLen) | |||
{ | |||
return 0; | |||
} | |||
INT32 mmGetErrorCode() | |||
{ | |||
return 0; | |||
} | |||
INT32 mmIsDir(const CHAR *fileName) | |||
{ | |||
return 0; | |||
} | |||
INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) | |||
{ | |||
return 0; | |||
} | |||
INT32 mmDlclose(VOID *handle) | |||
{ | |||
return 0; | |||
} | |||
CHAR *mmDlerror() | |||
{ | |||
return ""; | |||
} | |||
INT32 mmDladdr(VOID *addr, mmDlInfo *info) | |||
{ | |||
return 0; | |||
} | |||
VOID *mmDlopen(const CHAR *fileName, INT32 mode) | |||
{ | |||
return NULL; | |||
} | |||
VOID *mmDlsym(VOID *handle, const CHAR *funcName) | |||
{ | |||
return NULL; | |||
} |
@@ -13,33 +13,47 @@ | |||
# limitations under the License. | |||
# ============================================================================ | |||
cmake_minimum_required(VERSION 2.8) | |||
#cmake_minimum_required(VERSION 2.8) | |||
project(OMG_CCE) | |||
set(CMAKE_CXX_STANDARD 11) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | |||
include_directories(${GE_SOURCE_DIR}/inc) | |||
include_directories(${GE_SOURCE_DIR}/inc/framework) | |||
include_directories(${GE_SOURCE_DIR}/inc/graph) | |||
include_directories(${GE_SOURCE_DIR}/inc/external) | |||
include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||
include_directories(${GE_SOURCE_DIR}/src/ge) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce) | |||
include_directories(${GE_CODE_DIR}/inc) | |||
include_directories(${GE_CODE_DIR}/metadef/inc) | |||
include_directories(${GE_CODE_DIR}/inc/framework) | |||
include_directories(${GE_CODE_DIR}/metadef/inc/graph) | |||
include_directories(${GE_CODE_DIR}/inc/external) | |||
include_directories(${GE_CODE_DIR}/metadef/inc/external) | |||
include_directories(${GE_CODE_DIR}/metadef/inc/external/graph) | |||
include_directories(${GE_CODE_DIR}/ge) | |||
include_directories(${CMAKE_BINARY_DIR}) | |||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"${GE_SOURCE_DIR}/src/proto/om.proto" | |||
"${GE_SOURCE_DIR}/src/proto/task.proto" | |||
set(PROTO_LIST | |||
"${GE_CODE_DIR}/metadef/proto/om.proto" | |||
"${GE_CODE_DIR}/metadef/proto/task.proto" | |||
) | |||
ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
# "${GE_SOURCE_DIR}/src/ge/common/util.cc" | |||
"src/omg_stub.cc" | |||
set(SRCS | |||
# "${GE_CODE_DIR}/src/ge/common/util.cc" | |||
"src/omg_stub.cc" | |||
) | |||
add_library(omg_stub SHARED ${SRCS} ${PROTO_SRCS} ${PROTO_HDRS}) | |||
target_link_libraries(omg_stub protobuf::protobuf) | |||
target_compile_definitions(omg_stub PRIVATE | |||
google=ascend_private | |||
) | |||
target_link_libraries(omg_stub PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
-Wl,--no-as-needed | |||
ascend_protobuf | |||
-Wl,--as-needed | |||
c_sec | |||
json | |||
) |
@@ -643,7 +643,7 @@ Status GetInputOutputDescInfo(uint32_t model_id, vector<InputOutputDescInfo> &in | |||
} | |||
Status DataInput(const InputData *input_data, OutputData *output_data) { return SUCCESS; } | |||
/* | |||
class ModelManager { | |||
public: | |||
static std::shared_ptr<ModelManager> GetInstance(); | |||
@@ -741,6 +741,8 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy | |||
return SUCCESS; | |||
} | |||
*/ | |||
} // namespace ge | |||
namespace ge { | |||
@@ -13,12 +13,16 @@ | |||
# limitations under the License. | |||
# ============================================================================ | |||
cmake_minimum_required(VERSION 2.8) | |||
#cmake_minimum_required(VERSION 2.8) | |||
project(profiler_stub) | |||
file(GLOB_RECURSE SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
"src/profiler_stub.cc" | |||
) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||
add_library(profiler_stub SHARED ${SRC_FILES}) | |||
add_library(profiler_stub SHARED ${SRC_FILES}) | |||
target_link_libraries(profiler_stub PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
) |
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
# ============================================================================ | |||
cmake_minimum_required(VERSION 2.8) | |||
#cmake_minimum_required(VERSION 2.8) | |||
project(STUB_MMPA) | |||
@@ -21,7 +21,12 @@ file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"src/runtime_stub.cc" | |||
) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_SOURCE_DIR}/inc/framework) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_CODE_DIR}/inc/framework) | |||
add_library(runtime_stub SHARED ${SRCS}) | |||
target_link_libraries(runtime_stub PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
c_sec | |||
) |
@@ -221,8 +221,9 @@ rtError_t rtCpuKernelLaunch(const void *so_name, const void *kernel_name, uint32 | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtModelGetTaskId(void *handle, uint32_t *task_id) { | |||
rtError_t rtModelGetTaskId(void *handle, uint32_t *task_id, uint32_t *stream_id) { | |||
*task_id = 0; | |||
*stream_id = 0; | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtEndGraph(rtModel_t model, rtStream_t stream) { return RT_ERROR_NONE; } | |||
@@ -307,3 +308,79 @@ rtError_t rtModelBindQueue(rtModel_t model, uint32_t queueId, rtModelQueueFlag_t | |||
{ | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtSetSocVersion(const char *version) | |||
{ | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtGetSocVersion(char *version, const uint32_t maxLen) | |||
{ | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) | |||
{ | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtMallocHostSharedMemory(rtMallocHostSharedMemoryIn *in, | |||
rtMallocHostSharedMemoryOut *out) | |||
{ | |||
out->ptr = new uint8_t[in->size]; | |||
out->devPtr = new uint8_t[in->size]; | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtFreeHostSharedMemory(rtFreeHostSharedMemoryIn *in) | |||
{ | |||
delete[] (uint8_t*)in->ptr; | |||
delete[] (uint8_t*)in->devPtr; | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtGetAicpuDeploy(rtAicpuDeployType_t *deplyType) | |||
{ | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtDebugRegister(rtModel_t model, uint32_t flag, const void *addr, uint32_t *streamId, uint32_t *taskId) | |||
{ | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtDebugUnRegister(rtModel_t model) | |||
{ | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtDumpAddrSet(rtModel_t model, void *addr, uint32_t dumpSize, uint32_t flag) | |||
{ | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtSetCtxINFMode(bool mode) | |||
{ | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream) | |||
{ | |||
*label = new uint32_t; | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtGetRtCapability(rtFeatureType_t featureType, int32_t featureInfo, int64_t *value) | |||
{ | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtGetMaxStreamAndTask(uint32_t streamType, uint32_t *maxStrCount, uint32_t *maxTaskCount) | |||
{ | |||
return RT_ERROR_NONE; | |||
} | |||
rtError_t rtModelExit(rtModel_t model, rtStream_t stream) | |||
{ | |||
return RT_ERROR_NONE; | |||
} |
@@ -13,11 +13,14 @@ | |||
# limitations under the License. | |||
# ============================================================================ | |||
cmake_minimum_required(VERSION 2.8) | |||
#cmake_minimum_required(VERSION 2.8) | |||
project(slog_stub) | |||
file(GLOB_RECURSE SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
"src/*.cc" | |||
) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||
add_library(slog_stub SHARED ${SRC_FILES}) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||
add_library(slog_stub SHARED ${SRC_FILES}) | |||
target_link_libraries(slog_stub PRIVATE | |||
$<BUILD_INTERFACE:intf_pub> | |||
) |
@@ -38,6 +38,8 @@ void DlogWithKVInner(int module_id, int level, KeyValue *pst_kv_array, int kv_nu | |||
dav_log(module_id, fmt); | |||
} | |||
int dlog_setlevel(int module_id, int level, int enable_event) { return DLOG_DEBUG; } | |||
int dlog_getlevel(int module_id, int *enable_event) { return DLOG_DEBUG; } | |||
int CheckLogLevel(int moduleId, int logLevel) | |||
@@ -17,30 +17,34 @@ project(ut_libgraph) | |||
set(CMAKE_CXX_STANDARD 11) | |||
file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
"${GE_SOURCE_DIR}/src/proto/om.proto" | |||
"${GE_SOURCE_DIR}/src/proto/ge_ir.proto" | |||
"${onnx_INC}/onnx/onnx.proto" | |||
set(PROTO_LIST | |||
"${GE_CODE_DIR}/metadef/proto/om.proto" | |||
"${GE_CODE_DIR}/metadef/proto/ge_ir.proto" | |||
"${GE_CODE_DIR}/metadef/proto/proto_inner/ge_onnx.proto" | |||
) | |||
ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
# include directories | |||
include_directories(${CMAKE_CURRENT_LIST_DIR}) | |||
include_directories(${GE_SOURCE_DIR}/src) | |||
include_directories(${GE_SOURCE_DIR}/src/common) | |||
include_directories(${GE_SOURCE_DIR}/src/common/graph) | |||
include_directories(${GE_SOURCE_DIR}/inc) | |||
include_directories(${GE_SOURCE_DIR}/inc/external) | |||
include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||
include_directories(${GE_SOURCE_DIR}/inc/graph) | |||
include_directories(${GE_SOURCE_DIR}/inc/common) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) | |||
include_directories(${GE_CODE_DIR}) | |||
include_directories(${GE_CODE_DIR}/metadef) | |||
include_directories(${GE_CODE_DIR}/metadef/graph) | |||
include_directories(${GE_CODE_DIR}/inc) | |||
include_directories(${GE_CODE_DIR}/inc/external) | |||
include_directories(${GE_CODE_DIR}/metadef/inc/external) | |||
include_directories(${GE_CODE_DIR}/metadef/inc/external/graph) | |||
include_directories(${GE_CODE_DIR}/metadef/inc) | |||
include_directories(${GE_CODE_DIR}/metadef/inc/graph) | |||
include_directories(${GE_CODE_DIR}/metadef/inc/common) | |||
include_directories(${GE_CODE_DIR}/metadef/third_party) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) | |||
include_directories(${CMAKE_BINARY_DIR}) | |||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) | |||
file(GLOB_RECURSE UT_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
set(UT_FILES | |||
"testcase/ge_graph/ge_anchor_utils_unittest.cc" | |||
"testcase/ge_graph/ge_def_type_unittest.cc" | |||
"testcase/ge_graph/ge_graph_anchor_unittest.cc" | |||
@@ -56,41 +60,59 @@ file(GLOB_RECURSE UT_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
"testcase/ge_graph/ge_model_unittest.cc" | |||
) | |||
file(GLOB_RECURSE SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
"${GE_SOURCE_DIR}/src/common/graph/option/ge_local_context.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/option/ge_context.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/anchor.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/ge_attr_value.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/attr_value.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/buffer.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/compute_graph.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/ge_attr_define.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/graph.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/model.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/model_serialize.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/node.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/op_desc.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/operator.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/operator_reg.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/operator_factory.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/operator_factory_impl.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/range_vistor.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/tensor.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/ge_tensor.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/shape_refiner.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/format_refiner.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/inference_context.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/detail/attributes_holder.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/utils/anchor_utils.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/utils/graph_utils.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/utils/node_utils.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/utils/op_desc_utils.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/utils/type_utils.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/utils/ge_ir_utils.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/utils/tensor_utils.cc" | |||
"${GE_SOURCE_DIR}/src/common/ops/op_imp.cc" | |||
"${GE_SOURCE_DIR}/src/common/graph/opsproto/opsproto_manager.cc" | |||
set(SRC_FILES | |||
#"${GE_CODE_DIR}/metadef/graph/option/ge_local_context.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/option/ge_context.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/anchor.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/ge_attr_value.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/attr_value.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/buffer.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/compute_graph.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/ge_attr_define.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/graph.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/gnode.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/ascend_string.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/model.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/model_serialize.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/node.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/op_desc.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/operator.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/operator_reg.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/operator_factory.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/operator_factory_impl.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/range_vistor.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/tensor.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/ge_tensor.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/shape_refiner.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/format_refiner.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/inference_context.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" | |||
#"${GE_CODE_DIR}/metadef/graph/utils/tensor_utils.cc" | |||
"${GE_CODE_DIR}/metadef/ops/op_imp.cpp" | |||
#"${GE_CODE_DIR}/metadef/graph/opsproto/opsproto_manager.cc" | |||
) | |||
#add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | |||
add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | |||
target_link_libraries(ut_libgraph graphengine::gtest graphengine::gtest_main slog_stub protobuf::protobuf graphengine::securec rt dl) | |||
target_compile_definitions(ut_libgraph PRIVATE | |||
google=ascend_private | |||
) | |||
target_link_libraries(ut_libgraph | |||
$<BUILD_INTERFACE:intf_pub> | |||
graph | |||
gtest | |||
gtest_main | |||
slog_stub | |||
ascend_protobuf | |||
c_sec | |||
-lrt | |||
-ldl | |||
) |
@@ -85,7 +85,7 @@ ut::GraphBuilder BuildGraph1() { | |||
builder.AddDataEdge(var2, 0, conv1, 1); | |||
builder.AddDataEdge(conv1, 0, relu1, 0); | |||
builder.AddDataEdge(relu1, 0, netoutput1, 0); | |||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||
return builder; | |||
} | |||
@@ -134,7 +134,7 @@ ut::GraphBuilder BuildGraph2() { | |||
builder.AddDataEdge(var6, 0, bn1, 4); | |||
builder.AddDataEdge(bn1, 0, relu1, 0); | |||
builder.AddDataEdge(relu1, 0, netoutput1, 0); | |||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||
return builder; | |||
} | |||
@@ -189,7 +189,7 @@ ut::GraphBuilder BuildGraph3() { | |||
builder.AddDataEdge(relu1, 0, conv2, 0); | |||
builder.AddDataEdge(var3, 0, conv2, 1); | |||
builder.AddDataEdge(conv2, 0, netoutput1, 0); | |||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||
return builder; | |||
} | |||
@@ -248,7 +248,7 @@ ut::GraphBuilder BuildGraph4() { | |||
builder.AddDataEdge(relu1, 0, conv2, 0); | |||
builder.AddDataEdge(var3, 0, conv2, 1); | |||
builder.AddDataEdge(conv2, 0, netoutput1, 0); | |||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||
return builder; | |||
} | |||
@@ -305,7 +305,7 @@ ut::GraphBuilder BuilderGraph5() { | |||
builder.AddDataEdge(relug1, 0, bng1, 0); | |||
builder.AddDataEdge(bng1, 0, apply1, 0); | |||
builder.AddDataEdge(apply1, 0, netoutput1, 0); | |||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||
return builder; | |||
} | |||
@@ -353,7 +353,7 @@ ut::GraphBuilder BuildGraph6() { | |||
builder.AddDataEdge(constant, 0, addn, 2); | |||
builder.AddDataEdge(addn, 0, netoutput, 0); | |||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||
return builder; | |||
} | |||
@@ -397,7 +397,7 @@ ut::GraphBuilder BuildGraph7() { | |||
builder.AddDataEdge(constant, 0, addn, 2); | |||
builder.AddDataEdge(addn, 0, netoutput, 0); | |||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||
return builder; | |||
} | |||
@@ -449,7 +449,7 @@ ut::GraphBuilder BuildGraph8() { | |||
builder.AddDataEdge(relu, 0, reshape, 0); | |||
builder.AddDataEdge(reshape, 0, conv, 1); | |||
builder.AddDataEdge(conv, 0, netoutput, 0); | |||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||
return builder; | |||
} | |||
} // namespace | |||
@@ -457,7 +457,7 @@ ut::GraphBuilder BuildGraph8() { | |||
TEST_F(UtestFormatRefiner, data_format) { | |||
auto builder = BuildGraph8(); | |||
auto graph = builder.GetGraph(); | |||
FormatRefiner::SetInferOrigineFormatFlag(false); | |||
//FormatRefiner::SetInferOrigineFormatFlag(false); | |||
graph->SaveDataFormat(FORMAT_NCHW); | |||
EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); | |||
auto data2 = graph->FindNode("data2"); | |||
@@ -466,18 +466,18 @@ TEST_F(UtestFormatRefiner, data_format) { | |||
EXPECT_EQ(data2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); | |||
EXPECT_EQ(relu->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); | |||
EXPECT_EQ(relu->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); | |||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||
} | |||
TEST_F(UtestFormatRefiner, constant_fail) { | |||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||
auto builder = BuildGraph6(); | |||
auto graph = builder.GetGraph(); | |||
EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_FAILED); | |||
} | |||
TEST_F(UtestFormatRefiner, scalar_nodes_infer) { | |||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||
auto builder = BuildGraph6(); | |||
auto graph = builder.GetGraph(); | |||
auto constant = graph->FindNode("constant"); | |||
@@ -650,7 +650,7 @@ TEST_F(UtestFormatRefiner, infer_origine_format_failed) { | |||
} | |||
TEST_F(UtestFormatRefiner, save_format) { | |||
FormatRefiner::SetInferOrigineFormatFlag(true); | |||
//FormatRefiner::SetInferOrigineFormatFlag(true); | |||
auto builder = BuildGraph6(); | |||
auto graph = builder.GetGraph(); | |||
graph->SaveDataFormat(FORMAT_NHWC); | |||
@@ -658,4 +658,4 @@ TEST_F(UtestFormatRefiner, save_format) { | |||
EXPECT_EQ(save_format, FORMAT_NHWC); | |||
graph->SaveDataFormat(FORMAT_ND); | |||
} | |||
} // namespace ge | |||
} // namespace ge |
@@ -1060,7 +1060,7 @@ TEST(UtestGeModelSerialize, test_model_serialize_imp_invalid_param) { | |||
auto graph = std::make_shared<ComputeGraph>("test_graph"); | |||
auto node = graph->AddNode(std::make_shared<OpDesc>()); | |||
node->op_ = nullptr; | |||
proto::ModelDef model_def; | |||
ge::proto::ModelDef model_def; | |||
Model model; | |||
model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); | |||
EXPECT_FALSE(imp.SerializeModel(model, &model_def)); | |||
@@ -1101,26 +1101,26 @@ TEST(UTEST_ge_model_unserialize, test_invalid_tensor) { | |||
TEST(UTEST_ge_model_unserialize, test_invalid_TensorDesc) { | |||
{ // valid | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.mutable_attr(); | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
auto tensor_desc_attr = attr_def->mutable_td(); | |||
tensor_desc_attr->set_layout("NCHW"); | |||
tensor_desc_attr->set_dtype(proto::DataType::DT_INT8); | |||
tensor_desc_attr->set_dtype(ge::proto::DataType::DT_INT8); | |||
ModelSerializeImp imp; | |||
Model model; | |||
EXPECT_TRUE(imp.UnserializeModel(model, mode_def)); | |||
} | |||
{ // invalid layout | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.mutable_attr(); | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
auto tensor_desc_attr = attr_def->mutable_td(); | |||
tensor_desc_attr->set_layout("InvalidLayout"); | |||
tensor_desc_attr->set_dtype(proto::DataType::DT_INT8); | |||
tensor_desc_attr->set_dtype(ge::proto::DataType::DT_INT8); | |||
ModelSerializeImp imp; | |||
Model model; | |||
@@ -1131,13 +1131,13 @@ TEST(UTEST_ge_model_unserialize, test_invalid_TensorDesc) { | |||
EXPECT_EQ(tensor_desc.GetDataType(), DT_INT8); | |||
} | |||
{ // invalid datatype | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.mutable_attr(); | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
auto tensor_desc_attr = attr_def->mutable_td(); // tensor desc | |||
tensor_desc_attr->set_layout("NHWC"); | |||
tensor_desc_attr->set_dtype((proto::DataType)100); | |||
tensor_desc_attr->set_dtype((ge::proto::DataType)100); | |||
ModelSerializeImp imp; | |||
Model model; | |||
@@ -1148,13 +1148,13 @@ TEST(UTEST_ge_model_unserialize, test_invalid_TensorDesc) { | |||
EXPECT_EQ(tensor_desc.GetDataType(), DT_UNDEFINED); | |||
} | |||
{ // invalid datatype | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.mutable_attr(); | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
auto tensor_desc_attr = attr_def->mutable_t()->mutable_desc(); // tensor | |||
tensor_desc_attr->set_layout("NHWC"); | |||
tensor_desc_attr->set_dtype((proto::DataType)100); | |||
tensor_desc_attr->set_dtype((ge::proto::DataType)100); | |||
ModelSerializeImp imp; | |||
Model model; | |||
@@ -1167,13 +1167,13 @@ TEST(UTEST_ge_model_unserialize, test_invalid_TensorDesc) { | |||
EXPECT_EQ(tensor_desc.GetDataType(), DT_UNDEFINED); | |||
} | |||
{ // invalid attrmap | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.add_graph()->mutable_attr(); // graph attr | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
auto tensor_desc_attr = attr_def->mutable_t()->mutable_desc(); // tensor | |||
tensor_desc_attr->set_layout("NCHW"); | |||
tensor_desc_attr->set_dtype(proto::DataType::DT_INT8); | |||
tensor_desc_attr->set_dtype(ge::proto::DataType::DT_INT8); | |||
auto attrs1 = tensor_desc_attr->mutable_attr(); | |||
auto attr1 = (*attrs1)["key2"]; // empty attr | |||
@@ -1191,13 +1191,13 @@ TEST(UTEST_ge_model_unserialize, test_invalid_TensorDesc) { | |||
EXPECT_EQ(attr_value.GetValueType(), GeAttrValue::VT_NONE); | |||
} | |||
{ // invalid attrmap2 | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
auto tensor_desc_attr = attr_def->mutable_t()->mutable_desc(); // tensor | |||
tensor_desc_attr->set_layout("NCHW"); | |||
tensor_desc_attr->set_dtype(proto::DataType::DT_INT8); | |||
tensor_desc_attr->set_dtype(ge::proto::DataType::DT_INT8); | |||
auto attrs1 = tensor_desc_attr->mutable_attr(); | |||
auto attr1 = (*attrs1)["key2"].mutable_list(); // empty list attr | |||
@@ -1219,14 +1219,14 @@ TEST(UTEST_ge_model_unserialize, test_invalid_TensorDesc) { | |||
} | |||
TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||
{ // invalid graph | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
auto graph_attr = attr_def->mutable_g(); | |||
auto attrs_of_graph = graph_attr->mutable_attr(); | |||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | |||
tensor_val->set_dtype(proto::DT_INT8); | |||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||
tensor_val->set_layout("invalidLayout"); | |||
ModelSerializeImp imp; | |||
@@ -1245,15 +1245,15 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | |||
} | |||
{ // invalid list graph | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH); | |||
auto graph_attr = attr_def->mutable_list()->add_g(); | |||
auto attrs_of_graph = graph_attr->mutable_attr(); | |||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | |||
tensor_val->set_dtype(proto::DT_INT8); | |||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||
tensor_val->set_layout("invalidLayout"); | |||
ModelSerializeImp imp; | |||
@@ -1273,14 +1273,14 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | |||
} | |||
{ // invalid named_attrs | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
auto graph_attr = attr_def->mutable_func(); | |||
auto attrs_of_graph = graph_attr->mutable_attr(); | |||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | |||
tensor_val->set_dtype(proto::DT_INT8); | |||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||
tensor_val->set_layout("invalidLayout"); | |||
ModelSerializeImp imp; | |||
@@ -1298,15 +1298,15 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | |||
} | |||
{ // invalid list named_attrs | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS); | |||
auto graph_attr = attr_def->mutable_list()->add_na(); | |||
auto attrs_of_graph = graph_attr->mutable_attr(); | |||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | |||
tensor_val->set_dtype(proto::DT_INT8); | |||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||
tensor_val->set_layout("invalidLayout"); | |||
ModelSerializeImp imp; | |||
@@ -1325,14 +1325,14 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | |||
} | |||
{ // invalid tensor_desc | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
auto graph_attr = attr_def->mutable_td(); | |||
auto attrs_of_graph = graph_attr->mutable_attr(); | |||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | |||
tensor_val->set_dtype(proto::DT_INT8); | |||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||
tensor_val->set_layout("invalidLayout"); | |||
ModelSerializeImp imp; | |||
@@ -1350,15 +1350,15 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | |||
} | |||
{ // invalid list tensor_desc | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC); | |||
auto graph_attr = attr_def->mutable_list()->add_td(); | |||
auto attrs_of_graph = graph_attr->mutable_attr(); | |||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | |||
tensor_val->set_dtype(proto::DT_INT8); | |||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||
tensor_val->set_layout("invalidLayout"); | |||
ModelSerializeImp imp; | |||
@@ -1377,14 +1377,14 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | |||
} | |||
{ // invalid tensor | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
auto graph_attr = attr_def->mutable_t()->mutable_desc(); | |||
auto attrs_of_graph = graph_attr->mutable_attr(); | |||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | |||
tensor_val->set_dtype(proto::DT_INT8); | |||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||
tensor_val->set_layout("invalidLayout"); | |||
ModelSerializeImp imp; | |||
@@ -1402,15 +1402,15 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | |||
} | |||
{ // invalid list tensor | |||
proto::ModelDef mode_def; | |||
ge::proto::ModelDef mode_def; | |||
auto attrs = mode_def.add_graph()->add_op()->mutable_attr(); // node attr | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR); | |||
auto graph_attr = attr_def->mutable_list()->add_t()->mutable_desc(); | |||
auto attrs_of_graph = graph_attr->mutable_attr(); | |||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | |||
tensor_val->set_dtype(proto::DT_INT8); | |||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||
tensor_val->set_layout("invalidLayout"); | |||
ModelSerializeImp imp; | |||
@@ -1429,15 +1429,15 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||
EXPECT_EQ(tensor_desc1.GetDataType(), DT_INT8); | |||
} | |||
{ // invalid list tensor | |||
proto::GraphDef graph_def; | |||
ge::proto::GraphDef graph_def; | |||
auto attrs = graph_def.add_op()->mutable_attr(); // node attr | |||
proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
ge::proto::AttrDef *attr_def = &(*attrs)["key1"]; | |||
attr_def->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR); | |||
auto graph_attr = attr_def->mutable_list()->add_t()->mutable_desc(); | |||
auto attrs_of_graph = graph_attr->mutable_attr(); | |||
auto tensor_val = (*attrs_of_graph)["key2"].mutable_td(); | |||
tensor_val->set_dtype(proto::DT_INT8); | |||
tensor_val->set_dtype(ge::proto::DT_INT8); | |||
tensor_val->set_layout("invalidLayout"); | |||
ModelSerializeImp imp; | |||
@@ -1462,7 +1462,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_attr) { | |||
TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||
// model invalid node input | |||
{ | |||
proto::ModelDef model_def; | |||
ge::proto::ModelDef model_def; | |||
auto op_def = model_def.add_graph()->add_op(); // node attr | |||
op_def->add_input("invalidNodeName:0"); | |||
@@ -1475,7 +1475,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||
} | |||
// model invalid node control input | |||
{ | |||
proto::ModelDef model_def; | |||
ge::proto::ModelDef model_def; | |||
auto op_def = model_def.add_graph()->add_op(); // node attr | |||
op_def->add_input("invalidNodeName:-1"); | |||
@@ -1488,7 +1488,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||
} | |||
// model invalid graph input | |||
{ | |||
proto::ModelDef model_def; | |||
ge::proto::ModelDef model_def; | |||
model_def.add_graph()->add_input("invalidNodeName:0"); | |||
Buffer buffer(model_def.ByteSizeLong()); | |||
@@ -1500,7 +1500,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||
} | |||
// model invalid graph input | |||
{ | |||
proto::ModelDef model_def; | |||
ge::proto::ModelDef model_def; | |||
model_def.add_graph()->add_output("invalidNodeName:0"); | |||
Buffer buffer(model_def.ByteSizeLong()); | |||
@@ -1512,7 +1512,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||
} | |||
// graph invalid node input | |||
{ | |||
proto::GraphDef graph_def; | |||
ge::proto::GraphDef graph_def; | |||
auto op_def = graph_def.add_op(); // node attr | |||
op_def->add_input("invalidNodeName:0"); | |||
@@ -1525,7 +1525,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||
} | |||
// graph invalid node control input | |||
{ | |||
proto::GraphDef graph_def; | |||
ge::proto::GraphDef graph_def; | |||
auto op_def = graph_def.add_op(); // node attr | |||
op_def->add_input("invalidNodeName:-1"); | |||
@@ -1538,7 +1538,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||
} | |||
// graph invalid graph input | |||
{ | |||
proto::GraphDef graph_def; | |||
ge::proto::GraphDef graph_def; | |||
graph_def.add_input("invalidNodeName:0"); | |||
Buffer buffer(graph_def.ByteSizeLong()); | |||
@@ -1550,7 +1550,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||
} | |||
// graph invalid graph output | |||
{ | |||
proto::GraphDef graph_def; | |||
ge::proto::GraphDef graph_def; | |||
graph_def.add_output("invalidNodeName:0"); | |||
Buffer buffer(graph_def.ByteSizeLong()); | |||
@@ -1562,7 +1562,7 @@ TEST(UTEST_ge_model_unserialize, test_invalid_input_output) { | |||
} | |||
// model invalid node input anchor | |||
{ | |||
proto::ModelDef model_def; | |||
ge::proto::ModelDef model_def; | |||
auto graph_def = model_def.add_graph(); | |||
auto node_def1 = graph_def->add_op(); // node attr | |||
node_def1->set_name("node1"); | |||
@@ -151,7 +151,7 @@ TEST_F(UtestGeNode, update_opdesc) { | |||
EXPECT_EQ(n1->UpdateOpDesc(desc_ptr2), GRAPH_SUCCESS); | |||
} | |||
/* | |||
TEST_F(UtestGeNode, add_link_from) { | |||
OpDescPtr desc_ptr = std::make_shared<OpDesc>("name", "type"); | |||
EXPECT_EQ(desc_ptr->AddInputDesc("x", GeTensorDesc(GeShape({1, 16, 16, 16}), FORMAT_NCHW)), GRAPH_SUCCESS); | |||
@@ -179,6 +179,7 @@ TEST_F(UtestGeNode, add_link_from) { | |||
NodePtr n8 = graph_ptr1->AddNode(desc_ptr1); | |||
EXPECT_EQ(n8->AddLinkFromForParse(n7), GRAPH_PARAM_INVALID); | |||
} | |||
*/ | |||
TEST_F(UtestGeNode, add_link_from_fail) { | |||
OpDescPtr desc_ptr = std::make_shared<OpDesc>("name1", "type1"); | |||
@@ -18,7 +18,7 @@ | |||
#include "common/formats/format_transfers/datatype_transfer.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "common/formats/formats.h" | |||
#include "common/fp16_t.h" | |||
@@ -17,9 +17,10 @@ | |||
#include <gtest/gtest.h> | |||
#include "common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "common/fp16_t.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
@@ -644,4 +645,4 @@ TEST_F(UTEST_FormatTransferNc1hwc0ToNchw, invalid_src_data_type) { | |||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
} | |||
} // namespace formats | |||
} // namespace ge | |||
} // namespace ge |
@@ -18,9 +18,12 @@ | |||
#include "common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "common/fp16_t.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
class UtestFormatTransfer5dNhwc : public testing::Test { | |||
@@ -759,4 +762,4 @@ TEST_F(UtestFormatTransfer5dNhwc, invalid_src_dst_shape_relation) { | |||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
} | |||
} // namespace formats | |||
} // namespace ge | |||
} // namespace ge |
@@ -18,9 +18,12 @@ | |||
#include "common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "common/fp16_t.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
class UtestFormatTransferC1hwncoc0Hwcn : public testing::Test { | |||
@@ -13710,4 +13713,4 @@ TEST_F(UtestFormatTransferC1hwncoc0Hwcn, sixd_to_hwcn_fp32_success_gt_cube) { | |||
} | |||
} | |||
} // namespace formats | |||
} // namespace ge | |||
} // namespace ge |
@@ -19,11 +19,14 @@ | |||
#include "common/formats/format_transfers/format_transfer_fractal_nz.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "common/formats/formats.h" | |||
#include "common/fp16_t.h" | |||
#include "time.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
class UtestFormatTransferNdFractNz : public testing::Test { | |||
@@ -9164,4 +9167,4 @@ TEST_F(UtestFormatTransferNdFractNz, invalid_src_dst_shape_relation) { | |||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
} | |||
} // namespace formats | |||
} // namespace ge | |||
} // namespace ge |
@@ -19,11 +19,14 @@ | |||
#include "common/formats/format_transfers/format_transfer_fractal_zz.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "common/formats/formats.h" | |||
#include "common/fp16_t.h" | |||
#include "time.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
class UtestFormatTransferNdFractZz : public testing::Test { | |||
@@ -7988,4 +7991,4 @@ TEST_F(UtestFormatTransferNdFractZz, invalid_src_dst_shape_relation) { | |||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
} | |||
} // namespace formats | |||
} // namespace ge | |||
} // namespace ge |
@@ -18,9 +18,12 @@ | |||
#include "common/formats/format_transfers/format_transfer_fracz_hwcn.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "common/fp16_t.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
class UtestFormatTransferFracZHwcn : public testing::Test { | |||
@@ -18,9 +18,12 @@ | |||
#include "common/formats/format_transfers/format_transfer_fracz_nchw.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "common/fp16_t.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
class UtestFormatTransferFraczNchw : public testing::Test { | |||
@@ -10486,4 +10489,4 @@ TEST_F(UtestFormatTransferFraczNchw, fp32_1) { | |||
} | |||
} | |||
} // namespace formats | |||
} // namespace ge | |||
} // namespace ge |
@@ -18,9 +18,12 @@ | |||
#include "common/formats/format_transfers/format_transfer_fracz_nhwc.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "common/fp16_t.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
class UtestFormatTransferFraczNhwc : public testing::Test { | |||
@@ -5422,4 +5425,4 @@ TEST_F(UtestFormatTransferFraczNhwc, fracz_to_nhwc_fp32_success_gt_cube) { | |||
} | |||
} | |||
} // namespace formats | |||
} // namespace ge | |||
} // namespace ge |
@@ -18,9 +18,12 @@ | |||
#include "common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "common/fp16_t.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
class UtestFormatTransferHwcnC1hwncoc0 : public testing::Test { | |||
@@ -13745,4 +13748,4 @@ TEST_F(UtestFormatTransferHwcnC1hwncoc0, hwcn_to_6d_fp32_success_gt_cube) { | |||
} | |||
} | |||
} // namespace formats | |||
} // namespace ge | |||
} // namespace ge |
@@ -18,7 +18,10 @@ | |||
#include "common/formats/format_transfers/format_transfer_fractal_z.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
@@ -34460,4 +34463,4 @@ TEST_F(UtestFormatTransferHwcnFz, build_transfer_not_support) { | |||
EXPECT_EQ(transfer, nullptr); | |||
} | |||
} // namespace formats | |||
} // namespace ge | |||
} // namespace ge |
@@ -18,7 +18,10 @@ | |||
#include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
@@ -18,7 +18,10 @@ | |||
#include "common/formats/format_transfers/format_transfer_fractal_z.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
@@ -16873,4 +16876,4 @@ TEST_F(UtestFormatTransferNchwFz, build_transfer_uint8) { | |||
EXPECT_NE(transfer, nullptr); | |||
} | |||
} // namespace formats | |||
} // namespace ge | |||
} // namespace ge |
@@ -18,9 +18,12 @@ | |||
#include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "common/fp16_t.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
class UtestFormatTransferNhwc5d : public testing::Test { | |||
@@ -747,4 +750,4 @@ TEST_F(UtestFormatTransferNhwc5d, unsupport_dst_format) { | |||
EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); | |||
} | |||
} // namespace formats | |||
} // namespace ge | |||
} // namespace ge |
@@ -18,7 +18,10 @@ | |||
#include "common/formats/format_transfers/format_transfer_fractal_z.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
@@ -5351,4 +5354,4 @@ TEST_F(UtestFormatTransferNhwcFz, build_transfer_uint8) { | |||
EXPECT_NE(transfer, nullptr); | |||
} | |||
} // namespace formats | |||
} // namespace ge | |||
} // namespace ge |
@@ -18,9 +18,13 @@ | |||
#include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
//#include "common/formats/format_transfers/format_transfer.h" | |||
#include "common/formats/utils/formats_trans_utils.h" | |||
#include "register/register_format_transfer.h" | |||
#include "framework/common/ge_inner_error_codes.h" | |||
namespace ge { | |||
namespace formats { | |||
@@ -78,4 +82,4 @@ TEST_F(UtestFormatTransfer, get_size_by_data_type) { | |||
EXPECT_EQ(DT_UNDEFINED, 26); | |||
} | |||
} // namespace formats | |||
} // namespace ge | |||
} // namespace ge |
@@ -189,18 +189,20 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||
bool ExpectStreamEq(SubGraphInfoPtr subgraph, int64_t expect) { return GetStream(subgraph) == expect; } | |||
bool ExpectStreamNe(SubGraphInfoPtr subgraph, int64_t expect) { return GetStream(subgraph) != expect; } | |||
Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, vector<EngineConfPtr> &confs, | |||
Status AssignLogicalStreams(Graph2SubGraphInfoList &subgraph_map, vector<EngineConfPtr> &confs, | |||
std::map<std::string, int> &max_parallel_num, ComputeGraphPtr &whole_graph) { | |||
SchedulerConf scheduler_conf; | |||
if (confs.empty()) { | |||
for (const auto &subgraph : subgraphs) { | |||
EngineConfPtr conf = make_shared<EngineConf>(); | |||
conf->id = subgraph->GetEngineName(); | |||
if (conf->id == "ge_local") { | |||
conf->skip_assign_stream = true; | |||
conf->attach = true; | |||
} | |||
scheduler_conf.cal_engines[conf->id] = conf; | |||
for (const auto &subgraph_pair : subgraph_map) { | |||
for (const auto &sub_graph : subgraph_pair.second) { | |||
EngineConfPtr conf = make_shared<EngineConf>(); | |||
conf->id = sub_graph->GetEngineName(); | |||
if (conf->id == "ge_local") { | |||
conf->skip_assign_stream = true; | |||
conf->attach = true; | |||
} | |||
scheduler_conf.cal_engines[conf->id] = conf; | |||
} | |||
} | |||
} else { | |||
for (auto &conf : confs) { | |||
@@ -217,11 +219,21 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||
scheduler_confs["scheduler"] = scheduler_conf; | |||
LogicalStreamAllocator allocator(scheduler_confs, max_parallel_num); | |||
int64_t stream_num = 0; | |||
return allocator.Assign(whole_graph, subgraphs, stream_num); | |||
return allocator.Assign(whole_graph, subgraph_map, stream_num); | |||
} | |||
Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, std::map<std::string, int> &max_parallel_num, | |||
vector<EngineConfPtr> &confs) { | |||
Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, | |||
vector<EngineConfPtr> &confs, | |||
std::map<std::string, int> &max_parallel_num, | |||
ComputeGraphPtr &whole_graph) { | |||
Graph2SubGraphInfoList subgraph_map; | |||
subgraph_map[whole_graph] = subgraphs; | |||
return AssignLogicalStreams(subgraph_map, confs, max_parallel_num, whole_graph); | |||
} | |||
Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, | |||
vector<EngineConfPtr>& confs, | |||
std::map<std::string, int> &max_parallel_num) { | |||
ComputeGraphPtr whole_graph = make_shared<ComputeGraph>("whole_graph"); | |||
return AssignLogicalStreams(subgraphs, confs, max_parallel_num, whole_graph); | |||
} | |||
@@ -229,12 +241,12 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||
Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, | |||
vector<EngineConfPtr> confs = vector<EngineConfPtr>()) { | |||
std::map<std::string, int> max_parallel_num; | |||
return AssignLogicalStreams(subgraphs, max_parallel_num, confs); | |||
return AssignLogicalStreams(subgraphs, confs, max_parallel_num); | |||
} | |||
Status AssignLogicalStreams(vector<SubGraphInfoPtr> subgraphs, std::map<std::string, int> &max_parallel_num) { | |||
vector<EngineConfPtr> confs; | |||
return AssignLogicalStreams(subgraphs, max_parallel_num, confs); | |||
return AssignLogicalStreams(subgraphs, confs, max_parallel_num); | |||
} | |||
/// typical case | |||
@@ -295,7 +307,7 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||
Status status = AssignLogicalStreams({const1, const2, get_next, genmask1, genmask2, domask, subgraph4, subgraph5, | |||
subgraph6, allreduce1, allreduce2, apply1, apply2}, | |||
max_parallel_num, confs); | |||
confs, max_parallel_num); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
EXPECT_EQ(GetStream(get_next), 0); | |||
@@ -652,7 +664,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent) { | |||
vector<EngineConfPtr> confs = {conf1, conf2}; | |||
Status status = | |||
AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4, subgraph5}, max_parallel_num, confs); | |||
AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4, subgraph5}, confs, max_parallel_num); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
EXPECT_EQ(GetStream(subgraph1), 0); | |||
EXPECT_EQ(GetStream(subgraph2), 0); | |||
@@ -695,7 +707,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { | |||
vector<EngineConfPtr> confs = {conf1, conf2, conf3}; | |||
Status status = | |||
AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4, subgraph5}, max_parallel_num, confs); | |||
AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4, subgraph5},confs, max_parallel_num); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
EXPECT_EQ(GetStream(subgraph1), 4); | |||
EXPECT_EQ(GetStream(subgraph2), 0); | |||
@@ -858,7 +870,7 @@ TEST_F(UtestLogicalStreamAllocator, test_all_reduce_parallel_pass) { | |||
std::map<std::string, int> max_parallel_num; | |||
LogicalStreamPass::Context context; | |||
context.next_stream = 5; | |||
context.hcom_parallel = true; | |||
context.enable_hcom_parallel = true; | |||
vector<LogicalStreamPass::SubgraphPtr> subgraphs; | |||
LogicalStreamPassPtr allreduce_pass = std::make_shared<AllReduceParallelPass>(); | |||
ret = allreduce_pass->Run(graph, subgraphs, context); | |||
@@ -152,7 +152,7 @@ TEST_F(UtestMemoryAssignerTest, MemoryBlock_Resize_RealSizeList_is_empty) { | |||
ge::OpDescPtr op_def_a = createOpWithWsSize("A", 6000); | |||
ge::NodePtr node_a = graph->AddNode(op_def_a); | |||
MemoryBlock* memory_block = new MemoryBlock(0); | |||
memory_block->Init(1, kOutput, node_a, 0); | |||
memory_block->Init(1, kOutput, node_a, 0, 1); | |||
memory_block->real_size_list_.clear(); | |||
memory_block->Resize(); | |||
@@ -165,7 +165,7 @@ namespace ge { | |||
class MockBlockMemAssigner : public BlockMemAssigner { | |||
public: | |||
explicit MockBlockMemAssigner(ge::ComputeGraphPtr compute_graph) : BlockMemAssigner(compute_graph){}; | |||
explicit MockBlockMemAssigner(ge::ComputeGraphPtr compute_graph, const std::map<std::string, std::string> &anchor_to_symbol, const std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors) : BlockMemAssigner(compute_graph, anchor_to_symbol, symbol_to_anchors) {}; | |||
virtual ~MockBlockMemAssigner(){}; | |||
@@ -177,7 +177,10 @@ class MockBlockMemAssigner : public BlockMemAssigner { | |||
TEST_F(UtestMemoryAssignerTest, Mock_block_mem_assigner_failed) { | |||
ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>(""); | |||
make_graph(graph); | |||
MockBlockMemAssigner mock_assigner(graph); | |||
std::map<std::string, std::string> anchor_to_symbol; | |||
std::map<std::string, std::list<NodeIndexIO>> symbol_to_anchors; | |||
EXPECT_EQ(GraphUtils::GetRefMapping(graph, symbol_to_anchors, anchor_to_symbol), GRAPH_SUCCESS); | |||
MockBlockMemAssigner mock_assigner(graph, anchor_to_symbol, symbol_to_anchors); | |||
EXPECT_EQ(mock_assigner.Assign(), FAILED); | |||
} |
@@ -40,20 +40,22 @@ std::vector<void *> stub_get_output_addrs(const RuntimeParam &model_param, Const | |||
} | |||
TEST_F(UtestDataDumper, LoadDumpInfo_no_output_addrs_fail) { | |||
DataDumper data_dumper; | |||
RuntimeParam rts_param; | |||
DataDumper data_dumper(rts_param); | |||
data_dumper.SetModelName("test"); | |||
data_dumper.SetModelId(2333); | |||
data_dumper.SetMemory(std::move(RuntimeParam{})); | |||
std::shared_ptr<OpDesc> op_desc_1(new OpDesc()); | |||
op_desc_1->AddOutputDesc("test", GeTensorDesc()); | |||
data_dumper.SaveDumpTask(0, op_desc_1, 0); | |||
data_dumper.SaveDumpTask(0, 0, op_desc_1, 0); | |||
string dump_mode = "output"; | |||
data_dumper.dump_properties_.SetDumpMode(dump_mode); | |||
Status ret = data_dumper.LoadDumpInfo(); | |||
EXPECT_EQ(ret, PARAM_INVALID); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
TEST_F(UtestDataDumper, UnloadDumpInfo_success) { | |||
DataDumper data_dumper; | |||
RuntimeParam rts_param; | |||
DataDumper data_dumper(rts_param); | |||
data_dumper.SetModelName("test"); | |||
data_dumper.SetModelId(2333); | |||
@@ -25,7 +25,6 @@ | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/model_serialize.h" | |||
#include "graph/load/new_model_manager/davinci_model.h" | |||
#include "graph/load/new_model_manager/model_output.h" | |||
#include "common/properties_manager.h" | |||
#include "common/op/ge_op_utils.h" | |||
#include <cce/taskdown_api.h> | |||
@@ -38,7 +37,6 @@ | |||
#include "graph/load/new_model_manager/task_info/stream_switch_task_info.h" | |||
#include "graph/load/new_model_manager/task_info/profiler_trace_task_info.h" | |||
#include "graph/load/new_model_manager/task_info/memcpy_async_task_info.h" | |||
#include "graph/load/new_model_manager/task_info/label_goto_task_info.h" | |||
#include "graph/load/new_model_manager/task_info/label_set_task_info.h" | |||
#include "graph/load/new_model_manager/task_info/kernel_ex_task_info.h" | |||
#include "graph/load/new_model_manager/task_info/kernel_task_info.h" | |||
@@ -113,7 +113,7 @@ class DModelListener : public ge::ModelListener { | |||
uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { return 0; } | |||
}; | |||
shared_ptr<ge::ModelListener> UTEST_CALL_BACK_FUN(new DModelListener()); | |||
shared_ptr<ModelListener> UTEST_CALL_BACK_FUN(new DModelListener()); | |||
TEST_F(UtestModelManagerModelManager, case_load_incorrect_param) { | |||
ModelManager mm; | |||
@@ -164,7 +164,7 @@ TEST_F(UtestModelManagerModelManager, case_load_model_encypt_type_unsupported) { | |||
delete[](uint8_t *) data.model_data; | |||
} | |||
shared_ptr<ge::ModelListener> LabelCallBack(new DModelListener()); | |||
shared_ptr<ModelListener> LabelCallBack(new DModelListener()); | |||
// test HandleCommand | |||
TEST_F(UtestModelManagerModelManager, command_success1) { | |||
@@ -306,6 +306,8 @@ TEST_F(UtestModelManagerModelManager, get_input_output_desc_info_fail) { | |||
EXPECT_EQ(ge::PARAM_INVALID, manager.GetInputOutputDescInfo(2, input_shape, output_shape)); | |||
} | |||
/* | |||
// test GetInputOutputDescInfo fail | |||
TEST_F(UtestModelManagerModelManager, get_input_output_desc_info_zero_copy_fail) { | |||
ModelManager manager; | |||
@@ -314,6 +316,7 @@ TEST_F(UtestModelManagerModelManager, get_input_output_desc_info_zero_copy_fail) | |||
vector<InputOutputDescInfo> output_shape; | |||
EXPECT_EQ(ge::PARAM_INVALID, manager.GetInputOutputDescInfoForZeroCopy(2, input_shape, output_shape)); | |||
} | |||
*/ | |||
// test Stop | |||
TEST_F(UtestModelManagerModelManager, stop_fail) { | |||
@@ -324,7 +327,7 @@ TEST_F(UtestModelManagerModelManager, stop_fail) { | |||
// build input_data | |||
TEST_F(UtestModelManagerModelManager, check_data_len_success) { | |||
shared_ptr<ge::ModelListener> g_label_call_back(new DModelListener()); | |||
shared_ptr<ModelListener> g_label_call_back(new DModelListener()); | |||
DavinciModel model(0, g_label_call_back); | |||
ModelManager model_manager; | |||
ge::InputData input_data; | |||
@@ -134,6 +134,55 @@ class OmeTestOpUtils { | |||
} | |||
} | |||
static Status TransModelToGeModel(const ModelPtr &model, GeModelPtr &ge_model) { | |||
if (model == nullptr) { | |||
GELOGE(FAILED, "Model is null"); | |||
return FAILED; | |||
} | |||
ge_model = ge::MakeShared<ge::GeModel>(); | |||
GE_CHECK_NOTNULL(ge_model); | |||
ge_model->SetGraph(model->GetGraph()); | |||
ge_model->SetName(model->GetName()); | |||
ge_model->SetVersion(model->GetVersion()); | |||
ge_model->SetPlatformVersion(model->GetPlatformVersion()); | |||
ge_model->SetAttr(model->MutableAttrMap()); | |||
auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph()); | |||
ge::Buffer weight; | |||
(void)ge::AttrUtils::GetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, weight); | |||
ge_model->SetWeight(weight); | |||
if (model->HasAttr(MODEL_ATTR_TASKS)) { | |||
ge::Buffer task_buffer; | |||
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetZeroCopyBytes(model, MODEL_ATTR_TASKS, task_buffer), FAILED, | |||
"Get bytes failed."); | |||
std::shared_ptr<ModelTaskDef> task = ge::MakeShared<ModelTaskDef>(); | |||
GE_CHECK_NOTNULL(task); | |||
GE_IF_BOOL_EXEC(task_buffer.GetData() == nullptr, GELOGE(FAILED, "Get data fail"); return FAILED); | |||
GE_IF_BOOL_EXEC(task_buffer.GetSize() == 0, GELOGE(FAILED, "Get size fail"); return FAILED); | |||
GE_CHK_BOOL_EXEC(ReadProtoFromArray(task_buffer.GetData(), static_cast<int>(task_buffer.GetSize()), task.get()), | |||
return INTERNAL_ERROR, "ReadProtoFromArray failed."); | |||
ge_model->SetModelTaskDef(task); | |||
} | |||
TBEKernelStore kernel_store; | |||
if (compute_graph != nullptr && compute_graph->GetDirectNodesSize() != 0) { | |||
for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { | |||
auto node_op_desc = n->GetOpDesc(); | |||
GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | |||
TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | |||
GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); | |||
kernel_store.AddTBEKernel(tbe_kernel); | |||
GELOGI("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); | |||
} | |||
} | |||
if (!kernel_store.Build()) { | |||
GELOGE(FAILED, "TBE Kernels store build failed!"); | |||
return FAILED; | |||
} | |||
ge_model->SetTBEKernelStore(kernel_store); | |||
return SUCCESS; | |||
} | |||
static void LoadStandardModelDataLocal(ge::ModelData &data) { | |||
static const std::string STANDARD_MODEL_DATA_PATH = | |||
"llt/framework/domi/ut/ome/test/data/standard_partition_model.txt"; | |||
@@ -151,7 +200,7 @@ class OmeTestOpUtils { | |||
ge::Model::Load((uint8_t *)data.model_data, data.model_len, *model_); | |||
GeModelPtr ge_model; | |||
ModelHelper::TransModelToGeModel(model_, ge_model); | |||
TransModelToGeModel(model_, ge_model); | |||
davinciModel.Assign(ge_model); | |||
if (data.model_data != nullptr) { | |||
@@ -178,7 +227,7 @@ class OmeTestOpUtils { | |||
model->SetGraph(graph); | |||
GeModelPtr ge_model; | |||
ModelHelper::TransModelToGeModel(model, ge_model); | |||
TransModelToGeModel(model, ge_model); | |||
davinciModel.Assign(ge_model); | |||
} | |||
@@ -24,9 +24,7 @@ | |||
#include "common/debug/memory_dumper.h" | |||
#include "common/op/ge_op_utils.h" | |||
#include "graph/load/new_model_manager/davinci_model.h" | |||
#include "graph/load/new_model_manager/model_output.h" | |||
#include "graph/load/new_model_manager/model_utils.h" | |||
#include "graph/load/output/output.h" | |||
#include "graph/manager/graph_var_manager.h" | |||
#include "new_op_test_utils.h" | |||
#include "proto/om.pb.h" | |||
@@ -49,7 +49,7 @@ class UtestTestPass : public BaseNodePass { | |||
for (const auto &node_name : iter->second) { | |||
auto del_node = node->GetOwnerComputeGraph()->FindNode(node_name); | |||
GraphUtils::IsolateNode(del_node, {0}); | |||
AddNodeDeleted(del_node.get()); | |||
AddNodeDeleted(del_node); | |||
} | |||
} | |||
iter = names_to_add_repass_.find(node->GetName()); | |||
@@ -38,7 +38,7 @@ class UtestGraphPassesFlowCtrlPass : public testing::Test { | |||
EXPECT_EQ(SUCCESS, ge::VarManager::Instance(0)->Init(session_version, session_id, device_id, job_id)); | |||
} | |||
void TearDown() { VarManagerPool::Instance().Destroy(); } | |||
void TearDown() { VarManagerPool::Instance().Destory(); } | |||
public: | |||
/// Set up a graph with the following network structure | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/add_kernel.h" | |||
#include "host_kernels/add_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -31,7 +31,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/broadcast_args_kernel.h" | |||
#include "host_kernels/broadcast_args_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -31,7 +31,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/broadcast_gradient_args_kernel.h" | |||
#include "host_kernels/broadcast_gradient_args_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/cast_kernel.h" | |||
#include "host_kernels/cast_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -28,7 +28,7 @@ | |||
#include "common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/passes/dimension_compute_pass.h" | |||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||
#include "host_kernels/kernel_utils.h" | |||
#include "graph/types.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/concat_offset_kernel.h" | |||
#include "host_kernels/concat_offset_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -28,7 +28,7 @@ | |||
#include "common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/passes/dimension_compute_pass.h" | |||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||
#include "host_kernels/kernel_utils.h" | |||
#include "graph/types.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/concat_v2_kernel.h" | |||
#include "host_kernels/concat_v2_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -28,7 +28,7 @@ | |||
#include "common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/passes/dimension_compute_pass.h" | |||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||
#include "host_kernels/kernel_utils.h" | |||
#include "graph/types.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/dynamic_stitch_kernel.h" | |||
#include "host_kernels/dynamic_stitch_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -27,7 +27,7 @@ | |||
#include "common/op/attr_value_util.h" | |||
#include "common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||
#include "host_kernels/kernel_utils.h" | |||
#include "graph/types.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
@@ -31,7 +31,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/empty_kernel.h" | |||
#include "host_kernels/empty_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/expanddims_kernel.h" | |||
#include "host_kernels/expanddims_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/fill_kernel.h" | |||
#include "host_kernels/fill_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/fp16_t.h" | |||
@@ -18,13 +18,13 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/floordiv_kernel.h" | |||
#include "host_kernels/floordiv_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
#include "common/op/ge_op_utils.h" | |||
#include "common/types.h" | |||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||
#include "host_kernels/kernel_utils.h" | |||
#include "graph/types.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
@@ -20,7 +20,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/floormod_kernel.h" | |||
#include "host_kernels/floormod_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -27,7 +27,7 @@ | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/operator.h" | |||
#include "graph/passes/constant_folding_pass.h" | |||
#include "graph/passes/folding_kernel/broadcast_args_kernel.h" | |||
#include "host_kernels/broadcast_args_kernel.h" | |||
#include "inc/kernel_factory.h" | |||
#include "shape_refiner.h" | |||
@@ -19,7 +19,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/gather_v2_kernel.h" | |||
#include "host_kernels/gather_v2_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -29,7 +29,7 @@ | |||
#include "common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/passes/dimension_compute_pass.h" | |||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||
#include "host_kernels/kernel_utils.h" | |||
#include "graph/types.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
@@ -984,4 +984,4 @@ TEST_F(UtestGraphPassesFoldingKernelGatherV2Kernel, AbnormalTest) { | |||
status = kernel->Compute(op_desc_ptr, input_7, outputs); | |||
EXPECT_NE(ge::SUCCESS, status); | |||
} | |||
} | |||
} |
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/greater_kernel.h" | |||
#include "host_kernels/greater_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -20,7 +20,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/maximum_kernel.h" | |||
#include "host_kernels/maximum_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -20,7 +20,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/mul_kernel.h" | |||
#include "host_kernels/mul_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/pack_kernel.h" | |||
#include "host_kernels/pack_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -28,7 +28,7 @@ | |||
#include "common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/passes/dimension_compute_pass.h" | |||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||
#include "host_kernels/kernel_utils.h" | |||
#include "graph/types.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/permute_kernel.h" | |||
#include "host_kernels/permute_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -28,7 +28,7 @@ | |||
#include "common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/passes/dimension_compute_pass.h" | |||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||
#include "host_kernels/kernel_utils.h" | |||
#include "graph/types.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
@@ -20,7 +20,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/range_kernel.h" | |||
#include "host_kernels/range_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/rank_kernel.h" | |||
#include "host_kernels/rank_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -18,14 +18,14 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/reduce_prod_kernel.h" | |||
#include "host_kernels/reduce_prod_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
#include "common/op/ge_op_utils.h" | |||
#include "common/types.h" | |||
#include "graph/passes/folding_kernel/concat_v2_kernel.h" | |||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||
#include "host_kernels/concat_v2_kernel.h" | |||
#include "host_kernels/kernel_utils.h" | |||
#include "graph/types.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
@@ -16,12 +16,12 @@ | |||
#include <gtest/gtest.h> | |||
#include "graph/passes/folding_kernel/reformat_kernel.h" | |||
#include "host_kernels/reformat_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/ge_inner_error_codes.h" | |||
#include "common/types.h" | |||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||
#include "host_kernels/kernel_utils.h" | |||
#include "graph/types.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/reshape_kernel.h" | |||
#include "host_kernels/reshape_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/rsqrt_kernel.h" | |||
#include "host_kernels/rsqrt_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/shape_kernel.h" | |||
#include "host_kernels/shape_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/shape_n_kernel.h" | |||
#include "host_kernels/shape_n_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -19,7 +19,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/size_kernel.h" | |||
#include "host_kernels/size_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/slice_kernel.h" | |||
#include "host_kernels/slice_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/squeeze_kernel.h" | |||
#include "host_kernels/squeeze_kernel.h" | |||
#include "../graph_builder_utils.h" | |||
#include "common/debug/log.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/ssd_prior_box_kernel.h" | |||
#include "host_kernels/ssd_prior_box_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/strided_slice_kernel.h" | |||
#include "host_kernels/strided_slice_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -28,7 +28,7 @@ | |||
#include "common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/passes/dimension_compute_pass.h" | |||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||
#include "host_kernels/kernel_utils.h" | |||
#include "graph/types.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
@@ -18,7 +18,7 @@ | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/folding_kernel/sub_kernel.h" | |||
#include "host_kernels/sub_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -16,7 +16,7 @@ | |||
#include <gtest/gtest.h> | |||
#include "graph/passes/folding_kernel/transdata_kernel.h" | |||
#include "host_kernels/transdata_kernel.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -26,7 +26,7 @@ | |||
#include "common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/passes/dimension_compute_pass.h" | |||
#include "graph/passes/folding_kernel/kernel_utils.h" | |||
#include "host_kernels/kernel_utils.h" | |||
#include "graph/types.h" | |||
#include "graph/utils/attr_utils.h" | |||
#include "graph/utils/graph_utils.h" | |||
@@ -155,7 +155,7 @@ TEST_F(UtestGraphPassesNetOutputPass, add_ctrl_edge_for_netout_from_leaf_success | |||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{relu3, 0}}; | |||
compute_graph->SetGraphOutNodesInfo(output_nodes); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
// check contain netoutput | |||
@@ -200,7 +200,7 @@ TEST_F(UtestGraphPassesNetOutputPass, only_target_node_success) { | |||
std::vector<ge::NodePtr> target_nodes = {mul1, mul2}; | |||
compute_graph->SetGraphTargetNodesInfo(target_nodes); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
// check contain netoutput | |||
@@ -256,7 +256,7 @@ TEST_F(UtestGraphPassesNetOutputPass, targets_with_retval_success) { | |||
} | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
// check contain netoutput | |||
@@ -300,7 +300,7 @@ TEST_F(UtestGraphPassesNetOutputPass, output_node_and_target_node_no_duplicate_s | |||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{relu3, 0}}; | |||
compute_graph->SetGraphOutNodesInfo(output_nodes); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
// check contain netoutput | |||
@@ -348,7 +348,7 @@ TEST_F(UtestGraphPassesNetOutputPass, output_node_and_target_node_duplicate_succ | |||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}}; | |||
compute_graph->SetGraphOutNodesInfo(output_nodes); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
// check contain netoutput | |||
@@ -398,7 +398,7 @@ TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_target_node_success) { | |||
compute_graph->SetGraphTargetNodesInfo(target_nodes); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
// check contain netoutput | |||
@@ -462,7 +462,7 @@ TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_output_nodes_and_targe | |||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}}; | |||
compute_graph->SetGraphOutNodesInfo(output_nodes); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
// check contain netoutput | |||
@@ -518,7 +518,7 @@ TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_output_nodes_and_targe | |||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}}; | |||
compute_graph->SetGraphOutNodesInfo(output_nodes); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
// check contain netoutput | |||
@@ -582,7 +582,7 @@ TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_output_nodes_and_targe | |||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}}; | |||
compute_graph->SetGraphOutNodesInfo(output_nodes); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
// check contain netoutput | |||
@@ -626,7 +626,7 @@ TEST_F(UtestGraphPassesNetOutputPass, no_output_no_target_no_retval_success) { | |||
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}}; | |||
compute_graph->SetGraphOutNodesInfo(output_nodes); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
} | |||
@@ -641,7 +641,7 @@ TEST_F(UtestGraphPassesNetOutputPass, user_out_node_success) { | |||
compute_graph->SetGraphOutNodesInfo(output_nodes); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | |||
@@ -687,7 +687,7 @@ TEST_F(UtestGraphPassesNetOutputPass, retval_node_for_out_success) { | |||
} | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | |||
@@ -737,7 +737,7 @@ TEST_F(UtestGraphPassesNetOutputPass, check_order_and_const_flag_success) { | |||
GraphUtils::AddEdge(mul2->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0)); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | |||
@@ -775,7 +775,7 @@ TEST_F(UtestGraphPassesNetOutputPass, out_node_check_fail) { | |||
compute_graph->SetGraphOutNodesInfo(output_nodes_invalid_name); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::INTERNAL_ERROR); | |||
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | |||
@@ -817,7 +817,7 @@ TEST_F(UtestGraphPassesNetOutputPass, retval_node_check_fail) { | |||
} | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::INTERNAL_ERROR); | |||
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); | |||
@@ -832,7 +832,7 @@ TEST_F(UtestGraphPassesNetOutputPass, out_node_update_desc_check_fail) { | |||
EXPECT_NE(netout_node, nullptr); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::INTERNAL_ERROR); | |||
} | |||
@@ -852,7 +852,7 @@ TEST_F(UtestGraphPassesNetOutputPass, out_node_remove_check_fail) { | |||
EXPECT_EQ(mul1, nullptr); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) NetOutputPass); | |||
pass_managers.AddPass("", new (std::nothrow) NetOutputPass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
} | |||
@@ -72,7 +72,7 @@ ComputeGraphPtr CreatePadGraph() { | |||
TEST_F(UtestGraphPassesPassManagerPass, all_pass_success) { | |||
PassManager manager; | |||
manager.AddPass(new SuccessGraphPass); | |||
manager.AddPass("", new SuccessGraphPass); | |||
EXPECT_EQ(manager.GraphPasses().size(), 1); | |||
ComputeGraphPtr graph = CreatePadGraph(); | |||
@@ -83,7 +83,7 @@ TEST_F(UtestGraphPassesPassManagerPass, all_pass_success) { | |||
TEST_F(UtestGraphPassesPassManagerPass, graph_pass_success) { | |||
ComputeGraphPtr graph = CreatePadGraph(); | |||
SuccessGraphPass pass; | |||
vector<GraphPass *> passes = {&pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(SUCCESS, status); | |||
} | |||
@@ -91,7 +91,7 @@ TEST_F(UtestGraphPassesPassManagerPass, graph_pass_success) { | |||
TEST_F(UtestGraphPassesPassManagerPass, graph_pass_not_changed) { | |||
ComputeGraphPtr graph = CreatePadGraph(); | |||
NotChangedGraphPass pass; | |||
vector<GraphPass *> passes = {&pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(NOT_CHANGED, status); | |||
} | |||
@@ -99,7 +99,7 @@ TEST_F(UtestGraphPassesPassManagerPass, graph_pass_not_changed) { | |||
TEST_F(UtestGraphPassesPassManagerPass, graph_pass_error) { | |||
ComputeGraphPtr graph = CreatePadGraph(); | |||
ErrorGraphPass pass; | |||
vector<GraphPass *> passes = {&pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(FAILED, status); | |||
} |
@@ -67,7 +67,7 @@ TEST_F(UtestGraphPassesPrunePass, no_net_out_put_node) { | |||
uint64_t size_ori = graph->GetDirectNode().size(); | |||
PrunePass prune_pass; | |||
vector<GraphPass *> passes = {&prune_pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} }; | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(ge::SUCCESS, status); | |||
@@ -109,7 +109,7 @@ TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_only_one_path) { | |||
uint64_t size_ori = graph->GetDirectNode().size(); | |||
PrunePass prune_pass; | |||
vector<GraphPass *> passes = {&prune_pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} }; | |||
Status status = PassManager::Run(graph, passes); | |||
uint64_t size = graph->GetDirectNode().size(); | |||
@@ -250,7 +250,7 @@ TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_multi_path) { | |||
uint64_t size_ori = graph->GetDirectNode().size(); | |||
PrunePass prune_pass; | |||
vector<GraphPass *> passes = {&prune_pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} }; | |||
Status status = PassManager::Run(graph, passes); | |||
uint64_t size_after_proc = graph->GetDirectNode().size(); | |||
@@ -323,7 +323,7 @@ TEST_F(UtestGraphPassesPrunePass, multi_net_out_put_node_with_circle_net) { | |||
uint64_t size_ori = graph->GetDirectNode().size(); | |||
PrunePass prune_pass; | |||
vector<GraphPass *> passes = {&prune_pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} }; | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(ge::SUCCESS, status); | |||
uint64_t size_after_proc = graph->GetDirectNode().size(); | |||
@@ -464,7 +464,7 @@ TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_two_isolate_data_nod | |||
uint64_t size_ori = graph->GetDirectNode().size(); | |||
PrunePass prune_pass; | |||
vector<GraphPass *> passes = {&prune_pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} }; | |||
Status status = PassManager::Run(graph, passes); | |||
uint64_t size = graph->GetDirectNode().size(); | |||
@@ -68,7 +68,7 @@ TEST_F(UtestResourcePairControlPass, resource_pair_control) { | |||
EXPECT_EQ(stackpop0->GetInControlNodes().size(), 0); | |||
ResourcePairAddControlPass add_pass; | |||
vector<GraphPass*> passes = {&add_pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes = { {"", &add_pass} }; | |||
EXPECT_EQ(PassManager::Run(graph, passes), SUCCESS); | |||
auto stackpush1 = graph->FindNode("stackpush1"); | |||
@@ -80,7 +80,7 @@ TEST_F(UtestResourcePairControlPass, resource_pair_control) { | |||
EXPECT_EQ(stackpop1->GetInControlNodes().at(0)->GetName(), "stackpush1"); | |||
ResourcePairRemoveControlPass remove_pass; | |||
passes = {&remove_pass}; | |||
passes = { {"", &remove_pass} }; | |||
EXPECT_EQ(PassManager::Run(graph, passes), SUCCESS); | |||
auto stackpush2 = graph->FindNode("stackpush1"); | |||
@@ -70,7 +70,7 @@ ge::ComputeGraphPtr CreateSaveGraph() { | |||
TEST_F(UtestGraphPassesSavePass, cover_run_success) { | |||
ge::ComputeGraphPtr compute_graph = CreateSaveGraph(); | |||
ge::PassManager pass_managers; | |||
pass_managers.AddPass(new (std::nothrow) SavePass); | |||
pass_managers.AddPass("", new (std::nothrow) SavePass); | |||
Status status = pass_managers.Run(compute_graph); | |||
EXPECT_EQ(status, ge::SUCCESS); | |||
} |
@@ -19,7 +19,6 @@ | |||
#include "omg/omg_inner_types.h" | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/switch_op_pass.h" | |||
#include "common/debug/log.h" | |||
#include "common/debug/memory_dumper.h" | |||
@@ -27,7 +26,6 @@ | |||
#include "common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/graph.h" | |||
#include "graph/passes/control_op_attr_pass.h" | |||
#include "inc/pass_manager.h" | |||
#undef protected | |||
#undef private | |||
@@ -19,7 +19,6 @@ | |||
#include <string> | |||
#define private public | |||
#include "graph/passes/switch_pass.h" | |||
#include "common/ge_inner_error_codes.h" | |||
#include "inc/pass_manager.h" | |||
@@ -54,9 +54,11 @@ TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_reshape) { | |||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), reshape_node->GetInDataAnchor(0)); | |||
ge::UnusedOpRemovePass unused_pass(FMK_TYPE_T); | |||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
ge::IsolatedOpRemovePass isolate_pass; | |||
vector<GraphPass *> passes = {&unused_pass, &isolate_pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
passes.emplace_back("", &isolate_pass); | |||
passes.emplace_back("", &unused_pass); | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(SUCCESS, status); | |||
NodePtr found_node = graph->FindNode("transpose1"); | |||
@@ -73,9 +75,11 @@ TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_squeeze) { | |||
GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), squeeze_node->GetInDataAnchor(0)); | |||
ge::UnusedOpRemovePass unused_pass(FMK_TYPE_T); | |||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
ge::IsolatedOpRemovePass isolate_pass; | |||
vector<GraphPass *> passes = {&unused_pass, &isolate_pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
passes.emplace_back("", &isolate_pass); | |||
passes.emplace_back("", &unused_pass); | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(SUCCESS, status); | |||
NodePtr found_node = graph->FindNode("transpose1"); | |||
@@ -100,9 +104,11 @@ TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv) { | |||
NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); | |||
GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); | |||
ge::UnusedOpRemovePass unused_pass(FMK_TYPE_T); | |||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
ge::IsolatedOpRemovePass isolate_pass; | |||
vector<GraphPass *> passes = {&unused_pass, &isolate_pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
passes.emplace_back("", &isolate_pass); | |||
passes.emplace_back("", &unused_pass); | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(SUCCESS, status); | |||
NodePtr found_node0 = graph->FindNode("transpose1"); | |||
@@ -128,9 +134,11 @@ TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, transpose_and_conv3) { | |||
NodePtr conv2_node = AddNode(graph, "conv2", CONVOLUTION); | |||
GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), conv2_node->GetInDataAnchor(0)); | |||
ge::UnusedOpRemovePass unused_pass(FMK_TYPE_T); | |||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
ge::IsolatedOpRemovePass isolate_pass; | |||
vector<GraphPass *> passes = {&unused_pass, &isolate_pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
passes.emplace_back("", &isolate_pass); | |||
passes.emplace_back("", &unused_pass); | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(SUCCESS, status); | |||
NodePtr found_node0 = graph->FindNode("transpose1"); | |||
@@ -151,9 +159,11 @@ TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, cast_and_cast) { | |||
GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); | |||
ge::UnusedOpRemovePass unused_pass(FMK_TYPE_T); | |||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
ge::IsolatedOpRemovePass isolate_pass; | |||
vector<GraphPass *> passes = {&unused_pass, &isolate_pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
passes.emplace_back("", &isolate_pass); | |||
passes.emplace_back("", &unused_pass); | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(SUCCESS, status); | |||
} | |||
@@ -171,9 +181,11 @@ TEST_F(UtestGraphPassesUnusedAndIsolatedOpRemovePass, remove_parent_node) { | |||
GraphUtils::AddEdge(conv3_node->GetOutDataAnchor(0), transpose_node->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(transpose_node->GetOutDataAnchor(0), transpose_node_1->GetInDataAnchor(0)); | |||
ge::UnusedOpRemovePass unused_pass(FMK_TYPE_T); | |||
ge::UnusedOpRemovePass unused_pass(TENSORFLOW); | |||
ge::IsolatedOpRemovePass isolate_pass; | |||
vector<GraphPass *> passes = {&unused_pass, &isolate_pass}; | |||
std::vector<std::pair<string, GraphPass*>> passes; | |||
passes.emplace_back("", &isolate_pass); | |||
passes.emplace_back("", &unused_pass); | |||
Status status = PassManager::Run(graph, passes); | |||
EXPECT_EQ(SUCCESS, status); | |||
} |
@@ -39,7 +39,6 @@ | |||
#include "graph/manager/graph_var_manager.h" | |||
#include "graph_builder_utils.h" | |||
#include "cce/dnn_struct_base.hpp" | |||
#include "common/formats/format_transfers/format_transfer.h" | |||
#include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" | |||
#include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" | |||
#include "common/formats/format_transfers/datatype_transfer.h" | |||
@@ -38,10 +38,9 @@ class UtestSingleOpManager : public testing::Test { | |||
}; | |||
TEST_F(UtestSingleOpManager, test_get_resource) { | |||
uintptr_t resource_id = 0x1; | |||
rtStream_t stream = (rtStream_t)0x01; | |||
auto &instance = SingleOpManager::GetInstance(); | |||
ASSERT_EQ(instance.TryGetResource(resource_id), nullptr); | |||
ASSERT_NE(instance.GetResource(resource_id), nullptr); | |||
ASSERT_NE(instance.GetResource(0x01, stream), nullptr); | |||
} | |||
TEST_F(UtestSingleOpManager, test_get_op_from_model) { | |||
@@ -56,7 +55,7 @@ TEST_F(UtestSingleOpManager, test_get_op_from_model) { | |||
model_data.model_len = model_str.size(); | |||
ASSERT_EQ(instance.GetOpFromModel("model", model_data, stream, &single_op), FAILED); | |||
ASSERT_EQ(instance.GetResource(resource_id)->GetOperator(model_data.model_data), nullptr); | |||
ASSERT_EQ(instance.GetResource(resource_id, stream)->GetOperator(model_data.model_data), nullptr); | |||
} | |||
TEST_F(UtestSingleOpManager, test_relesase_resource) { | |||
@@ -64,7 +63,7 @@ TEST_F(UtestSingleOpManager, test_relesase_resource) { | |||
auto &instance = SingleOpManager::GetInstance(); | |||
ASSERT_EQ(instance.ReleaseResource(stream), SUCCESS); | |||
instance.GetResource(0x99); | |||
instance.GetResource(0x99, stream); | |||
ASSERT_EQ(instance.ReleaseResource(stream), SUCCESS); | |||
} | |||
@@ -92,4 +91,4 @@ TEST_F(UtestSingleOpManager, get_resource_failed) { | |||
auto &instance = SingleOpManager::GetInstance(); | |||
ASSERT_EQ(instance.GetOpFromModel("model", model_data, stream, &single_op), FAILED); | |||
} | |||
} |
@@ -97,7 +97,9 @@ TEST_F(UtestSingleOpModel, test_set_inputs_and_outputs) { | |||
model.output_offset_list_.push_back(0); | |||
model.output_sizes_.push_back(16); | |||
SingleOp single_op; | |||
std::mutex stream_mu_; | |||
rtStream_t stream_ = nullptr; | |||
SingleOp single_op(&stream_mu_, stream_); | |||
ASSERT_EQ(model.SetInputsAndOutputs(single_op), SUCCESS); | |||
} | |||
@@ -111,25 +113,29 @@ TEST_F(UtestSingleOpModel, test_build_kernel_task) { | |||
model.output_offset_list_.push_back(0); | |||
model.output_sizes_.push_back(16); | |||
auto graph = make_shared<ComputeGraph>("graph"); | |||
auto op_desc = make_shared<OpDesc>("AddN", "AddN"); | |||
vector<int64_t> shape{16, 16}; | |||
GeShape ge_shape(shape); | |||
GeTensorDesc desc(ge_shape); | |||
op_desc->AddInputDesc(desc); | |||
op_desc->AddOutputDesc(desc); | |||
auto node = graph->AddNode(op_desc); | |||
std::mutex stream_mu_; | |||
rtStream_t stream_ = nullptr; | |||
SingleOp single_op(&stream_mu_, stream_); | |||
SingleOp single_op; | |||
domi::KernelDef kernel_def; | |||
kernel_def.mutable_context()->set_kernel_type(cce::ccKernelType::CCE_AI_CORE); | |||
OpTask *task = nullptr; | |||
ASSERT_EQ(model.BuildKernelTask(kernel_def, single_op, &task), UNSUPPORTED); | |||
kernel_def.mutable_context()->set_kernel_type(cce::ccKernelType::TE); | |||
TbeOpTask *task = nullptr; | |||
ASSERT_EQ(model.BuildKernelTask(kernel_def, &task), UNSUPPORTED); | |||
kernel_def.mutable_context()->set_kernel_type(cce::ccKernelType::TE); | |||
ASSERT_EQ(model.BuildKernelTask(kernel_def, single_op, &task), INTERNAL_ERROR); | |||
ASSERT_EQ(model.BuildKernelTask(kernel_def, &task), INTERNAL_ERROR); | |||
model.op_list_[0] = op_desc; | |||
model.op_list_[0] = node; | |||
ASSERT_EQ(model.BuildKernelTask(kernel_def, single_op, &task), PARAM_INVALID); | |||
ASSERT_EQ(model.BuildKernelTask(kernel_def, &task), PARAM_INVALID); | |||
ASSERT_EQ(task, nullptr); | |||
delete task; | |||
} | |||
@@ -145,18 +151,22 @@ TEST_F(UtestSingleOpModel, test_parse_arg_table) { | |||
SingleOpModel op_model("model", model_data_str.c_str(), model_data_str.size()); | |||
TbeOpTask task; | |||
SingleOp op; | |||
OpDescPtr op_desc; | |||
std::mutex stream_mu_; | |||
rtStream_t stream_ = nullptr; | |||
SingleOp op(&stream_mu_, stream_); | |||
op.arg_table_.resize(2); | |||
auto *args = new uintptr_t[2]; | |||
args[0] = 0x100000; | |||
args[1] = 0x200000; | |||
task.SetKernelArgs(args, 16, 1); | |||
auto args = std::unique_ptr<uint8_t[]>(new uint8_t[sizeof(uintptr_t) * 2]); | |||
auto *arg_base = (uintptr_t*)args.get(); | |||
arg_base[0] = 0x100000; | |||
arg_base[1] = 0x200000; | |||
task.SetKernelArgs(std::move(args), 16, 1, op_desc); | |||
op_model.model_params_.addr_mapping_[0x100000] = 1; | |||
op_model.ParseArgTable(&task, op); | |||
ASSERT_EQ(op.arg_table_[0].size(), 0); | |||
ASSERT_EQ(op.arg_table_[1].size(), 1); | |||
ASSERT_EQ(op.arg_table_[1].front(), &args[0]); | |||
ASSERT_EQ(op.arg_table_[1].front(), &arg_base[0]); | |||
} |
@@ -38,8 +38,9 @@ class UtestStreamResource : public testing::Test { | |||
rtStream_t stream; | |||
}; | |||
/* | |||
TEST_F(UtestStreamResource, test_cache_op) { | |||
StreamResource res; | |||
StreamResource res((uintptr_t)1); | |||
auto *op = new SingleOp(); | |||
string stub_name = "stubFunc"; | |||
const void *key = stub_name.c_str(); | |||
@@ -47,31 +48,34 @@ TEST_F(UtestStreamResource, test_cache_op) { | |||
res.CacheOperator(key, op); | |||
ASSERT_NE(res.GetOperator(key), nullptr); | |||
} | |||
*/ | |||
TEST_F(UtestStreamResource, test_malloc_memory) { | |||
StreamResource res; | |||
ASSERT_NE(res.MallocMemory(100), nullptr); | |||
ASSERT_NE(res.MallocMemory(100), nullptr); | |||
ASSERT_NE(res.MallocMemory(100), nullptr); | |||
StreamResource res((uintptr_t)1); | |||
string purpose("test"); | |||
ASSERT_NE(res.MallocMemory(purpose, 100), nullptr); | |||
ASSERT_NE(res.MallocMemory(purpose, 100), nullptr); | |||
ASSERT_NE(res.MallocMemory(purpose, 100), nullptr); | |||
} | |||
TEST_F(UtestStreamResource, test_do_malloc_memory) { | |||
size_t max_allocated = 0; | |||
vector<uint8_t *> allocated; | |||
string purpose("test"); | |||
uint8_t *ret = StreamResource::DoMallocMemory(100, max_allocated, allocated); | |||
StreamResource res((uintptr_t)1); | |||
uint8_t *ret = res.DoMallocMemory(purpose, 100, max_allocated, allocated); | |||
ASSERT_EQ(allocated.size(), 1); | |||
ASSERT_NE(allocated.back(), nullptr); | |||
ASSERT_EQ(max_allocated, 100); | |||
StreamResource::DoMallocMemory(50, max_allocated, allocated); | |||
StreamResource::DoMallocMemory(99, max_allocated, allocated); | |||
StreamResource::DoMallocMemory(100, max_allocated, allocated); | |||
res.DoMallocMemory(purpose, 50, max_allocated, allocated); | |||
res.DoMallocMemory(purpose, 99, max_allocated, allocated); | |||
res.DoMallocMemory(purpose, 100, max_allocated, allocated); | |||
ASSERT_EQ(allocated.size(), 1); | |||
ASSERT_EQ(max_allocated, 100); | |||
StreamResource::DoMallocMemory(101, max_allocated, allocated); | |||
res.DoMallocMemory(purpose, 101, max_allocated, allocated); | |||
ASSERT_EQ(allocated.size(), 2); | |||
ASSERT_EQ(max_allocated, 101); | |||