@@ -50,7 +50,7 @@ CommentPragmas: '^ IWYU pragma:' | |||||
CompactNamespaces: false | CompactNamespaces: false | ||||
ConstructorInitializerAllOnOneLineOrOnePerLine: true | ConstructorInitializerAllOnOneLineOrOnePerLine: true | ||||
ConstructorInitializerIndentWidth: 4 | ConstructorInitializerIndentWidth: 4 | ||||
ContinuationIndentWidth: 2 | |||||
ContinuationIndentWidth: 4 | |||||
Cpp11BracedListStyle: true | Cpp11BracedListStyle: true | ||||
DerivePointerAlignment: true | DerivePointerAlignment: true | ||||
DisableFormat: false | DisableFormat: false | ||||
@@ -95,6 +95,7 @@ else () | |||||
#find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | #find_module(ascendcl_static libascendcl.a ${GE_LIB_PATH}) | ||||
else() | else() | ||||
find_module(slog libalog.so ${ASCEND_ATC_DIR}) | find_module(slog libalog.so ${ASCEND_ATC_DIR}) | ||||
find_module(opt_feature libopt_feature.so ${ASCEND_ATC_DIR}) | |||||
find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | find_module(static_mmpa libmmpa.a ${ASCEND_ATC_DIR}) | ||||
if(PLATFORM STREQUAL "train") | if(PLATFORM STREQUAL "train") | ||||
find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | find_module(adump_server libadump_server.a ${ASCEND_RUNTIME_DIR}) | ||||
@@ -144,7 +144,6 @@ build_graphengine() | |||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_UT=ON" | CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_UT=ON" | ||||
fi | fi | ||||
if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then | if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then | ||||
CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_ST=ON" | CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_ST=ON" | ||||
fi | fi | ||||
@@ -176,7 +175,7 @@ build_graphengine() | |||||
TARGET="ge_compiler atc_atc.bin ge_executor_shared ${TARGET}" | TARGET="ge_compiler atc_atc.bin ge_executor_shared ${TARGET}" | ||||
elif [ "X$ENABLE_GE_ST" = "Xon" ] | elif [ "X$ENABLE_GE_ST" = "Xon" ] | ||||
then | then | ||||
TARGET="ge_graph_dsl_test graph_engine_test" | |||||
TARGET="ge_graph_dsl_test ge_running_env_test graph_engine_test" | |||||
elif [ "X$ENABLE_GE_UT" = "Xon" ] | elif [ "X$ENABLE_GE_UT" = "Xon" ] | ||||
then | then | ||||
TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest" | TARGET="ut_libgraph ut_libge_multiparts_utest ut_libge_others_utest ut_libge_kernel_utest ut_libge_distinct_load_utest" | ||||
@@ -244,13 +243,13 @@ if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then | |||||
mkdir -p ${OUTPUT_PATH}/plugin/opskernel | mkdir -p ${OUTPUT_PATH}/plugin/opskernel | ||||
cp ${BUILD_PATH}/tests/framework/libnnengine.so ${OUTPUT_PATH}/plugin/nnengine | cp ${BUILD_PATH}/tests/framework/libnnengine.so ${OUTPUT_PATH}/plugin/nnengine | ||||
cp ${BUILD_PATH}/engine_conf.json ${OUTPUT_PATH}/plugin/nnengine/ge_config | cp ${BUILD_PATH}/engine_conf.json ${OUTPUT_PATH}/plugin/nnengine/ge_config | ||||
cp ${BUILD_PATH}/tests/framework/libhost_cpu_engine.so ${OUTPUT_PATH}/plugin/opskernel | |||||
cp ${BUILD_PATH}/tests/framework/libge_local_engine.so ${OUTPUT_PATH}/plugin/opskernel | cp ${BUILD_PATH}/tests/framework/libge_local_engine.so ${OUTPUT_PATH}/plugin/opskernel | ||||
cp ${BUILD_PATH}/tests/framework/stub_engine/libfe.so ${OUTPUT_PATH}/plugin/opskernel | |||||
#prepare st execution bin | #prepare st execution bin | ||||
cp ${BUILD_PATH}/tests/st/testcase/graph_engine_test ${OUTPUT_PATH} | cp ${BUILD_PATH}/tests/st/testcase/graph_engine_test ${OUTPUT_PATH} | ||||
cp ${BUILD_PATH}/tests/framework/ge_running_env/tests/ge_running_env_test ${OUTPUT_PATH} | |||||
cp ${BUILD_PATH}/tests/framework/ge_graph_dsl/tests/ge_graph_dsl_test ${OUTPUT_PATH} | cp ${BUILD_PATH}/tests/framework/ge_graph_dsl/tests/ge_graph_dsl_test ${OUTPUT_PATH} | ||||
#execute st testcase | #execute st testcase | ||||
RUN_TEST_CASE=${OUTPUT_PATH}/ge_running_env_test && ${RUN_TEST_CASE} | |||||
RUN_TEST_CASE=${OUTPUT_PATH}/graph_engine_test && ${RUN_TEST_CASE} | RUN_TEST_CASE=${OUTPUT_PATH}/graph_engine_test && ${RUN_TEST_CASE} | ||||
RUN_TEST_CASE=${OUTPUT_PATH}/ge_graph_dsl_test && ${RUN_TEST_CASE} | RUN_TEST_CASE=${OUTPUT_PATH}/ge_graph_dsl_test && ${RUN_TEST_CASE} | ||||
if [[ "$?" -ne 0 ]]; then | if [[ "$?" -ne 0 ]]; then | ||||
@@ -436,6 +436,7 @@ set(TRAIN_SRC_LIST | |||||
"graph/build/memory/max_block_mem_assigner.cc" | "graph/build/memory/max_block_mem_assigner.cc" | ||||
"graph/build/memory/var_mem_assign_util.cc" | "graph/build/memory/var_mem_assign_util.cc" | ||||
"graph/build/memory/buffer_pool_mem_assigner.cc" | "graph/build/memory/buffer_pool_mem_assigner.cc" | ||||
"ge_opt_info/ge_opt_info.cc" | |||||
) | ) | ||||
set(INFER_SRC_LIST | set(INFER_SRC_LIST | ||||
@@ -715,6 +716,7 @@ set(INFER_SRC_LIST | |||||
"graph/build/memory/max_block_mem_assigner.cc" | "graph/build/memory/max_block_mem_assigner.cc" | ||||
"graph/build/memory/var_mem_assign_util.cc" | "graph/build/memory/var_mem_assign_util.cc" | ||||
"graph/build/memory/buffer_pool_mem_assigner.cc" | "graph/build/memory/buffer_pool_mem_assigner.cc" | ||||
"ge_opt_info/ge_opt_info.cc" | |||||
) | ) | ||||
if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) | ||||
@@ -769,11 +771,13 @@ target_include_directories(ge_runner SYSTEM PRIVATE | |||||
${GE_CODE_DIR}/../inc | ${GE_CODE_DIR}/../inc | ||||
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | ||||
${GE_CODE_DIR}/../abl/adump/external | ${GE_CODE_DIR}/../abl/adump/external | ||||
${GE_CODE_DIR}/../abl/licctrl | |||||
#### blue zone | #### blue zone | ||||
${ASCEND_DIR}/driver/include | ${ASCEND_DIR}/driver/include | ||||
${ASCEND_DIR}/fwkacllib/include | ${ASCEND_DIR}/fwkacllib/include | ||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | ${GE_CODE_DIR}/third_party/fwkacllib/inc | ||||
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ||||
${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info | |||||
) | ) | ||||
target_link_options(ge_runner PRIVATE | target_link_options(ge_runner PRIVATE | ||||
@@ -796,6 +800,7 @@ target_link_libraries(ge_runner PRIVATE | |||||
runtime | runtime | ||||
error_manager | error_manager | ||||
ascend_hal_stub | ascend_hal_stub | ||||
opt_feature | |||||
-Wl,--as-needed | -Wl,--as-needed | ||||
json | json | ||||
-lrt | -lrt | ||||
@@ -843,11 +848,13 @@ target_include_directories(ge_compiler SYSTEM PRIVATE | |||||
${GE_CODE_DIR}/../inc | ${GE_CODE_DIR}/../inc | ||||
${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | ${GE_CODE_DIR}/../toolchain/ide/ide-daemon/external | ||||
${GE_CODE_DIR}/../abl/adump/external | ${GE_CODE_DIR}/../abl/adump/external | ||||
${GE_CODE_DIR}/../abl/licctrl | |||||
#### blue zone #### | #### blue zone #### | ||||
${ASCEND_DIR}/driver/include | ${ASCEND_DIR}/driver/include | ||||
${ASCEND_DIR}/fwkacllib/include | ${ASCEND_DIR}/fwkacllib/include | ||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | ${GE_CODE_DIR}/third_party/fwkacllib/inc | ||||
${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain | ||||
${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info | |||||
) | ) | ||||
target_link_options(ge_compiler PRIVATE | target_link_options(ge_compiler PRIVATE | ||||
@@ -867,6 +874,7 @@ target_link_libraries(ge_compiler PRIVATE | |||||
error_manager | error_manager | ||||
slog | slog | ||||
runtime_compile | runtime_compile | ||||
opt_feature | |||||
-Wl,--as-needed | -Wl,--as-needed | ||||
json | json | ||||
-lrt | -lrt | ||||
@@ -18,6 +18,7 @@ | |||||
#include <cstdio> | #include <cstdio> | ||||
#include <string> | #include <string> | ||||
#include <regex> | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "framework/common/util.h" | #include "framework/common/util.h" | ||||
@@ -37,6 +38,159 @@ const uint32_t kAtomicOverflow = (0x1 << 1); | |||||
const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); | const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); | ||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::Split(const std::string &s, | |||||
std::vector<std::string> &result, | |||||
const char *delchar) { | |||||
if (s.empty()) { | |||||
return; | |||||
} | |||||
result.clear(); | |||||
char *buffer = new (std::nothrow)char[s.size() + 1]; | |||||
if (buffer == nullptr) { | |||||
GELOGE(FAILED, "[Split][string] failed while malloc memory, string value is:%s", s.c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Memory malloc may fail when split string, get fatal exception, " | |||||
"string value is:%s", s.c_str()); | |||||
return; | |||||
} | |||||
buffer[s.size()] = '\0'; | |||||
errno_t e = strcpy_s(buffer, s.size() + 1, s.c_str()); | |||||
if (e != EOK) { | |||||
delete[] buffer; | |||||
return; | |||||
} | |||||
char *context = nullptr; | |||||
char *p = strtok_s(buffer, delchar, &context); | |||||
while (p != nullptr) { | |||||
result.emplace_back(p); | |||||
p = strtok_s(nullptr, delchar, &context); | |||||
} | |||||
delete[] buffer; | |||||
} | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckDumpStep(const std::string &dump_step) { | |||||
std::string modified_dum_step = dump_step + "|"; | |||||
std::smatch result; | |||||
std::vector<string> match_vecs; | |||||
std::regex pattern(R"((\d{1,}-\d{1,}\||\d{1,}\|)+)"); | |||||
if (regex_match(modified_dum_step, result, pattern)) { | |||||
Split(result.str(), match_vecs, "|"); | |||||
if (match_vecs.empty()) { | |||||
REPORT_CALL_ERROR("E19999", "Split may get fatal exception, dump_step:%s.", dump_step.c_str()); | |||||
GELOGE(FAILED, "[Check][Param] failed. Split may get fatal exception, ge.exec.dumpStep:%s.", dump_step.c_str()); | |||||
return FAILED; | |||||
} | |||||
// 100 is the max sets of dump steps. | |||||
if (match_vecs.size() > 100) { | |||||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
std::vector<std::string>({ | |||||
"ge.exec.dumpStep", | |||||
dump_step.c_str(), | |||||
" is not supported, only support dump <= 100 sets of data"})); | |||||
GELOGE(PARAM_INVALID, "[Check][Param] get dump_step value:%s, " | |||||
"dump_step only support dump <= 100 sets of data.", dump_step.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
for (const auto &match_vec : match_vecs) { | |||||
std::vector<string> vec_after_split; | |||||
Split(match_vec, vec_after_split, "-"); | |||||
if (match_vecs.empty()) { | |||||
REPORT_CALL_ERROR("E19999", "Split may get fatal exception."); | |||||
GELOGE(FAILED, "[Check][Param] failed, split may get fatal exception."); | |||||
return FAILED; | |||||
} | |||||
if (vec_after_split.size() > 1) { | |||||
if (std::atoi(vec_after_split[0].c_str()) >= std::atoi(vec_after_split[1].c_str())) { | |||||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
std::vector<std::string>({ | |||||
"ge.exec.dumpStep", | |||||
dump_step.c_str(), | |||||
" is not supported." | |||||
"in range steps, the first step is >= second step, correct example:'0|5|10-20"})); | |||||
GELOGE(PARAM_INVALID, "[Check][Param] get dump_step value:%s, " | |||||
"in range steps, the first step is >= second step, correct example:'0|5|10-20'", dump_step.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
std::vector<std::string>({ | |||||
"ge.exec.dumpStep", | |||||
dump_step.c_str(), | |||||
" is not supported, correct example:'0|5|10|50-100."})); | |||||
GELOGE(PARAM_INVALID, "[Check][Param] get dump_step value:%s, " | |||||
"dump_step string style is error, correct example:'0|5|10|50-100.'", dump_step.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckDumpMode(const std::string &dump_mode) { | |||||
const std::set<string> dump_mode_list = {"input", "output", "all"}; | |||||
std::set<string>::iterator iter; | |||||
if ((iter = dump_mode_list.find(dump_mode)) == dump_mode_list.end()) { | |||||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
std::vector<std::string>({ | |||||
"ge.exec.dumpMode", | |||||
dump_mode.c_str(), | |||||
" is not supported, should be one of the following:[input, output, all]"})); | |||||
GELOGE(PARAM_INVALID, "[Check][Param] the dump_debug_mode:%s, is is not supported," | |||||
"should be one of the following:[input, output, all].", dump_mode.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckDumpPath(const std::string &input) { | |||||
if (mmIsDir(input.c_str()) != EN_OK) { | |||||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
std::vector<std::string>({ | |||||
"ge.exec.dumpPath", | |||||
input.c_str(), | |||||
" is not a directory."})); | |||||
GELOGE(PARAM_INVALID, "[Check][Param] the path:%s, is not directory.", input.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
char trusted_path[MMPA_MAX_PATH] = { "\0" }; | |||||
if (mmRealPath(input.c_str(), trusted_path, MMPA_MAX_PATH) != EN_OK) { | |||||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
std::vector<std::string>({ | |||||
"ge.exec.dumpPath", | |||||
input.c_str(), | |||||
" dumpPath invalid."})); | |||||
GELOGE(PARAM_INVALID, "[Check][Param] the dumpPath:%s, is invalid.", input.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (mmAccess2(trusted_path, R_OK | W_OK) != EN_OK) { | |||||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
std::vector<std::string>({ | |||||
"ge.exec.dumpPath", | |||||
input.c_str(), | |||||
" does't have read, write permissions."})); | |||||
GELOGE(PARAM_INVALID, "[Check][Param] the path:%s, does't have read, write permissions.", input.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::CheckEnableDump(const std::string &input) { | |||||
std::set<string> enable_dump_option_list = {"1", "0"}; | |||||
auto it = enable_dump_option_list.find(input); | |||||
if (it == enable_dump_option_list.end()) { | |||||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
std::vector<std::string>({ | |||||
"ge.exec.enableDump", | |||||
input.c_str(), | |||||
" only support 1 or 0."})); | |||||
GELOGE(PARAM_INVALID, "[Check][Param] Not support ge.exec.enableDump or ge.exec.enableDumpDebug format:%s, " | |||||
"only support 1 or 0.", input.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties::DumpProperties(const DumpProperties &other) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties::DumpProperties(const DumpProperties &other) { | ||||
CopyFrom(other); | CopyFrom(other); | ||||
} | } | ||||
@@ -47,7 +201,26 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &DumpProperties: | |||||
return *this; | return *this; | ||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitByOptions() { | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::SetDumpOptions() { | |||||
if (enable_dump_ == kEnableFlag) { | |||||
std::string dump_step; | |||||
if (GetContext().GetOption(OPTION_EXEC_DUMP_STEP, dump_step) == GRAPH_SUCCESS) { | |||||
GE_CHK_STATUS_RET(CheckDumpStep(dump_step), "[Check][dump_step] failed."); | |||||
GELOGI("Get dump step %s successfully", dump_step.c_str()); | |||||
SetDumpStep(dump_step); | |||||
} | |||||
string dump_mode = "output"; | |||||
if (GetContext().GetOption(OPTION_EXEC_DUMP_MODE, dump_mode) == GRAPH_SUCCESS) { | |||||
GELOGI("Get dump mode %s successfully", dump_mode.c_str()); | |||||
GE_CHK_STATUS_RET(CheckDumpMode(dump_mode), "[Check][dump_mode] failed."); | |||||
SetDumpMode(dump_mode); | |||||
} | |||||
AddPropertyValue(DUMP_ALL_MODEL, {}); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status DumpProperties::InitByOptions() { | |||||
enable_dump_.clear(); | enable_dump_.clear(); | ||||
enable_dump_debug_.clear(); | enable_dump_debug_.clear(); | ||||
dump_path_.clear(); | dump_path_.clear(); | ||||
@@ -57,17 +230,32 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitByOpti | |||||
is_infer_op_debug_ = false; | is_infer_op_debug_ = false; | ||||
op_debug_mode_ = 0; | op_debug_mode_ = 0; | ||||
std::string enable_dump; | |||||
std::string enable_dump = std::to_string(false); | |||||
(void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP, enable_dump); | (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP, enable_dump); | ||||
enable_dump_ = enable_dump; | enable_dump_ = enable_dump; | ||||
if (!enable_dump_.empty()) { | |||||
GE_CHK_STATUS_RET(CheckEnableDump(enable_dump_), "[Check][enable_dump] failed."); | |||||
} | |||||
std::string enable_dump_debug; | |||||
std::string enable_dump_debug = std::to_string(false); | |||||
(void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP_DEBUG, enable_dump_debug); | (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP_DEBUG, enable_dump_debug); | ||||
enable_dump_debug_ = enable_dump_debug; | enable_dump_debug_ = enable_dump_debug; | ||||
if (!enable_dump_debug_.empty()) { | |||||
GE_CHK_STATUS_RET(CheckEnableDump(enable_dump_debug_), "[Check][enable_dump_debug] failed."); | |||||
} | |||||
if ((enable_dump_ == kEnableFlag) && (enable_dump_debug_ == kEnableFlag)) { | |||||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
std::vector<std::string>({ | |||||
"ge.exec.enableDump and ge.exec.enableDumpDebug", | |||||
enable_dump_ + ", " + enable_dump_debug, | |||||
"ge.exec.enableDump and ge.exec.enableDumpDebug cannot be set to 1 at the same time."})); | |||||
GELOGE(FAILED, "ge.exec.enableDump and ge.exec.enableDumpDebug cannot be both set to 1 at the same time."); | |||||
return FAILED; | |||||
} | |||||
if ((enable_dump_ == kEnableFlag) || (enable_dump_debug_ == kEnableFlag)) { | if ((enable_dump_ == kEnableFlag) || (enable_dump_debug_ == kEnableFlag)) { | ||||
std::string dump_path; | std::string dump_path; | ||||
if (GetContext().GetOption(OPTION_EXEC_DUMP_PATH, dump_path) == GRAPH_SUCCESS) { | if (GetContext().GetOption(OPTION_EXEC_DUMP_PATH, dump_path) == GRAPH_SUCCESS) { | ||||
GE_CHK_STATUS_RET(CheckDumpPath(dump_path), "Check dump path failed."); | |||||
if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { | if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { | ||||
dump_path = dump_path + "/"; | dump_path = dump_path + "/"; | ||||
} | } | ||||
@@ -75,25 +263,21 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitByOpti | |||||
GELOGI("Get dump path %s successfully", dump_path.c_str()); | GELOGI("Get dump path %s successfully", dump_path.c_str()); | ||||
SetDumpPath(dump_path); | SetDumpPath(dump_path); | ||||
} else { | } else { | ||||
GELOGW("Dump path is not set"); | |||||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
std::vector<std::string>({ | |||||
"ge.exec.dumpPath", | |||||
dump_path, | |||||
"ge.exec.dumpPath is not set."})); | |||||
GELOGE(FAILED, "[Check][dump_path] failed. Dump path is not set."); | |||||
return FAILED; | |||||
} | } | ||||
} | } | ||||
if (enable_dump_ == kEnableFlag) { | |||||
std::string dump_step; | |||||
if (GetContext().GetOption(OPTION_EXEC_DUMP_STEP, dump_step) == GRAPH_SUCCESS) { | |||||
GELOGI("Get dump step %s successfully", dump_step.c_str()); | |||||
SetDumpStep(dump_step); | |||||
} | |||||
string dump_mode; | |||||
if (GetContext().GetOption(OPTION_EXEC_DUMP_MODE, dump_mode) == GRAPH_SUCCESS) { | |||||
GELOGI("Get dump mode %s successfully", dump_mode.c_str()); | |||||
SetDumpMode(dump_mode); | |||||
} | |||||
AddPropertyValue(DUMP_ALL_MODEL, {}); | |||||
} | |||||
GE_CHK_STATUS_RET(SetDumpOptions(), "SetDumpOptions failed."); | |||||
GE_CHK_STATUS_RET(SetDumpDebugOptions(), "SetDumpDebugOptions failed."); | |||||
SetDumpDebugOptions(); | |||||
return SUCCESS; | |||||
} | } | ||||
// The following is the new dump scenario of the fusion operator | // The following is the new dump scenario of the fusion operator | ||||
@@ -253,14 +437,20 @@ void DumpProperties::CopyFrom(const DumpProperties &other) { | |||||
} | } | ||||
} | } | ||||
void DumpProperties::SetDumpDebugOptions() { | |||||
Status DumpProperties::SetDumpDebugOptions() { | |||||
if (enable_dump_debug_ == kEnableFlag) { | if (enable_dump_debug_ == kEnableFlag) { | ||||
std::string dump_debug_mode; | std::string dump_debug_mode; | ||||
if (GetContext().GetOption(OPTION_EXEC_DUMP_DEBUG_MODE, dump_debug_mode) == GRAPH_SUCCESS) { | if (GetContext().GetOption(OPTION_EXEC_DUMP_DEBUG_MODE, dump_debug_mode) == GRAPH_SUCCESS) { | ||||
GELOGD("Get dump debug mode %s successfully", dump_debug_mode.c_str()); | GELOGD("Get dump debug mode %s successfully", dump_debug_mode.c_str()); | ||||
} else { | } else { | ||||
GELOGW("Dump debug mode is not set."); | |||||
return; | |||||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
std::vector<std::string>({ | |||||
"ge.exec.dumpDebugMode", | |||||
dump_debug_mode, | |||||
"ge.exec.dumpDebugMode is not set."})); | |||||
GELOGE(PARAM_INVALID, "[Check][dump_debug_mode] failed. Dump debug mode is not set."); | |||||
return PARAM_INVALID; | |||||
} | } | ||||
if (dump_debug_mode == OP_DEBUG_AICORE) { | if (dump_debug_mode == OP_DEBUG_AICORE) { | ||||
@@ -276,10 +466,17 @@ void DumpProperties::SetDumpDebugOptions() { | |||||
is_train_op_debug_ = true; | is_train_op_debug_ = true; | ||||
op_debug_mode_ = kAllOverflow; | op_debug_mode_ = kAllOverflow; | ||||
} else { | } else { | ||||
GELOGW("ge.exec.dumpDebugMode is invalid."); | |||||
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
std::vector<std::string>({ | |||||
"ge.exec.dumpDebugMode", | |||||
dump_debug_mode, | |||||
"ge.exec.dumpDebugMode is invalid."})); | |||||
GELOGE(PARAM_INVALID, "[Set][DumpDebugOptions] failed, ge.exec.dumpDebugMode is invalid."); | |||||
return PARAM_INVALID; | |||||
} | } | ||||
} else { | } else { | ||||
GELOGI("ge.exec.enableDumpDebug is false or is not set."); | GELOGI("ge.exec.enableDumpDebug is false or is not set."); | ||||
} | } | ||||
return SUCCESS; | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -23,6 +23,7 @@ | |||||
#include <vector> | #include <vector> | ||||
namespace ge { | namespace ge { | ||||
using Status = uint32_t; | |||||
class DumpProperties { | class DumpProperties { | ||||
public: | public: | ||||
DumpProperties() = default; | DumpProperties() = default; | ||||
@@ -33,7 +34,7 @@ class DumpProperties { | |||||
DumpProperties &operator=(const DumpProperties &dump); | DumpProperties &operator=(const DumpProperties &dump); | ||||
void InitByOptions(); | |||||
Status InitByOptions(); | |||||
void AddPropertyValue(const std::string &model, const std::set<std::string> &layers); | void AddPropertyValue(const std::string &model, const std::set<std::string> &layers); | ||||
@@ -95,7 +96,20 @@ class DumpProperties { | |||||
private: | private: | ||||
void CopyFrom(const DumpProperties &other); | void CopyFrom(const DumpProperties &other); | ||||
void SetDumpDebugOptions(); | |||||
Status SetDumpDebugOptions(); | |||||
Status SetDumpOptions(); | |||||
void Split(const std::string &s, std::vector<std::string> &result, const char *delchar); | |||||
Status CheckDumpStep(const std::string &dump_step); | |||||
Status CheckDumpMode(const std::string &dump_mode); | |||||
Status CheckDumpPath(const std::string &input); | |||||
Status CheckEnableDump(const std::string &input); | |||||
std::string enable_dump_; | std::string enable_dump_; | ||||
std::string enable_dump_debug_; | std::string enable_dump_debug_; | ||||
@@ -0,0 +1,58 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_opt_info/ge_opt_info.h" | |||||
#include <string> | |||||
#include <map> | |||||
#include "graph/ge_local_context.h" | |||||
#include "ge/ge_api_types.h" | |||||
#include "common/debug/ge_log.h" | |||||
#include "opt_info.h" | |||||
namespace ge { | |||||
Status GeOptInfo::SetOptInfo() { | |||||
std::string soc_ver; | |||||
graphStatus ret = GetThreadLocalContext().GetOption(SOC_VERSION, soc_ver); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Get soc version failed."); | |||||
GELOGE(FAILED, "[Get][SocVersion]Get soc version failed."); | |||||
return FAILED; | |||||
} | |||||
GELOGD("Soc version:%s.", soc_ver.c_str()); | |||||
std::map<std::string, std::string> opt_info; | |||||
// the first arg does not work at present. | |||||
if (gelc::GetOptInfo(gelc::kOffline, soc_ver, opt_info) != gelc::SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Get optional information failed, is_offline:%d, soc version:%s", | |||||
gelc::kOffline, soc_ver.c_str()); | |||||
GELOGE(FAILED, "[Get][OptInfo]Get optional information failed, is_offline:%d, soc version:%s", | |||||
gelc::kOffline, soc_ver.c_str()); | |||||
return FAILED; | |||||
} | |||||
// do nothing if get empty information | |||||
if (opt_info.empty()) { | |||||
GELOGI("Optional information is empty."); | |||||
return SUCCESS; | |||||
} | |||||
std::map<std::string, std::string> graph_options = GetThreadLocalContext().GetAllGraphOptions(); | |||||
for (const auto &itr : opt_info) { | |||||
graph_options.emplace(itr.first, itr.second); | |||||
GELOGI("Get optional information success, key:%s, value:%s.", itr.first.c_str(), itr.second.c_str()); | |||||
} | |||||
GetThreadLocalContext().SetGraphOption(graph_options); | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge |
@@ -14,23 +14,18 @@ | |||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ | |||||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ | |||||
#ifndef GE_OPT_INFO_GE_OPT_INFO_H_ | |||||
#define GE_OPT_INFO_GE_OPT_INFO_H_ | |||||
#include "stub_engine/ops_kernel_store/op/op.h" | |||||
#include "ge/ge_api_error_codes.h" | |||||
#include "register/register_types.h" | |||||
namespace ge { | namespace ge { | ||||
namespace st { | |||||
class GE_FUNC_VISIBILITY HostOp : public Op { | |||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeOptInfo { | |||||
public: | public: | ||||
HostOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} | |||||
~HostOp() override = default; | |||||
HostOp &operator=(const HostOp &op) = delete; | |||||
HostOp(const HostOp &op) = delete; | |||||
Status Run() override; | |||||
GeOptInfo() = default; | |||||
static Status SetOptInfo(); | |||||
}; | }; | ||||
} // namespace st | |||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ | |||||
#endif // GE_OPT_INFO_GE_OPT_INFO_H_ |
@@ -16,6 +16,7 @@ | |||||
#include "ge_runtime/task/hccl_task.h" | #include "ge_runtime/task/hccl_task.h" | ||||
#include <algorithm> | #include <algorithm> | ||||
#include "framework/common/util.h" | |||||
#include "ge_runtime/task/task_factory.h" | #include "ge_runtime/task/task_factory.h" | ||||
#include "common/opskernel/ops_kernel_info_store.h" | #include "common/opskernel/ops_kernel_info_store.h" | ||||
#include "common/opskernel/ge_task_info.h" | #include "common/opskernel/ge_task_info.h" | ||||
@@ -20,7 +20,6 @@ | |||||
#include <string> | #include <string> | ||||
#include <utility> | #include <utility> | ||||
#include "framework/common/debug/ge_log.h" | |||||
#include "graph/manager/graph_mem_manager.h" | #include "graph/manager/graph_mem_manager.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -94,7 +93,8 @@ void IncreaseCount(std::map<size_t, size_t> &count, size_t size) { | |||||
} | } | ||||
} | } | ||||
CachingAllocator::CachingAllocator(rtMemType_t memory_type) : memory_type_(memory_type), memory_allocator_(nullptr) { | |||||
CachingAllocator::CachingAllocator(rtMemType_t memory_type) | |||||
: memory_type_(memory_type), memory_allocator_(nullptr), called_malloc_counts_(0), called_free_counts_(0) { | |||||
for (uint32_t i = 0; i < kNumBins; i++) { | for (uint32_t i = 0; i < kNumBins; i++) { | ||||
free_block_bins_[i] = nullptr; | free_block_bins_[i] = nullptr; | ||||
} | } | ||||
@@ -121,6 +121,8 @@ Status CachingAllocator::Initialize(uint32_t device_id) { | |||||
if (memory_allocator_ == nullptr) { | if (memory_allocator_ == nullptr) { | ||||
return ACL_ERROR_GE_INTERNAL_ERROR; | return ACL_ERROR_GE_INTERNAL_ERROR; | ||||
} | } | ||||
called_malloc_counts_ = 0; | |||||
called_free_counts_ = 0; | |||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
} | } | ||||
@@ -133,6 +135,7 @@ void CachingAllocator::Finalize(uint32_t device_id) { | |||||
uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device_id) { | uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device_id) { | ||||
GELOGI("Start malloc pool memory, size = %zu, device id = %u", size, device_id); | GELOGI("Start malloc pool memory, size = %zu, device id = %u", size, device_id); | ||||
called_malloc_counts_++; | |||||
size = GetBlockSize(size); | size = GetBlockSize(size); | ||||
uint8_t *ptr = nullptr; | uint8_t *ptr = nullptr; | ||||
Block *block = FindFreeBlock(size, org_ptr, device_id); | Block *block = FindFreeBlock(size, org_ptr, device_id); | ||||
@@ -156,6 +159,7 @@ uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device | |||||
Status CachingAllocator::Free(uint8_t *ptr, uint32_t device_id) { | Status CachingAllocator::Free(uint8_t *ptr, uint32_t device_id) { | ||||
GELOGI("Free device id = %u", device_id); | GELOGI("Free device id = %u", device_id); | ||||
called_free_counts_++; | |||||
if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
REPORT_INNER_ERROR("E19999", "Param ptr is nullptr, device_id:%u, check invalid", device_id); | REPORT_INNER_ERROR("E19999", "Param ptr is nullptr, device_id:%u, check invalid", device_id); | ||||
GELOGE(PARAM_INVALID, "[Check][Param] Invalid memory pointer, device_id:%u", device_id); | GELOGE(PARAM_INVALID, "[Check][Param] Invalid memory pointer, device_id:%u", device_id); | ||||
@@ -283,6 +287,7 @@ Status CachingAllocator::TryExtendCache(size_t size, uint32_t device_id) { | |||||
if (memory_addr == nullptr) { | if (memory_addr == nullptr) { | ||||
GELOGE(ge::FAILED, "[Malloc][Memory] failed, no enough memory for size = %zu, device_id = %u", memory_size, | GELOGE(ge::FAILED, "[Malloc][Memory] failed, no enough memory for size = %zu, device_id = %u", memory_size, | ||||
device_id); | device_id); | ||||
PrintStatics(DLOG_ERROR); | |||||
return ge::FAILED; | return ge::FAILED; | ||||
} | } | ||||
GELOGT(TRACE_RUNNING, "Try to free cached memory size:%zu and malloc memory size:%zu success.", | GELOGT(TRACE_RUNNING, "Try to free cached memory size:%zu and malloc memory size:%zu success.", | ||||
@@ -385,14 +390,14 @@ void CachingAllocator::FreeBlockBins() { | |||||
} | } | ||||
void PrintCount(std::map<size_t, size_t> &count, const std::string &name, size_t total_size, size_t total_count) { | void PrintCount(std::map<size_t, size_t> &count, const std::string &name, size_t total_size, size_t total_count) { | ||||
GELOGI("%6s total[size:%10zu count:%10zu].", name.c_str(), total_size, total_count); | |||||
GEEVENT("%6s total[size:%11zu count:%11zu].", name.c_str(), total_size, total_count); | |||||
for (auto &it : count) { | for (auto &it : count) { | ||||
GELOGI(" |- block[size:%10zu count:%10zu].", it.first, it.second); | |||||
GEEVENT(" |- block[size:%11zu count:%11zu].", it.first, it.second); | |||||
} | } | ||||
} | } | ||||
void CachingAllocator::PrintStatics() { | |||||
if (!IsLogEnable(GE_MODULE_NAME, DLOG_INFO)) { | |||||
void CachingAllocator::PrintStatics(int32_t level) { | |||||
if (!IsLogEnable(GE_MODULE_NAME, level)) { | |||||
return; | return; | ||||
} | } | ||||
size_t total_using_size = 0; | size_t total_using_size = 0; | ||||
@@ -435,6 +440,7 @@ void CachingAllocator::PrintStatics() { | |||||
} | } | ||||
} while (0); | } while (0); | ||||
GEEVENT("Called counts[malloc:%11zu free:%11zu].", called_malloc_counts_.load(), called_free_counts_.load()); | |||||
PrintCount(malloc_block_stat, "Malloc", total_malloc_size, total_malloc_count); | PrintCount(malloc_block_stat, "Malloc", total_malloc_size, total_malloc_count); | ||||
PrintCount(using_block_stat, "Using", total_using_size, total_using_count); | PrintCount(using_block_stat, "Using", total_using_size, total_using_count); | ||||
PrintCount(free_block_stat, "Free", total_free_size, total_free_count); | PrintCount(free_block_stat, "Free", total_free_size, total_free_count); | ||||
@@ -27,6 +27,7 @@ | |||||
#include <unordered_map> | #include <unordered_map> | ||||
#include <unordered_set> | #include <unordered_set> | ||||
#include "framework/common/debug/ge_log.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
#include "graph/manager/block_memory.h" | #include "graph/manager/block_memory.h" | ||||
@@ -192,9 +193,10 @@ class CachingAllocator { | |||||
/// | /// | ||||
/// @ingroup ge_graph | /// @ingroup ge_graph | ||||
/// @brief print the memory info in pool | /// @brief print the memory info in pool | ||||
/// @param [in] log level | |||||
/// @return void | /// @return void | ||||
/// | /// | ||||
void PrintStatics(); | |||||
void PrintStatics(int32_t level = DLOG_INFO); | |||||
private: | private: | ||||
rtMemType_t memory_type_; | rtMemType_t memory_type_; | ||||
@@ -213,6 +215,12 @@ class CachingAllocator { | |||||
// malloced memorys from device | // malloced memorys from device | ||||
std::map<size_t, size_t> malloced_memory_; | std::map<size_t, size_t> malloced_memory_; | ||||
//user call Malloc total counts | |||||
std::atomic<size_t> called_malloc_counts_; | |||||
//user call Free total counts | |||||
std::atomic<size_t> called_free_counts_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_ | #endif // GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_ |
@@ -27,6 +27,7 @@ | |||||
#include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
#include "common/thread_pool.h" | #include "common/thread_pool.h" | ||||
#include "common/dump/dump_manager.h" | #include "common/dump/dump_manager.h" | ||||
#include "ge_opt_info/ge_opt_info.h" | |||||
#include "analyzer/analyzer.h" | #include "analyzer/analyzer.h" | ||||
#include "graph/common/ge_call_wrapper.h" | #include "graph/common/ge_call_wrapper.h" | ||||
#include "graph/common/local_context.h" | #include "graph/common/local_context.h" | ||||
@@ -1002,6 +1003,12 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<Ge | |||||
return ret; | return ret; | ||||
} | } | ||||
ret = GeOptInfo::SetOptInfo(); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Set][OptInfo] Set optional information failed."); | |||||
return ret; | |||||
} | |||||
/// 1. BUILD_MODE_TUNING with BUILD_STEP_AFTER_UB_MATCH no need PreRunOptimizeOriginalGraph; | /// 1. BUILD_MODE_TUNING with BUILD_STEP_AFTER_UB_MATCH no need PreRunOptimizeOriginalGraph; | ||||
/// 2. BUILD_MODE_TUNING with BUILD_STEP_AFTER_MERGE no need PreRunOptimizeOriginalGraph. | /// 2. BUILD_MODE_TUNING with BUILD_STEP_AFTER_MERGE no need PreRunOptimizeOriginalGraph. | ||||
/// 3. BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need PreRunOptimizeOriginalGraph. | /// 3. BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need PreRunOptimizeOriginalGraph. | ||||
@@ -3131,10 +3138,10 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
} | } | ||||
// Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency | // Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency | ||||
if (count > 1 && graph_node->GetBuildFlag()) { | if (count > 1 && graph_node->GetBuildFlag()) { | ||||
graph_node->Lock(); | |||||
GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); | GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id); | ||||
// In online inference concurrency senario, graph_node is allowed to be locked for 'count' times | // In online inference concurrency senario, graph_node is allowed to be locked for 'count' times | ||||
graph_node->SetSemSize(count); | graph_node->SetSemSize(count); | ||||
graph_node->Lock(); | |||||
graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, | graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context, | ||||
args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); | args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback })); | ||||
GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); | GELOGI("[PreRunThread] Loop end. Start to run with cached build model."); | ||||
@@ -84,9 +84,8 @@ Status InferBasePass::Run(NodePtr &node) { | |||||
bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; } | bool InferBasePass::NeedInfer(const NodePtr &node) const { return true; } | ||||
void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) { | void InferBasePass::AddChangedNodesImmediateRepass(const std::set<NodePtr> &changed_nodes) { | ||||
for (const auto &node_ele : changed_nodes) { | |||||
AddImmediateRePassNode(node_ele); | |||||
} | |||||
// need passed_nodes set to solve the problem that multi-input operators do repass in advance. | |||||
// when there is passed_nodes set, wo should call AddImmediateRePassNode for all nodes in changed_nodes. | |||||
} | } | ||||
graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) { | graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) { | ||||
@@ -286,6 +286,9 @@ graphStatus InferValueRangePass::GenerateWorstValueRange(NodePtr &node) { | |||||
} | } | ||||
std::vector<std::pair<int64_t, int64_t>> output_i_value_range(output_i_shape_size, {1, -1}); | std::vector<std::pair<int64_t, int64_t>> output_i_value_range(output_i_shape_size, {1, -1}); | ||||
if (output_i_shape.IsScalar()) { | |||||
output_i_value_range.emplace_back(1, -1); | |||||
} | |||||
output_desc->SetValueRange(output_i_value_range); | output_desc->SetValueRange(output_i_value_range); | ||||
GELOGD("Node %s output %zu shape is %s, the generated worst value range is %s.", node->GetName().c_str(), i, | GELOGD("Node %s output %zu shape is %s, the generated worst value range is %s.", node->GetName().c_str(), i, | ||||
formats::ShapeToString(output_i_shape).c_str(), formats::RangeToString(output_i_value_range).c_str()); | formats::ShapeToString(output_i_shape).c_str(), formats::RangeToString(output_i_value_range).c_str()); | ||||
@@ -16,8 +16,6 @@ | |||||
#include "graph/passes/mark_force_unknown_for_cond_pass.h" | #include "graph/passes/mark_force_unknown_for_cond_pass.h" | ||||
#include <queue> | |||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
@@ -26,17 +24,7 @@ namespace { | |||||
inline bool IsMergeInLoop(const NodePtr &node) { | inline bool IsMergeInLoop(const NodePtr &node) { | ||||
const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | ||||
std::string node_type; | |||||
(void)GetOriginalType(node, node_type); | |||||
return kLoopMergeInputs.count(node_type) > 0; | |||||
} | |||||
inline bool IsSwitchInLoop(const NodePtr &node) { | |||||
const static std::set<std::string> kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND }; | |||||
std::string node_type; | |||||
(void)GetOriginalType(node, node_type); | |||||
return kLoopSwitchInputs.count(node_type) > 0; | |||||
return kLoopMergeInputs.count(NodeUtils::GetNodeType(node)) > 0; | |||||
} | } | ||||
} | } | ||||
@@ -44,10 +32,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||||
GELOGD("MarkForceUnknownForCondPass Enter"); | GELOGD("MarkForceUnknownForCondPass Enter"); | ||||
std::map<NodePtr, std::vector<NodePtr>> switch_groups; | std::map<NodePtr, std::vector<NodePtr>> switch_groups; | ||||
for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), | |||||
"[Get][OriginalType] of node in graph:%s failed.", graph->GetName().c_str()); | |||||
if (kMergeOpTypes.count(node_type) == 0) { | |||||
if (kMergeOpTypes.count(NodeUtils::GetNodeType(node)) == 0) { | |||||
continue; | continue; | ||||
} | } | ||||
@@ -65,6 +50,51 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||||
} | } | ||||
/// | /// | ||||
/// @brief Deal with Switch node for LoopCond | |||||
/// @param [in] Switch node | |||||
/// @param [in] dest span | |||||
/// @param [out] Search queue | |||||
/// @return true: Switch In while loop / false: Not in while Loop. | |||||
/// | |||||
bool MarkForceUnknownForCondPass::DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, | |||||
std::queue<std::pair<NodePtr, uint32_t>> &search_queue) { | |||||
/// LoopCond --->\. | |||||
/// \. | |||||
/// Enter-----------+ \. | |||||
/// +--> Merge --> Switch --> Exit | |||||
/// NextIteration---+ | |||||
const auto is_loop_op = [](const NodePtr &n) { | |||||
return NodeUtils::GetNodeType(n) == LOOPCOND; | |||||
}; | |||||
const auto is_exit_op = [](const NodePtr &n) { | |||||
return kExitOpTypes.count(NodeUtils::GetNodeType(n)) > 0; | |||||
}; | |||||
const auto src_nodes = node->GetInAllNodes(); | |||||
const auto dst_nodes = node->GetOutAllNodes(); | |||||
if (std::none_of(src_nodes.begin(), src_nodes.end(), is_loop_op) && | |||||
std::none_of(dst_nodes.begin(), dst_nodes.end(), is_exit_op)) { | |||||
return false; | |||||
} | |||||
for (const auto &m : src_nodes) { | |||||
if (kMergeOpTypes.count(NodeUtils::GetNodeType(m)) > 0) { | |||||
for (const auto &n : m->GetInAllNodes()) { | |||||
if (kNextIterationOpTypes.count(NodeUtils::GetNodeType(n)) > 0) { | |||||
continue; | |||||
} | |||||
search_queue.push({n, dst_span}); | |||||
GELOGD("Travel in Loop: %s <-- %s <-- %s, span is: %u", node->GetName().c_str(), m->GetName().c_str(), | |||||
n->GetName().c_str(), dst_span); | |||||
} | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
/// | |||||
/// @brief Mark force unknown shape for Switch node | /// @brief Mark force unknown shape for Switch node | ||||
/// @param [in] merge node | /// @param [in] merge node | ||||
/// @param [out] switch group | /// @param [out] switch group | ||||
@@ -72,6 +102,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||||
/// | /// | ||||
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) { | void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) { | ||||
// Switch --> {Switch --> Merge} --> Merge | // Switch --> {Switch --> Merge} --> Merge | ||||
GELOGD("Search Switch node for Merge: %s", node->GetName().c_str()); | |||||
std::unordered_set<NodePtr> nodes_seen; | std::unordered_set<NodePtr> nodes_seen; | ||||
std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}}); | std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}}); | ||||
while (!search_queue.empty()) { | while (!search_queue.empty()) { | ||||
@@ -79,43 +110,25 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||||
const auto dst_span = search_queue.front().second; | const auto dst_span = search_queue.front().second; | ||||
search_queue.pop(); | search_queue.pop(); | ||||
// Switch --> Identity --> Constant | |||||
for (const auto &in_node : dst_node->GetInControlNodes()) { | |||||
if (nodes_seen.count(in_node) > 0) { | |||||
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
nodes_seen.insert(in_node); | |||||
if (in_node->GetType() == IDENTITY) { | |||||
GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(), | |||||
in_node->GetName().c_str(), dst_span); | |||||
search_queue.push({in_node, dst_span}); | |||||
} | |||||
} | |||||
for (const auto &in_node : dst_node->GetInDataNodes()) { | |||||
for (const auto &in_node : dst_node->GetInAllNodes()) { | |||||
if (nodes_seen.count(in_node) > 0) { | if (nodes_seen.count(in_node) > 0) { | ||||
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | ||||
continue; | continue; | ||||
} | } | ||||
nodes_seen.insert(in_node); | nodes_seen.insert(in_node); | ||||
std::string node_type; | |||||
(void)GetOriginalType(in_node, node_type); | |||||
const std::string node_type = NodeUtils::GetNodeType(in_node); | |||||
GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), | GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), | ||||
in_node->GetName().c_str(), dst_span); | in_node->GetName().c_str(), dst_span); | ||||
if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. | if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. | ||||
if (DealAsLoopSwitch(in_node, dst_span, search_queue)) { | |||||
continue; | |||||
} | |||||
if (dst_span > 0) { | if (dst_span > 0) { | ||||
search_queue.push({in_node, dst_span - 1}); | search_queue.push({in_node, dst_span - 1}); | ||||
} else { | } else { | ||||
const auto &all_in_nodes = in_node->GetInDataNodes(); | |||||
if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) { | |||||
GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(), | |||||
in_node->GetName().c_str()); | |||||
} else { | |||||
switch_group.emplace_back(in_node); | |||||
} | |||||
switch_group.emplace_back(in_node); | |||||
} | } | ||||
} else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | ||||
search_queue.push({in_node, dst_span + 1}); | search_queue.push({in_node, dst_span + 1}); | ||||
@@ -19,6 +19,8 @@ | |||||
#include "inc/graph_pass.h" | #include "inc/graph_pass.h" | ||||
#include <queue> | |||||
namespace ge { | namespace ge { | ||||
class MarkForceUnknownForCondPass : public GraphPass { | class MarkForceUnknownForCondPass : public GraphPass { | ||||
public: | public: | ||||
@@ -26,6 +28,15 @@ class MarkForceUnknownForCondPass : public GraphPass { | |||||
private: | private: | ||||
/// | /// | ||||
/// @brief Deal with Switch node for LoopCond | |||||
/// @param [in] Switch node | |||||
/// @param [in] dest span | |||||
/// @param [out] Search queue | |||||
/// @return true: Switch In while loop / false: Not in while Loop. | |||||
/// | |||||
bool DealAsLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue<std::pair<NodePtr, uint32_t>> &search_queue); | |||||
/// | |||||
/// @brief Mark force unknown shape for Switch node | /// @brief Mark force unknown shape for Switch node | ||||
/// @param [in] merge node | /// @param [in] merge node | ||||
/// @param [out] switch group | /// @param [out] switch group | ||||
@@ -24,7 +24,9 @@ using std::string; | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
const int64_t kLoopType = 1; | |||||
constexpr int64_t kLoopType = 1; | |||||
constexpr uint8_t kMaxTransOp = 3; | |||||
constexpr uint8_t kTransOpIoSize = 1; | |||||
} | } | ||||
Status NextIterationPass::Run(ComputeGraphPtr graph) { | Status NextIterationPass::Run(ComputeGraphPtr graph) { | ||||
@@ -287,18 +289,25 @@ void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, i | |||||
std::string node_type; | std::string node_type; | ||||
for (const auto &switch_node : loop_group.switch_nodes) { | for (const auto &switch_node : loop_group.switch_nodes) { | ||||
SetControlFlowGroup(switch_node, group_index); | SetControlFlowGroup(switch_node, group_index); | ||||
for (const auto &node : switch_node->GetOutDataNodes()) { | |||||
(void)GetOriginalType(node, node_type); | |||||
if (kExitOpTypes.count(node_type) > 0) { | |||||
SetControlFlowGroup(node, group_index); | |||||
} else { | |||||
// For: Switch -> Cast -> Exit | |||||
for (const auto &n : node->GetOutDataNodes()) { | |||||
(void)GetOriginalType(n, node_type); | |||||
if (kExitOpTypes.count(node_type) > 0) { | |||||
SetControlFlowGroup(n, group_index); | |||||
} | |||||
for (auto node : switch_node->GetOutDataNodes()) { | |||||
// Switch --> Exit | |||||
// Switch --> Cast --> Exit | |||||
// Switch --> TransData --> Cast --> Exit | |||||
for (uint8_t i = 0; i < kMaxTransOp; ++i) { | |||||
if (node->GetInDataNodes().size() != kTransOpIoSize || node->GetAllOutDataAnchorsSize() != kTransOpIoSize) { | |||||
break; | |||||
} | } | ||||
if (kExitOpTypes.count(NodeUtils::GetNodeType(node)) > 0) { | |||||
SetControlFlowGroup(node, group_index); | |||||
break; | |||||
} | |||||
const auto &all_nodes = node->GetOutAllNodes(); | |||||
if (all_nodes.size() != kTransOpIoSize) { | |||||
break; | |||||
} | |||||
node = all_nodes.at(0); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -71,7 +71,7 @@ Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { | |||||
GELOGI("Node %s Got empty output_desc_ptr, ignore current pass.", node->GetName().c_str()); | GELOGI("Node %s Got empty output_desc_ptr, ignore current pass.", node->GetName().c_str()); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
if (!IsEmptyTenor(output_desc_ptr->GetShape())) { | |||||
if (!IsKnownEmptyTenor(output_desc_ptr->GetShape())) { | |||||
is_all_output_empty = false; | is_all_output_empty = false; | ||||
break; | break; | ||||
} | } | ||||
@@ -107,12 +107,16 @@ Status ReplaceWithEmptyConstPass::GetOutputsOfCurrNode(const NodePtr &node_to_re | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
bool ReplaceWithEmptyConstPass::IsEmptyTenor(const GeShape &shape) const { | |||||
bool ReplaceWithEmptyConstPass::IsKnownEmptyTenor(const GeShape &shape) const { | |||||
bool is_known_empty_tensor = false; | |||||
for (auto dim : shape.GetDims()) { | for (auto dim : shape.GetDims()) { | ||||
if (dim == 0) { | |||||
return true; | |||||
if (dim < 0) { | |||||
// current dim is unknown dim, skip replace | |||||
return false; | |||||
} else if (dim == 0) { | |||||
is_known_empty_tensor = true; | |||||
} | } | ||||
} | } | ||||
return false; | |||||
return is_known_empty_tensor; | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -26,7 +26,7 @@ class ReplaceWithEmptyConstPass : public FoldingPass { | |||||
private: | private: | ||||
Status GetOutputsOfCurrNode(const NodePtr &node_to_replace, vector<GeTensorPtr> &outputs); | Status GetOutputsOfCurrNode(const NodePtr &node_to_replace, vector<GeTensorPtr> &outputs); | ||||
bool IsEmptyTenor(const GeShape &shape) const; | |||||
bool IsKnownEmptyTenor(const GeShape &shape) const; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ | #endif // GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ |
@@ -395,8 +395,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||||
peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | ||||
int64_t group_index = -1; | int64_t group_index = -1; | ||||
(void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
SetControlFlowGroup(stream_switch, group_index); | |||||
if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||||
SetControlFlowGroup(stream_switch, group_index); | |||||
} | |||||
return stream_switch; | return stream_switch; | ||||
} | } | ||||
@@ -326,17 +326,45 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||||
} | } | ||||
void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | ||||
if (node_item_->root_data_.count(input_idx) > 0) { | |||||
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||||
root_tensor_values_[input_idx] = tensor; | |||||
const auto is_persist_tensor = [](const std::map<const NodeItem *, std::set<int>> &items, int idx) { | |||||
const auto is_exist = [&idx](const std::pair<const NodeItem *, std::set<int>> &items) { | |||||
return items.second.count(idx) > 0; | |||||
}; | |||||
return std::any_of(items.begin(), items.end(), is_exist); | |||||
}; | |||||
if (root_tensor_values_.count(input_idx) > 0) { | |||||
return; | |||||
} | } | ||||
if (node_item_->enter_data_.count(input_idx) > 0) { | |||||
if (is_persist_tensor(node_item_->root_data_, input_idx)) { | |||||
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||||
root_tensor_values_[input_idx] = tensor; | |||||
} else if (is_persist_tensor(node_item_->enter_data_, input_idx)) { | |||||
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); | GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); | ||||
root_tensor_values_[input_idx] = tensor; | root_tensor_values_[input_idx] = tensor; | ||||
} | } | ||||
} | } | ||||
void NodeState::UpdatePersistTensor() { | |||||
const auto update_tensor = [&](const std::map<const NodeItem *, std::set<int>> &items) { | |||||
for (const auto &item : items) { | |||||
for (const auto idx : item.second) { | |||||
UpdatePersistTensor(idx); | |||||
} | |||||
} | |||||
}; | |||||
if (root_tensor_values_.empty()) { | |||||
return; | |||||
} | |||||
update_tensor(node_item_->root_data_); | |||||
if (iteration_count_ > 0) { | |||||
update_tensor(node_item_->enter_data_); | |||||
} | |||||
} | |||||
void NodeState::UpdatePersistTensor(int input_idx) { | void NodeState::UpdatePersistTensor(int input_idx) { | ||||
const auto it = root_tensor_values_.find(input_idx); | const auto it = root_tensor_values_.find(input_idx); | ||||
if (it == root_tensor_values_.end()) { | if (it == root_tensor_values_.end()) { | ||||
@@ -363,16 +391,9 @@ void NodeState::ResetContext(uint64_t iteration) { | |||||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | ||||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | ||||
for (auto item : node_item_->root_data_) { | |||||
UpdatePersistTensor(item.first); | |||||
} | |||||
if (iteration > 0) { | if (iteration > 0) { | ||||
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | ||||
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | ||||
for (auto item : node_item_->enter_data_) { | |||||
UpdatePersistTensor(item.first); | |||||
} | |||||
} | } | ||||
iteration_count_ = iteration; | iteration_count_ = iteration; | ||||
@@ -132,6 +132,7 @@ struct NodeState { | |||||
void RunNextIteration(); | void RunNextIteration(); | ||||
void SavePersistTensor(int input_idx, const TensorValue &tensor); | void SavePersistTensor(int input_idx, const TensorValue &tensor); | ||||
void UpdatePersistTensor(); | |||||
Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | ||||
@@ -109,7 +109,6 @@ Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector<TensorValue | |||||
GE_CHECK_NOTNULL(output_desc); | GE_CHECK_NOTNULL(output_desc); | ||||
output_desc->SetShape(tensor_desc->GetShape()); | output_desc->SetShape(tensor_desc->GetShape()); | ||||
output_desc->SetOriginShape(tensor_desc->GetOriginShape()); | output_desc->SetOriginShape(tensor_desc->GetOriginShape()); | ||||
output_desc->SetDataType(tensor_desc->GetDataType()); | |||||
node_state->SetSkipInferShape(true); | node_state->SetSkipInferShape(true); | ||||
} | } | ||||
} | } | ||||
@@ -373,6 +373,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, | |||||
auto executor = node_item.node_executor; | auto executor = node_item.node_executor; | ||||
GE_CHECK_NOTNULL(executor); | GE_CHECK_NOTNULL(executor); | ||||
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); | RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); | ||||
node_state.UpdatePersistTensor(); | |||||
GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", | GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", | ||||
node_state.GetName().c_str()); | node_state.GetName().c_str()); | ||||
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); | RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); | ||||
@@ -147,6 +147,7 @@ class HybridModel { | |||||
GeRootModelPtr ge_root_model_; | GeRootModelPtr ge_root_model_; | ||||
std::map<uint32_t, NodeItem *> input_nodes_; | std::map<uint32_t, NodeItem *> input_nodes_; | ||||
ComputeGraphPtr root_graph_; | ComputeGraphPtr root_graph_; | ||||
ComputeGraphPtr orig_root_graph_; | |||||
std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148 | std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148 | ||||
std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148 | std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148 | ||||
std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_; | std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_; | ||||
@@ -147,6 +147,7 @@ Status HybridModelBuilder::Build() { | |||||
GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); | ||||
hybrid_model_.model_name_ = ge_root_model_->GetModelName(); | hybrid_model_.model_name_ = ge_root_model_->GetModelName(); | ||||
GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | ||||
GE_CHK_STATUS_RET(CopyGraph(), "[Invoke][CopyGraph] failed, model_name_:[%s]", GetGraphName()); | |||||
GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName()); | ||||
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), | GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), | ||||
"[Invoke][RecoverGraphUnknownFlag] failed, model_name_:[%s]", GetGraphName()); | "[Invoke][RecoverGraphUnknownFlag] failed, model_name_:[%s]", GetGraphName()); | ||||
@@ -171,11 +172,12 @@ Status HybridModelBuilder::Build() { | |||||
Status HybridModelBuilder::BuildForSingleOp() { | Status HybridModelBuilder::BuildForSingleOp() { | ||||
GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName()); | ||||
hybrid_model_.root_graph_ = ge_root_model_->GetRootGraph(); | |||||
hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); | hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); | ||||
GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | GELOGI("[%s] Start to build hybrid model.", GetGraphName()); | ||||
auto ret = ge_root_model_->GetSubgraphInstanceNameToModel(); | auto ret = ge_root_model_->GetSubgraphInstanceNameToModel(); | ||||
const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()]; | |||||
GE_CHK_STATUS_RET(IndexTaskDefs(ge_root_model_->GetRootGraph(), ge_model), | |||||
const GeModelPtr ge_model = ret[hybrid_model_.root_graph_->GetName()]; | |||||
GE_CHK_STATUS_RET(IndexTaskDefs(hybrid_model_.root_graph_, ge_model), | |||||
"[Invoke][IndexTaskDefs] failed, model_name_:[%s]", GetGraphName()); | "[Invoke][IndexTaskDefs] failed, model_name_:[%s]", GetGraphName()); | ||||
GE_CHK_STATUS_RET(LoadGraph(), "[Invoke][LoadGraph] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(LoadGraph(), "[Invoke][LoadGraph] failed, model_name_:[%s]", GetGraphName()); | ||||
GE_CHK_STATUS_RET(InitWeights(), "[Invoke][InitWeights] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(InitWeights(), "[Invoke][InitWeights] failed, model_name_:[%s]", GetGraphName()); | ||||
@@ -190,6 +192,27 @@ Status HybridModelBuilder::ValidateParams() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelBuilder::CopyGraph() { | |||||
GELOGD("Copy compute graph begin."); | |||||
auto root_graph = ge_root_model_->GetRootGraph(); | |||||
std::string new_graph_name = ge_root_model_->GetRootGraph()->GetName(); | |||||
ComputeGraphPtr new_root_graph = MakeShared<ComputeGraph>(new_graph_name); | |||||
GE_CHECK_NOTNULL(new_root_graph); | |||||
int32_t depth = 0; | |||||
std::map<ConstNodePtr, NodePtr> node_old_2_new; | |||||
std::map<ConstOpDescPtr, OpDescPtr> op_desc_old_2_new; | |||||
graphStatus ret = GraphUtils::CopyComputeGraph(root_graph, new_root_graph, node_old_2_new, op_desc_old_2_new, depth); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Copy compute graph failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
hybrid_model_.root_graph_ = new_root_graph; | |||||
GELOGD("Copy compute graph[%s] success.", new_graph_name.c_str()); | |||||
return SUCCESS; | |||||
} | |||||
Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { | Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), | GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), | ||||
@@ -810,12 +833,13 @@ Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, | |||||
} | } | ||||
Status HybridModelBuilder::LoadGraph() { | Status HybridModelBuilder::LoadGraph() { | ||||
auto root_graph = ge_root_model_->GetRootGraph(); | |||||
auto root_graph = hybrid_model_.root_graph_; | |||||
if (!GetContext().GetHostExecFlag()) { | if (!GetContext().GetHostExecFlag()) { | ||||
std::shared_ptr<ComputeGraph> merged_graph; | std::shared_ptr<ComputeGraph> merged_graph; | ||||
GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", | ||||
root_graph->GetDirectNodesSize(), | root_graph->GetDirectNodesSize(), | ||||
root_graph->GetAllNodesSize()); | root_graph->GetAllNodesSize()); | ||||
hybrid_model_.orig_root_graph_ = root_graph; | |||||
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), | GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), | ||||
"[Invoke][UnfoldSubgraphs]Failed to unfold subgraphs, model_name_:%s.", GetGraphName()); | "[Invoke][UnfoldSubgraphs]Failed to unfold subgraphs, model_name_:%s.", GetGraphName()); | ||||
root_graph = std::move(merged_graph); | root_graph = std::move(merged_graph); | ||||
@@ -873,6 +897,7 @@ Status HybridModelBuilder::LoadGraph() { | |||||
} | } | ||||
for (auto &it : hybrid_model_.known_shape_sub_models_) { | for (auto &it : hybrid_model_.known_shape_sub_models_) { | ||||
auto node_item = MutableNodeItem(it.first); | auto node_item = MutableNodeItem(it.first); | ||||
GE_CHECK_NOTNULL(node_item); | |||||
AscendString graph_name; | AscendString graph_name; | ||||
GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name"); | GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name"); | ||||
auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString()); | auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString()); | ||||
@@ -1121,7 +1146,9 @@ Status HybridModelBuilder::InitWeights() { | |||||
sub_weight_buffer->GetSize()); | sub_weight_buffer->GetSize()); | ||||
auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); | auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); | ||||
if (subgraph != ge_root_model_->GetRootGraph()) { | if (subgraph != ge_root_model_->GetRootGraph()) { | ||||
subgraph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first); | |||||
subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_model.first); | |||||
} else { | |||||
subgraph = hybrid_model_.root_graph_; | |||||
} | } | ||||
GE_CHECK_NOTNULL(subgraph); | GE_CHECK_NOTNULL(subgraph); | ||||
hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer)); | hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer)); | ||||
@@ -1300,7 +1327,7 @@ Status HybridModelBuilder::IndexTaskDefs(const ComputeGraphPtr &sub_graph, const | |||||
} | } | ||||
Status HybridModelBuilder::IndexTaskDefs() { | Status HybridModelBuilder::IndexTaskDefs() { | ||||
const auto root_graph = ge_root_model_->GetRootGraph(); | |||||
const auto &root_graph = hybrid_model_.root_graph_; | |||||
const auto &root_graph_name = root_graph->GetName(); | const auto &root_graph_name = root_graph->GetName(); | ||||
if (SetOutputNameAttr(*root_graph) != SUCCESS) { | if (SetOutputNameAttr(*root_graph) != SUCCESS) { | ||||
GELOGW("Set output name attr failed."); | GELOGW("Set output name attr failed."); | ||||
@@ -1334,7 +1361,7 @@ Status HybridModelBuilder::IndexTaskDefs() { | |||||
Status HybridModelBuilder::IndexSpecialNodes() { | Status HybridModelBuilder::IndexSpecialNodes() { | ||||
GELOGD("Start to index special nodes"); | GELOGD("Start to index special nodes"); | ||||
const auto &root_graph = ge_root_model_->GetRootGraph(); | |||||
const auto &root_graph = hybrid_model_.root_graph_; | |||||
for (auto &node : root_graph->GetAllNodes()) { | for (auto &node : root_graph->GetAllNodes()) { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
@@ -1489,7 +1516,7 @@ Status HybridModelBuilder::InitRuntimeParams() { | |||||
runtime_param_.session_id = ret ? static_cast<uint64_t>(value) : 0; | runtime_param_.session_id = ret ? static_cast<uint64_t>(value) : 0; | ||||
ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_TASK_GEN_VAR_ADDR, value); | ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_TASK_GEN_VAR_ADDR, value); | ||||
runtime_param_.logic_var_base = ret ? static_cast<uint64_t>(value) : 0; | runtime_param_.logic_var_base = ret ? static_cast<uint64_t>(value) : 0; | ||||
runtime_param_.graph_id = ge_root_model_->GetRootGraph()->GetGraphID(); | |||||
runtime_param_.graph_id = hybrid_model_.root_graph_->GetGraphID(); | |||||
value = 0; | value = 0; | ||||
for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) { | for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) { | ||||
(void) ge::AttrUtils::GetInt(it.second, ATTR_MODEL_VAR_SIZE, value); | (void) ge::AttrUtils::GetInt(it.second, ATTR_MODEL_VAR_SIZE, value); | ||||
@@ -1626,7 +1653,7 @@ Status HybridModelBuilder::TransAllVarData() { | |||||
} | } | ||||
Status HybridModelBuilder::CopyVarData() { | Status HybridModelBuilder::CopyVarData() { | ||||
GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(ge_root_model_->GetRootGraph(), | |||||
GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(hybrid_model_.root_graph_, | |||||
runtime_param_.session_id, | runtime_param_.session_id, | ||||
hybrid_model_.device_id_), | hybrid_model_.device_id_), | ||||
"[Invoke][CopyVarData] failed."); | "[Invoke][CopyVarData] failed."); | ||||
@@ -1709,7 +1736,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem | |||||
} | } | ||||
Status HybridModelBuilder::RecoverGraphUnknownFlag() { | Status HybridModelBuilder::RecoverGraphUnknownFlag() { | ||||
const auto &root_graph = ge_root_model_->GetRootGraph(); | |||||
const auto &root_graph = hybrid_model_.root_graph_; | |||||
for (auto &sub_graph : root_graph->GetAllSubgraphs()) { | for (auto &sub_graph : root_graph->GetAllSubgraphs()) { | ||||
GE_CHECK_NOTNULL(sub_graph); | GE_CHECK_NOTNULL(sub_graph); | ||||
for (const auto &node : sub_graph->GetDirectNode()) { | for (const auto &node : sub_graph->GetDirectNode()) { | ||||
@@ -56,6 +56,7 @@ class HybridModelBuilder { | |||||
Status BuildOutputMapping(GraphItem &partitioned_call, const NodeItem &node_item, bool is_root_graph); | Status BuildOutputMapping(GraphItem &partitioned_call, const NodeItem &node_item, bool is_root_graph); | ||||
Status ValidateParams(); | Status ValidateParams(); | ||||
Status LoadGraph(); | Status LoadGraph(); | ||||
Status CopyGraph(); | |||||
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | ||||
static Status InitHcclExecutorOnDemand(const GeModelPtr &ge_model); | static Status InitHcclExecutorOnDemand(const GeModelPtr &ge_model); | ||||
Status LoadTask(NodeItem &node_item); | Status LoadTask(NodeItem &node_item); | ||||
@@ -395,11 +395,13 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
data_send_.emplace(node_item); | data_send_.emplace(node_item); | ||||
node_item->data_recv_[this] = anchor_index; | node_item->data_recv_[this] = anchor_index; | ||||
if (is_root_node_) { | if (is_root_node_) { | ||||
node_item->root_data_[anchor_index] = this; | |||||
auto &data_anchors = node_item->root_data_[this]; | |||||
data_anchors.emplace(anchor_index); | |||||
} | } | ||||
// If Enter feed Not Merge, take as root Node. | // If Enter feed Not Merge, take as root Node. | ||||
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | ||||
node_item->enter_data_[anchor_index] = this; | |||||
auto &data_anchors = node_item->enter_data_[this]; | |||||
data_anchors.emplace(anchor_index); | |||||
} | } | ||||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
} | } | ||||
@@ -148,9 +148,9 @@ struct NodeItem { | |||||
int64_t frame_index_ = -1; | int64_t frame_index_ = -1; | ||||
int64_t parent_frame_ = -1; | int64_t parent_frame_ = -1; | ||||
std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | ||||
std::map<int, const NodeItem *> root_data_; // Recv data from root node | |||||
std::map<const NodeItem *, std::set<int>> root_data_; // Recv data from root node | |||||
std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | ||||
std::map<int, const NodeItem *> enter_data_; // Recv data from Enter node | |||||
std::map<const NodeItem *, std::set<int>> enter_data_; // Recv data from Enter node | |||||
std::set<const NodeItem *> data_send_; // Send data notify to | std::set<const NodeItem *> data_send_; // Send data notify to | ||||
std::map<const NodeItem *, int> data_recv_; // Recv data notify from | std::map<const NodeItem *, int> data_recv_; // Recv data notify from | ||||
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | ||||
@@ -18,6 +18,7 @@ | |||||
#include "framework/common/taskdown_common.h" | #include "framework/common/taskdown_common.h" | ||||
#include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
#include "external/runtime/rt_error_codes.h" | #include "external/runtime/rt_error_codes.h" | ||||
#include "single_op/task/build_task_utils.h" | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
@@ -196,6 +197,11 @@ Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> | |||||
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] Start"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] Start"); | ||||
GE_CHK_STATUS_RET_NOLOG((*it)->LaunchKernel(context.GetStream())); | GE_CHK_STATUS_RET_NOLOG((*it)->LaunchKernel(context.GetStream())); | ||||
GE_CHK_STATUS_RET_NOLOG(CheckOverflow(context)); | GE_CHK_STATUS_RET_NOLOG(CheckOverflow(context)); | ||||
GE_CHECK_NOTNULL(context.GetExecutionContext()->model); | |||||
GELOGD("[DEBUG_TASK_INFO : Executor Task] %s/%s %s", | |||||
context.GetExecutionContext()->model->GetModelName().c_str(), | |||||
(*it)->GetName().empty() ? (*it)->GetLogName().c_str() : (*it)->GetName().c_str(), | |||||
BuildTaskUtils::GetTaskInfo(context).c_str()); | |||||
// save profiling data | // save profiling data | ||||
uint32_t task_id = 0; | uint32_t task_id = 0; | ||||
uint32_t stream_id = 0; | uint32_t stream_id = 0; | ||||
@@ -208,7 +214,7 @@ Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> | |||||
context.SetTaskId(task_id); | context.SetTaskId(task_id); | ||||
context.SetStreamId(stream_id); | context.SetStreamId(stream_id); | ||||
GELOGD("Aicore node[%s] task_id: %u, stream_id: %u.", context.GetNodeName(), task_id, stream_id); | GELOGD("Aicore node[%s] task_id: %u, stream_id: %u.", context.GetNodeName(), task_id, stream_id); | ||||
(void)context.SaveProfilingTaskDescInfo(task_id, stream_id, kTaskTypeAicore, (*it)->GetBlockDim()); | |||||
(void)context.SaveProfilingTaskDescInfo(task_id, stream_id, kTaskTypeAicore, (*it)->GetBlockDim(), (*it)->GetOpType()); | |||||
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); | ||||
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); | ||||
} | } | ||||
@@ -33,6 +33,7 @@ namespace { | |||||
constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | ||||
constexpr char const *kAttrOpParamSize = "op_para_size"; | constexpr char const *kAttrOpParamSize = "op_para_size"; | ||||
constexpr char const *kAttrAtomicOpParamSize = "atomic_op_para_size"; | constexpr char const *kAttrAtomicOpParamSize = "atomic_op_para_size"; | ||||
const string kAtomicOpType = "DynamicAtomicAddrClean"; | |||||
std::atomic<std::uint64_t> log_id(0); | std::atomic<std::uint64_t> log_id(0); | ||||
} // namespace | } // namespace | ||||
@@ -51,6 +52,7 @@ bool TbeHandleRegistry::AddHandle(std::unique_ptr<TbeHandleHolder> &&holder) { | |||||
} | } | ||||
Status AiCoreOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) { | Status AiCoreOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) { | ||||
op_type_ = op_desc.GetType(); | |||||
log_name_ = op_desc.GetName() + "_tvmbin"; | log_name_ = op_desc.GetName() + "_tvmbin"; | ||||
log_id_ = log_id++; | log_id_ = log_id++; | ||||
auto op_desc_ptr = MakeShared<OpDesc>(op_desc); | auto op_desc_ptr = MakeShared<OpDesc>(op_desc); | ||||
@@ -538,6 +540,10 @@ const std::string &AiCoreOpTask::GetName() const { | |||||
return stub_name_; | return stub_name_; | ||||
} | } | ||||
const std::string &AiCoreOpTask::GetOpType() const { | |||||
return op_type_; | |||||
} | |||||
std::string AiCoreOpTask::GetKeyForOpParamSize() const { | std::string AiCoreOpTask::GetKeyForOpParamSize() const { | ||||
return kAttrOpParamSize; | return kAttrOpParamSize; | ||||
} | } | ||||
@@ -631,6 +637,10 @@ std::string AtomicAddrCleanOpTask::GetKeyForKernelName(const OpDesc &op_desc) co | |||||
return op_desc.GetName() + "_atomic_kernelname"; | return op_desc.GetName() + "_atomic_kernelname"; | ||||
} | } | ||||
const std::string &AtomicAddrCleanOpTask::GetOpType() const { | |||||
return kAtomicOpType; | |||||
} | |||||
Status AtomicAddrCleanOpTask::CalcTilingInfo(const NodePtr &node, OpRunInfo &tiling_info) { | Status AtomicAddrCleanOpTask::CalcTilingInfo(const NodePtr &node, OpRunInfo &tiling_info) { | ||||
GELOGD("[%s] Start to invoke OpAtomicCalculate.", node->GetName().c_str()); | GELOGD("[%s] Start to invoke OpAtomicCalculate.", node->GetName().c_str()); | ||||
GE_CHK_STATUS_RET(optiling::OpAtomicCalculateV2(*node, tiling_info), | GE_CHK_STATUS_RET(optiling::OpAtomicCalculateV2(*node, tiling_info), | ||||
@@ -72,12 +72,16 @@ class AiCoreOpTask { | |||||
const std::string& GetName() const; | const std::string& GetName() const; | ||||
const std::string& GetLogName() const {return log_name_;} | |||||
bool GetClearAtomic() const {return clear_atomic_;} | bool GetClearAtomic() const {return clear_atomic_;} | ||||
uint32_t GetBlockDim() const {return block_dim_;} | uint32_t GetBlockDim() const {return block_dim_;} | ||||
void SetSingleOp(bool is_single_op) {is_single_op_ = is_single_op;}; | void SetSingleOp(bool is_single_op) {is_single_op_ = is_single_op;}; | ||||
virtual const std::string& GetOpType() const; | |||||
protected: | protected: | ||||
Status UpdateTilingInfo(TaskContext &context); | Status UpdateTilingInfo(TaskContext &context); | ||||
virtual std::string GetKeyForOpParamSize() const; | virtual std::string GetKeyForOpParamSize() const; | ||||
@@ -117,12 +121,14 @@ class AiCoreOpTask { | |||||
uint64_t log_id_ = 0; | uint64_t log_id_ = 0; | ||||
std::string log_name_; | std::string log_name_; | ||||
uint32_t offset_ = 0; | uint32_t offset_ = 0; | ||||
std::string op_type_; | |||||
}; | }; | ||||
class AtomicAddrCleanOpTask : public AiCoreOpTask { | class AtomicAddrCleanOpTask : public AiCoreOpTask { | ||||
public: | public: | ||||
Status Init(const OpDesc &op_desc, const domi::TaskDef &task_def) override; | Status Init(const OpDesc &op_desc, const domi::TaskDef &task_def) override; | ||||
Status UpdateArgs(TaskContext &task_context) override; | Status UpdateArgs(TaskContext &task_context) override; | ||||
const std::string& GetOpType() const override; | |||||
protected: | protected: | ||||
std::string GetKeyForOpParamSize() const override; | std::string GetKeyForOpParamSize() const override; | ||||
@@ -207,7 +207,7 @@ Status AicpuNodeTaskBase::ExecuteAsync(TaskContext &context, std::function<void( | |||||
context.SetTaskId(task_id); | context.SetTaskId(task_id); | ||||
context.SetStreamId(stream_id); | context.SetStreamId(stream_id); | ||||
GELOGD("Aicpu node[%s] task_id: %u, stream_id: %u.", context.GetNodeName(), task_id, stream_id); | GELOGD("Aicpu node[%s] task_id: %u, stream_id: %u.", context.GetNodeName(), task_id, stream_id); | ||||
(void)context.SaveProfilingTaskDescInfo(task_id, stream_id, kTaskTypeAicpu, 0); | |||||
(void)context.SaveProfilingTaskDescInfo(task_id, stream_id, kTaskTypeAicpu, 0, node_type_); | |||||
auto callback = [=, &context]() { | auto callback = [=, &context]() { | ||||
GELOGD("Node[%s] callback start.", node_name_.c_str()); | GELOGD("Node[%s] callback start.", node_name_.c_str()); | ||||
RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[TaskCallback] Start"); | RECORD_CALLBACK_EVENT(context.GetExecutionContext(), node_name_.c_str(), "[TaskCallback] Start"); | ||||
@@ -460,10 +460,6 @@ Status TaskContext::PropagateOutputs() { | |||||
subgraph_context_->all_inputs_[input_offset].SetName( | subgraph_context_->all_inputs_[input_offset].SetName( | ||||
node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | ||||
} | } | ||||
auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); | |||||
GE_CHECK_NOTNULL(dst_node_state); | |||||
dst_node_state->SavePersistTensor(dst_input_idx, *tensor); | |||||
} | } | ||||
} | } | ||||
(void)guard; | (void)guard; | ||||
@@ -495,6 +491,7 @@ void TaskContext::ReleaseInputsAndOutputs() { | |||||
void TaskContext::ReleaseInput(int index) { | void TaskContext::ReleaseInput(int index) { | ||||
auto input_tensor = MutableInput(index); | auto input_tensor = MutableInput(index); | ||||
if (input_tensor != nullptr) { | if (input_tensor != nullptr) { | ||||
node_state_->SavePersistTensor(index, *input_tensor); | |||||
input_tensor->Destroy(); | input_tensor->Destroy(); | ||||
GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); | GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); | ||||
} | } | ||||
@@ -574,8 +571,8 @@ Status TaskContext::Synchronize() { | |||||
return execution_context_->Synchronize(GetStream()); | return execution_context_->Synchronize(GetStream()); | ||||
} | } | ||||
Status TaskContext::SaveProfilingTaskDescInfo(uint32_t task_id, uint32_t stream_id, | |||||
const std::string &task_type, uint32_t block_dim) { | |||||
Status TaskContext::SaveProfilingTaskDescInfo(uint32_t task_id, uint32_t stream_id, const std::string &task_type, | |||||
uint32_t block_dim, const std::string &op_type) { | |||||
if (ProfilingManager::Instance().ProfilingModelLoadOn()) { | if (ProfilingManager::Instance().ProfilingModelLoadOn()) { | ||||
const NodeItem &node_item = GetNodeItem(); | const NodeItem &node_item = GetNodeItem(); | ||||
auto op_desc = node_item.GetOpDesc(); | auto op_desc = node_item.GetOpDesc(); | ||||
@@ -589,7 +586,7 @@ Status TaskContext::SaveProfilingTaskDescInfo(uint32_t task_id, uint32_t stream | |||||
TaskDescInfo tmp_task_desc_info; | TaskDescInfo tmp_task_desc_info; | ||||
tmp_task_desc_info.model_name = dynamic_model_name; | tmp_task_desc_info.model_name = dynamic_model_name; | ||||
tmp_task_desc_info.op_name = op_desc->GetName(); | tmp_task_desc_info.op_name = op_desc->GetName(); | ||||
tmp_task_desc_info.op_type = op_desc->GetType(); | |||||
tmp_task_desc_info.op_type = op_type; | |||||
tmp_task_desc_info.block_dim = block_dim; | tmp_task_desc_info.block_dim = block_dim; | ||||
tmp_task_desc_info.task_type = task_type; | tmp_task_desc_info.task_type = task_type; | ||||
tmp_task_desc_info.task_id = task_id; | tmp_task_desc_info.task_id = task_id; | ||||
@@ -118,8 +118,8 @@ class TaskContext { | |||||
void *handle_ = nullptr; | void *handle_ = nullptr; | ||||
const std::vector<TaskDescInfo>& GetProfilingTaskDescInfo() const { return task_desc_info; } | const std::vector<TaskDescInfo>& GetProfilingTaskDescInfo() const { return task_desc_info; } | ||||
Status SaveProfilingTaskDescInfo(uint32_t task_id, uint32_t stream_id, | |||||
const std::string &task_type, uint32_t block_dim); | |||||
Status SaveProfilingTaskDescInfo(uint32_t task_id, uint32_t stream_id, const std::string &task_type, | |||||
uint32_t block_dim, const std::string &op_type); | |||||
void ClearProfilingTaskDescInfo() { task_desc_info.clear(); } | void ClearProfilingTaskDescInfo() { task_desc_info.clear(); } | ||||
private: | private: | ||||
@@ -121,7 +121,7 @@ Status InnerSession::Initialize() { | |||||
GE_CHK_RT_RET(rtSetDevice(GetContext().DeviceId())); | GE_CHK_RT_RET(rtSetDevice(GetContext().DeviceId())); | ||||
DumpProperties dump_properties; | DumpProperties dump_properties; | ||||
dump_properties.InitByOptions(); | |||||
GE_CHK_STATUS_RET(dump_properties.InitByOptions(), "Init dump properties failed."); | |||||
GE_CHK_STATUS_RET(AddDumpProperties(dump_properties), "[Add][DumpProperties] failed."); | GE_CHK_STATUS_RET(AddDumpProperties(dump_properties), "[Add][DumpProperties] failed."); | ||||
ret = graph_manager_.Initialize(options_); | ret = graph_manager_.Initialize(options_); | ||||
@@ -297,6 +297,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOp::ExecuteAsync(c | |||||
for (auto &task : tasks_) { | for (auto &task : tasks_) { | ||||
ret = task->LaunchKernel(stream_); | ret = task->LaunchKernel(stream_); | ||||
GELOGD("[DEBUG_TASK_INFO : Static Task] %s %s", | |||||
task->GetTaskName().c_str(), | |||||
BuildTaskUtils::GetTaskInfo(task->GetOpdesc(), inputs, outputs).c_str()); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -447,6 +450,8 @@ Status DynamicSingleOp::ExecuteAsync(const vector<GeTensorDesc> &input_desc, | |||||
} else { | } else { | ||||
GE_CHK_STATUS_RET_NOLOG(op_task_->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_)); | GE_CHK_STATUS_RET_NOLOG(op_task_->LaunchKernel(input_desc, input_buffers, output_desc, output_buffers, stream_)); | ||||
} | } | ||||
GELOGD("[DEBUG_TASK_INFO : Dynamic Task] %s", | |||||
BuildTaskUtils::GetTaskInfo(op_task_->GetOpdesc(), input_buffers, output_buffers).c_str()); | |||||
GE_CHK_STATUS_RET_NOLOG(op_task_->OpenDump(stream_)); | GE_CHK_STATUS_RET_NOLOG(op_task_->OpenDump(stream_)); | ||||
GE_CHK_STATUS_RET_NOLOG(ProfilingTaskInfo(op_task_.get(), kShapeTypeDynamic)); | GE_CHK_STATUS_RET_NOLOG(ProfilingTaskInfo(op_task_.get(), kShapeTypeDynamic)); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -70,7 +70,9 @@ std::vector<void *> BuildTaskUtils::GetKernelArgs(const OpDescPtr &op_desc, | |||||
return JoinAddresses(addresses); | return JoinAddresses(addresses); | ||||
} | } | ||||
std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc) { | |||||
std::string BuildTaskUtils::InnerGetTaskInfo(const OpDescPtr &op_desc, | |||||
const std::vector<const void *> &input_addrs, | |||||
const std::vector<const void *> &output_addrs) { | |||||
std::stringstream ss; | std::stringstream ss; | ||||
if (op_desc != nullptr) { | if (op_desc != nullptr) { | ||||
auto op_type = op_desc->GetType(); | auto op_type = op_desc->GetType(); | ||||
@@ -87,7 +89,10 @@ std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc) { | |||||
} | } | ||||
ss << TypeUtils::DataTypeToSerialString(input->GetDataType()) << " "; | ss << TypeUtils::DataTypeToSerialString(input->GetDataType()) << " "; | ||||
ss << TypeUtils::FormatToSerialString(input->GetFormat()); | ss << TypeUtils::FormatToSerialString(input->GetFormat()); | ||||
ss << VectorToString(input->GetShape().GetDims()); | |||||
ss << VectorToString(input->GetShape().GetDims()) << " "; | |||||
if (idx < input_addrs.size()) { | |||||
ss << input_addrs[idx]; | |||||
} | |||||
if (idx < op_desc->GetInputsSize() - 1) { | if (idx < op_desc->GetInputsSize() - 1) { | ||||
ss << ","; | ss << ","; | ||||
} | } | ||||
@@ -101,7 +106,10 @@ std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc) { | |||||
const GeShape &out_shape = output->GetShape(); | const GeShape &out_shape = output->GetShape(); | ||||
const auto &dims = out_shape.GetDims(); | const auto &dims = out_shape.GetDims(); | ||||
ss << TypeUtils::FormatToSerialString(out_format); | ss << TypeUtils::FormatToSerialString(out_format); | ||||
ss << VectorToString(dims); | |||||
ss << VectorToString(dims) << " "; | |||||
if (idx < output_addrs.size()) { | |||||
ss << output_addrs[idx]; | |||||
} | |||||
if (idx < op_desc->GetOutputsSize() - 1) { | if (idx < op_desc->GetOutputsSize() - 1) { | ||||
ss << ","; | ss << ","; | ||||
} | } | ||||
@@ -110,4 +118,44 @@ std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc) { | |||||
} | } | ||||
return ss.str(); | return ss.str(); | ||||
} | } | ||||
std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc) { | |||||
vector<const void *> input_addrs; | |||||
vector<const void *> output_addrs; | |||||
return InnerGetTaskInfo(op_desc, input_addrs, output_addrs); | |||||
} | |||||
std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc, | |||||
const std::vector<DataBuffer> &inputs, | |||||
const std::vector<DataBuffer> &outputs) { | |||||
vector<const void *> input_addrs; | |||||
vector<const void *> output_addrs; | |||||
GE_CHECK_NOTNULL_EXEC(op_desc, return ""); | |||||
if (op_desc->GetAllInputsSize() == inputs.size()) { | |||||
std::for_each(inputs.begin(), inputs.end(), [&](const DataBuffer &db) { input_addrs.push_back(db.data); }); | |||||
} | |||||
if (op_desc->GetOutputsSize() == outputs.size()) { | |||||
std::for_each(outputs.begin(), outputs.end(), [&](const DataBuffer &db) { output_addrs.push_back(db.data); }); | |||||
} | |||||
return InnerGetTaskInfo(op_desc, input_addrs, output_addrs); | |||||
} | |||||
std::string BuildTaskUtils::GetTaskInfo(const hybrid::TaskContext &task_context) { | |||||
auto &node_item = task_context.GetNodeItem(); | |||||
auto op_desc = node_item.GetOpDesc(); | |||||
GE_CHECK_NOTNULL_EXEC(op_desc, return ""); | |||||
vector<const void *> input_addrs; | |||||
vector<const void *> output_addrs; | |||||
if (op_desc->GetAllInputsSize() == static_cast<uint32_t>(task_context.NumInputs())) { | |||||
for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) { | |||||
input_addrs.push_back(task_context.GetInput(i)->GetData()); | |||||
} | |||||
} | |||||
if (op_desc->GetOutputsSize() == static_cast<uint32_t>(task_context.NumOutputs())) { | |||||
for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { | |||||
output_addrs.push_back(task_context.GetOutput(i)->GetData()); | |||||
} | |||||
} | |||||
return InnerGetTaskInfo(op_desc, input_addrs, output_addrs); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -23,6 +23,7 @@ | |||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
#include "single_op/single_op.h" | #include "single_op/single_op.h" | ||||
#include "single_op/single_op_model.h" | #include "single_op/single_op_model.h" | ||||
#include "hybrid/node_executor/task_context.h" | |||||
namespace ge { | namespace ge { | ||||
class BuildTaskUtils { | class BuildTaskUtils { | ||||
@@ -35,7 +36,14 @@ class BuildTaskUtils { | |||||
bool keep_workspace = true); | bool keep_workspace = true); | ||||
static std::vector<void *> JoinAddresses(const std::vector<std::vector<void *>> &addresses); | static std::vector<void *> JoinAddresses(const std::vector<std::vector<void *>> &addresses); | ||||
static std::vector<void *> GetKernelArgs(const OpDescPtr &op_desc, const SingleOpModelParam ¶m); | static std::vector<void *> GetKernelArgs(const OpDescPtr &op_desc, const SingleOpModelParam ¶m); | ||||
static std::string InnerGetTaskInfo(const OpDescPtr &op_desc, | |||||
const std::vector<const void *> &input_addrs, | |||||
const std::vector<const void *> &output_addrs); | |||||
static std::string GetTaskInfo(const OpDescPtr &op_desc); | static std::string GetTaskInfo(const OpDescPtr &op_desc); | ||||
static std::string GetTaskInfo(const OpDescPtr &op_desc, | |||||
const std::vector<DataBuffer> &inputs, | |||||
const std::vector<DataBuffer> &outputs); | |||||
static std::string GetTaskInfo(const hybrid::TaskContext& task_context); | |||||
template<typename T> | template<typename T> | ||||
static std::string VectorToString(const std::vector<T> &values) { | static std::string VectorToString(const std::vector<T> &values) { | ||||
std::stringstream ss; | std::stringstream ss; | ||||
@@ -89,6 +89,7 @@ Status OpTask::OpenDump(rtStream_t stream) { | |||||
void TbeOpTask::SetStubFunc(const std::string &name, const void *stub_func) { | void TbeOpTask::SetStubFunc(const std::string &name, const void *stub_func) { | ||||
this->stub_name_ = name; | this->stub_name_ = name; | ||||
this->stub_func_ = stub_func; | this->stub_func_ = stub_func; | ||||
this->task_name_ = name; | |||||
} | } | ||||
void TbeOpTask::SetKernelArgs(std::unique_ptr<uint8_t[]> &&args, size_t arg_size, uint32_t block_dim, | void TbeOpTask::SetKernelArgs(std::unique_ptr<uint8_t[]> &&args, size_t arg_size, uint32_t block_dim, | ||||
@@ -345,49 +346,95 @@ Status TbeOpTask::AllocateWorkspaces(const vector<int64_t> &workspace_sizes) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status TbeOpTask::LaunchKernel(const vector<GeTensorDesc> &input_desc, | |||||
const vector<DataBuffer> &input_buffers, | |||||
vector<GeTensorDesc> &output_desc, | |||||
vector<DataBuffer> &output_buffers, | |||||
rtStream_t stream) { | |||||
GELOGD("[%s] Start to launch kernel", node_->GetName().c_str()); | |||||
GE_CHK_STATUS_RET_NOLOG(UpdateNodeByShape(input_desc, output_desc)); | |||||
GE_CHK_STATUS_RET_NOLOG(UpdateRunInfo()); | |||||
GE_CHK_STATUS_RET(AllocateWorkspaces(run_info_workspaces_), "[Allocate][Workspaces] failed."); | |||||
std::vector<void *> args; | |||||
for (auto &buffer : input_buffers) { | |||||
args.emplace_back(buffer.data); | |||||
Status TbeOpTask::UpdateTilingArgs(rtStream_t stream) { | |||||
size_t args_size = input_num_ + output_num_ + workspaces_.size(); | |||||
if (tiling_buffer_ != nullptr) { | |||||
args_size++; | |||||
} | } | ||||
for (auto &buffer : output_buffers) { | |||||
args.emplace_back(buffer.data); | |||||
size_t temp_size = args_size * sizeof(void *); | |||||
if (arg_size_ < temp_size) { | |||||
GELOGD("Need to reset size of args_ from %zu to %zu.", arg_size_, temp_size); | |||||
std::unique_ptr<uint8_t[]> args(new (std::nothrow) uint8_t[temp_size]()); | |||||
GE_CHECK_NOTNULL(args); | |||||
if (memcpy_s(args.get(), temp_size, args_.get(), arg_size_) != EOK) { | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Update][KernelArgs] failed for [%s].", node_->GetName().c_str()); | |||||
REPORT_INNER_ERROR("E19999", "update kernel args failed for %s.", node_->GetName().c_str()); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
} | |||||
args_ = std::move(args); | |||||
arg_size_ = temp_size; | |||||
} | } | ||||
for (auto &buffer : workspaces_) { | |||||
args.emplace_back(buffer); | |||||
uintptr_t *arg_base = reinterpret_cast<uintptr_t *>(args_.get()); | |||||
size_t arg_index = input_num_ + output_num_; | |||||
for (size_t i = 0; i < workspaces_.size(); ++i) { | |||||
arg_base[arg_index++] = reinterpret_cast<uintptr_t>(workspaces_[i]); | |||||
} | } | ||||
if (tiling_buffer_ != nullptr) { | if (tiling_buffer_ != nullptr) { | ||||
GELOGD("[%s] Start to copy tiling info. size = %zu", node_->GetName().c_str(), tiling_data_.size()); | GELOGD("[%s] Start to copy tiling info. size = %zu", node_->GetName().c_str(), tiling_data_.size()); | ||||
GE_CHK_RT_RET(rtMemcpyAsync(tiling_buffer_, max_tiling_size_, tiling_data_.data(), tiling_data_.size(), | GE_CHK_RT_RET(rtMemcpyAsync(tiling_buffer_, max_tiling_size_, tiling_data_.data(), tiling_data_.size(), | ||||
RT_MEMCPY_HOST_TO_DEVICE_EX, stream)); | RT_MEMCPY_HOST_TO_DEVICE_EX, stream)); | ||||
arg_base[arg_index] = reinterpret_cast<uintptr_t>(tiling_buffer_); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status TbeOpTask::SetArgIndex() { | |||||
const vector<bool> v_is_input_const = op_desc_->GetIsInputConst(); | |||||
size_t input_index = 0; | |||||
for (size_t i = 0; i < op_desc_->GetAllInputsSize(); ++i) { | |||||
const GeTensorDescPtr tensor_desc = op_desc_->MutableInputDesc(static_cast<uint32_t>(i)); | |||||
if (tensor_desc == nullptr) { | |||||
GELOGD("SingleOp: %s, Index: %zu, has no input", op_desc_->GetName().c_str(), i); | |||||
continue; | |||||
} | |||||
if (i < v_is_input_const.size() && v_is_input_const[i]) { | |||||
GELOGD("SingleOp: %s, Index: %zu, input is const", op_desc_->GetName().c_str(), i); | |||||
input_index++; | |||||
continue; | |||||
} | |||||
arg_index_.emplace_back(input_index); | |||||
input_index++; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
args.emplace_back(tiling_buffer_); | |||||
Status TbeOpTask::UpdateIoAddr(const vector<DataBuffer> &inputs, const vector<DataBuffer> &outputs) { | |||||
if (arg_index_.size() != inputs.size()) { | |||||
GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Size] Args size is %zu, but get input size is %zu.", | |||||
arg_index_.size(), inputs.size()); | |||||
REPORT_INNER_ERROR("E19999", "[Check][Size] Args size is %zu, but get input size is %zu.", | |||||
arg_index_.size(), inputs.size()); | |||||
return ACL_ERROR_GE_PARAM_INVALID; | |||||
} | } | ||||
GELOGD("Dst size is %zu, src size is %zu.", arg_size_, args.size() * sizeof(void *)); | |||||
// node with workspace: build can not get size of workspace, need to update arg_size_ when execute | |||||
if (arg_size_ < (args.size() * sizeof(void *))) { | |||||
size_t temp_size = args.size() * sizeof(void *); | |||||
GELOGD("Need to reset size of args_ from %zu to %zu.", arg_size_, temp_size); | |||||
args_.reset(new(std::nothrow) uint8_t[temp_size]()); | |||||
GE_CHECK_NOTNULL(args_); | |||||
arg_size_ = temp_size; | |||||
uintptr_t *arg_base = reinterpret_cast<uintptr_t *>(args_.get()); | |||||
for (size_t i = 0; i < arg_index_.size(); ++i) { | |||||
arg_base[arg_index_[i]] = reinterpret_cast<uintptr_t>(inputs[i].data); | |||||
} | } | ||||
if (memcpy_s(args_.get(), arg_size_, args.data(), args.size() * sizeof(void *)) != EOK) { | |||||
GELOGE(ACL_ERROR_GE_MEMORY_OPERATE_FAILED, "[Update][KernelArgs] failed for [%s].", node_->GetName().c_str()); | |||||
REPORT_INNER_ERROR("E19999", "update kernel args failed for %s.", node_->GetName().c_str()); | |||||
return ACL_ERROR_GE_MEMORY_OPERATE_FAILED; | |||||
for (size_t i = 0; i < op_desc_->GetOutputsSize(); ++i) { | |||||
arg_base[input_num_ + i] = reinterpret_cast<uintptr_t>(outputs[i].data); | |||||
} | } | ||||
return SUCCESS; | |||||
} | |||||
Status TbeOpTask::LaunchKernel(const vector<GeTensorDesc> &input_desc, | |||||
const vector<DataBuffer> &input_buffers, | |||||
vector<GeTensorDesc> &output_desc, | |||||
vector<DataBuffer> &output_buffers, | |||||
rtStream_t stream) { | |||||
GELOGD("[%s] Start to launch kernel", node_->GetName().c_str()); | |||||
GE_CHK_STATUS_RET(UpdateIoAddr(input_buffers, output_buffers), "[Update][IoAddr] failed."); | |||||
GE_CHK_STATUS_RET_NOLOG(UpdateNodeByShape(input_desc, output_desc)); | |||||
GE_CHK_STATUS_RET_NOLOG(UpdateRunInfo()); | |||||
GE_CHK_STATUS_RET(AllocateWorkspaces(run_info_workspaces_), "[Allocate][Workspaces] failed."); | |||||
GE_CHK_STATUS_RET(UpdateTilingArgs(stream), "[Update][TilingArgs] failed."); | |||||
GELOGD("[%s] Start to invoke rtKernelLaunch", node_->GetName().c_str()); | GELOGD("[%s] Start to invoke rtKernelLaunch", node_->GetName().c_str()); | ||||
GE_CHK_STATUS_RET(DoLaunchKernel(stream), "Failed to do launch kernel."); | GE_CHK_STATUS_RET(DoLaunchKernel(stream), "Failed to do launch kernel."); | ||||
@@ -44,6 +44,7 @@ class OpTask { | |||||
virtual Status UpdateArgTable(const SingleOpModelParam ¶m); | virtual Status UpdateArgTable(const SingleOpModelParam ¶m); | ||||
void SetModelArgs(std::string model_name, uint32_t model_id); | void SetModelArgs(std::string model_name, uint32_t model_id); | ||||
Status GetProfilingArgs(TaskDescInfo &task_desc_info, uint32_t &model_id); | Status GetProfilingArgs(TaskDescInfo &task_desc_info, uint32_t &model_id); | ||||
const std::string &GetTaskName() const {return task_name_;} | |||||
void SetOpDesc(const OpDescPtr &op_desc) { | void SetOpDesc(const OpDescPtr &op_desc) { | ||||
op_desc_ = op_desc; | op_desc_ = op_desc; | ||||
} | } | ||||
@@ -66,6 +67,7 @@ class OpTask { | |||||
std::string model_name_; | std::string model_name_; | ||||
uint32_t model_id_ = 0; | uint32_t model_id_ = 0; | ||||
uint32_t block_dim_ = 1; | uint32_t block_dim_ = 1; | ||||
std::string task_name_; | |||||
}; | }; | ||||
class TbeOpTask : public OpTask { | class TbeOpTask : public OpTask { | ||||
@@ -85,6 +87,7 @@ class TbeOpTask : public OpTask { | |||||
const OpDescPtr &op_desc, const domi::KernelDefWithHandle& kernel_def_with_handle); | const OpDescPtr &op_desc, const domi::KernelDefWithHandle& kernel_def_with_handle); | ||||
Status UpdateRunInfo() override; | Status UpdateRunInfo() override; | ||||
Status SetArgIndex(); | |||||
const void *GetArgs() const; | const void *GetArgs() const; | ||||
size_t GetArgSize() const; | size_t GetArgSize() const; | ||||
@@ -100,7 +103,9 @@ class TbeOpTask : public OpTask { | |||||
Status UpdateNodeByShape(const vector<GeTensorDesc> &input_desc, | Status UpdateNodeByShape(const vector<GeTensorDesc> &input_desc, | ||||
const vector<GeTensorDesc> &output_desc); | const vector<GeTensorDesc> &output_desc); | ||||
Status AllocateWorkspaces(const std::vector<int64_t> &workspace_sizes); | Status AllocateWorkspaces(const std::vector<int64_t> &workspace_sizes); | ||||
Status UpdateTilingArgs(rtStream_t stream); | |||||
Status DoLaunchKernel(rtStream_t stream); | Status DoLaunchKernel(rtStream_t stream); | ||||
Status UpdateIoAddr(const vector<DataBuffer> &inputs, const vector<DataBuffer> &outputs); | |||||
const void *stub_func_ = nullptr; | const void *stub_func_ = nullptr; | ||||
std::unique_ptr<uint8_t[]> args_; | std::unique_ptr<uint8_t[]> args_; | ||||
@@ -120,6 +125,9 @@ class TbeOpTask : public OpTask { | |||||
void* handle_ = nullptr; | void* handle_ = nullptr; | ||||
std::string original_kernel_key_; | std::string original_kernel_key_; | ||||
std::string node_info_; | std::string node_info_; | ||||
std::vector<size_t> arg_index_; // data index in args | |||||
size_t input_num_; // include const input | |||||
size_t output_num_; | |||||
}; | }; | ||||
class AiCpuBaseTask : public OpTask { | class AiCpuBaseTask : public OpTask { | ||||
@@ -387,6 +387,9 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶ | |||||
} | } | ||||
task.SetStubFunc(stub_name_, stub_func); | task.SetStubFunc(stub_name_, stub_func); | ||||
} | } | ||||
GE_CHK_STATUS_RET(task.SetArgIndex(), "[Set][ArgTable] failed."); | |||||
task.input_num_ = op_desc_->GetInputsSize(); | |||||
task.output_num_ = op_desc_->GetOutputsSize(); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -22,6 +22,7 @@ add_subdirectory(depends/runtime) | |||||
add_subdirectory(depends/hccl) | add_subdirectory(depends/hccl) | ||||
add_subdirectory(depends/profiler) | add_subdirectory(depends/profiler) | ||||
add_subdirectory(depends/error_manager) | add_subdirectory(depends/error_manager) | ||||
add_subdirectory(depends/opt_info) | |||||
if (ENABLE_GE_COV OR ENABLE_GE_UT) | if (ENABLE_GE_COV OR ENABLE_GE_UT) | ||||
add_subdirectory(ut) | add_subdirectory(ut) | ||||
@@ -345,6 +345,10 @@ INT32 mmIsDir(const CHAR *fileName) | |||||
INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) | INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len) | ||||
{ | { | ||||
const char *env = getenv(name); | |||||
if (env != nullptr) { | |||||
strcpy(value, env); | |||||
} | |||||
return 0; | return 0; | ||||
} | } | ||||
@@ -0,0 +1,37 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
#cmake_minimum_required(VERSION 2.8) | |||||
project(opt_feature_stub) | |||||
file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"src/opt_info_stub.cc" | |||||
) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info) | |||||
add_library(opt_feature_stub SHARED ${SRCS}) | |||||
target_compile_options(opt_feature_stub PRIVATE | |||||
-g | |||||
) | |||||
target_link_libraries(opt_feature_stub PRIVATE | |||||
$<BUILD_INTERFACE:intf_pub> | |||||
c_sec | |||||
) | |||||
target_include_directories(opt_feature_stub INTERFACE ${CMAKE_CURRENT_LIST_DIR}/src) |
@@ -0,0 +1,46 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "opt_info.h" | |||||
#include <string> | |||||
#include <map> | |||||
#include <vector> | |||||
#include <algorithm> | |||||
namespace gelc { | |||||
namespace { | |||||
const std::vector<std::string> kSocVersions = {"Ascend910"}; | |||||
} | |||||
void SetAllOptInfo(std::map<std::string, std::string> &opt_infos) { | |||||
opt_infos.emplace("opt_module.fe", "all"); | |||||
opt_infos.emplace("opt_module.pass", "all"); | |||||
opt_infos.emplace("opt_module.op_tune", "all"); | |||||
opt_infos.emplace("opt_module.rl_tune", "all"); | |||||
opt_infos.emplace("opt_module.aoe", "all"); | |||||
} | |||||
Status GetOptInfo(WorkMode mode, const std::string &soc_ver, | |||||
std::map<std::string, std::string> &opt_infos) { | |||||
if (std::find(kSocVersions.begin(), kSocVersions.end(), soc_ver)== kSocVersions.end()) { | |||||
SetAllOptInfo(opt_infos); | |||||
return SUCCESS; | |||||
} | |||||
opt_infos.emplace("opt_module.fe", "all"); | |||||
opt_infos.emplace("opt_module.pass", "all"); | |||||
opt_infos.emplace("opt_module.op_tune", "all"); | |||||
return SUCCESS; | |||||
} | |||||
} // namespace gelc |
@@ -23,13 +23,46 @@ | |||||
void dav_log(int module_id, const char *fmt, ...) {} | void dav_log(int module_id, const char *fmt, ...) {} | ||||
void DlogErrorInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } | |||||
static int log_level = DLOG_ERROR; | |||||
#define __DO_PRINT() \ | |||||
do { \ | |||||
const int FMT_BUFF_SIZE = 1024; \ | |||||
char fmt_buff[FMT_BUFF_SIZE] = {0}; \ | |||||
va_list valist; \ | |||||
va_start(valist, fmt); \ | |||||
vsnprintf(fmt_buff, FMT_BUFF_SIZE, fmt, valist); \ | |||||
va_end(valist); \ | |||||
printf("%s \n", fmt_buff); \ | |||||
} while (0) | |||||
void DlogErrorInner(int module_id, const char *fmt, ...) { | |||||
if (log_level > DLOG_ERROR) { | |||||
return; | |||||
} | |||||
__DO_PRINT(); | |||||
} | |||||
void DlogWarnInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } | |||||
void DlogWarnInner(int module_id, const char *fmt, ...) { | |||||
if (log_level > DLOG_WARN) { | |||||
return; | |||||
} | |||||
__DO_PRINT(); | |||||
} | |||||
void DlogInfoInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } | |||||
void DlogInfoInner(int module_id, const char *fmt, ...) { | |||||
if (log_level > DLOG_INFO) { | |||||
return; | |||||
} | |||||
__DO_PRINT(); | |||||
} | |||||
void DlogDebugInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } | |||||
void DlogDebugInner(int module_id, const char *fmt, ...) { | |||||
if (log_level > DLOG_DEBUG) { | |||||
return; | |||||
} | |||||
__DO_PRINT(); | |||||
} | |||||
void DlogEventInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } | void DlogEventInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } | ||||
@@ -39,30 +72,25 @@ void DlogWithKVInner(int module_id, int level, KeyValue *pst_kv_array, int kv_nu | |||||
dav_log(module_id, fmt); | dav_log(module_id, fmt); | ||||
} | } | ||||
int dlog_setlevel(int module_id, int level, int enable_event) { return DLOG_DEBUG; } | |||||
int dlog_setlevel(int module_id, int level, int enable_event) { | |||||
log_level = level; | |||||
return log_level; | |||||
} | |||||
int dlog_getlevel(int module_id, int *enable_event) { return DLOG_DEBUG; } | |||||
int dlog_getlevel(int module_id, int *enable_event) { return log_level; } | |||||
int CheckLogLevel(int moduleId, int logLevel) | |||||
{ | |||||
return 1; | |||||
} | |||||
int CheckLogLevel(int moduleId, int log_level_check) { return log_level >= log_level_check; } | |||||
/** | /** | ||||
* @ingroup plog | * @ingroup plog | ||||
* @brief DlogReportInitialize: init log in service process before all device setting. | * @brief DlogReportInitialize: init log in service process before all device setting. | ||||
* @return: 0: SUCCEED, others: FAILED | * @return: 0: SUCCEED, others: FAILED | ||||
*/ | */ | ||||
int DlogReportInitialize() { | |||||
return 0; | |||||
} | |||||
int DlogReportInitialize() { return 0; } | |||||
/** | /** | ||||
* @ingroup plog | * @ingroup plog | ||||
* @brief DlogReportFinalize: release log resource in service process after all device reset. | * @brief DlogReportFinalize: release log resource in service process after all device reset. | ||||
* @return: 0: SUCCEED, others: FAILED | * @return: 0: SUCCEED, others: FAILED | ||||
*/ | */ | ||||
int DlogReportFinalize() { | |||||
return 0; | |||||
} | |||||
int DlogReportFinalize() { return 0; } |
@@ -15,8 +15,8 @@ | |||||
include(cmake/graphengine.cmake) | include(cmake/graphengine.cmake) | ||||
add_subdirectory(easy_graph) | add_subdirectory(easy_graph) | ||||
add_subdirectory(stub_engine) | |||||
add_subdirectory(ge_graph_dsl) | add_subdirectory(ge_graph_dsl) | ||||
add_subdirectory(ge_running_env) | |||||
file(GLOB_RECURSE UTILS_SRC CONFIGURE_DEPENDS | file(GLOB_RECURSE UTILS_SRC CONFIGURE_DEPENDS | ||||
"utils/*.cc" | "utils/*.cc" | ||||
@@ -29,4 +29,4 @@ target_include_directories(framework | |||||
) | ) | ||||
set_target_properties(framework PROPERTIES CXX_STANDARD 11) | set_target_properties(framework PROPERTIES CXX_STANDARD 11) | ||||
target_link_libraries(framework PUBLIC ge_graph_dsl graphengine fe) | |||||
target_link_libraries(framework PUBLIC ge_graph_dsl ge_with_env) |
@@ -103,6 +103,7 @@ list(APPEND INCLUDE_DIRECTORIES | |||||
"${GE_CODE_DIR}/third_party/fwkacllib/inc/cce" | "${GE_CODE_DIR}/third_party/fwkacllib/inc/cce" | ||||
"${GE_CODE_DIR}/third_party/fwkacllib/inc/ops" | "${GE_CODE_DIR}/third_party/fwkacllib/inc/ops" | ||||
"${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain" | "${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain" | ||||
"${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info" | |||||
"${GE_CODE_DIR}/tests/ut/ge" | "${GE_CODE_DIR}/tests/ut/ge" | ||||
"${GE_CODE_DIR}/tests/ut/common" | "${GE_CODE_DIR}/tests/ut/common" | ||||
"${CMAKE_BINARY_DIR}" | "${CMAKE_BINARY_DIR}" | ||||
@@ -117,6 +118,7 @@ list(APPEND STUB_LIBS | |||||
runtime_stub | runtime_stub | ||||
profiler_stub | profiler_stub | ||||
hccl_stub | hccl_stub | ||||
opt_feature_stub | |||||
error_manager_stub | error_manager_stub | ||||
ascend_protobuf | ascend_protobuf | ||||
json | json | ||||
@@ -150,7 +152,7 @@ set_target_properties(metadef_graph PROPERTIES CXX_STANDARD 11) | |||||
# ---- Target : Local engine ---- | # ---- Target : Local engine ---- | ||||
add_library(ge_local_engine SHARED ${LOCAL_ENGINE_SRC} ${METADEF_REGISTER_SRCS}) | |||||
add_library(ge_local_engine SHARED ${LOCAL_ENGINE_SRC}) | |||||
target_include_directories(ge_local_engine | target_include_directories(ge_local_engine | ||||
PUBLIC | PUBLIC | ||||
@@ -169,38 +171,11 @@ target_compile_options(ge_local_engine PRIVATE | |||||
target_link_libraries(ge_local_engine PUBLIC | target_link_libraries(ge_local_engine PUBLIC | ||||
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} | $<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} | ||||
metadef_graph | |||||
-lrt -ldl -lpthread -lgcov | -lrt -ldl -lpthread -lgcov | ||||
) | ) | ||||
set_target_properties(ge_local_engine PROPERTIES CXX_STANDARD 11) | set_target_properties(ge_local_engine PROPERTIES CXX_STANDARD 11) | ||||
# ---- Target : Host engine ---- | |||||
add_library(host_cpu_engine SHARED ${HOST_ENGINE_SRC}) | |||||
target_include_directories(host_cpu_engine | |||||
PUBLIC | |||||
"${INCLUDE_DIRECTORIES}" | |||||
"${GE_CODE_DIR}/ge/host_cpu_engine" | |||||
) | |||||
target_compile_definitions(host_cpu_engine PRIVATE | |||||
google=ascend_private | |||||
FMK_SUPPORT_DUMP | |||||
) | |||||
target_compile_options(host_cpu_engine PRIVATE | |||||
-g --coverage -fprofile-arcs -ftest-coverage | |||||
-Werror=format | |||||
) | |||||
target_link_libraries(host_cpu_engine PUBLIC | |||||
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} metadef_graph -lrt -ldl -lpthread -lgcov | |||||
) | |||||
set_target_properties(host_cpu_engine PROPERTIES CXX_STANDARD 11) | |||||
# ---- Target : engine plugin---- | # ---- Target : engine plugin---- | ||||
# | # | ||||
@@ -273,4 +248,4 @@ target_link_libraries(graphengine PUBLIC | |||||
) | ) | ||||
set_target_properties(graphengine PROPERTIES CXX_STANDARD 11) | set_target_properties(graphengine PROPERTIES CXX_STANDARD 11) | ||||
add_dependencies(graphengine host_cpu_engine ge_local_engine nnengine engine_conf.json optimizer_priority.pbtxt) | |||||
add_dependencies(graphengine ge_local_engine nnengine engine_conf.json optimizer_priority.pbtxt) |
@@ -21,6 +21,7 @@ | |||||
#include "ge_graph_dsl/ge.h" | #include "ge_graph_dsl/ge.h" | ||||
#include "ge_graph_dsl/op_desc/op_box.h" | #include "ge_graph_dsl/op_desc/op_box.h" | ||||
#include "ge_graph_dsl/op_desc/op_desc_cfg.h" | #include "ge_graph_dsl/op_desc/op_desc_cfg.h" | ||||
#include "graph/ge_attr_value.h" | |||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
GE_NS_BEGIN | GE_NS_BEGIN | ||||
@@ -29,19 +30,32 @@ struct OpDescCfgBox : OpBox, private OpDescCfg { | |||||
OpDescCfgBox(const OpType &opType); | OpDescCfgBox(const OpType &opType); | ||||
OpDescCfgBox &InCnt(int in_cnt); | OpDescCfgBox &InCnt(int in_cnt); | ||||
OpDescCfgBox &OutCnt(int out_cnt); | OpDescCfgBox &OutCnt(int out_cnt); | ||||
OpDescCfgBox &ParentNodeIndex(int node_index); | |||||
OpDescCfgBox &TensorDesc(Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | OpDescCfgBox &TensorDesc(Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | ||||
std::vector<int64_t> shape = {1, 1, 224, 224}); | |||||
template<typename Type> | |||||
OpDescCfgBox& Attr(const std::string &name, Type value) { | |||||
auto attrvalue = ge::GeAttrValue::CreateFrom<Type>(value); | |||||
attrs_.emplace(std::make_pair(name, attrvalue)); | |||||
return *this; | |||||
} | |||||
std::vector<int64_t> shape = {1, 1, 224, 224}); | |||||
OpDescCfgBox &Weight(GeTensorPtr &); | |||||
private: | |||||
template <typename Type> | |||||
OpDescCfgBox &Attr(const std::string &name, Type &&value) { | |||||
auto attrvalue = ge::GeAttrValue::CreateFrom<Type>(std::forward<Type>(value)); | |||||
attrs_.emplace(std::make_pair(name, attrvalue)); | |||||
return *this; | |||||
} | |||||
template <typename Type> | |||||
OpDescCfgBox &Attr(const std::string &name, Type &value) { | |||||
auto attrvalue = ge::GeAttrValue::CreateFrom<Type>(value); | |||||
attrs_.emplace(std::make_pair(name, attrvalue)); | |||||
return *this; | |||||
} | |||||
OpDescCfgBox &Attr(const std::string &name, int value); | |||||
OpDescCfgBox &Attr(const std::string &name, const char *value); | |||||
OpDescPtr Build(const ::EG_NS::NodeId &id) const override; | OpDescPtr Build(const ::EG_NS::NodeId &id) const override; | ||||
void UpdateAttrs(OpDescPtr&) const; | |||||
std::map<std::string, GeAttrValue> attrs_; | |||||
private: | |||||
void UpdateAttrs(OpDescPtr &) const; | |||||
std::map<std::string, GeAttrValue> attrs_; | |||||
}; | }; | ||||
#define OP_CFG(optype) ::GE_NS::OpDescCfgBox(optype) | #define OP_CFG(optype) ::GE_NS::OpDescCfgBox(optype) | ||||
@@ -17,8 +17,8 @@ | |||||
#include "ge_graph_dsl/op_desc/op_desc_cfg_box.h" | #include "ge_graph_dsl/op_desc/op_desc_cfg_box.h" | ||||
#include "easy_graph/infra/status.h" | #include "easy_graph/infra/status.h" | ||||
#include "ge_graph_dsl/op_desc/op_desc_cfg_repo.h" | #include "ge_graph_dsl/op_desc/op_desc_cfg_repo.h" | ||||
#include "ge_graph_dsl/op_desc/op_desc_cfg.h" | |||||
#include "external/graph/gnode.h" | #include "external/graph/gnode.h" | ||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
using ::EG_NS::Status; | using ::EG_NS::Status; | ||||
@@ -44,6 +44,26 @@ OpDescCfgBox &OpDescCfgBox::OutCnt(int out_cnt) { | |||||
return *this; | return *this; | ||||
} | } | ||||
OpDescCfgBox &OpDescCfgBox::ParentNodeIndex(int node_index) { | |||||
this->Attr(ATTR_NAME_PARENT_NODE_INDEX, node_index); | |||||
return *this; | |||||
} | |||||
OpDescCfgBox &OpDescCfgBox::Attr(const std::string &name, int value) { | |||||
this->Attr(name, (int64_t)value); | |||||
return *this; | |||||
} | |||||
OpDescCfgBox &OpDescCfgBox::Attr(const std::string &name, const char *value) { | |||||
this->Attr(name, std::string(value)); | |||||
return *this; | |||||
} | |||||
OpDescCfgBox &OpDescCfgBox::Weight(GeTensorPtr &tensor_ptr) { | |||||
this->Attr<GeAttrValue::TENSOR>(ATTR_NAME_WEIGHTS, tensor_ptr); | |||||
return *this; | |||||
} | |||||
OpDescCfgBox &OpDescCfgBox::TensorDesc(Format format, DataType data_type, std::vector<int64_t> shape) { | OpDescCfgBox &OpDescCfgBox::TensorDesc(Format format, DataType data_type, std::vector<int64_t> shape) { | ||||
default_tensor_.format_ = format; | default_tensor_.format_ = format; | ||||
default_tensor_.data_type_ = data_type; | default_tensor_.data_type_ = data_type; | ||||
@@ -51,10 +71,9 @@ OpDescCfgBox &OpDescCfgBox::TensorDesc(Format format, DataType data_type, std::v | |||||
return *this; | return *this; | ||||
} | } | ||||
void OpDescCfgBox::UpdateAttrs(OpDescPtr& op_desc) const { | |||||
std::for_each(attrs_.begin(), attrs_.end(), [&op_desc](const auto &attr){ | |||||
op_desc->SetAttr(attr.first, attr.second); | |||||
}); | |||||
void OpDescCfgBox::UpdateAttrs(OpDescPtr &op_desc) const { | |||||
std::for_each(attrs_.begin(), attrs_.end(), | |||||
[&op_desc](const auto &attr) { op_desc->SetAttr(attr.first, attr.second); }); | |||||
} | } | ||||
OpDescPtr OpDescCfgBox::Build(const ::EG_NS::NodeId &id) const { | OpDescPtr OpDescCfgBox::Build(const ::EG_NS::NodeId &id) const { | ||||
@@ -0,0 +1,75 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "gtest/gtest.h" | |||||
#include "framework/common/types.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "ge_graph_dsl/op_desc/op_desc_cfg_box.h" | |||||
#include "graph/ge_tensor.h" | |||||
#include "graph/utils/attr_utils.h" | |||||
GE_NS_BEGIN | |||||
class OpDescCfgTest : public testing::Test {}; | |||||
TEST_F(OpDescCfgTest, test_attr_set_string_success) { | |||||
auto op_ptr = OP_CFG(DATA).Attr(ENTER_ATTR_FRAME_NAME, "1").Build("data1"); | |||||
ge::GeAttrValue ret; | |||||
op_ptr->GetAttr(ENTER_ATTR_FRAME_NAME, ret); | |||||
std::string value; | |||||
ret.GetValue<std::string>(value); | |||||
ASSERT_EQ(value, "1"); | |||||
} | |||||
TEST_F(OpDescCfgTest, test_attr_set_int_success) { | |||||
auto op_ptr = OP_CFG(DATA).Attr(ENTER_ATTR_FRAME_NAME, 2).Build("data1"); | |||||
ge::GeAttrValue ret; | |||||
op_ptr->GetAttr(ENTER_ATTR_FRAME_NAME, ret); | |||||
int64_t value; | |||||
ret.GetValue<int64_t>(value); | |||||
ASSERT_EQ(value, 2); | |||||
} | |||||
TEST_F(OpDescCfgTest, test_attr_set_perent_node_index_success) { | |||||
auto op_ptr = OP_CFG(DATA).ParentNodeIndex(2).Build("data1"); | |||||
ge::GeAttrValue ret; | |||||
op_ptr->GetAttr(ATTR_NAME_PARENT_NODE_INDEX, ret); | |||||
int64_t value; | |||||
ret.GetValue<int64_t>(value); | |||||
ASSERT_EQ(value, 2); | |||||
} | |||||
TEST_F(OpDescCfgTest, test_attr_set_weight_success) { | |||||
int64_t dims_size = 1; | |||||
vector<int64_t> data_vec = {5}; | |||||
for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); | |||||
vector<int32_t> data_value_vec(dims_size, 1); | |||||
GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); | |||||
GeTensorPtr data_tensor = std::make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), | |||||
data_value_vec.size() * sizeof(int32_t)); | |||||
auto op_ptr = OP_CFG(CONSTANT).Weight(data_tensor).Build("const1"); | |||||
ConstGeTensorPtr tensor_value; | |||||
ASSERT_TRUE(AttrUtils::GetTensor(op_ptr, ge::ATTR_NAME_WEIGHTS, tensor_value)); | |||||
ASSERT_EQ(tensor_value->GetTensorDesc().GetDataType(), DT_INT32); | |||||
} | |||||
GE_NS_END |
@@ -23,6 +23,7 @@ GE_NS_BEGIN | |||||
REGISTER_OPTYPE_DEFINE(DATA, "Data"); | REGISTER_OPTYPE_DEFINE(DATA, "Data"); | ||||
REGISTER_OPTYPE_DEFINE(HCOMALLGATHER, "HcomAllGather"); | REGISTER_OPTYPE_DEFINE(HCOMALLGATHER, "HcomAllGather"); | ||||
REGISTER_OPTYPE_DEFINE(VARIABLE, "Variable"); | REGISTER_OPTYPE_DEFINE(VARIABLE, "Variable"); | ||||
REGISTER_OPTYPE_DEFINE(CONSTANT, "Const"); | |||||
REGISTER_OPTYPE_DEFINE(CONSTANTOP, "Constant"); | REGISTER_OPTYPE_DEFINE(CONSTANTOP, "Constant"); | ||||
REGISTER_OPTYPE_DEFINE(LESS, "Less"); | REGISTER_OPTYPE_DEFINE(LESS, "Less"); | ||||
REGISTER_OPTYPE_DEFINE(MUL, "Mul"); | REGISTER_OPTYPE_DEFINE(MUL, "Mul"); | ||||
@@ -0,0 +1,18 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
add_subdirectory(include) | |||||
add_subdirectory(src) | |||||
add_subdirectory(tests) |
@@ -0,0 +1,17 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
add_library(ge_running_env_inc INTERFACE) | |||||
target_include_directories(ge_running_env_inc INTERFACE ./) |
@@ -0,0 +1,35 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef H1D9F4FDE_BB21_4DE4_AC7E_751920B45039 | |||||
#define H1D9F4FDE_BB21_4DE4_AC7E_751920B45039 | |||||
#include "fake_ns.h" | |||||
#include "opskernel_manager/ops_kernel_manager.h" | |||||
#include "register/ops_kernel_builder_registry.h" | |||||
FAKE_NS_BEGIN | |||||
struct EnvInstaller { | |||||
virtual void InstallTo(std::map<string, OpsKernelInfoStorePtr>&) const {} | |||||
virtual void InstallTo(std::map<string, GraphOptimizerPtr>&) const {} | |||||
virtual void InstallTo(std::map<string, OpsKernelBuilderPtr>&) const {} | |||||
virtual void Install() const {} | |||||
}; | |||||
FAKE_NS_END | |||||
#endif |
@@ -0,0 +1,56 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef HAF5E9BF2_752F_4E03_B0A5_E1B912A5FA24 | |||||
#define HAF5E9BF2_752F_4E03_B0A5_E1B912A5FA24 | |||||
#include <string> | |||||
#include "fake_ns.h" | |||||
#include "ge_running_env/env_installer.h" | |||||
#include "common/opskernel/ops_kernel_info_types.h" | |||||
#include "opskernel_manager/ops_kernel_manager.h" | |||||
#include "register/ops_kernel_builder_registry.h" | |||||
#include "fake_ops_kernel_builder.h" | |||||
#include "fake_ops_kernel_info_store.h" | |||||
FAKE_NS_BEGIN | |||||
using FakeOpsKernelBuilderPtr = std::shared_ptr<FakeOpsKernelBuilder>; | |||||
using FakeOpsKernelInfoStorePtr = std::shared_ptr<FakeOpsKernelInfoStore>; | |||||
struct FakeEngine : EnvInstaller { | |||||
FakeEngine(const std::string& engine_name); | |||||
FakeEngine& KernelBuilder(FakeOpsKernelBuilderPtr); | |||||
FakeEngine& KernelInfoStore(FakeOpsKernelInfoStorePtr); | |||||
FakeEngine& KernelInfoStore(const std::string&); | |||||
private: | |||||
void InstallTo(std::map<string, OpsKernelInfoStorePtr>&) const override; | |||||
void InstallTo(std::map<string, OpsKernelBuilderPtr>&) const override; | |||||
private: | |||||
template <typename BasePtr, typename SubClass> | |||||
void InstallFor(std::map<string, BasePtr>& maps, const std::map<std::string, std::shared_ptr<SubClass>>&) const; | |||||
private: | |||||
std::string engine_name_; | |||||
std::set<std::string> info_store_names_; | |||||
std::map<std::string, FakeOpsKernelBuilderPtr> custom_builders_; | |||||
std::map<std::string, FakeOpsKernelInfoStorePtr> custom_info_stores_; | |||||
}; | |||||
FAKE_NS_END | |||||
#endif |
@@ -0,0 +1,28 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef H7AEFF0EA_9FDE_487F_8562_2917A2D48EA2 | |||||
#define H7AEFF0EA_9FDE_487F_8562_2917A2D48EA2 | |||||
#define FAKE_NS ge | |||||
#define FAKE_NS_BEGIN namespace FAKE_NS { | |||||
#define FAKE_NS_END } | |||||
#define USING_STUB_NS using namespace FAKE_NS; | |||||
#define FWD_DECL_STUB(type) \ | |||||
namespace FAKE_NS { \ | |||||
struct type; \ | |||||
} | |||||
#endif |
@@ -0,0 +1,49 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef H737AD661_27C0_400F_8B08_29701308C5D0 | |||||
#define H737AD661_27C0_400F_8B08_29701308C5D0 | |||||
#include <string> | |||||
#include <set> | |||||
#include "fake_ns.h" | |||||
#include "ge_running_env/env_installer.h" | |||||
#include "graph/operator_factory.h" | |||||
FAKE_NS_BEGIN | |||||
struct FakeOp : EnvInstaller { | |||||
FakeOp(const std::string& op_type); | |||||
FakeOp& Inputs(const std::vector<std::string>&); | |||||
FakeOp& Outputs(const std::vector<std::string>&); | |||||
FakeOp& InferShape(InferShapeFunc); | |||||
FakeOp& InfoStoreAndBuilder(const std::string&); | |||||
private: | |||||
void Install() const override; | |||||
void InstallTo(std::map<string, OpsKernelInfoStorePtr>&) const override; | |||||
private: | |||||
const std::string op_type_; | |||||
std::vector<std::string> inputs_; | |||||
std::vector<std::string> outputs_; | |||||
InferShapeFunc info_fun_; | |||||
std::set<std::string> info_store_names_; | |||||
}; | |||||
FAKE_NS_END | |||||
#endif /* H737AD661_27C0_400F_8B08_29701308C5D0 */ |
@@ -13,39 +13,26 @@ | |||||
* See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#ifndef H39E4E719_91F4_4D0F_BA4F_6BA56CB1E20D | |||||
#define H39E4E719_91F4_4D0F_BA4F_6BA56CB1E20D | |||||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ | |||||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ | |||||
#include "fake_ns.h" | |||||
#include "common/opskernel/ops_kernel_builder.h" | |||||
#include "info_store_holder.h" | |||||
#if defined(_MSC_VER) | |||||
#ifdef FUNC_VISIBILITY | |||||
#define GE_FUNC_VISIBILITY _declspec(dllexport) | |||||
#else | |||||
#define GE_FUNC_VISIBILITY | |||||
#endif | |||||
#else | |||||
#ifdef FUNC_VISIBILITY | |||||
#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
#else | |||||
#define GE_FUNC_VISIBILITY | |||||
#endif | |||||
#endif | |||||
FAKE_NS_BEGIN | |||||
#include "common/opskernel/ops_kernel_builder.h" | |||||
struct FakeOpsKernelBuilder : OpsKernelBuilder, InfoStoreHolder { | |||||
FakeOpsKernelBuilder(const std::string &kernel_lib_name); | |||||
FakeOpsKernelBuilder(); | |||||
namespace ge { | |||||
namespace st { | |||||
class GE_FUNC_VISIBILITY StubOpsKernelBuilder : public OpsKernelBuilder { | |||||
public: | |||||
private: | |||||
Status Initialize(const map<std::string, std::string> &options) override; | Status Initialize(const map<std::string, std::string> &options) override; | ||||
Status Finalize() override; | Status Finalize() override; | ||||
Status CalcOpRunningParam(Node &node) override; | Status CalcOpRunningParam(Node &node) override; | ||||
Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) override; | Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) override; | ||||
}; | }; | ||||
} // namespace st | |||||
} // namespace ge | |||||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_BUILDER_H_ | |||||
FAKE_NS_END | |||||
#endif |
@@ -0,0 +1,39 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef H1EBABA85_7056_48F0_B496_E4DB68E5FED3 | |||||
#define H1EBABA85_7056_48F0_B496_E4DB68E5FED3 | |||||
#include "fake_ns.h" | |||||
#include "common/opskernel/ops_kernel_info_store.h" | |||||
#include "ge/ge_api_types.h" | |||||
#include "info_store_holder.h" | |||||
FAKE_NS_BEGIN | |||||
struct FakeOpsKernelInfoStore : OpsKernelInfoStore, InfoStoreHolder { | |||||
FakeOpsKernelInfoStore(const std::string &kernel_lib_name); | |||||
FakeOpsKernelInfoStore(); | |||||
private: | |||||
Status Initialize(const std::map<std::string, std::string> &options) override; | |||||
Status Finalize() override; | |||||
bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override; | |||||
void GetAllOpsKernelInfo(std::map<std::string, ge::OpInfo> &infos) const override; | |||||
}; | |||||
FAKE_NS_END | |||||
#endif |
@@ -0,0 +1,45 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef H99C11FC4_700E_4D4D_B073_7808FA88BEBC | |||||
#define H99C11FC4_700E_4D4D_B073_7808FA88BEBC | |||||
#include "ge_running_env/fake_engine.h" | |||||
#include "fake_ns.h" | |||||
#include "opskernel_manager/ops_kernel_manager.h" | |||||
#include "register/ops_kernel_builder_registry.h" | |||||
FAKE_NS_BEGIN | |||||
struct GeRunningEnvFaker { | |||||
GeRunningEnvFaker(); | |||||
GeRunningEnvFaker &Reset(); | |||||
GeRunningEnvFaker &Install(const EnvInstaller &); | |||||
GeRunningEnvFaker &InstallDefault(); | |||||
static void BackupEnv(); | |||||
private: | |||||
void flush(); | |||||
private: | |||||
std::map<string, vector<OpInfo>> &op_kernel_info_; | |||||
std::map<string, OpsKernelInfoStorePtr> &ops_kernel_info_stores_; | |||||
std::map<string, GraphOptimizerPtr> &ops_kernel_optimizers_; | |||||
std::map<string, OpsKernelBuilderPtr> &ops_kernel_builders_; | |||||
}; | |||||
FAKE_NS_END | |||||
#endif /* H99C11FC4_700E_4D4D_B073_7808FA88BEBC */ |
@@ -13,33 +13,27 @@ | |||||
* See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#ifndef H7992249B_058D_40A1_94EA_52BBCB76434E | |||||
#define H7992249B_058D_40A1_94EA_52BBCB76434E | |||||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ | |||||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ | |||||
#include <climits> | |||||
#include <string> | |||||
#include <vector> | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "fake_ns.h" | |||||
#include "common/opskernel/ops_kernel_info_types.h" | #include "common/opskernel/ops_kernel_info_types.h" | ||||
#include "graph/node.h" | |||||
namespace ge { | |||||
namespace st { | |||||
/** | |||||
* The base class for all op. | |||||
*/ | |||||
class GE_FUNC_VISIBILITY Op { | |||||
public: | |||||
Op(const Node &node, RunContext &run_context) : run_context_(run_context), node_(node) {} | |||||
virtual ~Op() = default; | |||||
virtual Status Run() = 0; | |||||
FAKE_NS_BEGIN | |||||
struct InfoStoreHolder { | |||||
InfoStoreHolder(); | |||||
InfoStoreHolder(const std::string&); | |||||
void EngineName(std::string engine_name); | |||||
void RegistOp(std::string op_type); | |||||
std::string GetLibName(); | |||||
protected: | protected: | ||||
const RunContext &run_context_; | |||||
const Node &node_; | |||||
std::map<std::string, ge::OpInfo> op_info_map_; | |||||
std::string kernel_lib_name_; | |||||
std::string engine_name_; | |||||
}; | }; | ||||
} // namespace st | |||||
} // namespace ge | |||||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ | |||||
FAKE_NS_END | |||||
#endif |
@@ -0,0 +1,45 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++") | |||||
# ---- Target : stub Host engine ---- | |||||
add_library(ge_with_env STATIC ${SOURCES}) | |||||
target_include_directories(ge_with_env | |||||
PUBLIC | |||||
include | |||||
) | |||||
target_include_directories(ge_with_env | |||||
PRIVATE | |||||
${CMAKE_CURRENT_SOURCE_DIR} | |||||
) | |||||
target_compile_definitions(ge_with_env PRIVATE | |||||
google=ascend_private | |||||
FMK_SUPPORT_DUMP | |||||
) | |||||
target_compile_options(ge_with_env PRIVATE | |||||
-g --coverage -fprofile-arcs -ftest-coverage | |||||
-Werror=format | |||||
) | |||||
target_link_libraries(ge_with_env PUBLIC | |||||
$<BUILD_INTERFACE:intf_pub> ge_running_env_inc graphengine -lrt -ldl -lpthread -lgcov | |||||
) | |||||
set_target_properties(ge_with_env PROPERTIES CXX_STANDARD 17) |
@@ -0,0 +1,81 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_running_env/fake_engine.h" | |||||
#include "ge_running_env/fake_ops_kernel_builder.h" | |||||
#include "ge_running_env/fake_ops_kernel_info_store.h" | |||||
#include "opskernel_manager/ops_kernel_manager.h" | |||||
FAKE_NS_BEGIN | |||||
FakeEngine::FakeEngine(const std::string &engine_name) : engine_name_(engine_name) {} | |||||
FakeEngine &FakeEngine::KernelInfoStore(const std::string &info_store) { | |||||
info_store_names_.insert(info_store); | |||||
return *this; | |||||
} | |||||
FakeEngine &FakeEngine::KernelInfoStore(FakeOpsKernelInfoStorePtr ptr) { | |||||
info_store_names_.insert(ptr->GetLibName()); | |||||
custom_info_stores_.insert(std::make_pair(ptr->GetLibName(), ptr)); | |||||
return *this; | |||||
} | |||||
FakeEngine &FakeEngine::KernelBuilder(FakeOpsKernelBuilderPtr builder) { | |||||
info_store_names_.insert(builder->GetLibName()); | |||||
custom_builders_.insert(std::make_pair(builder->GetLibName(), builder)); | |||||
return *this; | |||||
} | |||||
namespace { | |||||
template <typename BasePtr, typename SubClass> | |||||
void InstallDefault(std::map<string, BasePtr> &maps, const std::string &info_store_name, | |||||
const std::string &engine_name) { | |||||
auto parent_obj = std::make_shared<SubClass>(info_store_name); | |||||
if (parent_obj == nullptr) { | |||||
return; | |||||
} | |||||
parent_obj->EngineName(engine_name); | |||||
maps.insert(std::make_pair(parent_obj->GetLibName(), parent_obj)); | |||||
} | |||||
} // namespace | |||||
template <typename BasePtr, typename SubClass> | |||||
void FakeEngine::InstallFor(std::map<string, BasePtr> &maps, | |||||
const std::map<std::string, std::shared_ptr<SubClass>> &child_maps) const { | |||||
if (info_store_names_.empty()) { | |||||
InstallDefault<BasePtr, SubClass>(maps, engine_name_, engine_name_); | |||||
} else { | |||||
for (auto &info_store_name : info_store_names_) { | |||||
auto iter = child_maps.find(info_store_name); | |||||
if (iter == child_maps.end()) { | |||||
InstallDefault<BasePtr, SubClass>(maps, info_store_name, engine_name_); | |||||
} else { | |||||
maps.insert(std::make_pair(iter->second->GetLibName(), iter->second)); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void FakeEngine::InstallTo(std::map<string, OpsKernelInfoStorePtr> &ops_kernel_info_stores) const { | |||||
InstallFor<OpsKernelInfoStorePtr, FakeOpsKernelInfoStore>(ops_kernel_info_stores, custom_info_stores_); | |||||
} | |||||
void FakeEngine::InstallTo(std::map<string, OpsKernelBuilderPtr> &ops_kernel_builders) const { | |||||
InstallFor<OpsKernelBuilderPtr, FakeOpsKernelBuilder>(ops_kernel_builders, custom_builders_); | |||||
} | |||||
FAKE_NS_END |
@@ -14,40 +14,25 @@ | |||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#include "stub_ops_kernel_builder.h" | |||||
#include <memory> | |||||
#include "ge_running_env/fake_ops_kernel_builder.h" | |||||
#include "graph/utils/node_utils.h" | |||||
#include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
#include "ge/ge_api_types.h" | #include "ge/ge_api_types.h" | ||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include <securec.h> | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "host_cpu_engine/common/constant/constant.h" | |||||
#include "register/ops_kernel_builder_registry.h" | |||||
#include "inc/st_types.h" | |||||
FAKE_NS_BEGIN | |||||
namespace ge { | |||||
namespace st { | |||||
REGISTER_OPS_KERNEL_BUILDER(kAicoreLibName, StubOpsKernelBuilder); | |||||
REGISTER_OPS_KERNEL_BUILDER(kVectorLibName, StubOpsKernelBuilder); | |||||
REGISTER_OPS_KERNEL_BUILDER(kAicpuLibName, StubOpsKernelBuilder); | |||||
REGISTER_OPS_KERNEL_BUILDER(kAicpuAscendLibName, StubOpsKernelBuilder); | |||||
REGISTER_OPS_KERNEL_BUILDER(kHcclLibName, StubOpsKernelBuilder); | |||||
REGISTER_OPS_KERNEL_BUILDER(kRTSLibName, StubOpsKernelBuilder); | |||||
FakeOpsKernelBuilder::FakeOpsKernelBuilder(const std::string &info_store_name) : InfoStoreHolder(info_store_name) {} | |||||
FakeOpsKernelBuilder::FakeOpsKernelBuilder() : InfoStoreHolder() {} | |||||
Status StubOpsKernelBuilder::Finalize() { | |||||
return SUCCESS; | |||||
} | |||||
Status StubOpsKernelBuilder::Initialize(const map<std::string, std::string> &options) { | |||||
return SUCCESS; | |||||
} | |||||
Status FakeOpsKernelBuilder::Finalize() { return SUCCESS; } | |||||
Status FakeOpsKernelBuilder::Initialize(const map<std::string, std::string> &options) { return SUCCESS; } | |||||
Status StubOpsKernelBuilder::CalcOpRunningParam(Node &ge_node) { | |||||
Status FakeOpsKernelBuilder::CalcOpRunningParam(Node &ge_node) { | |||||
OpDescPtr op_desc = ge_node.GetOpDesc(); | OpDescPtr op_desc = ge_node.GetOpDesc(); | ||||
if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
GELOGE(FAILED, "[Get][OpDesc]CalcOpRunningParam failed, as op desc is null"); | |||||
REPORT_INNER_ERROR("E19999", "GetOpDesc failed."); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -86,9 +71,9 @@ Status StubOpsKernelBuilder::CalcOpRunningParam(Node &ge_node) { | |||||
name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), | name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), | ||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
REPORT_CALL_ERROR( | REPORT_CALL_ERROR( | ||||
"E19999", "CalcTensorMemSize failed for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.", | |||||
name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
"E19999", "CalcTensorMemSize failed for op[%s:%s] out[%zu] mem size, mem_size=%ld, format=%s, data_type=%s.", | |||||
name.c_str(), type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
GELOGI("Calc op[%s:%s] out[%zu] mem size is %ld, format=%s, data_type=%s.", name.c_str(), type.c_str(), i, | GELOGI("Calc op[%s:%s] out[%zu] mem size is %ld, format=%s, data_type=%s.", name.c_str(), type.c_str(), i, | ||||
@@ -111,9 +96,9 @@ Status StubOpsKernelBuilder::CalcOpRunningParam(Node &ge_node) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status StubOpsKernelBuilder::GenerateTask(const Node &node, RunContext &context, vector<domi::TaskDef> &tasks) { | |||||
Status FakeOpsKernelBuilder::GenerateTask(const Node &node, RunContext &context, vector<domi::TaskDef> &tasks) { | |||||
// no need to generate device task | // no need to generate device task | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace st | |||||
} // namespace ge | |||||
FAKE_NS_END |
@@ -0,0 +1,42 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "external/ge/ge_api_error_codes.h" | |||||
#include "ge_running_env/fake_ops_kernel_info_store.h" | |||||
FAKE_NS_BEGIN | |||||
FakeOpsKernelInfoStore::FakeOpsKernelInfoStore(const std::string &info_store_name) : InfoStoreHolder(info_store_name) {} | |||||
FakeOpsKernelInfoStore::FakeOpsKernelInfoStore() : InfoStoreHolder() {} | |||||
Status FakeOpsKernelInfoStore::Finalize() { | |||||
op_info_map_.clear(); | |||||
return SUCCESS; | |||||
} | |||||
Status FakeOpsKernelInfoStore::Initialize(const std::map<std::string, std::string> &options) { return SUCCESS; } | |||||
void FakeOpsKernelInfoStore::GetAllOpsKernelInfo(map<string, OpInfo> &infos) const { infos = op_info_map_; } | |||||
bool FakeOpsKernelInfoStore::CheckSupported(const OpDescPtr &op_desc, std::string &) const { | |||||
if (op_desc == nullptr) { | |||||
return false; | |||||
} | |||||
return op_info_map_.count(op_desc->GetType()) > 0; | |||||
} | |||||
FAKE_NS_END |
@@ -0,0 +1,49 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_running_env/info_store_holder.h" | |||||
FAKE_NS_BEGIN | |||||
namespace { | |||||
std::string GenStoreName() { | |||||
static int store_id = 0; | |||||
return "store_" + std::to_string(store_id++); | |||||
} | |||||
} // namespace | |||||
InfoStoreHolder::InfoStoreHolder(const std::string& kernel_lib_name) : kernel_lib_name_(kernel_lib_name) {} | |||||
InfoStoreHolder::InfoStoreHolder() : kernel_lib_name_(GenStoreName()) {} | |||||
void InfoStoreHolder::RegistOp(std::string op_type) { | |||||
OpInfo default_op_info = {.engine = engine_name_, | |||||
.opKernelLib = kernel_lib_name_, | |||||
.computeCost = 0, | |||||
.flagPartial = false, | |||||
.flagAsync = false, | |||||
.isAtomic = false}; | |||||
auto iter = op_info_map_.find(op_type); | |||||
if (iter == op_info_map_.end()) { | |||||
op_info_map_.emplace(op_type, default_op_info); | |||||
} | |||||
} | |||||
void InfoStoreHolder::EngineName(std::string engine_name) { engine_name_ = engine_name; } | |||||
std::string InfoStoreHolder::GetLibName() { return kernel_lib_name_; } | |||||
FAKE_NS_END |
@@ -0,0 +1,56 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_default_running_env.h" | |||||
#include "ge_running_env/ge_running_env_faker.h" | |||||
#include "ge_running_env/fake_op.h" | |||||
FAKE_NS_BEGIN | |||||
namespace { | |||||
std::vector<FakeEngine> default_engines = {FakeEngine("AIcoreEngine").KernelInfoStore("AiCoreLib"), | |||||
FakeEngine("VectorEngine").KernelInfoStore("VectorLib"), | |||||
FakeEngine("DNN_VM_AICPU").KernelInfoStore("AicpuLib"), | |||||
FakeEngine("DNN_VM_AICPU_ASCEND").KernelInfoStore("AicpuAscendLib"), | |||||
FakeEngine("DNN_HCCL").KernelInfoStore("HcclLib"), | |||||
FakeEngine("DNN_VM_RTS").KernelInfoStore("RTSLib")}; | |||||
std::vector<FakeOp> fake_ops = { | |||||
FakeOp(ENTER).InfoStoreAndBuilder("RTSLib"), FakeOp(MERGE).InfoStoreAndBuilder("RTSLib"), | |||||
FakeOp(SWITCH).InfoStoreAndBuilder("RTSLib"), FakeOp(LOOPCOND).InfoStoreAndBuilder("RTSLib"), | |||||
FakeOp(STREAMMERGE).InfoStoreAndBuilder("RTSLib"), FakeOp(STREAMSWITCH).InfoStoreAndBuilder("RTSLib"), | |||||
FakeOp(STREAMACTIVE).InfoStoreAndBuilder("RTSLib"), FakeOp(EXIT).InfoStoreAndBuilder("RTSLib"), | |||||
FakeOp(LESS).InfoStoreAndBuilder("AiCoreLib"), FakeOp(NEXTITERATION).InfoStoreAndBuilder("AiCoreLib"), | |||||
FakeOp(CAST).InfoStoreAndBuilder("AiCoreLib"), FakeOp(TRANSDATA).InfoStoreAndBuilder("AiCoreLib"), | |||||
FakeOp(NOOP).InfoStoreAndBuilder("AiCoreLib"), FakeOp(VARIABLE).InfoStoreAndBuilder("AiCoreLib"), | |||||
FakeOp(CONSTANT).InfoStoreAndBuilder("AiCoreLib"), FakeOp(ASSIGN).InfoStoreAndBuilder("AiCoreLib"), | |||||
FakeOp(ADD).InfoStoreAndBuilder("AiCoreLib"), FakeOp(MUL).InfoStoreAndBuilder("AiCoreLib"), | |||||
FakeOp(DATA).InfoStoreAndBuilder("AiCoreLib"), FakeOp(NETOUTPUT).InfoStoreAndBuilder("AiCoreLib"), | |||||
}; | |||||
} // namespace | |||||
void GeDefaultRunningEnv::InstallTo(GeRunningEnvFaker& ge_env) { | |||||
for (auto& fake_engine : default_engines) { | |||||
ge_env.Install(fake_engine); | |||||
} | |||||
for (auto& fake_op : fake_ops) { | |||||
ge_env.Install(fake_op); | |||||
} | |||||
} | |||||
FAKE_NS_END |
@@ -0,0 +1,32 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_5D044B8760CB41ABA108AE2E37E8EBDE | |||||
#define INC_5D044B8760CB41ABA108AE2E37E8EBDE | |||||
#include "ge_running_env/fake_ns.h" | |||||
FAKE_NS_BEGIN | |||||
struct GeRunningEnvFaker; | |||||
struct GeDefaultRunningEnv { | |||||
static void InstallTo(GeRunningEnvFaker&); | |||||
}; | |||||
FAKE_NS_END | |||||
#endif |
@@ -0,0 +1,109 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <map> | |||||
#include <algorithm> | |||||
#include "external/ge/ge_api.h" | |||||
#include "opskernel_manager/ops_kernel_builder_manager.h" | |||||
#include "init/gelib.h" | |||||
#include "utility" | |||||
#include "ge_running_env/ge_running_env_faker.h" | |||||
#include "ge_default_running_env.h" | |||||
#include "ge_running_env/env_installer.h" | |||||
#include "op/fake_op_repo.h" | |||||
FAKE_NS_BEGIN | |||||
namespace { | |||||
OpsKernelManager& getKernelManger() { | |||||
std::shared_ptr<GELib> instancePtr = ge::GELib::GetInstance(); | |||||
return instancePtr->OpsKernelManagerObj(); | |||||
} | |||||
struct InitEnv { | |||||
static InitEnv& GetInstance() { | |||||
static InitEnv instance; | |||||
return instance; | |||||
} | |||||
void reset(std::map<string, OpsKernelInfoStorePtr>& ops_kernel_info_stores, | |||||
std::map<string, OpsKernelBuilderPtr>& builders) { | |||||
std::set<string> remove_info_names; | |||||
for (auto iter : ops_kernel_info_stores) { | |||||
if (kernel_info_names.find(iter.first) == kernel_info_names.end()) { | |||||
remove_info_names.insert(iter.first); | |||||
} | |||||
} | |||||
for (auto info_name : remove_info_names) { | |||||
ops_kernel_info_stores.erase(info_name); | |||||
builders.erase(info_name); | |||||
} | |||||
} | |||||
private: | |||||
InitEnv() { | |||||
for (auto iter : getKernelManger().GetAllOpsKernelInfoStores()) { | |||||
kernel_info_names.insert(iter.first); | |||||
} | |||||
} | |||||
private: | |||||
std::set<string> kernel_info_names; | |||||
}; | |||||
} // namespace | |||||
GeRunningEnvFaker::GeRunningEnvFaker() | |||||
: op_kernel_info_(const_cast<std::map<string, vector<OpInfo>>&>(getKernelManger().GetAllOpsKernelInfo())), | |||||
ops_kernel_info_stores_( | |||||
const_cast<std::map<string, OpsKernelInfoStorePtr>&>(getKernelManger().GetAllOpsKernelInfoStores())), | |||||
ops_kernel_optimizers_( | |||||
const_cast<std::map<string, GraphOptimizerPtr>&>(getKernelManger().GetAllGraphOptimizerObjs())), | |||||
ops_kernel_builders_(const_cast<std::map<string, OpsKernelBuilderPtr>&>( | |||||
OpsKernelBuilderManager::Instance().GetAllOpsKernelBuilders())) { | |||||
Reset(); | |||||
} | |||||
GeRunningEnvFaker& GeRunningEnvFaker::Reset() { | |||||
InitEnv& init_env = InitEnv::GetInstance(); | |||||
FakeOpRepo::Reset(); | |||||
init_env.reset(ops_kernel_info_stores_, ops_kernel_builders_); | |||||
flush(); | |||||
return *this; | |||||
} | |||||
void GeRunningEnvFaker::BackupEnv() { InitEnv::GetInstance(); } | |||||
GeRunningEnvFaker& GeRunningEnvFaker::Install(const EnvInstaller& installer) { | |||||
installer.Install(); | |||||
installer.InstallTo(ops_kernel_info_stores_); | |||||
installer.InstallTo(ops_kernel_optimizers_); | |||||
installer.InstallTo(ops_kernel_builders_); | |||||
flush(); | |||||
return *this; | |||||
} | |||||
void GeRunningEnvFaker::flush() { | |||||
op_kernel_info_.clear(); | |||||
getKernelManger().GetOpsKernelInfo(""); | |||||
} | |||||
GeRunningEnvFaker& GeRunningEnvFaker::InstallDefault() { | |||||
Reset(); | |||||
GeDefaultRunningEnv::InstallTo(*this); | |||||
return *this; | |||||
} | |||||
FAKE_NS_END |
@@ -0,0 +1,95 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_running_env/fake_op.h" | |||||
#include "fake_op_repo.h" | |||||
#include "ge_running_env/info_store_holder.h" | |||||
#include "graph/operator_factory.h" | |||||
FAKE_NS_BEGIN | |||||
FakeOp::FakeOp(const std::string& op_type) : op_type_(op_type) {} | |||||
FakeOp& FakeOp::Inputs(const std::vector<std::string>& inputs) { | |||||
inputs_ = inputs; | |||||
return *this; | |||||
} | |||||
FakeOp& FakeOp::Outputs(const std::vector<std::string>& outputs) { | |||||
outputs_ = outputs; | |||||
return *this; | |||||
} | |||||
FakeOp& FakeOp::InferShape(InferShapeFunc infer_fun) { | |||||
info_fun_ = infer_fun; | |||||
return *this; | |||||
} | |||||
FakeOp& FakeOp::InfoStoreAndBuilder(const std::string& name) { | |||||
info_store_names_.insert(name); | |||||
return *this; | |||||
} | |||||
namespace { | |||||
void RegistOpToInfoStore(OpsKernelInfoStorePtr& info_store, const std::string& op_type) { | |||||
if (info_store == nullptr) { | |||||
return; | |||||
} | |||||
auto holder = dynamic_cast<InfoStoreHolder*>(info_store.get()); | |||||
holder->RegistOp(op_type); | |||||
} | |||||
struct FakeOperator : Operator { | |||||
FakeOperator(const std::string& op_type) : Operator(op_type) {} | |||||
FakeOperator& RegistInputs(const std::vector<std::string>& inputs) { | |||||
for (auto& input : inputs) { | |||||
Operator::InputRegister(input); | |||||
} | |||||
return *this; | |||||
} | |||||
FakeOperator& RegistOutputs(const std::vector<std::string>& outputs) { | |||||
for (auto& output : outputs) { | |||||
Operator::OutputRegister(output); | |||||
} | |||||
return *this; | |||||
} | |||||
}; | |||||
} // namespace | |||||
void FakeOp::InstallTo(std::map<string, OpsKernelInfoStorePtr>& info_stores) const { | |||||
std::for_each(info_store_names_.begin(), info_store_names_.end(), [=, &info_stores](auto& info_store_name) { | |||||
auto iter = info_stores.find(info_store_name); | |||||
if (iter != info_stores.end()) { | |||||
RegistOpToInfoStore(iter->second, op_type_); | |||||
} | |||||
}); | |||||
} | |||||
void FakeOp::Install() const { | |||||
FakeOpRepo::Regist( | |||||
op_type_, | |||||
[op_type = this->op_type_, inputs = this->inputs_, outputs = this->outputs_](const std::string&) -> Operator { | |||||
return FakeOperator(op_type).RegistInputs(inputs).RegistOutputs(outputs); | |||||
}); | |||||
if (info_fun_) { | |||||
FakeOpRepo::Regist(op_type_, info_fun_); | |||||
} | |||||
} | |||||
FAKE_NS_END |
@@ -0,0 +1,39 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/operator_factory_impl.h" | |||||
#include "ge_running_env/fake_op.h" | |||||
#include "fake_op_repo.h" | |||||
FAKE_NS_BEGIN | |||||
void FakeOpRepo::Reset() { | |||||
if (OperatorFactoryImpl::operator_creators_) { | |||||
OperatorFactoryImpl::operator_creators_->clear(); | |||||
} | |||||
if (OperatorFactoryImpl::operator_infershape_funcs_) { | |||||
OperatorFactoryImpl::operator_infershape_funcs_->clear(); | |||||
} | |||||
} | |||||
void FakeOpRepo::Regist(const std::string &operator_type, const OpCreator creator) { | |||||
OperatorFactoryImpl::RegisterOperatorCreator(operator_type, creator); | |||||
} | |||||
void FakeOpRepo::Regist(const std::string &operator_type, const InferShapeFunc infer_fun) { | |||||
OperatorFactoryImpl::RegisterInferShapeFunc(operator_type, infer_fun); | |||||
} | |||||
FAKE_NS_END |
@@ -0,0 +1,31 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef DBF6CE7CD4AC4A83BA4ED4B372FC66E4 | |||||
#define DBF6CE7CD4AC4A83BA4ED4B372FC66E4 | |||||
#include "ge_running_env/fake_ns.h" | |||||
#include "graph/operator_factory.h" | |||||
FAKE_NS_BEGIN | |||||
struct FakeOpRepo { | |||||
static void Reset(); | |||||
static void Regist(const std::string &operator_type, const OpCreator); | |||||
static void Regist(const std::string &operator_type, const InferShapeFunc); | |||||
}; | |||||
FAKE_NS_END | |||||
#endif |
@@ -0,0 +1,33 @@ | |||||
# Copyright 2021 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================ | |||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP") | |||||
add_executable(ge_running_env_test ${SOURCES}) | |||||
target_include_directories(ge_running_env_test | |||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
) | |||||
target_compile_options(ge_running_env_test PRIVATE | |||||
-g | |||||
) | |||||
set_target_properties(ge_running_env_test PROPERTIES CXX_STANDARD 17) | |||||
target_link_libraries(ge_running_env_test PUBLIC gtest ge_with_env) | |||||
include(CTest) | |||||
enable_testing() | |||||
add_test(NAME test COMMAND ge_running_env_test) |
@@ -0,0 +1,148 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include "graph/operator_factory_impl.h" | |||||
#include "init/gelib.h" | |||||
#include "external/ge/ge_api.h" | |||||
#include "opskernel_manager/ops_kernel_builder_manager.h" | |||||
#include "ge_running_env/fake_ops_kernel_builder.h" | |||||
#include "ge_running_env/fake_ns.h" | |||||
#include "ge_running_env/ge_running_env_faker.h" | |||||
#include "ge_running_env/fake_op.h" | |||||
FAKE_NS_BEGIN | |||||
#define ASSERT_OPS_LIST_SIZE(list_size) \ | |||||
std::vector<AscendString> ops_list; \ | |||||
OperatorFactory::GetOpsTypeList(ops_list);\ | |||||
ASSERT_EQ(ops_list.size(), list_size); | |||||
class GeRunningEvnFakerTest : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
OpsKernelManager &kernel_manager = ge::GELib::GetInstance()->OpsKernelManagerObj(); | |||||
OpsKernelBuilderManager &builder_manager = OpsKernelBuilderManager::Instance(); | |||||
}; | |||||
TEST_F(GeRunningEvnFakerTest, test_reset_running_env_is_success) { | |||||
GeRunningEnvFaker ge_env; | |||||
ge_env.Reset(); | |||||
ASSERT_OPS_LIST_SIZE(0); | |||||
ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 1); | |||||
ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 1); | |||||
ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 52); | |||||
ASSERT_EQ(kernel_manager.GetOpsKernelInfo(SWITCH).size(), 1); | |||||
} | |||||
TEST_F(GeRunningEvnFakerTest, test_install_fake_op_success) { | |||||
GeRunningEnvFaker ge_env; | |||||
ge_env.Install(FakeOp(DATA)).Install(FakeOp(SWITCH)); | |||||
ASSERT_OPS_LIST_SIZE(2); | |||||
ASSERT_TRUE(OperatorFactory::IsExistOp(DATA)); | |||||
ASSERT_TRUE(OperatorFactory::IsExistOp(SWITCH)); | |||||
} | |||||
TEST_F(GeRunningEvnFakerTest, test_install_fake_op_with_inputs_and_outputs_success) { | |||||
GeRunningEnvFaker ge_env; | |||||
ge_env.Install(FakeOp(ADD).Inputs({"x1", "x2"}).Outputs({"y"})); | |||||
auto add1 = OperatorFactory::CreateOperator("add1", ADD); | |||||
ASSERT_EQ(add1.GetInputsSize(), 2); | |||||
ASSERT_EQ(add1.GetOutputsSize(), 1); | |||||
ASSERT_OPS_LIST_SIZE(1); | |||||
} | |||||
TEST_F(GeRunningEvnFakerTest, test_install_fake_op_with_infer_shape_success) { | |||||
GeRunningEnvFaker ge_env; | |||||
auto infer_fun = [](Operator &op) -> graphStatus { | |||||
TensorDesc input_desc = op.GetInputDescByName("data"); | |||||
return GRAPH_SUCCESS; | |||||
}; | |||||
ASSERT_TRUE(OperatorFactoryImpl::GetInferShapeFunc(DATA) == nullptr); | |||||
ge_env.Install(FakeOp(DATA).Inputs({"data"}).InferShape(infer_fun)); | |||||
ASSERT_TRUE(OperatorFactoryImpl::GetInferShapeFunc(DATA) != nullptr); | |||||
} | |||||
TEST_F(GeRunningEvnFakerTest, test_install_engine_with_default_info_store) { | |||||
GeRunningEnvFaker ge_env; | |||||
ge_env.Install(FakeEngine("DNN_HCCL")); | |||||
ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 2); | |||||
ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 2); | |||||
ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 52); | |||||
ASSERT_EQ(kernel_manager.GetOpsKernelInfo(SWITCH).size(), 1); | |||||
} | |||||
TEST_F(GeRunningEvnFakerTest, test_install_engine_with_info_store_name) { | |||||
GeRunningEnvFaker ge_env; | |||||
ge_env.Install(FakeEngine("DNN_HCCL").KernelInfoStore("AiCoreLib2")) | |||||
.Install(FakeOp(SWITCH).InfoStoreAndBuilder("AiCoreLib2")); | |||||
ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 2); | |||||
ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 2); | |||||
ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 52); | |||||
ASSERT_EQ(kernel_manager.GetOpsKernelInfo(SWITCH).size(), 2); | |||||
} | |||||
TEST_F(GeRunningEvnFakerTest, test_install_custom_kernel_builder_success) { | |||||
struct FakeKernelBuilder : FakeOpsKernelBuilder { | |||||
Status CalcOpRunningParam(Node &node) override { | |||||
OpDescPtr op_desc = node.GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
}; | |||||
GeRunningEnvFaker ge_env; | |||||
auto ai_core_kernel = FakeEngine("DNN_HCCL").KernelBuilder(std::make_shared<FakeKernelBuilder>()); | |||||
ge_env.Reset().Install(ai_core_kernel); | |||||
ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 2); | |||||
ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 2); | |||||
ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 52); | |||||
} | |||||
TEST_F(GeRunningEvnFakerTest, test_install_custom_kernel_info_store_success) { | |||||
struct FakeKernelBuilder : FakeOpsKernelInfoStore { | |||||
FakeKernelBuilder(const std::string &kernel_lib_name) : FakeOpsKernelInfoStore(kernel_lib_name) {} | |||||
bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override { return FAILED; } | |||||
}; | |||||
GeRunningEnvFaker ge_env; | |||||
auto ai_core_kernel = FakeEngine("DNN_HCCL").KernelInfoStore(std::make_shared<FakeKernelBuilder>("AiCoreLib2")); | |||||
ge_env.Reset().Install(ai_core_kernel); | |||||
ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 2); | |||||
ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 2); | |||||
ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 52); | |||||
} | |||||
TEST_F(GeRunningEvnFakerTest, test_install_default_fake_engine_success) { | |||||
GeRunningEnvFaker ge_env; | |||||
ge_env.InstallDefault(); | |||||
ASSERT_EQ(kernel_manager.GetAllOpsKernelInfoStores().size(), 7); | |||||
ASSERT_EQ(builder_manager.GetAllOpsKernelBuilders().size(), 7); | |||||
ASSERT_EQ(kernel_manager.GetAllOpsKernelInfo().size(), 66); | |||||
} | |||||
FAKE_NS_END |
@@ -0,0 +1,34 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include "common/debug/log.h" | |||||
#include "external/ge/ge_api.h" | |||||
#include "ge_running_env/ge_running_env_faker.h" | |||||
using namespace std; | |||||
using namespace ge; | |||||
int main(int argc, char **argv) { | |||||
map<AscendString, AscendString> options; | |||||
ge::GEInitialize(options); | |||||
GeRunningEnvFaker::BackupEnv(); | |||||
testing::InitGoogleTest(&argc, argv); | |||||
int ret = RUN_ALL_TESTS(); | |||||
return ret; | |||||
} |
@@ -1,58 +0,0 @@ | |||||
list(APPEND INCLUDE_DIRECTORIES | |||||
"${CMAKE_CURRENT_SOURCE_DIR}" | |||||
"${GE_CODE_DIR}" | |||||
"${GE_CODE_DIR}/inc" | |||||
"${GE_CODE_DIR}/metadef/inc" | |||||
"${GE_CODE_DIR}/ge" | |||||
"${GE_CODE_DIR}/ge/inc" | |||||
"${GE_CODE_DIR}/ge/ir_build" | |||||
"${GE_CODE_DIR}/metadef" | |||||
"${GE_CODE_DIR}/metadef/graph" | |||||
"${GE_CODE_DIR}/inc/external" | |||||
"${GE_CODE_DIR}/inc/framework/common" | |||||
"${GE_CODE_DIR}/metadef/inc/external" | |||||
"${GE_CODE_DIR}/metadef/inc/external/graph" | |||||
"${GE_CODE_DIR}/metadef/inc/graph" | |||||
"${GE_CODE_DIR}/inc/framework" | |||||
"${GE_CODE_DIR}/metadef/inc/common" | |||||
"${GE_CODE_DIR}/metadef/third_party" | |||||
"${GE_CODE_DIR}/metadef/third_party/transformer/inc" | |||||
"${GE_CODE_DIR}/parser" | |||||
"${GE_CODE_DIR}/parser/parser" | |||||
"${GE_CODE_DIR}/third_party/fwkacllib/inc" | |||||
"${GE_CODE_DIR}/third_party/fwkacllib/inc/cce" | |||||
"${GE_CODE_DIR}/third_party/fwkacllib/inc/ops" | |||||
"${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain" | |||||
"${GE_CODE_DIR}/tests/ut/ge" | |||||
"${GE_CODE_DIR}/tests/ut/common" | |||||
"${CMAKE_BINARY_DIR}" | |||||
"${CMAKE_BINARY_DIR}/proto/ge" | |||||
"${CMAKE_BINARY_DIR}/proto/ge/proto" | |||||
) | |||||
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS "*.cc" "*.CC" "*.cpp" "*.CPP" "*.c++") | |||||
# ---- Target : stub Host engine ---- | |||||
add_library(fe SHARED ${SOURCES}) | |||||
target_include_directories(fe | |||||
PUBLIC | |||||
${INCLUDE_DIRECTORIES} | |||||
${CMAKE_CURRENT_SOURCE_DIR} | |||||
) | |||||
target_compile_definitions(fe PRIVATE | |||||
google=ascend_private | |||||
FMK_SUPPORT_DUMP | |||||
) | |||||
target_compile_options(fe PRIVATE | |||||
-g --coverage -fprofile-arcs -ftest-coverage | |||||
-Werror=format | |||||
) | |||||
target_link_libraries(fe PUBLIC | |||||
$<BUILD_INTERFACE:intf_pub> ${STUB_LIBS} metadef_graph -lmmpa -L${GE_CODE_DIR}/third_party/prebuild/x86_64 -lrt -ldl -lpthread -lgcov | |||||
) | |||||
set_target_properties(fe PROPERTIES CXX_STANDARD 11) |
@@ -1,74 +0,0 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "stub_engine.h" | |||||
#include <map> | |||||
#include <memory> | |||||
#include <string> | |||||
#include <securec.h> | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "common/ge/ge_util.h" | |||||
#include "inc/st_types.h" | |||||
namespace ge { | |||||
namespace st { | |||||
StubEngine &StubEngine::Instance() { | |||||
static StubEngine instance; | |||||
return instance; | |||||
} | |||||
Status StubEngine::Initialize(const std::map<string, string> &options) { | |||||
for (const auto engine_2_lib : kStubEngine2KernelLib) { | |||||
auto ops_kernel_store = MakeShared<StubOpsKernelInfoStore>(engine_2_lib.second); | |||||
if (ops_kernel_store == nullptr) { | |||||
return FAILED; | |||||
} | |||||
ops_kernel_store_map_.insert(make_pair(engine_2_lib.second, ops_kernel_store)); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
void StubEngine::GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map) { | |||||
for (const auto name_2_ops_kernel_store : ops_kernel_store_map_) { | |||||
ops_kernel_map[name_2_ops_kernel_store.first] = name_2_ops_kernel_store.second; | |||||
} | |||||
} | |||||
void StubEngine::GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &) { | |||||
// no optimizer for host cpu engine | |||||
} | |||||
Status StubEngine::Finalize() { | |||||
return SUCCESS; | |||||
} | |||||
} // namespace st | |||||
} // namespace ge | |||||
ge::Status Initialize(const std::map<string, string> &options) { | |||||
return ge::st::StubEngine::Instance().Initialize(options); | |||||
} | |||||
void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map) { | |||||
ge::st::StubEngine::Instance().GetOpsKernelInfoStores(ops_kernel_map); | |||||
} | |||||
void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers) { | |||||
ge::st::StubEngine::Instance().GetGraphOptimizerObjs(graph_optimizers); | |||||
} | |||||
ge::Status Finalize() { | |||||
return ge::st::StubEngine::Instance().Finalize(); | |||||
} |
@@ -1,127 +0,0 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GRAPH_ENGINE_LLT_STUB_ENGINE_H_ | |||||
#define GRAPH_ENGINE_LLT_STUB_ENGINE_H_ | |||||
#if defined(_MSC_VER) | |||||
#ifdef FUNC_VISIBILITY | |||||
#define GE_FUNC_VISIBILITY _declspec(dllexport) | |||||
#else | |||||
#define GE_FUNC_VISIBILITY | |||||
#endif | |||||
#else | |||||
#ifdef FUNC_VISIBILITY | |||||
#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
#else | |||||
#define GE_FUNC_VISIBILITY | |||||
#endif | |||||
#endif | |||||
#include <map> | |||||
#include <memory> | |||||
#include <string> | |||||
#include "inc/st_types.h" | |||||
#include "common/opskernel/ops_kernel_info_store.h" | |||||
#include "common/optimizer/graph_optimizer.h" | |||||
#include "stub_engine/ops_kernel_store/stub_ops_kernel_store.h" | |||||
using OpsKernelInfoStorePtr = std::shared_ptr<ge::OpsKernelInfoStore>; | |||||
using StubOpsKernelInfoStorePtr = std::shared_ptr<ge::st::StubOpsKernelInfoStore>; | |||||
using GraphOptimizerPtr = std::shared_ptr<ge::GraphOptimizer>; | |||||
namespace ge { | |||||
namespace st { | |||||
/** | |||||
* host cpu engine. | |||||
* Used for the ops which executes on host. | |||||
*/ | |||||
class GE_FUNC_VISIBILITY StubEngine { | |||||
public: | |||||
/** | |||||
* get StubEngine instance. | |||||
* @return StubEngine instance. | |||||
*/ | |||||
static StubEngine &Instance(); | |||||
virtual ~StubEngine() = default; | |||||
/** | |||||
* When Ge start, GE will invoke this interface | |||||
* @return The status whether initialize successfully | |||||
*/ | |||||
Status Initialize(const std::map<string, string> &options); | |||||
/** | |||||
* After the initialize, GE will invoke this interface | |||||
* to get the Ops kernel Store. | |||||
* @param ops_kernel_map The host cpu's ops kernel info | |||||
*/ | |||||
void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map); | |||||
/** | |||||
* After the initialize, GE will invoke this interface | |||||
* to get the Graph Optimizer. | |||||
* @param graph_optimizers The host cpu's Graph Optimizer objs | |||||
*/ | |||||
void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers); | |||||
/** | |||||
* When the graph finished, GE will invoke this interface | |||||
* @return The status whether initialize successfully | |||||
*/ | |||||
Status Finalize(); | |||||
StubEngine(const StubEngine &StubEngine) = delete; | |||||
StubEngine(const StubEngine &&StubEngine) = delete; | |||||
StubEngine &operator=(const StubEngine &StubEngine) = delete; | |||||
StubEngine &operator=(StubEngine &&StubEngine) = delete; | |||||
private: | |||||
StubEngine() = default; | |||||
map<string, OpsKernelInfoStorePtr> ops_kernel_store_map_; | |||||
}; | |||||
} // namespace st | |||||
} // namespace ge | |||||
extern "C" { | |||||
/** | |||||
* When Ge start, GE will invoke this interface | |||||
* @return The status whether initialize successfully | |||||
*/ | |||||
GE_FUNC_VISIBILITY ge::Status Initialize(const map<string, string> &options); | |||||
/** | |||||
* After the initialize, GE will invoke this interface to get the Ops kernel Store | |||||
* @param ops_kernel_map The host cpu's ops kernel info | |||||
*/ | |||||
GE_FUNC_VISIBILITY void GetOpsKernelInfoStores(std::map<std::string, OpsKernelInfoStorePtr> &ops_kernel_map); | |||||
/** | |||||
* After the initialize, GE will invoke this interface to get the Graph Optimizer | |||||
* @param graph_optimizers The host cpu's Graph Optimizer objs | |||||
*/ | |||||
GE_FUNC_VISIBILITY void GetGraphOptimizerObjs(std::map<std::string, GraphOptimizerPtr> &graph_optimizers); | |||||
/** | |||||
* When the graph finished, GE will invoke this interface | |||||
* @return The status whether initialize successfully | |||||
*/ | |||||
GE_FUNC_VISIBILITY ge::Status Finalize(); | |||||
} | |||||
#endif // GRAPH_ENGINE_LLT_STUB_ENGINE_H_ |
@@ -1,33 +0,0 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GRAPHENGINE_ST_TYPES_H | |||||
#define GRAPHENGINE_ST_TYPES_H | |||||
#include <map> | |||||
namespace ge { | |||||
namespace st { | |||||
const std::string kAicoreLibName = "AiCoreLib"; | |||||
const std::string kVectorLibName = "VectorLib"; | |||||
const std::string kAicpuLibName = "AicpuLib"; | |||||
const std::string kAicpuAscendLibName = "AicpuAscendLib"; | |||||
const std::string kHcclLibName = "HcclLib"; | |||||
const std::string kRTSLibName = "RTSLib"; | |||||
const std::map<std::string, std::string> kStubEngine2KernelLib = { | |||||
{"AIcoreEngine", "AiCoreLib"}, {"VectorEngine", "VectorLib"}, | |||||
{"DNN_VM_AICPU", "AicpuLib"}, {"DNN_VM_AICPU_ASCEND", "AicpuAscendLib"}, | |||||
{"DNN_HCCL", "HcclLib"}, {"DNN_VM_RTS", "RTSLib"}}; | |||||
} // namespace st | |||||
} // namespace ge | |||||
#endif // GRAPHENGINE_ST_TYPES_H |
@@ -1,41 +0,0 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "inc/st_types.h" | |||||
#include "stub_engine/ops_kernel_store/op/host_op.h" | |||||
#include "framework/common/util.h" | |||||
#include "stub_engine/ops_kernel_store/op/stub_op_factory.h" | |||||
namespace ge { | |||||
namespace st { | |||||
Status HostOp::Run() { | |||||
// no need to generate device task | |||||
return SUCCESS; | |||||
} | |||||
REGISTER_OP_CREATOR(Enter, RTSLib, HostOp); | |||||
REGISTER_OP_CREATOR(Merge, RTSLib, HostOp); | |||||
REGISTER_OP_CREATOR(Switch, RTSLib, HostOp); | |||||
REGISTER_OP_CREATOR(Less, AiCoreLib, HostOp); | |||||
REGISTER_OP_CREATOR(NextIteration, AiCoreLib, HostOp); | |||||
REGISTER_OP_CREATOR(LoopCond, RTSLib, HostOp); | |||||
REGISTER_OP_CREATOR(Exit, RTSLib, HostOp); | |||||
REGISTER_OP_CREATOR(StreamMerge, RTSLib, HostOp); | |||||
REGISTER_OP_CREATOR(StreamSwitch, RTSLib, HostOp); | |||||
REGISTER_OP_CREATOR(StreamActive, RTSLib, HostOp); | |||||
REGISTER_OP_CREATOR(Cast, AiCoreLib, HostOp); | |||||
REGISTER_OP_CREATOR(Transdata, AiCoreLib, HostOp); | |||||
} // namespace st | |||||
} // namespace ge |
@@ -1,51 +0,0 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "stub_op_factory.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "graph/op_desc.h" | |||||
namespace ge { | |||||
namespace st { | |||||
OpFactory &OpFactory::Instance() { | |||||
static OpFactory instance; | |||||
return instance; | |||||
} | |||||
std::shared_ptr<Op> OpFactory::CreateOp(const Node &node, RunContext &run_context) { | |||||
auto iter = op_creator_map_.find(node.GetType()); | |||||
if (iter != op_creator_map_.end()) { | |||||
return iter->second(node, run_context); | |||||
} | |||||
GELOGE(FAILED, "Not supported OP, type = %s, name = %s", node.GetType().c_str(), node.GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
void OpFactory::RegisterCreator(const std::string &type, const std::string &kernel_lib, const OP_CREATOR_FUNC &func) { | |||||
if (func == nullptr) { | |||||
GELOGW("Func is NULL."); | |||||
return; | |||||
} | |||||
if (all_store_ops_.find(kernel_lib) != all_store_ops_.end()) { | |||||
all_store_ops_[kernel_lib].emplace_back(type); | |||||
} else { | |||||
all_store_ops_[kernel_lib] = {type}; | |||||
} | |||||
} | |||||
} // namespace st | |||||
} // namespace ge |
@@ -1,109 +0,0 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ | |||||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ | |||||
#include <functional> | |||||
#include <map> | |||||
#include <memory> | |||||
#include <string> | |||||
#include <vector> | |||||
#include "common/ge/ge_util.h" | |||||
#include "stub_engine/ops_kernel_store/op/op.h" | |||||
#include "inc/st_types.h" | |||||
namespace ge { | |||||
namespace st { | |||||
using OP_CREATOR_FUNC = std::function<std::shared_ptr<Op>(const Node &, RunContext &)>; | |||||
/** | |||||
* manage all the op, support create op. | |||||
*/ | |||||
class GE_FUNC_VISIBILITY OpFactory { | |||||
public: | |||||
static OpFactory &Instance(); | |||||
/** | |||||
* @brief create Op. | |||||
* @param [in] node share ptr of node | |||||
* @param [in] run_context run context | |||||
* @return not nullptr success | |||||
* @return nullptr fail | |||||
*/ | |||||
std::shared_ptr<Op> CreateOp(const Node &node, RunContext &run_context); | |||||
/** | |||||
* @brief Register Op create function. | |||||
* @param [in] type Op type | |||||
* @param [in] func Op create func | |||||
*/ | |||||
void RegisterCreator(const std::string &type, const std::string &lib_name, const OP_CREATOR_FUNC &func); | |||||
const std::vector<std::string> &GetAllOps() const { | |||||
return all_ops_; | |||||
} | |||||
const std::vector<std::string> &GetAllOps(std::string lib_name) const { | |||||
auto iter = all_store_ops_.find(lib_name); | |||||
if (iter == all_store_ops_.end()) { | |||||
return all_ops_; | |||||
} | |||||
return iter->second; | |||||
} | |||||
bool CheckSupported(const std::string &type) { | |||||
return op_creator_map_.find(type) != op_creator_map_.end(); | |||||
} | |||||
OpFactory(const OpFactory &) = delete; | |||||
OpFactory &operator=(const OpFactory &) = delete; | |||||
OpFactory(OpFactory &&) = delete; | |||||
OpFactory &operator=(OpFactory &&) = delete; | |||||
private: | |||||
OpFactory() = default; | |||||
~OpFactory() = default; | |||||
// the op creator function map | |||||
std::map<std::string, OP_CREATOR_FUNC> op_creator_map_; | |||||
std::map<std::string, std::map<std::string, OP_CREATOR_FUNC>> lib_op_creator_map_; | |||||
std::vector<std::string> all_ops_; | |||||
std::map<std::string, vector<std::string>> all_store_ops_; | |||||
}; | |||||
class GE_FUNC_VISIBILITY OpRegistrar { | |||||
public: | |||||
OpRegistrar(const std::string &type, const std::string &kernel_lib, const OP_CREATOR_FUNC &func) { | |||||
OpFactory::Instance().RegisterCreator(type, kernel_lib, func); | |||||
} | |||||
~OpRegistrar() = default; | |||||
OpRegistrar(const OpRegistrar &) = delete; | |||||
OpRegistrar &operator=(const OpRegistrar &) = delete; | |||||
OpRegistrar(OpRegistrar &&) = delete; | |||||
OpRegistrar &operator=(OpRegistrar &&) = delete; | |||||
}; | |||||
#define REGISTER_OP_CREATOR(type, lib_name, clazz) \ | |||||
std::shared_ptr<Op> Creator_##type##Op(const Node &node, RunContext &run_context) { \ | |||||
return MakeShared<clazz>(node, run_context); \ | |||||
} \ | |||||
OpRegistrar g_##type##Op_creator(#type, #lib_name, Creator_##type##Op) | |||||
} // namespace st | |||||
} // namespace ge | |||||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ |
@@ -1,77 +0,0 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "stub_ops_kernel_store.h" | |||||
#include <memory> | |||||
#include "ge/ge_api_types.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "graph/utils/node_utils.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "graph/utils/type_utils.h" | |||||
#include "op/stub_op_factory.h" | |||||
namespace ge { | |||||
namespace st { | |||||
using domi::TaskDef; | |||||
using std::map; | |||||
using std::string; | |||||
using std::vector; | |||||
Status StubOpsKernelInfoStore::Initialize(const map<string, string> &options) { | |||||
GELOGI("StubOpsKernelInfoStore init start."); | |||||
string engine_name; | |||||
for (const auto &engine_2_lib : kStubEngine2KernelLib) { | |||||
if (engine_2_lib.second == store_name_) { | |||||
engine_name = engine_2_lib.first; | |||||
} | |||||
} | |||||
if (engine_name.empty()) { | |||||
return FAILED; | |||||
} | |||||
OpInfo default_op_info = {.engine = engine_name, | |||||
.opKernelLib = store_name_, | |||||
.computeCost = 0, | |||||
.flagPartial = false, | |||||
.flagAsync = false, | |||||
.isAtomic = false}; | |||||
// Init op_info_map_ | |||||
auto all_ops_in_store = OpFactory::Instance().GetAllOps(store_name_); | |||||
for (auto &op : all_ops_in_store) { | |||||
op_info_map_[op] = default_op_info; | |||||
} | |||||
GELOGI("StubOpsKernelInfoStore inited success. op num=%zu", op_info_map_.size()); | |||||
return SUCCESS; | |||||
} | |||||
Status StubOpsKernelInfoStore::Finalize() { | |||||
op_info_map_.clear(); | |||||
return SUCCESS; | |||||
} | |||||
void StubOpsKernelInfoStore::GetAllOpsKernelInfo(map<string, OpInfo> &infos) const { | |||||
infos = op_info_map_; | |||||
} | |||||
bool StubOpsKernelInfoStore::CheckSupported(const OpDescPtr &op_desc, std::string &) const { | |||||
if (op_desc == nullptr) { | |||||
return false; | |||||
} | |||||
return op_info_map_.count(op_desc->GetType()) > 0; | |||||
} | |||||
} // namespace st | |||||
} // namespace ge |
@@ -1,73 +0,0 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ | |||||
#define GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ | |||||
#if defined(_MSC_VER) | |||||
#ifdef FUNC_VISIBILITY | |||||
#define GE_FUNC_VISIBILITY _declspec(dllexport) | |||||
#else | |||||
#define GE_FUNC_VISIBILITY | |||||
#endif | |||||
#else | |||||
#ifdef FUNC_VISIBILITY | |||||
#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
#else | |||||
#define GE_FUNC_VISIBILITY | |||||
#endif | |||||
#endif | |||||
#include <map> | |||||
#include <string> | |||||
#include <vector> | |||||
#include "common/opskernel/ops_kernel_info_store.h" | |||||
namespace ge { | |||||
namespace st { | |||||
/*const vector<std::string> kStubOpKernelLibNameVec = { | |||||
"AiCoreLib", | |||||
"AicpuLib", | |||||
"HcclLib", | |||||
"RTSLib" | |||||
};*/ | |||||
class GE_FUNC_VISIBILITY StubOpsKernelInfoStore : public OpsKernelInfoStore { | |||||
public: | |||||
StubOpsKernelInfoStore(std::string store_name) : store_name_(store_name) {} | |||||
~StubOpsKernelInfoStore() override = default; | |||||
Status Initialize(const std::map<std::string, std::string> &options) override; | |||||
Status Finalize() override; | |||||
bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override; | |||||
void GetAllOpsKernelInfo(std::map<std::string, ge::OpInfo> &infos) const override; | |||||
std::string GetOpsKernelStoreName() const { | |||||
return store_name_; | |||||
} | |||||
StubOpsKernelInfoStore(const StubOpsKernelInfoStore &ops_kernel_store) = delete; | |||||
StubOpsKernelInfoStore(const StubOpsKernelInfoStore &&ops_kernel_store) = delete; | |||||
StubOpsKernelInfoStore &operator=(const StubOpsKernelInfoStore &ops_kernel_store) = delete; | |||||
StubOpsKernelInfoStore &operator=(StubOpsKernelInfoStore &&ops_kernel_store) = delete; | |||||
private: | |||||
// store op name and OpInfo key-value pair | |||||
std::map<std::string, ge::OpInfo> op_info_map_; | |||||
std::string store_name_; | |||||
}; | |||||
} // namespace st | |||||
} // namespace ge | |||||
#endif // GE_HOST_CPU_ENGINE_OPS_KERNEL_STORE_HOST_CPU_OPS_KERNEL_INFO_H_ |
@@ -8,7 +8,7 @@ target_include_directories(graph_engine_test | |||||
set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17) | set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17) | ||||
target_link_libraries(graph_engine_test PRIVATE gtest gtest_main framework) | |||||
target_link_libraries(graph_engine_test PRIVATE gtest framework) | |||||
include(CTest) | include(CTest) | ||||
enable_testing() | enable_testing() |
@@ -17,9 +17,13 @@ | |||||
#include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
#include <map> | #include <map> | ||||
#include "external/ge/ge_api.h" | #include "external/ge/ge_api.h" | ||||
#include "ge_running_env/fake_engine.h" | |||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "builder/graph_builder_utils.h" | #include "builder/graph_builder_utils.h" | ||||
#include "ge_running_env/ge_running_env_faker.h" | |||||
#include "graph/operator_reg.h" | #include "graph/operator_reg.h" | ||||
#include "graph/operator.h" | #include "graph/operator.h" | ||||
#define protected public | #define protected public | ||||
@@ -109,8 +113,8 @@ Graph BuildV1ControlFlowGraph() { | |||||
for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); | for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); | ||||
vector<int32_t> data_value_vec(dims_size, 1); | vector<int32_t> data_value_vec(dims_size, 1); | ||||
GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); | GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); | ||||
GeTensorPtr data_tensor = make_shared<GeTensor>(data_tensor_desc, (uint8_t *) data_value_vec.data(), | |||||
data_value_vec.size() * sizeof(int32_t)); | |||||
GeTensorPtr data_tensor = | |||||
make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), data_value_vec.size() * sizeof(int32_t)); | |||||
OpDescUtils::SetWeights(const_5->GetOpDesc(), data_tensor); | OpDescUtils::SetWeights(const_5->GetOpDesc(), data_tensor); | ||||
OpDescUtils::SetWeights(const_2->GetOpDesc(), data_tensor); | OpDescUtils::SetWeights(const_2->GetOpDesc(), data_tensor); | ||||
OpDescUtils::SetWeights(const_1->GetOpDesc(), data_tensor); | OpDescUtils::SetWeights(const_1->GetOpDesc(), data_tensor); | ||||
@@ -120,13 +124,9 @@ Graph BuildV1ControlFlowGraph() { | |||||
} // namespace | } // namespace | ||||
class FrameworkTest : public testing::Test { | class FrameworkTest : public testing::Test { | ||||
protected: | protected: | ||||
void SetUp() { | |||||
// ge initialize | |||||
map<AscendString, AscendString> options; | |||||
auto ret = ge::GEInitialize(options); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
void SetUp() { ge_env.InstallDefault(); } | |||||
void TearDown() {} | void TearDown() {} | ||||
GeRunningEnvFaker ge_env; | |||||
}; | }; | ||||
/// data data | /// data data | ||||
@@ -0,0 +1,123 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include "easy_graph/graph/box.h" | |||||
#include "easy_graph/graph/node.h" | |||||
#include "easy_graph/builder/graph_dsl.h" | |||||
#include "easy_graph/builder/box_builder.h" | |||||
#include "easy_graph/layout/graph_layout.h" | |||||
#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" | |||||
#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||||
#include "graph/graph.h" | |||||
#include "graph/compute_graph.h" | |||||
#include "framework/common/types.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "ge_graph_dsl/graph_dsl.h" | |||||
#include "ge_graph_dsl/op_desc/op_desc_cfg_box.h" | |||||
#define protected public | |||||
#define private public | |||||
#include "ge_opt_info/ge_opt_info.h" | |||||
#undef private | |||||
#undef protected | |||||
namespace ge { | |||||
class STEST_opt_info : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
TEST_F(STEST_opt_info, get_opt_info_all) { | |||||
std::map<std::string, std::string> options = {{ge::SOC_VERSION, "Ascend310"}}; | |||||
GetThreadLocalContext().SetGlobalOption(options); | |||||
/// data1 data2 | |||||
/// \ / | |||||
/// add | |||||
// build graph | |||||
DEF_GRAPH(g1) { | |||||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
CHAIN(NODE("data2", DATA)->NODE("add")); | |||||
}); | |||||
auto graph = ToGeGraph(g1); | |||||
// new session & add graph | |||||
Session session(options); | |||||
auto ret = session.AddGraph(1, graph, options); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
// build input tensor | |||||
std::vector<InputTensorInfo> inputs; | |||||
// build_graph through session | |||||
ret = session.BuildGraph(1, inputs); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
std::map<std::string, std::string> graph_options = GetThreadLocalContext().GetAllGraphOptions(); | |||||
auto itr = graph_options.find("opt_module.fe"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
itr = graph_options.find("opt_module.pass"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
itr = graph_options.find("opt_module.op_tune"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
itr = graph_options.find("opt_module.rl_tune"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
itr = graph_options.find("opt_module.aoe"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
} | |||||
TEST_F(STEST_opt_info, get_opt_info_success) { | |||||
std::map<std::string, std::string> options = {{ge::SOC_VERSION, "Ascend910"}}; | |||||
GetThreadLocalContext().SetGlobalOption(options); | |||||
/// data1 data2 | |||||
/// \ / | |||||
/// add | |||||
// build graph | |||||
DEF_GRAPH(g1) { | |||||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
CHAIN(NODE("data2", DATA)->NODE("add")); | |||||
}); | |||||
auto graph = ToGeGraph(g1); | |||||
// new session & add graph | |||||
Session session(options); | |||||
auto ret = session.AddGraph(1, graph, options); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
// build input tensor | |||||
std::vector<InputTensorInfo> inputs; | |||||
// build_graph through session | |||||
ret = session.BuildGraph(1, inputs); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
std::map<std::string, std::string> graph_options = GetThreadLocalContext().GetAllGraphOptions(); | |||||
auto itr = graph_options.find("opt_module.fe"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
itr = graph_options.find("opt_module.pass"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
itr = graph_options.find("opt_module.op_tune"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,37 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include "common/debug/log.h" | |||||
#include "external/ge/ge_api.h" | |||||
#include "ge_running_env/include/ge_running_env/ge_running_env_faker.h" | |||||
using namespace std; | |||||
using namespace ge; | |||||
int main(int argc, char **argv) { | |||||
// init the logging | |||||
map<AscendString, AscendString> options; | |||||
auto init_status = ge::GEInitialize(options); | |||||
if (init_status != SUCCESS) { | |||||
std::cout << "ge init failed , ret code:" << init_status << endl; | |||||
} | |||||
GeRunningEnvFaker::BackupEnv(); | |||||
testing::InitGoogleTest(&argc, argv); | |||||
int ret = RUN_ALL_TESTS(); | |||||
return ret; | |||||
} |
@@ -62,6 +62,7 @@ include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | |||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce) | include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce) | ||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) | include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) | ||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain) | include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/toolchain) | ||||
include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/opt_info) | |||||
include_directories(${GE_CODE_DIR}/tests/ut/ge) | include_directories(${GE_CODE_DIR}/tests/ut/ge) | ||||
include_directories(${GE_CODE_DIR}/tests/ut/common) | include_directories(${GE_CODE_DIR}/tests/ut/common) | ||||
include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
@@ -349,6 +350,7 @@ set(COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/common/ge/datatype_util.cc" | "${GE_CODE_DIR}/ge/common/ge/datatype_util.cc" | ||||
"${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" | "${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" | ||||
"${GE_CODE_DIR}/ge/session/omg.cc" | "${GE_CODE_DIR}/ge/session/omg.cc" | ||||
"${GE_CODE_DIR}/ge/ge_opt_info/ge_opt_info.cc" | |||||
) | ) | ||||
set(COMMON_FORMAT_SRC_FILES | set(COMMON_FORMAT_SRC_FILES | ||||
@@ -456,6 +458,7 @@ set(GRAPH_EXECUTE_COMMON_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/graph/manager/graph_manager.cc" | "${GE_CODE_DIR}/ge/graph/manager/graph_manager.cc" | ||||
"${GE_CODE_DIR}/ge/graph/manager/graph_context.cc" | "${GE_CODE_DIR}/ge/graph/manager/graph_context.cc" | ||||
"${GE_CODE_DIR}/ge/graph/manager/util/rt_context_util.cc" | "${GE_CODE_DIR}/ge/graph/manager/util/rt_context_util.cc" | ||||
"${GE_CODE_DIR}/ge/ge_opt_info/ge_opt_info.cc" | |||||
"${GE_CODE_DIR}/ge/graph/manager/graph_context.h" | "${GE_CODE_DIR}/ge/graph/manager/graph_context.h" | ||||
) | ) | ||||
@@ -633,6 +636,10 @@ set(SINGLE_OP_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/hybrid/hybrid_davinci_model.cc" | "${GE_CODE_DIR}/ge/hybrid/hybrid_davinci_model.cc" | ||||
) | ) | ||||
set(GE_OPT_INFO_SRC_FILES | |||||
"${GE_CODE_DIR}/ge/ge_opt_info/ge_opt_info.cc" | |||||
) | |||||
# test files | # test files | ||||
set(COMMON_TEST_FILES | set(COMMON_TEST_FILES | ||||
"graph/passes/graph_builder_utils.cc" | "graph/passes/graph_builder_utils.cc" | ||||
@@ -773,6 +780,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
"common/util_unittest.cc" | "common/util_unittest.cc" | ||||
"common/dump_manager_unittest.cc" | "common/dump_manager_unittest.cc" | ||||
"common/dump_op_unittest.cc" | "common/dump_op_unittest.cc" | ||||
"common/dump_properties_unittest.cc" | |||||
"common/dump_exception_unittest.cc" | "common/dump_exception_unittest.cc" | ||||
"common/opdebug_register_unittest.cc" | "common/opdebug_register_unittest.cc" | ||||
"common/format_transfer_unittest.cc" | "common/format_transfer_unittest.cc" | ||||
@@ -820,6 +828,10 @@ set(MULTI_PARTS_TEST_FILES | |||||
"common/tbe_plugin_manager_unittest.cc" | "common/tbe_plugin_manager_unittest.cc" | ||||
) | ) | ||||
set(GE_OPT_INFO_TEST_FILES | |||||
"ge_opt_info/ge_opt_info_unittest.cc" | |||||
) | |||||
set(GENERATOR_TEST_FILES | set(GENERATOR_TEST_FILES | ||||
"generator/ge_generator_unittest.cc" | "generator/ge_generator_unittest.cc" | ||||
) | ) | ||||
@@ -855,7 +867,6 @@ set(HYBRID_TEST_FILES | |||||
"hybrid/executor/hybrid_model_async_executor_unittest.cc" | "hybrid/executor/hybrid_model_async_executor_unittest.cc" | ||||
"hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | ||||
"hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | ||||
) | ) | ||||
set(OTHERS_TEST_FILES | set(OTHERS_TEST_FILES | ||||
@@ -871,6 +882,7 @@ list(APPEND COMMON_SHARED_LIBRARIES | |||||
mmpa_stub | mmpa_stub | ||||
hccl_stub | hccl_stub | ||||
error_manager_stub | error_manager_stub | ||||
opt_feature_stub | |||||
ascend_protobuf | ascend_protobuf | ||||
json | json | ||||
) | ) | ||||
@@ -882,6 +894,7 @@ add_library(ge_ut_graph STATIC | |||||
target_compile_definitions(ge_ut_graph PRIVATE | target_compile_definitions(ge_ut_graph PRIVATE | ||||
google=ascend_private | google=ascend_private | ||||
FMK_SUPPORT_DUMP | |||||
) | ) | ||||
target_compile_options(ge_ut_graph PRIVATE | target_compile_options(ge_ut_graph PRIVATE | ||||
@@ -1116,10 +1129,12 @@ target_link_libraries(ut_libge_multiparts_utest | |||||
# libge_others_utest | # libge_others_utest | ||||
add_executable(ut_libge_others_utest | add_executable(ut_libge_others_utest | ||||
${GE_OPT_INFO_SRC_FILES} | |||||
${COMMON_TEST_FILES} | ${COMMON_TEST_FILES} | ||||
${PASS_TEST_FILES} | ${PASS_TEST_FILES} | ||||
${EXECUTE_TEST_FILES} | ${EXECUTE_TEST_FILES} | ||||
${OTHERS_TEST_FILES} | ${OTHERS_TEST_FILES} | ||||
${GE_OPT_INFO_TEST_FILES} | |||||
) | ) | ||||
target_compile_options(ut_libge_others_utest PRIVATE | target_compile_options(ut_libge_others_utest PRIVATE | ||||
@@ -0,0 +1,126 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#define protected public | |||||
#define private public | |||||
#include "common/dump/dump_properties.h" | |||||
#include "ge_local_context.h" | |||||
#include "ge/ge_api_types.h" | |||||
#include "common/debug/log.h" | |||||
#include "common/ge_inner_error_codes.h" | |||||
namespace ge { | |||||
class UTEST_dump_properties : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
TEST_F(UTEST_dump_properties, check_dump_step) { | |||||
DumpProperties dp; | |||||
std::string dump_step{"0|3-5|10"}; | |||||
std::string unsupport_input1{"0|5-3|10"}; | |||||
std::string unsupport_input2{"one"}; | |||||
std::string unsupport_input3; | |||||
for (int i = 0; i < 200; ++i) { | |||||
unsupport_input3 += std::to_string(i) + "|"; | |||||
} | |||||
unsupport_input3.pop_back(); | |||||
Status st = dp.CheckDumpStep(dump_step); | |||||
EXPECT_EQ(st, SUCCESS); | |||||
st = dp.CheckDumpStep(unsupport_input1); | |||||
EXPECT_NE(st, SUCCESS); | |||||
st = dp.CheckDumpStep(unsupport_input2); | |||||
EXPECT_NE(st, SUCCESS); | |||||
st = dp.CheckDumpStep(unsupport_input3); | |||||
EXPECT_NE(st, SUCCESS); | |||||
} | |||||
TEST_F(UTEST_dump_properties, check_dump_mode) { | |||||
DumpProperties dp; | |||||
std::string dump_mode_1{"input"}; | |||||
std::string dump_mode_2{"output"}; | |||||
std::string dump_mode_3{"all"}; | |||||
std::string unsupport_input1{"mode1"}; | |||||
Status st = dp.CheckDumpMode(dump_mode_1); | |||||
EXPECT_EQ(st, SUCCESS); | |||||
st = dp.CheckDumpMode(dump_mode_2); | |||||
EXPECT_EQ(st, SUCCESS); | |||||
st = dp.CheckDumpMode(dump_mode_3); | |||||
EXPECT_EQ(st, SUCCESS); | |||||
st = dp.CheckDumpMode(unsupport_input1); | |||||
EXPECT_NE(st, SUCCESS); | |||||
} | |||||
TEST_F(UTEST_dump_properties, check_dump_path) { | |||||
DumpProperties dp; | |||||
std::string dump_path{"/tmp/"}; | |||||
std::string unsupport_input1{" \\unsupported"}; | |||||
Status st = dp.CheckDumpPath(dump_path); | |||||
EXPECT_EQ(st, SUCCESS); | |||||
st = dp.CheckDumpPath(unsupport_input1); | |||||
EXPECT_NE(st, SUCCESS); | |||||
} | |||||
TEST_F(UTEST_dump_properties, check_enable_dump) { | |||||
DumpProperties dp; | |||||
std::string enable_dump_t{"1"}; | |||||
std::string enable_dump_f{"0"}; | |||||
std::string unsupport_input1{"true"}; | |||||
std::string unsupport_input2{"false"}; | |||||
Status st = dp.CheckEnableDump(enable_dump_t); | |||||
EXPECT_EQ(st, SUCCESS); | |||||
st = dp.CheckEnableDump(enable_dump_f); | |||||
EXPECT_EQ(st, SUCCESS); | |||||
st = dp.CheckEnableDump(unsupport_input1); | |||||
EXPECT_NE(st, SUCCESS); | |||||
st = dp.CheckEnableDump(unsupport_input2); | |||||
EXPECT_NE(st, SUCCESS); | |||||
} | |||||
TEST_F(UTEST_dump_properties, init_by_options_success_1) { | |||||
DumpProperties dp; | |||||
std::map<std::string, std::string> options {{OPTION_EXEC_ENABLE_DUMP, "1"}, | |||||
{OPTION_EXEC_DUMP_PATH, "/tmp/"}, | |||||
{OPTION_EXEC_DUMP_STEP, "0|1-3|10"}, | |||||
{OPTION_EXEC_DUMP_MODE, "all"}}; | |||||
GetThreadLocalContext().SetGlobalOption(options); | |||||
Status st = dp.InitByOptions(); | |||||
EXPECT_EQ(st, SUCCESS); | |||||
} | |||||
TEST_F(UTEST_dump_properties, init_by_options_success_2) { | |||||
DumpProperties dp; | |||||
std::map<std::string, std::string> options {{OPTION_EXEC_ENABLE_DUMP_DEBUG, "1"}, | |||||
{OPTION_EXEC_DUMP_PATH, "/tmp/"}, | |||||
{OPTION_EXEC_DUMP_DEBUG_MODE, "aicore_overflow"}}; | |||||
GetThreadLocalContext().SetGlobalOption(options); | |||||
Status st = dp.InitByOptions(); | |||||
EXPECT_EQ(st, SUCCESS); | |||||
} | |||||
TEST_F(UTEST_dump_properties, init_by_options_failed) { | |||||
DumpProperties dp; | |||||
std::map<std::string, std::string> options {{OPTION_EXEC_ENABLE_DUMP_DEBUG, "1"}, | |||||
{OPTION_EXEC_DUMP_PATH, "/tmp/"}}; | |||||
GetThreadLocalContext().SetGlobalOption(options); | |||||
Status st = dp.InitByOptions(); | |||||
EXPECT_NE(st, SUCCESS); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,82 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <gmock/gmock.h> | |||||
#define protected public | |||||
#define private public | |||||
#include "ge_opt_info/ge_opt_info.h" | |||||
#include "graph/ge_local_context.h" | |||||
#include "external/ge/ge_api_types.h" | |||||
#undef private | |||||
#undef protected | |||||
namespace ge { | |||||
class UTEST_opt_info : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
TEST_F(UTEST_opt_info, get_opt_info_success) { | |||||
std::map<std::string, std::string> options = {{ge::SOC_VERSION, "Ascend910"}}; | |||||
GetThreadLocalContext().SetGlobalOption(options); | |||||
auto ret = GeOptInfo::SetOptInfo(); | |||||
EXPECT_EQ(ret, ge::SUCCESS); | |||||
std::map<std::string, std::string> graph_options = GetThreadLocalContext().GetAllGraphOptions(); | |||||
auto itr = graph_options.find("opt_module.fe"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
itr = graph_options.find("opt_module.pass"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
itr = graph_options.find("opt_module.op_tune"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
} | |||||
TEST_F(UTEST_opt_info, get_opt_info_all) { | |||||
std::map<std::string, std::string> global_options = {{ge::SOC_VERSION, "Ascend310"}}; | |||||
GetThreadLocalContext().SetGlobalOption(global_options); | |||||
auto ret = GeOptInfo::SetOptInfo(); | |||||
EXPECT_EQ(ret, ge::SUCCESS); | |||||
std::map<std::string, std::string> graph_options = GetThreadLocalContext().GetAllGraphOptions(); | |||||
auto itr = graph_options.find("opt_module.fe"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
itr = graph_options.find("opt_module.pass"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
itr = graph_options.find("opt_module.op_tune"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
itr = graph_options.find("opt_module.rl_tune"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
itr = graph_options.find("opt_module.aoe"); | |||||
EXPECT_NE(itr, graph_options.end()); | |||||
EXPECT_EQ(itr->second, "all"); | |||||
} | |||||
TEST_F(UTEST_opt_info, get_opt_info_failed) { | |||||
std::map<std::string, std::string> options; | |||||
GetThreadLocalContext().SetGlobalOption(options); | |||||
auto ret = GeOptInfo::SetOptInfo(); | |||||
EXPECT_EQ(ret, ge::FAILED); | |||||
} | |||||
} // namespace ge |
@@ -349,7 +349,7 @@ class UtestLogicalStreamAllocator : public testing::Test { | |||||
/// B --> C(AllReduce) --- D | /// B --> C(AllReduce) --- D | ||||
/// / | /// / | ||||
/// stream id: 0 A | /// stream id: 0 A | ||||
/// \ | |||||
/// \. | |||||
/// E --> F(AllReduce) --- G | /// E --> F(AllReduce) --- G | ||||
/// stream id: 2 2 2 | /// stream id: 2 2 2 | ||||
/// | /// | ||||
@@ -599,7 +599,7 @@ TEST_F(UtestLogicalStreamAllocator, test_label_not_reusable2) { | |||||
/// case of multi-output, then unuse stream | /// case of multi-output, then unuse stream | ||||
/// sub1 | /// sub1 | ||||
/// / | \ | |||||
/// / | \. | |||||
/// sub2 sub3 sub4 | /// sub2 sub3 sub4 | ||||
TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | ||||
SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
@@ -624,7 +624,7 @@ TEST_F(UtestLogicalStreamAllocator, test_multiOut_new_stream) { | |||||
/// if paralle id 1, then use stream | /// if paralle id 1, then use stream | ||||
/// sub1 | /// sub1 | ||||
/// / | | \ | |||||
/// / | | \. | |||||
/// sub2 sub3 sub4 sub5 | /// sub2 sub3 sub4 sub5 | ||||
TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | ||||
SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
@@ -653,7 +653,7 @@ TEST_F(UtestLogicalStreamAllocator, test_parallel_one) { | |||||
/// if the param of engine independent is true, then set independent stream | /// if the param of engine independent is true, then set independent stream | ||||
/// sub1 | /// sub1 | ||||
/// / | | \ | |||||
/// / | | \. | |||||
/// sub2 sub3 sub4 sub5 | /// sub2 sub3 sub4 sub5 | ||||
TEST_F(UtestLogicalStreamAllocator, test_independent) { | TEST_F(UtestLogicalStreamAllocator, test_independent) { | ||||
SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
@@ -692,7 +692,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent) { | |||||
/// set stream based on stream label, and then based on independent | /// set stream based on stream label, and then based on independent | ||||
/// sub1 | /// sub1 | ||||
/// / | | \ | |||||
/// / | | \. | |||||
/// sub2 sub3 sub4 sub5 | /// sub2 sub3 sub4 sub5 | ||||
TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { | TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { | ||||
SubGraphInfoPtr data = CreateDataSubgraph(); | SubGraphInfoPtr data = CreateDataSubgraph(); | ||||
@@ -36,7 +36,7 @@ class UtestStreamAllocator : public testing::Test { | |||||
/// | /// | ||||
/// A | /// A | ||||
/// / \ | |||||
/// / \. | |||||
/// B C | /// B C | ||||
/// | | | /// | | | ||||
/// D 400 | /// D 400 | ||||
@@ -55,7 +55,7 @@ class UtestGraphPassesAssertPass : public Test { | |||||
}; | }; | ||||
/// D E | /// D E | ||||
/// | \ | \ | |||||
/// | \ | \. | |||||
/// F C G | /// F C G | ||||
/// : | : | /// : | : | ||||
/// H A I | /// H A I | ||||
@@ -134,8 +134,8 @@ TEST_F(UtestGraphPassesAssertPass, assert_pass_test2) { | |||||
EXPECT_EQ(graph->FindNode("D"), nullptr); | EXPECT_EQ(graph->FindNode("D"), nullptr); | ||||
} | } | ||||
/// E F | |||||
/// | \ | \ | |||||
/// E F | |||||
/// | \ | \. | |||||
/// H C -> D G | /// H C -> D G | ||||
/// \ | : | /// \ | : | ||||
/// A I | /// A I | ||||
@@ -130,7 +130,7 @@ class UTESTGraphPassesBasePass : public testing::Test { | |||||
/// reshape1 | /// reshape1 | ||||
/// | | /// | | ||||
/// add1 | /// add1 | ||||
/// / \ | |||||
/// / \. | |||||
/// | | | /// | | | ||||
/// data1 const1 | /// data1 const1 | ||||
ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
@@ -148,9 +148,9 @@ ComputeGraphPtr BuildGraph1() { | |||||
} | } | ||||
/// sum1 | /// sum1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// / \. | |||||
/// reshape1 addn1 | /// reshape1 addn1 | ||||
/// | c | | /// | c | | ||||
/// add1 <--- shape1 | /// add1 <--- shape1 | ||||
@@ -217,7 +217,7 @@ void CheckIterOrder(UtestTestPass *pass, std::vector<std::unordered_set<std::str | |||||
/// Op1 | /// Op1 | ||||
/// | | /// | | ||||
/// Merge | /// Merge | ||||
/// / \ | |||||
/// / \. | |||||
/// Op2 Op3 | /// Op2 Op3 | ||||
TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | ||||
auto builder = ut::GraphBuilder("g1"); | auto builder = ut::GraphBuilder("g1"); | ||||
@@ -245,7 +245,7 @@ TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) { | |||||
/// Op1 | /// Op1 | ||||
/// | | /// | | ||||
/// Merge | /// Merge | ||||
/// / \ | |||||
/// / \. | |||||
/// Op2 Op3 | /// Op2 Op3 | ||||
TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { | TEST_F(UTESTGraphPassesBasePass, del_isolate_success) { | ||||
auto builder = ut::GraphBuilder("g1"); | auto builder = ut::GraphBuilder("g1"); | ||||
@@ -459,7 +459,7 @@ TEST_F(UTESTGraphPassesBasePass, while_loop) { | |||||
/// data1 const | /// data1 const | ||||
/// \ / | /// \ / | ||||
/// while | /// while | ||||
/// / \ | |||||
/// / \. | |||||
/// | | | /// | | | ||||
/// cast1 cast2 | /// cast1 cast2 | ||||
ComputeGraphPtr BuildWhileGraph1() { | ComputeGraphPtr BuildWhileGraph1() { | ||||
@@ -34,11 +34,11 @@ namespace { | |||||
/// net_output | /// net_output | ||||
/// | | /// | | ||||
/// merge | /// merge | ||||
/// / \ | |||||
/// / \. | |||||
/// square add | /// square add | ||||
/// F| T/ T\ | |||||
/// F| T/ T\. | |||||
/// switch1 switch2 | /// switch1 switch2 | ||||
/// / \ / \ | |||||
/// / \ / \. | |||||
/// var1 var2 var3 | /// var1 var2 var3 | ||||
/// | /// | ||||
ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
@@ -173,8 +173,8 @@ namespace { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnYes1 | /// addnYes1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph1() { | ComputeGraphPtr BuildGraph1() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -223,8 +223,8 @@ ComputeGraphPtr BuildGraph2() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | c | /// | c | ||||
/// addnYes1 <----- dataNo1 | /// addnYes1 <----- dataNo1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph3() { | ComputeGraphPtr BuildGraph3() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -249,8 +249,8 @@ ComputeGraphPtr BuildGraph3() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | c | /// | c | ||||
/// addnYes1 <--------- | /// addnYes1 <--------- | ||||
/// / \ \ | |||||
/// / \ c \ | |||||
/// / \ \. | |||||
/// / \ c \. | |||||
/// const1 const2 <----- dataNo1 | /// const1 const2 <----- dataNo1 | ||||
ComputeGraphPtr BuildGraph4() { | ComputeGraphPtr BuildGraph4() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -276,7 +276,7 @@ ComputeGraphPtr BuildGraph4() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | c | /// | c | ||||
/// addnYes1 <----- dataNo1 | /// addnYes1 <----- dataNo1 | ||||
/// / \ | |||||
/// / \. | |||||
/// / \ c | /// / \ c | ||||
/// const1 const2 <----- dataNo2 | /// const1 const2 <----- dataNo2 | ||||
ComputeGraphPtr BuildGraph5() { | ComputeGraphPtr BuildGraph5() { | ||||
@@ -306,8 +306,8 @@ ComputeGraphPtr BuildGraph5() { | |||||
/// addYes1 <---- const3 | /// addYes1 <---- const3 | ||||
/// | | /// | | ||||
/// addnYes1 <- | /// addnYes1 <- | ||||
/// / \ \ | |||||
/// / \ \ | |||||
/// / \ \. | |||||
/// / \ \. | |||||
/// const1 const2 const4 | /// const1 const2 const4 | ||||
ComputeGraphPtr BuildGraph6() { | ComputeGraphPtr BuildGraph6() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -332,12 +332,12 @@ ComputeGraphPtr BuildGraph6() { | |||||
} | } | ||||
/// netoutput1 | /// netoutput1 | ||||
/// / \ | |||||
/// / \. | |||||
/// shapeNo1 ShpaeNo2 | /// shapeNo1 ShpaeNo2 | ||||
/// \ / | /// \ / | ||||
/// huberLoss1 | /// huberLoss1 | ||||
/// / | \ | |||||
/// / | \ | |||||
/// / | \. | |||||
/// / | \. | |||||
/// const1 const2 const3 | /// const1 const2 const3 | ||||
ComputeGraphPtr BuildGraph7() { | ComputeGraphPtr BuildGraph7() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -365,8 +365,8 @@ ComputeGraphPtr BuildGraph7() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnNo1 | /// addnNo1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph8() { | ComputeGraphPtr BuildGraph8() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -389,8 +389,8 @@ ComputeGraphPtr BuildGraph8() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnYes1 | /// addnYes1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 data1 | /// const1 data1 | ||||
ComputeGraphPtr BuildGraph9() { | ComputeGraphPtr BuildGraph9() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -409,12 +409,12 @@ ComputeGraphPtr BuildGraph9() { | |||||
} | } | ||||
/// netoutput1 | /// netoutput1 | ||||
/// / \ | |||||
/// / \. | |||||
/// addDim sqrt1 | /// addDim sqrt1 | ||||
/// \ / | /// \ / | ||||
/// switch1 | /// switch1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph10() { | ComputeGraphPtr BuildGraph10() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -63,8 +63,8 @@ namespace { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnNo1 | /// addnNo1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
/// const1 const2 | /// const1 const2 | ||||
ComputeGraphPtr BuildGraph8() { | ComputeGraphPtr BuildGraph8() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||
@@ -87,8 +87,8 @@ ComputeGraphPtr BuildGraph8() { | |||||
/// shapeNo1 | /// shapeNo1 | ||||
/// | | /// | | ||||
/// addnYes1 | /// addnYes1 | ||||
/// / \ | |||||
/// / \ | |||||
/// / \. | |||||
/// / \. | |||||
///const1 data1 | ///const1 data1 | ||||
ComputeGraphPtr BuildGraph9() { | ComputeGraphPtr BuildGraph9() { | ||||
auto builder = ut::GraphBuilder("test"); | auto builder = ut::GraphBuilder("test"); | ||||