Browse Source

Fix multi control from one node

tags/v1.3.0
zhangxiaokun 4 years ago
parent
commit
7b1331770a
6 changed files with 14 additions and 28 deletions
  1. +4
    -7
      ge/ge_local_engine/engine/host_cpu_engine.cc
  2. +0
    -4
      ge/hybrid/model/hybrid_model_builder.cc
  3. +7
    -10
      ge/hybrid/model/node_item.cc
  4. +1
    -1
      ge/hybrid/model/node_item.h
  5. +1
    -0
      ge/hybrid/node_executor/hccl/hccl_node_executor.cc
  6. +1
    -6
      ge/hybrid/node_executor/rts/rts_node_executor.cc

+ 4
- 7
ge/ge_local_engine/engine/host_cpu_engine.cc View File

@@ -13,15 +13,15 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "host_cpu_engine.h"
#include "graph/common/omg_util.h"
#include "ge_local_engine/engine/host_cpu_engine.h"
#include "graph/utils/op_desc_utils.h" #include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_adapter.h" #include "graph/utils/tensor_adapter.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/type_utils.h"
#include "register/op_kernel_registry.h" #include "register/op_kernel_registry.h"
#include "register/host_cpu_context.h" #include "register/host_cpu_context.h"
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "common/ge/plugin_manager.h" #include "common/ge/plugin_manager.h"
#include "graph/utils/type_utils.h"
#include "common/fp16_t.h" #include "common/fp16_t.h"
#include "common/math/math_util.h" #include "common/math/math_util.h"


@@ -123,10 +123,7 @@ bool HostCpuEngine::CheckSupported(const string &op_type) {
} }


Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) { Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) {
std::string op_type;
auto status = GetOriginalType(node, op_type);
GE_CHK_BOOL_EXEC_NOLOG(status == SUCCESS, return status);

const std::string op_type = NodeUtils::GetNodeType(node);
auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type);
if (kernel == nullptr) { if (kernel == nullptr) {
GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str());


+ 0
- 4
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -288,10 +288,6 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n
return SUCCESS; return SUCCESS;
} }


if (node->GetType() == MEMCPYASYNC) { // Convert MemcpyAsync to Identity.
node->GetOpDesc()->SetType(IDENTITY);
}

std::unique_ptr<NodeItem> new_node; std::unique_ptr<NodeItem> new_node;
GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName());
GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor));


+ 7
- 10
ge/hybrid/model/node_item.cc View File

@@ -14,10 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */


#include "node_item.h"
#include <sstream>
#include "common/debug/log.h"
#include "graph/common/omg_util.h"
#include "hybrid/model/node_item.h"

#include "graph/compute_graph.h" #include "graph/compute_graph.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "hybrid/executor/worker/shape_inference_engine.h" #include "hybrid/executor/worker/shape_inference_engine.h"
@@ -98,8 +96,7 @@ Status ParseFusedSubgraph(NodeItem &node_item) {
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc(); auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc); GE_CHECK_NOTNULL(op_desc);
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type));
const std::string node_type = NodeUtils::GetNodeType(node);
if (node_type == DATA) { if (node_type == DATA) {
GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph));
} else if (node_type == kNodeTypeRetVal) { } else if (node_type == kNodeTypeRetVal) {
@@ -409,8 +406,8 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) {


void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) {
if (switch_index < switch_groups_.size()) { if (switch_index < switch_groups_.size()) {
std::vector<const NodeItem *> &switch_group = switch_groups_[switch_index];
switch_group.emplace_back(node_item);
auto &switch_group = switch_groups_[switch_index];
switch_group.emplace(node_item);
} else { } else {
ctrl_send_.insert(node_item); ctrl_send_.insert(node_item);
} }
@@ -433,8 +430,8 @@ void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) {
} }


// this is StreamMerge node, node_item is StreamActive node. // this is StreamMerge node, node_item is StreamActive node.
std::vector<const NodeItem *> &switch_group = switch_groups_[merge_index];
switch_group.emplace_back(node_item);
auto &switch_group = switch_groups_[merge_index];
switch_group.emplace(node_item);


node_item->ctrl_send_.emplace(this); node_item->ctrl_send_.emplace(this);
GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str());


+ 1
- 1
ge/hybrid/model/node_item.h View File

@@ -155,7 +155,7 @@ struct NodeItem {
std::map<const NodeItem *, int> data_recv_; // Recv data notify from std::map<const NodeItem *, int> data_recv_; // Recv data notify from
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to
std::vector<std::set<const NodeItem *>> switch_groups_; // Send ctrl notify to


std::shared_ptr<NodeTask> kernel_task; std::shared_ptr<NodeTask> kernel_task;
std::unique_ptr<FusedSubgraph> fused_subgraph; std::unique_ptr<FusedSubgraph> fused_subgraph;


+ 1
- 0
ge/hybrid/node_executor/hccl/hccl_node_executor.cc View File

@@ -342,6 +342,7 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do
GE_CHK_RT_RET(rtEventDestroy(evt)); GE_CHK_RT_RET(rtEventDestroy(evt));
} }
GELOGI("rdma callback success."); GELOGI("rdma callback success.");
return SUCCESS;
}; };


HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback);


+ 1
- 6
ge/hybrid/node_executor/rts/rts_node_executor.cc View File

@@ -17,13 +17,9 @@
#include "hybrid/node_executor/rts/rts_node_executor.h" #include "hybrid/node_executor/rts/rts_node_executor.h"
#include "hybrid/node_executor/rts/rts_task_factory.h" #include "hybrid/node_executor/rts/rts_task_factory.h"


#include "common/debug/log.h"
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "common/types.h"
#include "graph/common/omg_util.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "hybrid/model/hybrid_model.h" #include "hybrid/model/hybrid_model.h"
#include "runtime/rt.h"


namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
@@ -133,8 +129,7 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function<
Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const {
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
GELOGD("[%s] Load for local task.", node->GetName().c_str()); GELOGD("[%s] Load for local task.", node->GetName().c_str());
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed.");
const std::string node_type = NodeUtils::GetNodeType(node);
RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type);
if (rts_task == nullptr) { if (rts_task == nullptr) {
GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str());


Loading…
Cancel
Save