Browse Source

add mds cutn

pull/2000/head
gengchao4@huawei.com 3 years ago
parent
commit
ff722e25b8
10 changed files with 218 additions and 259 deletions
  1. +5
    -0
      ge/CMakeLists.txt
  2. +5
    -3
      ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc
  3. +9
    -11
      ge/graph/execute/graph_execute.cc
  4. +7
    -11
      ge/graph/load/graph_loader.cc
  5. +19
    -37
      ge/graph/passes/mds_kernels/base_mds_kernel.cc
  6. +6
    -2
      ge/graph/passes/mds_kernels/base_mds_kernel.h
  7. +102
    -116
      ge/graph/passes/mds_kernels/mds_utils.cc
  8. +27
    -36
      ge/graph/passes/mds_kernels/mds_utils.h
  9. +29
    -37
      ge/graph/passes/mds_pass.cc
  10. +9
    -6
      ge/graph/passes/mds_pass.h

+ 5
- 0
ge/CMakeLists.txt View File

@@ -166,6 +166,11 @@ set(EXECUTOR_SRC_LIST
"graph/manager/util/debug.cc"
#"graph/manager/util/hcom_util.cc" # Just for runner.
"graph/passes/pass_utils.cc"
"graph/passes/mds_pass.cc"
"graph/passes/mds_kernels/mds_utils.cc"
"graph/passes/mds_kernels/variable_mds_kernel.cc"
"graph/passes/mds_kernels/conv2d_mds_kernel.cc"
"graph/passes/mds_kernels/base_mds_kernel.cc"
"host_kernels/add_kernel.cc"
"host_kernels/broadcast_args_kernel.cc"
"host_kernels/broadcast_gradient_args_kernel.cc"


+ 5
- 3
ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc View File

@@ -77,9 +77,11 @@ Status GeLocalOpsKernelInfoStore::DestroySession(const map<string, string> &sess
return SUCCESS;
}
Status GeLocalOpsKernelInfoStore::SetCutSupportedInfo(const NodePtr &node) {
//TODO:1,针对变量类型是否标识为可训练变量
//2,是否开启smdp1和3
//满足上述两点,在变量切分信息里面设置为当前变量节点可切
// TODO:
// 1. Whether the variable type is identified as a trainable variable
// 2, whether to turn on smdp1 and 3
// To meet the above two points, set the current variable
// node to be tangent in the variable segmentation information
return SUCCESS;
}
} // namespace ge_local


+ 9
- 11
ge/graph/execute/graph_execute.cc View File

@@ -540,14 +540,12 @@ Status GraphExecutor::SetCallback(uint32_t model_id, const GeRootModelPtr &ge_ro
return SUCCESS;
}
Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<ge::Tensor> &inputs,
const RunAsyncCallback &callback) {
const RunAsyncCallback &callback) {
// get deploy number of model instance
auto root_graph = ge_root_model->GetRootGraph();
vector<GeAttrValue::NAMED_ATTRS> deploy_info;
if (!ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deploy_info) || deploy_info.empty()) {
GELOGE(FAILED,
"[AsyncMultiExecuteModel] graph %s has invalid deploy attr %s",
root_graph->GetName().c_str(),
GELOGE(FAILED, "[AsyncMultiExecuteModel] graph %s has invalid deploy attr %s", root_graph->GetName().c_str(),
ATTR_NAME_DEPLOY_INFO.c_str());
return FAILED;
}
@@ -567,7 +565,7 @@ Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model
std::vector<GeTensorPtr> graph_inputs;
if (ge::AttrUtils::MutableListTensor(thread_instance, kAttrGraphInputs, graph_inputs)) {
std::vector<ge::Tensor> graph_input_updated(inputs.begin(), inputs.end());
for (const auto &ge_tensor_ptr:graph_inputs) {
for (const auto &ge_tensor_ptr : graph_inputs) {
graph_input_updated.push_back(TensorAdapter::AsTensor(*ge_tensor_ptr));
}
GraphExecutor graph_executor;
@@ -575,12 +573,12 @@ Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model
std::future<Status> f;
bool need_return_result = false;
if ((ge::AttrUtils::GetBool(thread_instance, kAttrNeedReturnResult, need_return_result) && need_return_result)) {
f = executor.commit(execute_model_func, &graph_executor, ge_root_model,
model_ids[i], graph_input_updated, callback);
f = executor.commit(execute_model_func, &graph_executor, ge_root_model, model_ids[i], graph_input_updated,
callback);
} else {
RunAsyncCallback callback_stub;
f = executor.commit(execute_model_func, &graph_executor, ge_root_model,
model_ids[i], graph_input_updated, callback_stub);
f = executor.commit(execute_model_func, &graph_executor, ge_root_model, model_ids[i], graph_input_updated,
callback_stub);
}
if (!f.valid()) {
GELOGE(FAILED, "[Call][Commit] failed, Future is invalid");
@@ -600,8 +598,8 @@ Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model
return SUCCESS;
}

Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, uint32_t model_id, const std::vector<ge::Tensor> &inputs,
const RunAsyncCallback &callback) {
Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, uint32_t model_id,
const std::vector<ge::Tensor> &inputs, const RunAsyncCallback &callback) {
if (model_id == kInvalidModelId) {
GELOGE(INTERNAL_ERROR, "No valid model id.");
return INTERNAL_ERROR;


+ 7
- 11
ge/graph/load/graph_loader.cc View File

@@ -149,17 +149,14 @@ Status GraphLoader::LoadMultiModelOnline(const std::shared_ptr<ge::GeRootModel>
auto root_graph = ge_root_model->GetRootGraph();
vector<GeAttrValue::NAMED_ATTRS> deploy_info;
if (!ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deploy_info) || deploy_info.empty()) {
GELOGE(FAILED,
"[LoadMultiModelOnline] Load multi model failed, graph %s has invalid deploy attr %s",
root_graph->GetName().c_str(),
ATTR_NAME_DEPLOY_INFO.c_str());
GELOGE(FAILED, "[LoadMultiModelOnline] Load multi model failed, graph %s has invalid deploy attr %s",
root_graph->GetName().c_str(), ATTR_NAME_DEPLOY_INFO.c_str());
return FAILED;
}
auto thread_instances_size = deploy_info.size();
auto device_id_fission_from = GetContext().DeviceId();
GELOGI("Graph %s need to load model %zu times, and fission from device %u.",
root_graph->GetName().c_str(), thread_instances_size,
device_id_fission_from);
GELOGI("Graph %s need to load model %zu times, and fission from device %u.", root_graph->GetName().c_str(),
thread_instances_size, device_id_fission_from);
ThreadPool executor(thread_instances_size);
std::vector<std::future<Status>> vector_future;
GE_TIMESTAMP_START(LoadModelOnline);
@@ -167,17 +164,16 @@ Status GraphLoader::LoadMultiModelOnline(const std::shared_ptr<ge::GeRootModel>
auto thread_instance = deploy_info[i];
std::string device_type;
ModelIdInfo model_id_info;
//TODO: listener要区分同步异步
std::shared_ptr<ModelListener> listener;
if (is_async) {
listener = MakeShared<RunAsyncListener>();
GE_CHECK_NOTNULL(listener);
} else {
// TODO: GraphModelListener for sync
}
int64_t device_id_fissioned = kInvalidDieId;
if (!ge::AttrUtils::GetInt(thread_instance, kAttrDeviceId, device_id_fissioned)
|| device_id_fissioned == kInvalidDieId) {
if (!ge::AttrUtils::GetInt(thread_instance, kAttrDeviceId, device_id_fissioned) ||
device_id_fissioned == kInvalidDieId) {
REPORT_CALL_ERROR("E19999", "graph %s has invalid deploy attr %s", root_graph->GetName().c_str(),
ATTR_NAME_DEPLOY_INFO.c_str());
GELOGE(GRAPH_FAILED, "[LoadMultiModelOnline] graph %s has invalid deploy attr %s", root_graph->GetName().c_str(),


+ 19
- 37
ge/graph/passes/mds_kernels/base_mds_kernel.cc View File

@@ -26,8 +26,7 @@ shared_ptr<DeploySchedulerKernel> GetKernelByType(const NodePtr &node) {
string type = node->GetType();
if (type == FRAMEWORKOP) {
if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) {
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed",
ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(),
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(),
node->GetName().c_str(), node->GetType().c_str());
return nullptr;
}
@@ -35,10 +34,10 @@ shared_ptr<DeploySchedulerKernel> GetKernelByType(const NodePtr &node) {
return factory.Create(type);
}
}
} // namespace mds_cut_pass
shared_ptr<DeploySchedulerKernel> DeploySchedulerKernel::Instance() {
static const std::shared_ptr<DeploySchedulerKernel> instance_ptr =
shared_ptr<DeploySchedulerKernel>(new(std::nothrow) DeploySchedulerKernel());
shared_ptr<DeploySchedulerKernel>(new (std::nothrow) DeploySchedulerKernel());
return instance_ptr;
}
@@ -63,20 +62,15 @@ Status DeploySchedulerKernel::CutN(const ge::NodePtr &node) {
if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) {
tensor_desc->SetShape(src_tensor_desc->GetShape());
} else {
MDS_REQUIRE_SUCCESS(MdsUtils::DataGather(src_anchor, in_anchor),
"[CutN] failed to gather between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(),
src_anchor->GetIdx(),
op_desc->GetName().c_str(),
in_anchor->GetIdx());
MDS_REQUIRE_SUCCESS(
MdsUtils::DataGather(src_anchor, in_anchor), "[CutN] failed to gather between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx());
}
} else {
if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) {
MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_),
"[CutN] failed to slice between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(),
src_anchor->GetIdx(),
op_desc->GetName().c_str(),
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(),
in_anchor->GetIdx());
} else {
tensor_desc->SetShape(src_tensor_desc->GetShape());
@@ -85,14 +79,11 @@ Status DeploySchedulerKernel::CutN(const ge::NodePtr &node) {
// insert hcomallreduce for cutn
bool is_grad_compute_node = false;
if (ge::AttrUtils::GetBool(src_node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node)
&& is_grad_compute_node) {
MDS_REQUIRE_SUCCESS(MdsUtils::DataReduce(src_anchor, in_anchor),
"[CutN] failed to reduce between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(),
src_anchor->GetIdx(),
op_desc->GetName().c_str(),
in_anchor->GetIdx());
if (ge::AttrUtils::GetBool(src_node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) &&
is_grad_compute_node) {
MDS_REQUIRE_SUCCESS(
MdsUtils::DataReduce(src_anchor, in_anchor), "[CutN] failed to reduce between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx());
}
}
@@ -122,32 +113,23 @@ Status DeploySchedulerKernel::CutH(const ge::NodePtr &node) {
if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) {
MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx()),
"[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(),
src_anchor->GetIdx(),
op_desc->GetName().c_str(),
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(),
in_anchor->GetIdx());
} else {
MDS_REQUIRE_SUCCESS(MdsUtils::DataGather(src_anchor, in_anchor),
"[CutH] failed to gather between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(),
src_anchor->GetIdx(),
op_desc->GetName().c_str(),
in_anchor->GetIdx());
MDS_REQUIRE_SUCCESS(
MdsUtils::DataGather(src_anchor, in_anchor), "[CutH] failed to gather between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx());
}
} else {
if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) {
MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_),
"[CutH] failed to slice between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(),
src_anchor->GetIdx(),
op_desc->GetName().c_str(),
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(),
in_anchor->GetIdx());
} else {
MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx(), true),
"[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(),
src_anchor->GetIdx(),
op_desc->GetName().c_str(),
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(),
in_anchor->GetIdx());
}
}
@@ -157,4 +139,4 @@ Status DeploySchedulerKernel::CutH(const ge::NodePtr &node) {
MDS_REQUIRE_SUCCESS(node->InferShapeAndType(), "[CutH] call infer shape failed", node->GetName().c_str());
return SUCCESS;
}
}
} // namespace ge

+ 6
- 2
ge/graph/passes/mds_kernels/base_mds_kernel.h View File

@@ -55,12 +55,16 @@ class DeploySchedulerKernel {
// halo exchange process
Status HaloExchangeProcess(NodePtr node, int64_t index, bool local_slice = false);
NodePtr GetInputNode() { return input_node_; }
NodePtr GetInputNode() {
return input_node_;
}
DeploySchedulerKernel &operator=(const DeploySchedulerKernel &kernel) = delete;
DeploySchedulerKernel(const DeploySchedulerKernel &kernel) = delete;
protected:
DeploySchedulerKernel() = default;
virtual ~DeploySchedulerKernel() = default;
private:
NodePtr input_node_ = nullptr;
};
@@ -69,4 +73,4 @@ namespace mds_cut_pass {
shared_ptr<DeploySchedulerKernel> GetKernelByType(const NodePtr &node);
}
} // namespace ge
#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_
#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_

+ 102
- 116
ge/graph/passes/mds_kernels/mds_utils.cc View File

@@ -21,39 +21,48 @@ thread_local int64_t data_slice_count = 0;
thread_local int64_t data_gather_count = 0;
thread_local int64_t data_reduce_count = 0;
const std::string kPrefix = "mds";
}
} // namespace
int64_t MdsUtils::GetNLocation(Format fmt) {
int64_t loc = kNInvalidLocation;
switch (fmt) {
case FORMAT_NCHW :
case FORMAT_NHWC :loc = kNLocation0;
case FORMAT_NCHW:
case FORMAT_NHWC:
loc = kNLocation0;
break;
case FORMAT_CHWN:
case FORMAT_HWCN:loc = kNLocation3;
case FORMAT_HWCN:
loc = kNLocation3;
break;
default:GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str());
default:
GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str());
}
return loc;
}
int64_t MdsUtils::GetHLocation(Format fmt) {
int64_t loc = kHInvalidLocation;
switch (fmt) {
case FORMAT_HWCN :loc = kHLocation0;
case FORMAT_HWCN:
loc = kHLocation0;
break;
case FORMAT_NHWC :
case FORMAT_CHWN :loc = kHLocation1;
case FORMAT_NHWC:
case FORMAT_CHWN:
loc = kHLocation1;
break;
case FORMAT_NCHW :loc = kHLocation2;
default :GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str());
case FORMAT_NCHW:
loc = kHLocation2;
default:
GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str());
}
return loc;
}
int64_t MdsUtils::GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type) {
Format fmt = ge_tensor_desc->GetFormat();
switch (type) {
case kCutN :return GetNLocation(fmt);
case kCutH :return GetHLocation(fmt);
default :;
case kCutN:
return GetNLocation(fmt);
case kCutH:
return GetHLocation(fmt);
default:;
}
GELOGE(FAILED, "[MDS]invalid CutType:%d", type);
return kInvalidIndex;
@@ -64,7 +73,7 @@ bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_des
GELOGE(FAILED, "[MDS]invalid input param: tensor is null!");
return false;
}
if (type != kCutN || type != kCutH) {
if (type != kCutN && type != kCutH) {
REPORT_INNER_ERROR("E19999", "invalid CutType:%d", type);
GELOGE(FAILED, "[MDS]invalid CutType:%d", type);
return false;
@@ -75,6 +84,18 @@ bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_des
GELOGE(FAILED, "[MDS]", "invalid index param:%ld", cut_index);
return false;
}
auto dims = ge_tensor_desc->GetShape().GetDims();
if (cut_index < 0 || cut_index >= dims.size()) {
REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type,
dims.size());
GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type,
dims.size());
return false;
}
if (dims[cut_index] % kDeployNumber != 0) {
GELOGW("[MDS] cut_index %ld for CutType %d with dim %ld can not deploy", cut_index, type, dims[cut_index]);
return false;
}
vector<int64_t> cut_support_info;
if (!(AttrUtils::GetListInt(*ge_tensor_desc, ATTR_NAME_CUT_INFO, cut_support_info))) {
REPORT_INNER_ERROR("E19999", "call GetlistInt failed");
@@ -82,10 +103,10 @@ bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_des
return false;
}
if (cut_index < 0 || cut_index >= cut_support_info.size()) {
REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu",
cut_index, type, cut_support_info.size());
GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu",
cut_index, type, cut_support_info.size());
REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index,
type, cut_support_info.size());
GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index,
type, cut_support_info.size());
return false;
}
if (cut_support_info[cut_index] < kNotSupport || cut_support_info[cut_index] > kAnyCutSupported) {
@@ -95,8 +116,7 @@ bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_des
}
return cut_support_info[cut_index] & kSplitCutSupported;
}
Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc,
CutType type, int64_t deploy_number) {
Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type, int64_t deploy_number) {
GE_CHECK_NOTNULL(ge_tensor_desc);
auto index = MdsUtils::GetIndexByFormat(ge_tensor_desc, type);
auto dims = ge_tensor_desc->GetShape().GetDims();
@@ -109,14 +129,10 @@ Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc,
}
Status MdsUtils::SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor, const std::string &group_name) {
GE_CHECK_NOTNULL(hcom_op);
REQUIRE(fission_factor > kDefaultFissionFactor,
"fission_factor %ld need be bigger than %ld",
fission_factor,
REQUIRE(fission_factor > kDefaultFissionFactor, "fission_factor %ld need be bigger than %ld", fission_factor,
kDefaultFissionFactor);
REQUIRE(ge::AttrUtils::SetInt(hcom_op, ATTR_NAME_FISSION_FACTOR, fission_factor),
"Failed to set attr fission_factor %ld for op:%s(%s)",
fission_factor,
hcom_op->GetName().c_str(),
"Failed to set attr fission_factor %ld for op:%s(%s)", fission_factor, hcom_op->GetName().c_str(),
hcom_op->GetType().c_str());
if (!group_name.empty()) {
@@ -131,7 +147,7 @@ bool MdsUtils::IsMDSNeeded() {
GELOGI("[MDS]device type is %s, skip mds", device_type.c_str());
return false;
}
//TODO:解析系统的配置文件得到exe unit
// TODO: Parse the configuration file of the system to get the sys_config_exe_unit
std::string sys_config_exe_unit = "DIE";
return device_type != sys_config_exe_unit;
}
@@ -153,18 +169,17 @@ Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodeP
GeAttrValue::NAMED_ATTRS thread_instance;
thread_instance.SetName(std::to_string(device_id));
(void) thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id));
// TODO:跟rts确认属性值
(void) thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>("MultiMode"));
(void) thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName()));
(void) thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(graph_inputs));
(void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id));
// TODO:Change to enumeration from RTS header file
(void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>("MultiMode"));
(void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName()));
(void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(graph_inputs));
deploy_info.emplace_back(thread_instance);
GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id);
}
// set deploy info
REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info),
"Set attr failed for graph %s",
compute_graph->GetName().c_str());
"Set attr failed for graph %s", compute_graph->GetName().c_str());
return SUCCESS;
}
@@ -176,7 +191,7 @@ CutType MdsUtils::TryGetGraphCutType(const ComputeGraphPtr &compute_graph) {
is_unknown_graph = true;
}
CutType selected_cut_type = kNoCut;
for (const auto &data: compute_graph->GetInputNodes()) {
for (const auto &data : compute_graph->GetInputNodes()) {
GELOGI("Get graph input %s %s", data->GetName().c_str(), data->GetType().c_str());
auto data_n_index = MdsUtils::GetIndexByFormat(data->GetOpDesc()->MutableOutputDesc(0), kCutN);
auto data_n_dim = data->GetOpDesc()->GetOutputDesc(0).GetShape().GetDim(data_n_index);
@@ -197,8 +212,7 @@ CutType MdsUtils::TryGetGraphCutType(const ComputeGraphPtr &compute_graph) {
return selected_cut_type;
}
Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph,
const std::multimap<DeviceId, GraphInputs> &deploys,
const std::string &device_type) {
const std::multimap<DeviceId, GraphInputs> &deploys, const std::string &device_type) {
GE_CHECK_NOTNULL(compute_graph);
GELOGD("[MDS]%s SetDeployInfo start", compute_graph->GetName().c_str());
@@ -208,19 +222,18 @@ Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph,
int64_t device_id = pair.first;
GeAttrValue::NAMED_ATTRS thread_instance;
thread_instance.SetName(std::to_string(device_id));
(void) thread_instance.SetAttr(kAttrNeedReturnResult,
GeAttrValue::CreateFrom<GeAttrValue::BOOL>(deploy_info.empty() ? true : false));
(void) thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id));
(void) thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>(device_type));
(void) thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName()));
(void) thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(pair.second));
(void)thread_instance.SetAttr(kAttrNeedReturnResult,
GeAttrValue::CreateFrom<GeAttrValue::BOOL>(deploy_info.empty() ? true : false));
(void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id));
(void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>(device_type));
(void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName()));
(void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(pair.second));
deploy_info.emplace_back(thread_instance);
GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id);
}
// set deploy info
REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info),
"Set attr failed for graph %s",
compute_graph->GetName().c_str());
"Set attr failed for graph %s", compute_graph->GetName().c_str());
return SUCCESS;
}
@@ -233,24 +246,20 @@ Status MdsUtils::DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &
auto src_graph = src_node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(src_graph);
std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_gather_count));
auto
hcom_allgather_node = AddDynamicInputOutputNode(src_graph, HCOMALLGATHER, HCOMALLGATHER + node_name_suffix, 1, 1);
auto hcom_allgather_node =
AddDynamicInputOutputNode(src_graph, HCOMALLGATHER, HCOMALLGATHER + node_name_suffix, 1, 1);
GE_CHECK_NOTNULL(hcom_allgather_node);
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, hcom_allgather_node),
"[DataGather] failed between %s and %s",
src_node->GetName().c_str(),
"[DataGather] failed between %s and %s", src_node->GetName().c_str(),
dst_node->GetName().c_str());
MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(hcom_allgather_node->GetOpDesc(), kDeployNumber, kDefaultGroup),
"[DataGather]set attr for node for %s(%s) failed",
hcom_allgather_node->GetName().c_str(), hcom_allgather_node->GetType().c_str());
"[DataGather]set attr for node for %s(%s) failed", hcom_allgather_node->GetName().c_str(),
hcom_allgather_node->GetType().c_str());
REQUIRE(ge::AttrUtils::SetInt(hcom_allgather_node->GetOpDesc(), HCOM_ATTR_RANK_SIZE, kDefaultRankSize),
"Failed to set attr reduction type %s for op:%s(%s)",
kDefaultReduction.c_str(),
hcom_allgather_node->GetName().c_str(),
hcom_allgather_node->GetType().c_str());
"Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(),
hcom_allgather_node->GetName().c_str(), hcom_allgather_node->GetType().c_str());
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(hcom_allgather_node, false),
"[DataGather] %s call infershape failed",
hcom_allgather_node->GetName().c_str());
"[DataGather] %s call infershape failed", hcom_allgather_node->GetName().c_str());
data_gather_count++;
return SUCCESS;
}
@@ -269,8 +278,7 @@ Status MdsUtils::DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &
NodePtr all_reduce_node = nullptr;
if (NeedInsertHcomAllReduce(src_node, all_reduce_node)) {
MDS_REQUIRE_SUCCESS(ConstructReduceNode(src_graph, src, dst, all_reduce_node),
"[DataReduce] construct allreduce node for %s failed",
all_reduce_node->GetName().c_str());
"[DataReduce] construct allreduce node for %s failed", all_reduce_node->GetName().c_str());
GE_CHECK_NOTNULL(all_reduce_node);
} else {
GE_CHECK_NOTNULL(all_reduce_node);
@@ -300,25 +308,19 @@ Status MdsUtils::DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &d
GeTensorDesc tensor = src_node->GetOpDesc()->GetOutputDesc(src->GetIdx());
NodePtr slice_node = nullptr;
MDS_REQUIRE_SUCCESS(ConstructSliceNode(src_graph, tensor, input_node.get(), slice_node),
"[DataSlice] construct slice node for %s failed",
src_node->GetName().c_str());
"[DataSlice] construct slice node for %s failed", src_node->GetName().c_str());
GE_CHECK_NOTNULL(slice_node);
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, slice_node),
"[DataSlice] failed between %s and %s",
src_node->GetName().c_str(),
dst_node->GetName().c_str());
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(slice_node, false),
"[DataSlice] %s call infer shape failed",
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, slice_node), "[DataSlice] failed between %s and %s",
src_node->GetName().c_str(), dst_node->GetName().c_str());
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(slice_node, false), "[DataSlice] %s call infer shape failed",
slice_node->GetName().c_str());
return SUCCESS;
}
Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph,
const GeTensorDesc &tensor,
Node *input_node,
Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *input_node,
NodePtr &slice_node) {
vector<int64_t> slice_sizes = tensor.GetShape().GetDims();
// TODO: 用图结构表达
// TODO: Express with graph structure
slice_sizes[0] /= kDeployNumber;
vector<GeTensorPtr> ge_tensors;
GeTensorDesc ge_tensor_desc;
@@ -346,10 +348,10 @@ Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph,
GE_CHECK_NOTNULL(input_node);
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(input_node->GetOutDataAnchor(0), mul_node->GetInDataAnchor(0)),
"[ConstructSliceNode] add edge failed");
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_slice_offset_first_dim->GetOutDataAnchor(0),
mul_node->GetInDataAnchor(1)), "[ConstructSliceNode] add edge failed");
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(mul_node, false),
"[DataSlice] %s call infer shape failed",
MDS_REQUIRE_SUCCESS(
GraphUtils::AddEdge(const_node_slice_offset_first_dim->GetOutDataAnchor(0), mul_node->GetInDataAnchor(1)),
"[ConstructSliceNode] add edge failed");
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(mul_node, false), "[DataSlice] %s call infer shape failed",
mul_node->GetName().c_str());
NodePtr pack_node = AddDynamicInputOutputNode(src_graph, PACK, PACK + node_name_suffix, slice_sizes.size(), 1);
bool is_first_input = true;
@@ -363,8 +365,7 @@ Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph,
"[ConstructSliceNode] add edge failed");
}
}
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(pack_node, false),
"[DataSlice] %s call infer shape failed",
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(pack_node, false), "[DataSlice] %s call infer shape failed",
pack_node->GetName().c_str());
slice_node = AddDynamicInputOutputNode(src_graph, SLICE, SLICE + node_name_suffix, 3, 1);
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(pack_node->GetOutDataAnchor(0), slice_node->GetInDataAnchor(1)),
@@ -375,9 +376,7 @@ Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph,
return SUCCESS;
}
NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph,
const string &name,
const string &type,
NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type,
const GeTensorDesc &tensor) {
GELOGI("Begin to create op: %s", name.c_str());
OpDescBuilder op_desc_builder(name, type);
@@ -389,20 +388,17 @@ NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph,
}
NodePtr node = graph->AddNode(op_desc);
if (node == nullptr) {
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(),
op_desc->GetType().c_str(), graph->GetName().c_str());
GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(),
graph->GetName().c_str());
return nullptr;
}
return node;
}
NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph,
const std::string &type,
const std::string &node_name,
size_t input_num,
size_t output_num) {
NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const std::string &type,
const std::string &node_name, size_t input_num, size_t output_num) {
GELOGI("Begin to create op: %s", node_name.c_str());
OpDescBuilder op_desc_builder(node_name, type);
OpDescPtr op_desc = op_desc_builder.AddDynamicInput("x", input_num).AddDynamicOutput("y", output_num).Build();
@@ -413,10 +409,10 @@ NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph,
}
NodePtr node = graph->AddNode(op_desc);
if (node == nullptr) {
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(),
op_desc->GetType().c_str(), graph->GetName().c_str());
GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(),
graph->GetName().c_str());
return nullptr;
}
return node;
@@ -437,27 +433,20 @@ NodePtr MdsUtils::AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr
return graph->AddNodeFront(const_desc);
}
Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph,
const OutDataAnchorPtr &src,
const InDataAnchorPtr &dst,
NodePtr &reduce_node) {
Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src,
const InDataAnchorPtr &dst, NodePtr &reduce_node) {
std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_reduce_count));
reduce_node = AddDynamicInputOutputNode(src_graph, HCOMALLREDUCE, HCOMALLREDUCE + node_name_suffix, 1, 1);
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, reduce_node),
"[DataReduce] failed insert %s between %s and %s", reduce_node->GetName().c_str(),
src->GetOwnerNode()->GetName().c_str(),
dst->GetOwnerNode()->GetName().c_str());
src->GetOwnerNode()->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str());
MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(reduce_node->GetOpDesc(), kDeployNumber, kDefaultGroup),
"[DataReduce][Create] set attr for allreduce node for %s failed",
reduce_node->GetName().c_str());
"[DataReduce][Create] set attr for allreduce node for %s failed", reduce_node->GetName().c_str());
REQUIRE(ge::AttrUtils::SetStr(reduce_node->GetOpDesc(), HCOM_ATTR_REDUCE_TYPE, kDefaultReduction),
"Failed to set attr reduction type %s for op:%s(%s)",
kDefaultReduction.c_str(),
reduce_node->GetName().c_str(),
reduce_node->GetType().c_str());
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(reduce_node, false),
"[DataReduce] %s call infershape failed",
"Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(),
reduce_node->GetName().c_str(), reduce_node->GetType().c_str());
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(reduce_node, false), "[DataReduce] %s call infershape failed",
reduce_node->GetName().c_str());
auto div_node = AddDynamicInputOutputNode(src_graph, REALDIV, REALDIV + node_name_suffix, 2, 1);
vector<int64_t> slice_sizes{kDeployNumber};
@@ -471,20 +460,17 @@ Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph,
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_div_input->GetOutDataAnchor(0), div_node->GetInDataAnchor(1)),
"[ConstructSliceNode] add edge failed");
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(reduce_node->GetOutDataAnchor(0), {dst}, div_node),
"[DataReduce] failed insert %s between %s and %s",
div_node->GetName().c_str(),
reduce_node->GetName().c_str(),
dst->GetOwnerNode()->GetName().c_str());
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(div_node, false),
"[DataReduce] %s call infershape failed",
"[DataReduce] failed insert %s between %s and %s", div_node->GetName().c_str(),
reduce_node->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str());
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(div_node, false), "[DataReduce] %s call infershape failed",
div_node->GetName().c_str());
return SUCCESS;
}
bool MdsUtils::NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node) {
//TODO: 识别到网络中本来就是多p的模型,即已经有了allreduce节点,则无需插入
// TODO: recognize that the graph is originally a multi-p model, that is, there is already an allreduce node,
// so there is no need to insert i
return true;
}
}
} // namespace ge

+ 27
- 36
ge/graph/passes/mds_kernels/mds_utils.h View File

@@ -30,13 +30,13 @@
#include "graph/utils/op_desc_utils.h"
#include "../pass_utils.h"
#define REQUIRE(cond, ...) \
do { \
if (!(cond)) { \
REPORT_INNER_ERROR("E19999", __VA_ARGS__); \
GELOGE(FAILED, "[MDS]" __VA_ARGS__); \
return FAILED; \
} \
#define REQUIRE(cond, ...) \
do { \
if (!(cond)) { \
REPORT_INNER_ERROR("E19999", __VA_ARGS__); \
GELOGE(FAILED, "[MDS]" __VA_ARGS__); \
return FAILED; \
} \
} while (0)
#define MDS_REQUIRE_NOT_NULL(cond, ...) REQUIRE(((cond) != nullptr), __VA_ARGS__)
@@ -44,7 +44,7 @@
#define MDS_REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__)
namespace ge {
namespace {
//Invalid location index
// Invalid location index
const int64_t kInvalidIndex = -1;
enum NCutIndex { kNLocation0 = 0, kNLocation1, kNLocation2, kNLocation3, kNInvalidLocation = -1 };
enum HCutIndex { kHLocation0 = 0, kHLocation1, kHLocation2, kHLocation3, kHInvalidLocation = -1 };
@@ -69,10 +69,9 @@ const std::string kDefaultReduction = "sum";
const char *const kDefaultDeviceType = "DEFAULT_DEVICE_TYPE";
const char *const kDefaultExecUnit = "DEFAULT_DEVICE_TYPE";
//deploy info
// deploy info
const char *const kAttrNeedReturnResult = "_need_return_result";
const char *const kAttrDeviceType = "_device_type";
// TODO:跟rts确认属性值
const char *const kDieDeviceTypeValue = "MultiMode";
const char *const kAttrDeviceId = "_device_id";
const char *const kAttrGraphName = "_graph_name";
@@ -80,7 +79,7 @@ const char *const kAttrGraphInputs = "_graph_inputs";
using GraphInputs = vector<GeTensorPtr>;
using DeviceId = int64_t;
using GraphInputNodes = vector<NodePtr>;
}
} // namespace
class MdsUtils {
public:
// Parse the configuration file and determine whether to enable MDS based on the value of device_type.
@@ -90,24 +89,25 @@ class MdsUtils {
static int64_t GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type);
static bool IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type);
static Status SetAttrForHcomNode(const OpDescPtr &hcom_op,
int64_t fission_factor,
static Status SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor,
const std::string &group_name = "");
/// @param [in] index 切分的轴
/// @param [in] deploy_number 切分的份数
static Status DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc,
CutType type,
static Status DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type,
int64_t deploy_number = kDeployNumber);
// Sets the information, notifies the number of threads to be started during the
// loading phase, the device on which each thread should run, and constructs different input data on each device.
static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node);
static Status SetDeployInfo(const ComputeGraphPtr &compute_graph,
const std::multimap<DeviceId, GraphInputs> &deploys,
static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const std::multimap<DeviceId, GraphInputs> &deploys,
const std::string &device_type = kDieDeviceTypeValue);
// Get cut policy for whole graph
static CutType TryGetGraphCutType(const ComputeGraphPtr &compute_graph);
static GraphInputNodes GetInputNodes() { return input_nodes_; }
static void AddInputNode(const NodePtr &input_node) { input_nodes_.push_back(input_node); }
static GraphInputNodes GetInputNodes() {
return input_nodes_;
}
static void AddInputNode(const NodePtr &input_node) {
input_nodes_.push_back(input_node);
}
static Status DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
static Status DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
@@ -115,25 +115,16 @@ class MdsUtils {
private:
static GraphInputNodes input_nodes_;
static NodePtr AddDynamicInputOutputNode(const ComputeGraphPtr &graph,
const string &type,
const string &node_name,
size_t input_num,
size_t output_num);
static NodePtr AddSingleInputOutputNode(const ComputeGraphPtr &graph,
const string &name,
const string &type,
static NodePtr AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const string &type, const string &node_name,
size_t input_num, size_t output_num);
static NodePtr AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type,
const GeTensorDesc &tensor = GeTensorDesc());
static Status ConstructReduceNode(const ComputeGraphPtr &src_graph,
const OutDataAnchorPtr &src,
const InDataAnchorPtr &dst,
NodePtr &reduce_node);
static Status ConstructSliceNode(const ComputeGraphPtr &src_graph,
const GeTensorDesc &tensor,
Node *node,
static Status ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src,
const InDataAnchorPtr &dst, NodePtr &reduce_node);
static Status ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *node,
NodePtr &slice_node);
static bool NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node);
static NodePtr AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph);
};
}
#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_
} // namespace ge
#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_

+ 29
- 37
ge/graph/passes/mds_pass.cc View File

@@ -43,30 +43,30 @@ Status ModelDeploySchedulerPass::CutProcess() {
}
auto type = MdsUtils::TryGetGraphCutType(compute_graph_);
switch (type) {
case kCutN:MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_),
"[MDS][CutNProcessImply] failed, graph_name:[%s]",
GetGraphName());
case kCutN:
MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_), "[MDS][CutNProcessImply] failed, graph_name:[%s]",
GetGraphName());
break;
case kCutH:MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_),
"[MDS][CutHProcessImply] failed, graph_name:[%s]",
GetGraphName());
case kCutH:
MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_), "[MDS][CutHProcessImply] failed, graph_name:[%s]",
GetGraphName());
break;
case kDynamicCutN:MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_, true),
"[MDS][CutNProcessImply] failed, graph_name:[%s]",
GetGraphName());
case kDynamicCutN:
MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]",
GetGraphName());
break;
case kDynamicCutH:MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_, true),
"[MDS][CutHProcessImply] failed, graph_name:[%s]",
GetGraphName());
case kDynamicCutH:
MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]",
GetGraphName());
break;
case kDynamicCutAll:MDS_REQUIRE_SUCCESS(DynamicCutAll(compute_graph_),
"[MDS][DynamicCutAll] failed, graph_name:[%s]",
GetGraphName());
case kDynamicCutAll:
MDS_REQUIRE_SUCCESS(DynamicCutAll(compute_graph_), "[MDS][DynamicCutAll] failed, graph_name:[%s]",
GetGraphName());
break;
default:GELOGI("[MDS][CutProcess] could not cut, just return. graph_name:[%s]", GetGraphName());
default:
GELOGI("[MDS][CutProcess] could not cut, just return. graph_name:[%s]", GetGraphName());
return SUCCESS;
}
}
Status ModelDeploySchedulerPass::CutNProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic) {
@@ -78,21 +78,19 @@ Status ModelDeploySchedulerPass::CutNProcessImply(const ComputeGraphPtr &compute
op_kernel = DeploySchedulerKernel::Instance();
}
if (is_dynamic) {
MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutN(node),
"[MDS][DYNAMIC_CUTN] failed, node:[%s]",
MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutN(node), "[MDS][DYNAMIC_CUTN] failed, node:[%s]",
node->GetName().c_str());
} else {
MDS_REQUIRE_SUCCESS(op_kernel->CutN(node), "[MDS][CUTN] failed, node:[%s]", node->GetName().c_str());
}
bool is_grad_compute_node = false;
if (ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node)
&& is_grad_compute_node) {
if (ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) &&
is_grad_compute_node) {
grad_compute_nodes_.push_back(node);
}
}
//TODO:针对单输出多引用插入的allgather,allreduce节点做广度融合优化
MDS_REQUIRE_SUCCESS(HcomNodeFusionProcess(),
"[MDS][CUTN][HcomNodeFusionProcess] failed, compute graph:[%s]",
// TODO:for single output multi reference insertion allgather, allreduce nodes, do breadth fusion optimization
MDS_REQUIRE_SUCCESS(HcomNodeFusionProcess(), "[MDS][CUTN][HcomNodeFusionProcess] failed, compute graph:[%s]",
compute_graph->GetName().c_str());
return SUCCESS;
}
@@ -105,12 +103,10 @@ Status ModelDeploySchedulerPass::CutHProcessImply(const ComputeGraphPtr &compute
op_kernel = DeploySchedulerKernel::Instance();
}
if (is_dynamic) {
MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutH(node),
"[MDS][DYNAMIC_CUTH] failed, node:[%s]",
MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutH(node), "[MDS][DYNAMIC_CUTH] failed, node:[%s]",
node->GetName().c_str());
} else {
MDS_REQUIRE_SUCCESS(op_kernel->CutH(node), "[MDS][CUTH] failed, node:[%s]", node->GetName().c_str());
}
}
return SUCCESS;
@@ -121,13 +117,11 @@ Status ModelDeploySchedulerPass::DynamicCutAll(const ComputeGraphPtr &compute_gr
std::vector<NodePtr> output_nodes;
auto compute_graph0 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes);
auto compute_graph1 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes);
MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph0, true),
"[MDS][CutNProcessImply] failed, graph_name:[%s]",
MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph0, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]",
compute_graph0->GetName().c_str());
MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph1, true),
"[MDS][CutHProcessImply] failed, graph_name:[%s]",
MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph1, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]",
compute_graph1->GetName().c_str());
//TODO:创建case节点,把两个图放在case的两个分支下,case节点添加到原来的compute_graph中,构造case节点的输入
// TODO:Create a case node, put the two graphs under the two branches of case
return SUCCESS;
}
@@ -143,9 +137,8 @@ Status ModelDeploySchedulerPass::SMDPProcess(bool before_cut) {
Status ModelDeploySchedulerPass::SetDeployInfo() {
vector<GeAttrValue::NAMED_ATTRS> deployInfo;
REQUIRE (!ge::AttrUtils::GetListNamedAttrs(compute_graph_, ATTR_NAME_DEPLOY_INFO, deployInfo),
"%s already has deployed before!",
GetGraphName());
REQUIRE(!ge::AttrUtils::GetListNamedAttrs(compute_graph_, ATTR_NAME_DEPLOY_INFO, deployInfo),
"%s already has deployed before!", GetGraphName());
std::multimap<DeviceId, GraphInputs> deploys;
for (int64_t j = 0; j < kDeployNumber; j++) {
int64_t device_id = j;
@@ -179,7 +172,6 @@ Status ModelDeploySchedulerPass::SMDPWeight() {
return SUCCESS;
}
Status ModelDeploySchedulerPass::SMDPGradient() {
//TDOD:标识buffer poolid
return SUCCESS;
}
}
} // namespace ge

+ 9
- 6
ge/graph/passes/mds_pass.h View File

@@ -28,6 +28,7 @@ namespace ge {
class ModelDeploySchedulerPass : public GraphPass {
public:
Status Run(ge::ComputeGraphPtr graph) override;
private:
// Part0:Process Func
// cut and dynamic cut
@@ -47,22 +48,24 @@ class ModelDeploySchedulerPass : public GraphPass {
// set delpoyinfo
Status SetDeployInfo();
//Part1: Utils Func
// std::vector<bool> GetNodeInputsSupportCut(NodePtr node, uint64_t cut_index);
// std::vector<bool> GetNodeOutputsSupportCut(NodePtr node, uint64_t cut_index);
// Part1: Utils Func
// std::vector<bool> GetNodeInputsSupportCut(NodePtr node, uint64_t cut_index);
// std::vector<bool> GetNodeOutputsSupportCut(NodePtr node, uint64_t cut_index);
Status HcomNodeFusionProcess();
Status GetAllModelStateVar();
Status GetAllWeightVar();
std::vector<NodePtr> GetAllGradComputeNodes() { return grad_compute_nodes_; }
std::vector<NodePtr> GetAllGradComputeNodes() {
return grad_compute_nodes_;
}
const char *GetGraphName() const {
return compute_graph_->GetName().c_str();
}
//members
// members
std::vector<NodePtr> model_state_vars_;
std::vector<NodePtr> model_weight_vars_;
std::vector<NodePtr> grad_compute_nodes_;
ComputeGraphPtr compute_graph_ = nullptr;
};
} // namespace ge
#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_
#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_

Loading…
Cancel
Save