Browse Source

!1407 update commit id to r1.2 0330

From: @shenwei41
Reviewed-by: @xsmq,@lilongfei15
Signed-off-by: @lilongfei15
tags/v1.2.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
de47249a72
36 changed files with 647 additions and 142 deletions
  1. +1
    -1
      build.sh
  2. +7
    -1
      ge/CMakeLists.txt
  3. +3
    -2
      ge/common/CMakeLists.txt
  4. +0
    -5
      ge/common/dump/opdebug_register.cc
  5. +3
    -2
      ge/executor/CMakeLists.txt
  6. +8
    -6
      ge/ge_local_engine/CMakeLists.txt
  7. +6
    -2
      ge/graph/build/task_generator.cc
  8. +8
    -0
      ge/graph/manager/graph_caching_allocator.cc
  9. +7
    -0
      ge/graph/manager/graph_caching_allocator.h
  10. +6
    -6
      ge/graph/passes/atomic_addr_clean_pass.cc
  11. +0
    -1
      ge/graph/passes/attach_stream_label_pass.cc
  12. +7
    -1
      ge/graph/passes/pass_utils.cc
  13. +2
    -0
      ge/graph/passes/pass_utils.h
  14. +1
    -1
      ge/graph/passes/subexpression_migration_pass.cc
  15. +8
    -2
      ge/graph/passes/switch_dead_branch_elimination.cc
  16. +2
    -0
      ge/graph/passes/switch_to_stream_switch_pass.cc
  17. +3
    -2
      ge/host_cpu_engine/CMakeLists.txt
  18. +1
    -0
      ge/hybrid/executor/hybrid_execution_context.h
  19. +6
    -0
      ge/hybrid/executor/hybrid_model_executor.cc
  20. +1
    -6
      ge/hybrid/executor/worker/execution_engine.cc
  21. +167
    -65
      ge/hybrid/model/hybrid_model_builder.cc
  22. +7
    -2
      ge/hybrid/model/hybrid_model_builder.h
  23. +4
    -0
      ge/hybrid/model/node_item.cc
  24. +2
    -0
      ge/hybrid/model/node_item.h
  25. +24
    -25
      ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc
  26. +5
    -3
      ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h
  27. +4
    -0
      ge/single_op/single_op_manager.cc
  28. +35
    -6
      ge/single_op/single_op_model.cc
  29. +1
    -1
      metadef
  30. +1
    -1
      parser
  31. +4
    -0
      tests/depends/runtime/src/runtime_stub.cc
  32. +3
    -0
      tests/ut/ge/CMakeLists.txt
  33. +68
    -0
      tests/ut/ge/graph/build/task_generator_unittest.cc
  34. +65
    -0
      tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc
  35. +163
    -0
      tests/ut/ge/graph/passes/switch_dead_branch_elimination_unittest.cc
  36. +14
    -1
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 1
- 1
build.sh View File

@@ -229,7 +229,7 @@ if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then
rm -rf ${BASEPATH}/cov
mkdir ${BASEPATH}/cov
lcov -c -d build/tests/ut/ge -d build/tests/ut/common/graph/ -o cov/tmp.info
lcov -r cov/tmp.info '*/output/*' '*/build/opensrc/*' '*/build/proto/*' '*/third_party/*' '*/tests/*' '/usr/local/*' -o cov/coverage.info
lcov -r cov/tmp.info '*/output/*' '*/build/opensrc/*' '*/build/proto/*' '*/third_party/*' '*/tests/*' '/usr/local/*' '/usr/include/*' '*/metadef/*' '*/parser/*' -o cov/coverage.info
cd ${BASEPATH}/cov
genhtml coverage.info
fi


+ 7
- 1
ge/CMakeLists.txt View File

@@ -31,6 +31,7 @@ set(PROTO_HEADER_LIST
protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST})
protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST})
protobuf_generate(ge_client PROTO_CLIENT_HEADER_SRCS PROTO_CLIENT_HEADER_HDRS ${PROTO_HEADER_LIST})

if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES)
############ libge_proto_common.a ############
@@ -56,7 +57,7 @@ target_link_libraries(ge_proto_common PRIVATE

############ libge_proto_client.a ############
add_library(ge_proto_client STATIC
${PROTO_HEADER_HDRS}
${PROTO_CLIENT_HEADER_HDRS}
${PROTO_CLIENT_SRCS}
)

@@ -65,6 +66,11 @@ target_compile_definitions(ge_proto_client PRIVATE
google=ascend_private
)

target_include_directories(ge_proto_client PRIVATE
${CMAKE_BINARY_DIR}/proto/ge_client
${CMAKE_BINARY_DIR}/proto/ge_client/proto
)

target_compile_options(ge_proto_client PRIVATE
-O2
-fno-common


+ 3
- 2
ge/common/CMakeLists.txt View File

@@ -16,6 +16,7 @@ set(PROTO_LIST
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST})

set(SRC_LIST
"context/ctx.cc"
@@ -127,7 +128,7 @@ target_link_libraries(ge_common PRIVATE
)

############ libge_common.a ############
add_library(ge_common_static STATIC ${SRC_LIST} ${PROTO_HDRS})
add_library(ge_common_static STATIC ${SRC_LIST} ${PROTO_STATIC_HDRS})
target_compile_definitions(ge_common_static PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
HOST_VISIBILITY
@@ -158,7 +159,7 @@ target_include_directories(ge_common_static PRIVATE
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
${CMAKE_BINARY_DIR}/proto/ge_static
#### yellow zone ####
${GE_DEPEND_DIR}/inc
${GE_DEPEND_DIR}/inc/cce


+ 0
- 5
ge/common/dump/opdebug_register.cc View File

@@ -80,13 +80,11 @@ Status OpdebugRegister::RegisterDebugForStream(rtStream_t stream, uint32_t op_de

uint32_t debug_stream_id = 0;
uint32_t debug_task_id = 0;
#ifdef ONLY_COMPILE_OPEN_SRC
auto rt_ret = rtDebugRegisterForStream(stream, op_debug_mode, op_debug_addr_, &debug_stream_id, &debug_task_id);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtDebugRegisterForStream error, ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}
#endif
GELOGD("debug_task_id:%u, debug_stream_id:%u in stream overflow.", debug_task_id, debug_stream_id);
data_dumper.SaveOpDebugId(debug_task_id, debug_stream_id, p2p_debug_addr_, true);
return SUCCESS;
@@ -94,7 +92,6 @@ Status OpdebugRegister::RegisterDebugForStream(rtStream_t stream, uint32_t op_de

void OpdebugRegister::UnregisterDebugForStream(rtStream_t stream) {
rtError_t rt_ret = RT_ERROR_NONE;
#ifdef ONLY_COMPILE_OPEN_SRC
if (stream != nullptr) {
GELOGD("start call rtDebugUnRegisterForStream in unknown shape over flow.");
rt_ret = rtDebugUnRegisterForStream(stream);
@@ -102,8 +99,6 @@ void OpdebugRegister::UnregisterDebugForStream(rtStream_t stream) {
GELOGW("rtDebugUnRegisterForStream failed, ret: 0x%X", rt_ret);
}
}
#endif

if (op_debug_addr_ != nullptr) {
rt_ret = rtFree(op_debug_addr_);
if (rt_ret != RT_ERROR_NONE) {


+ 3
- 2
ge/executor/CMakeLists.txt View File

@@ -8,6 +8,7 @@ set(PROTO_LIST
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST})

set(SRC_LIST
"ge_executor.cc"
@@ -162,7 +163,7 @@ set(SRC_LIST
)

######## libge_executor.a ########
add_library(ge_executor STATIC ${SRC_LIST} ${PROTO_HDRS})
add_library(ge_executor STATIC ${SRC_LIST} ${PROTO_STATIC_HDRS})

target_compile_options(ge_executor PRIVATE
$<$<OR:$<STREQUAL:${TARGET_SYSTEM_NAME},Linux>,$<STREQUAL:${TARGET_SYSTEM_NAME},Android>>:-fvisibility=hidden -O2 -Werror -Wno-deprecated-declarations -fno-common>
@@ -191,7 +192,7 @@ target_include_directories(ge_executor SYSTEM PRIVATE
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
${CMAKE_BINARY_DIR}/proto/ge_static
#### yellow zone ####
${GE_CODE_DIR}/../inc
${GE_CODE_DIR}/../inc/cce


+ 8
- 6
ge/ge_local_engine/CMakeLists.txt View File

@@ -20,6 +20,8 @@ set(OPS_KERNEL_SRC_LIST
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
protobuf_generate(ge_ops_shared PROTO_OPS_SHARED_SRCS PROTO_OPS_SHARED_HDRS ${PROTO_LIST})
protobuf_generate(ge_ops_static PROTO_OPS_STATIC_SRCS PROTO_OPS_STATIC_HDRS ${PROTO_LIST})

############ libge_local_engine.so ############
add_library(ge_local_engine SHARED ${SRC_LIST} ${PROTO_HDRS})
@@ -119,7 +121,7 @@ set_target_properties(atc_ge_local_engine PROPERTIES
)

############ libge_local_opskernel_builder.so ############
add_library(ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_HDRS})
add_library(ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_OPS_SHARED_HDRS})

target_compile_options(ge_local_opskernel_builder PRIVATE
-Werror
@@ -143,7 +145,7 @@ target_include_directories(ge_local_opskernel_builder PRIVATE
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
${CMAKE_BINARY_DIR}/proto/ge_ops_shared
#### yellow zone ####
${GE_CODE_DIR}/../inc
#### blue zone ####
@@ -166,7 +168,7 @@ target_link_libraries(ge_local_opskernel_builder PRIVATE
)

############ atclib/libge_local_opskernel_builder.so ############
add_library(atc_ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_HDRS})
add_library(atc_ge_local_opskernel_builder SHARED ${OPS_KERNEL_SRC_LIST} ${PROTO_OPS_SHARED_HDRS})

target_compile_options(atc_ge_local_opskernel_builder PRIVATE
-Werror
@@ -190,7 +192,7 @@ target_include_directories(atc_ge_local_opskernel_builder PRIVATE
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
${CMAKE_BINARY_DIR}/proto/ge_ops_shared
#### yellow zone ####
${GE_CODE_DIR}/../inc
#### blue zone ####
@@ -218,7 +220,7 @@ set_target_properties(atc_ge_local_opskernel_builder PROPERTIES
)

############ libge_local_opskernel_builder.a ############
add_library(ge_local_opskernel_builder_static STATIC ${OPS_KERNEL_SRC_LIST} ${PROTO_HDRS})
add_library(ge_local_opskernel_builder_static STATIC ${OPS_KERNEL_SRC_LIST} ${PROTO_OPS_STATIC_HDRS})

target_compile_options(ge_local_opskernel_builder_static PRIVATE
-Werror
@@ -243,7 +245,7 @@ target_include_directories(ge_local_opskernel_builder_static PRIVATE
${METADEF_DIR}/inc/external/graph
${METADEF_DIR}/inc/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
${CMAKE_BINARY_DIR}/proto/ge_ops_static
#### yellow zone ####
${GE_CODE_DIR}/../inc
#### blue zone ####


+ 6
- 2
ge/graph/build/task_generator.cc View File

@@ -49,6 +49,7 @@ const char *const kIsLastNode = "is_last_node";
const char *const kIsInputVar = "INPUT_IS_VAR";
const char *const kIsOutputVar = "OUTPUT_IS_VAR";
const char *const kProfilingMode = "PROFILING_MODE";
const char *const kIteratorV2 = "IteratorV2";
const uint32_t kProfilingArStep = 2;
const uint64_t kProfilingFpStartLogid = 1;
const uint64_t kProfilingBpEndLogid = 2;
@@ -57,6 +58,7 @@ const uint64_t kProfilingArEndLogid = 4;
const uint64_t kProfilingIterEndLogid = 65535;
const int64_t kHashFactor = 100000;
const int64_t kInvalidGroupId = -1;
const std::set<std::string> kFpNodeTypes = {ge::DATA, ge::GETNEXT, kIteratorV2};
} // namespace
namespace ge {
TaskGenerator::TaskGenerator(uint8_t *var_mem_base, uint64_t var_mem_size) {
@@ -621,8 +623,10 @@ Status TaskGenerator::AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingP
if (op_kernel_lib_name.empty()) {
continue;
}

if (op_desc->GetType() == GETNEXT || op_desc->GetType() == DATA) {
auto type = op_desc->GetType();
std::string original_type;
(void)AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type);
if (kFpNodeTypes.find(type) != kFpNodeTypes.end() || kFpNodeTypes.find(original_type) != kFpNodeTypes.end()) {
auto out_anchor = node->GetOutDataAnchor(0);
for (auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
GE_CHECK_NOTNULL(peer_in_anchor);


+ 8
- 0
ge/graph/manager/graph_caching_allocator.cc View File

@@ -356,6 +356,14 @@ void CachingAllocator::FreeBlocks() {
(void) FreeCachedBlocks();
}

void CachingAllocator::TryFreeBlocks() {
GELOGI("Try free blocks.");
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (allocated_blocks_.empty()) {
(void) FreeCachedBlocks();
}
}

void CachingAllocator::FreeBlockBins() {
GELOGI("Free block bins.");
std::lock_guard<std::recursive_mutex> lock(mutex_);


+ 7
- 0
ge/graph/manager/graph_caching_allocator.h View File

@@ -94,6 +94,13 @@ class CachingAllocator {
///
Status Free(uint8_t *memory_addr, uint32_t device_id = 0);

///
/// @ingroup ge_graph
/// @brief try to free memory when no memory is referenced
/// @return void
///
void TryFreeBlocks();

private:

///


+ 6
- 6
ge/graph/passes/atomic_addr_clean_pass.cc View File

@@ -126,11 +126,11 @@ bool AtomicAddrCleanPass::IsOutputIndexPeerInputAtomic(const NodePtr &node, int6

bool AtomicAddrCleanPass::CheckSkipInsertInLoopGraph(const NodePtr &node) {
OpDescPtr op_desc = node->GetOpDesc();
std::map<string, std::map<int, int>> node_workspace_offset;
std::map<string, std::map<int64_t, int64_t>> atomic_workspace_index_size;
bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX);
bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX);
node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset);
if (!has_atomic_input && has_atomic_output && node_workspace_offset.empty()) {
atomic_workspace_index_size = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_index_size);
if (!has_atomic_input && has_atomic_output && atomic_workspace_index_size.empty()) {
std::vector<int64_t> atomic_output_index;
(void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index);
bool is_all_output_peer_also_atomic = true;
@@ -332,11 +332,11 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) {
}

// 2.Check atomic attr in node
std::map<string, std::map<int, int>> node_workspace_offset;
std::map<string, std::map<int64_t, int64_t>> atomic_workspace_index_size;
bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX);
bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX);
node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset);
if (!has_atomic_input && !has_atomic_output && node_workspace_offset.empty()) {
atomic_workspace_index_size = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_index_size);
if (!has_atomic_input && !has_atomic_output && atomic_workspace_index_size.empty()) {
return false;
}



+ 0
- 1
ge/graph/passes/attach_stream_label_pass.cc View File

@@ -137,7 +137,6 @@ Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &strea
return INTERNAL_ERROR;
}
stream_label = node->GetInDataNodes().at(0)->GetName();
GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed.");
bool value = false;
OpDescPtr op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);


+ 7
- 1
ge/graph/passes/pass_utils.cc View File

@@ -35,9 +35,9 @@
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h"
#include "utils/node_utils.h"

namespace ge {

Status PassUtils::ConstructTensorDescWithData(const GeTensorDesc &out_desc, std::vector<int64_t> &data,
std::vector<GeTensorPtr> &v_output, const bool scalar_output) {
Status ret = SUCCESS;
@@ -246,6 +246,12 @@ NodePtr PassUtils::GetInDataNode(const ConstNodePtr &node, int index) {
return src_node;
}

NodePtr PassUtils::GetInNodeCrossSubgraphByIndex(const ConstNodePtr &node, int index) {
auto src_node = GetInDataNode(node, index);

return NodeUtils::GetInNodeCrossSubgraph(src_node);
}

bool PassUtils::IsNeedTrainIteFlowCtrl(const ComputeGraphPtr &compute_graph) {
if (compute_graph == nullptr) {
return false;


+ 2
- 0
ge/graph/passes/pass_utils.h View File

@@ -30,6 +30,8 @@ class PassUtils {

static NodePtr GetInDataNode(const ConstNodePtr &node, int index);

static NodePtr GetInNodeCrossSubgraphByIndex(const ConstNodePtr &node, int index);

static bool IsConstant(const ConstNodePtr &node);

static Status SetOutNodeWeight(const OutDataAnchorPtr &out_data_anchor, const NodePtr &src_node);


+ 1
- 1
ge/graph/passes/subexpression_migration_pass.cc View File

@@ -279,7 +279,7 @@ Status SubexpressionMigrationPass::GraphNodeMigration(const ComputeGraphPtr &gra
const auto &in_anchor = in_anchors.at(i);
const auto &base_node = in_anchor->GetOwnerNode();
GELOGD("Get Data direct node: %s", base_node->GetName().c_str());
if (!base_node->GetHostNode()) {
if (!base_node->GetHostNode() || base_node->GetType() == SWITCH) {
continue;
}



+ 8
- 2
ge/graph/passes/switch_dead_branch_elimination.cc View File

@@ -94,6 +94,12 @@ Status SwitchDeadBranchElimination::DeleteSwitchNode(NodePtr &node, NodePtr &pre
GELOGE(FAILED, "parameter is null.");
return FAILED;
}

// If two nodes aren't in same graph, get node's direct in_node instead of pred_node.
if (node->GetOwnerComputeGraph() != pred_node->GetOwnerComputeGraph()) {
pred_node = PassUtils::GetInDataNode(node, kPredInputIndex);
}

// link pred's in control nodes to switch
if (GraphUtils::CopyInCtrlEdges(pred_node, node) != GRAPH_SUCCESS) {
return FAILED;
@@ -131,7 +137,7 @@ Status SwitchDeadBranchElimination::Run(NodePtr &node) {
return SUCCESS;
}

auto pred_node = PassUtils::GetInDataNode(node, kPredInputIndex);
auto pred_node = PassUtils::GetInNodeCrossSubgraphByIndex(node, kPredInputIndex);
if (pred_node == nullptr) {
GELOGD("[%s] Pred input is null.", node->GetName().c_str());
return SUCCESS;
@@ -143,7 +149,7 @@ Status SwitchDeadBranchElimination::Run(NodePtr &node) {
return SUCCESS;
}

auto input_node = PassUtils::GetInDataNode(node, kDataInputIndex);
auto input_node = PassUtils::GetInNodeCrossSubgraphByIndex(node, kDataInputIndex);
if (input_node == nullptr) {
GELOGD("[%s] Data input is null.", node->GetName().c_str());
return SUCCESS;


+ 2
- 0
ge/graph/passes/switch_to_stream_switch_pass.cc View File

@@ -448,6 +448,8 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)

// select first stream_switch
NodePtr stream_switch = switch_list.front();
// set stream_label
GE_CHK_STATUS_RET(SetStreamLabel(stream_switch, cast_node->GetName()), "Set stream label failed.");
OpDescPtr switch_desc = stream_switch->GetOpDesc();
GE_CHECK_NOTNULL(switch_desc);
switch_desc->SetName(CheckDuplicateName(cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f")));


+ 3
- 2
ge/host_cpu_engine/CMakeLists.txt View File

@@ -3,6 +3,7 @@ set(PROTO_LIST
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
protobuf_generate(ge_atcstub PROTO_ATCSTUB_SRCS PROTO_ATCSTUB_HDRS ${PROTO_LIST})

set(SRC_LIST
"engine/host_cpu_engine.cc"
@@ -61,7 +62,7 @@ target_link_libraries(host_cpu_engine PRIVATE
)

############ atcstub/libhost_cpu_engine.so ############
add_library(atc_host_cpu_engine SHARED ${SRC_LIST} ${PROTO_HDRS})
add_library(atc_host_cpu_engine SHARED ${SRC_LIST} ${PROTO_ATCSTUB_HDRS})

target_compile_options(atc_host_cpu_engine PRIVATE
-Werror
@@ -84,7 +85,7 @@ target_include_directories(atc_host_cpu_engine PRIVATE
${METADEF_DIR}/inc/external
${METADEF_DIR}/inc/external/graph
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
${CMAKE_BINARY_DIR}/proto/ge_atcstub
#### yellow zone ####
${GE_CODE_DIR}/../inc
#### blue zone ####


+ 1
- 0
ge/hybrid/executor/hybrid_execution_context.h View File

@@ -71,6 +71,7 @@ struct GraphExecutionContext {
std::atomic_bool is_eos_;
long profiling_level = 0;
long iteration = 0;
void *global_step = nullptr;

private:
Status status = SUCCESS;


+ 6
- 0
ge/hybrid/executor/hybrid_model_executor.cc View File

@@ -33,6 +33,9 @@ HybridModelExecutor::~HybridModelExecutor() {
if (context_.rt_gen_context != nullptr) {
(void) rtCtxDestroy(context_.rt_gen_context);
}
if (context_.global_step != nullptr) {
(void) rtFree(context_.global_step);
}
}

Status HybridModelExecutor::Init() {
@@ -47,6 +50,8 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
auto root_graph_item = model_->GetRootGraphItem();
GE_CHECK_NOTNULL(root_graph_item);

GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration,
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream));
SubgraphExecutor executor(model_->GetRootGraphItem(), &context_);
auto ret = ExecuteGraphInternal(executor, args);
Cleanup();
@@ -97,6 +102,7 @@ Status HybridModelExecutor::InitExecutionContext() {
GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context));
GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0));
GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context));
GE_CHK_RT_RET(rtMalloc(&context_.global_step, sizeof(uint64_t), RT_MEMORY_HBM));

context_.stream = stream_;
context_.model = model_;


+ 1
- 6
ge/hybrid/executor/worker/execution_engine.cc View File

@@ -231,12 +231,6 @@ Status NodeDoneCallback::DumpDynamicNode() {
uint32_t model_id = model->GetModelId();
dump_op_.SetDynamicModelInfo(dynamic_model_name, model_id);

void *global_step = nullptr;
TensorValue *varible_global_step = context_->GetVariable(NODE_NAME_GLOBAL_STEP);
if (varible_global_step != nullptr) {
global_step = const_cast<void *>(varible_global_step->GetData());
}

void *loop_per_iter = nullptr;
TensorValue *varible_loop_per_iter = context_->GetVariable(NODE_NAME_FLOWCTRL_LOOP_PER_ITER);
if (varible_loop_per_iter != nullptr) {
@@ -248,6 +242,7 @@ Status NodeDoneCallback::DumpDynamicNode() {
if (varible_loop_cond != nullptr) {
loop_cond = const_cast<void *>(varible_loop_cond->GetData());
}
void *global_step = context_->GetExecutionContext()->global_step;
dump_op_.SetLoopAddr(global_step, loop_per_iter, loop_cond);

GE_CHK_STATUS_RET(dump_op_.LaunchDumpOp(), "Failed to launch dump op in hybird model");


+ 167
- 65
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -255,9 +255,7 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n
(void) AttrUtils::SetBool(new_node->op_desc, kIsFirstNode, false);
(void) AttrUtils::SetBool(new_node->op_desc, kIsLastNode, false);

new_node->node_id = node_index;
new_node->op_desc->SetId(node_index);
node_index += 1;
new_node->node_id = static_cast<int>(new_node->op_desc->GetId());
NodeExecutorManager::ExecutorType executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node);
new_node->is_profiling_report = (executor_type == NodeExecutorManager::ExecutorType::AICORE) ||
(executor_type == NodeExecutorManager::ExecutorType::AICPU_TF) ||
@@ -273,16 +271,16 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt
// not care result, if no this attr, stand for the op does not need force infershape
(void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape);
GELOGD("node [%s] is need do infershape , flag is %d",
op_desc->GetName().c_str(),
node_item.is_need_force_infershape);
op_desc->GetName().c_str(),
node_item.is_need_force_infershape);
return SUCCESS;
}

Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) {
std::set<NodePtr> dependent_input_nodes;
std::set<NodePtr> dependent_for_shape_inference;
std::set<NodePtr> dependent_for_execution;
auto &ge_node = node_item.node;
bool is_hccl_op =
NodeExecutorManager::GetInstance().ResolveExecutorType(*ge_node) == NodeExecutorManager::ExecutorType::HCCL;
bool is_hccl_op = node_item.IsHcclOp();

// The input tensors become valid after computation is done for parent nodes of type DEPEND_COMPUTE.
// Wait for these parent nodes before execution.
@@ -297,29 +295,15 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
auto src_node_item = MutableNodeItem(src_node);
GE_CHECK_NOTNULL(src_node_item);

if (is_hccl_op) {
GELOGD("[%s] Add input data dependent node [%s] due to engine type is HCCL",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str());
src_node_item->has_observer = true;
node_item.dependents_for_execution.emplace_back(src_node);
node_item.has_observer = true;
for (auto &dst_node : ge_node->GetOutNodes()) {
if (dst_node == nullptr) {
continue;
}

NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(dst_node, &dst_node_item));
dst_node_item->dependents_for_execution.emplace_back(ge_node);
}
} else if (src_node_item->shape_inference_type == DEPEND_COMPUTE) {
GELOGD("[%s] Add input data dependent node [%s] due to inference type = DEPEND_COMPUTE",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str());

if (src_node_item->shape_inference_type == DEPEND_COMPUTE || is_hccl_op || src_node_item->IsHcclOp()) {
GELOGD("[%s](%s) Add input data dependent node [%s](%s), shape inference type = %d",
ge_node->GetName().c_str(),
ge_node->GetType().c_str(),
src_node->GetName().c_str(),
src_node->GetType().c_str(),
static_cast<int>(src_node_item->shape_inference_type));
src_node_item->has_observer = true;
node_item.dependents_for_execution.emplace_back(src_node);
dependent_for_execution.emplace(src_node);
}

if (src_node_item->shape_inference_type == DEPEND_SHAPE_RANGE) {
@@ -327,22 +311,17 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str());
src_node_item->has_observer = true;
dependent_input_nodes.emplace(src_node);
dependent_for_shape_inference.emplace(src_node);
}
}

// cond or branch need to be prepared before the execution of IF or CASE
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) {
const auto &in_anchor = ge_node->GetInDataAnchor(0);
GE_CHECK_NOTNULL(in_anchor);
const auto &peer_anchor = in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_anchor);
auto src_node = peer_anchor->GetOwnerNode();
auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input
GE_CHECK_NOTNULL(src_node);
auto src_node_item = MutableNodeItem(src_node);
GE_CHECK_NOTNULL(src_node_item);
src_node_item->has_observer = true;
node_item.dependents_for_execution.emplace_back(src_node);
dependent_for_execution.emplace(src_node);
GELOGD("[%s] Dependent added from %s for control op's cond/branch",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str());
@@ -366,24 +345,32 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
GE_CHECK_NOTNULL(src_node);
auto src_node_item = MutableNodeItem(src_node);
src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx());
src_node_item->has_observer = true;

dependent_input_nodes.emplace(src_node);
dependent_for_shape_inference.emplace(src_node);
GELOGD("[%s] Dependent added from output of [%s:%d]",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str(),
peer_out_anchor->GetIdx());
}

for (const auto &dep_node : dependent_input_nodes) {
GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item, dependent_for_shape_inference));
for (const auto &dep_node : dependent_for_shape_inference) {
auto src_node_item = MutableNodeItem(dep_node);
GE_CHECK_NOTNULL(src_node_item);
src_node_item->has_observer = true;
node_item.dependents_for_shape_inference.emplace_back(dep_node);
}

GE_CHK_STATUS_RET(ParseDependentForFusedSubgraph(node_item));
for (const auto &dep_node : dependent_for_execution) {
auto src_node_item = MutableNodeItem(dep_node);
GE_CHECK_NOTNULL(src_node_item);
src_node_item->has_observer = true;
node_item.dependents_for_execution.emplace_back(dep_node);
}

return SUCCESS;
}

Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) {
Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item, std::set<ge::NodePtr> &dependencies) {
if (node_item.fused_subgraph == nullptr) {
return SUCCESS;
}
@@ -413,17 +400,12 @@ Status HybridModelBuilder::ParseDependentForFusedSubgraph(NodeItem &node_item) {
node_item.NodeName().c_str(),
op_desc->GetName().c_str(),
src_node_item->NodeName().c_str());
src_node_item->has_observer = true;
src_node_item->to_const_output_id_list.emplace(peer_out_anchor->GetIdx());

auto &depends = node_item.dependents_for_shape_inference;
if (std::find(depends.begin(), depends.end(), src_node) == depends.end()) {
depends.emplace_back(src_node);
GELOGD("[%s] Dependent added from output of [%s:%d]",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str(),
peer_out_anchor->GetIdx());
}
dependencies.emplace(src_node);
GELOGD("[%s] Dependent added from output of [%s:%d]",
node_item.NodeName().c_str(),
src_node_item->NodeName().c_str(),
peer_out_anchor->GetIdx());
}

return SUCCESS;
@@ -770,9 +752,23 @@ Status HybridModelBuilder::LoadGraph() {
GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu",
root_graph->GetDirectNodesSize(),
root_graph->GetAllNodesSize());
GE_DUMP(root_graph, "hybrid_merged_graph");
}

root_graph_ = root_graph;
// Reset node id by topological order across all subgraphs
int64_t index = 0;
for (const auto &node : root_graph->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
auto parent_graph = node->GetOwnerComputeGraph();
// No need to update nodes in known subgraph
if (parent_graph != nullptr && !parent_graph->GetGraphUnknownFlag()) {
continue;
}
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
op_desc->SetId(index++);
}
GE_DUMP(root_graph, "hybrid_merged_graph");
GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph.");
GELOGD("Done loading root graph successfully.");
GE_CHK_STATUS_RET(hybrid_model_.root_graph_item_->GroupNodes(), "Failed to group nodes for root graph");
@@ -810,6 +806,7 @@ Status HybridModelBuilder::LoadGraph() {
}
}

GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), "Failed to establish dependencies for hccl ops");
GELOGI("Done loading all subgraphs successfully.");
return SUCCESS;
}
@@ -1075,25 +1072,41 @@ Status HybridModelBuilder::InitWeights() {
return SUCCESS;
}

Status HybridModelBuilder::LoadTask(NodeItem &node_item) {
auto &node_ptr = node_item.node;
GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str());
auto load_ret = node_item.node_executor->LoadTask(hybrid_model_,
node_ptr,
node_item.kernel_task);
if (load_ret != UNSUPPORTED && load_ret != SUCCESS) {
GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str());
return load_ret;
}

GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str());
return SUCCESS;
}

Status HybridModelBuilder::LoadTasks() {
GE_CHK_STATUS_RET(CheckAicpuOpList(), "Check Aicpu op failed.");
std::map<int, std::map<std::string, NodeItem *>> ordered_partitioned_calls;
for (auto &it : hybrid_model_.node_items_) {
auto &node_item = it.second;
auto &node_ptr = node_item->node;
if (node_item->node_type == NETOUTPUT) {
continue;
}

GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str());
auto load_ret = node_item->node_executor->LoadTask(hybrid_model_,
node_ptr,
node_item->kernel_task);
if (load_ret != UNSUPPORTED && load_ret != SUCCESS) {
GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str());
return load_ret;
if (node_item->node_type == PARTITIONEDCALL) {
ordered_partitioned_calls[node_item->node_id][node_item->node_name] = node_item.get();
continue;
}
GE_CHK_STATUS_RET_NOLOG(LoadTask(*node_item));
}

GELOGD("[%s] Done loading task successfully.", node_ptr->GetName().c_str());
// HCCL operators need to be loaded in the same order across different processes
for (auto &it : ordered_partitioned_calls) {
for (auto &it2 : it.second) {
GE_CHK_STATUS_RET_NOLOG(LoadTask(*it2.second));
}
}

return SUCCESS;
@@ -1626,6 +1639,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem
auto temp_graph = MakeShared<ComputeGraph>("temp");
GE_CHECK_NOTNULL(temp_graph);
auto wrapper_node = temp_graph->AddNode(wrapper_op_desc);
wrapper_op_desc->SetId(parent_node_item->node_id);
GeModelPtr ge_model = subgraph_models_[subgraph_name];
GE_CHECK_NOTNULL(ge_model);
hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model);
@@ -2011,5 +2025,93 @@ Status HybridModelBuilder::CheckAicpuOpList() {
"Launch check aicpu op type failed.");
return SUCCESS;
}

Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) {
const auto &node = node_item->node;
auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node);
if (executor_type == NodeExecutorManager::ExecutorType::HCCL) {
std::string parallel_group;
if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) {
GELOGD("[%s] Got parallel group = [%s]", node_item->NodeName().c_str(), parallel_group.c_str());
parallel_group_to_nodes_[parallel_group].emplace(node_item);
std::set<std::string> group{parallel_group};
node_to_parallel_groups_[node_item].emplace(parallel_group);
}
} else if (executor_type == NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH) {
std::set<std::string> parallel_groups;
GELOGD("[%s] To collect parallel group for known-shaped subgraph", node_item->NodeName().c_str());
for (const auto &subgraph_name : node->GetOpDesc()->GetSubgraphInstanceNames()) {
GELOGD("[%s] Start to get parallel group from subgraph: %s",
node_item->NodeName().c_str(),
subgraph_name.c_str());
auto subgraph = root_graph_->GetSubgraph(subgraph_name);
GE_CHECK_NOTNULL(subgraph);
for (const auto &sub_node : subgraph->GetAllNodes()) {
std::string parallel_group;
if (AttrUtils::GetStr(sub_node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) {
GELOGD("[%s::%s] Got parallel group = %s",
subgraph_name.c_str(),
sub_node->GetName().c_str(),
parallel_group.c_str());
parallel_groups.emplace(parallel_group);
}
}
}

if (!parallel_groups.empty()) {
for (const auto &parallel_group : parallel_groups) {
parallel_group_to_nodes_[parallel_group].emplace(node_item);
GELOGD("[%s] has parallel group: %s", node_item->NodeName().c_str(), parallel_group.c_str());
}
node_to_parallel_groups_.emplace(node_item, std::move(parallel_groups));
}
}

return SUCCESS;
}

Status HybridModelBuilder::ParseDependentByParallelGroup() {
for (auto &it : hybrid_model_.node_items_) {
GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(it.second.get()));
}
for (const auto &it : node_to_parallel_groups_) {
auto node_item = it.first;
auto dst_executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node);
for (const auto &parallel_group : it.second) {
auto &dependent_nodes = parallel_group_to_nodes_[parallel_group];
NodeItem *nearest_dep_node = nullptr;
int max_id = -1;
for (auto &dep_node : dependent_nodes) {
if (dep_node->node_id < node_item->node_id && dep_node->node_id > max_id) {
nearest_dep_node = dep_node;
max_id = dep_node->node_id;
}
}

if (nearest_dep_node != nullptr) {
GELOGD("[%s] Nearest node = [%s]", node_item->NodeName().c_str(), nearest_dep_node->NodeName().c_str());
auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*nearest_dep_node->node);
if (src_engine_type == dst_executor_type) {
GELOGD("No need to add dependency for nodes with same executor type");
continue;
}
auto &deps = node_item->dependents_for_execution;
if (std::find(deps.begin(), deps.end(), nearest_dep_node->node) != deps.end()) {
GELOGD("%s->%s Already has dependency, skip it",
nearest_dep_node->node->GetName().c_str(),
node_item->NodeName().c_str());
continue;
}
nearest_dep_node->has_observer = true;
deps.emplace_back(nearest_dep_node->node);
GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]",
parallel_group.c_str(),
nearest_dep_node->NodeName().c_str(),
node_item->NodeName().c_str());
}
}
}
return SUCCESS;
}
} // namespace hybrid
} // namespace ge

+ 7
- 2
ge/hybrid/model/hybrid_model_builder.h View File

@@ -57,14 +57,17 @@ class HybridModelBuilder {
Status ValidateParams();
Status LoadGraph();
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model);
Status LoadTask(NodeItem &node_item);
Status LoadTasks();
Status IdentifyVariableOutputs(NodeItem &node_item);
Status IdentifySameInputs(NodeItem &node_item);
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item);
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item);
Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item);
Status CollectParallelGroups(NodeItem *node_item);
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies);
Status ParseDependentForFusedSubgraph(NodeItem &node_item);
Status ParseDependentForFusedSubgraph(NodeItem &node_item, std::set<ge::NodePtr> &dependencies);
Status ParseDependentByParallelGroup();
Status IndexTaskDefs();
Status IndexTaskDefs(const ComputeGraphPtr &sub_graph, const GeModelPtr &ge_model);
Status IndexSpecialNodes();
@@ -97,12 +100,14 @@ class HybridModelBuilder {
NodeItem *MutableNodeItem(const NodePtr &node);

GeRootModelPtr ge_root_model_;
ComputeGraphPtr root_graph_;
std::map<std::string, GeModelPtr> subgraph_models_;
std::map<std::string, NodePtr> constant_op_nodes_;
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_;
std::map<NodeItem *, std::set<std::string>> node_to_parallel_groups_;

HybridModel &hybrid_model_;
std::map<NodePtr, std::vector<std::pair<int, NodePtr>>> node_ref_inputs_;
int node_index = 0;

RuntimeParam &runtime_param_;
VarManager *var_manager_ = nullptr;


+ 4
- 0
ge/hybrid/model/node_item.cc View File

@@ -251,6 +251,10 @@ bool NodeItem::IsControlOp() const {
return ge::hybrid::IsControlOp(op_desc->GetType());
}

bool NodeItem::IsHcclOp() const {
return NodeExecutorManager::GetInstance().ResolveExecutorType(*node) == NodeExecutorManager::ExecutorType::HCCL;
}

std::string NodeItem::DebugString() const {
std::stringstream ss;
ss << "Node: ";


+ 2
- 0
ge/hybrid/model/node_item.h View File

@@ -67,6 +67,8 @@ struct NodeItem {

bool IsControlOp() const;

bool IsHcclOp() const;

void SetToDynamic();

std::string DebugString() const;


+ 24
- 25
ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc View File

@@ -95,13 +95,6 @@ Status KnownNodeTask::UpdateArgs(TaskContext &context) {
Status KnownNodeTask::Init(TaskContext &context) {
// allocate output mem
GE_CHK_STATUS_RET(context.AllocateOutputs(), "known node task allocate output failed.");

// init davinicmodel
if (!load_flag_) {
davinci_model_->InitRuntimeParams();
GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed.");
}

// allocate mem base
void *buffer = nullptr;
if (davinci_model_->TotalMemSize() != 0) {
@@ -126,30 +119,34 @@ Status KnownNodeTask::Init(TaskContext &context) {
auto dump_properties = context.GetDumpProperties();
if (dump_properties.IsDumpOpen() || dump_properties.IsOpDebugOpen()) {
davinci_model_->SetDumpProperties(dump_properties);
void *global_step = nullptr;
TensorValue *varible_global_step = context.GetVariable(NODE_NAME_GLOBAL_STEP);
if (varible_global_step != nullptr) {
global_step = varible_global_step->MutableData();
}
void *global_step = context.GetExecutionContext()->global_step;
davinci_model_->SetKnownShapeGlobalStep(global_step);
}
int32_t device_id = 0;
rtError_t rt_ret = rtGetDevice(&device_id);
if (rt_ret != RT_ERROR_NONE || device_id < 0) {
GELOGE(rt_ret, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}
davinci_model_->SetDeviceId(device_id);
GE_CHK_STATUS_RET(davinci_model_->Init(), "KnownNodeExecutor::InitDavinciModel failed.");
load_flag_ = true;
} else {
GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(),
davinci_model_->Id(), davinci_model_->SubModelId()), "KnownNodeTask::Init destroy aicpu kernel failed.");
}
GE_CHK_STATUS_RET(ModelManager::GetInstance()->DestroyAicpuKernel(davinci_model_->GetSessionId(),
davinci_model_->Id(), davinci_model_->SubModelId()),
"KnownNodeTask::Init destroy aicpu kernel failed.");
GELOGI("[%s] KnownNodeExecutor::Init success.", context.GetNodeName());
return SUCCESS;
}

Status KnownNodeTask::InitDavinciModel() {
GELOGD("[Init][Model] start");
davinci_model_->InitRuntimeParams();
GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed");
int32_t device_id = 0;
GE_CHK_RT_RET(rtGetDevice(&device_id));
davinci_model_->SetDeviceId(static_cast<uint32_t>(device_id));
GE_CHK_STATUS_RET(DoInitDavinciModel(), "[Init][Model] Failed to init davinci model.");
GELOGD("[Init][Model] success");
return SUCCESS;
}

Status KnownNodeTask::DoInitDavinciModel() {
return davinci_model_->Init();
}

Status KnownNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const {
GELOGD("[%s] KnownNodeExecutor::PrepareTask in.", context.GetNodeName());
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorPrepareTask] Start");
@@ -186,9 +183,11 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node

GE_CHK_STATUS_RET(davinci_model->Assign(ge_model), "KnownNodeExecutor::LoadTask davincimodel assign failed.");

task = MakeShared<KnownNodeTask>(davinci_model);
GE_CHECK_NOTNULL(task);
auto known_node_task = MakeShared<KnownNodeTask>(davinci_model);
GE_CHECK_NOTNULL(known_node_task);
GE_CHK_STATUS_RET_NOLOG(known_node_task->InitDavinciModel());
GELOGI("[%s] KnownNodeExecutor::LoadTask success.", node->GetName().c_str());
task = std::move(known_node_task);
return SUCCESS;
}



+ 5
- 3
ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h View File

@@ -31,11 +31,15 @@ class KnownNodeTask : public NodeTask {
: davinci_model_(davinci_model)
{}

~KnownNodeTask() {}
~KnownNodeTask() = default;

Status UpdateArgs(TaskContext &context) override;
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
Status Init(TaskContext &context) override;
Status InitDavinciModel();

protected:
virtual Status DoInitDavinciModel();
private:
std::shared_ptr<DavinciModel> davinci_model_ = nullptr;
bool load_flag_ = false;
@@ -47,8 +51,6 @@ class KnownNodeExecutor : public NodeExecutor {
Status PrepareTask(NodeTask &task, TaskContext &context) const;
Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const;
~KnownNodeExecutor() {}
private:
std::shared_ptr<DavinciModel> davinci_model_ = nullptr;
};
} // namespace hybrid
} // namespace ge


+ 4
- 0
ge/single_op/single_op_manager.cc View File

@@ -19,6 +19,9 @@
#include <mutex>
#include <string>

#include "graph/manager/graph_mem_allocator.h"
#include "graph/manager/graph_caching_allocator.h"

namespace ge {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY SingleOpManager::~SingleOpManager() {
for (auto &it : stream_resources_) {
@@ -67,6 +70,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::Release
delete it->second;
it->second = nullptr;
(void)stream_resources_.erase(it);
MemManager::Instance().CachingInstance(RT_MEMORY_HBM).TryFreeBlocks();
return SUCCESS;
}



+ 35
- 6
ge/single_op/single_op_model.cc View File

@@ -44,19 +44,46 @@ namespace ge {
namespace {
const size_t kDataOutputNum = 1;

bool NeedHybridModel(GeModelPtr &ge_model) {
Status IfInferDepend(GeModelPtr &ge_model, bool &flag) {
auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph());
GE_CHECK_NOTNULL(comp_graph);
for (const auto &node : comp_graph->GetAllNodes()) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
const auto &depends = op_desc->GetOpInferDepends();
if (!depends.empty()) {
flag = true;
return SUCCESS;
}
}
return SUCCESS;
}

Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) {
bool infer_depend_flag = false;
GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed.");
auto tasks = ge_model->GetModelTaskDefPtr()->task();
int32_t kernel_task_num = 0;
for (int i = 0; i < tasks.size(); ++i) {
auto task_type = static_cast<rtModelTaskType_t>(tasks[i].type());
if (task_type == RT_MODEL_TASK_KERNEL || task_type == RT_MODEL_TASK_ALL_KERNEL) {
kernel_task_num++;
if (kernel_task_num > 1) {
return true;
const auto &context = task_type == RT_MODEL_TASK_KERNEL ? tasks[i].kernel().context() :
tasks[i].kernel_with_handle().context();
auto kernel_type = static_cast<ccKernelType>(context.kernel_type());
if (kernel_type == ccKernelType::TE) {
if (infer_depend_flag) {
flag = true;
return SUCCESS;
}
kernel_task_num++;
if (kernel_task_num > 1) {
flag = true;
return SUCCESS;
}
}
}
}
return false;
return SUCCESS;
}
} // namespace

@@ -503,7 +530,9 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp &

auto ge_model = model_helper_.GetGeModel();
GE_CHECK_NOTNULL(ge_model);
if (NeedHybridModel(ge_model)) {
bool need_hybrid_model = false;
GE_CHK_STATUS_RET(NeedHybridModel(ge_model, need_hybrid_model), "[Check][NeedHybridModel] failed.");
if (need_hybrid_model) {
GELOGD("Build single op HybridModel.");
GE_CHK_STATUS_RET_NOLOG(hybrid::NodeExecutorManager::GetInstance().EnsureInitialized());
auto root_model = model_helper_.GetGeRootModel();


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit e68940202b874ccec77d621f59b34fc4404bede2
Subproject commit 0c4602a4615a9368b06633a5087e2114518f29ca

+ 1
- 1
parser

@@ -1 +1 @@
Subproject commit b203d47837421b2c149f353fc0808f6a29fa584e
Subproject commit d851e1d467768b6cefd8f5f44745be1c5312121a

+ 4
- 0
tests/depends/runtime/src/runtime_stub.cc View File

@@ -435,3 +435,7 @@ rtError_t rtGetTaskIdAndStreamID(uint32_t *taskId, uint32_t *streamId)
rtError_t rtDebugRegisterForStream(rtStream_t stream, uint32_t flag, const void *addr, uint32_t *streamId, uint32_t *taskId) {
return RT_ERROR_NONE;
}

rtError_t rtDebugUnRegisterForStream(rtStream_t stream) {
return RT_ERROR_NONE;
}

+ 3
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -667,6 +667,7 @@ set(PASS_TEST_FILES
"graph/passes/merge_pass_unittest.cc"
#"graph/passes/switch_pass_unittest.cc"
"graph/passes/switch_logic_remove_pass_unittest.cc"
"graph/passes/switch_dead_branch_elimination_unittest.cc"
"graph/passes/assert_pass_unittest.cc"
"graph/passes/dropout_pass_unittest.cc"
"graph/passes/unused_const_pass_unittest.cc"
@@ -731,6 +732,7 @@ set(KERNEL_TEST_FILES
"graph/passes/folding_kernel/gather_v2_kernel_unittest.cc"
"graph/passes/folding_kernel/slice_kernel_unittest.cc"
"graph/passes/folding_kernel/dynamic_stitch_kernel_unittest.cc"
"graph/passes/atomic_addr_clean_pass_unittest.cc"
)

set(MULTI_PARTS_TEST_FILES
@@ -760,6 +762,7 @@ set(MULTI_PARTS_TEST_FILES
"graph/variable_accelerate_ctrl_unittest.cc"
"graph/build/logical_stream_allocator_unittest.cc"
"graph/build/mem_assigner_unittest.cc"
"graph/build/task_generator_unittest.cc"
"graph/preprocess/graph_preprocess_unittest.cc"
"graph/manager/hcom_util_unittest.cc"
"graph/manager/graph_caching_allocator_unittest.cc"


+ 68
- 0
tests/ut/ge/graph/build/task_generator_unittest.cc View File

@@ -0,0 +1,68 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <memory>

#include "graph/anchor.h"
#include "graph/attr_value.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
#include "omg/omg_inner_types.h"
#include "../passes/graph_builder_utils.h"

#define protected public
#define private public
#include "graph/build/task_generator.h"
#undef protected
#undef private

using namespace std;
using namespace testing;
using namespace ge;

class UtestTaskGeneratorTest : public testing::Test {
public:
ge::ComputeGraphPtr BuildGraphFpProfiling() {
ge::ut::GraphBuilder builder("graph");
auto data = builder.AddNode("data", "phony", 1, 1);
auto addn1 = builder.AddNode("addn1", "AddN", 1, 1);
auto netoutput = builder.AddNode("netoutput", "NetOutput", 2, 0);
auto op_desc = data->GetOpDesc();
(void)AttrUtils::SetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "IteratorV2");
op_desc->SetOpKernelLibName("GE");
builder.AddDataEdge(data, 0, addn1, 0);
builder.AddDataEdge(addn1, 0, netoutput, 0);
return builder.GetGraph();
}

protected:
void SetUp() {}
void TearDown() {}
};

TEST_F(UtestTaskGeneratorTest, AutoFindFpOpIndex) {
auto graph = BuildGraphFpProfiling();
TaskGenerator task_generator(nullptr, 0);
ProfilingPoint profiling_point;
profiling_point.fp_index = -1;
EXPECT_EQ(task_generator.AutoFindFpOpIndex(graph, profiling_point), SUCCESS);
// addn1 is fp
EXPECT_EQ(profiling_point.fp_index, 2);
}

+ 65
- 0
tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc View File

@@ -0,0 +1,65 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include "graph/passes/atomic_addr_clean_pass.h"
#include "common/op/ge_op_utils.h"
#include "common/types.h"
#include "graph/anchor.h"
#include "graph/attr_value.h"
#include "graph/compute_graph.h"
#include "graph/op_desc.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
#include "inc/pass_manager.h"
using namespace testing;

namespace ge {
class UtestGraphPassesAtomicAddrCleanPass : public Test {
public:
UtestGraphPassesAtomicAddrCleanPass() {
graph_ = std::make_shared<ComputeGraph>("test");
}

NodePtr NewNode(const string &name, const string &type, int input_cnt, int output_cnt) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
for (int i = 0; i < input_cnt; ++i) {
op_desc->AddInputDesc(GeTensorDesc());
}
for (int i = 0; i < output_cnt; ++i) {
op_desc->AddOutputDesc(GeTensorDesc());
}
NodePtr node = graph_->AddNode(op_desc);
return node;
}

ComputeGraphPtr graph_;
};

// node1 -> node2 -> node3
TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) {
auto node1 = NewNode("node1", DATA, 0, 1);
auto node2 = NewNode("node2", RELU, 1, 1);
auto node3 = NewNode("node3", NETOUTPUT, 1, 0);
GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
AtomicAddrCleanPass atomi_addr_clean_pass;
Status ret = atomi_addr_clean_pass.Run(graph_);
EXPECT_EQ(ret, SUCCESS);
}
} // namespace ge

+ 163
- 0
tests/ut/ge/graph/passes/switch_dead_branch_elimination_unittest.cc View File

@@ -0,0 +1,163 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include <string>
#include <gtest/gtest.h>
#include "common/ge_inner_error_codes.h"
#include "graph/passes/switch_dead_branch_elimination.h"
#include "graph_builder_utils.h"
namespace ge {
class UtestSwitchDeadBranchElimination : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};
namespace {
/*
* data1 const1
* \ /
* case1
* |
* relu1
* |
* netoutput
*/
ut::GraphBuilder ParentGraphBuilder() {
ut::GraphBuilder builder = ut::GraphBuilder("g1");
auto data1 = builder.AddNode("data1", "Data", 0, 1);
auto const1 = builder.AddNode("const1", "Const", 0, 1);
auto case1 = builder.AddNode("case1", CASE, 2, 1);
auto relu1 = builder.AddNode("relu1", "Relu", 1, 1);
auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
int32_t weight[1] = {1};
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
OpDescUtils::SetWeights(const1, {tensor});
builder.AddDataEdge(data1, 0, case1, 0);
builder.AddDataEdge(const1, 0, case1, 1);
builder.AddDataEdge(case1, 0, relu1, 0);
builder.AddDataEdge(relu1, 0, netoutput, 0);
return builder;
}
/*
* data1 data2
* \ /
* switch
* / \
* relu1 relu2
* \ /
* merge
* |
* netoutput
*/
ut::GraphBuilder SwitchSubgraphBuilder(string graph_name, uint32_t num) {
ut::GraphBuilder builder = ut::GraphBuilder(graph_name);
string data1_name = "data1_" + std::to_string(num);
auto data1 = builder.AddNode(data1_name, "Data", 0, 1);
auto data1_desc = data1->GetOpDesc();
EXPECT_NE(data1_desc, nullptr);
AttrUtils::SetInt(data1_desc, "_parent_node_index", 0);
string data2_name = "data2_" + std::to_string(num);
auto data2 = builder.AddNode(data2_name, "Data", 0, 1);
auto data2_desc = data2->GetOpDesc();
EXPECT_NE(data2_desc, nullptr);
AttrUtils::SetInt(data2_desc, "_parent_node_index", 1);
string switch_name = "switch_" + std::to_string(num);
auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2);
string relu1_name = "relu1_" + std::to_string(num);
auto relu1 = builder.AddNode(relu1_name, "Relu", 1, 1);
string relu2_name = "relu2_" + std::to_string(num);
auto relu2 = builder.AddNode(relu2_name, "Relu", 1, 1);
string merge_name = "merge_" + std::to_string(num);
auto merge = builder.AddNode(merge_name, "Merge", 2, 1);
string output_name = "output_" + std::to_string(num);
auto netoutput = builder.AddNode(output_name, NETOUTPUT, 1, 0);
builder.AddDataEdge(data1, 0, switch1, 0);
builder.AddDataEdge(data2, 0, switch1, 1);
builder.AddDataEdge(switch1, 0, relu1, 0);
builder.AddDataEdge(switch1, 1, relu2, 0);
builder.AddDataEdge(relu1, 0, merge, 0);
builder.AddDataEdge(relu2, 0, merge, 1);
builder.AddDataEdge(merge, 0, netoutput, 0);
return builder;
}
void AddCaseSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) {
auto case_node = parent_graph->FindNode("case1");
EXPECT_NE(case_node, nullptr);
for (uint32_t i = 0; i < branch_num; ++i) {
string name = "Branch_Graph_" + std::to_string(i);
auto builder_subgraph = SwitchSubgraphBuilder(name, i);
auto switch_subgraph = builder_subgraph.GetGraph();
case_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName());
case_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName());
switch_subgraph->SetParentNode(case_node);
switch_subgraph->SetParentGraph(parent_graph);
EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS);
}
}
} // namespace
TEST_F(UtestSwitchDeadBranchElimination, switch_dead_branch_elimination_across_case_success) {
auto builder = ParentGraphBuilder();
auto parent_graph = builder.GetGraph();
AddCaseSubgraph(parent_graph, 2);
auto subgraphs = parent_graph->GetAllSubgraphs();
EXPECT_EQ(subgraphs.size(), 2);
SwitchDeadBranchElimination switch_pass;
for (auto &subgraph : subgraphs) {
auto switch_node = subgraph->FindFirstNodeMatchType("Switch");
if (switch_node != nullptr) {
EXPECT_EQ(switch_pass.Run(switch_node), SUCCESS);
}
}
auto all_nodes = parent_graph->GetAllNodes();
EXPECT_EQ(all_nodes.size(), 17);
for (auto &subgraph : subgraphs) {
EXPECT_EQ(subgraph->GetDirectNode().size(), 6);
EXPECT_EQ(subgraph->FindFirstNodeMatchType("Switch"), nullptr);
auto merge_node = subgraph->FindFirstNodeMatchType("Merge");
EXPECT_NE(merge_node, nullptr);
auto merge_innode = merge_node->GetInDataNodes();
EXPECT_EQ(merge_innode.size(), 1);
}
}
} // namespace ge

+ 14
- 1
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -30,6 +30,7 @@
#include "framework/common/debug/log.h"
#include "graph/ge_context.h"
#include "hybrid/executor/hybrid_execution_context.h"
#include "hybrid/executor/hybrid_model_executor.h"
#include "hybrid/node_executor/aicore/aicore_task_builder.h"
#include "graph/load/model_manager/tbe_handle_store.h"
#include "graph/manager/graph_mem_allocator.h"
@@ -242,4 +243,16 @@ TEST_F(UtestGeHybrid, init_weight_success) {
ge_sub_model->SetWeight(weight_buffer);
ret = hybrid_model_builder.InitWeights();
ASSERT_EQ(ret,PARAM_INVALID);
}
}

TEST_F(UtestGeHybrid, hybrid_model_executor) {
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("abc");
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
HybridModel model(root_model);
HybridModel *model_ptr = &model;

uint32_t device_id = 0;
rtStream_t stream;
HybridModelExecutor executor(model_ptr, device_id, stream);
executor.Init();
}

Loading…
Cancel
Save