@@ -42,12 +42,12 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake) | |||||
include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) | include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) | ||||
include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) | include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) | ||||
include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake) | include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake) | ||||
include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake) | |||||
set(CMAKE_SKIP_RPATH TRUE) | set(CMAKE_SKIP_RPATH TRUE) | ||||
# for CPU/GPU mode, find c_sec and slog from local prebuild | # for CPU/GPU mode, find c_sec and slog from local prebuild | ||||
if(NOT ENABLE_D AND NOT GE_ONLY) | if(NOT ENABLE_D AND NOT GE_ONLY) | ||||
set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) | set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) | ||||
find_library(c_sec libc_sec.so ${GE_PREBUILD_PATH}) | |||||
find_library(slog libslog.so ${GE_PREBUILD_PATH}) | find_library(slog libslog.so ${GE_PREBUILD_PATH}) | ||||
# if D_LINK_PATH is set in environment variables, search libraries in given path | # if D_LINK_PATH is set in environment variables, search libraries in given path | ||||
elseif(DEFINED ENV{D_LINK_PATH}) | elseif(DEFINED ENV{D_LINK_PATH}) | ||||
@@ -64,6 +64,7 @@ elseif(DEFINED ENV{D_LINK_PATH}) | |||||
message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") | ||||
endif() | endif() | ||||
set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) | ||||
find_library(c_sec libc_sec.so ${GE_LIB_PATH}) | |||||
find_library(slog libslog.so ${GE_LIB_PATH}) | find_library(slog libslog.so ${GE_LIB_PATH}) | ||||
find_library(mmpa libmmpa.so ${GE_LIB_PATH}) | find_library(mmpa libmmpa.so ${GE_LIB_PATH}) | ||||
find_library(runtime libruntime.so ${GE_LIB_PATH}) | find_library(runtime libruntime.so ${GE_LIB_PATH}) | ||||
@@ -80,6 +81,7 @@ else() | |||||
endif() | endif() | ||||
set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64/common) | set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64/common) | ||||
set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) | set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) | ||||
find_library(c_sec libc_sec.so ${ASCEND_DRIVER_DIR}) | |||||
find_library(slog libslog.so ${ASCEND_DRIVER_DIR}) | find_library(slog libslog.so ${ASCEND_DRIVER_DIR}) | ||||
find_library(mmpa libmmpa.so ${ASCEND_DRIVER_DIR}) | find_library(mmpa libmmpa.so ${ASCEND_DRIVER_DIR}) | ||||
find_library(msprof libmsprof.so ${ASCEND_DRIVER_DIR}) | find_library(msprof libmsprof.so ${ASCEND_DRIVER_DIR}) | ||||
@@ -128,7 +130,7 @@ elseif(GE_ONLY) | |||||
add_subdirectory(${GE_SOURCE_DIR}/src/ge/plugin/engine) | add_subdirectory(${GE_SOURCE_DIR}/src/ge/plugin/engine) | ||||
endif() | endif() | ||||
if (ENABLE_GE_COV OR ENABLE_GE_UT OR ENABLE_GE_ST) | |||||
add_subdirectory(tests) | |||||
endif() | |||||
# if (ENABLE_GE_COV OR ENABLE_GE_UT OR ENABLE_GE_ST) | |||||
# add_subdirectory(tests) | |||||
# endif() | |||||
@@ -41,7 +41,7 @@ checkopts() | |||||
{ | { | ||||
VERBOSE="" | VERBOSE="" | ||||
THREAD_NUM=8 | THREAD_NUM=8 | ||||
ENABLE_GE_UT_ONLY_COMPILE="off" | |||||
# ENABLE_GE_UT_ONLY_COMPILE="off" | |||||
ENABLE_GE_UT="off" | ENABLE_GE_UT="off" | ||||
ENABLE_GE_ST="off" | ENABLE_GE_ST="off" | ||||
ENABLE_GE_COV="off" | ENABLE_GE_COV="off" | ||||
@@ -52,7 +52,7 @@ checkopts() | |||||
OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') | OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') | ||||
case "${opt}" in | case "${opt}" in | ||||
u) | u) | ||||
ENABLE_GE_UT_ONLY_COMPILE="on" | |||||
# ENABLE_GE_UT_ONLY_COMPILE="on" | |||||
ENABLE_GE_UT="on" | ENABLE_GE_UT="on" | ||||
GE_ONLY="off" | GE_ONLY="off" | ||||
;; | ;; | ||||
@@ -137,39 +137,39 @@ find ${OUTPUT_PATH} -name "*.so*" -print0 | xargs -0 chmod 500 | |||||
echo "---------------- GraphEngine output generated ----------------" | echo "---------------- GraphEngine output generated ----------------" | ||||
if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then | |||||
cp ${BUILD_PATH}/graphengine/tests/st/st_resnet50_train ${OUTPUT_PATH} | |||||
fi | |||||
if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then | |||||
cp ${BUILD_PATH}/graphengine/tests/ut/common/graph/ut_libgraph ${OUTPUT_PATH} | |||||
cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_multiparts_utest ${OUTPUT_PATH} | |||||
cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_distinct_load_utest ${OUTPUT_PATH} | |||||
cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_others_utest ${OUTPUT_PATH} | |||||
cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_kernel_utest ${OUTPUT_PATH} | |||||
if [[ "X${ENABLE_GE_UT_ONLY_COMPILE}" != "Xon" ]]; then | |||||
export LD_LIBRARY_PATH=${D_LINK_PATH}/x86_64/:${BUILD_PATH}../third_party/prebuild/x86_64/:${BUILD_PATH}/graphengine/:/usr/local/HiAI/driver/lib64:/usr/local/HiAI/runtime/lib64:${LD_LIBRARY_PATH} | |||||
echo ${LD_LIBRARY_PATH} | |||||
${OUTPUT_PATH}/ut_libgraph && | |||||
${OUTPUT_PATH}/ut_libge_multiparts_utest && | |||||
${OUTPUT_PATH}/ut_libge_distinct_load_utest && | |||||
${OUTPUT_PATH}/ut_libge_others_utest && | |||||
${OUTPUT_PATH}/ut_libge_kernel_utest | |||||
if [[ "$?" -ne 0 ]]; then | |||||
echo "!!! UT FAILED, PLEASE CHECK YOUR CHANGES !!!" | |||||
exit 1; | |||||
fi | |||||
fi | |||||
if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then | |||||
echo "Generating coverage statistics, please wait..." | |||||
cd ${BASEPATH} | |||||
rm -rf ${BASEPATH}/cov | |||||
mkdir ${BASEPATH}/cov | |||||
gcovr -r ./ --exclude 'third_party' --exclude 'build' --exclude 'tests' --exclude 'prebuild' --exclude 'inc' --print-summary --html --html-details -d -o cov/index.html | |||||
fi | |||||
fi | |||||
# if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then | |||||
# cp ${BUILD_PATH}/graphengine/tests/st/st_resnet50_train ${OUTPUT_PATH} | |||||
# fi | |||||
# if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then | |||||
# cp ${BUILD_PATH}/graphengine/tests/ut/common/graph/ut_libgraph ${OUTPUT_PATH} | |||||
# cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_multiparts_utest ${OUTPUT_PATH} | |||||
# cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_distinct_load_utest ${OUTPUT_PATH} | |||||
# cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_others_utest ${OUTPUT_PATH} | |||||
# cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_kernel_utest ${OUTPUT_PATH} | |||||
# if [[ "X${ENABLE_GE_UT_ONLY_COMPILE}" != "Xon" ]]; then | |||||
# export LD_LIBRARY_PATH=${D_LINK_PATH}/x86_64/:${BUILD_PATH}../third_party/prebuild/x86_64/:${BUILD_PATH}/graphengine/:/usr/local/HiAI/driver/lib64:/usr/local/HiAI/runtime/lib64:${LD_LIBRARY_PATH} | |||||
# echo ${LD_LIBRARY_PATH} | |||||
# ${OUTPUT_PATH}/ut_libgraph && | |||||
# ${OUTPUT_PATH}/ut_libge_multiparts_utest && | |||||
# ${OUTPUT_PATH}/ut_libge_distinct_load_utest && | |||||
# ${OUTPUT_PATH}/ut_libge_others_utest && | |||||
# ${OUTPUT_PATH}/ut_libge_kernel_utest | |||||
# if [[ "$?" -ne 0 ]]; then | |||||
# echo "!!! UT FAILED, PLEASE CHECK YOUR CHANGES !!!" | |||||
# exit 1; | |||||
# fi | |||||
# fi | |||||
# if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then | |||||
# echo "Generating coverage statistics, please wait..." | |||||
# cd ${BASEPATH} | |||||
# rm -rf ${BASEPATH}/cov | |||||
# mkdir ${BASEPATH}/cov | |||||
# gcovr -r ./ --exclude 'third_party' --exclude 'build' --exclude 'tests' --exclude 'prebuild' --exclude 'inc' --print-summary --html --html-details -d -o cov/index.html | |||||
# fi | |||||
# fi | |||||
# generate output package in tar form, including ut/st libraries/executables | # generate output package in tar form, including ut/st libraries/executables | ||||
cd ${BASEPATH} | cd ${BASEPATH} | ||||
@@ -1,11 +0,0 @@ | |||||
graphengine_add_pkg(securec | |||||
VER 1.1.10 | |||||
URL https://gitee.com/openeuler/bounds_checking_function/repository/archive/v1.1.10.tar.gz | |||||
MD5 0782dd2351fde6920d31a599b23d8c91 | |||||
LIBS c_sec | |||||
PATCHES ${GE_SOURCE_DIR}/third_party/patch/securec/securec.patch001 | |||||
CMAKE_OPTION " " | |||||
) | |||||
include_directories(${securec_INC}) | |||||
file(COPY ${securec_INC}/../lib/libc_sec.so DESTINATION ${CMAKE_SOURCE_DIR}/build/graphengine) | |||||
add_library(graphengine::securec ALIAS securec::c_sec) |
@@ -1,36 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef COMPRESS_H | |||||
#define COMPRESS_H | |||||
#include <uchar.h> | |||||
enum CmpStatus { RET_SUCCESS = 0, RET_ERROR = -1 }; | |||||
struct CompressConfig { | |||||
size_t inputSize; // length of data to compress | |||||
size_t engineNum; // how many decompress engines | |||||
size_t maxRatio; // how much size of a basic compression block, only 64 supported now (8x: 64 4x: 32) | |||||
size_t channel; // channels of L2 or DDR. For load balance | |||||
size_t fractalSize; // size of compressing block | |||||
bool isTight; // whether compose compressed data tightly | |||||
}; | |||||
CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, | |||||
size_t& compressedLength); | |||||
#endif // COMPRESS_H |
@@ -40,8 +40,6 @@ const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | |||||
const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | ||||
const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | ||||
const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | ||||
const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; | |||||
const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; | |||||
// Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | // Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | ||||
const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | ||||
const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | ||||
@@ -69,7 +69,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { | |||||
static std::unique_ptr<InferenceContext> Create(); | static std::unique_ptr<InferenceContext> Create(); | ||||
private: | private: | ||||
explicit InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||||
InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||||
std::shared_ptr<InferenceContextImpl> inference_context_impl_; | std::shared_ptr<InferenceContextImpl> inference_context_impl_; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -116,5 +116,27 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { | |||||
namespace ge { | namespace ge { | ||||
using OpRegistrationData = domi::OpRegistrationData; | using OpRegistrationData = domi::OpRegistrationData; | ||||
using OpReceiver = domi::OpReceiver; | using OpReceiver = domi::OpReceiver; | ||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { | |||||
public: | |||||
HostCpuOp() = default; | |||||
virtual ~HostCpuOp() = default; | |||||
virtual graphStatus Compute(Operator &op, const std::map<std::string, const Tensor> &inputs, | |||||
std::map<std::string, Tensor> &outputs) = 0; | |||||
}; | |||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { | |||||
public: | |||||
HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); | |||||
}; | |||||
#define REGISTER_HOST_CPU_OP_BUILDER(name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) | |||||
#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) | |||||
#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ | |||||
static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr __attribute__((unused)) = \ | |||||
::ge::HostCpuOpRegistrar(name, []() -> ::ge::HostCpuOp * { return new (std::nothrow) op(); }) | |||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ |
@@ -51,24 +51,24 @@ inline pid_t GetTid() { | |||||
return tid; | return tid; | ||||
} | } | ||||
#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | |||||
#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = domi::GetCurrentTimestap() | |||||
#define GE_TIMESTAMP_END(stage, stage_name) \ | #define GE_TIMESTAMP_END(stage, stage_name) \ | ||||
do { \ | do { \ | ||||
uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ | |||||
uint64_t endUsec_##stage = domi::GetCurrentTimestap(); \ | |||||
GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ | GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ | ||||
(endUsec_##stage - startUsec_##stage)); \ | (endUsec_##stage - startUsec_##stage)); \ | ||||
} while (0); | } while (0); | ||||
#define GE_TIMESTAMP_CALLNUM_START(stage) \ | |||||
uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ | |||||
uint64_t call_num_of##stage = 0; \ | |||||
#define GE_TIMESTAMP_CALLNUM_START(stage) \ | |||||
uint64_t startUsec_##stage = domi::GetCurrentTimestap(); \ | |||||
uint64_t call_num_of##stage = 0; \ | |||||
uint64_t time_of##stage = 0 | uint64_t time_of##stage = 0 | ||||
#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) | |||||
#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = domi::GetCurrentTimestap()) | |||||
#define GE_TIMESTAMP_ADD(stage) \ | |||||
time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ | |||||
#define GE_TIMESTAMP_ADD(stage) \ | |||||
time_of##stage += domi::GetCurrentTimestap() - startUsec_##stage; \ | |||||
call_num_of##stage++ | call_num_of##stage++ | ||||
#define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ | #define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ | ||||
@@ -103,17 +103,17 @@ using cce::ccStatus_t; | |||||
} while (0); | } while (0); | ||||
// If expr is not true, print the log and return the specified status | // If expr is not true, print the log and return the specified status | ||||
#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||||
do { \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
std::string msg; \ | |||||
(void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | |||||
(void)msg.append( \ | |||||
ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||||
DOMI_LOGE("%s", msg.c_str()); \ | |||||
return _status; \ | |||||
} \ | |||||
#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||||
do { \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
std::string msg; \ | |||||
(void)msg.append(domi::StringUtils::FormatString(__VA_ARGS__)); \ | |||||
(void)msg.append( \ | |||||
domi::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||||
DOMI_LOGE("%s", msg.c_str()); \ | |||||
return _status; \ | |||||
} \ | |||||
} while (0); | } while (0); | ||||
// If expr is not true, print the log and return the specified status | // If expr is not true, print the log and return the specified status | ||||
@@ -152,6 +152,7 @@ GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_INVALID, 11, "Get computeGraph by g | |||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_TINY_CAL_CHECK_FAILED, 15, "Check tiny calibration failed."); // 1343242255 | |||||
GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 | GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 | ||||
GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 | GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 | ||||
@@ -20,7 +20,7 @@ | |||||
#include <gflags/gflags.h> | #include <gflags/gflags.h> | ||||
#include <string> | #include <string> | ||||
namespace ge { | |||||
namespace domi { | |||||
class GflagsUtils { | class GflagsUtils { | ||||
public: | public: | ||||
static bool IsSetCommandTrue(const char *name) { | static bool IsSetCommandTrue(const char *name) { | ||||
@@ -66,6 +66,6 @@ class GflagsUtils { | |||||
} | } | ||||
} | } | ||||
}; | }; | ||||
} // namespace ge | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_COMMON_GFLAGS_UTIL_H_ | #endif // INC_FRAMEWORK_COMMON_GFLAGS_UTIL_H_ |
@@ -26,7 +26,7 @@ | |||||
#include "graph/model.h" | #include "graph/model.h" | ||||
#include "model/ge_model.h" | #include "model/ge_model.h" | ||||
namespace ge { | |||||
namespace domi { | |||||
class ModelHelper { | class ModelHelper { | ||||
public: | public: | ||||
ModelHelper() = default; | ModelHelper() = default; | ||||
@@ -65,8 +65,9 @@ class ModelHelper { | |||||
Status LoadTask(OmFileLoadHelper& om_load_helper); | Status LoadTask(OmFileLoadHelper& om_load_helper); | ||||
Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | ||||
Status ReleaseLocalModelData() noexcept; | Status ReleaseLocalModelData() noexcept; | ||||
Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, ModelPartitionType type, | Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, ModelPartitionType type, | ||||
const uint8_t* data, size_t size); | const uint8_t* data, size_t size); | ||||
}; | }; | ||||
} // namespace ge | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ | #endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ |
@@ -26,10 +26,8 @@ | |||||
#include "framework/common/ge_types.h" | #include "framework/common/ge_types.h" | ||||
using ProcParam = struct PROC_PARAM; | using ProcParam = struct PROC_PARAM; | ||||
using std::string; | |||||
using std::vector; | |||||
namespace ge { | |||||
namespace domi { | |||||
struct ModelPartition { | struct ModelPartition { | ||||
ModelPartitionType type; | ModelPartitionType type; | ||||
uint8_t *data = 0; | uint8_t *data = 0; | ||||
@@ -90,5 +88,5 @@ class OmFileSaveHelper { | |||||
ModelFileHeader model_header_; | ModelFileHeader model_header_; | ||||
OmFileContext context_; | OmFileContext context_; | ||||
}; | }; | ||||
} // namespace ge | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ | #endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ |
@@ -30,7 +30,7 @@ | |||||
using std::vector; | using std::vector; | ||||
namespace ge { | |||||
namespace domi { | |||||
// Size of RC memory alignment, 2M | // Size of RC memory alignment, 2M | ||||
constexpr size_t ALIGN_SIZE = 2097152; | constexpr size_t ALIGN_SIZE = 2097152; | ||||
@@ -118,6 +118,6 @@ class L2CacheOptimize { | |||||
bool Cross(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); | bool Cross(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); | ||||
bool Connect(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); | bool Connect(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); | ||||
}; | }; | ||||
} // namespace ge | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ | #endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ |
@@ -21,17 +21,11 @@ | |||||
#include <unordered_map> | #include <unordered_map> | ||||
#include <string> | #include <string> | ||||
#include "common/op/attr_define.h" | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "proto/om.pb.h" | #include "proto/om.pb.h" | ||||
using domi::AttrDef; | |||||
using domi::AttrDef_ListValue; | |||||
using domi::ModelDef; | |||||
using domi::NamedAttrs; | |||||
using domi::OpDef; | |||||
namespace ge { | |||||
namespace domi { | |||||
using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; | using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; | ||||
using AttrDefPair = ::google::protobuf::MapPair<std::string, domi::AttrDef>; | using AttrDefPair = ::google::protobuf::MapPair<std::string, domi::AttrDef>; | ||||
@@ -156,6 +150,6 @@ bool GetAttrDefListValue(const std::string &key, int idx, int32_t *value, const | |||||
bool GetAttrDefListValue(const std::string &key, int idx, uint32_t *value, const AttrDefMap &attr); | bool GetAttrDefListValue(const std::string &key, int idx, uint32_t *value, const AttrDefMap &attr); | ||||
bool GetAttrDefListValue(const std::string &key, int idx, float *value, const AttrDefMap &attr); | bool GetAttrDefListValue(const std::string &key, int idx, float *value, const AttrDefMap &attr); | ||||
bool GetAttrDefListValue(const std::string &key, int idx, double *value, const AttrDefMap &attr); | bool GetAttrDefListValue(const std::string &key, int idx, double *value, const AttrDefMap &attr); | ||||
} // namespace ge | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | #endif // INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ |
@@ -62,8 +62,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_LIMIT | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int NORMAL_TENSOR_SIZE; | |||||
class OpUtils { | class OpUtils { | ||||
public: | public: | ||||
/// | /// | ||||
@@ -22,7 +22,7 @@ | |||||
#include <math.h> | #include <math.h> | ||||
#include <stdint.h> | #include <stdint.h> | ||||
namespace ge { | |||||
namespace domi { | |||||
// general | // general | ||||
const float DEFAULT_ALPHA_VALUE = 1.0; | const float DEFAULT_ALPHA_VALUE = 1.0; | ||||
const float DEFAULT_BETA_VALUE = 0.0; | const float DEFAULT_BETA_VALUE = 0.0; | ||||
@@ -421,5 +421,5 @@ const uint32_t MULTI_SHAPE_INPUT_NUM = 2; | |||||
// Shufflechannel | // Shufflechannel | ||||
const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; | const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; | ||||
} // namespace ge | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ | #endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ |
@@ -20,7 +20,7 @@ | |||||
#include <set> | #include <set> | ||||
#include <string> | #include <string> | ||||
namespace ge { | |||||
namespace domi { | |||||
class OpTypeContainer { | class OpTypeContainer { | ||||
public: | public: | ||||
static OpTypeContainer *Instance() { | static OpTypeContainer *Instance() { | ||||
@@ -57,6 +57,6 @@ class OpTypeRegistrar { | |||||
const OpTypeRegistrar g_##var_name##_reg(str_name); | const OpTypeRegistrar g_##var_name##_reg(str_name); | ||||
#define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name)) | #define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name)) | ||||
} // namespace ge | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_COMMON_OP_TYPES_H_ | #endif // INC_FRAMEWORK_COMMON_OP_TYPES_H_ |
@@ -25,10 +25,10 @@ | |||||
/// MAKE_GUARD([&] { Release Resource 1 }) | /// MAKE_GUARD([&] { Release Resource 1 }) | ||||
/// Acquire Resource 2 | /// Acquire Resource 2 | ||||
// MAKE_GUARD([&] { Release Resource 2 }) | // MAKE_GUARD([&] { Release Resource 2 }) | ||||
#define GE_MAKE_GUARD(var, callback) ScopeGuard make_guard_##var(callback) | |||||
#define GE_MAKE_GUARD(var, callback) domi::ScopeGuard make_guard_##var(callback) | |||||
#define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | #define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | ||||
namespace ge { | |||||
namespace domi { | |||||
class ScopeGuard { | class ScopeGuard { | ||||
public: | public: | ||||
// Noncopyable | // Noncopyable | ||||
@@ -55,6 +55,6 @@ class ScopeGuard { | |||||
std::function<void()> on_exit_scope_; | std::function<void()> on_exit_scope_; | ||||
bool dismissed_; | bool dismissed_; | ||||
}; | }; | ||||
} // namespace ge | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ | #endif // INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ |
@@ -25,7 +25,7 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
namespace ge { | |||||
namespace domi { | |||||
class StringUtils { | class StringUtils { | ||||
public: | public: | ||||
static std::string &Ltrim(std::string &s) { | static std::string &Ltrim(std::string &s) { | ||||
@@ -151,6 +151,6 @@ class StringUtils { | |||||
return ret > 0 ? buffer : ""; | return ret > 0 ? buffer : ""; | ||||
} | } | ||||
}; | }; | ||||
} // namespace ge | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_COMMON_STRING_UTIL_H_ | #endif // INC_FRAMEWORK_COMMON_STRING_UTIL_H_ |
@@ -26,7 +26,6 @@ | |||||
#include <string> | #include <string> | ||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include "framework/common/fmk_error_codes.h" | #include "framework/common/fmk_error_codes.h" | ||||
#include "framework/common/fmk_types.h" | #include "framework/common/fmk_types.h" | ||||
#include "framework/common/op_types.h" | #include "framework/common/op_types.h" | ||||
@@ -47,7 +46,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_A | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_STATUS; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_STATUS; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; | ||||
} // namespace ge | |||||
namespace domi { | |||||
// Supported public properties name | // Supported public properties name | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_DUMP_PATH; // Dump path | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_DUMP_PATH; // Dump path | ||||
@@ -67,6 +68,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFIL | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::map<std::string, std::string> PROFILE_COMPONENT_MAP; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::map<std::string, std::string> PROFILE_COMPONENT_MAP; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_CONFIG; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_CONFIG; | ||||
/// @brief Data structure definition related to task sinking | |||||
/// Build model | |||||
enum BuildMode { | |||||
GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | |||||
GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | |||||
GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) | |||||
}; | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; | ||||
@@ -333,7 +342,7 @@ REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); | |||||
REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); | REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); | ||||
REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); | REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); | ||||
// ANN dedicated operator | |||||
/***************ANN dedicated operator *************************/ | |||||
REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); | REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); | ||||
REGISTER_OPTYPE_DECLARE(ANN_CONVOLUTION, "AnnConvolution"); | REGISTER_OPTYPE_DECLARE(ANN_CONVOLUTION, "AnnConvolution"); | ||||
REGISTER_OPTYPE_DECLARE(ANN_DEPCONVOLUTION, "AnnDepthConv"); | REGISTER_OPTYPE_DECLARE(ANN_DEPCONVOLUTION, "AnnDepthConv"); | ||||
@@ -350,7 +359,7 @@ REGISTER_OPTYPE_DECLARE(ANN_QUANTIZE, "AnnQuant"); | |||||
REGISTER_OPTYPE_DECLARE(ANN_PAD, "AnnPad"); | REGISTER_OPTYPE_DECLARE(ANN_PAD, "AnnPad"); | ||||
REGISTER_OPTYPE_DECLARE(ANN_RESIZE_BILINEAR, "AnnResizeBilinear"); | REGISTER_OPTYPE_DECLARE(ANN_RESIZE_BILINEAR, "AnnResizeBilinear"); | ||||
// Training operator | |||||
/********************Training operator ***********************/ | |||||
REGISTER_OPTYPE_DECLARE(GATHERV2, "GatherV2"); | REGISTER_OPTYPE_DECLARE(GATHERV2, "GatherV2"); | ||||
REGISTER_OPTYPE_DECLARE(CONVGRADFILTER, "Conv2DBackpropFilter"); | REGISTER_OPTYPE_DECLARE(CONVGRADFILTER, "Conv2DBackpropFilter"); | ||||
REGISTER_OPTYPE_DECLARE(CONV2D, "Conv2D"); | REGISTER_OPTYPE_DECLARE(CONV2D, "Conv2D"); | ||||
@@ -434,7 +443,6 @@ REGISTER_OPTYPE_DECLARE(STREAMSWITCH, "StreamSwitch"); | |||||
REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); | REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); | ||||
REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); | REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); | ||||
REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); | REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); | ||||
REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | |||||
REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | ||||
REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | ||||
REGISTER_OPTYPE_DECLARE(SEND, "Send"); | REGISTER_OPTYPE_DECLARE(SEND, "Send"); | ||||
@@ -442,7 +450,6 @@ REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | |||||
REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | ||||
REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | ||||
REGISTER_OPTYPE_DECLARE(LABELGOTOEX, "LabelGotoEx"); | |||||
REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | ||||
REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | ||||
@@ -821,6 +828,9 @@ static constexpr int32_t PARTITION_TYPE_TASK_INFO = 2; | |||||
// number of partitions in the current model | // number of partitions in the current model | ||||
static constexpr uint32_t PARTITION_SIZE = 4; | static constexpr uint32_t PARTITION_SIZE = 4; | ||||
#define SIZE_OF_MODEL_PARTITION_TABLE(table) \ | |||||
(sizeof(domi::ModelPartitionTable) + sizeof(domi::ModelPartitionMemInfo) * (table).num) | |||||
enum ModelPartitionType { MODEL_DEF = 0, WEIGHTS_DATA, TASK_INFO, TBE_KERNELS }; | enum ModelPartitionType { MODEL_DEF = 0, WEIGHTS_DATA, TASK_INFO, TBE_KERNELS }; | ||||
struct ModelPartitionMemInfo { | struct ModelPartitionMemInfo { | ||||
@@ -834,8 +844,6 @@ struct ModelPartitionTable { | |||||
ModelPartitionMemInfo partition[0]; | ModelPartitionMemInfo partition[0]; | ||||
}; | }; | ||||
#define SIZE_OF_MODEL_PARTITION_TABLE(table) (sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * (table).num) | |||||
static constexpr int32_t PTHREAD_CREAT_SUCCESS = 0; // pthread_creat success | static constexpr int32_t PTHREAD_CREAT_SUCCESS = 0; // pthread_creat success | ||||
// Filter format | // Filter format | ||||
@@ -967,8 +975,8 @@ typedef enum tagDomiNanPropagation { | |||||
// mode of cropandresize | // mode of cropandresize | ||||
typedef enum tagDomiCropAndResizeMode { | typedef enum tagDomiCropAndResizeMode { | ||||
DOMI_RESIZE_METHOD_BILINEAR = 0, // resize bilinear | |||||
DOMI_RESIZE_METHOD_NEAREST, // resize nearest | |||||
DOMI_RESIZE_METHOD_BILINEAR = 0, /**< resize bilinear */ | |||||
DOMI_RESIZE_METHOD_NEAREST, /**< resize nearest */ | |||||
DOMI_RESIZE_RESERVED | DOMI_RESIZE_RESERVED | ||||
} domiCropAndResizeMode_t; | } domiCropAndResizeMode_t; | ||||
@@ -1055,15 +1063,6 @@ struct BasicInfo { | |||||
uint32_t total_size; // total memory size | uint32_t total_size; // total memory size | ||||
}; | }; | ||||
#pragma pack() // Cancels single-byte alignment | #pragma pack() // Cancels single-byte alignment | ||||
} // namespace ge | |||||
namespace domi { | |||||
/// @brief Data structure definition related to task sinking | |||||
enum BuildMode { | |||||
GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | |||||
GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | |||||
GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) | |||||
}; | |||||
} // namespace domi | } // namespace domi | ||||
#endif // INC_FRAMEWORK_COMMON_TYPES_H_ | #endif // INC_FRAMEWORK_COMMON_TYPES_H_ |
@@ -220,7 +220,7 @@ static constexpr int32_t OM_PROTO_VERSION = 2; | |||||
*/ | */ | ||||
#define CEIL(N, n) (((N) + (n)-1) / (n)) | #define CEIL(N, n) (((N) + (n)-1) / (n)) | ||||
namespace ge { | |||||
namespace domi { | |||||
using google::protobuf::Message; | using google::protobuf::Message; | ||||
/// | /// | ||||
@@ -390,6 +390,6 @@ bool CheckOutputPathValid(const std::string &file_path); | |||||
/// @param [out] result | /// @param [out] result | ||||
/// | /// | ||||
bool ValidateStr(const std::string &filePath, const std::string &mode); | bool ValidateStr(const std::string &filePath, const std::string &mode); | ||||
} // namespace ge | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_COMMON_UTIL_H_ | #endif // INC_FRAMEWORK_COMMON_UTIL_H_ |
@@ -28,16 +28,12 @@ | |||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "register/register_fmk_types.h" | #include "register/register_fmk_types.h" | ||||
using domi::DOMI_TENSOR_ND; | |||||
using domi::DOMI_TENSOR_RESERVED; | |||||
using domi::domiTensorFormat_t; | |||||
using domi::FrameworkType; | |||||
using std::map; | using std::map; | ||||
using std::string; | using std::string; | ||||
using std::unordered_map; | using std::unordered_map; | ||||
using std::vector; | using std::vector; | ||||
namespace ge { | |||||
namespace domi { | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
* @brief run model | * @brief run model | ||||
@@ -97,7 +93,7 @@ struct OmgContext { | |||||
std::string ddk_version; | std::string ddk_version; | ||||
// preferential format used by the entire network | // preferential format used by the entire network | ||||
domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | ||||
domi::FrameworkType type = domi::FMK_TYPE_RESERVED; | |||||
FrameworkType type = FMK_TYPE_RESERVED; | |||||
RunMode run_mode = ONLY_PRE_CHECK; | RunMode run_mode = ONLY_PRE_CHECK; | ||||
bool train_flag = false; | bool train_flag = false; | ||||
// whether to use FP16 high precision | // whether to use FP16 high precision | ||||
@@ -106,25 +102,23 @@ struct OmgContext { | |||||
std::string output_type; | std::string output_type; | ||||
// Save the name of the entire network: Some special operators are used to determine a network. Some operators in the | // Save the name of the entire network: Some special operators are used to determine a network. Some operators in the | ||||
// network require special processing based on the specific network. e.g:faster-rcnn, the FirstStageProcessor module | |||||
// is determined as the Faster-R-CNN network based on the scope fusion. Then, the conv+reshape operators in the | |||||
// FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The convolution kernel rearrangement reshape | |||||
// operator needs to be deleted for the convolution kernel. | |||||
// network require special processing based on the specific network. | |||||
// e.g:faster-rcnn, the FirstStageProcessor module is determined as the Faster-R-CNN network based on the scope | |||||
// fusion. Then, the conv+reshape operators in the FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The | |||||
// convolution kernel rearrangement reshape operator needs to be deleted for the convolution kernel. | |||||
std::string net_name; | std::string net_name; | ||||
// Whether to use dynamic batch size or dynamic image size | // Whether to use dynamic batch size or dynamic image size | ||||
bool is_dynamic_input = false; | bool is_dynamic_input = false; | ||||
std::string dynamic_batch_size; | std::string dynamic_batch_size; | ||||
std::string dynamic_image_size; | std::string dynamic_image_size; | ||||
}; | }; | ||||
} // namespace ge | |||||
namespace domi { | |||||
/** | /** | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
* @brief get OMG context | * @brief get OMG context | ||||
* @return OmgContext context | * @return OmgContext context | ||||
*/ | */ | ||||
ge::OmgContext &GetContext(); | |||||
OmgContext &GetContext(); | |||||
struct TEBinInfo { | struct TEBinInfo { | ||||
// It is obsolete. It will be automatically obtained from the binfilename field of the JSON file later. | // It is obsolete. It will be automatically obtained from the binfilename field of the JSON file later. | ||||
@@ -26,7 +26,7 @@ | |||||
#include "common/string_util.h" | #include "common/string_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
namespace ge { | |||||
namespace domi { | |||||
class PlatformVersionManager { | class PlatformVersionManager { | ||||
public: | public: | ||||
PlatformVersionManager() = delete; | PlatformVersionManager() = delete; | ||||
@@ -40,6 +40,6 @@ class PlatformVersionManager { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
}; // class PlatformManager | }; // class PlatformManager | ||||
} // namespace ge | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_OMG_VERSION_H_ | #endif // INC_FRAMEWORK_OMG_VERSION_H_ |
@@ -58,8 +58,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; | ||||
@@ -76,7 +74,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; | |||||
// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string | |||||
// ATTR_NAME_WEIGHTS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; | ||||
@@ -124,13 +123,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; | ||||
@@ -148,24 +140,12 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||||
// to be deleted | // to be deleted | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; | ||||
@@ -178,15 +158,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; | ||||
// _Arg | // _Arg | ||||
@@ -275,29 +255,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNOR | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; | |||||
// Huberloss | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; | |||||
// SSDRealDivTileMul | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; | |||||
// SSDSumMulRealDivMean | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; | |||||
/// ConcatFive2Four | |||||
/// ConcatFour2Five | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; | |||||
// Scale | // Scale | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; | ||||
@@ -334,6 +292,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_AT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | ||||
// Roipooling | // Roipooling | ||||
@@ -346,7 +305,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLI | |||||
// DetectionOutput | // DetectionOutput | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | ||||
@@ -405,7 +363,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ | |||||
// Permute | // Permute | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; | |||||
// SSD Normalize | // SSD Normalize | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | ||||
@@ -446,15 +403,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_AT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; | ||||
// Log | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; | |||||
// Pack | // Pack | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; | ||||
// Dynamic stitch | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||||
// Unpack | // Unpack | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; | ||||
// Gathernd | // Gathernd | ||||
@@ -463,16 +414,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND | |||||
// Argmax | // Argmax | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; | |||||
// Upsample | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; | |||||
// Relu | // Relu | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; | ||||
@@ -511,7 +454,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_TF; | |||||
// Generate_rpn_proposal | // Generate_rpn_proposal | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; | ||||
@@ -544,7 +486,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_AT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; | ||||
static const std::string NOT_NET_OUTPUT = "not_net_output"; | |||||
// ENTER | // ENTER | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; | ||||
@@ -570,9 +511,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_B | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; | ||||
// RetinaNet | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; | |||||
// MatMul | // MatMul | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; | ||||
@@ -621,30 +559,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; | |||||
// Upsample | // Upsample | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; | ||||
// PadV2 | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; | |||||
// MirrorPad | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; | |||||
// Filler | // Filler | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; | ||||
@@ -665,6 +583,36 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; | ||||
@@ -689,6 +637,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||||
// Public attribute | // Public attribute | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; | ||||
@@ -740,159 +696,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||||
// L2_normalize | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; | |||||
// HCOM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | |||||
// Log time stamp | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; | |||||
// SpaceToDepth/DepthToSpace | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; | |||||
// SparseSoftmaxCrossEntropyWithLogits | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; | |||||
// MaxPoolGradWithArgmax | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; | |||||
// AvgPoolGrad | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; | |||||
// Varible | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||||
// Assign | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; | |||||
// ShapeN | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; | |||||
// Space2bacth batch2space | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; | |||||
// Depth_to_space space_to_depth | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||||
// FakeQuantWithMinMaxVars | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; | |||||
// Mobilenet_ssd_conv_fusion | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; | |||||
// Lsh project | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; | |||||
// Control flow | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; | |||||
// GatherV2 attr def | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; | |||||
// Reshape attr def | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; | |||||
// Axis attr def | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; | |||||
// The node link with SparseSoftmaxCrossEntropyWithLogits | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||||
// For constant folding | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||||
// Used for mark the active label list to find stream of activated node | // Used for mark the active label list to find stream of activated node | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; | ||||
@@ -905,6 +708,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
// Control flow | // Control flow | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | ||||
@@ -979,33 +783,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | ||||
// functional ops attr | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; | |||||
// used for label switch | // used for label switch | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | ||||
// Varible | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; | |||||
// HCOM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DATATYPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_DATATYPE; | |||||
// Dynamic stitch | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ |
@@ -22,7 +22,7 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/anchor.h" | #include "graph/anchor.h" | ||||
#include "graph/detail/attributes_holder.h" | |||||
#include "detail/attributes_holder.h" | |||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
@@ -25,7 +25,11 @@ | |||||
#include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
namespace domi { | |||||
class ModelHelper; | |||||
} | |||||
namespace ge { | namespace ge { | ||||
using domi::ModelHelper; | |||||
using std::map; | using std::map; | ||||
using std::string; | using std::string; | ||||
using std::vector; | using std::vector; | ||||
@@ -14,8 +14,8 @@ | |||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#ifndef INC_GRAPH_USR_TYPES_H_ | |||||
#define INC_GRAPH_USR_TYPES_H_ | |||||
#ifndef INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||||
#define INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||||
#include <atomic> | #include <atomic> | ||||
#include <memory> | #include <memory> | ||||
@@ -130,4 +130,4 @@ struct UsrQuantizeFactorParams { | |||||
#undef USR_TYPE_BYTES_DEC | #undef USR_TYPE_BYTES_DEC | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_GRAPH_USR_TYPES_H_ | |||||
#endif // INC_EXTERNAL_GRAPH_USR_TYPES_H_ |
@@ -262,8 +262,6 @@ class GraphUtils { | |||||
static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | ||||
static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | ||||
static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec); | |||||
}; | }; | ||||
class ComputeGraphBuilder { | class ComputeGraphBuilder { | ||||
@@ -59,6 +59,7 @@ include_directories(${GE_SOURCE_DIR}/inc/graph) | |||||
include_directories(${GE_SOURCE_DIR}/inc/common) | include_directories(${GE_SOURCE_DIR}/inc/common) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
include_directories(${GE_SOURCE_DIR}/build) | include_directories(${GE_SOURCE_DIR}/build) | ||||
@@ -53,6 +53,7 @@ void Anchor::UnlinkAll() noexcept { | |||||
if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { | if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { | ||||
GELOGW("unlink peer_anchor_ptr failed."); | GELOGW("unlink peer_anchor_ptr failed."); | ||||
} | } | ||||
} while (!peer_anchors_.empty()); | } while (!peer_anchors_.empty()); | ||||
} | } | ||||
} | } | ||||
@@ -54,34 +54,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesS | |||||
return s; | return s; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const { | ||||
if (sub_graph_.empty()) { | |||||
return Vistor<NodePtr>(shared_from_this(), nodes_); | |||||
} | |||||
std::vector<NodePtr> all_nodes; | |||||
std::deque<NodePtr> candidates; | |||||
candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); | |||||
while (!candidates.empty()) { | |||||
NodePtr node = candidates.front(); | |||||
all_nodes.emplace_back(node); | |||||
candidates.pop_front(); | |||||
OpDescPtr op_desc = node->GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
vector<NodePtr> all_nodes(nodes_.size()); | |||||
(void)std::copy(nodes_.begin(), nodes_.end(), all_nodes.begin()); | |||||
for (const auto &sub_graph : sub_graph_) { | |||||
if (sub_graph == nullptr) { | |||||
GELOGW("sub graph is nullptr"); | |||||
continue; | continue; | ||||
} | } | ||||
const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||||
for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { | |||||
auto subgraph = GetSubgraph(*name_iter); | |||||
if (subgraph != nullptr) { | |||||
candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); | |||||
} | |||||
for (const auto &node : sub_graph->GetAllNodes()) { | |||||
all_nodes.push_back(node); | |||||
} | } | ||||
} | } | ||||
return Vistor<NodePtr>(shared_from_this(), all_nodes); | return Vistor<NodePtr>(shared_from_this(), all_nodes); | ||||
} | } | ||||
size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } | size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } | ||||
@@ -619,7 +602,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE | |||||
graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | ||||
std::map<NodePtr, uint32_t> &map_in_edge_num, | std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
std::vector<NodePtr> &stack) { | std::vector<NodePtr> &stack) { | ||||
GELOGI("Runing_Dfs_Sort: %s", name_.c_str()); | |||||
GELOGI("Runing_Dfs_Sort"); | |||||
// Record the number of non data nodes but no input nodes | // Record the number of non data nodes but no input nodes | ||||
GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); | GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); | ||||
@@ -664,7 +647,7 @@ graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||||
graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec, | graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec, | ||||
std::map<NodePtr, uint32_t> &map_in_edge_num, | std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
std::deque<NodePtr> &stack) { | std::deque<NodePtr> &stack) { | ||||
GELOGI("Runing_Bfs_Sort: %s", name_.c_str()); | |||||
GELOGI("Runing_Bfs_Sort"); | |||||
std::vector<NodePtr> stack_input; | std::vector<NodePtr> stack_input; | ||||
std::map<string, NodePtr> breadth_node_map; | std::map<string, NodePtr> breadth_node_map; | ||||
// Record the number of non data nodes but no input nodes | // Record the number of non data nodes but no input nodes | ||||
@@ -752,7 +735,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog | |||||
use_BFS = true; | use_BFS = true; | ||||
} | } | ||||
} else { | } else { | ||||
GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); | |||||
GELOGW("Get OPTION_GRAPH_RUN_MODE failed, use BFSTopologicalSorting by default."); | |||||
} | } | ||||
if (use_BFS) { | if (use_BFS) { | ||||
@@ -66,6 +66,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||||
anchor_points.clear(); | anchor_points.clear(); | ||||
// Get all anchor point nodes and switch nodes | // Get all anchor point nodes and switch nodes | ||||
for (const auto &node_ptr : graph->GetAllNodes()) { | for (const auto &node_ptr : graph->GetAllNodes()) { | ||||
std::vector<bool> is_node_set_format; | |||||
if (node_ptr == nullptr) { | if (node_ptr == nullptr) { | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
@@ -42,8 +42,6 @@ const std::string ATTR_NAME_BIAS = "bias"; | |||||
const std::string ATTR_NAME_BIAS_TERM = "bias_term"; | const std::string ATTR_NAME_BIAS_TERM = "bias_term"; | ||||
const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; | |||||
const std::string ATTR_NAME_PAD = "pad"; | const std::string ATTR_NAME_PAD = "pad"; | ||||
const std::string ATTR_NAME_PADS = "pad"; | const std::string ATTR_NAME_PADS = "pad"; | ||||
@@ -85,7 +83,6 @@ const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; | |||||
const std::string ATTR_NAME_AXIS = "axis"; | const std::string ATTR_NAME_AXIS = "axis"; | ||||
const std::string ATTR_NAME_BROADCAST = "broadcast"; | const std::string ATTR_NAME_BROADCAST = "broadcast"; | ||||
const std::string ATTR_NAME_OUTPUT = "output"; | |||||
const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; | const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; | ||||
const std::string ATTR_NAME_TIDX = "t_idx"; | const std::string ATTR_NAME_TIDX = "t_idx"; | ||||
@@ -106,13 +103,6 @@ const std::string ATTR_NAME_TSHAPE = "Tshape"; | |||||
const std::string ATTR_NAME_NAN_OPT = "nan_opt"; | const std::string ATTR_NAME_NAN_OPT = "nan_opt"; | ||||
const std::string ATTR_NAME_AIPP = "aipp"; | const std::string ATTR_NAME_AIPP = "aipp"; | ||||
const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; | |||||
const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||||
const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; | |||||
const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; | |||||
const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; | |||||
const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; | const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; | ||||
const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; | const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; | ||||
@@ -121,7 +111,6 @@ const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; | |||||
const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; | const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; | ||||
const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; | const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; | ||||
const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; | const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; | ||||
const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||||
const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | ||||
const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | ||||
@@ -133,11 +122,9 @@ const std::string ATTR_NAME_WEIGHTS = "value"; | |||||
const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; | const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; | ||||
const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | ||||
const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | ||||
const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; | |||||
const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; | |||||
const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; | |||||
const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; | |||||
const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; | |||||
const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||||
const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||||
// To be deleted | // To be deleted | ||||
const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; | const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; | ||||
@@ -151,13 +138,15 @@ const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; | |||||
const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | ||||
const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | ||||
const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||||
// Refinedet | // Refinedet | ||||
const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; | const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; | ||||
const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||||
const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; | const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; | ||||
const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | ||||
const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||||
const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||||
// _Arg | // _Arg | ||||
const std::string ATTR_NAME_INDEX = "index"; | const std::string ATTR_NAME_INDEX = "index"; | ||||
@@ -247,30 +236,6 @@ const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; | |||||
const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; | const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; | ||||
const std::string BATCHNORM_ATTR_SCALE = "scale"; | const std::string BATCHNORM_ATTR_SCALE = "scale"; | ||||
const std::string BATCHNORM_ATTR_BIAS = "bias"; | const std::string BATCHNORM_ATTR_BIAS = "bias"; | ||||
const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; | |||||
const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; | |||||
const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; | |||||
// huberloss | |||||
const std::string HUBER_LOSS_ATTR_DELTA = "delta"; | |||||
// SSDRealDivTileMul | |||||
const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; | |||||
// SSDSumMulRealDivMean | |||||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; | |||||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; | |||||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; | |||||
const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; | |||||
// ConcatFive2Four | |||||
// ConcatFour2Five | |||||
const std::string SSD_BOX_TYPE_NUM = "box_type_num"; | |||||
const std::string SSD_CLASS_NUM = "class_num"; | |||||
const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; | |||||
const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; | |||||
const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; | |||||
const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; | |||||
// Scale | // Scale | ||||
const std::string SCALE_ATTR_SCALE = "scale"; | const std::string SCALE_ATTR_SCALE = "scale"; | ||||
@@ -375,7 +340,6 @@ const std::string SOFTMAX_ATTR_AXIS = "axis"; | |||||
// Permute | // Permute | ||||
const std::string PERMUTE_ATTR_ORDER = "order"; | const std::string PERMUTE_ATTR_ORDER = "order"; | ||||
const std::string PERMUTE_ATTR_PERM = "perm"; | |||||
// SSD Normalize | // SSD Normalize | ||||
const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; | const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; | ||||
@@ -403,10 +367,6 @@ const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; | |||||
const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; | const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; | ||||
const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | ||||
// RefinedetDetectionOutput | |||||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||||
const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||||
// PRelu | // PRelu | ||||
const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; | const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; | ||||
@@ -420,16 +380,11 @@ const std::string POWER_ATTR_NAME_POWER = "power"; | |||||
const std::string POWER_ATTR_NAME_SCALE = "scale"; | const std::string POWER_ATTR_NAME_SCALE = "scale"; | ||||
const std::string POWER_ATTR_NAME_SHIFT = "shift"; | const std::string POWER_ATTR_NAME_SHIFT = "shift"; | ||||
// log | |||||
const std::string LOG_ATTR_NAME_SCALE = "scale"; | |||||
const std::string LOG_ATTR_NAME_SHIFT = "shift"; | |||||
const std::string LOG_ATTR_NAME_BASE = "base"; | |||||
// Pack | // Pack | ||||
const std::string PACK_ATTR_NAME_NUM = "N"; | const std::string PACK_ATTR_NAME_NUM = "N"; | ||||
// Unpack | // Unpack | ||||
const std::string UNPACK_ATTR_NAME_NUM = "num"; | const std::string UNPACK_ATTR_NAME_NUM = "num"; | ||||
const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; | |||||
// Gathernd | // Gathernd | ||||
const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; | const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; | ||||
const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; | const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; | ||||
@@ -439,13 +394,6 @@ const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; | |||||
const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; | const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; | ||||
const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; | const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; | ||||
const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; | const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; | ||||
const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; | |||||
const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; | |||||
const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; | |||||
// upsample | |||||
const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; | |||||
const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; | |||||
// Relu | // Relu | ||||
const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; | const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; | ||||
@@ -485,7 +433,6 @@ const std::string ROIALIGN_ATTR_SPATIAL_SCALE = "spatial_scale"; | |||||
const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; | const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; | ||||
const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; | const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; | ||||
const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; | const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; | ||||
const std::string ROIALIGN_ATTR_NAME_TF = "roialign_tf"; | |||||
// Generate_rpn_proposal | // Generate_rpn_proposal | ||||
const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; | const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; | ||||
@@ -584,42 +531,19 @@ const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape | |||||
const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; | const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; | ||||
// Rnn | // Rnn | ||||
const std::string RNN_TENSORFLOW = "rnn_tensorflow"; | |||||
const std::string RNN_MODE_STATIC = "rnn_static"; | |||||
const std::string MUTI_RNN = "multi_rnn"; | |||||
const std::string CNN_RNN = "cnn_rnn"; | |||||
const std::string RNN_MODE_ = "rnn_"; | const std::string RNN_MODE_ = "rnn_"; | ||||
const std::string CNN_RNN = "cnn_rnn"; | |||||
const std::string MUTI_RNN = "multi_rnn"; | |||||
const std::string CELL_MODE = "mode"; | const std::string CELL_MODE = "mode"; | ||||
const std::string LSTM_CELL = "lstm_cell"; | const std::string LSTM_CELL = "lstm_cell"; | ||||
const std::string GRU_CELL = "gru_cell"; | const std::string GRU_CELL = "gru_cell"; | ||||
const std::string RNN_HT = "ht"; | const std::string RNN_HT = "ht"; | ||||
const std::string RNN_XT_HT = "xt_ht"; | const std::string RNN_XT_HT = "xt_ht"; | ||||
const std::string RNN_BATCH_SIZE = "batch_size"; | const std::string RNN_BATCH_SIZE = "batch_size"; | ||||
const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; | |||||
const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; | |||||
const std::string LSTM_ACTIVATE = "lstm_activate"; | |||||
const std::string LSTM_OUT_MAP = "lstm_out_map"; | |||||
const std::string LSTM_OUT_MODE = "lstm_out_mode"; | |||||
const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; | |||||
const std::string LSTM_TIME_MAJOR = "lstm_time_major"; | |||||
const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; | |||||
// Upsample | // Upsample | ||||
const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; | const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; | ||||
// PadV2 | |||||
const std::string PADV2_ATTR_NAME_MODE = "mode"; | |||||
const std::string PADV2_ATTR_NAME_PADS = "paddings"; | |||||
const std::string PADV2_ATTR_NAME_T = "T"; | |||||
const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||||
const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; | |||||
// MirrorPad | |||||
const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; | |||||
const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; | |||||
const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||||
const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; | |||||
// Filler | // Filler | ||||
const std::string FILLER_TYPE = "filler_type"; | const std::string FILLER_TYPE = "filler_type"; | ||||
const std::string FILLER_VALUE = "filler_value"; | const std::string FILLER_VALUE = "filler_value"; | ||||
@@ -630,6 +554,9 @@ const std::string SHUFFLE_CHANNEL_GROUP = "group"; | |||||
// TopKV2 | // TopKV2 | ||||
const std::string TOPKV2_ATTR_K = "k"; | const std::string TOPKV2_ATTR_K = "k"; | ||||
const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||||
const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||||
// Calibaration | // Calibaration | ||||
const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; | const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; | ||||
const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; | const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; | ||||
@@ -733,121 +660,6 @@ const std::string TARGET_TYPE_TINY = "TINY"; | |||||
const std::string TARGET_TYPE_LITE = "LITE"; | const std::string TARGET_TYPE_LITE = "LITE"; | ||||
// l2_normalize | |||||
const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; | |||||
const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||||
const std::string POOL_PARAMA_ATTR_WINDOW = "window"; | |||||
const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; | |||||
const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; | |||||
const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; | |||||
const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; | |||||
const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; | |||||
// HCOM | |||||
const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; | |||||
const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; | |||||
const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; | |||||
const std::string HCOM_ATTR_GROUP = "group"; | |||||
const std::string HCOM_ATTR_SR_TAG = "sr_tag"; | |||||
const std::string HCOM_ATTR_SRC_RANK = "src_rank"; | |||||
const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; | |||||
const std::string HCOM_ATTR_FUSION = "fusion"; | |||||
const std::string HCOM_ATTR_SHAPE = "shape"; | |||||
const std::string HCOM_ATTR_DATA_TYPE = "dtype"; | |||||
// SpaceToDepth/DepthToSpace | |||||
const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; | |||||
// SparseSoftmaxCrossEntropyWithLogits | |||||
const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; | |||||
// MaxPoolGradWithArgmax | |||||
const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; | |||||
// AvgPoolGrad | |||||
const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; | |||||
// Pad | |||||
const std::string ATTR_PAD_FORMAT = "attr_pad_format"; | |||||
// Varible | |||||
const std::string VAR_ATTR_FORMAT = "_var_format"; | |||||
const std::string VAR_ATTR_NAME = "var_name"; | |||||
const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; | |||||
const std::string VAR_ATTR_4D_FORMAT = "4D"; | |||||
const std::string VAR_ATTR_5D_FORMAT = "5D"; | |||||
const std::string VAR_ATTR_DATA_TYPE = "data_format"; | |||||
const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; | |||||
const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; | |||||
const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; | |||||
const std::string VAR_ATTR_SHAPE = "shape"; | |||||
const std::string HALF_VAR_NAME_END = "_fp16"; | |||||
const std::string VAR_ATTR_INITED = "var_is_inited"; | |||||
const std::string VAR_ATTR_CONTAINER = "container"; | |||||
const std::string VAR_ATTR_SHARED_NAME = "shared_name"; | |||||
const std::string VAR_ATTR_DTYPE = "dtype"; | |||||
const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; | |||||
const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; | |||||
const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; | |||||
const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; | |||||
const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; | |||||
const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; | |||||
// Assign | |||||
const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; | |||||
// space2bacth batch2space | |||||
const std::string BATCH_SPACE_ATTR_BLOCK = "block"; | |||||
const std::string BATCH_SPACE_ATTR_PADDING = "padding"; | |||||
// depth_to_space space_to_depth | |||||
const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||||
// FakeQuantWithMinMaxVars | |||||
const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; | |||||
const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; | |||||
// mobilenet_ssd_conv_fusion | |||||
const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; | |||||
const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; | |||||
const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; | |||||
// lsh project | |||||
const std::string LSH_PROJ_TYPE = "lsh_project_type"; | |||||
// log time stamp | |||||
const std::string LOG_TIME_STAMP_LOGID = "logid"; | |||||
const std::string LOG_TIME_STAMP_NOTIFY = "notify"; | |||||
// ShapeN | |||||
const std::string SHAPEN_ATTR_N = "N"; | |||||
const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; | |||||
const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; | |||||
// GatherV2 attr def | |||||
const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; | |||||
const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; | |||||
const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; | |||||
// Reshape attr def | |||||
const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; | |||||
const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; | |||||
// axis attr def | |||||
const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; | |||||
const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; | |||||
const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; | |||||
const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; | |||||
// For constant folding | |||||
const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; | |||||
const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; | const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; | ||||
const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; | const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; | ||||
@@ -882,8 +694,6 @@ const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; | |||||
const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; | const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; | ||||
const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; | const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; | ||||
const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | ||||
const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; | |||||
const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; | |||||
const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | ||||
const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | ||||
@@ -954,14 +764,7 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_ou | |||||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | ||||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | ||||
// functional ops attr | |||||
const std::string ATTR_NAME_WHILE_COND = "cond"; | |||||
const std::string ATTR_NAME_WHILE_BODY = "body"; | |||||
// used for label switch | // used for label switch | ||||
const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | ||||
const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; | const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; | ||||
const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | |||||
const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | |||||
} // namespace ge | } // namespace ge |
@@ -21,8 +21,9 @@ | |||||
#include <unordered_map> | #include <unordered_map> | ||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include "framework/common/types.h" | |||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
#include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
#include "external/graph/operator.h" | #include "external/graph/operator.h" | ||||
@@ -28,7 +28,6 @@ | |||||
#include <cstring> | #include <cstring> | ||||
#include <fstream> | #include <fstream> | ||||
#include <iomanip> | #include <iomanip> | ||||
#include <queue> | |||||
#include "./ge_context.h" | #include "./ge_context.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
@@ -2000,60 +1999,4 @@ void PartialGraphBuilder::BuildExistNodes(graphStatus &error_code, std::string & | |||||
GELOGD("Build exist nodes succ."); | GELOGD("Build exist nodes succ."); | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
GraphUtils::TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec) { | |||||
std::vector<NodePtr> stack_input; | |||||
std::map<NodePtr, uint32_t> map_in_edge_num; | |||||
graphStatus ret = compute_graph->SortNodes(stack_input, map_in_edge_num); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Sort nodes failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
const size_t non_user_input_index = stack_input.size() - compute_graph->inputs_order_.size() - 1; | |||||
std::sort(stack_input.begin(), stack_input.begin() + non_user_input_index, | |||||
[](const NodePtr &a, const NodePtr &b) -> bool { return (a->GetName() > b->GetName()); }); | |||||
std::queue<NodePtr> stack; | |||||
NodePtr cur_node = nullptr; | |||||
std::map<string, NodePtr> name_node_map; | |||||
vector<string> nodes_name; | |||||
while (!stack_input.empty() || !stack.empty()) { | |||||
if (!stack.empty()) { | |||||
cur_node = stack.front(); | |||||
stack.pop(); | |||||
} else { | |||||
cur_node = stack_input.back(); | |||||
stack_input.pop_back(); | |||||
} | |||||
node_vec.emplace_back(cur_node); | |||||
compute_graph->CollectBreadthOutNode(cur_node, map_in_edge_num, name_node_map); | |||||
for (const auto &iter : name_node_map) { | |||||
nodes_name.emplace_back(iter.first); | |||||
} | |||||
std::sort(nodes_name.begin(), nodes_name.end()); | |||||
for (const auto &iter : nodes_name) { | |||||
stack.push(name_node_map[iter]); | |||||
} | |||||
name_node_map.clear(); | |||||
nodes_name.clear(); | |||||
} | |||||
// If they are not equal, there is a closed loop | |||||
if (node_vec.size() != compute_graph->nodes_.size()) { | |||||
std::set<Node *> itered_nodes_set; | |||||
for (auto &node : node_vec) { | |||||
itered_nodes_set.insert(node.get()); | |||||
} | |||||
GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", | |||||
compute_graph->nodes_.size(), node_vec.size()); | |||||
for (auto &node : compute_graph->nodes_) { | |||||
if (itered_nodes_set.count(node.get()) == 0) { | |||||
GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); | |||||
} | |||||
} | |||||
return GRAPH_FAILED; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -282,7 +282,6 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||||
case FORMAT_FRACTAL_Z_3D: | case FORMAT_FRACTAL_Z_3D: | ||||
case FORMAT_FRACTAL_Z_3D_TRANSPOSE: | case FORMAT_FRACTAL_Z_3D_TRANSPOSE: | ||||
case FORMAT_NDC1HWC0: | case FORMAT_NDC1HWC0: | ||||
case FORMAT_FRACTAL_Z_C04: | |||||
graph_status = CalcElementCntByDims(dims, element_cnt); | graph_status = CalcElementCntByDims(dims, element_cnt); | ||||
break; | break; | ||||
default: | default: | ||||
@@ -41,9 +41,9 @@ include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||||
include_directories(${GE_SOURCE_DIR}/inc/framework) | include_directories(${GE_SOURCE_DIR}/inc/framework) | ||||
include_directories(${GE_SOURCE_DIR}/inc/framework/common) | include_directories(${GE_SOURCE_DIR}/inc/framework/common) | ||||
include_directories(${GE_SOURCE_DIR}/inc/runtime) | include_directories(${GE_SOURCE_DIR}/inc/runtime) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
@@ -55,7 +55,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"common/formats/utils/formats_trans_utils.cc" | "common/formats/utils/formats_trans_utils.cc" | ||||
"common/fp16_t.cc" | "common/fp16_t.cc" | ||||
"common/ge/plugin_manager.cc" | "common/ge/plugin_manager.cc" | ||||
"common/helper/model_cache_helper.cc" | |||||
"common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
"engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
"ge_local_engine/engine/host_cpu_engine.cc" | "ge_local_engine/engine/host_cpu_engine.cc" | ||||
@@ -93,7 +92,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/load/new_model_manager/task_info/kernel_task_info.cc" | "graph/load/new_model_manager/task_info/kernel_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/label_goto_task_info.cc" | "graph/load/new_model_manager/task_info/label_goto_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/label_set_task_info.cc" | "graph/load/new_model_manager/task_info/label_set_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
@@ -196,7 +194,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/prevent_gradient_pass.cc" | "graph/passes/prevent_gradient_pass.cc" | ||||
"graph/passes/print_op_pass.cc" | "graph/passes/print_op_pass.cc" | ||||
"graph/passes/prune_pass.cc" | "graph/passes/prune_pass.cc" | ||||
"graph/passes/replace_with_empty_const_pass.cc" | |||||
"graph/passes/reshape_remove_pass.cc" | "graph/passes/reshape_remove_pass.cc" | ||||
"graph/passes/resource_pair_add_control_pass.cc" | "graph/passes/resource_pair_add_control_pass.cc" | ||||
"graph/passes/resource_pair_remove_control_pass.cc" | "graph/passes/resource_pair_remove_control_pass.cc" | ||||
@@ -271,7 +268,6 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"common/formats/utils/formats_trans_utils.cc" | "common/formats/utils/formats_trans_utils.cc" | ||||
"common/fp16_t.cc" | "common/fp16_t.cc" | ||||
"common/ge/plugin_manager.cc" | "common/ge/plugin_manager.cc" | ||||
"common/helper/model_cache_helper.cc" | |||||
"common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
"engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
"ge_local_engine/engine/host_cpu_engine.cc" | "ge_local_engine/engine/host_cpu_engine.cc" | ||||
@@ -308,7 +304,6 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/load/new_model_manager/task_info/kernel_task_info.cc" | "graph/load/new_model_manager/task_info/kernel_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/label_goto_task_info.cc" | "graph/load/new_model_manager/task_info/label_goto_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/label_set_task_info.cc" | "graph/load/new_model_manager/task_info/label_set_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
@@ -409,7 +404,6 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/passes/prevent_gradient_pass.cc" | "graph/passes/prevent_gradient_pass.cc" | ||||
"graph/passes/print_op_pass.cc" | "graph/passes/print_op_pass.cc" | ||||
"graph/passes/prune_pass.cc" | "graph/passes/prune_pass.cc" | ||||
"graph/passes/replace_with_empty_const_pass.cc" | |||||
"graph/passes/reshape_remove_pass.cc" | "graph/passes/reshape_remove_pass.cc" | ||||
"graph/passes/resource_pair_add_control_pass.cc" | "graph/passes/resource_pair_add_control_pass.cc" | ||||
"graph/passes/resource_pair_remove_control_pass.cc" | "graph/passes/resource_pair_remove_control_pass.cc" | ||||
@@ -474,7 +468,7 @@ target_link_libraries(ge_compiler | |||||
${slog} | ${slog} | ||||
${mmpa} | ${mmpa} | ||||
${msprof} | ${msprof} | ||||
${runtime_compiler} | |||||
${runtime} | |||||
${resouce} | ${resouce} | ||||
rt | rt | ||||
dl) | dl) |
@@ -46,6 +46,7 @@ include_directories(${GE_SOURCE_DIR}/inc/framework) | |||||
include_directories(${GE_SOURCE_DIR}/inc/graph) | include_directories(${GE_SOURCE_DIR}/inc/graph) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
@@ -32,6 +32,8 @@ | |||||
using domi::GetContext; | using domi::GetContext; | ||||
using domi::OpRegistry; | using domi::OpRegistry; | ||||
using domi::RealPath; | |||||
using domi::StringUtils; | |||||
using std::map; | using std::map; | ||||
using std::string; | using std::string; | ||||
using std::vector; | using std::vector; | ||||
@@ -41,7 +41,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" | "formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" | ||||
"formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" | "formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" | ||||
"formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" | "formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" | ||||
"formats/format_transfers/format_transfer_nchw_fz_c04.cc" | |||||
"formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" | "formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" | ||||
"formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" | "formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" | ||||
"formats/format_transfers/format_transfer_transpose.cc" | "formats/format_transfers/format_transfer_transpose.cc" | ||||
@@ -80,6 +79,7 @@ include_directories(${GE_SOURCE_DIR}/inc/framework) | |||||
include_directories(${GE_SOURCE_DIR}/inc/graph) | include_directories(${GE_SOURCE_DIR}/inc/graph) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
@@ -17,6 +17,7 @@ | |||||
#include "common/auth/file_saver.h" | #include "common/auth/file_saver.h" | ||||
#include <fcntl.h> | #include <fcntl.h> | ||||
#include <securec.h> | #include <securec.h> | ||||
#include <unistd.h> | #include <unistd.h> | ||||
#include <cstdlib> | #include <cstdlib> | ||||
@@ -28,6 +29,8 @@ | |||||
#include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
#include "framework/common/util.h" | #include "framework/common/util.h" | ||||
using domi::CreateDirectory; | |||||
using domi::ModelEncryptType; | |||||
using ge::ModelBufferData; | using ge::ModelBufferData; | ||||
namespace { | namespace { | ||||
@@ -267,4 +270,4 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(co | |||||
} | } | ||||
return ret; | return ret; | ||||
} | } | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -26,26 +26,30 @@ | |||||
#include "graph/buffer.h" | #include "graph/buffer.h" | ||||
#include "mmpa/mmpa_api.h" | #include "mmpa/mmpa_api.h" | ||||
using domi::ModelFileHeader; | |||||
using domi::ModelPartition; | |||||
using domi::ModelPartitionTable; | |||||
struct PROC_PARAM { | struct PROC_PARAM { | ||||
uint8_t *model_name; | uint8_t *model_name; | ||||
// ISV Ek buffer | |||||
/* ISV Ek buffer */ | |||||
uint8_t *model_key; | uint8_t *model_key; | ||||
uint32_t model_key_len; | uint32_t model_key_len; | ||||
// ISV root certificate buffer | |||||
/* ISV root certificate buffer */ | |||||
uint8_t *root_cert; | uint8_t *root_cert; | ||||
uint32_t root_cert_len; | uint32_t root_cert_len; | ||||
// ISV private key buffer | |||||
/* ISV private key buffer */ | |||||
uint8_t *pri_key; | uint8_t *pri_key; | ||||
uint32_t pri_key_len; | uint32_t pri_key_len; | ||||
// Raw AI Module Image buffer | |||||
/* Raw AI Module Image buffer */ | |||||
uint8_t *ai_image; | uint8_t *ai_image; | ||||
uint32_t ai_image_len; | uint32_t ai_image_len; | ||||
// ISV HW key buffer | |||||
/* ISV HW key buffer */ | |||||
uint8_t *hw_key; | uint8_t *hw_key; | ||||
uint32_t hw_key_len; | uint32_t hw_key_len; | ||||
}; | }; | ||||
@@ -62,11 +66,11 @@ using std::string; | |||||
class FileSaver { | class FileSaver { | ||||
public: | public: | ||||
/// | |||||
/// @ingroup domi_common | |||||
/// @brief save model, no encryption | |||||
/// @return Status result | |||||
/// | |||||
/** | |||||
* @ingroup domi_common | |||||
* @brief save model, no encryption | |||||
* @return Status result | |||||
*/ | |||||
static Status SaveToFile(const string &file_path, const ge::ModelData &model, | static Status SaveToFile(const string &file_path, const ge::ModelData &model, | ||||
const ModelFileHeader *model_file_header = nullptr); | const ModelFileHeader *model_file_header = nullptr); | ||||
@@ -80,26 +84,26 @@ class FileSaver { | |||||
static Status SaveToFile(const string &file_path, const void *data, int len); | static Status SaveToFile(const string &file_path, const void *data, int len); | ||||
protected: | protected: | ||||
/// | |||||
/// @ingroup domi_common | |||||
/// @brief Check validity of the file path | |||||
/// @return Status result | |||||
/// | |||||
/** | |||||
* @ingroup domi_common | |||||
* @brief Check validity of the file path | |||||
* @return Status result | |||||
*/ | |||||
static Status CheckPath(const string &file_path); | static Status CheckPath(const string &file_path); | ||||
static Status WriteData(const void *data, uint32_t size, int32_t fd); | static Status WriteData(const void *data, uint32_t size, int32_t fd); | ||||
static Status OpenFile(int32_t &fd, const std::string &file_path); | static Status OpenFile(int32_t &fd, const std::string &file_path); | ||||
/// | |||||
/// @ingroup domi_common | |||||
/// @brief save model to file | |||||
/// @param [in] file_path file output path | |||||
/// @param [in] file_header file header info | |||||
/// @param [in] data model data | |||||
/// @param [in] len model length | |||||
/// @return Status result | |||||
/// | |||||
/** | |||||
* @ingroup domi_common | |||||
* @brief save model to file | |||||
* @param [in] file_path file output path | |||||
* @param [in] file_header file header info | |||||
* @param [in] data model data | |||||
* @param [in] len model length | |||||
* @return Status result | |||||
*/ | |||||
static Status SaveWithFileHeader(const string &file_path, const ModelFileHeader &file_header, const void *data, | static Status SaveWithFileHeader(const string &file_path, const ModelFileHeader &file_header, const void *data, | ||||
int len); | int len); | ||||
@@ -16,7 +16,6 @@ | |||||
#include "framework/omg/omg_inner_types.h" | #include "framework/omg/omg_inner_types.h" | ||||
using ge::OmgContext; | |||||
namespace domi { | namespace domi { | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OmgContext &GetContext() { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OmgContext &GetContext() { | ||||
static OmgContext context; | static OmgContext context; | ||||
@@ -134,6 +134,10 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
} | } | ||||
auto trans_mode = iter->second; | auto trans_mode = iter->second; | ||||
if (args.src_data_size == 0) { | |||||
GELOGE(PARAM_INVALID, "Invalid src data size %zu", args.src_data_size); | |||||
return PARAM_INVALID; | |||||
} | |||||
int size = GetSizeByDataType(args.dst_data_type); | int size = GetSizeByDataType(args.dst_data_type); | ||||
if (size <= 0) { | if (size <= 0) { | ||||
GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", | GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", | ||||
@@ -145,12 +149,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
size_t total_size = static_cast<size_t>(args.src_data_size * size); | size_t total_size = static_cast<size_t>(args.src_data_size * size); | ||||
result.length = total_size; | |||||
if (total_size == 0) { | |||||
GELOGI("In TransDataType, total_size is zero, has no data."); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | ||||
@@ -164,6 +162,7 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
result.data = dst; | result.data = dst; | ||||
result.length = total_size; | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -27,9 +27,7 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
namespace { | namespace { | ||||
bool CheckDataTypeSupported(const DataType &data_type) { | |||||
return (data_type == DT_FLOAT || data_type == DT_FLOAT16 || data_type == DT_INT8); | |||||
} | |||||
bool CheckDataTypeSupported(const DataType &data_type) { return (data_type == DT_FLOAT || data_type == DT_FLOAT16); } | |||||
Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | ||||
auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
@@ -53,11 +51,10 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||||
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
auto cube_size = GetCubeSizeByDataType(args.src_data_type); | |||||
if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / cube_size + 1 || | |||||
if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / kCubeSize + 1 || | |||||
src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || | src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || | ||||
src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || | |||||
src_shape.at(kC1hwncoc0C0) != cube_size) { | |||||
src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != kCubeSize || | |||||
src_shape.at(kC1hwncoc0C0) != kCubeSize) { | |||||
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | ||||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -81,7 +78,6 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||||
auto c0 = args.src_shape.at(kC1hwncoc0C0); | auto c0 = args.src_shape.at(kC1hwncoc0C0); | ||||
auto co = args.src_shape.at(kC1hwncoc0Co); | auto co = args.src_shape.at(kC1hwncoc0Co); | ||||
auto c = args.dst_shape.at(kHwcnC); | auto c = args.dst_shape.at(kHwcnC); | ||||
auto cube_size = GetCubeSizeByDataType(args.src_data_type); | |||||
int64_t cn = c * n; | int64_t cn = c * n; | ||||
int64_t wcn = w * cn; | int64_t wcn = w * cn; | ||||
int64_t coc0 = co * c0; | int64_t coc0 = co * c0; | ||||
@@ -97,8 +93,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||||
int64_t c_head_addr = w_head_addr + c_idx * n; | int64_t c_head_addr = w_head_addr + c_idx * n; | ||||
for (int64_t n_idx = 0; n_idx < n; n_idx++) { | for (int64_t n_idx = 0; n_idx < n; n_idx++) { | ||||
int64_t dst_idx = c_head_addr + n_idx; | int64_t dst_idx = c_head_addr + n_idx; | ||||
int64_t c1_idx = c_idx / cube_size; | |||||
int64_t c0_idx = c_idx % cube_size; | |||||
int64_t c1_idx = c_idx / kCubeSize; | |||||
int64_t c0_idx = c_idx % kCubeSize; | |||||
int64_t co_idx = c0_idx; | int64_t co_idx = c0_idx; | ||||
int64_t src_idx = c1_idx * hwncoc0 + h_idx * wncoc0 + w_idx * ncoc0 + n_idx * coc0 + co_idx * c0 + c0_idx; | int64_t src_idx = c1_idx * hwncoc0 + h_idx * wncoc0 + w_idx * ncoc0 + n_idx * coc0 + co_idx * c0 + c0_idx; | ||||
auto src_offset = src_idx * size; | auto src_offset = src_idx * size; | ||||
@@ -134,11 +130,6 @@ Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResu | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -88,11 +88,6 @@ Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | |||||
dst_size *= dim; | dst_size *= dim; | ||||
} | } | ||||
dst_size *= data_size; | dst_size *= data_size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -89,11 +89,6 @@ Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &resul | |||||
dst_size *= dim; | dst_size *= dim; | ||||
} | } | ||||
dst_size *= data_size; | dst_size *= data_size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -116,11 +116,6 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -189,11 +184,6 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||||
Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -119,11 +119,6 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = total_ele_cnt * size; | int64_t dst_size = total_ele_cnt * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -199,11 +194,6 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||||
dst_size *= dim; | dst_size *= dim; | ||||
} | } | ||||
dst_size *= data_size; | dst_size *= data_size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -269,11 +259,6 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||||
dst_size *= dim; | dst_size *= dim; | ||||
} | } | ||||
dst_size *= data_size; | dst_size *= data_size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -117,11 +117,6 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -194,11 +189,6 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -133,12 +133,6 @@ Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult & | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -133,12 +133,6 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -146,7 +140,6 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||||
GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
@@ -132,12 +132,6 @@ Status FormatTransferFracZNhwc::TransFormat(const TransArgs &args, TransResult & | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -27,20 +27,16 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
namespace { | namespace { | ||||
bool CheckDataTypeSupported(const DataType &data_type) { | |||||
return (data_type == DT_FLOAT || data_type == DT_FLOAT16 || data_type == DT_INT8); | |||||
} | |||||
bool CheckDataTypeSupported(const DataType &data_type) { return (data_type == DT_FLOAT || data_type == DT_FLOAT16); } | |||||
Status TransShapeHwcnToC1hwncoc0(const DataType &data_type, const std::vector<int64_t> &src_shape, | |||||
std::vector<int64_t> &dst_shape) { | |||||
auto cube_size = GetCubeSizeByDataType(data_type); | |||||
Status TransShapeHwcnToC1hwncoc0(const std::vector<int64_t> &src_shape, std::vector<int64_t> &dst_shape) { | |||||
dst_shape.clear(); | dst_shape.clear(); | ||||
dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast<int64_t>(cube_size))); | |||||
dst_shape.push_back((src_shape.at(kHwcnC) - 1) / kCubeSize + 1); | |||||
dst_shape.push_back(src_shape.at(kHwcnH)); | dst_shape.push_back(src_shape.at(kHwcnH)); | ||||
dst_shape.push_back(src_shape.at(kHwcnW)); | dst_shape.push_back(src_shape.at(kHwcnW)); | ||||
dst_shape.push_back(src_shape.at(kHwcnN)); | dst_shape.push_back(src_shape.at(kHwcnN)); | ||||
dst_shape.push_back(cube_size); | |||||
dst_shape.push_back(cube_size); | |||||
dst_shape.push_back(kCubeSize); | |||||
dst_shape.push_back(kCubeSize); | |||||
if (!CheckShapeValid(dst_shape, kC1hwncoc0DimsNum)) { | if (!CheckShapeValid(dst_shape, kC1hwncoc0DimsNum)) { | ||||
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -69,7 +65,7 @@ Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
std::vector<int64_t> expect_dst_shape; | std::vector<int64_t> expect_dst_shape; | ||||
auto ret = TransShapeHwcnToC1hwncoc0(args.src_data_type, args.src_shape, expect_dst_shape); | |||||
auto ret = TransShapeHwcnToC1hwncoc0(args.src_shape, expect_dst_shape); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -169,12 +165,6 @@ Status FormatTransferHwcnC1hwncoc0::TransFormat(const TransArgs &args, TransResu | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -198,7 +188,7 @@ Status FormatTransferHwcnC1hwncoc0::TransShape(Format src_format, const std::vec | |||||
GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
return TransShapeHwcnToC1hwncoc0(data_type, src_shape, dst_shape); | |||||
return TransShapeHwcnToC1hwncoc0(src_shape, dst_shape); | |||||
} else { | } else { | ||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
@@ -58,7 +58,7 @@ Status CheckArgsForNc1hwc0ToNchw(const TransArgs &args) { | |||||
} | } | ||||
if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNchwH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNchwW) || | if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNchwH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNchwW) || | ||||
src_shape.at(kNc1hwc0N) != dst_shape.at(kNchwN) || src_shape.at(kNc1hwc0C0) != c0 || | src_shape.at(kNc1hwc0N) != dst_shape.at(kNchwN) || src_shape.at(kNc1hwc0C0) != c0 || | ||||
src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNchwC), c0))) { | |||||
src_shape.at(kNc1hwc0C1) != (dst_shape.at(kNchwC) - 1) / c0 + 1) { | |||||
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | ||||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -130,12 +130,6 @@ Status FormatTransferNc1hwc0Nchw::TransFormat(const TransArgs &args, TransResult | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -58,7 +58,7 @@ Status CheckArgsForNc1hwc0ToNhwc(const TransArgs &args) { | |||||
} | } | ||||
if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNhwcH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNhwcW) || | if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNhwcH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNhwcW) || | ||||
src_shape.at(kNc1hwc0N) != dst_shape.at(kNhwcN) || src_shape.at(kNc1hwc0C0) != c0 || | src_shape.at(kNc1hwc0N) != dst_shape.at(kNhwcN) || src_shape.at(kNc1hwc0C0) != c0 || | ||||
src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNhwcC), c0))) { | |||||
src_shape.at(kNc1hwc0C1) != (dst_shape.at(kNhwcC) - 1) / c0 + 1) { | |||||
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | ||||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -130,12 +130,6 @@ Status FormatTransferNc1hwc0Nhwc::TransFormat(const TransArgs &args, TransResult | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -1,314 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "common/formats/format_transfers/format_transfer_nchw_fz_c04.h" | |||||
#include "common/formats/format_transfers/format_transfer_transpose.h" | |||||
#include <securec.h> | |||||
#include <memory> | |||||
#include <stdlib.h> | |||||
#include "common/formats/utils/formats_definitions.h" | |||||
#include "common/formats/utils/formats_trans_utils.h" | |||||
#include "common/util.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "graph/utils/type_utils.h" | |||||
/** 【Explain about transfer from nchw to FZ_CO4】 | |||||
* First Step: Padding in N and C axis. Here C must be less or equal than 4 | |||||
* After Padding, it will be like (n = ceil(n,16)*16, 4, h, w) | |||||
* Second Step: transpose. It will be like (n = ceil(n,16)*16, h, w, 4) | |||||
* Third Step: View the 4D as 2D , first dim is N, second dim is h*w*c. | |||||
* Padding to (N, ceil(Z/16)*16) | |||||
* Last Step: View the (N, ceil(Z/16)*16) as 4D (N/16, 16, C/16, 16) and transpose to (C/16, N/16, 16, 16) | |||||
*/ | |||||
namespace ge { | |||||
namespace formats { | |||||
namespace { | |||||
constexpr int64_t kMaxDimsNumC = 4; | |||||
Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } | |||||
Status TransShape(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape) { | |||||
auto c0 = GetCubeSizeByDataType(data_type); | |||||
if (c0 < 0) { | |||||
return UNSUPPORTED; | |||||
} | |||||
auto chw = c * h * w; | |||||
auto first_dim = Ceil(chw, c0); | |||||
auto no = Ceil(n, static_cast<int64_t>(c0)); | |||||
dst_shape.clear(); | |||||
dst_shape.push_back(first_dim); | |||||
dst_shape.push_back(no); | |||||
dst_shape.push_back(c0); | |||||
dst_shape.push_back(c0); | |||||
if (!IsShapeValid(dst_shape)) { | |||||
GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status TransShapeNchwToFzC04(const std::vector<int64_t> &src_shape, DataType data_type, | |||||
std::vector<int64_t> &dst_shape) { | |||||
if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | |||||
return PARAM_INVALID; | |||||
} | |||||
auto n = src_shape.at(kNchwN); | |||||
auto c = src_shape.at(kNchwC); | |||||
auto h = src_shape.at(kNchwH); | |||||
auto w = src_shape.at(kNchwW); | |||||
return TransShape(n, c, h, w, data_type, dst_shape); | |||||
} | |||||
Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||||
int64_t n = args.src_shape.at(kNchwN); | |||||
int64_t c = args.src_shape.at(kNchwC); | |||||
int64_t h = args.src_shape.at(kNchwH); | |||||
int64_t w = args.src_shape.at(kNchwW); | |||||
int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
int size = GetSizeByDataType(args.src_data_type); | |||||
auto data = args.data; | |||||
TransResult trans_result_1; | |||||
std::vector<int64_t> perm_arg_1 = {0, 2, 3, 1}; | |||||
std::vector<int64_t> expect_shape = {n, h, w, c}; | |||||
auto ret = ge::formats::Transpose(data, args.src_shape, args.src_data_type, perm_arg_1, trans_result_1); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to Transpose from NCHW to HWCN"); | |||||
return NOT_CHANGED; | |||||
} | |||||
TransArgs args_tmp = args; | |||||
args_tmp.src_shape = expect_shape; | |||||
args_tmp.data = trans_result_1.data.get(); | |||||
// check size it should be same with original | |||||
size_t expect_size = n * c * h * w * size; // before has do check about mul | |||||
if (trans_result_1.length != expect_size) { | |||||
GELOGE(INTERNAL_ERROR, "size is not match after transpose!"); | |||||
return NOT_CHANGED; | |||||
} | |||||
/* prepare for padding in chw*/ | |||||
int64_t tmp = h * w * c; | |||||
int64_t n_o = Ceil(n, static_cast<int64_t>(c0)); | |||||
int64_t c_o = c0; | |||||
int64_t h_o = Ceil(tmp, c0); | |||||
int64_t w_o = c0; | |||||
std::vector<int64_t> shape_o = {n_o, c_o, h_o, w_o}; | |||||
// data overflow check totally | |||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), | |||||
GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", h_o, w_o); | |||||
return INTERNAL_ERROR); | |||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), | |||||
GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", n_o, c_o); | |||||
return INTERNAL_ERROR); | |||||
auto t1 = h_o * w_o; | |||||
auto t2 = n_o * c_o; | |||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", t1, t2); | |||||
return INTERNAL_ERROR); | |||||
int64_t total_ele_cnt = n_o * c_o * h_o * w_o; | |||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), | |||||
GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", total_ele_cnt, size); | |||||
return INTERNAL_ERROR); | |||||
int64_t dst_size = total_ele_cnt * size; | |||||
if (dst_size == 0) { | |||||
result.length = 0; | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
if (dst == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||||
return OUT_OF_MEMORY; | |||||
} | |||||
auto retMem = memset_s(dst.get(), dst_size, 0, dst_size); | |||||
if (retMem != EOK) { | |||||
GELOGE(INTERNAL_ERROR, "memst failed!"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
// copy data | |||||
auto block = c * h * w * size; | |||||
auto stride = h_o * w_o * size; | |||||
auto p_s = trans_result_1.data.get(); | |||||
auto p_d = dst.get(); | |||||
auto protectSize = dst_size; | |||||
for (auto k = 0; k < n; k++) { | |||||
ret = memcpy_s(p_d + k * stride, protectSize, p_s + k * block, block); | |||||
if (ret != EOK) { | |||||
GELOGE(INTERNAL_ERROR, "memcpy_s failed!"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
protectSize = protectSize - block; | |||||
} | |||||
// transpose : 2,0,1,3 | |||||
std::vector<int64_t> perm_arg_2 = {2, 0, 1, 3}; | |||||
ret = ge::formats::Transpose(dst.get(), shape_o, args.src_data_type, perm_arg_2, result); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to Transpose from NCHW to HWCN"); | |||||
return NOT_CHANGED; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uint8_t> &dst) { | |||||
args_tmp = args; | |||||
auto src_shape = args_tmp.src_shape; | |||||
if (!CheckShapeValid(src_shape, kNchwDimsNum)) { | |||||
return PARAM_INVALID; | |||||
} | |||||
int64_t c0 = GetCubeSizeByDataType(args.src_data_type); | |||||
auto n = src_shape.at(kNchwN); | |||||
auto c = src_shape.at(kNchwC); | |||||
auto h = src_shape.at(kNchwH); | |||||
auto w = src_shape.at(kNchwW); | |||||
if (c > kMaxDimsNumC) { | |||||
GELOGE(PARAM_INVALID, "Invalie dim c num[%lu].It should be in (0,4]", c); | |||||
return PARAM_INVALID; | |||||
} | |||||
auto n_o = Ceil(n, c0) * c0; | |||||
auto c_o = kMaxDimsNumC; | |||||
auto h_o = h; | |||||
auto w_o = w; | |||||
args_tmp.src_shape.at(kNchwN) = n_o; | |||||
args_tmp.src_shape.at(kNchwC) = c_o; | |||||
args_tmp.src_shape.at(kNchwH) = h_o; | |||||
args_tmp.src_shape.at(kNchwW) = w_o; | |||||
// data overflow check | |||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), | |||||
GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", h_o, w_o); | |||||
return INTERNAL_ERROR); | |||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), | |||||
GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", n_o, c_o); | |||||
return INTERNAL_ERROR); | |||||
auto t1 = h_o * w_o; | |||||
auto t2 = n_o * c_o; | |||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", t1, t2); | |||||
return INTERNAL_ERROR); | |||||
int64_t total_ele_cnt = n_o * c_o * h_o * w_o; | |||||
int size = GetSizeByDataType(args.src_data_type); | |||||
GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), | |||||
GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", total_ele_cnt, size); | |||||
return INTERNAL_ERROR); | |||||
int64_t dst_size = total_ele_cnt * size; | |||||
if (dst_size == 0) { | |||||
return SUCCESS; | |||||
} | |||||
dst.reset(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
if (dst == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||||
return OUT_OF_MEMORY; | |||||
} | |||||
auto ret = memset_s(dst.get(), dst_size, 0, dst_size); | |||||
if (ret != EOK) { | |||||
GELOGE(INTERNAL_ERROR, "memst failed!"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
auto p_s = args.data; | |||||
auto p_d = dst.get(); | |||||
auto block = h * w * size; | |||||
auto protectSize = dst_size; | |||||
for (int i = 0; i < n; i++) { | |||||
for (int j = 0; j < c; j++) { | |||||
ret = memcpy_s(p_d + (i * c_o * h_o * w_o + j * h_o * w_o) * size, protectSize, | |||||
p_s + (i * c * h * w + j * h * w) * size, block); | |||||
if (ret != EOK) { | |||||
GELOGE(INTERNAL_ERROR, "memcpy_s failed!"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
protectSize = protectSize - block; | |||||
} | |||||
} | |||||
args_tmp.data = dst.get(); | |||||
return SUCCESS; | |||||
} | |||||
} // namespace | |||||
Status FormatTransferNchwToFZC04::TransFormat(const TransArgs &args, TransResult &result) { | |||||
GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str()); | |||||
TransArgs args_tmp = args; | |||||
std::shared_ptr<uint8_t> dst = nullptr; | |||||
auto ret = PaddingNC(args, args_tmp, dst); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Padding in NC axis failed!"); | |||||
return ret; | |||||
} | |||||
std::vector<int64_t> expect_shape; | |||||
ret = TransShape(args_tmp.src_format, args_tmp.src_shape, args_tmp.src_data_type, args_tmp.dst_format, expect_shape); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | |||||
if (!args_tmp.dst_shape.empty() && args_tmp.dst_shape != expect_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
TypeUtils::FormatToSerialString(args_tmp.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args_tmp.dst_format).c_str(), ShapeToString(args_tmp.dst_shape).c_str(), | |||||
ShapeToString(expect_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (args_tmp.src_format == FORMAT_NCHW && args_tmp.dst_format == FORMAT_FRACTAL_Z_C04) { | |||||
return TransFormatFromNchwToFzC04(args_tmp, result); | |||||
} | |||||
return UNSUPPORTED; | |||||
} | |||||
Status FormatTransferNchwToFZC04::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | |||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | |||||
if (CheckDataTypeSupport(data_type) != SUCCESS) { | |||||
return UNSUPPORTED; | |||||
} | |||||
if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z_C04) { | |||||
return TransShapeNchwToFzC04(src_shape, data_type, dst_shape); | |||||
} | |||||
return UNSUPPORTED; | |||||
} | |||||
REGISTER_FORMAT_TRANSFER(FormatTransferNchwToFZC04, FORMAT_NCHW, FORMAT_FRACTAL_Z_C04) | |||||
} // namespace formats | |||||
} // namespace ge |
@@ -1,35 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_NCHW_FZC04_H_ | |||||
#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_NCHW_FZC04_H_ | |||||
#include <vector> | |||||
#include "common/formats/format_transfers/format_transfer.h" | |||||
namespace ge { | |||||
namespace formats { | |||||
class FormatTransferNchwToFZC04 : public FormatTransfer { | |||||
public: | |||||
Status TransFormat(const ge::formats::TransArgs &args, ge::formats::TransResult &result) override; | |||||
Status TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, Format dst_format, | |||||
std::vector<int64_t> &dst_shape) override; | |||||
}; | |||||
} // namespace formats | |||||
} // namespace ge | |||||
#endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_Z_H_ |
@@ -40,7 +40,7 @@ Status TransShapeNchwToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
} | } | ||||
dst_shape.clear(); | dst_shape.clear(); | ||||
dst_shape.push_back(src_shape.at(kNchwN)); | dst_shape.push_back(src_shape.at(kNchwN)); | ||||
dst_shape.push_back(Ceil(src_shape.at(kNchwC), c0)); | |||||
dst_shape.push_back((src_shape.at(kNchwC) - 1) / c0 + 1); | |||||
dst_shape.push_back(src_shape.at(kNchwH)); | dst_shape.push_back(src_shape.at(kNchwH)); | ||||
dst_shape.push_back(src_shape.at(kNchwW)); | dst_shape.push_back(src_shape.at(kNchwW)); | ||||
dst_shape.push_back(c0); | dst_shape.push_back(c0); | ||||
@@ -74,8 +74,25 @@ Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace | |||||
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||||
Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | |||||
if (CheckArgsForNchwToNc1hwc0(args) != SUCCESS) { | |||||
return PARAM_INVALID; | |||||
} | |||||
// Guarantee the validity of parameters in check function | |||||
int size = GetSizeByDataType(args.src_data_type); | |||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
if (total_size <= 0) { | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
GELOGD( | |||||
"Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " | |||||
"%s, dst shape %s memory size %ld", | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
ShapeToString(args.dst_shape).c_str(), total_size); | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, | GELOGE(OUT_OF_MEMORY, | ||||
@@ -152,39 +169,6 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
result.length = static_cast<size_t>(total_size); | result.length = static_cast<size_t>(total_size); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace | |||||
Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | |||||
if (CheckArgsForNchwToNc1hwc0(args) != SUCCESS) { | |||||
return PARAM_INVALID; | |||||
} | |||||
// Guarantee the validity of parameters in check function | |||||
int size = GetSizeByDataType(args.src_data_type); | |||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
if (total_size <= 0) { | |||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
GELOGD( | |||||
"Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " | |||||
"%s, dst shape %s memory size %ld", | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
ShapeToString(args.dst_shape).c_str(), total_size); | |||||
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
ShapeToString(args.dst_shape).c_str(), total_size); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status FormatTransferNchwNc1hwc0::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | Status FormatTransferNchwNc1hwc0::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | ||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | ||||
@@ -38,7 +38,7 @@ Status TransShapeNhwcToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
} | } | ||||
dst_shape.clear(); | dst_shape.clear(); | ||||
dst_shape.push_back(src_shape.at(kNhwcN)); | dst_shape.push_back(src_shape.at(kNhwcN)); | ||||
dst_shape.push_back(Ceil(src_shape.at(kNhwcC), c0)); | |||||
dst_shape.push_back((src_shape.at(kNhwcC) - 1) / c0 + 1); | |||||
dst_shape.push_back(src_shape.at(kNhwcH)); | dst_shape.push_back(src_shape.at(kNhwcH)); | ||||
dst_shape.push_back(src_shape.at(kNhwcW)); | dst_shape.push_back(src_shape.at(kNhwcW)); | ||||
dst_shape.push_back(c0); | dst_shape.push_back(c0); | ||||
@@ -161,12 +161,6 @@ Status FormatTransferNhwcNc1hwc0::TransFormat(const TransArgs &args, TransResult | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -51,8 +51,8 @@ bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<in | |||||
return false; | return false; | ||||
} | } | ||||
for (auto dim : src_shape) { | for (auto dim : src_shape) { | ||||
if (dim < 0) { | |||||
GELOGE(PARAM_INVALID, "Failed to transpose, negative dim in src shape %s", ShapeToString(src_shape).c_str()); | |||||
if (dim <= 0) { | |||||
GELOGE(PARAM_INVALID, "Failed to transpose, zero dim in src shape %s", ShapeToString(src_shape).c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -146,16 +146,12 @@ Status Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, Data | |||||
int64_t dst_ele_num = GetItemNumByShape(dst_shape); | int64_t dst_ele_num = GetItemNumByShape(dst_shape); | ||||
int64_t data_size = GetSizeByDataType(src_data_type); | int64_t data_size = GetSizeByDataType(src_data_type); | ||||
int64_t dst_size = data_size * dst_ele_num; | int64_t dst_size = data_size * dst_ele_num; | ||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
GELOGD("Begin to transpose, src shape %s, perm arg %s, dst shape %s, data type %s", JoinToString(src_shape).c_str(), | GELOGD("Begin to transpose, src shape %s, perm arg %s, dst shape %s, data type %s", JoinToString(src_shape).c_str(), | ||||
JoinToString(perm_arg).c_str(), JoinToString(dst_shape).c_str(), | JoinToString(perm_arg).c_str(), JoinToString(dst_shape).c_str(), | ||||
TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | ||||
if (dst_ele_num == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
int64_t dst_index = 0; | int64_t dst_index = 0; | ||||
std::vector<int64_t> dst_indexes(dst_shape.size()); | std::vector<int64_t> dst_indexes(dst_shape.size()); | ||||
while (dst_index < dst_ele_num) { | while (dst_index < dst_ele_num) { | ||||
@@ -24,7 +24,6 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "common/formats/utils/formats_trans_utils.h" | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
@@ -39,13 +38,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArg | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | TypeUtils::FormatToSerialString(args.dst_format).c_str()); | ||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
auto src_shape_size = GetItemNumByShape(args.src_shape); | |||||
if (args.data == nullptr && src_shape_size != 0) { | |||||
if (args.data == nullptr) { | |||||
GELOGE(PARAM_INVALID, "Invalid input null data"); | GELOGE(PARAM_INVALID, "Invalid input null data"); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
return transfer->TransFormat(args, result); | return transfer->TransFormat(args, result); | ||||
} | } | ||||
@@ -75,12 +71,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastAr | |||||
TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | ||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (args.data == nullptr && args.src_data_size != 0) { | |||||
GELOGE(PARAM_INVALID, "Invalid input null data"); | |||||
return PARAM_INVALID; | |||||
} | |||||
return transfer->TransDataType(args, result); | return transfer->TransDataType(args, result); | ||||
} | } | ||||
@@ -69,11 +69,11 @@ bool IsShapeValid(const std::vector<int64_t> &shape) { | |||||
} | } | ||||
int64_t num = 1; | int64_t num = 1; | ||||
for (auto dim : shape) { | for (auto dim : shape) { | ||||
if (dim < 0) { | |||||
GELOGE(PARAM_INVALID, "Invalid negative dim in the shape %s", ShapeToString(shape).c_str()); | |||||
if (dim < 1) { | |||||
GELOGE(PARAM_INVALID, "Invalid zero dim in the shape %s", ShapeToString(shape).c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
if (dim != 0 && kShapeItemNumMAX / dim < num) { | |||||
if (kShapeItemNumMAX / dim < num) { | |||||
GELOGE(PARAM_INVALID, "Shape overflow, the total count should be less than %ld!", kShapeItemNumMAX); | GELOGE(PARAM_INVALID, "Shape overflow, the total count should be less than %ld!", kShapeItemNumMAX); | ||||
return false; | return false; | ||||
} | } | ||||
@@ -64,9 +64,6 @@ bool IsShapeEqual(const GeShape &src, const GeShape &dst); | |||||
template <typename T> | template <typename T> | ||||
T Ceil(T n1, T n2) { | T Ceil(T n1, T n2) { | ||||
if (n1 == 0) { | |||||
return 0; | |||||
} | |||||
return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; | return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; | ||||
} | } | ||||
@@ -601,4 +601,4 @@ int16_t GetManBitLength(T man) { | |||||
return len; | return len; | ||||
} | } | ||||
}; // namespace ge | }; // namespace ge | ||||
#endif // GE_COMMON_FP16_T_H_ | |||||
#endif // GE_COMMON_FP16_T_H_ |
@@ -27,7 +27,6 @@ | |||||
#include <string> | #include <string> | ||||
#include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
#include "framework/common/util.h" | |||||
namespace ge { | namespace ge { | ||||
static const int kMaxNumOfSo = 64; | static const int kMaxNumOfSo = 64; | ||||
@@ -101,7 +100,7 @@ Status PluginManager::LoadSo(const string &path, const vector<string> &func_chec | |||||
} | } | ||||
std::string file_name = single_path.substr(single_path.rfind('/') + 1, string::npos); | std::string file_name = single_path.substr(single_path.rfind('/') + 1, string::npos); | ||||
string file_path_dlopen = RealPath(single_path.c_str()); | |||||
string file_path_dlopen = domi::RealPath(single_path.c_str()); | |||||
if (file_path_dlopen.empty()) { | if (file_path_dlopen.empty()) { | ||||
GELOGW("Failed to get realpath of %s!", single_path.c_str()); | GELOGW("Failed to get realpath of %s!", single_path.c_str()); | ||||
continue; | continue; | ||||
@@ -226,7 +225,7 @@ Status PluginManager::Load(const string &path, const vector<string> &func_check_ | |||||
} | } | ||||
std::string canonical_path_str = (std::string(canonical_path) + "/" + file_name); | std::string canonical_path_str = (std::string(canonical_path) + "/" + file_name); | ||||
string file_path_dlopen = RealPath(canonical_path_str.c_str()); | |||||
string file_path_dlopen = domi::RealPath(canonical_path_str.c_str()); | |||||
if (file_path_dlopen.empty()) { | if (file_path_dlopen.empty()) { | ||||
GELOGW("failed to get realpath of %s", canonical_path_str.c_str()); | GELOGW("failed to get realpath of %s", canonical_path_str.c_str()); | ||||
continue; | continue; | ||||
@@ -1,121 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ | |||||
#define GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ | |||||
#include <nlohmann/json.hpp> | |||||
#include <set> | |||||
#include <string> | |||||
#include "ge/ge_api_error_codes.h" | |||||
#include "graph/compute_graph.h" | |||||
#include "graph/manager/graph_var_manager.h" | |||||
#include "model/ge_model.h" | |||||
namespace ge { | |||||
using Json = nlohmann::json; | |||||
struct CacheInfo { | |||||
size_t node_num; | |||||
size_t edge_num; | |||||
size_t graph_hash; | |||||
map<std::string, size_t> nodes_hash; | |||||
CacheInfo() : node_num(0), edge_num(0), graph_hash(0) {} | |||||
}; | |||||
class ModelCacheHelper { | |||||
public: | |||||
ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph); | |||||
Status SaveCacheInfoToCache() const; | |||||
Status SaveVarManagerToCache(bool before_build) const; | |||||
Status SaveOmModelToCache(const GeModelPtr &ge_model) const; | |||||
bool IsModelCacheHit() const; | |||||
Status RecoverVarManagerFromCache() const; | |||||
Status LoadOmModelFromCache(GeModelPtr &ge_model) const; | |||||
Status RefreshComputeGraph(const ComputeGraphPtr &compute_graph); | |||||
Status ClearCache(uint32_t graph_id) const; | |||||
private: | |||||
Status GetComputeGraphHash(size_t &hash) const; | |||||
Status GetNodesHash(map<std::string, size_t> &hash_map) const; | |||||
Status GetCacheInfo(CacheInfo &cache_info) const; | |||||
Status RecoverMemResource(const Json &json) const; | |||||
Status RecoverAllocatedGraphId(const Json &json) const; | |||||
Status RecoverChangedGraphId(const Json &json) const; | |||||
Status RecoverVarAddrAndTensorDesc(const Json &json) const; | |||||
Status RecoverBroadcastInfo(const Json &json) const; | |||||
Status RecoverTransRoads(const Json &json) const; | |||||
static Status RecompileNodes(GeModelPtr &ge_model); | |||||
bool IsNodeHashSameAsCache(const map<std::string, size_t> &hash_map) const; | |||||
bool IsMemResourceSameAsCache(Json &json) const; | |||||
bool IsChangedGraphIdSameAsCache(Json &json) const; | |||||
bool IsAllocatedGraphIdSameAsCache(Json &json) const; | |||||
bool IsCurVarTensorDescSameAsCache(Json &json) const; | |||||
bool IsVarAddrMgrMapSameAsCache(Json &json) const; | |||||
bool IsBroadcastInfoSameAsCache(Json &json) const; | |||||
bool IsTransRoadsSameAsCache(Json &json) const; | |||||
bool IsVarManagerSameAsCache(Json &json) const; | |||||
bool IsVarManagerParamSameAsCache(Json &json) const; | |||||
Status SaveJsonToFile(const string &file_name, const Json &json) const; | |||||
Status LoadJsonFromFile(const string &file_name, Json &json) const; | |||||
Status GetNodesHashMapJson(Json &json) const; | |||||
Status GetMemResourceMap(Json &json) const; | |||||
Status GetVarAddrMgrMapJson(Json &json) const; | |||||
Status GetCurVarTensorDescMapJson(Json &json) const; | |||||
Status GetTransRoadsJson(Json &json) const; | |||||
Status GetChangedGraphIdJson(Json &json) const; | |||||
Status GetAllocatedGraphIdJson(Json &json) const; | |||||
Status GetBroadcastInfoJson(Json &json) const; | |||||
Status GetVarResourceJson(Json &json) const; | |||||
Status GetVarManagerJson(Json &json) const; | |||||
static Status TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json); | |||||
static Status JsonToTensorDesc(const Json &json, GeTensorDesc &ge_tensor_desc); | |||||
static Status ParseMemResourceFromJson(const Json &json, map<rtMemType_t, int64_t> &mem_resource); | |||||
static Status ParseVarAddrMgrMapFromJson(const Json &json, | |||||
std::vector<std::pair<std::string, VarAddrMgr>> &var_addr_mgr_vector, | |||||
std::unordered_set<uint64_t> &var_offset_set); | |||||
static Status ParseCurVarTensorDescMapFromJson( | |||||
const Json &json, std::unordered_map<std::string, ge::GeTensorDesc> &cur_var_tensor_desc_map); | |||||
static Status ParseTransRoadsFromJson(const Json &json, | |||||
std::unordered_map<std::string, std::vector<TransNodeInfo>> &trans_roads); | |||||
static Status ParseChangedGraphIdFromJson(const Json &json, | |||||
std::unordered_map<std::string, uint32_t> &changed_graph_id); | |||||
static Status ParseAllocatedGraphIdFromJson(const Json &json, | |||||
std::unordered_map<std::string, uint32_t> &allocated_graph_id); | |||||
static Status ParseBroadcastInfoFromJson(const Json &json, | |||||
std::unordered_map<std::string, VarBroadCastInfo> &var_broadcast_info); | |||||
static Status GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, string &var_name); | |||||
uint64_t session_id_; | |||||
uint32_t graph_id_; | |||||
string cache_path_; | |||||
ComputeGraphPtr compute_graph_; | |||||
std::set<string> var_names_; | |||||
bool is_cache_path_valid_for_output; | |||||
static map<uint32_t, uint32_t> graph_id_run_times_; | |||||
}; | |||||
using ModelCacheHelperPtr = std::shared_ptr<ModelCacheHelper>; | |||||
} // namespace ge | |||||
#endif // GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ |
@@ -26,17 +26,15 @@ | |||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
using domi::ModelTaskDef; | |||||
using ge::ModelBufferData; | using ge::ModelBufferData; | ||||
using ge::TBEKernelPtr; | using ge::TBEKernelPtr; | ||||
using ge::TBEKernelStore; | using ge::TBEKernelStore; | ||||
using std::string; | using std::string; | ||||
namespace { | namespace { | ||||
const int64_t kOriginalOmPartitionNum = 1; | const int64_t kOriginalOmPartitionNum = 1; | ||||
} | } | ||||
namespace ge { | |||||
namespace domi { | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } | ||||
Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, ModelPartitionType type, | Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, ModelPartitionType type, | ||||
@@ -508,4 +506,4 @@ Status ModelHelper::ReleaseLocalModelData() noexcept { | |||||
} | } | ||||
return result; | return result; | ||||
} | } | ||||
} // namespace ge | |||||
} // namespace domi |
@@ -25,10 +25,11 @@ | |||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "framework/common/util.h" | #include "framework/common/util.h" | ||||
using ge::FileSaver; | |||||
using ge::ModelBufferData; | using ge::ModelBufferData; | ||||
using std::string; | using std::string; | ||||
namespace ge { | |||||
namespace domi { | |||||
// For Load | // For Load | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(const ge::ModelData &model) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(const ge::ModelData &model) { | ||||
if (CheckModelValid(model) != SUCCESS) { | if (CheckModelValid(model) != SUCCESS) { | ||||
@@ -225,4 +226,4 @@ Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferDat | |||||
return SUCCESS; | return SUCCESS; | ||||
#endif | #endif | ||||
} | } | ||||
} // namespace ge | |||||
} // namespace domi |
@@ -26,7 +26,7 @@ | |||||
#include "framework/common/util.h" | #include "framework/common/util.h" | ||||
#include "mmpa/mmpa_api.h" | #include "mmpa/mmpa_api.h" | ||||
namespace ge { | |||||
namespace domi { | |||||
/** | /** | ||||
* @ingroup domi_calibration | * @ingroup domi_calibration | ||||
@@ -68,6 +68,6 @@ Status NnSet(const int32_t n, const Dtype alpha, Dtype *output) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // end namespace ge | |||||
} // end namespace domi | |||||
#endif // GE_COMMON_MATH_UTIL_H_ | #endif // GE_COMMON_MATH_UTIL_H_ |
@@ -22,9 +22,15 @@ | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include "framework/common/debug/ge_log.h" | |||||
#include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
#include "framework/common/util.h" | #include "framework/common/util.h" | ||||
#include "framework/common/debug/ge_log.h" | |||||
using domi::GetFileLength; | |||||
using domi::MODEL_FILE_MAGIC_NUM; | |||||
using domi::ModelEncryptType; | |||||
using domi::ModelFileHeader; | |||||
using domi::RealPath; | |||||
namespace ge { | namespace ge { | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelParserBase::ModelParserBase() {} | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelParserBase::ModelParserBase() {} | ||||
@@ -63,7 +63,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi | |||||
const char *model_char = model_str.c_str(); | const char *model_char = model_str.c_str(); | ||||
uint32_t len = static_cast<uint32_t>(model_str.length()); | uint32_t len = static_cast<uint32_t>(model_str.length()); | ||||
// Write data to file | // Write data to file | ||||
mmSsize_t mmpa_ret = mmWrite(fd, const_cast<void *>((const void *)model_char), len); | |||||
int32_t mmpa_ret = mmWrite(fd, const_cast<void *>((const void *)model_char), len); | |||||
if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { | if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { | ||||
// Need to both print the error info of mmWrite and mmClose, so return ret after mmClose | // Need to both print the error info of mmWrite and mmClose, so return ret after mmClose | ||||
GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno)); | GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno)); | ||||
@@ -18,7 +18,7 @@ | |||||
#include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
#include "framework/common/util.h" | #include "framework/common/util.h" | ||||
namespace ge { | |||||
namespace domi { | |||||
#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ | #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ | ||||
FMK_FUNC_DEV_VISIBILITY void SetAttrDef(ARG_TYPE value, AttrDef *out) { \ | FMK_FUNC_DEV_VISIBILITY void SetAttrDef(ARG_TYPE value, AttrDef *out) { \ | ||||
GE_CHECK_NOTNULL_JUST_RETURN(out); \ | GE_CHECK_NOTNULL_JUST_RETURN(out); \ | ||||
@@ -312,4 +312,4 @@ DEFINE_GET_ATTR_LIST_SIZE(const std::string &, uint32_t, u); | |||||
DEFINE_GET_ATTR_LIST_SIZE(const std::string &, float, f); | DEFINE_GET_ATTR_LIST_SIZE(const std::string &, float, f); | ||||
DEFINE_GET_ATTR_LIST_SIZE(const std::string &, double, f); | DEFINE_GET_ATTR_LIST_SIZE(const std::string &, double, f); | ||||
DEFINE_GET_ATTR_LIST_SIZE(const std::string &, bool, b); | DEFINE_GET_ATTR_LIST_SIZE(const std::string &, bool, b); | ||||
} // namespace ge | |||||
} // namespace domi |
@@ -25,10 +25,10 @@ | |||||
#include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
#include "framework/common/fmk_error_codes.h" | #include "framework/common/fmk_error_codes.h" | ||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "framework/common/op/attr_define.h" | |||||
#include "framework/common/op/attr_value_util.h" | #include "framework/common/op/attr_value_util.h" | ||||
#include "framework/common/util.h" | #include "framework/common/util.h" | ||||
#include "graph/anchor.h" | #include "graph/anchor.h" | ||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
@@ -69,8 +69,6 @@ const uint32_t FOR_LIMIT_INPUT = 1; | |||||
const uint32_t FOR_DELTA_INPUT = 2; | const uint32_t FOR_DELTA_INPUT = 2; | ||||
const uint32_t FOR_DATA_INPUT = 3; | const uint32_t FOR_DATA_INPUT = 3; | ||||
const int NORMAL_TENSOR_SIZE = 4; | |||||
// Get the value of key from attr | // Get the value of key from attr | ||||
#define AIPP_GET_ATTR_VALUE(KEY, ATTR_TYPE) \ | #define AIPP_GET_ATTR_VALUE(KEY, ATTR_TYPE) \ | ||||
if (aipp_attr.GetItem(#KEY).GetValue<ATTR_TYPE>(KEY) != SUCCESS) { \ | if (aipp_attr.GetItem(#KEY).GetValue<ATTR_TYPE>(KEY) != SUCCESS) { \ | ||||
@@ -179,7 +177,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OpUtils::TransferDim(con | |||||
for (auto dim_temp : dim) { | for (auto dim_temp : dim) { | ||||
new_dim_list.push_back(dim_temp); | new_dim_list.push_back(dim_temp); | ||||
} | } | ||||
if (input_shape_size > DIM_DEFAULT_SIZE) { | |||||
if (input_shape_size > domi::DIM_DEFAULT_SIZE) { | |||||
dim_vector = dim; | dim_vector = dim; | ||||
GELOGI("Dim_vector size is %zu, do not to transfer dim", input_shape_size); | GELOGI("Dim_vector size is %zu, do not to transfer dim", input_shape_size); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -182,7 +182,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In | |||||
return SUCCESS; | return SUCCESS; | ||||
} else { | } else { | ||||
std::string prof_options_str = std::string(prof_options); | std::string prof_options_str = std::string(prof_options); | ||||
profiling_opts_ = StringUtils::Split(prof_options_str, ':'); | |||||
profiling_opts_ = domi::StringUtils::Split(prof_options_str, ':'); | |||||
is_profiling_ = true; | is_profiling_ = true; | ||||
} | } | ||||
GELOGI("The profiling in options is %s, %s", is_profiling, prof_options); | GELOGI("The profiling in options is %s, %s", is_profiling, prof_options); | ||||
@@ -314,119 +314,122 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProf | |||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingTaskDescInfo( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingTaskDescInfo( | ||||
const std::vector<TaskDescInfo> &task_desc_info, const int32_t &device_id) { | |||||
const std::vector<TaskDescInfo> &task_desc_info) { | |||||
#ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | ||||
if (reporter == nullptr) { | if (reporter == nullptr) { | ||||
GELOGI("Profiling report is nullptr!"); | GELOGI("Profiling report is nullptr!"); | ||||
return; | return; | ||||
} | } | ||||
std::string data; | std::string data; | ||||
for (const auto &task : task_desc_info) { | |||||
std::string op_name = task.op_name; | |||||
uint32_t block_dim = task.block_dim; | |||||
uint32_t task_id = task.task_id; | |||||
uint32_t stream_id = task.stream_id; | |||||
data = op_name.append(" ").append(std::to_string(block_dim) | |||||
.append(" ") | |||||
.append(std::to_string(task_id)) | |||||
.append(" ") | |||||
.append(std::to_string(stream_id)) | |||||
.append("\n")); | |||||
Msprof::Engine::ReporterData reporter_data{}; | |||||
reporter_data.deviceId = device_id; | |||||
reporter_data.data = (unsigned char *)data.c_str(); | |||||
reporter_data.dataLen = data.size(); | |||||
int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "task_desc_info", sizeof("task_desc_info")); | |||||
if (ret != EOK) { | |||||
GELOGE(ret, "Report data tag of task_desc_info memcpy error!"); | |||||
return; | |||||
} | |||||
for (size_t i = 0; i < device_id_.size(); ++i) { | |||||
for (const auto &task : task_desc_info) { | |||||
std::string op_name = task.op_name; | |||||
uint32_t block_dim = task.block_dim; | |||||
uint32_t task_id = task.task_id; | |||||
uint32_t stream_id = task.stream_id; | |||||
data = op_name.append(" ").append(std::to_string(block_dim) | |||||
.append(" ") | |||||
.append(std::to_string(task_id)) | |||||
.append(" ") | |||||
.append(std::to_string(stream_id)) | |||||
.append("\n")); | |||||
Msprof::Engine::ReporterData reporter_data{}; | |||||
reporter_data.deviceId = device_id_[i]; | |||||
reporter_data.data = (unsigned char *)data.c_str(); | |||||
reporter_data.dataLen = data.size(); | |||||
int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "task_desc_info", sizeof("task_desc_info")); | |||||
if (ret != EOK) { | |||||
GELOGE(ret, "Report data tag of task_desc_info memcpy error!"); | |||||
return; | |||||
} | |||||
ret = reporter->Report(&reporter_data); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Reporter data of task_desc_info fail!"); | |||||
return; | |||||
ret = reporter->Report(&reporter_data); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Reporter data of task_desc_info fail!"); | |||||
return; | |||||
} | |||||
} | } | ||||
} | |||||
data.clear(); | |||||
data.clear(); | |||||
} | |||||
#endif | #endif | ||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingGraphDescInfo( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingGraphDescInfo( | ||||
const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, const int32_t &device_id) { | |||||
const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info) { | |||||
#ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | ||||
GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return;); | GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return;); | ||||
std::string data; | std::string data; | ||||
for (const auto &graph : compute_graph_desc_info) { | |||||
data.append("op_name:").append(graph.op_name).append(" op_type:").append(graph.op_type); | |||||
for (size_t i = 0; i < graph.input_format.size(); ++i) { | |||||
data.append(" input_id:") | |||||
.append(std::to_string(i)) | |||||
.append(" input_format:") | |||||
.append(std::to_string(graph.input_format.at(i))) | |||||
.append(" input_data_type:") | |||||
.append(std::to_string(graph.input_data_type.at(i))) | |||||
.append(" input_shape:\""); | |||||
size_t input_shape_len = graph.input_shape.at(i).size(); | |||||
if (input_shape_len == 0) { | |||||
data.append(""); | |||||
} else if (input_shape_len == 1) { | |||||
data.append(std::to_string(graph.input_shape.at(i).at(0))); | |||||
} else { | |||||
for (size_t j = 0; j < input_shape_len - 1; ++j) { | |||||
data.append(std::to_string(graph.input_shape.at(i).at(j))).append(","); | |||||
for (size_t idx = 0; idx < device_id_.size(); ++idx) { | |||||
for (const auto &graph : compute_graph_desc_info) { | |||||
data.append("op_name:").append(graph.op_name).append(" op_type:").append(graph.op_type); | |||||
for (size_t i = 0; i < graph.input_format.size(); ++i) { | |||||
data.append(" input_id:") | |||||
.append(std::to_string(i)) | |||||
.append(" input_format:") | |||||
.append(std::to_string(graph.input_format.at(i))) | |||||
.append(" input_data_type:") | |||||
.append(std::to_string(graph.input_data_type.at(i))) | |||||
.append(" input_shape:\""); | |||||
size_t input_shape_len = graph.input_shape.at(i).size(); | |||||
if (input_shape_len == 0) { | |||||
data.append(""); | |||||
} else if (input_shape_len == 1) { | |||||
data.append(std::to_string(graph.input_shape.at(i).at(0))); | |||||
} else { | |||||
for (size_t j = 0; j < input_shape_len - 1; ++j) { | |||||
data.append(std::to_string(graph.input_shape.at(i).at(j))).append(","); | |||||
} | |||||
data.append(std::to_string(graph.input_shape.at(i).at(input_shape_len - 1))); | |||||
} | } | ||||
data.append(std::to_string(graph.input_shape.at(i).at(input_shape_len - 1))); | |||||
} | |||||
data.append("\""); | |||||
} | |||||
data.append("\""); | |||||
} | |||||
for (size_t i = 0; i < graph.output_format.size(); ++i) { | |||||
data.append(" output_id:") | |||||
.append(std::to_string(i)) | |||||
.append(" output_format:") | |||||
.append(std::to_string(graph.output_format.at(i))) | |||||
.append(" output_data_type:") | |||||
.append(std::to_string(graph.output_data_type.at(i))) | |||||
.append(" output_shape:\""); | |||||
size_t output_shape_len = graph.output_shape.at(i).size(); | |||||
if (output_shape_len == 0) { | |||||
data.append(""); | |||||
} else if (output_shape_len == 1) { | |||||
data.append(std::to_string(graph.output_shape.at(i).at(0))); | |||||
} else { | |||||
for (size_t j = 0; j < output_shape_len - 1; ++j) { | |||||
data.append(std::to_string(graph.output_shape.at(i).at(j))).append(","); | |||||
for (size_t i = 0; i < graph.output_format.size(); ++i) { | |||||
data.append(" output_id:") | |||||
.append(std::to_string(i)) | |||||
.append(" output_format:") | |||||
.append(std::to_string(graph.output_format.at(i))) | |||||
.append(" output_data_type:") | |||||
.append(std::to_string(graph.output_data_type.at(i))) | |||||
.append(" output_shape:\""); | |||||
size_t output_shape_len = graph.output_shape.at(i).size(); | |||||
if (output_shape_len == 0) { | |||||
data.append(""); | |||||
} else if (output_shape_len == 1) { | |||||
data.append(std::to_string(graph.output_shape.at(i).at(0))); | |||||
} else { | |||||
for (size_t j = 0; j < output_shape_len - 1; ++j) { | |||||
data.append(std::to_string(graph.output_shape.at(i).at(j))).append(","); | |||||
} | |||||
data.append(std::to_string(graph.output_shape.at(i).at(output_shape_len - 1))); | |||||
} | } | ||||
data.append(std::to_string(graph.output_shape.at(i).at(output_shape_len - 1))); | |||||
data.append("\""); | |||||
} | } | ||||
data.append("\""); | |||||
} | |||||
data.append("\n"); | |||||
data.append("\n"); | |||||
Msprof::Engine::ReporterData reporter_data{}; | |||||
Report(device_id, data, *reporter, reporter_data); | |||||
Msprof::Engine::ReporterData reporter_data{}; | |||||
Report(idx, data, *reporter, reporter_data); | |||||
data.clear(); | |||||
data.clear(); | |||||
} | |||||
} | } | ||||
#endif | #endif | ||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Report( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Report( | ||||
const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter, | |||||
const size_t &idx, const string &data, Msprof::Engine::Reporter &reporter, | |||||
Msprof::Engine::ReporterData &reporter_data) { | Msprof::Engine::ReporterData &reporter_data) { | ||||
#ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
size_t index = data.size() / kReportMaxLen; | size_t index = data.size() / kReportMaxLen; | ||||
if (index >= 1) { | if (index >= 1) { | ||||
reporter_data.deviceId = device_id; | |||||
reporter_data.deviceId = device_id_[idx]; | |||||
int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); | int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); | ||||
GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;); | GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;); | ||||
for (size_t i = 0; i < index; ++i) { | for (size_t i = 0; i < index; ++i) { | ||||
@@ -442,7 +445,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Report( | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Reporter data of graph_desc_info fail!"); return;); | GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Reporter data of graph_desc_info fail!"); return;); | ||||
} | } | ||||
} else { | } else { | ||||
reporter_data.deviceId = device_id; | |||||
reporter_data.deviceId = device_id_[idx]; | |||||
reporter_data.data = (unsigned char *)data.c_str(); | reporter_data.data = (unsigned char *)data.c_str(); | ||||
reporter_data.dataLen = data.size(); | reporter_data.dataLen = data.size(); | ||||
int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); | int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); | ||||
@@ -457,24 +460,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Report( | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportProfilingData( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportProfilingData( | ||||
const std::vector<TaskDescInfo> &task_desc_info, const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info) { | const std::vector<TaskDescInfo> &task_desc_info, const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info) { | ||||
#ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
int32_t device_id = 0; | |||||
rtError_t rt_ret = rtGetDevice(&device_id); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(rt_ret, "runtime get device_id failed, current device_id:%d", device_id); | |||||
return; | |||||
} | |||||
GELOGI("current device_id:%d", device_id); | |||||
auto ret = std::find(device_id_.begin(), device_id_.end(), device_id); | |||||
if (ret == device_id_.end()) { | |||||
GELOGE(FAILED, "get valid device_id failed, profiling report failed."); | |||||
return; | |||||
} | |||||
GELOGI("start ProfilingTaskDescInfo."); | GELOGI("start ProfilingTaskDescInfo."); | ||||
ProfilingTaskDescInfo(task_desc_info, device_id); | |||||
ProfilingTaskDescInfo(task_desc_info); | |||||
GELOGI("start ProfilingGraphDescInfo."); | GELOGI("start ProfilingGraphDescInfo."); | ||||
ProfilingGraphDescInfo(compute_graph_desc_info, device_id); | |||||
ProfilingGraphDescInfo(compute_graph_desc_info); | |||||
GELOGI("Report profiling data for GE end."); | GELOGI("Report profiling data for GE end."); | ||||
#endif | #endif | ||||
} | } | ||||
@@ -50,11 +50,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { | |||||
void ReportProfilingData(const std::vector<TaskDescInfo> &task_desc_info, | void ReportProfilingData(const std::vector<TaskDescInfo> &task_desc_info, | ||||
const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info); | const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info); | ||||
void Report(const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter, | |||||
void Report(const size_t &idx, const string &data, Msprof::Engine::Reporter &reporter, | |||||
Msprof::Engine::ReporterData &reporter_data); | Msprof::Engine::ReporterData &reporter_data); | ||||
void ProfilingTaskDescInfo(const std::vector<TaskDescInfo> &task_desc_info, const int32_t &device_id); | |||||
void ProfilingGraphDescInfo(const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, | |||||
const int32_t &device_id); | |||||
void ProfilingTaskDescInfo(const std::vector<TaskDescInfo> &task_desc_info); | |||||
void ProfilingGraphDescInfo(const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info); | |||||
void SetProfilingConfig(const string &profiling_cfg); | void SetProfilingConfig(const string &profiling_cfg); | ||||
vector<int32_t> GetProfilingDeviceId() const { return device_id_; } | vector<int32_t> GetProfilingDeviceId() const { return device_id_; } | ||||
@@ -59,7 +59,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool PropertiesManager::Init(co | |||||
// Load file contents | // Load file contents | ||||
bool PropertiesManager::LoadFileContent(const std::string &file_path) { | bool PropertiesManager::LoadFileContent(const std::string &file_path) { | ||||
// Normalize the path | // Normalize the path | ||||
string resolved_file_path = RealPath(file_path.c_str()); | |||||
string resolved_file_path = domi::RealPath(file_path.c_str()); | |||||
if (resolved_file_path.empty()) { | if (resolved_file_path.empty()) { | ||||
DOMI_LOGE("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str()); | DOMI_LOGE("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str()); | ||||
return false; | return false; | ||||
@@ -15,6 +15,7 @@ | |||||
*/ | */ | ||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "graph/types.h" | #include "graph/types.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -26,13 +27,14 @@ const std::string DUMP_LAYER = "layer"; | |||||
const std::string DUMP_FILE_PATH = "path"; | const std::string DUMP_FILE_PATH = "path"; | ||||
} // namespace ge | } // namespace ge | ||||
using ge::OpTypeRegistrar; | |||||
namespace ge { | |||||
namespace domi { | |||||
const int DEFAULT_FORMAT = static_cast<const int>(ge::FORMAT_NCHW); | const int DEFAULT_FORMAT = static_cast<const int>(ge::FORMAT_NCHW); | ||||
// Supported public property names | |||||
const std::string PROP_OME_START_TIME = "ome_start_time"; // start time | |||||
const std::string PROP_OME_DUMP_PATH = "ome_dump_path"; // dump path | |||||
const std::string PROP_OME_LOG_PATH = "ome_log_path"; // log path | |||||
/** | |||||
* @brief Supported public property names | |||||
*/ | |||||
const std::string PROP_OME_START_TIME = "ome_start_time"; /**< start time */ | |||||
const std::string PROP_OME_DUMP_PATH = "ome_dump_path"; /**< dump path */ | |||||
const std::string PROP_OME_LOG_PATH = "ome_log_path"; /**< log path */ | |||||
// Profile related constant | // Profile related constant | ||||
const uint32_t CCE_PROFILE_ON = 0; | const uint32_t CCE_PROFILE_ON = 0; | ||||
@@ -385,7 +387,6 @@ REGISTER_OPTYPE_DEFINE(STREAMSWITCH, "StreamSwitch"); | |||||
REGISTER_OPTYPE_DEFINE(STREAMSWITCHN, "StreamSwitchN"); | REGISTER_OPTYPE_DEFINE(STREAMSWITCHN, "StreamSwitchN"); | ||||
REGISTER_OPTYPE_DEFINE(STREAMACTIVE, "StreamActive"); | REGISTER_OPTYPE_DEFINE(STREAMACTIVE, "StreamActive"); | ||||
REGISTER_OPTYPE_DEFINE(MEMCPYASYNC, "MemcpyAsync"); | REGISTER_OPTYPE_DEFINE(MEMCPYASYNC, "MemcpyAsync"); | ||||
REGISTER_OPTYPE_DEFINE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | |||||
REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); | REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); | ||||
REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); | REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); | ||||
REGISTER_OPTYPE_DEFINE(SEND, "Send"); | REGISTER_OPTYPE_DEFINE(SEND, "Send"); | ||||
@@ -393,7 +394,6 @@ REGISTER_OPTYPE_DEFINE(RECV, "Recv"); | |||||
REGISTER_OPTYPE_DEFINE(LABELSET, "LabelSet"); | REGISTER_OPTYPE_DEFINE(LABELSET, "LabelSet"); | ||||
REGISTER_OPTYPE_DEFINE(LABELGOTO, "LabelGoto"); | REGISTER_OPTYPE_DEFINE(LABELGOTO, "LabelGoto"); | ||||
REGISTER_OPTYPE_DEFINE(LABELGOTOEX, "LabelGotoEx"); | |||||
REGISTER_OPTYPE_DEFINE(LABELSWITCH, "LabelSwitch"); | REGISTER_OPTYPE_DEFINE(LABELSWITCH, "LabelSwitch"); | ||||
REGISTER_OPTYPE_DEFINE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | REGISTER_OPTYPE_DEFINE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | ||||
@@ -469,315 +469,315 @@ const uint64_t ALLOC_MEMORY_MAX_SIZE = 8589934592; // Max size of 8 GB. | |||||
const uint64_t ALLOC_MEMORY_MAX_SIZE = 536870912; // Max size of 512M. | const uint64_t ALLOC_MEMORY_MAX_SIZE = 536870912; // Max size of 512M. | ||||
#endif | #endif | ||||
/// | |||||
///@brief Magic number of model file | |||||
/// | |||||
/** | |||||
* @brief Magic number of model file | |||||
*/ | |||||
const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number | const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number | ||||
/// | |||||
///@brief Model head length | |||||
/// | |||||
/** | |||||
* @brief Model head length | |||||
*/ | |||||
const uint32_t MODEL_FILE_HEAD_LEN = 256; | const uint32_t MODEL_FILE_HEAD_LEN = 256; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief Input node type | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Input node type | |||||
*/ | |||||
const std::string INPUT_TYPE = "Input"; | const std::string INPUT_TYPE = "Input"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief AIPP label, label AIPP conv operator | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief AIPP label, label AIPP conv operator | |||||
*/ | |||||
const std::string AIPP_CONV_FLAG = "Aipp_Conv_Flag"; | const std::string AIPP_CONV_FLAG = "Aipp_Conv_Flag"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief AIPP label, label aipp data operator | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief AIPP label, label aipp data operator | |||||
*/ | |||||
const std::string AIPP_DATA_FLAG = "Aipp_Data_Flag"; | const std::string AIPP_DATA_FLAG = "Aipp_Data_Flag"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief Record the w dimension of model input corresponding to dynamic AIPP | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Record the w dimension of model input corresponding to dynamic AIPP | |||||
*/ | |||||
const std::string AIPP_RELATED_DATA_DIM_W = "aipp_related_data_dim_w"; | const std::string AIPP_RELATED_DATA_DIM_W = "aipp_related_data_dim_w"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief Record the H dimension of model input corresponding to dynamic AIPP | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Record the H dimension of model input corresponding to dynamic AIPP | |||||
*/ | |||||
const std::string AIPP_RELATED_DATA_DIM_H = "aipp_related_data_dim_h"; | const std::string AIPP_RELATED_DATA_DIM_H = "aipp_related_data_dim_h"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief The tag of the data operator. Mark this input to the dynamic AIPP operator | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief The tag of the data operator. Mark this input to the dynamic AIPP operator | |||||
*/ | |||||
const std::string INPUT_TO_DYNAMIC_AIPP = "input_to_dynamic_aipp"; | const std::string INPUT_TO_DYNAMIC_AIPP = "input_to_dynamic_aipp"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief DATA node type | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief DATA node type | |||||
*/ | |||||
const std::string DATA_TYPE = "Data"; | const std::string DATA_TYPE = "Data"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief DATA node type | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief DATA node type | |||||
*/ | |||||
const std::string AIPP_DATA_TYPE = "AippData"; | const std::string AIPP_DATA_TYPE = "AippData"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief Frame operator type | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Frame operator type | |||||
*/ | |||||
const std::string FRAMEWORK_OP_TYPE = "FrameworkOp"; | const std::string FRAMEWORK_OP_TYPE = "FrameworkOp"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief Data node type | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Data node type | |||||
*/ | |||||
const std::string ANN_DATA_TYPE = "AnnData"; | const std::string ANN_DATA_TYPE = "AnnData"; | ||||
const std::string ANN_NETOUTPUT_TYPE = "AnnNetOutput"; | const std::string ANN_NETOUTPUT_TYPE = "AnnNetOutput"; | ||||
const std::string ANN_DEPTHCONV_TYPE = "AnnDepthConv"; | const std::string ANN_DEPTHCONV_TYPE = "AnnDepthConv"; | ||||
const std::string ANN_CONV_TYPE = "AnnConvolution"; | const std::string ANN_CONV_TYPE = "AnnConvolution"; | ||||
const std::string ANN_FC_TYPE = "AnnFullConnection"; | const std::string ANN_FC_TYPE = "AnnFullConnection"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief Convolution node type | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Convolution node type | |||||
*/ | |||||
const std::string NODE_NAME_NET_OUTPUT = "Node_Output"; | const std::string NODE_NAME_NET_OUTPUT = "Node_Output"; | ||||
const std::string NODE_NAME_END_GRAPH = "Node_EndGraph"; | const std::string NODE_NAME_END_GRAPH = "Node_EndGraph"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief Convolution node type | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Convolution node type | |||||
*/ | |||||
const std::string OP_TYPE_CONVOLUTION = "Convolution"; | const std::string OP_TYPE_CONVOLUTION = "Convolution"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief Add convolution node name to AIPP | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Add convolution node name to AIPP | |||||
*/ | |||||
const std::string AIPP_CONV_OP_NAME = "aipp_conv_op"; | const std::string AIPP_CONV_OP_NAME = "aipp_conv_op"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief Operator configuration item separator | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Operator configuration item separator | |||||
*/ | |||||
const std::string OP_CONF_DELIMITER = ":"; | const std::string OP_CONF_DELIMITER = ":"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief attr value name | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief attr value name | |||||
*/ | |||||
const std::string ATTR_NAME_VALUE1 = "value1"; | const std::string ATTR_NAME_VALUE1 = "value1"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief attr value name, 6d_2_4d C | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief attr value name, 6d_2_4d C | |||||
*/ | |||||
const std::string ATTR_NAME_INPUT_CVALUE = "input_cvalue"; | const std::string ATTR_NAME_INPUT_CVALUE = "input_cvalue"; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief alpha default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief alpha default value | |||||
*/ | |||||
const float ALPHA_DEFAULT_VALUE = 1.0; | const float ALPHA_DEFAULT_VALUE = 1.0; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief beta default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief beta default value | |||||
*/ | |||||
const float BETA_DEFAULT_VALUE = 0.0; | const float BETA_DEFAULT_VALUE = 0.0; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief coef default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief coef default value | |||||
*/ | |||||
const float COEF_DEFAULT_VALUE = 0.0; | const float COEF_DEFAULT_VALUE = 0.0; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief Relu6 coef value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Relu6 coef value | |||||
*/ | |||||
const float RELU6_COEF = 6.0; | const float RELU6_COEF = 6.0; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief stride default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief stride default value | |||||
*/ | |||||
const uint32_t STRIDE_DEFAULT_VALUE = 1; | const uint32_t STRIDE_DEFAULT_VALUE = 1; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief pad default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief pad default value | |||||
*/ | |||||
const uint32_t PAD_DEFAULT_VALUE = 0; | const uint32_t PAD_DEFAULT_VALUE = 0; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief dilation default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief dilation default value | |||||
*/ | |||||
const int DILATION_DEFAULT_VALUE = 1; | const int DILATION_DEFAULT_VALUE = 1; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief kernel default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief kernel default value | |||||
*/ | |||||
const uint32_t KERNEL_DEFAULT_VALUE = 0; | const uint32_t KERNEL_DEFAULT_VALUE = 0; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief defaule convolution group size | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief defaule convolution group size | |||||
*/ | |||||
const uint32_t DEFAULT_CONV_GROUP = 1; | const uint32_t DEFAULT_CONV_GROUP = 1; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief Default deconvolution adj | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Default deconvolution adj | |||||
*/ | |||||
const uint32_t DEFAULT_DECONV_ADJ = 0; | const uint32_t DEFAULT_DECONV_ADJ = 0; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief Represents value 1 | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Represents value 1 | |||||
*/ | |||||
const uint32_t NUM_ONE = 1; | const uint32_t NUM_ONE = 1; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief spatial dim size default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief spatial dim size default value | |||||
*/ | |||||
const int32_t SPATIAL_DIM_DEFAULT_SIZE = 2; | const int32_t SPATIAL_DIM_DEFAULT_SIZE = 2; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief dim extended default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief dim extended default value | |||||
*/ | |||||
const int32_t DIM_DEFAULT_VALUE = 1; | const int32_t DIM_DEFAULT_VALUE = 1; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief The first weight list in opdef is filter | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief The first weight list in opdef is filter | |||||
*/ | |||||
const int32_t WEIGHT_FILTER_INDEX = 0; | const int32_t WEIGHT_FILTER_INDEX = 0; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief The second weight list in opdef is bias | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief The second weight list in opdef is bias | |||||
*/ | |||||
const int32_t WEIGHT_BIAS_INDEX = 1; | const int32_t WEIGHT_BIAS_INDEX = 1; | ||||
const int32_t TENSOR_ND_SUPPORT_SIZE = 8; | const int32_t TENSOR_ND_SUPPORT_SIZE = 8; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief NCHW index default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief NCHW index default value | |||||
*/ | |||||
const uint32_t NCHW_DIM_N = 0; | const uint32_t NCHW_DIM_N = 0; | ||||
const uint32_t NCHW_DIM_C = 1; | const uint32_t NCHW_DIM_C = 1; | ||||
const uint32_t NCHW_DIM_H = 2; | const uint32_t NCHW_DIM_H = 2; | ||||
const uint32_t NCHW_DIM_W = 3; | const uint32_t NCHW_DIM_W = 3; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief KCHW index default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief KCHW index default value | |||||
*/ | |||||
const uint32_t KCHW_DIM_K = 0; | const uint32_t KCHW_DIM_K = 0; | ||||
const uint32_t KCHW_DIM_C = 1; | const uint32_t KCHW_DIM_C = 1; | ||||
const uint32_t KCHW_DIM_H = 2; | const uint32_t KCHW_DIM_H = 2; | ||||
const uint32_t KCHW_DIM_W = 3; | const uint32_t KCHW_DIM_W = 3; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief HWCK index default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief HWCK index default value | |||||
*/ | |||||
const uint32_t HWCK_DIM_H = 0; | const uint32_t HWCK_DIM_H = 0; | ||||
const uint32_t HWCK_DIM_W = 1; | const uint32_t HWCK_DIM_W = 1; | ||||
const uint32_t HWCK_DIM_C = 2; | const uint32_t HWCK_DIM_C = 2; | ||||
const uint32_t HWCK_DIM_K = 3; | const uint32_t HWCK_DIM_K = 3; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief NHWC index default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief NHWC index default value | |||||
*/ | |||||
const uint32_t NHWC_DIM_N = 0; | const uint32_t NHWC_DIM_N = 0; | ||||
const uint32_t NHWC_DIM_H = 1; | const uint32_t NHWC_DIM_H = 1; | ||||
const uint32_t NHWC_DIM_W = 2; | const uint32_t NHWC_DIM_W = 2; | ||||
const uint32_t NHWC_DIM_C = 3; | const uint32_t NHWC_DIM_C = 3; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief CHWN index default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief CHWN index default value | |||||
*/ | |||||
const uint32_t CHWN_DIM_N = 3; | const uint32_t CHWN_DIM_N = 3; | ||||
const uint32_t CHWN_DIM_C = 0; | const uint32_t CHWN_DIM_C = 0; | ||||
const uint32_t CHWN_DIM_H = 1; | const uint32_t CHWN_DIM_H = 1; | ||||
const uint32_t CHWN_DIM_W = 2; | const uint32_t CHWN_DIM_W = 2; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief CHW index default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief CHW index default value | |||||
*/ | |||||
const uint32_t CHW_DIM_C = 0; | const uint32_t CHW_DIM_C = 0; | ||||
const uint32_t CHW_DIM_H = 1; | const uint32_t CHW_DIM_H = 1; | ||||
const uint32_t CHW_DIM_W = 2; | const uint32_t CHW_DIM_W = 2; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief HWC index default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief HWC index default value | |||||
*/ | |||||
const uint32_t HWC_DIM_H = 0; | const uint32_t HWC_DIM_H = 0; | ||||
const uint32_t HWC_DIM_W = 1; | const uint32_t HWC_DIM_W = 1; | ||||
const uint32_t HWC_DIM_C = 2; | const uint32_t HWC_DIM_C = 2; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief Pad index default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Pad index default value | |||||
*/ | |||||
const uint32_t PAD_H_HEAD = 0; | const uint32_t PAD_H_HEAD = 0; | ||||
const uint32_t PAD_H_TAIL = 1; | const uint32_t PAD_H_TAIL = 1; | ||||
const uint32_t PAD_W_HEAD = 2; | const uint32_t PAD_W_HEAD = 2; | ||||
const uint32_t PAD_W_TAIL = 3; | const uint32_t PAD_W_TAIL = 3; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief window index default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief window index default value | |||||
*/ | |||||
const uint32_t WINDOW_H = 0; | const uint32_t WINDOW_H = 0; | ||||
const uint32_t WINDOW_W = 1; | const uint32_t WINDOW_W = 1; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief stride index default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief stride index default value | |||||
*/ | |||||
const uint32_t STRIDE_H = 0; | const uint32_t STRIDE_H = 0; | ||||
const uint32_t STRIDE_W = 1; | const uint32_t STRIDE_W = 1; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief dilation index default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief dilation index default value | |||||
*/ | |||||
const uint32_t DILATION_H = 0; | const uint32_t DILATION_H = 0; | ||||
const uint32_t DILATION_W = 1; | const uint32_t DILATION_W = 1; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief the num of XRBG channel | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief the num of XRBG channel | |||||
*/ | |||||
const uint32_t XRGB_CHN_NUM = 4; | const uint32_t XRGB_CHN_NUM = 4; | ||||
/// | |||||
///@ingroup domi_omg | |||||
///@brief global pooling default value | |||||
/// | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief global pooling default value | |||||
*/ | |||||
const bool DEFAULT_GLOBAL_POOLING = false; | const bool DEFAULT_GLOBAL_POOLING = false; | ||||
const uint32_t MODEL_VERSION = 0x10000000; ///< Model version 1.0/// | |||||
const uint32_t MODEL_VERSION = 0x10000000; /**< Model version 1.0 */ | |||||
// Eltwise's input size | // Eltwise's input size | ||||
const int ELTWISE_MIN_INPUT_SIZE = 2; | const int ELTWISE_MIN_INPUT_SIZE = 2; | ||||
// flowctrl | |||||
/* flowctrl */ | |||||
const std::string NODE_NAME_STREAM_SWITCH = "IteratorCtrl_StreamSwitch"; | const std::string NODE_NAME_STREAM_SWITCH = "IteratorCtrl_StreamSwitch"; | ||||
const std::string NODE_NAME_STREAM_ACTIVE = "IteratorCtrl_StreamActive"; | const std::string NODE_NAME_STREAM_ACTIVE = "IteratorCtrl_StreamActive"; | ||||
const std::string NODE_NAME_FLOWCTRL_LOOP_PER_ITER = "npu_runconfig/iterations_per_loop"; | const std::string NODE_NAME_FLOWCTRL_LOOP_PER_ITER = "npu_runconfig/iterations_per_loop"; | ||||
@@ -792,4 +792,4 @@ const uint32_t STREAM_SWITCH_INPUT_NUM = 2; | |||||
const std::string NODE_NAME_GLOBAL_STEP = "ge_global_step"; | const std::string NODE_NAME_GLOBAL_STEP = "ge_global_step"; | ||||
const std::string NODE_NAME_GLOBAL_STEP_ASSIGNADD = "global_step_assignadd"; | const std::string NODE_NAME_GLOBAL_STEP_ASSIGNADD = "global_step_assignadd"; | ||||
}; // namespace ge | |||||
}; // namespace domi |
@@ -57,7 +57,7 @@ const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M | |||||
const int kMaxFileSizeLimit = INT_MAX; | const int kMaxFileSizeLimit = INT_MAX; | ||||
} // namespace | } // namespace | ||||
namespace ge { | |||||
namespace domi { | |||||
static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) { | static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) { | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr, return false, "incorrect parameter. nullptr == proto"); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr, return false, "incorrect parameter. nullptr == proto"); | ||||
@@ -196,7 +196,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); | GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); | ||||
auto dir_path_len = directory_path.length(); | auto dir_path_len = directory_path.length(); | ||||
if (dir_path_len >= PATH_MAX) { | if (dir_path_len >= PATH_MAX) { | ||||
GELOGW("Directory path is too long."); | |||||
GELOGE(ge::FAILED, "Directory path is too long."); | |||||
return -1; | return -1; | ||||
} | } | ||||
char tmp_dir_path[PATH_MAX] = {0}; | char tmp_dir_path[PATH_MAX] = {0}; | ||||
@@ -207,7 +207,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 | int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 | ||||
if (ret != 0) { | if (ret != 0) { | ||||
if (errno != EEXIST) { | if (errno != EEXIST) { | ||||
GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", | |||||
GELOGE(ge::FAILED, "Cannot create directory %s. Make sure that the directory exists and writable.", | |||||
directory_path.c_str()); | directory_path.c_str()); | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -218,7 +218,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 | int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 | ||||
if (ret != 0) { | if (ret != 0) { | ||||
if (errno != EEXIST) { | if (errno != EEXIST) { | ||||
GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", directory_path.c_str()); | |||||
GELOGE(ge::FAILED, "Cannot create directory %s. Make sure that the directory exists and writable.", | |||||
directory_path.c_str()); | |||||
return ret; | return ret; | ||||
} | } | ||||
} | } | ||||
@@ -338,7 +339,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path) { | ||||
// The specified path is empty | // The specified path is empty | ||||
if (file_path.empty()) { | if (file_path.empty()) { | ||||
GELOGW("Path is empty."); | |||||
GELOGE(ge::FAILED, "Path is empty."); | |||||
return false; | return false; | ||||
} | } | ||||
@@ -357,23 +358,23 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const | |||||
std::string real_path = RealPath(file_path.c_str()); | std::string real_path = RealPath(file_path.c_str()); | ||||
// Unable to get absolute path (does not exist or does not have permission to access) | // Unable to get absolute path (does not exist or does not have permission to access) | ||||
if (real_path.empty()) { | if (real_path.empty()) { | ||||
GELOGW("Can not get real path for %s, %s", file_path.c_str(), strerror(errno)); | |||||
GELOGE(ge::FAILED, "Can not get real path for %s, %s", file_path.c_str(), strerror(errno)); | |||||
return false; | return false; | ||||
} | } | ||||
// The absolute path points to a file that is not readable | // The absolute path points to a file that is not readable | ||||
if (access(real_path.c_str(), R_OK) != 0) { | if (access(real_path.c_str(), R_OK) != 0) { | ||||
GELOGW("Can not read file in %s, %s", file_path.c_str(), strerror(errno)); | |||||
GELOGE(ge::FAILED, "Can not read file in %s, %s", file_path.c_str(), strerror(errno)); | |||||
return false; | return false; | ||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) { | |||||
FMK_FUNC_HOST_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) { | |||||
// The specified path is empty | // The specified path is empty | ||||
if (file_path.empty()) { | if (file_path.empty()) { | ||||
GELOGW("Path is empty."); | |||||
GELOGE(ge::FAILED, "Path is empty."); | |||||
return false; | return false; | ||||
} | } | ||||
@@ -393,8 +394,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const | |||||
// Can get absolute path (file exists) | // Can get absolute path (file exists) | ||||
if (!real_path.empty()) { | if (!real_path.empty()) { | ||||
// File is not readable or writable | // File is not readable or writable | ||||
if (access(real_path.c_str(), W_OK | F_OK) != 0) { | |||||
GELOGW("Path[ %s ] exists, but can not be write, %s", file_path.c_str(), strerror(errno)); | |||||
if (access(real_path.c_str(), R_OK | W_OK | F_OK) != 0) { | |||||
GELOGE(ge::FAILED, "Path[ %s ] exists, but can not be write, %s", file_path.c_str(), strerror(errno)); | |||||
return false; | return false; | ||||
} | } | ||||
} else { | } else { | ||||
@@ -412,7 +413,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const | |||||
std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos)); | std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos)); | ||||
// Determine whether the specified path is valid by creating the path | // Determine whether the specified path is valid by creating the path | ||||
if (CreateDirectory(prefix_path) != 0) { | if (CreateDirectory(prefix_path) != 0) { | ||||
GELOGW("Can not create prefix path for path[ %s ].", file_path.c_str()); | |||||
GELOGE(ge::FAILED, "Can not create prefix path for path[ %s ].", file_path.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -435,4 +436,4 @@ FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::str | |||||
return true; | return true; | ||||
#endif | #endif | ||||
} | } | ||||
} // namespace ge | |||||
} // namespace domi |
@@ -47,7 +47,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"../graph/load/new_model_manager/task_info/kernel_task_info.cc" | "../graph/load/new_model_manager/task_info/kernel_task_info.cc" | ||||
"../graph/load/new_model_manager/task_info/label_goto_task_info.cc" | "../graph/load/new_model_manager/task_info/label_goto_task_info.cc" | ||||
"../graph/load/new_model_manager/task_info/label_set_task_info.cc" | "../graph/load/new_model_manager/task_info/label_set_task_info.cc" | ||||
"../graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||||
"../graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | "../graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | ||||
"../graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "../graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
"../graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "../graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
@@ -86,6 +85,7 @@ include_directories(${GE_SOURCE_DIR}/inc) | |||||
include_directories(${GE_SOURCE_DIR}/inc/graph) | include_directories(${GE_SOURCE_DIR}/inc/graph) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
@@ -193,8 +193,15 @@ Status GeExecutor::Initialize() { | |||||
} | } | ||||
// Start profiling | // Start profiling | ||||
int32_t device_id = 0; | |||||
rtError_t rt_ret = rtGetDevice(&device_id); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(rt_ret, "runtime get device_id failed, current device_id:%d", device_id); | |||||
return FAILED; | |||||
} | |||||
GELOGI("current device_id:%d", device_id); | |||||
Options profiling_options; | Options profiling_options; | ||||
profiling_options.device_id = 0; | |||||
profiling_options.device_id = device_id; | |||||
profiling_options.job_id = ""; | profiling_options.job_id = ""; | ||||
ProfilingManager::Instance().Init(profiling_options); | ProfilingManager::Instance().Init(profiling_options); | ||||
@@ -345,7 +352,7 @@ Status GeExecutor::LoadModelOffline(uint32_t &model_id, const std::string &path, | |||||
return GE_EXEC_NOT_INIT; | return GE_EXEC_NOT_INIT; | ||||
} | } | ||||
string filePath = RealPath(path.c_str()); | |||||
string filePath = domi::RealPath(path.c_str()); | |||||
if (filePath.empty()) { | if (filePath.empty()) { | ||||
GELOGE(ge::FAILED, "fileath is invalid. please check your text file '%s'.", path.c_str()); | GELOGE(ge::FAILED, "fileath is invalid. please check your text file '%s'.", path.c_str()); | ||||
return ge::FAILED; | return ge::FAILED; | ||||
@@ -396,6 +403,10 @@ Status GeExecutor::UnloadModel(uint32_t model_id) { | |||||
return GE_EXEC_NOT_INIT; | return GE_EXEC_NOT_INIT; | ||||
} | } | ||||
// stop profiling | |||||
if (!ProfilingManager::Instance().ProfilingOpTraceOn() && ProfilingManager::Instance().ProfilingOn()) { | |||||
ProfilingManager::Instance().StopProfiling(); | |||||
} | |||||
return GraphLoader::UnloadModel(model_id); | return GraphLoader::UnloadModel(model_id); | ||||
} | } | ||||
@@ -554,7 +565,7 @@ Status GeExecutor::LoadDataFromFile(const std::string &path, ModelData &model_da | |||||
return GE_EXEC_NOT_INIT; | return GE_EXEC_NOT_INIT; | ||||
} | } | ||||
string filePath = RealPath(path.c_str()); | |||||
string filePath = domi::RealPath(path.c_str()); | |||||
if (filePath.empty()) { | if (filePath.empty()) { | ||||
GELOGE(ge::FAILED, "filePath is invalid. please check your text file '%s'.", path.c_str()); | GELOGE(ge::FAILED, "filePath is invalid. please check your text file '%s'.", path.c_str()); | ||||
return ge::FAILED; | return ge::FAILED; | ||||
@@ -35,6 +35,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||||
include_directories(${GE_SOURCE_DIR}/inc/framework) | include_directories(${GE_SOURCE_DIR}/inc/framework) | ||||
include_directories(${GE_SOURCE_DIR}/inc/graph) | include_directories(${GE_SOURCE_DIR}/inc/graph) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
@@ -237,7 +237,7 @@ Status HostCpuEngine::LoadLib(const std::string &lib_path) { | |||||
} | } | ||||
Status HostCpuEngine::GetRealPath(std::string &path) { | Status HostCpuEngine::GetRealPath(std::string &path) { | ||||
std::string real_path = RealPath(path.c_str()); | |||||
std::string real_path = domi::RealPath(path.c_str()); | |||||
if (real_path.empty()) { | if (real_path.empty()) { | ||||
GELOGW("File path %s is invalid.", path.c_str()); | GELOGW("File path %s is invalid.", path.c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
@@ -21,7 +21,7 @@ | |||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
#include "graph/operator.h" | #include "graph/operator.h" | ||||
#include "inc/register/register.h" | |||||
#include "register/register.h" | |||||
namespace ge { | namespace ge { | ||||
class HostCpuEngine { | class HostCpuEngine { | ||||
@@ -17,6 +17,8 @@ | |||||
#include "ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h" | #include "ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h" | ||||
#include <memory> | #include <memory> | ||||
#include "common/constant/constant.h" | #include "common/constant/constant.h" | ||||
#include "framework/common/debug/ge_log.h" | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
@@ -26,6 +26,7 @@ include_directories(${GE_SOURCE_DIR}/inc/framework/common) | |||||
include_directories(${GE_SOURCE_DIR}/inc/framework/ge_runtime) | include_directories(${GE_SOURCE_DIR}/inc/framework/ge_runtime) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
@@ -447,11 +447,8 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||||
/// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | ||||
/// and that of unknown shape is zero too. | /// and that of unknown shape is zero too. | ||||
/// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | ||||
int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); | |||||
if (elem_num == 0 && constant->weight_tensors[0].size == 0) { | |||||
elem_num = 1; | |||||
} | |||||
int64_t elem_num = | |||||
(constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); | |||||
if (constant->weight_data.size() < sizeof(uint64_t)) { | if (constant->weight_data.size() < sizeof(uint64_t)) { | ||||
GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | ||||
return false; | return false; | ||||
@@ -28,6 +28,11 @@ | |||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "model/ge_model.h" | #include "model/ge_model.h" | ||||
using domi::DATA; | |||||
using domi::ModelHelper; | |||||
using domi::NETOUTPUT; | |||||
using domi::NODE_NAME_NET_OUTPUT; | |||||
using domi::SaveParam; | |||||
using ge::ModelBufferData; | using ge::ModelBufferData; | ||||
using std::map; | using std::map; | ||||
using std::string; | using std::string; | ||||
@@ -101,7 +106,7 @@ static void GetOpsProtoPath(string &opsproto_path) { | |||||
const char *path_env = std::getenv("ASCEND_OPP_PATH"); | const char *path_env = std::getenv("ASCEND_OPP_PATH"); | ||||
if (path_env != nullptr) { | if (path_env != nullptr) { | ||||
string path = path_env; | string path = path_env; | ||||
string file_path = RealPath(path.c_str()); | |||||
string file_path = domi::RealPath(path.c_str()); | |||||
if (file_path.empty()) { | if (file_path.empty()) { | ||||
GELOGE(FAILED, "File path %s is invalid.", path.c_str()); | GELOGE(FAILED, "File path %s is invalid.", path.c_str()); | ||||
return; | return; | ||||
@@ -143,7 +148,7 @@ Status GeGenerator::Initialize(const map<string, string> &options) { | |||||
GELOGI("opsproto_path is %s", opsproto_path.c_str()); | GELOGI("opsproto_path is %s", opsproto_path.c_str()); | ||||
OpsProtoManager *manager = OpsProtoManager::Instance(); | OpsProtoManager *manager = OpsProtoManager::Instance(); | ||||
map<string, string> option_tmp; | map<string, string> option_tmp; | ||||
option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | |||||
option_tmp.insert(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | |||||
(void)manager->Initialize(option_tmp); | (void)manager->Initialize(option_tmp); | ||||
Status ret = impl_->graph_manager_.Initialize(options); | Status ret = impl_->graph_manager_.Initialize(options); | ||||
@@ -258,7 +263,7 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor | |||||
map<string, GeAttrValue> op_attrs = op_desc->GetAllAttrs(); | map<string, GeAttrValue> op_attrs = op_desc->GetAllAttrs(); | ||||
// 1. Create ComputeGraph. | // 1. Create ComputeGraph. | ||||
string name = ge::CurrentTimeInStr() + "_" + model_file_name; | |||||
string name = domi::CurrentTimeInStr() + "_" + model_file_name; | |||||
ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(name); | ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(name); | ||||
if (compute_graph == nullptr) { | if (compute_graph == nullptr) { | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
@@ -116,7 +116,7 @@ Status_t OpTaskGernerator(const char *op_type, const OpTensor_t *in_tensor, int | |||||
CHECK_PARAM_NOT_NULL(om_file); | CHECK_PARAM_NOT_NULL(om_file); | ||||
const std::string om_file_name(om_file); | const std::string om_file_name(om_file); | ||||
std::string op_name = std::string(op_type) + "_" + std::to_string(ge::GetCurrentTimestap()); | |||||
std::string op_name = std::string(op_type) + "_" + std::to_string(domi::GetCurrentTimestap()); | |||||
ge::OpDescPtr op_desc = ge::MakeShared<ge::OpDesc>(op_name, op_type); | ge::OpDescPtr op_desc = ge::MakeShared<ge::OpDesc>(op_name, op_type); | ||||
if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
return ge::FAILED; | return ge::FAILED; | ||||
@@ -18,15 +18,18 @@ | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/helper/model_helper.h" | #include "common/helper/model_helper.h" | ||||
#include "common/opskernel/ops_kernel_info_types.h" | #include "common/opskernel/ops_kernel_info_types.h" | ||||
#include "graph/build/run_context.h" | |||||
#include "graph/build/stream_graph_optimizer.h" | #include "graph/build/stream_graph_optimizer.h" | ||||
#include "graph/build/run_context.h" | |||||
#include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "init/gelib.h" | #include "init/gelib.h" | ||||
#include "model/ge_model.h" | #include "model/ge_model.h" | ||||
using domi::ATTR_MODEL_MEMORY_SIZE; | |||||
using domi::ATTR_MODEL_WEIGHT_SIZE; | |||||
using domi::BuildMode; | using domi::BuildMode; | ||||
using domi::DATA; | |||||
namespace { | namespace { | ||||
const int32_t kInvalidPerfLevel = -1; | const int32_t kInvalidPerfLevel = -1; | ||||
@@ -98,10 +101,8 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfo | |||||
Status ret = SecondPartition(comp_graph, subgraph_ptr_list); | Status ret = SecondPartition(comp_graph, subgraph_ptr_list); | ||||
GE_CHK_STATUS_RET(ret, "Graph second partition Failed."); | GE_CHK_STATUS_RET(ret, "Graph second partition Failed."); | ||||
auto subgraph_map = graph_partitioner_.GetSubGraphMap(); | |||||
GE_TIMESTAMP_START(BuildSubgraph); | GE_TIMESTAMP_START(BuildSubgraph); | ||||
ge::ModelBuilder builder(comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); | |||||
ge::ModelBuilder builder(comp_graph, subgraph_ptr_list, stream_max_parallel_num_, hcom_parallel_, build_mode_); | |||||
GELOGI("[Build] invoke the other opskernel to generate task."); | GELOGI("[Build] invoke the other opskernel to generate task."); | ||||
@@ -137,7 +138,7 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfo | |||||
} | } | ||||
GE_TIMESTAMP_START(GetTaskInfo); | GE_TIMESTAMP_START(GetTaskInfo); | ||||
ret = GetTaskInfo(builder, model_ptr, comp_graph, subgraph_map, session_id); | |||||
ret = GetTaskInfo(builder, model_ptr, comp_graph, subgraph_ptr_list, session_id); | |||||
GE_TIMESTAMP_END(GetTaskInfo, "GraphBuilder::GetTaskInfo"); | GE_TIMESTAMP_END(GetTaskInfo, "GraphBuilder::GetTaskInfo"); | ||||
GraphUtils::DumpGEGraph(comp_graph, "AfterGetTask"); | GraphUtils::DumpGEGraph(comp_graph, "AfterGetTask"); | ||||
@@ -157,7 +158,7 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfo | |||||
} | } | ||||
Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, | Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, | ||||
ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map, | |||||
ComputeGraphPtr &comp_graph, std::vector<SubGraphInfoPtr> &subgraph_ptr_list, | |||||
uint64_t session_id) { | uint64_t session_id) { | ||||
GE_CHECK_NOTNULL(model_ptr); | GE_CHECK_NOTNULL(model_ptr); | ||||
GE_CHECK_NOTNULL(comp_graph); | GE_CHECK_NOTNULL(comp_graph); | ||||
@@ -192,7 +193,7 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr | |||||
} | } | ||||
StreamGraphOptimizer stream_optimizer; | StreamGraphOptimizer stream_optimizer; | ||||
ret = stream_optimizer.OptimizeStreamedSubGraph(comp_graph, subgraph_map, run_context.GetRunContext()); | |||||
ret = stream_optimizer.OptimizeStreamedSubGraph(comp_graph, subgraph_ptr_list, run_context.GetRunContext()); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Optimize streamed subGraph fail."); | GELOGE(ret, "Optimize streamed subGraph fail."); | ||||
return ret; | return ret; | ||||
@@ -53,7 +53,7 @@ class GraphBuilder { | |||||
private: | private: | ||||
Status CalcOpParam(const ge::ComputeGraphPtr &graph); | Status CalcOpParam(const ge::ComputeGraphPtr &graph); | ||||
Status GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, ComputeGraphPtr &comp_graph, | Status GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, ComputeGraphPtr &comp_graph, | ||||
Graph2SubGraphInfoList &subgraph_map, uint64_t session_id = INVALID_SESSION_ID); | |||||
std::vector<SubGraphInfoPtr> &subgraph_ptr_list, uint64_t session_id = INVALID_SESSION_ID); | |||||
Status SetInputSize(const ge::NodePtr &node_ptr); | Status SetInputSize(const ge::NodePtr &node_ptr); | ||||
Status UpdateDataInputSize(const ge::NodePtr &node_ptr); | Status UpdateDataInputSize(const ge::NodePtr &node_ptr); | ||||
Status SecondPartition(ge::ComputeGraphPtr &comp_graph, vector<ge::SubGraphInfoPtr> &subgraph_ptr_list); | Status SecondPartition(ge::ComputeGraphPtr &comp_graph, vector<ge::SubGraphInfoPtr> &subgraph_ptr_list); | ||||
@@ -16,17 +16,22 @@ | |||||
#include "graph/build/logical_stream_allocator.h" | #include "graph/build/logical_stream_allocator.h" | ||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/op/attr_define.h" | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/fmk_error_codes.h" | #include "framework/common/fmk_error_codes.h" | ||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/debug/ge_attr_define.h" | |||||
using std::map; | using std::map; | ||||
using std::set; | using std::set; | ||||
using std::string; | using std::string; | ||||
using std::vector; | using std::vector; | ||||
using domi::ATTR_NAME_STREAM_LABEL; | |||||
using domi::CONSTANT; | |||||
using domi::CONSTANTOP; | |||||
using domi::HCOMALLREDUCE; | |||||
namespace { | namespace { | ||||
const char *const kAICPUEngineName = "DNN_VM_AICPU"; | const char *const kAICPUEngineName = "DNN_VM_AICPU"; | ||||
const char *const kAttrNameParentOpType = "parentOpType"; | const char *const kAttrNameParentOpType = "parentOpType"; | ||||
@@ -70,7 +75,7 @@ bool LogicalStreamPass::HasNonConstInputNode(const Subgraph &subgraph) const { | |||||
return false; | return false; | ||||
} | } | ||||
Status AssignByLabelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
Status AssignByLabelPass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
bool changed = false; | bool changed = false; | ||||
int64_t &next_stream = context.next_stream; | int64_t &next_stream = context.next_stream; | ||||
map<string, int64_t> label_streams; | map<string, int64_t> label_streams; | ||||
@@ -97,7 +102,7 @@ Status AssignByLabelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> & | |||||
return changed ? SUCCESS : NOT_CHANGED; | return changed ? SUCCESS : NOT_CHANGED; | ||||
} | } | ||||
Status IndependentStreamPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
Status IndependentStreamPass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
bool changed = false; | bool changed = false; | ||||
int64_t &next_stream = context.next_stream; | int64_t &next_stream = context.next_stream; | ||||
@@ -129,7 +134,8 @@ Status IndependentStreamPass::Run(ComputeGraphPtr graph, const vector<SubgraphPt | |||||
return changed ? SUCCESS : NOT_CHANGED; | return changed ? SUCCESS : NOT_CHANGED; | ||||
} | } | ||||
Status AssignByDependencyPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
Status AssignByDependencyPass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, | |||||
Context &context) { | |||||
bool changed = false; | bool changed = false; | ||||
if (IsHeadNodeExceeded(subgraphs)) { | if (IsHeadNodeExceeded(subgraphs)) { | ||||
int64_t &next_stream = context.next_stream; | int64_t &next_stream = context.next_stream; | ||||
@@ -297,7 +303,7 @@ int64_t AssignByDependencyPass::AssignNewStream(SubgraphPtr subgraph) { | |||||
subgraph->stream_id = stream_id; | subgraph->stream_id = stream_id; | ||||
engine_next_streams_[engine_name] = stream_id + 1; | engine_next_streams_[engine_name] = stream_id + 1; | ||||
assigned_subgraphs_.emplace_back(subgraph); | |||||
assigned_subgraphs_.emplace(subgraph); | |||||
if ((stream_id + 1) > engine_stream_num_[engine_name]) { | if ((stream_id + 1) > engine_stream_num_[engine_name]) { | ||||
engine_stream_num_[engine_name] = stream_id + 1; | engine_stream_num_[engine_name] = stream_id + 1; | ||||
@@ -310,15 +316,6 @@ int64_t AssignByDependencyPass::AssignNewStream(SubgraphPtr subgraph) { | |||||
} | } | ||||
void AssignByDependencyPass::UpdateAssignedSubgraphs(Context &context) { | void AssignByDependencyPass::UpdateAssignedSubgraphs(Context &context) { | ||||
// If the parent stream is valid, the first assigned stream will reuse the parent stream id | |||||
// and other streams use new id. To ensure that the id of the new stream is continuous, | |||||
// we first subtract one from next_stream. | |||||
int64_t to_be_updated_stream = kInvalidStream; | |||||
if (context.parent_stream != kInvalidStream) { | |||||
context.next_stream--; | |||||
to_be_updated_stream = context.next_stream; | |||||
} | |||||
// Update the starting stream id for each engine. | // Update the starting stream id for each engine. | ||||
int64_t &next_stream = context.next_stream; | int64_t &next_stream = context.next_stream; | ||||
map<string, int64_t> engine_start_streams; | map<string, int64_t> engine_start_streams; | ||||
@@ -328,16 +325,10 @@ void AssignByDependencyPass::UpdateAssignedSubgraphs(Context &context) { | |||||
next_stream += stream_count; | next_stream += stream_count; | ||||
} | } | ||||
// Update the subgraph streams assigned by engine. | |||||
// Update the subgraphs assigned by the engine. | |||||
for (auto &subgraph : assigned_subgraphs_) { | for (auto &subgraph : assigned_subgraphs_) { | ||||
subgraph->stream_id += engine_start_streams[subgraph->engine_conf.id]; | subgraph->stream_id += engine_start_streams[subgraph->engine_conf.id]; | ||||
if (subgraph->stream_id == to_be_updated_stream) { | |||||
subgraph->stream_id = context.parent_stream; | |||||
GELOGI("Subgraph %s of engine %s reuses parent stream %ld.", subgraph->name.c_str(), | |||||
subgraph->engine_conf.id.c_str(), context.parent_stream); | |||||
} else { | |||||
GELOGI("Stream of subgraph %s has been updated to %ld.", subgraph->name.c_str(), subgraph->stream_id); | |||||
} | |||||
GELOGI("Stream of subgraph %s has been updated to %ld.", subgraph->name.c_str(), subgraph->stream_id); | |||||
} | } | ||||
} | } | ||||
@@ -351,7 +342,7 @@ void AssignByDependencyPass::UpdateReusedSubgraphs() { | |||||
} | } | ||||
} | } | ||||
Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
Status NodeStreamUpdatePass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
// Check if all subgraphs have been assigned a stream. | // Check if all subgraphs have been assigned a stream. | ||||
for (const SubgraphPtr &subgraph : subgraphs) { | for (const SubgraphPtr &subgraph : subgraphs) { | ||||
const string &engine_name = subgraph->engine_conf.id; | const string &engine_name = subgraph->engine_conf.id; | ||||
@@ -367,7 +358,7 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr | |||||
} | } | ||||
// Init the stream id of node. | // Init the stream id of node. | ||||
for (NodePtr &node : graph->GetDirectNode()) { | |||||
for (NodePtr &node : whole_graph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
node->GetOpDesc()->SetStreamId(kInvalidStream); | node->GetOpDesc()->SetStreamId(kInvalidStream); | ||||
} | } | ||||
@@ -389,11 +380,76 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr | |||||
} | } | ||||
// Update stream id for nodes belong to skipped engine subgraph | // Update stream id for nodes belong to skipped engine subgraph | ||||
GE_CHK_STATUS_RET(UpdateForSkippedEngine(graph, subgraphs)); | |||||
GE_CHK_STATUS_RET(UpdateForSkippedEngine(whole_graph, subgraphs)); | |||||
RefreshContinuousStreams(whole_graph, context); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status AllReduceParallelPass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
if (!context.hcom_parallel) { | |||||
return NOT_CHANGED; | |||||
} | |||||
GELOGI("AllReduceParallelPass is enabled."); | |||||
GraphUtils::DumpGEGraph(whole_graph, "BeforeAllReduceParallel"); | |||||
// All successors of HcomAllReduce. | |||||
set<NodePtr> all_reduce_succs; | |||||
for (const NodePtr &node : whole_graph->GetDirectNode()) { | |||||
if (node->GetType() != HCOMALLREDUCE || node->GetInDataNodes().size() <= 1) { | |||||
continue; | |||||
} | |||||
string reduce_stream_label; | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
// ATTR_NAME_STREAM_LABEL is optional. | |||||
(void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, reduce_stream_label); | |||||
set<NodePtr> cur_nodes = {node}; | |||||
while (!cur_nodes.empty()) { | |||||
set<NodePtr> all_out_data_nodes; | |||||
for (auto &curr_node : cur_nodes) { | |||||
for (const NodePtr &out_node : curr_node->GetOutDataNodes()) { | |||||
string out_stream_label; | |||||
GE_CHECK_NOTNULL(out_node->GetOpDesc()); | |||||
// ATTR_NAME_STREAM_LABEL is optional. | |||||
(void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, out_stream_label); | |||||
if (out_stream_label == reduce_stream_label) { | |||||
all_reduce_succs.emplace(out_node); | |||||
all_out_data_nodes.emplace(out_node); | |||||
} | |||||
} | |||||
} | |||||
cur_nodes = all_out_data_nodes; | |||||
} | |||||
} | |||||
map<int64_t, int64_t> old_stream_to_new; | |||||
for (const NodePtr &node : all_reduce_succs) { | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
auto old_stream = node->GetOpDesc()->GetStreamId(); | |||||
if (old_stream != kInvalidStream) { | |||||
int64_t new_stream = kInvalidStream; | |||||
auto iter = old_stream_to_new.find(old_stream); | |||||
if (iter != old_stream_to_new.end()) { | |||||
new_stream = iter->second; | |||||
} else { | |||||
new_stream = context.next_stream; | |||||
context.next_stream++; | |||||
old_stream_to_new.emplace(old_stream, new_stream); | |||||
} | |||||
GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream); | |||||
node->GetOpDesc()->SetStreamId(new_stream); | |||||
} | |||||
} | |||||
return !all_reduce_succs.empty() ? SUCCESS : NOT_CHANGED; | |||||
} | |||||
int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { | int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { | ||||
set<int64_t> stream_ids; | set<int64_t> stream_ids; | ||||
@@ -421,11 +477,11 @@ int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { | |||||
return kInvalidStream; | return kInvalidStream; | ||||
} | } | ||||
Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph, | |||||
Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &whole_graph, | |||||
const vector<SubgraphPtr> &subgraphs) { | const vector<SubgraphPtr> &subgraphs) { | ||||
set<OpDescPtr> nodes_to_be_updated; | set<OpDescPtr> nodes_to_be_updated; | ||||
// Check if subgraph is engine skipped and without stream label or not | |||||
// Check if sub graph is engine skipped and without stream label or not | |||||
for (const SubgraphPtr &subgraph : subgraphs) { | for (const SubgraphPtr &subgraph : subgraphs) { | ||||
if (IsEngineSkip(*subgraph) && !HasStreamLabel(*subgraph)) { | if (IsEngineSkip(*subgraph) && !HasStreamLabel(*subgraph)) { | ||||
auto graph = subgraph->subgraph_info.GetSubGraph(); | auto graph = subgraph->subgraph_info.GetSubGraph(); | ||||
@@ -441,7 +497,7 @@ Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph | |||||
} | } | ||||
// Try reassign the stream id | // Try reassign the stream id | ||||
for (ge::NodePtr &node : graph->GetDirectNode()) { | |||||
for (ge::NodePtr &node : whole_graph->GetDirectNode()) { | |||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
int64_t stream_id = op_desc->GetStreamId(); | int64_t stream_id = op_desc->GetStreamId(); | ||||
@@ -458,7 +514,6 @@ Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph | |||||
} | } | ||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -475,65 +530,40 @@ bool NodeStreamUpdatePass::AreAllPredStreamsInvalid(const NodePtr &node) const { | |||||
return true; | return true; | ||||
} | } | ||||
Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
if (!context.hcom_parallel) { | |||||
return NOT_CHANGED; | |||||
} | |||||
GELOGI("AllReduceParallelPass is enabled."); | |||||
GraphUtils::DumpGEGraph(graph, "BeforeAllReduceParallel"); | |||||
// All successors of HcomAllReduce. | |||||
set<NodePtr> all_reduce_succs; | |||||
for (const NodePtr &node : graph->GetDirectNode()) { | |||||
if (node->GetType() != HCOMALLREDUCE || node->GetInDataNodes().size() <= 1) { | |||||
continue; | |||||
} | |||||
string reduce_stream_label; | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
(void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, reduce_stream_label); | |||||
void NodeStreamUpdatePass::RefreshContinuousStreams(ComputeGraphPtr whole_graph, Context &context) const { | |||||
int64_t stream_num = context.next_stream; | |||||
vector<bool> stream_has_node(stream_num); | |||||
set<NodePtr> cur_nodes = {node}; | |||||
while (!cur_nodes.empty()) { | |||||
set<NodePtr> all_out_data_nodes; | |||||
for (auto &curr_node : cur_nodes) { | |||||
for (const NodePtr &out_node : curr_node->GetOutDataNodes()) { | |||||
string out_stream_label; | |||||
GE_CHECK_NOTNULL(out_node->GetOpDesc()); | |||||
(void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, out_stream_label); | |||||
if (out_stream_label == reduce_stream_label) { | |||||
all_reduce_succs.emplace(out_node); | |||||
all_out_data_nodes.emplace(out_node); | |||||
} | |||||
for (const NodePtr &node : whole_graph->GetDirectNode()) { | |||||
if (node != nullptr) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
if (op_desc != nullptr) { | |||||
int64_t stream_id = op_desc->GetStreamId(); | |||||
if (stream_id != kInvalidStream && stream_id < stream_num) { | |||||
stream_has_node[stream_id] = true; | |||||
} | } | ||||
} | } | ||||
cur_nodes = all_out_data_nodes; | |||||
} | } | ||||
} | } | ||||
map<int64_t, int64_t> old_stream_to_new; | |||||
for (const NodePtr &node : all_reduce_succs) { | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
auto old_stream = node->GetOpDesc()->GetStreamId(); | |||||
if (old_stream != kInvalidStream) { | |||||
int64_t new_stream = kInvalidStream; | |||||
auto iter = old_stream_to_new.find(old_stream); | |||||
if (iter != old_stream_to_new.end()) { | |||||
new_stream = iter->second; | |||||
} else { | |||||
new_stream = context.next_stream; | |||||
context.next_stream++; | |||||
old_stream_to_new.emplace(old_stream, new_stream); | |||||
} | |||||
GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream); | |||||
node->GetOpDesc()->SetStreamId(new_stream); | |||||
context.next_stream = 0; | |||||
vector<int64_t> old_to_new_streams(stream_num, kInvalidStream); | |||||
for (size_t old_stream = 0; old_stream < stream_has_node.size(); ++old_stream) { | |||||
if (stream_has_node[old_stream]) { | |||||
old_to_new_streams[old_stream] = context.next_stream; | |||||
++context.next_stream; | |||||
} | } | ||||
} | } | ||||
return !all_reduce_succs.empty() ? SUCCESS : NOT_CHANGED; | |||||
for (const NodePtr &node : whole_graph->GetDirectNode()) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
if (op_desc != nullptr) { | |||||
int64_t stream_id = op_desc->GetStreamId(); | |||||
if (stream_id != kInvalidStream && stream_id < stream_num) { | |||||
op_desc->SetStreamId(old_to_new_streams[stream_id]); | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
LogicalStreamAllocator::LogicalStreamAllocator(const map<string, SchedulerConf> &scheduler_confs, | LogicalStreamAllocator::LogicalStreamAllocator(const map<string, SchedulerConf> &scheduler_confs, | ||||
@@ -542,10 +572,9 @@ LogicalStreamAllocator::LogicalStreamAllocator(const map<string, SchedulerConf> | |||||
context_.hcom_parallel = hcom_parallel; | context_.hcom_parallel = hcom_parallel; | ||||
} | } | ||||
Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const Graph2SubGraphInfoList &subgraph_map, | |||||
Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const vector<SubGraphInfoPtr> &subgraph_infos, | |||||
int64_t &stream_num) { | int64_t &stream_num) { | ||||
GE_CHECK_NOTNULL(whole_graph); | GE_CHECK_NOTNULL(whole_graph); | ||||
map<string, EngineConfPtr> engine_confs; | map<string, EngineConfPtr> engine_confs; | ||||
GE_TIMESTAMP_START(InitEngineConfs); | GE_TIMESTAMP_START(InitEngineConfs); | ||||
for (const auto &item : scheduler_confs_) { | for (const auto &item : scheduler_confs_) { | ||||
@@ -559,64 +588,16 @@ Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const | |||||
} | } | ||||
GE_TIMESTAMP_END(InitEngineConfs, "GraphBuilder::AssignStreamInitEngineConfs"); | GE_TIMESTAMP_END(InitEngineConfs, "GraphBuilder::AssignStreamInitEngineConfs"); | ||||
Status status = DoAssign(whole_graph, subgraph_map, engine_confs); | |||||
if (status != SUCCESS) { | |||||
GELOGE(status, "Assign streams failed."); | |||||
return status; | |||||
} | |||||
vector<ComputeGraphPtr> subgraphs = whole_graph->GetAllSubgraphs(); | |||||
for (const ComputeGraphPtr &subgraph : subgraphs) { | |||||
Status status = DoAssign(subgraph, subgraph_map, engine_confs); | |||||
if (status != SUCCESS) { | |||||
GELOGE(status, "Assign streams failed."); | |||||
return status; | |||||
} | |||||
} | |||||
RefreshContinuousStreams(whole_graph); | |||||
stream_num = context_.next_stream; | |||||
GELOGI("Assigned logical stream num: %ld.", stream_num); | |||||
return SUCCESS; | |||||
} | |||||
Status LogicalStreamAllocator::DoAssign(const ComputeGraphPtr &graph, const Graph2SubGraphInfoList &subgraph_map, | |||||
const map<string, EngineConfPtr> &engine_confs) { | |||||
GE_CHECK_NOTNULL(graph); | |||||
NodePtr parent_node = graph->GetParentNode(); | |||||
if (parent_node == nullptr || parent_node->GetOpDesc() == nullptr) { | |||||
context_.parent_stream = kInvalidStream; | |||||
} else { | |||||
context_.parent_stream = parent_node->GetOpDesc()->GetStreamId(); | |||||
} | |||||
auto iter = subgraph_map.find(graph); | |||||
if (iter == subgraph_map.end()) { | |||||
GELOGE(FAILED, "Graph %s not found.", graph->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
const vector<SubGraphInfoPtr> &subgraph_info_list = iter->second; | |||||
vector<SubgraphPtr> subgraphs; | vector<SubgraphPtr> subgraphs; | ||||
GE_TIMESTAMP_START(ConvertSubgraphs); | GE_TIMESTAMP_START(ConvertSubgraphs); | ||||
Status status = ConvertSubgraphs(subgraph_info_list, engine_confs, subgraphs); | |||||
Status status = ConvertSubgraphs(subgraph_infos, engine_confs, subgraphs); | |||||
GE_TIMESTAMP_END(ConvertSubgraphs, "GraphBuilder::AssignStreamConvertSubgraphs"); | GE_TIMESTAMP_END(ConvertSubgraphs, "GraphBuilder::AssignStreamConvertSubgraphs"); | ||||
if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
GELOGE(status, "Create subgraphs failed."); | GELOGE(status, "Create subgraphs failed."); | ||||
return status; | return status; | ||||
} | } | ||||
GELOGI("Subgraphs of graph %s:", graph->GetName().c_str()); | |||||
for (const auto &subgraph : subgraphs) { | |||||
if (subgraph != nullptr) { | |||||
GELOGI("subgraph: %s", subgraph->name.c_str()); | |||||
} | |||||
} | |||||
return RunPasses(graph, subgraphs); | |||||
return RunPasses(whole_graph, subgraphs, stream_num); | |||||
} | } | ||||
Status LogicalStreamAllocator::ConvertSubgraphs(const vector<SubGraphInfoPtr> &subgraph_infos, | Status LogicalStreamAllocator::ConvertSubgraphs(const vector<SubGraphInfoPtr> &subgraph_infos, | ||||
@@ -655,7 +636,8 @@ Status LogicalStreamAllocator::ConvertSubgraphs(const vector<SubGraphInfoPtr> &s | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vector<SubgraphPtr> &subgraphs) { | |||||
Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &whole_graph, const vector<SubgraphPtr> &subgraphs, | |||||
int64_t &stream_num) { | |||||
vector<LogicalStreamPassPtr> passes; | vector<LogicalStreamPassPtr> passes; | ||||
passes.emplace_back(MakeShared<AssignByLabelPass>()); | passes.emplace_back(MakeShared<AssignByLabelPass>()); | ||||
passes.emplace_back(MakeShared<IndependentStreamPass>()); | passes.emplace_back(MakeShared<IndependentStreamPass>()); | ||||
@@ -666,7 +648,7 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec | |||||
for (auto &pass : passes) { | for (auto &pass : passes) { | ||||
GE_CHECK_NOTNULL(pass); | GE_CHECK_NOTNULL(pass); | ||||
Status status = pass->Run(graph, subgraphs, context_); | |||||
Status status = pass->Run(whole_graph, subgraphs, context_); | |||||
if (status == SUCCESS) { | if (status == SUCCESS) { | ||||
GELOGI("Stream pass %s return SUCCESS.", pass->GetName().c_str()); | GELOGI("Stream pass %s return SUCCESS.", pass->GetName().c_str()); | ||||
} else if (status == NOT_CHANGED) { | } else if (status == NOT_CHANGED) { | ||||
@@ -677,42 +659,9 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | |||||
} | |||||
void LogicalStreamAllocator::RefreshContinuousStreams(const ComputeGraphPtr &graph) { | |||||
int64_t stream_num = context_.next_stream; | |||||
vector<bool> stream_has_node(stream_num); | |||||
for (const NodePtr &node : graph->GetAllNodes()) { | |||||
if (node != nullptr) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
if (op_desc != nullptr) { | |||||
int64_t stream_id = op_desc->GetStreamId(); | |||||
if (stream_id != kInvalidStream && stream_id < stream_num) { | |||||
stream_has_node[stream_id] = true; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
context_.next_stream = 0; | |||||
vector<int64_t> old_to_new_streams(stream_num, kInvalidStream); | |||||
for (size_t old_stream = 0; old_stream < stream_has_node.size(); ++old_stream) { | |||||
if (stream_has_node[old_stream]) { | |||||
old_to_new_streams[old_stream] = context_.next_stream; | |||||
++context_.next_stream; | |||||
} | |||||
} | |||||
stream_num = context_.next_stream; | |||||
GELOGI("Assigned logical stream num: %ld.", stream_num); | |||||
for (const NodePtr &node : graph->GetAllNodes()) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
if (op_desc != nullptr) { | |||||
int64_t stream_id = op_desc->GetStreamId(); | |||||
if (stream_id != kInvalidStream && stream_id < stream_num) { | |||||
op_desc->SetStreamId(old_to_new_streams[stream_id]); | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -60,7 +60,7 @@ class LogicalStreamPass { | |||||
}; | }; | ||||
struct Context { | struct Context { | ||||
int64_t parent_stream = kInvalidStream; | |||||
// Next stream id. | |||||
int64_t next_stream = 0; | int64_t next_stream = 0; | ||||
bool hcom_parallel = false; | bool hcom_parallel = false; | ||||
}; | }; | ||||
@@ -71,7 +71,7 @@ class LogicalStreamPass { | |||||
virtual ~LogicalStreamPass() = default; | virtual ~LogicalStreamPass() = default; | ||||
const std::string &GetName() const; | const std::string &GetName() const; | ||||
virtual Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) = 0; | |||||
virtual Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) = 0; | |||||
protected: | protected: | ||||
bool IsEngineSkip(const Subgraph &subgraph) const; | bool IsEngineSkip(const Subgraph &subgraph) const; | ||||
@@ -93,21 +93,21 @@ using LogicalStreamPassPtr = std::shared_ptr<LogicalStreamPass>; | |||||
class AssignByLabelPass : public LogicalStreamPass { | class AssignByLabelPass : public LogicalStreamPass { | ||||
public: | public: | ||||
STREAM_PASS_DEFAULT_FUNC(AssignByLabelPass); | STREAM_PASS_DEFAULT_FUNC(AssignByLabelPass); | ||||
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
}; | }; | ||||
// Engines such as hccl require independent Stream. | // Engines such as hccl require independent Stream. | ||||
class IndependentStreamPass : public LogicalStreamPass { | class IndependentStreamPass : public LogicalStreamPass { | ||||
public: | public: | ||||
STREAM_PASS_DEFAULT_FUNC(IndependentStreamPass); | STREAM_PASS_DEFAULT_FUNC(IndependentStreamPass); | ||||
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
}; | }; | ||||
// Reuse streams or assign new streams based on dependencies. | // Reuse streams or assign new streams based on dependencies. | ||||
class AssignByDependencyPass : public LogicalStreamPass { | class AssignByDependencyPass : public LogicalStreamPass { | ||||
public: | public: | ||||
STREAM_PASS_DEFAULT_FUNC(AssignByDependencyPass); | STREAM_PASS_DEFAULT_FUNC(AssignByDependencyPass); | ||||
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
private: | private: | ||||
void InitEndSubgraphMap(const std::vector<SubgraphPtr> &subgraphs, std::map<NodePtr, SubgraphPtr> &end_subgraph_map); | void InitEndSubgraphMap(const std::vector<SubgraphPtr> &subgraphs, std::map<NodePtr, SubgraphPtr> &end_subgraph_map); | ||||
@@ -132,7 +132,7 @@ class AssignByDependencyPass : public LogicalStreamPass { | |||||
std::map<std::string, int64_t> engine_stream_num_; | std::map<std::string, int64_t> engine_stream_num_; | ||||
// Subgraphs of assign stream by engine | // Subgraphs of assign stream by engine | ||||
std::vector<SubgraphPtr> assigned_subgraphs_; | |||||
std::set<SubgraphPtr> assigned_subgraphs_; | |||||
// <current subgraph, reused subgraph> | // <current subgraph, reused subgraph> | ||||
std::vector<std::pair<SubgraphPtr, SubgraphPtr>> reused_subgraphs_; | std::vector<std::pair<SubgraphPtr, SubgraphPtr>> reused_subgraphs_; | ||||
@@ -142,7 +142,7 @@ class AssignByDependencyPass : public LogicalStreamPass { | |||||
class NodeStreamUpdatePass : public LogicalStreamPass { | class NodeStreamUpdatePass : public LogicalStreamPass { | ||||
public: | public: | ||||
STREAM_PASS_DEFAULT_FUNC(NodeStreamUpdatePass); | STREAM_PASS_DEFAULT_FUNC(NodeStreamUpdatePass); | ||||
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
private: | private: | ||||
/// Optimize for case like: | /// Optimize for case like: | ||||
@@ -150,18 +150,19 @@ class NodeStreamUpdatePass : public LogicalStreamPass { | |||||
/// To case: | /// To case: | ||||
/// NodeA(stream1) -> Const(stream1) -> NodeB(stream1) | /// NodeA(stream1) -> Const(stream1) -> NodeB(stream1) | ||||
/// Which could reduce event number (Const could be other type which belong to skipped engine subgraph) | /// Which could reduce event number (Const could be other type which belong to skipped engine subgraph) | ||||
Status UpdateForSkippedEngine(const ComputeGraphPtr &graph, const std::vector<SubgraphPtr> &subgraphs); | |||||
Status UpdateForSkippedEngine(const ComputeGraphPtr &whole_graph, const std::vector<SubgraphPtr> &subgraphs); | |||||
int64_t GetSingleInoutStream(const NodePtr &node) const; | int64_t GetSingleInoutStream(const NodePtr &node) const; | ||||
// Judge if all predecessors' streams of node are INVALID_STREAM | // Judge if all predecessors' streams of node are INVALID_STREAM | ||||
bool AreAllPredStreamsInvalid(const NodePtr &node) const; | bool AreAllPredStreamsInvalid(const NodePtr &node) const; | ||||
void RefreshContinuousStreams(ComputeGraphPtr whole_graph, Context &context) const; | |||||
}; | }; | ||||
// AllReduce and backward operators execute in parallel. | // AllReduce and backward operators execute in parallel. | ||||
class AllReduceParallelPass : public LogicalStreamPass { | class AllReduceParallelPass : public LogicalStreamPass { | ||||
public: | public: | ||||
STREAM_PASS_DEFAULT_FUNC(AllReduceParallelPass); | STREAM_PASS_DEFAULT_FUNC(AllReduceParallelPass); | ||||
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
}; | }; | ||||
// Assign logical streams which is not limited by the number of tasks. | // Assign logical streams which is not limited by the number of tasks. | ||||
@@ -177,16 +178,13 @@ class LogicalStreamAllocator { | |||||
LogicalStreamAllocator &operator=(const LogicalStreamAllocator &) = delete; | LogicalStreamAllocator &operator=(const LogicalStreamAllocator &) = delete; | ||||
~LogicalStreamAllocator() = default; | ~LogicalStreamAllocator() = default; | ||||
Status Assign(const ComputeGraphPtr &whole_graph, const Graph2SubGraphInfoList &subgraph_map, int64_t &stream_num); | |||||
Status Assign(const ComputeGraphPtr &whole_graph, const std::vector<SubGraphInfoPtr> &subgraphs, int64_t &stream_num); | |||||
private: | private: | ||||
Status DoAssign(const ComputeGraphPtr &graph, const Graph2SubGraphInfoList &subgraph_map, | |||||
const map<string, EngineConfPtr> &engine_confs); | |||||
Status ConvertSubgraphs(const std::vector<SubGraphInfoPtr> &subgraph_infos, | Status ConvertSubgraphs(const std::vector<SubGraphInfoPtr> &subgraph_infos, | ||||
const std::map<std::string, EngineConfPtr> &engine_confs, | const std::map<std::string, EngineConfPtr> &engine_confs, | ||||
std::vector<SubgraphPtr> &subgraphs); | std::vector<SubgraphPtr> &subgraphs); | ||||
Status RunPasses(const ComputeGraphPtr &graph, const std::vector<SubgraphPtr> &subgraphs); | |||||
void RefreshContinuousStreams(const ComputeGraphPtr &graph); | |||||
Status RunPasses(const ComputeGraphPtr &whole_graph, const std::vector<SubgraphPtr> &subgraphs, int64_t &stream_num); | |||||
const std::map<std::string, SchedulerConf> &scheduler_confs_; | const std::map<std::string, SchedulerConf> &scheduler_confs_; | ||||
const std::map<std::string, int> &max_parallel_num_; | const std::map<std::string, int> &max_parallel_num_; | ||||
@@ -33,6 +33,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external) | |||||
include_directories(${GE_SOURCE_DIR}/inc/external/graph) | include_directories(${GE_SOURCE_DIR}/inc/external/graph) | ||||
include_directories(${GE_SOURCE_DIR}/inc/framework) | include_directories(${GE_SOURCE_DIR}/inc/framework) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | |||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
@@ -100,13 +100,13 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector<int64_t> &range_ceils) { | |||||
GELOGD("Origin ranges:"); | GELOGD("Origin ranges:"); | ||||
for (auto &v : ranges) { | for (auto &v : ranges) { | ||||
GELOGD("__%s", ToString(v).c_str()); | |||||
GELOGD("__%s", domi::ToString(v).c_str()); | |||||
} | } | ||||
PlanRanges(range_number_limit, ranges); | PlanRanges(range_number_limit, ranges); | ||||
GELOGD("Origin ranges:"); | GELOGD("Origin ranges:"); | ||||
for (auto &v : ranges) { | for (auto &v : ranges) { | ||||
GELOGD("__%s", ToString(v).c_str()); | |||||
GELOGD("__%s", domi::ToString(v).c_str()); | |||||
} | } | ||||
for (auto &range : ranges) { | for (auto &range : ranges) { | ||||
@@ -115,7 +115,7 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector<int64_t> &range_ceils) { | |||||
range_ceils.push_back(range.back()); | range_ceils.push_back(range.back()); | ||||
} | } | ||||
} | } | ||||
GELOGI("Range ceils: %s", ToString(range_ceils).c_str()); | |||||
GELOGI("Range ceils: %s", domi::ToString(range_ceils).c_str()); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -29,6 +29,7 @@ | |||||
#include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "common/op/attr_define.h" | |||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/optimize/common/params.h" | #include "graph/optimize/common/params.h" | ||||
@@ -46,6 +47,29 @@ const int kReuseMaxCharNum = 2000; | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
using domi::AIPP_DATA_TYPE; | |||||
using domi::AIPPDATA; | |||||
using domi::ANN_DATA_TYPE; | |||||
using domi::APPLYMOMENTUM; | |||||
using domi::ASSIGN; | |||||
using domi::ASSIGNADD; | |||||
using domi::ASSIGNSUB; | |||||
using domi::CONSTANT; | |||||
using domi::CONSTANTOP; | |||||
using domi::DATA; | |||||
using domi::DATA_TYPE; | |||||
using domi::ENTER; | |||||
using domi::FASTRCNNPREDICTIONS; | |||||
using domi::HCOMALLREDUCE; | |||||
using domi::HCOMBROADCAST; | |||||
using domi::MULTISHAPE; | |||||
using domi::NETOUTPUT; | |||||
using domi::NEXTITERATION; | |||||
using domi::PROPOSAL; | |||||
using domi::REFENTER; | |||||
using domi::REFNEXTITERATION; | |||||
using domi::VARIABLE; | |||||
using domi::ZEROSLIKE; | |||||
using std::map; | using std::map; | ||||
using std::pair; | using std::pair; | ||||
using std::string; | using std::string; | ||||
@@ -134,7 +158,7 @@ string ToString(ge::NodeTypeIndex &x) { | |||||
string MemoryBlock::String() { | string MemoryBlock::String() { | ||||
stringstream ss; | stringstream ss; | ||||
ss << "Block size: " << Size() << " from " << HeadOffset() << " to " << TailOffset() << ""; | ss << "Block size: " << Size() << " from " << HeadOffset() << " to " << TailOffset() << ""; | ||||
ss << "real_size_list: " << ToString(real_size_list_) << ""; | |||||
ss << "real_size_list: " << domi::ToString(real_size_list_) << ""; | |||||
ss << "ref_count: " << ref_count_ << ""; | ss << "ref_count: " << ref_count_ << ""; | ||||
ss << "members: "; | ss << "members: "; | ||||
for (auto x : NodeTypeIndexList()) { | for (auto x : NodeTypeIndexList()) { | ||||
@@ -175,7 +199,7 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) { | |||||
all_memory_size.insert(all_memory_size.end(), temp.begin(), temp.end()); | all_memory_size.insert(all_memory_size.end(), temp.begin(), temp.end()); | ||||
} | } | ||||
sort(all_memory_size.begin(), all_memory_size.end()); | sort(all_memory_size.begin(), all_memory_size.end()); | ||||
GELOGI("All memory size: %s", ToString(all_memory_size).c_str()); | |||||
GELOGI("All memory size: %s", domi::ToString(all_memory_size).c_str()); | |||||
for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) { | for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) { | ||||
if (*iter == 0) { | if (*iter == 0) { | ||||
@@ -18,6 +18,7 @@ | |||||
#include <cstring> | #include <cstring> | ||||
#include <set> | #include <set> | ||||
#include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
#include "common/op/attr_define.h" | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/build/memory/hybrid_mem_assigner.h" | #include "graph/build/memory/hybrid_mem_assigner.h" | ||||
#include "graph/build/memory/var_mem_assign_util.h" | #include "graph/build/memory/var_mem_assign_util.h" | ||||
@@ -28,6 +29,19 @@ | |||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
using domi::AIPP_DATA_TYPE; | |||||
using domi::ATOMICADDRCLEAN; | |||||
using domi::ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||||
using domi::ATTR_NAME_AUTOMIC_ADD_START; | |||||
using domi::CONCAT; | |||||
using domi::CONSTANTOP; | |||||
using domi::DATA_TYPE; | |||||
using domi::HCOMBROADCAST; | |||||
using domi::LABELSWITCHBYINDEX; | |||||
using domi::NODE_NAME_NET_OUTPUT; | |||||
using domi::STREAMMERGE; | |||||
using domi::VARIABLE; | |||||
namespace { | namespace { | ||||
const int kDataOutputIndex = 0; | const int kDataOutputIndex = 0; | ||||
const int kAllInputAddrIsAtomic = -1; | const int kAllInputAddrIsAtomic = -1; | ||||
@@ -423,10 +437,8 @@ Status GraphMemoryAssigner::ReAssignReuseAndNoPaddingContinuousInputMemory() { | |||||
pre_mem_offset, peer_op_desc->GetStreamId(), out_size, output_mem_size); | pre_mem_offset, peer_op_desc->GetStreamId(), out_size, output_mem_size); | ||||
} | } | ||||
memory_offset_[0].mem_offset_ += extra_memory_size; | memory_offset_[0].mem_offset_ += extra_memory_size; | ||||
size_t after_mem_offset = memory_offset_[0].mem_offset_; | |||||
AlignMemOffset(MEM_ALIGN_SIZE); | |||||
GELOGI("After reassign virtual input node[name:%s, type:%s] memory, memory offset = %zu, align memory = %zu.", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), after_mem_offset, memory_offset_[0].mem_offset_); | |||||
GELOGI("After reassign virtual input node[name:%s, type:%s] memory, memory offset = %zu.", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), memory_offset_[0].mem_offset_); | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -511,10 +523,8 @@ Status GraphMemoryAssigner::ReAssignReuseAndNoPaddingContinuousOutputMemory() { | |||||
} | } | ||||
op_desc->SetOutputOffset(output_list); | op_desc->SetOutputOffset(output_list); | ||||
memory_offset_[0].mem_offset_ += extra_memory_size; | memory_offset_[0].mem_offset_ += extra_memory_size; | ||||
size_t after_mem_offset = memory_offset_[0].mem_offset_; | |||||
AlignMemOffset(MEM_ALIGN_SIZE); | |||||
GELOGI("After reassign virtual output node[name:%s, type:%s] memory, memory offset = %zu, align memory = %zu.", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), after_mem_offset, memory_offset_[0].mem_offset_); | |||||
GELOGI("After reassign virtual output node[name:%s, type:%s] memory, memory offset = %zu.", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), memory_offset_[0].mem_offset_); | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -16,6 +16,7 @@ | |||||
#include "graph/build/memory/var_mem_assign_util.h" | #include "graph/build/memory/var_mem_assign_util.h" | ||||
#include <vector> | #include <vector> | ||||
#include "common/op/attr_define.h" | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/common/transop_util.h" | #include "graph/common/transop_util.h" | ||||
@@ -50,10 +51,10 @@ Status VarMemAssignUtil::AssignMemory2VariableNode(ge::ComputeGraphPtr &compute_ | |||||
Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_graph) { | Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_graph) { | ||||
GE_IF_BOOL_EXEC(compute_graph == nullptr, return FAILED); | GE_IF_BOOL_EXEC(compute_graph == nullptr, return FAILED); | ||||
for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { | for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { | ||||
GE_IF_BOOL_EXEC((n->GetType() != VARIABLE) && (n->GetType() != CONSTANTOP), continue); | |||||
GE_IF_BOOL_EXEC((n->GetType() != domi::VARIABLE) && (n->GetType() != domi::CONSTANTOP), continue); | |||||
string ref_var_src_var_name; | string ref_var_src_var_name; | ||||
GE_CHECK_NOTNULL(n->GetOpDesc()); | GE_CHECK_NOTNULL(n->GetOpDesc()); | ||||
GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(n->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name), continue); | |||||
GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(n->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_src_var_name), continue); | |||||
string node_name = n->GetName(); | string node_name = n->GetName(); | ||||
GE_IF_BOOL_EXEC(n->GetOpDesc()->GetAllOutputsDesc().empty(), | GE_IF_BOOL_EXEC(n->GetOpDesc()->GetAllOutputsDesc().empty(), | ||||
GELOGE(FAILED, "node:%s has no OutputDesc.", n->GetName().c_str()); | GELOGE(FAILED, "node:%s has no OutputDesc.", n->GetName().c_str()); | ||||
@@ -63,7 +64,7 @@ Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_gr | |||||
if (!VarManager::Instance(compute_graph->GetSessionID())->IsVarExist(node_name, *tensor_desc)) { | if (!VarManager::Instance(compute_graph->GetSessionID())->IsVarExist(node_name, *tensor_desc)) { | ||||
GE_CHK_STATUS_RET( | GE_CHK_STATUS_RET( | ||||
VarManager::Instance(compute_graph->GetSessionID())->AssignVarMem(node_name, *tensor_desc, RT_MEMORY_HBM)); | VarManager::Instance(compute_graph->GetSessionID())->AssignVarMem(node_name, *tensor_desc, RT_MEMORY_HBM)); | ||||
GE_IF_BOOL_EXEC(n->GetType() == VARIABLE, | |||||
GE_IF_BOOL_EXEC(n->GetType() == domi::VARIABLE, | |||||
GE_CHK_STATUS_RET(AssignData2Fp32Var(n, compute_graph->GetSessionID()))); | GE_CHK_STATUS_RET(AssignData2Fp32Var(n, compute_graph->GetSessionID()))); | ||||
GE_CHK_STATUS_RET(VarManager::Instance(compute_graph->GetSessionID()) | GE_CHK_STATUS_RET(VarManager::Instance(compute_graph->GetSessionID()) | ||||
->SetAllocatedGraphId(node_name, compute_graph->GetGraphID())); | ->SetAllocatedGraphId(node_name, compute_graph->GetGraphID())); | ||||
@@ -84,7 +85,7 @@ Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_gr | |||||
Status VarMemAssignUtil::AssignData2Fp32Var(const ge::NodePtr &node, uint64_t session_id) { | Status VarMemAssignUtil::AssignData2Fp32Var(const ge::NodePtr &node, uint64_t session_id) { | ||||
string src_var_name; | string src_var_name; | ||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
if (ge::AttrUtils::GetStr(node->GetOpDesc(), VAR_ATTR_SRC_VAR_NAME, src_var_name)) { | |||||
if (ge::AttrUtils::GetStr(node->GetOpDesc(), domi::VAR_ATTR_SRC_VAR_NAME, src_var_name)) { | |||||
ge::GeTensorDesc cur_tensor_desc; | ge::GeTensorDesc cur_tensor_desc; | ||||
uint8_t *dev_ptr = nullptr; | uint8_t *dev_ptr = nullptr; | ||||
rtMemType_t memory_type = RT_MEMORY_HBM; | rtMemType_t memory_type = RT_MEMORY_HBM; | ||||
@@ -99,10 +100,11 @@ Status VarMemAssignUtil::AssignData2Fp32Var(const ge::NodePtr &node, uint64_t se | |||||
Status VarMemAssignUtil::AssignVarAttr2Nodes(ge::ComputeGraphPtr &compute_graph) { | Status VarMemAssignUtil::AssignVarAttr2Nodes(ge::ComputeGraphPtr &compute_graph) { | ||||
for (const ge::NodePtr &node : compute_graph->GetDirectNode()) { | for (const ge::NodePtr &node : compute_graph->GetDirectNode()) { | ||||
GE_IF_BOOL_EXEC(node->GetType() != VARIABLE, continue); | |||||
GE_IF_BOOL_EXEC(node->GetType() != domi::VARIABLE, continue); | |||||
string ref_var_src_var_name; | string ref_var_src_var_name; | ||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name), continue); | |||||
GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_src_var_name), | |||||
continue); | |||||
GE_CHK_STATUS_RET(DealVariableNode(compute_graph->GetGraphID(), node, compute_graph->GetSessionID())); | GE_CHK_STATUS_RET(DealVariableNode(compute_graph->GetGraphID(), node, compute_graph->GetSessionID())); | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -140,7 +142,8 @@ Status VarMemAssignUtil::DealExportVariableNode(const ge::NodePtr &node, const g | |||||
GE_IF_BOOL_EXEC(var_out_anchor == nullptr, return FAILED); | GE_IF_BOOL_EXEC(var_out_anchor == nullptr, return FAILED); | ||||
for (const ge::InDataAnchorPtr &dst_in_var_anchor : var_out_anchor->GetPeerInDataAnchors()) { | for (const ge::InDataAnchorPtr &dst_in_var_anchor : var_out_anchor->GetPeerInDataAnchors()) { | ||||
ge::NodePtr dst_node = dst_in_var_anchor->GetOwnerNode(); | ge::NodePtr dst_node = dst_in_var_anchor->GetOwnerNode(); | ||||
if ((dst_node->GetType() == ASSIGN) || (dst_node->GetType() == ASSIGNADD) || (dst_node->GetType() == ASSIGNSUB)) { | |||||
if ((dst_node->GetType() == domi::ASSIGN) || (dst_node->GetType() == domi::ASSIGNADD) || | |||||
(dst_node->GetType() == domi::ASSIGNSUB)) { | |||||
if (dst_in_var_anchor == dst_node->GetInDataAnchor(0)) { | if (dst_in_var_anchor == dst_node->GetInDataAnchor(0)) { | ||||
GE_CHK_STATUS_RET(DealExportVariableNode(dst_node, var_node, session_id)); | GE_CHK_STATUS_RET(DealExportVariableNode(dst_node, var_node, session_id)); | ||||
} | } | ||||
@@ -208,19 +211,20 @@ Status VarMemAssignUtil::DealVariableNode(uint32_t graph_id, const ge::NodePtr & | |||||
for (const ge::OutDataAnchorPtr &var_out_data_anchor : node->GetAllOutDataAnchors()) { | for (const ge::OutDataAnchorPtr &var_out_data_anchor : node->GetAllOutDataAnchors()) { | ||||
for (const ge::InDataAnchorPtr &dst_in_data_anchor : var_out_data_anchor->GetPeerInDataAnchors()) { | for (const ge::InDataAnchorPtr &dst_in_data_anchor : var_out_data_anchor->GetPeerInDataAnchors()) { | ||||
ge::NodePtr dst_node = dst_in_data_anchor->GetOwnerNode(); | ge::NodePtr dst_node = dst_in_data_anchor->GetOwnerNode(); | ||||
if (dst_node->GetType() == HCOMBROADCAST) { | |||||
if (dst_node->GetType() == domi::HCOMBROADCAST) { | |||||
GE_CHK_STATUS_RET(DealBroadCastNode(graph_id, dst_node, dst_in_data_anchor, node, session_id)); | GE_CHK_STATUS_RET(DealBroadCastNode(graph_id, dst_node, dst_in_data_anchor, node, session_id)); | ||||
continue; | continue; | ||||
} | } | ||||
if ((dst_node->GetType() == ASSIGN) || (dst_node->GetType() == ASSIGNADD) || (dst_node->GetType() == ASSIGNSUB)) { | |||||
if ((dst_node->GetType() == domi::ASSIGN) || (dst_node->GetType() == domi::ASSIGNADD) || | |||||
(dst_node->GetType() == domi::ASSIGNSUB)) { | |||||
if (dst_in_data_anchor == dst_node->GetInDataAnchor(0)) { | if (dst_in_data_anchor == dst_node->GetInDataAnchor(0)) { | ||||
GE_CHK_STATUS_RET(DealExportVariableNode(dst_node, node, session_id)); | GE_CHK_STATUS_RET(DealExportVariableNode(dst_node, node, session_id)); | ||||
} | } | ||||
} | } | ||||
auto dst_type = dst_node->GetType(); | auto dst_type = dst_node->GetType(); | ||||
bool is_trans_node = | |||||
(dst_type == TRANSDATA) || (dst_type == CAST) || (dst_type == TRANSPOSE) || (dst_type == PERMUTE); | |||||
bool is_trans_node = (dst_type == domi::TRANSDATA) || (dst_type == domi::CAST) || (dst_type == domi::TRANSPOSE) || | |||||
(dst_type == domi::PERMUTE); | |||||
if (is_trans_node) { | if (is_trans_node) { | ||||
NodePtr final_trans_node = GetFinalTransNode(dst_node); | NodePtr final_trans_node = GetFinalTransNode(dst_node); | ||||
GE_CHK_STATUS_RET(DealTransNode(final_trans_node)); | GE_CHK_STATUS_RET(DealTransNode(final_trans_node)); | ||||
@@ -237,8 +241,8 @@ ge::NodePtr VarMemAssignUtil::GetFinalTransNode(const ge::NodePtr &trans_node) { | |||||
for (const auto &dst_in_anchor : trans_out_data_anchor->GetPeerInDataAnchors()) { | for (const auto &dst_in_anchor : trans_out_data_anchor->GetPeerInDataAnchors()) { | ||||
NodePtr dst_node = dst_in_anchor->GetOwnerNode(); | NodePtr dst_node = dst_in_anchor->GetOwnerNode(); | ||||
auto dst_type = dst_node->GetType(); | auto dst_type = dst_node->GetType(); | ||||
bool is_trans_node = | |||||
(dst_type == TRANSDATA) || (dst_type == CAST) || (dst_type == TRANSPOSE) || (dst_type == PERMUTE); | |||||
bool is_trans_node = (dst_type == domi::TRANSDATA) || (dst_type == domi::CAST) || (dst_type == domi::TRANSPOSE) || | |||||
(dst_type == domi::PERMUTE); | |||||
if (is_trans_node && (dst_in_anchor->GetIdx() == 0)) { | if (is_trans_node && (dst_in_anchor->GetIdx() == 0)) { | ||||
final_ref_node = GetFinalTransNode(dst_node); | final_ref_node = GetFinalTransNode(dst_node); | ||||
} | } | ||||
@@ -252,7 +256,8 @@ Status VarMemAssignUtil::DealTransNode(const ge::NodePtr &final_trans_node) { | |||||
GE_IF_BOOL_EXEC(final_trans_out_anchor == nullptr, return SUCCESS); | GE_IF_BOOL_EXEC(final_trans_out_anchor == nullptr, return SUCCESS); | ||||
for (const ge::InDataAnchorPtr &dst_in_var_anchor : final_trans_out_anchor->GetPeerInDataAnchors()) { | for (const ge::InDataAnchorPtr &dst_in_var_anchor : final_trans_out_anchor->GetPeerInDataAnchors()) { | ||||
ge::NodePtr dst_node = dst_in_var_anchor->GetOwnerNode(); | ge::NodePtr dst_node = dst_in_var_anchor->GetOwnerNode(); | ||||
if ((dst_node->GetType() == ASSIGN) || (dst_node->GetType() == ASSIGNADD) || (dst_node->GetType() == ASSIGNSUB)) { | |||||
if ((dst_node->GetType() == domi::ASSIGN) || (dst_node->GetType() == domi::ASSIGNADD) || | |||||
(dst_node->GetType() == domi::ASSIGNSUB)) { | |||||
GE_CHK_STATUS_RET(DealExportTransNode(dst_node, final_trans_node)); | GE_CHK_STATUS_RET(DealExportTransNode(dst_node, final_trans_node)); | ||||
} | } | ||||
} | } | ||||
@@ -264,7 +269,8 @@ Status VarMemAssignUtil::DealExportTransNode(const ge::NodePtr &node, const ge:: | |||||
GE_CHECK_NOTNULL(node_out_anchor); | GE_CHECK_NOTNULL(node_out_anchor); | ||||
for (const ge::InDataAnchorPtr &dst_in_var_anchor : node_out_anchor->GetPeerInDataAnchors()) { | for (const ge::InDataAnchorPtr &dst_in_var_anchor : node_out_anchor->GetPeerInDataAnchors()) { | ||||
ge::NodePtr dst_node = dst_in_var_anchor->GetOwnerNode(); | ge::NodePtr dst_node = dst_in_var_anchor->GetOwnerNode(); | ||||
if ((dst_node->GetType() == ASSIGN) || (dst_node->GetType() == ASSIGNADD) || (dst_node->GetType() == ASSIGNSUB)) { | |||||
if ((dst_node->GetType() == domi::ASSIGN) || (dst_node->GetType() == domi::ASSIGNADD) || | |||||
(dst_node->GetType() == domi::ASSIGNSUB)) { | |||||
GE_CHK_STATUS_RET(DealExportTransNode(dst_node, final_trans_node)); | GE_CHK_STATUS_RET(DealExportTransNode(dst_node, final_trans_node)); | ||||
} | } | ||||
} | } | ||||
@@ -300,7 +306,7 @@ Status VarMemAssignUtil::AssignMemory2HasRefAttrNode(ge::ComputeGraphPtr &comput | |||||
for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { | for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { | ||||
string ref_var_src_var_name; | string ref_var_src_var_name; | ||||
GE_CHECK_NOTNULL(n->GetOpDesc()); | GE_CHECK_NOTNULL(n->GetOpDesc()); | ||||
bool is_ref = ge::AttrUtils::GetStr(n->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); | |||||
bool is_ref = ge::AttrUtils::GetStr(n->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); | |||||
GE_IF_BOOL_EXEC(is_ref, | GE_IF_BOOL_EXEC(is_ref, | ||||
GE_CHK_STATUS_RET(AssignData2VarRef(n, ref_var_src_var_name, compute_graph->GetSessionID()))); | GE_CHK_STATUS_RET(AssignData2VarRef(n, ref_var_src_var_name, compute_graph->GetSessionID()))); | ||||
} | } | ||||
@@ -323,7 +329,7 @@ Status VarMemAssignUtil::AssignData2VarRef(const ge::NodePtr &has_ref_attr_node, | |||||
GE_CHECK_SIZE(ref_attr_node_output_list.size()); | GE_CHECK_SIZE(ref_attr_node_output_list.size()); | ||||
int out_index = 0; | int out_index = 0; | ||||
bool is_get = ge::AttrUtils::GetInt(var_ref_src_var->GetOpDesc(), REF_VAR_PRE_PEER_OUT_INDEX, out_index); | |||||
bool is_get = ge::AttrUtils::GetInt(var_ref_src_var->GetOpDesc(), domi::REF_VAR_PRE_PEER_OUT_INDEX, out_index); | |||||
if (!is_get) { | if (!is_get) { | ||||
GELOGI("%s failed to get attr [REF_VAR_PRE_PEER_OUT_INDEX]", var_ref_src_var->GetName().c_str()); | GELOGI("%s failed to get attr [REF_VAR_PRE_PEER_OUT_INDEX]", var_ref_src_var->GetName().c_str()); | ||||
} | } | ||||