Browse Source

Pre Merge pull request !2000 from gengchao/mds

pull/2000/MERGE
gengchao Gitee 3 years ago
parent
commit
7c8aa8cbc6
31 changed files with 1698 additions and 106 deletions
  1. +5
    -0
      ge/CMakeLists.txt
  2. +8
    -0
      ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc
  3. +2
    -0
      ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h
  4. +78
    -4
      ge/graph/execute/graph_execute.cc
  5. +8
    -1
      ge/graph/execute/graph_execute.h
  6. +18
    -8
      ge/graph/execute/model_executor.cc
  7. +1
    -2
      ge/graph/execute/model_executor.h
  8. +137
    -37
      ge/graph/load/graph_loader.cc
  9. +12
    -2
      ge/graph/load/graph_loader.h
  10. +8
    -10
      ge/graph/load/model_manager/davinci_model.cc
  11. +8
    -1
      ge/graph/load/model_manager/davinci_model.h
  12. +3
    -2
      ge/graph/load/model_manager/model_manager.cc
  13. +1
    -1
      ge/graph/load/model_manager/model_manager.h
  14. +33
    -1
      ge/graph/manager/graph_manager.cc
  15. +1
    -1
      ge/graph/manager/graph_manager.h
  16. +52
    -22
      ge/graph/manager/graph_mem_allocator.cc
  17. +9
    -4
      ge/graph/manager/graph_mem_allocator.h
  18. +7
    -7
      ge/graph/manager/graph_var_manager.cc
  19. +4
    -3
      ge/graph/manager/graph_var_manager.h
  20. +142
    -0
      ge/graph/passes/mds_kernels/base_mds_kernel.cc
  21. +76
    -0
      ge/graph/passes/mds_kernels/base_mds_kernel.h
  22. +30
    -0
      ge/graph/passes/mds_kernels/conv2d_mds_kernel.cc
  23. +29
    -0
      ge/graph/passes/mds_kernels/conv2d_mds_kernel.h
  24. +102
    -0
      ge/graph/passes/mds_kernels/mds_kernel_factory.h
  25. +476
    -0
      ge/graph/passes/mds_kernels/mds_utils.cc
  26. +130
    -0
      ge/graph/passes/mds_kernels/mds_utils.h
  27. +41
    -0
      ge/graph/passes/mds_kernels/variable_mds_kernel.cc
  28. +28
    -0
      ge/graph/passes/mds_kernels/variable_mds_kernel.h
  29. +177
    -0
      ge/graph/passes/mds_pass.cc
  30. +71
    -0
      ge/graph/passes/mds_pass.h
  31. +1
    -0
      inc/external/ge/ge_api_types.h

+ 5
- 0
ge/CMakeLists.txt View File

@@ -168,6 +168,11 @@ set(EXECUTOR_SRC_LIST
"graph/manager/util/debug.cc"
#"graph/manager/util/hcom_util.cc" # Just for runner.
"graph/passes/pass_utils.cc"
"graph/passes/mds_pass.cc"
"graph/passes/mds_kernels/mds_utils.cc"
"graph/passes/mds_kernels/variable_mds_kernel.cc"
"graph/passes/mds_kernels/conv2d_mds_kernel.cc"
"graph/passes/mds_kernels/base_mds_kernel.cc"
"host_kernels/add_kernel.cc"
"host_kernels/broadcast_args_kernel.cc"
"host_kernels/broadcast_gradient_args_kernel.cc"


+ 8
- 0
ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc View File

@@ -76,5 +76,13 @@ Status GeLocalOpsKernelInfoStore::DestroySession(const map<string, string> &sess
// Do nothing
return SUCCESS;
}
Status GeLocalOpsKernelInfoStore::SetCutSupportedInfo(const NodePtr &node) {
// TODO:
// 1. Whether the variable type is identified as a trainable variable
// 2, whether to turn on smdp1 and 3
// To meet the above two points, set the current variable
// node to be tangent in the variable segmentation information
return SUCCESS;
}
} // namespace ge_local
} // namespace ge

+ 2
- 0
ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h View File

@@ -86,6 +86,8 @@ class GE_FUNC_VISIBILITY GeLocalOpsKernelInfoStore : public OpsKernelInfoStore {
*/
Status DestroySession(const std::map<std::string, std::string> &session_options) override;

Status SetCutSupportedInfo(const ge::NodePtr &node) override;

// Copy prohibited
GeLocalOpsKernelInfoStore(const GeLocalOpsKernelInfoStore &ops_kernel_store) = delete;



+ 78
- 4
ge/graph/execute/graph_execute.cc View File

@@ -22,8 +22,18 @@
#include "graph/load/model_manager/model_manager.h"
#include "graph/load/model_manager/davinci_model.h"
#include "common/profiling/profiling_manager.h"
#include "graph/debug/ge_attr_define.h"
#include "common/thread_pool.h"

namespace ge {
namespace {
//deploy info
const char *const kAttrDeviceType = "_device_type";
const char *const kAttrDeviceId = "_device_id";
const char *const kAttrGraphName = "_graph_name";
const char *const kAttrGraphInputs = "_graph_inputs";
const char *const kAttrNeedReturnResult = "_need_return_result";
}
using Uint32Pair = pair<uint32_t, uint32_t>;
const uint32_t kInvalidModelId = UINT32_MAX;
GraphExecutor::GraphExecutor()
@@ -386,7 +396,14 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &
}
last_graph_id_ = graph_id;
GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED);
Status ret = AsyncExecuteModel(ge_root_model, input_tensor, callback);
vector<GeAttrValue::NAMED_ATTRS> deployInfo;
ModelIdInfo model_id_info;
Status ret;
if (ge::AttrUtils::GetListNamedAttrs(ge_root_model->GetRootGraph(), ATTR_NAME_DEPLOY_INFO, deployInfo)) {
ret = AsyncMultiExecuteModel(ge_root_model, input_tensor, callback);
} else {
ret = AsyncExecuteModel(ge_root_model, GetExecuteModelId(ge_root_model), input_tensor, callback);
}
if (ret != SUCCESS) {
GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[AsyncExecute][Model] Error! graph id:%u", graph_id);
return GE_GRAPH_SYNC_MODEL_FAILED;
@@ -522,10 +539,67 @@ Status GraphExecutor::SetCallback(uint32_t model_id, const GeRootModelPtr &ge_ro
}
return SUCCESS;
}
Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<ge::Tensor> &inputs,
const RunAsyncCallback &callback) {
// get deploy number of model instance
auto root_graph = ge_root_model->GetRootGraph();
vector<GeAttrValue::NAMED_ATTRS> deploy_info;
if (!ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deploy_info) || deploy_info.empty()) {
GELOGE(FAILED, "[AsyncMultiExecuteModel] graph %s has invalid deploy attr %s", root_graph->GetName().c_str(),
ATTR_NAME_DEPLOY_INFO.c_str());
return FAILED;
}
auto thread_instances_size = deploy_info.size();
auto model_ids = ge_root_model->GetAllModelId();
if (model_ids.size() != thread_instances_size) {
GELOGE(FAILED,
"[AsyncMultiExecuteModel] something wrong, attr deploy numbers %zu should be equal to loaded models %zu",
thread_instances_size, model_ids.size());
return FAILED;
}

ThreadPool executor(thread_instances_size);
std::vector<std::future<Status>> vector_future;
for (size_t i = 0; i < thread_instances_size; ++i) {
auto thread_instance = deploy_info[i];
std::vector<GeTensorPtr> graph_inputs;
if (ge::AttrUtils::MutableListTensor(thread_instance, kAttrGraphInputs, graph_inputs)) {
std::vector<ge::Tensor> graph_input_updated(inputs.begin(), inputs.end());
for (const auto &ge_tensor_ptr : graph_inputs) {
graph_input_updated.push_back(TensorAdapter::AsTensor(*ge_tensor_ptr));
}
GraphExecutor graph_executor;
ExecuteModelFunc execute_model_func(&GraphExecutor::AsyncExecuteModel);
std::future<Status> f;
bool need_return_result = false;
if ((ge::AttrUtils::GetBool(thread_instance, kAttrNeedReturnResult, need_return_result) && need_return_result)) {
f = executor.commit(execute_model_func, &graph_executor, ge_root_model, model_ids[i], graph_input_updated,
callback);
} else {
RunAsyncCallback callback_stub;
f = executor.commit(execute_model_func, &graph_executor, ge_root_model, model_ids[i], graph_input_updated,
callback_stub);
}
if (!f.valid()) {
GELOGE(FAILED, "[Call][Commit] failed, Future is invalid");
return FAILED;
}
vector_future.emplace_back(std::move(f));
}
}
for (size_t i = 0; i < vector_future.size(); ++i) {
Status ret_status = vector_future[i].get();
if (ret_status != SUCCESS) {
REPORT_CALL_ERROR("E19999", " Execute multi model %zu failed", i);
GELOGE(ret_status, "[AsyncMultiExecuteModel] Execute multi model failed", i);
return ret_status;
}
}
return SUCCESS;
}

Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<ge::Tensor> &inputs,
const RunAsyncCallback &callback) {
uint32_t model_id = GetExecuteModelId(ge_root_model);
Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, uint32_t model_id,
const std::vector<ge::Tensor> &inputs, const RunAsyncCallback &callback) {
if (model_id == kInvalidModelId) {
GELOGE(INTERNAL_ERROR, "No valid model id.");
return INTERNAL_ERROR;


+ 8
- 1
ge/graph/execute/graph_execute.h View File

@@ -136,8 +136,10 @@ class GraphExecutor {
Status SyncExecuteModel(uint32_t model_id, const std::vector<GeTensor> &input_tensor,
std::vector<GeTensor> &output_tensor);

Status AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<ge::Tensor> &input_tensor,
Status AsyncExecuteModel(const GeRootModelPtr &ge_root_model, uint32_t model_id, const std::vector<ge::Tensor> &input_tensor,
const RunAsyncCallback &callback);
Status AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<ge::Tensor> &input_tensor,
const RunAsyncCallback &callback);

void InitModelIdInfo(std::vector<uint32_t> &out_model_id_info, std::vector<SubGraphInfoPtr> &sub_graph_vec,
uint32_t output_size);
@@ -170,6 +172,11 @@ class GraphExecutor {
std::vector<void *> buffer_addr_;
std::vector<uint64_t> buffer_size_;
};
using ExecuteModelFunc = std::function<Status(GraphExecutor *,
const GeRootModelPtr &ge_root_model,
uint32_t model_id,
const std::vector<ge::Tensor> &inputs,
const RunAsyncCallback &callback)>;
} // namespace ge

#endif // GE_GRAPH_EXECUTE_GRAPH_EXECUTE_H_

+ 18
- 8
ge/graph/execute/model_executor.cc View File

@@ -325,34 +325,45 @@ Status ModelExecutor::RunGraphWithStream(const GraphNodePtr &graph_node, GraphId

Status ModelExecutor::ModelLoadSync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) {
ge_root_model->SetIsSpecificStream(graph_node->IsSpecificStream());
return ModelLoad(ge_root_model, graph_node, graph_run_listener_);
return ModelLoad(ge_root_model, graph_node, false);
}

Status ModelExecutor::ModelLoadAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) {
auto listener = MakeShared<RunAsyncListener>();
GE_CHECK_NOTNULL(listener);
return ModelLoad(ge_root_model, graph_node, listener);
return ModelLoad(ge_root_model, graph_node, true);
}

Status ModelExecutor::ModelLoad(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node,
const std::shared_ptr<ModelListener> &listener) {
bool is_async) {
ge_root_model->SetTrainFlag(train_graph_flag_);
bool is_unknown_shape = false;
GE_CHK_STATUS_RET(ge_root_model->CheckIsUnknownShape(is_unknown_shape));
auto root_graph = ge_root_model->GetRootGraph();
if (!is_unknown_shape) {
if (getenv(kEnvGeuseStaticMemory) != nullptr) {
GELOGI("[LoadGraph] GE_USE_STATIC_MEMORY is seted.");
} else {
auto root_graph = ge_root_model->GetRootGraph();
GE_CHECK_NOTNULL(root_graph);
auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel();
GeModelPtr ge_model = name_to_model[root_graph->GetName()];
GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node));
}
}
std::shared_ptr<ModelListener> listener =
is_async ? std::dynamic_pointer_cast<ModelListener>(MakeShared<RunAsyncListener>()) : std::dynamic_pointer_cast<
ModelListener>(graph_run_listener_);
GE_TIMESTAMP_START(LoadModelOnline);
uint32_t model_id = INVALID_MODEL_ID;
Status ret = GraphLoader::LoadModelOnline(model_id, ge_root_model, listener);
vector<GeAttrValue::NAMED_ATTRS> deployInfo;
Status ret;
if (ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deployInfo)) {
ret = GraphLoader::LoadMultiModelOnline(ge_root_model, is_async);
} else {
ret = GraphLoader::LoadModelOnline(model_id,
ge_root_model,
listener,
GetContext().DeviceId(),
kInvalidDieId);
}
GE_TIMESTAMP_EVENT_END(LoadModelOnline, "GraphLoader::LoadModelOnline");
if (ret != SUCCESS) {
GELOGE(ret, "[Load][ModelOnline] Failed, model_id:%u", model_id);
@@ -360,7 +371,6 @@ Status ModelExecutor::ModelLoad(const GeRootModelPtr &ge_root_model, const Graph
return ret;
}
graph_node->SetLoadFlag(true);
ge_root_model->SetModelId(model_id);
graph_node->SetGeRootModel(ge_root_model);
AddGraphNode(graph_node->GetGraphId(), graph_node);
return SUCCESS;


+ 1
- 2
ge/graph/execute/model_executor.h View File

@@ -98,8 +98,7 @@ class ModelExecutor : public Executor {

Status ModelLoadSync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node);
Status ModelLoadAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node);
Status ModelLoad(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node,
const std::shared_ptr<ModelListener> &listener);
Status ModelLoad(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node, bool is_async);

Status UnloadModel(const GeRootModelPtr &ge_root_model, uint32_t graph_id);



+ 137
- 37
ge/graph/load/graph_loader.cc View File

@@ -18,14 +18,24 @@

#include <string>
#include <vector>
#include <thread>

#include "framework/common/helper/model_helper.h"
#include "common/model_parser/model_parser.h"
#include "graph/ge_context.h"
#include "graph/load/model_manager/model_manager.h"
#include "graph/manager/graph_var_manager.h"
#include "graph/debug/ge_attr_define.h"
#include "common/thread_pool.h"

namespace ge {
namespace {
//deploy info
const char *const kAttrDeviceType = "_device_type";
const char *const kAttrDeviceId = "_device_id";
const char *const kAttrGraphName = "_graph_name";
const char *const kAttrGraphInputs = "_graph_inputs";
}
Status GraphLoader::UnloadModel(uint32_t model_id) {
auto model_manager = ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager);
@@ -45,43 +55,81 @@ Status GraphLoader::UnloadModel(uint32_t model_id) {
return SUCCESS;
}

Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeRootModel> &ge_root_model_ptr,
const std::shared_ptr<ModelListener> &listener) {
GELOGI("Load model online begin.");
rtError_t rt_ret = rtSetDevice(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret);
GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret);
Status GraphLoader::SetDevice(uint32_t device_id, int64_t die_id) {
if (device_id != kInvalidDeviceId && die_id != kInvalidDieId) {
rtError_t rt_ret = rtSetDevice(device_id, kMultiMode);
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", device_id, rt_ret);
GELOGE(RT_FAILED, "[Call][rtSetDevice] failed, device_id:%u, ret:0x%X", device_id, rt_ret);
return RT_FAILED;
}
rt_ret = rtSetDieId(die_id);
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtSetDieId failed, device_id:%u, ret:0x%X", die_id, rt_ret);
GELOGE(RT_FAILED, "[Call][RtSetDevice] rtSetDieId, device_id:%u, ret:0x%X", die_id, rt_ret);
return RT_FAILED;
}
} else if (device_id != kInvalidDeviceId && die_id == kInvalidDieId) {
rtError_t rt_ret = rtSetDevice(device_id);
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", device_id, rt_ret);
GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", device_id, rt_ret);
return RT_FAILED;
}
} else {
REPORT_CALL_ERROR("E19999", "Call SetDevice failed, device_id:%u, die_id:%ld", device_id, die_id);
GELOGE(RT_FAILED, "[Call][SetDevice] failed, device_id:%u, die_id:%ld", device_id, die_id);
return RT_FAILED;
}
return SUCCESS;
}

Status GraphLoader::ResetDevice(uint32_t device_id, int64_t die_id) {
if (die_id != kInvalidDieId) {
rtError_t rt_ret = rtDieReset(die_id);
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", die_id, rt_ret);
GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", die_id, rt_ret);
return RT_FAILED;
}
} else {
rtError_t rt_ret = rtDeviceReset(device_id);
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", device_id, rt_ret);
GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", device_id, rt_ret);
return RT_FAILED;
}
}
return SUCCESS;
}

Status GraphLoader::LoadModelOnline(uint32_t &model_id,
const std::shared_ptr<ge::GeRootModel> &ge_root_model_ptr,
const std::shared_ptr<ModelListener> &listener,
uint32_t device_id,
int64_t die_id) {
GELOGI("Load model online begin.");
if (ge_root_model_ptr == nullptr) {
REPORT_INNER_ERROR("E19999", "Check param ge_root_model_ptr nullptr, check invalid");
GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph][Check][Param] GE load graph model_ptr is nullptr.");
return GE_GRAPH_PARAM_NULLPTR;
}

if (SetDevice(device_id, die_id) != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Call SetDevice failed, device_id:%u", device_id);
GELOGE(RT_FAILED, "[Call][SetDevice] failed, device_id:%u", device_id);
return RT_FAILED;
}
GE_MAKE_GUARD(reset_device, [&] { GE_CHK_RT(ResetDevice(device_id, die_id)); });
auto model_manager = ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager);
Status ret = model_manager->LoadModelOnline(model_id, ge_root_model_ptr, listener);
Status ret = model_manager->LoadModelOnline(model_id, ge_root_model_ptr, listener,device_id, die_id);
if (ret != SUCCESS) {
GELOGE(ret, "[Load][Model] Online failed. ret = %u, model_id:%u", ret, model_id);
rt_ret = rtDeviceReset(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X",
GetContext().DeviceId(), rt_ret);
GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret);
}
return ret;
}

ge_root_model_ptr->SetModelId(model_id);
if (ge_root_model_ptr->IsSpecificStream()) {
GELOGI("No need to start a new thread to run model in specific scene.");
rt_ret = rtDeviceReset(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X",
GetContext().DeviceId(), rt_ret);
GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret);
}
return SUCCESS;
}
ret = model_manager->Start(model_id);
@@ -89,25 +137,77 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge
if (model_manager->Unload(model_id) != SUCCESS) {
GELOGE(ret, "[Unload][Model] failed while trying to unload after a failed start, model_id:%u.", model_id);
}

rt_ret = rtDeviceReset(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X",
GetContext().DeviceId(), rt_ret);
GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret);
}

GELOGE(ret, "[Start][Model] failed, model_id:%u.", model_id);
return ret;
}
rt_ret = rtDeviceReset(GetContext().DeviceId());
if (rt_ret != RT_ERROR_NONE) {
REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X",
GetContext().DeviceId(), rt_ret);
GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret);
return RT_FAILED;
}
GELOGI("Load model online success, model_id:%u.", model_id);
return SUCCESS;
}

Status GraphLoader::LoadMultiModelOnline(const std::shared_ptr<ge::GeRootModel> &ge_root_model, bool is_async) {
// get deploy number of model instance
auto root_graph = ge_root_model->GetRootGraph();
vector<GeAttrValue::NAMED_ATTRS> deploy_info;
if (!ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deploy_info) || deploy_info.empty()) {
GELOGE(FAILED, "[LoadMultiModelOnline] Load multi model failed, graph %s has invalid deploy attr %s",
root_graph->GetName().c_str(), ATTR_NAME_DEPLOY_INFO.c_str());
return FAILED;
}
auto thread_instances_size = deploy_info.size();
auto device_id_fission_from = GetContext().DeviceId();
GELOGI("Graph %s need to load model %zu times, and fission from device %u.", root_graph->GetName().c_str(),
thread_instances_size, device_id_fission_from);
ThreadPool executor(thread_instances_size);
std::vector<std::future<Status>> vector_future;
GE_TIMESTAMP_START(LoadModelOnline);
for (size_t i = 0; i < thread_instances_size; ++i) {
auto thread_instance = deploy_info[i];
std::string device_type;
ModelIdInfo model_id_info;
std::shared_ptr<ModelListener> listener;
if (is_async) {
listener = MakeShared<RunAsyncListener>();
GE_CHECK_NOTNULL(listener);
} else {
// TODO: GraphModelListener for sync
}
int64_t device_id_fissioned = kInvalidDieId;
if (!ge::AttrUtils::GetInt(thread_instance, kAttrDeviceId, device_id_fissioned) ||
device_id_fissioned == kInvalidDieId) {
REPORT_CALL_ERROR("E19999", "graph %s has invalid deploy attr %s", root_graph->GetName().c_str(),
ATTR_NAME_DEPLOY_INFO.c_str());
GELOGE(GRAPH_FAILED, "[LoadMultiModelOnline] graph %s has invalid deploy attr %s", root_graph->GetName().c_str(),
ATTR_NAME_DEPLOY_INFO.c_str());
return GRAPH_FAILED;
};
if (ge::AttrUtils::GetStr(thread_instance, kAttrDeviceType, device_type) && device_type == kMultiMode) {
std::future<Status> f = executor.commit(GraphLoader::LoadModelOnline, model_id_info.model_id, ge_root_model,
listener, device_id_fission_from, device_id_fissioned);
if (!f.valid()) {
GELOGE(FAILED, "[Call][Commit] failed, Future is invalid");
return FAILED;
}
vector_future.emplace_back(std::move(f));
} else {
std::future<Status> f = executor.commit(GraphLoader::LoadModelOnline, model_id_info.model_id, ge_root_model,
listener, device_id_fissioned, kInvalidDieId);
if (!f.valid()) {
GELOGE(FAILED, "[Call][Commit] failed, Future is invalid");
return FAILED;
}
vector_future.emplace_back(std::move(f));
}
}
GE_TIMESTAMP_EVENT_END(LoadModelOnline, "GraphLoader::LoadModelOnline");

for (size_t i = 0; i < vector_future.size(); ++i) {
Status ret_status = vector_future[i].get();
if (ret_status != SUCCESS) {
REPORT_CALL_ERROR("E19999", " Load multi model %zu failed", i);
GELOGE(ret_status, "[LoadMultiModelOnline] Load multi model failed", i);
return ret_status;
}
}

return SUCCESS;
}


+ 12
- 2
ge/graph/load/graph_loader.h View File

@@ -30,6 +30,13 @@
#include "runtime/mem.h"

namespace ge {
namespace {
const int64_t kInvalidDieId = -1;
const uint32_t kInvalidDeviceId = UINT32_MAX;
const char* kMultiMode ="MultiMode";
const char* kSingleMode ="SingleMode";
}

class GraphLoader {
public:
GraphLoader() = default;
@@ -64,9 +71,12 @@ class GraphLoader {
static Status DestroyAicpuKernel(uint64_t session_id, uint32_t model_id, uint32_t sub_model_id);

static Status DestroyAicpuSessionForInfer(uint32_t model_id);

static Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeRootModel> &ge_root_model,
const std::shared_ptr<ModelListener> &listener);
const std::shared_ptr<ModelListener> &listener, uint32_t device_id,
int64_t die_id = kInvalidDieId);
static Status SetDevice(uint32_t device_id, int64_t die_id);
static Status ResetDevice(uint32_t device_id, int64_t die_id);
static Status LoadMultiModelOnline(const std::shared_ptr<ge::GeRootModel> &ge_root_model_ptr, bool is_async);
};
} // namespace ge
#endif // GE_GRAPH_LOAD_GRAPH_LOADER_H_

+ 8
- 10
ge/graph/load/model_manager/davinci_model.cc View File

@@ -444,16 +444,16 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) {

Status DavinciModel::InitVariableMem() {
// malloc variable memory base
var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM);
var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM, GetDeviceId());
if (TotalVarMemSize() && (var_mem_base_ == nullptr)) {
Status ret = VarManager::Instance(session_id_)->MallocVarMemory(TotalVarMemSize());
Status ret = VarManager::Instance(session_id_)->MallocVarMemory(TotalVarMemSize(), GetDeviceId());
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999", "MallocVarMemory fail, var_size:%zu, model_id:%u, check invalid",
TotalVarMemSize(), model_id_);
GELOGE(ret, "[Malloc][VarMemory] failed, var_size:%zu, model_id:%u", TotalVarMemSize(), model_id_);
return ret;
}
var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM);
var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM, GetDeviceId());
GEEVENT("[IMAS]InitVariableMem graph_%u MallocMemory type[V] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id,
var_mem_base_, TotalVarMemSize());
}
@@ -2819,18 +2819,16 @@ void *DavinciModel::Run(DavinciModel *model) {
bool seq_end_flag = false;
uint32_t model_id = model->Id();
uint32_t device_id = model->GetDeviceId();
int64_t die_id = model->GetDieId();
ErrorManager::GetInstance().SetErrorContext(model->GetErrorContext());

GELOGI("Model Run thread start, model_id:%u.", model_id);
rtError_t rt_ret = rtSetDevice(static_cast<int32_t>(device_id));
if (rt_ret != RT_ERROR_NONE) {

GELOGE(FAILED, "[Run][Rtsetdevice] failed, model_id:%u, device_id:%u.", model_id, device_id);
if (GraphLoader::SetDevice(device_id, die_id) != SUCCESS) {
GELOGE(FAILED, "[Run][Setdevice] failed, model_id:%u, device_id:%u die_id%ld.", model_id, device_id, die_id);
return nullptr;
}
// DeviceReset before thread run finished!
GE_MAKE_GUARD(not_used_var, [&] { GE_CHK_RT(rtDeviceReset(device_id)); });

GE_MAKE_GUARD(reset_device, [&] { GE_CHK_RT(GraphLoader::ResetDevice(device_id, model->GetDieId())); });
ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute);
while (model->RunFlag()) {
// Model hasn't truly started runing before received data
@@ -2886,7 +2884,7 @@ void *DavinciModel::Run(DavinciModel *model) {
GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_START));
GE_TIMESTAMP_START(rtModelExecute);
GELOGI("rtModelExecute start.");
rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0);
auto rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0);
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false;
(void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput());
continue);


+ 8
- 1
ge/graph/load/model_manager/davinci_model.h View File

@@ -59,7 +59,9 @@ namespace ge {
// op debug need 2048 bits buffer
const size_t kOpDebugMemorySize = 2048UL;
const size_t kDebugP2pSize = 8UL;
const size_t kDebugP2pSize = 8UL;

const int64_t kInvalidDieId = -1;
typedef enum tagModelProcStage {
MODEL_LOAD_START = 1,
MODEL_LOAD_END,
@@ -441,13 +443,17 @@ class DavinciModel {
/// @return void
///
void SetDeviceId(uint32_t device_id) { device_id_ = device_id; }
void SetDieId(int64_t die_id) { die_id_ = die_id; }

///
/// @ingroup ge
/// @brief Get device Id
/// @return device id
///
uint32_t GetDeviceId() const { return device_id_; }
uint32_t GetDeviceId() const {
return die_id_ == kInvalidDieId ? device_id_ : die_id_;
}
int64_t GetDieId() const { return die_id_; }

bool NeedDestroyAicpuKernel() const { return need_destroy_aicpu_kernel_; }

@@ -1010,6 +1016,7 @@ class DavinciModel {
struct error_message::Context error_context_;

uint32_t device_id_;
int64_t die_id_ = kInvalidDieId;

mutex flowctrl_op_index_internal_map_mutex_;
map<uint32_t, uint32_t> flowctrl_op_index_internal_map_;


+ 3
- 2
ge/graph/load/model_manager/model_manager.cc View File

@@ -324,7 +324,7 @@ bool ModelManager::IsNeedHybridLoad(ge::GeRootModel &ge_root_model) {
/// @return Status run result
///
Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::GeRootModel> &ge_root_model,
std::shared_ptr<ModelListener> listener) {
std::shared_ptr<ModelListener> listener, uint32_t &device_id, int64_t die_id) {
GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "[Check][Param] Param incorrect, listener is null");
if (model_id == INVALID_MODEL_ID) {
GenModelId(&model_id);
@@ -342,7 +342,8 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge
davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano +
timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond
davinci_model->SetId(model_id);
davinci_model->SetDeviceId(GetContext().DeviceId());
davinci_model->SetDeviceId(device_id);
davinci_model->SetDieId(die_id);

auto root_graph = ge_root_model->GetRootGraph();
GE_CHECK_NOTNULL(root_graph);


+ 1
- 1
ge/graph/load/model_manager/model_manager.h View File

@@ -71,7 +71,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager {
/// @author @
///
ge::Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeRootModel> &ge_root_model,
std::shared_ptr<ModelListener> listener);
std::shared_ptr<ModelListener> listener,uint32_t &device_id, int64_t die_id);

ge::Status DoLoadHybridModelOnline(uint32_t model_id, const string &model_name,
const shared_ptr<ge::GeRootModel> &ge_root_model,


+ 33
- 1
ge/graph/manager/graph_manager.cc View File

@@ -98,6 +98,7 @@
#include "graph/passes/hccl_continuous_memcpy_pass.h"
#include "graph/passes/parallel_group_pass.h"
#include "graph/passes/buffer_pool_memory_pass.h"
#include "graph/passes/mds_pass.h"
#include "graph/build/label_allocator.h"
#include "graph/utils/tensor_adapter.h"
#include "inc/pass_manager.h"
@@ -110,6 +111,7 @@
#include "external/graph/types.h"
#include "common/util/error_manager/error_manager.h"
#include "common/profiling/profiling_manager.h"
#include "graph/debug/ge_attr_define.h"

namespace {
const char *const kSummary = "Summary";
@@ -1087,7 +1089,6 @@ Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphN
if (!options_.run_graph_flag) {
return SUCCESS;
}

ErrorManager::GetInstance().SetStage(error_message::kModelLoad, error_message::kModelLoad);
GE_CHECK_NOTNULL(executor_);
return executor_->LoadGraph(ge_root_model, graph_node);
@@ -2816,9 +2817,40 @@ const map<std::string, std::string> *GraphManager::GetGraphOptions(uint32_t grap
}

void GraphManager::SetOptionsRunGraphFlag(bool run_graph_flag) { options_.run_graph_flag = run_graph_flag; }
Status GraphManager::SetNodeCutInfo(ComputeGraphPtr &compute_graph) {
auto instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
REPORT_INNER_ERROR("E19999", "GeLib is not init before, check invalid");
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Check][Param] GE is not initialized");
return FAILED;
}
for (const auto &node : compute_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
auto kernel_lib_name = node->GetOpDesc()->GetOpKernelLibName();
OpsKernelInfoStorePtr kernel_info = instance_ptr->OpsKernelManagerObj().GetOpsKernelInfoStore(kernel_lib_name);
if (kernel_info == nullptr) {
REPORT_INNER_ERROR("E19999", "Find ops kernel by name:%s failed",
kernel_lib_name.c_str());
GELOGE(FAILED, "[Get][OpsKernelInfoStore] by name:%s failed", kernel_lib_name.c_str());
return FAILED;
}
GE_CHK_STATUS_RET(kernel_info->SetCutSupportedInfo(node));
}
}

Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph,
uint64_t session_id) {
GE_TIMESTAMP_START(MDS);
// Set the cut support information based on the engine of the node
EnginePlacer engine_placer;
engine_placer.SetComputeGraph(compute_graph);
GE_CHK_STATUS_RET(engine_placer.Run());
GE_CHK_STATUS_RET(SetNodeCutInfo(compute_graph));
// mds pass
PassManager graph_pass;
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeSubgraph::MDS", new (std::nothrow) ModelDeploySchedulerPass))
GE_CHK_STATUS_RET(graph_pass.Run(compute_graph));
GE_TIMESTAMP_EVENT_END(MDS, "OptimizeSubgraph::MDS");
// graph partition
// Stage partition, only for root graph
GE_TIMESTAMP_START(StagePartition);


+ 1
- 1
ge/graph/manager/graph_manager.h View File

@@ -242,7 +242,7 @@ class GraphManager {
uint64_t session_id = INVALID_SESSION_ID);

Status OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, uint64_t session_id);
Status SetNodeCutInfo (ComputeGraphPtr &compute_graph);
Status Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph,
GeRootModelPtr &ge_root_model, uint64_t session_id);



+ 52
- 22
ge/graph/manager/graph_mem_allocator.cc View File

@@ -23,12 +23,16 @@ Status MemoryAllocator::Initialize(uint32_t device_id) {
GELOGI("MemoryAllocator::Initialize");

// when redo Initialize free memory
for (auto &it : memory_base_map_) {
if (FreeMemory(it.second.memory_addr_, device_id) != ge::SUCCESS) {
GELOGW("Initialize: FreeMemory failed");
for (auto &it_map : deviceid_2_memory_bases_map_) {

for (auto &it : it_map.second) {
if (FreeMemory(it.second.memory_addr_, device_id) != ge::SUCCESS) {
GELOGW("Initialize: FreeMemory failed");
}
}
it_map.second.clear();
}
memory_base_map_.clear();
deviceid_2_memory_bases_map_.clear();
return SUCCESS;
}

@@ -36,12 +40,16 @@ void MemoryAllocator::Finalize(uint32_t device_id) {
GELOGI("MemoryAllocator::Finalize");

// free memory
for (auto &it : memory_base_map_) {
if (FreeMemory(it.second.memory_addr_, device_id) != ge::SUCCESS) {
GELOGW("Finalize: FreeMemory failed");
for (auto &it_map : deviceid_2_memory_bases_map_) {

for (auto &it : it_map.second) {
if (FreeMemory(it.second.memory_addr_, device_id) != ge::SUCCESS) {
GELOGW("Finalize: FreeMemory failed");
}
}
it_map.second.clear();
}
memory_base_map_.clear();
deviceid_2_memory_bases_map_.clear();
}

uint8_t *MemoryAllocator::MallocMemory(const string &purpose, size_t memory_size, uint32_t device_id) const {
@@ -75,12 +83,16 @@ Status MemoryAllocator::FreeMemory(uint8_t *memory_addr, uint32_t device_id) con

uint8_t *MemoryAllocator::MallocMemory(const string &purpose, const string &memory_key, size_t memory_size,
uint32_t device_id) {
auto it = memory_base_map_.find(memory_key);
if (it != memory_base_map_.end()) {
it->second.memory_used_num_++;
return it->second.memory_addr_;
map<string, MemoryInfo> memory_base_map;
auto it_map = deviceid_2_memory_bases_map_.find(device_id);
if (it_map != deviceid_2_memory_bases_map_.end()) {
memory_base_map = it_map->second;
auto it = it_map->second.find(memory_key);
if (it != it_map->second.end()) {
it->second.memory_used_num_++;
return it->second.memory_addr_;
}
}

uint8_t *memory_addr = MallocMemory(purpose, memory_size, device_id);

if (memory_addr == nullptr) {
@@ -91,16 +103,27 @@ uint8_t *MemoryAllocator::MallocMemory(const string &purpose, const string &memo
return nullptr;
}

MemoryInfo memory_info(memory_addr, memory_size);
MemoryInfo memory_info(memory_addr, memory_size, device_id);
memory_info.memory_used_num_++;
memory_base_map_[memory_key] = memory_info;
memory_base_map[memory_key] = memory_info;
deviceid_2_memory_bases_map_[device_id] = memory_base_map;
mem_malloced_ = true;
return memory_addr;
}

Status MemoryAllocator::FreeMemory(const string &memory_key, uint32_t device_id) {
auto it = memory_base_map_.find(memory_key);
if (it == memory_base_map_.end()) {
auto it_map = deviceid_2_memory_bases_map_.find(device_id);
if (it_map == deviceid_2_memory_bases_map_.end()){
if (mem_malloced_) {
GELOGW(
"MemoryAllocator::FreeMemory failed,"
" memory_key[%s] was not exist, device_id = %u.",
memory_key.c_str(), device_id);
}
return ge::INTERNAL_ERROR;
}
auto it = it_map->second.find(memory_key);
if (it == it_map->second.end()) {
if (mem_malloced_) {
GELOGW(
"MemoryAllocator::FreeMemory failed,"
@@ -109,7 +132,6 @@ Status MemoryAllocator::FreeMemory(const string &memory_key, uint32_t device_id)
}
return ge::INTERNAL_ERROR;
}

if (it->second.memory_used_num_ > 1) {
GELOGW("MemoryAllocator::FreeMemory memory_key[%s] should not be released, reference count %d", memory_key.c_str(),
it->second.memory_used_num_);
@@ -129,20 +151,28 @@ Status MemoryAllocator::FreeMemory(const string &memory_key, uint32_t device_id)

GELOGI("MemoryAllocator::FreeMemory device_id = %u", device_id);

memory_base_map_.erase(it);
it_map->second.erase(it);
return ge::SUCCESS;
}

uint8_t *MemoryAllocator::GetMemoryAddr(const string &memory_key, uint32_t device_id) {
auto it = memory_base_map_.find(memory_key);
if (it == memory_base_map_.end()) {
auto it_map = deviceid_2_memory_bases_map_.find(device_id);
if (it_map == deviceid_2_memory_bases_map_.end()) {

GELOGW(
"MemoryAllocator::GetMemoryAddr failed,"
" memory_key[%s] was not exist, device_id = %u.",
memory_key.c_str(), device_id);
return nullptr;
}
auto it = it_map->second.find(memory_key);
if (it == it_map->second.end()) {
GELOGW(
"MemoryAllocator::GetMemoryAddr failed,"
" memory_key[%s] was not exist, device_id = %u.",
memory_key.c_str(), device_id);
return nullptr;
}

return it->second.memory_addr_;
}
} // namespace ge

+ 9
- 4
ge/graph/manager/graph_mem_allocator.h View File

@@ -32,10 +32,13 @@
namespace ge {
class MemoryInfo {
public:
MemoryInfo() : memory_addr_(nullptr), memory_size_(0), memory_used_num_(0) {}
MemoryInfo() : memory_addr_(nullptr), memory_size_(0), memory_used_num_(0), device_id_(0) {}

MemoryInfo(uint8_t *memory_addr, size_t memory_size)
: memory_addr_(memory_addr), memory_size_(memory_size), memory_used_num_(0) {}
: memory_addr_(memory_addr), memory_size_(memory_size), memory_used_num_(0), device_id_(0) {}

MemoryInfo(uint8_t *memory_addr, size_t memory_size, uint32_t device_id)
: memory_addr_(memory_addr), memory_size_(memory_size), device_id_(device_id), memory_used_num_(0) {}

MemoryInfo &operator=(const MemoryInfo &op) {
if (&op == this) {
@@ -44,7 +47,7 @@ class MemoryInfo {

this->memory_addr_ = op.memory_addr_;
this->memory_size_ = op.memory_size_;
this->memory_used_num_ = op.memory_used_num_;
this->device_id_ = op.device_id_;
return *this;
}

@@ -52,12 +55,14 @@ class MemoryInfo {
this->memory_addr_ = op.memory_addr_;
this->memory_size_ = op.memory_size_;
this->memory_used_num_ = op.memory_used_num_;
this->device_id_ = op.device_id_;
}
virtual ~MemoryInfo() = default;

uint8_t *memory_addr_;
uint64_t memory_size_;
int32_t memory_used_num_;
uint32_t device_id_;
};

class MemoryAllocator {
@@ -133,7 +138,7 @@ class MemoryAllocator {
private:
rtMemType_t memory_type_;
bool mem_malloced_;
map<string, MemoryInfo> memory_base_map_;
map<uint32_t, map<string, MemoryInfo>> deviceid_2_memory_bases_map_;
};
} // namespace ge



+ 7
- 7
ge/graph/manager/graph_var_manager.cc View File

@@ -348,7 +348,7 @@ ge::Status VarManager::Init(const uint32_t &version, const uint64_t &session_id,
device_id_ = device_id;
session_id_ = session_id;
job_id_ = job_id;
var_resource_ = std::unique_ptr<VarResource>(new (std::nothrow) VarResource(session_id_));
var_resource_ = std::unique_ptr<VarResource>(new(std::nothrow) VarResource(session_id_));
if (var_resource_ == nullptr) {
GELOGW("VarManager init failed session id = %lu.", session_id);
return ge::INTERNAL_ERROR;
@@ -637,7 +637,7 @@ rtMemType_t VarManager::GetVarMemType(const int64_t &offset) {
return var_resource_->GetVarMemType(offset);
}

ge::Status VarManager::MallocVarMemory(size_t memory_size) {
ge::Status VarManager::MallocVarMemory(size_t memory_size, uint32_t device_id) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
uint8_t *var_mem_base = nullptr;
string memory_key = std::to_string(session_id_);
@@ -649,7 +649,7 @@ ge::Status VarManager::MallocVarMemory(size_t memory_size) {
var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;

const string purpose("variables and constant op memory in training network.");
var_mem_base = MemManager::Instance().MemInstance(RT_MEMORY_HBM).MallocMemory(purpose, memory_key, var_memory_size);
var_mem_base = MemManager::Instance().MemInstance(RT_MEMORY_HBM).MallocMemory(purpose, memory_key, var_memory_size, device_id);
if (var_mem_base == nullptr) {
GELOGE(ge::INTERNAL_ERROR, "[Malloc][VarMemory] failed, size:%zu, session_id:%s",
var_memory_size, memory_key.c_str());
@@ -658,22 +658,22 @@ ge::Status VarManager::MallocVarMemory(size_t memory_size) {
return SUCCESS;
}

uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) {
uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type, uint32_t device_id) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (memory_type == RT_MEMORY_RDMA_HBM) {
return MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).GetRdmaBaseAddr();
}
string memory_key = std::to_string(session_id_);
return MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(memory_key);
return MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(memory_key, device_id);
}

uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) {
uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type, uint32_t device_id) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (memory_type == RT_MEMORY_RDMA_HBM) {
return logic_addr;
}
string mem_key = std::to_string(session_id_);
uint8_t *mem_base = MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(mem_key);
uint8_t *mem_base = MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(mem_key, device_id);
if (mem_base == nullptr) {
return nullptr;
}


+ 4
- 3
ge/graph/manager/graph_var_manager.h View File

@@ -231,7 +231,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager {

ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc);

ge::Status MallocVarMemory(size_t memory_size = kMemoryVarManagerMallocSize);
ge::Status MallocVarMemory(size_t memory_size = kMemoryVarManagerMallocSize, uint32_t device_id = 0);

ge::Status FreeVarMemory();

@@ -277,9 +277,9 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager {

rtMemType_t GetVarMemType(const int64_t &offset);

uint8_t *GetVarMemoryBase(rtMemType_t memory_type);
uint8_t *GetVarMemoryBase(rtMemType_t memory_type, uint32_t device_id = 0);

uint8_t *GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type);
uint8_t *GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type, uint32_t device_id = 0);

Status GetAllVariables(std::map<std::string, GeTensorDesc> &all_variables);

@@ -293,6 +293,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager {
size_t var_mem_logic_base_;
size_t use_max_mem_size_;
std::unique_ptr<ge::VarResource> var_resource_;
// map<uint32_t , std::shared_ptr<ge::VarResource>> var_resource_map_;
map<rtMemType_t, MemResource *> mem_resource_map_;
mutable std::recursive_mutex mutex_;



+ 142
- 0
ge/graph/passes/mds_kernels/base_mds_kernel.cc View File

@@ -0,0 +1,142 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "./base_mds_kernel.h"
namespace ge {
namespace mds_cut_pass {
shared_ptr<DeploySchedulerKernel> GetKernelByType(const NodePtr &node) {
if (node == nullptr) {
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid");
GELOGE(FAILED, "[Check][Param] parameter node is nullptr.");
return nullptr;
}
KernelFactory &factory = KernelFactory::Instance();
string type = node->GetType();
if (type == FRAMEWORKOP) {
if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) {
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(),
node->GetName().c_str(), node->GetType().c_str());
return nullptr;
}
}
return factory.Create(type);
}
} // namespace mds_cut_pass
shared_ptr<DeploySchedulerKernel> DeploySchedulerKernel::Instance() {
static const std::shared_ptr<DeploySchedulerKernel> instance_ptr =
shared_ptr<DeploySchedulerKernel>(new (std::nothrow) DeploySchedulerKernel());
return instance_ptr;
}
Status DeploySchedulerKernel::CutN(const ge::NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (auto &in_anchor : node->GetAllInDataAnchors()) {
GE_CHECK_NOTNULL(in_anchor);
auto src_anchor = in_anchor->GetPeerOutAnchor();
if (src_anchor == nullptr) {
continue;
}
auto tensor_desc = op_desc->MutableInputDesc(in_anchor->GetIdx());
auto src_node = src_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
auto src_op_desc = src_node->GetOpDesc();
auto src_tensor_desc = src_op_desc->MutableOutputDesc(src_anchor->GetIdx());
GE_CHECK_NOTNULL(src_tensor_desc);
// peer out shape is cutted already
if (MdsUtils::IsDistributedDeploySupported(src_tensor_desc, kCutN)) {
if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) {
tensor_desc->SetShape(src_tensor_desc->GetShape());
} else {
MDS_REQUIRE_SUCCESS(
MdsUtils::DataGather(src_anchor, in_anchor), "[CutN] failed to gather between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx());
}
} else {
if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) {
MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_),
"[CutN] failed to slice between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(),
in_anchor->GetIdx());
} else {
tensor_desc->SetShape(src_tensor_desc->GetShape());
}
}
// insert hcomallreduce for cutn
bool is_grad_compute_node = false;
if (ge::AttrUtils::GetBool(src_node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) &&
is_grad_compute_node) {
MDS_REQUIRE_SUCCESS(
MdsUtils::DataReduce(src_anchor, in_anchor), "[CutN] failed to reduce between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx());
}
}
// call infer shape, update output shape
MDS_REQUIRE_SUCCESS(node->InferShapeAndType(), "[CutN] %s call infershape failed", node->GetName().c_str());
return SUCCESS;
}
Status DeploySchedulerKernel::CutH(const ge::NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (auto &in_anchor : node->GetAllInDataAnchors()) {
GE_CHECK_NOTNULL(in_anchor);
auto src_anchor = in_anchor->GetPeerOutAnchor();
if (src_anchor == nullptr) {
continue;
}
auto tensor_desc = op_desc->MutableInputDesc(in_anchor->GetIdx());
auto src_node = src_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
auto src_op_desc = src_node->GetOpDesc();
auto src_tensor_desc = src_op_desc->MutableOutputDesc(src_anchor->GetIdx());
GE_CHECK_NOTNULL(src_tensor_desc);
// peer out shape is cutted already
if (MdsUtils::IsDistributedDeploySupported(src_tensor_desc, kCutH)) {
if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) {
MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx()),
"[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(),
in_anchor->GetIdx());
} else {
MDS_REQUIRE_SUCCESS(
MdsUtils::DataGather(src_anchor, in_anchor), "[CutH] failed to gather between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx());
}
} else {
if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) {
MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_),
"[CutH] failed to slice between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(),
in_anchor->GetIdx());
} else {
MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx(), true),
"[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]",
src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(),
in_anchor->GetIdx());
}
}
}
// call infer shape, update output shape
MDS_REQUIRE_SUCCESS(node->InferShapeAndType(), "[CutH] call infer shape failed", node->GetName().c_str());
return SUCCESS;
}
} // namespace ge

+ 76
- 0
ge/graph/passes/mds_kernels/base_mds_kernel.h View File

@@ -0,0 +1,76 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_
#define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_
#include <vector>
#include "common/op/ge_op_utils.h"
#include "graph/compute_graph.h"
#include "graph/graph.h"
#include "graph/op_desc.h"
#include "graph/debug/ge_op_types.h"
#include "framework/common/types.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/shape_refiner.h"
#include "../pass_utils.h"
#include "./mds_utils.h"
#include "./mds_kernel_factory.h"
namespace ge {
class DeploySchedulerKernel {
public:
static shared_ptr<DeploySchedulerKernel> Instance();
/// CutN imply
/// @param [in] node_ptr
virtual Status CutN(const ge::NodePtr &node_ptr);
/// CutH imply
/// @param [in] node_ptr
virtual Status CutH(const ge::NodePtr &node_ptr);
/// DynamicCutN imply
/// @param [in] node_ptr
virtual Status DynamicCutN(const ge::NodePtr &node_ptr);
/// DynamicCutH imply
/// @param [in] node_ptr
virtual Status DynamicCutH(const ge::NodePtr &node_ptr);
// halo exchange process
Status HaloExchangeProcess(NodePtr node, int64_t index, bool local_slice = false);
NodePtr GetInputNode() {
return input_node_;
}
DeploySchedulerKernel &operator=(const DeploySchedulerKernel &kernel) = delete;
DeploySchedulerKernel(const DeploySchedulerKernel &kernel) = delete;
protected:
DeploySchedulerKernel() = default;
virtual ~DeploySchedulerKernel() = default;
private:
NodePtr input_node_ = nullptr;
};
namespace mds_cut_pass {
shared_ptr<DeploySchedulerKernel> GetKernelByType(const NodePtr &node);
}
} // namespace ge
#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_

+ 30
- 0
ge/graph/passes/mds_kernels/conv2d_mds_kernel.cc View File

@@ -0,0 +1,30 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "conv2d_mds_kernel.h"
#include "mds_kernel_factory.h"
namespace ge {
Status Conv2dDeploySchedulerKernel::CutN(const ge::NodePtr node_ptr) {
return DeploySchedulerKernel::CutN(node_ptr);
}
Status Conv2dDeploySchedulerKernel::CutH(const ge::NodePtr node_ptr) {
return DeploySchedulerKernel::CutH(node_ptr);
}
REGISTER_MDS_KERNEL(CONV2D, Conv2dDeploySchedulerKernel);
} // namespace ge

+ 29
- 0
ge/graph/passes/mds_kernels/conv2d_mds_kernel.h View File

@@ -0,0 +1,29 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_CONV2D_MDS_KERNEL_H_
#define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_CONV2D_MDS_KERNEL_H_
#include "base_mds_kernel.h"
namespace ge {
class Conv2dDeploySchedulerKernel : public DeploySchedulerKernel {
public:
Status CutN(const ge::NodePtr& node_ptr) override;
Status CutH(const ge::NodePtr& node_ptr) override;
};
} // namespace ge
#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_CONV2D_MDS_KERNEL_H_

+ 102
- 0
ge/graph/passes/mds_kernels/mds_kernel_factory.h View File

@@ -0,0 +1,102 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_KERNEL_FACTORY_H_
#define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_KERNEL_FACTORY_H_
#include <functional>
#include <map>
#include <memory>
#include <string>
#include "common/ge/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/graph.h"
using std::string;
namespace ge {
class DeploySchedulerKernel;
///
/// @brief DeploySchedulerKernel create factory
///
class KernelFactory {
public:
// KernelCreator(function), type definition
using KERNEL_CREATOR_FUN = std::function<std::shared_ptr<DeploySchedulerKernel>(void)>;
///
/// Get singleton instance
///
static KernelFactory &Instance() {
static KernelFactory instance;
return instance;
}
///
/// create DeploySchedulerKernel
/// @param [in] op_type operation type
///
std::shared_ptr<DeploySchedulerKernel> Create(const std::string &op_type) {
std::map<std::string, KERNEL_CREATOR_FUN>::iterator iter = creator_map_.find(op_type);
if (iter != creator_map_.end()) {
return iter->second();
}
return nullptr;
}
// DeploySchedulerKernel registration function to register different types of DeploySchedulerKernel to the factory
class Register {
public:
///
/// @brief Constructor
/// @param [in] type operation type
/// @param [in| fun DeploySchedulerKernel function of the operation
///
Register(const string &type, const KERNEL_CREATOR_FUN &fun) {
KernelFactory::Instance().RegisterCreator(type, fun);
}
~Register() = default;
};
protected:
KernelFactory() = default;
~KernelFactory() = default;
// register creator, this function will call in the constructor
void RegisterCreator(const string &type, const KERNEL_CREATOR_FUN &fun) {
std::map<std::string, KERNEL_CREATOR_FUN>::iterator iter = creator_map_.find(type);
if (iter != creator_map_.end()) {
GELOGW("KernelFactory::RegisterCreator: %s creator already exist", type.c_str());
return;
}
creator_map_[type] = fun;
}
private:
std::map<std::string, KERNEL_CREATOR_FUN> creator_map_{};
};
#define REGISTER_MDS_KERNEL(type, clazz) \
std::shared_ptr<DeploySchedulerKernel> Creator_##type##_Kernel() { \
std::shared_ptr<clazz> ptr = nullptr; \
ptr = MakeShared<clazz>(); \
return ptr; \
} \
KernelFactory::Register g_##type##_Kernel_Creator(type, Creator_##type##_Kernel)
} // namespace ge
#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_KERNEL_FACTORY_H_

+ 476
- 0
ge/graph/passes/mds_kernels/mds_utils.cc View File

@@ -0,0 +1,476 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "./mds_utils.h"
namespace ge {
namespace {
// for count
thread_local int64_t data_slice_count = 0;
thread_local int64_t data_gather_count = 0;
thread_local int64_t data_reduce_count = 0;
const std::string kPrefix = "mds";
} // namespace
int64_t MdsUtils::GetNLocation(Format fmt) {
int64_t loc = kNInvalidLocation;
switch (fmt) {
case FORMAT_NCHW:
case FORMAT_NHWC:
loc = kNLocation0;
break;
case FORMAT_CHWN:
case FORMAT_HWCN:
loc = kNLocation3;
break;
default:
GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str());
}
return loc;
}
int64_t MdsUtils::GetHLocation(Format fmt) {
int64_t loc = kHInvalidLocation;
switch (fmt) {
case FORMAT_HWCN:
loc = kHLocation0;
break;
case FORMAT_NHWC:
case FORMAT_CHWN:
loc = kHLocation1;
break;
case FORMAT_NCHW:
loc = kHLocation2;
default:
GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str());
}
return loc;
}
int64_t MdsUtils::GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type) {
Format fmt = ge_tensor_desc->GetFormat();
switch (type) {
case kCutN:
return GetNLocation(fmt);
case kCutH:
return GetHLocation(fmt);
default:;
}
GELOGE(FAILED, "[MDS]invalid CutType:%d", type);
return kInvalidIndex;
}
bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type) {
if (ge_tensor_desc == nullptr) {
REPORT_INNER_ERROR("E19999", "invalid input param: tensor is null!");
GELOGE(FAILED, "[MDS]invalid input param: tensor is null!");
return false;
}
if (type != kCutN && type != kCutH) {
REPORT_INNER_ERROR("E19999", "invalid CutType:%d", type);
GELOGE(FAILED, "[MDS]invalid CutType:%d", type);
return false;
}
int64_t cut_index = GetIndexByFormat(ge_tensor_desc, type);
if (cut_index == kInvalidIndex) {
REPORT_INNER_ERROR("E19999", "invalid index param:%ld", cut_index);
GELOGE(FAILED, "[MDS]", "invalid index param:%ld", cut_index);
return false;
}
auto dims = ge_tensor_desc->GetShape().GetDims();
if (cut_index < 0 || cut_index >= dims.size()) {
REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type,
dims.size());
GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type,
dims.size());
return false;
}
if (dims[cut_index] % kDeployNumber != 0) {
GELOGW("[MDS] cut_index %ld for CutType %d with dim %ld can not deploy", cut_index, type, dims[cut_index]);
return false;
}
vector<int64_t> cut_support_info;
if (!(AttrUtils::GetListInt(*ge_tensor_desc, ATTR_NAME_CUT_INFO, cut_support_info))) {
REPORT_INNER_ERROR("E19999", "call GetlistInt failed");
GELOGE(FAILED, "[MDS]", "call GetlistInt failed");
return false;
}
if (cut_index < 0 || cut_index >= cut_support_info.size()) {
REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index,
type, cut_support_info.size());
GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index,
type, cut_support_info.size());
return false;
}
if (cut_support_info[cut_index] < kNotSupport || cut_support_info[cut_index] > kAnyCutSupported) {
REPORT_INNER_ERROR("E19999", "invalid cut info value:%ld", cut_support_info[cut_index]);
GELOGE(FAILED, "[MDS]", "invalid cut info value:%ld", cut_support_info[cut_index]);
return false;
}
return cut_support_info[cut_index] & kSplitCutSupported;
}
Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type, int64_t deploy_number) {
GE_CHECK_NOTNULL(ge_tensor_desc);
auto index = MdsUtils::GetIndexByFormat(ge_tensor_desc, type);
auto dims = ge_tensor_desc->GetShape().GetDims();
REQUIRE(index < dims.size(), "[DistributedDeploy] failed, index %ld should less than %zu", index, dims.size());
auto dim_after_deploy = dims[index] / deploy_number;
MDS_REQUIRE_SUCCESS(ge_tensor_desc->MutableShape().SetDim(index, dim_after_deploy),
"[DistributedDeploy] update shape failed");
return SUCCESS;
}
Status MdsUtils::SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor, const std::string &group_name) {
GE_CHECK_NOTNULL(hcom_op);
REQUIRE(fission_factor > kDefaultFissionFactor, "fission_factor %ld need be bigger than %ld", fission_factor,
kDefaultFissionFactor);
REQUIRE(ge::AttrUtils::SetInt(hcom_op, ATTR_NAME_FISSION_FACTOR, fission_factor),
"Failed to set attr fission_factor %ld for op:%s(%s)", fission_factor, hcom_op->GetName().c_str(),
hcom_op->GetType().c_str());
if (!group_name.empty()) {
REQUIRE(ge::AttrUtils::SetStr(hcom_op, HCOM_ATTR_GROUP, group_name), "Failed to set attr group %s for op:%s(%s)",
group_name.c_str(), hcom_op->GetName().c_str(), hcom_op->GetType().c_str());
}
return SUCCESS;
}
bool MdsUtils::IsMDSNeeded() {
std::string device_type;
if (ge::GetContext().GetOption(ge::OPTION_DEVICE_TYPE, device_type) && device_type == kDefaultDeviceType) {
GELOGI("[MDS]device type is %s, skip mds", device_type.c_str());
return false;
}
// TODO: Parse the configuration file of the system to get the sys_config_exe_unit
std::string sys_config_exe_unit = "DIE";
return device_type != sys_config_exe_unit;
}
Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node) {
GE_CHECK_NOTNULL(compute_graph);
GELOGD("[MDS]%s SetDeployInfo start", compute_graph->GetName().c_str());
// build deploy info
vector<GeAttrValue::NAMED_ATTRS> deploy_info;
GE_CHECK_NOTNULL(input_node);
for (int64_t j = 0; j < kDeployNumber; j++) {
int64_t device_id = j;
GeAttrValue::LIST_TENSOR graph_inputs;
GeTensorPtr graph_input = MakeShared<GeTensor>(input_node->GetOpDesc()->GetOutputDesc(0));
vector<uint8_t> data{static_cast<uint8_t>(device_id)};
graph_input->SetData(data);
// For now, only one graph_input
graph_inputs.push_back(graph_input);
GeAttrValue::NAMED_ATTRS thread_instance;
thread_instance.SetName(std::to_string(device_id));
(void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id));
// TODO:Change to enumeration from RTS header file
(void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>("MultiMode"));
(void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName()));
(void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(graph_inputs));
deploy_info.emplace_back(thread_instance);
GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id);
}
// set deploy info
REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info),
"Set attr failed for graph %s", compute_graph->GetName().c_str());
return SUCCESS;
}
CutType MdsUtils::TryGetGraphCutType(const ComputeGraphPtr &compute_graph) {
bool is_unknown_graph = false;
if (GraphUtils::IsUnknownShapeGraph(compute_graph)) {
GELOGI("Graph %s is unknown shape graph", compute_graph->GetName().c_str());
is_unknown_graph = true;
}
CutType selected_cut_type = kNoCut;
for (const auto &data : compute_graph->GetInputNodes()) {
GELOGI("Get graph input %s %s", data->GetName().c_str(), data->GetType().c_str());
auto data_n_index = MdsUtils::GetIndexByFormat(data->GetOpDesc()->MutableOutputDesc(0), kCutN);
auto data_n_dim = data->GetOpDesc()->GetOutputDesc(0).GetShape().GetDim(data_n_index);
auto data_h_index = MdsUtils::GetIndexByFormat(data->GetOpDesc()->MutableOutputDesc(0), kCutH);
auto data_h_dim = data->GetOpDesc()->GetOutputDesc(0).GetShape().GetDim(data_h_index);
if (data_n_dim == -1 && data_h_dim == -1) {
selected_cut_type = kDynamicCutAll;
break;
}
if (data_n_dim % kDeployNumber == 0) {
is_unknown_graph ? selected_cut_type = kDynamicCutN : selected_cut_type = kCutN;
break;
}
if (data_h_dim % kDeployNumber == 0) {
is_unknown_graph ? selected_cut_type = kDynamicCutH : selected_cut_type = kCutH;
}
}
return selected_cut_type;
}
Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph,
const std::multimap<DeviceId, GraphInputs> &deploys, const std::string &device_type) {
GE_CHECK_NOTNULL(compute_graph);
GELOGD("[MDS]%s SetDeployInfo start", compute_graph->GetName().c_str());
// build deploy info
vector<GeAttrValue::NAMED_ATTRS> deploy_info;
for (const auto &pair : deploys) {
int64_t device_id = pair.first;
GeAttrValue::NAMED_ATTRS thread_instance;
thread_instance.SetName(std::to_string(device_id));
(void)thread_instance.SetAttr(kAttrNeedReturnResult,
GeAttrValue::CreateFrom<GeAttrValue::BOOL>(deploy_info.empty() ? true : false));
(void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id));
(void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>(device_type));
(void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName()));
(void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(pair.second));
deploy_info.emplace_back(thread_instance);
GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id);
}
// set deploy info
REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info),
"Set attr failed for graph %s", compute_graph->GetName().c_str());
return SUCCESS;
}
Status MdsUtils::DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst) {
auto src_node = src->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
auto dst_node = dst->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);
auto src_graph = src_node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(src_graph);
std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_gather_count));
auto hcom_allgather_node =
AddDynamicInputOutputNode(src_graph, HCOMALLGATHER, HCOMALLGATHER + node_name_suffix, 1, 1);
GE_CHECK_NOTNULL(hcom_allgather_node);
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, hcom_allgather_node),
"[DataGather] failed between %s and %s", src_node->GetName().c_str(),
dst_node->GetName().c_str());
MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(hcom_allgather_node->GetOpDesc(), kDeployNumber, kDefaultGroup),
"[DataGather]set attr for node for %s(%s) failed", hcom_allgather_node->GetName().c_str(),
hcom_allgather_node->GetType().c_str());
REQUIRE(ge::AttrUtils::SetInt(hcom_allgather_node->GetOpDesc(), HCOM_ATTR_RANK_SIZE, kDefaultRankSize),
"Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(),
hcom_allgather_node->GetName().c_str(), hcom_allgather_node->GetType().c_str());
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(hcom_allgather_node, false),
"[DataGather] %s call infershape failed", hcom_allgather_node->GetName().c_str());
data_gather_count++;
return SUCCESS;
}
// gradients->ApplyMomentum
// we want to reduce gradients on different device(die), so graph topo changed to
// gradients->hcomallreducemean->ApplyMomentum; Because 'mean' is not currently supported by hcomallreduce,
// topo will end up like gradients->hcomallreducesum->div->ApplyMomentum
Status MdsUtils::DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst) {
auto src_node = src->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
auto dst_node = dst->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);
auto src_graph = src_node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(src_graph);
NodePtr all_reduce_node = nullptr;
if (NeedInsertHcomAllReduce(src_node, all_reduce_node)) {
MDS_REQUIRE_SUCCESS(ConstructReduceNode(src_graph, src, dst, all_reduce_node),
"[DataReduce] construct allreduce node for %s failed", all_reduce_node->GetName().c_str());
GE_CHECK_NOTNULL(all_reduce_node);
} else {
GE_CHECK_NOTNULL(all_reduce_node);
MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(all_reduce_node->GetOpDesc(), kDeployNumber),
"[DataReduce][Modify] set attr for allreduce node for %s failed",
all_reduce_node->GetName().c_str());
}
return SUCCESS;
}
// tensor t with shape like [n,c,h,w], we want get [0:2/n, c, h, w] and [2/n : n, c, h, w] on different
// device; To achieve this goal, we use slice nodes.
// slice(t, [i * n/2, 0, 0, 0], [n/2, c, h, w]) i=0,1
// slice three input like : t->slice; data(0,1)->mul(n/2)->pack[i*n/2,0,0,0]->slice; const(n,c,h,w)->slice
Status MdsUtils::DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, NodePtr &input_node) {
auto src_node = src->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
auto dst_node = dst->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);
auto src_graph = src_node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(src_graph);
if (input_node == nullptr) {
std::string input_node_name = std::string(DATA) + "_" + kPrefix + "_" + std::to_string(0);
input_node = AddSingleInputOutputNode(src_graph, input_node_name, DATA);
AddInputNode(input_node);
}
GeTensorDesc tensor = src_node->GetOpDesc()->GetOutputDesc(src->GetIdx());
NodePtr slice_node = nullptr;
MDS_REQUIRE_SUCCESS(ConstructSliceNode(src_graph, tensor, input_node.get(), slice_node),
"[DataSlice] construct slice node for %s failed", src_node->GetName().c_str());
GE_CHECK_NOTNULL(slice_node);
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, slice_node), "[DataSlice] failed between %s and %s",
src_node->GetName().c_str(), dst_node->GetName().c_str());
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(slice_node, false), "[DataSlice] %s call infer shape failed",
slice_node->GetName().c_str());
return SUCCESS;
}
Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *input_node,
NodePtr &slice_node) {
vector<int64_t> slice_sizes = tensor.GetShape().GetDims();
// TODO: Express with graph structure
slice_sizes[0] /= kDeployNumber;
vector<GeTensorPtr> ge_tensors;
GeTensorDesc ge_tensor_desc;
ge_tensor_desc.SetDataType(DT_INT64);
MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_sizes, ge_tensors),
"[ConstructTensorDescWithData] failed");
GeTensorPtr slice_size_tensor = ge_tensors[0];
auto const_node_slice_size = AddConstNodeToGraph(slice_size_tensor, src_graph);
vector<int64_t> slice_offset_other_dim{0};
ge_tensors.clear();
MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_offset_other_dim, ge_tensors, true),
"[ConstructTensorDescWithData] failed");
GeTensorPtr slice_offset_tensor = ge_tensors[0];
auto const_node_slice_offset = AddConstNodeToGraph(slice_offset_tensor, src_graph);
vector<int64_t> slice_offset_first_dim{slice_sizes[0]};
ge_tensors.clear();
MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_offset_first_dim, ge_tensors, true),
"[ConstructTensorDescWithData] failed");
GeTensorPtr slice_offset_first_dim_tensor = ge_tensors[0];
auto const_node_slice_offset_first_dim = AddConstNodeToGraph(slice_offset_first_dim_tensor, src_graph);
std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_slice_count));
NodePtr mul_node = AddDynamicInputOutputNode(src_graph, MUL, MUL + node_name_suffix, 2, 1);
GE_CHECK_NOTNULL(input_node);
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(input_node->GetOutDataAnchor(0), mul_node->GetInDataAnchor(0)),
"[ConstructSliceNode] add edge failed");
MDS_REQUIRE_SUCCESS(
GraphUtils::AddEdge(const_node_slice_offset_first_dim->GetOutDataAnchor(0), mul_node->GetInDataAnchor(1)),
"[ConstructSliceNode] add edge failed");
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(mul_node, false), "[DataSlice] %s call infer shape failed",
mul_node->GetName().c_str());
NodePtr pack_node = AddDynamicInputOutputNode(src_graph, PACK, PACK + node_name_suffix, slice_sizes.size(), 1);
bool is_first_input = true;
for (const auto &in_anchor : pack_node->GetAllInDataAnchors()) {
if (is_first_input) {
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(mul_node->GetOutDataAnchor(0), in_anchor),
"[ConstructSliceNode] add edge failed");
is_first_input = false;
} else {
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_slice_offset->GetOutDataAnchor(0), in_anchor),
"[ConstructSliceNode] add edge failed");
}
}
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(pack_node, false), "[DataSlice] %s call infer shape failed",
pack_node->GetName().c_str());
slice_node = AddDynamicInputOutputNode(src_graph, SLICE, SLICE + node_name_suffix, 3, 1);
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(pack_node->GetOutDataAnchor(0), slice_node->GetInDataAnchor(1)),
"[ConstructSliceNode] add edge failed");
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_slice_size->GetOutDataAnchor(0), slice_node->GetInDataAnchor(2)),
"[ConstructSliceNode] add edge failed");
++data_slice_count;
return SUCCESS;
}
NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type,
const GeTensorDesc &tensor) {
GELOGI("Begin to create op: %s", name.c_str());
OpDescBuilder op_desc_builder(name, type);
OpDescPtr op_desc = op_desc_builder.AddInput("x", tensor).AddOutput("y", tensor).Build();
if (op_desc == nullptr) {
REPORT_CALL_ERROR("E19999", "Create op_desc:%s(%s) failed", name.c_str(), type.c_str());
GELOGE(FAILED, "[Create][OpDesc] failed, name:%s(%s).", name.c_str(), type.c_str());
return nullptr;
}
NodePtr node = graph->AddNode(op_desc);
if (node == nullptr) {
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(),
op_desc->GetType().c_str(), graph->GetName().c_str());
GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(),
graph->GetName().c_str());
return nullptr;
}
return node;
}
NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const std::string &type,
const std::string &node_name, size_t input_num, size_t output_num) {
GELOGI("Begin to create op: %s", node_name.c_str());
OpDescBuilder op_desc_builder(node_name, type);
OpDescPtr op_desc = op_desc_builder.AddDynamicInput("x", input_num).AddDynamicOutput("y", output_num).Build();
if (op_desc == nullptr) {
REPORT_CALL_ERROR("E19999", "Create op_desc:%s(%s) failed", node_name.c_str(), type.c_str());
GELOGE(FAILED, "[Create][OpDesc] failed, name:%s(%s).", node_name.c_str(), type.c_str());
return nullptr;
}
NodePtr node = graph->AddNode(op_desc);
if (node == nullptr) {
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(),
op_desc->GetType().c_str(), graph->GetName().c_str());
GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(),
graph->GetName().c_str());
return nullptr;
}
return node;
}
NodePtr MdsUtils::AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph) {
auto const_desc = OpDescUtils::CreateConstOp(tensor);
if (const_desc == nullptr) {
REPORT_CALL_ERROR("E19999", "Create Const op failed");
GELOGE(OUT_OF_MEMORY, "[Create][ConstOp] failed");
return nullptr;
}
if (graph == nullptr) {
GELOGW("input param graph is null");
return nullptr;
}
return graph->AddNodeFront(const_desc);
}
Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src,
const InDataAnchorPtr &dst, NodePtr &reduce_node) {
std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_reduce_count));
reduce_node = AddDynamicInputOutputNode(src_graph, HCOMALLREDUCE, HCOMALLREDUCE + node_name_suffix, 1, 1);
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, reduce_node),
"[DataReduce] failed insert %s between %s and %s", reduce_node->GetName().c_str(),
src->GetOwnerNode()->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str());
MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(reduce_node->GetOpDesc(), kDeployNumber, kDefaultGroup),
"[DataReduce][Create] set attr for allreduce node for %s failed", reduce_node->GetName().c_str());
REQUIRE(ge::AttrUtils::SetStr(reduce_node->GetOpDesc(), HCOM_ATTR_REDUCE_TYPE, kDefaultReduction),
"Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(),
reduce_node->GetName().c_str(), reduce_node->GetType().c_str());
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(reduce_node, false), "[DataReduce] %s call infershape failed",
reduce_node->GetName().c_str());
auto div_node = AddDynamicInputOutputNode(src_graph, REALDIV, REALDIV + node_name_suffix, 2, 1);
vector<int64_t> slice_sizes{kDeployNumber};
vector<GeTensorPtr> ge_tensors;
GeTensorDesc ge_tensor_desc;
ge_tensor_desc.SetDataType(DT_INT64);
MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_sizes, ge_tensors),
"[ConstructReduceNode] failed");
REQUIRE(!ge_tensors.empty(), "[ConstructReduceNode] failed");
auto const_node_div_input = AddConstNodeToGraph(ge_tensors[0], src_graph);
MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_div_input->GetOutDataAnchor(0), div_node->GetInDataAnchor(1)),
"[ConstructSliceNode] add edge failed");
MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(reduce_node->GetOutDataAnchor(0), {dst}, div_node),
"[DataReduce] failed insert %s between %s and %s", div_node->GetName().c_str(),
reduce_node->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str());
MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(div_node, false), "[DataReduce] %s call infershape failed",
div_node->GetName().c_str());
return SUCCESS;
}
bool MdsUtils::NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node) {
// TODO: recognize that the graph is originally a multi-p model, that is, there is already an allreduce node,
// so there is no need to insert i
return true;
}
} // namespace ge

+ 130
- 0
ge/graph/passes/mds_kernels/mds_utils.h View File

@@ -0,0 +1,130 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_
#define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_
#include "graph/ge_context.h"
#include "common/op/ge_op_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "ge/ge_api_types.h"
#include "common/ge/ge_util.h"
#include "graph/compute_graph.h"
#include "graph/shape_refiner.h"
#include "graph/debug/ge_op_types.h"
#include "framework/common/types.h"
#include "graph/utils/op_desc_utils.h"
#include "../pass_utils.h"
#define REQUIRE(cond, ...) \
do { \
if (!(cond)) { \
REPORT_INNER_ERROR("E19999", __VA_ARGS__); \
GELOGE(FAILED, "[MDS]" __VA_ARGS__); \
return FAILED; \
} \
} while (0)
#define MDS_REQUIRE_NOT_NULL(cond, ...) REQUIRE(((cond) != nullptr), __VA_ARGS__)
#define MDS_REQUIRE_SUCCESS(cond, ...) REQUIRE(((cond) == SUCCESS), __VA_ARGS__)
#define MDS_REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__)
namespace ge {
namespace {
// Invalid location index
const int64_t kInvalidIndex = -1;
enum NCutIndex { kNLocation0 = 0, kNLocation1, kNLocation2, kNLocation3, kNInvalidLocation = -1 };
enum HCutIndex { kHLocation0 = 0, kHLocation1, kHLocation2, kHLocation3, kHInvalidLocation = -1 };
// NCHW dim N index
const int32_t kNchwDimIdxN = 0;
// NCHW dim C index
const int32_t kNchwDimIdxC = 1;
// NCHW dim H index
const int32_t kNchwDimIdxH = 2;
// NCHW dim W index
const int32_t kNchwDimIdxW = 3;
// default die number
const uint32_t kDeployNumber = 2;
enum CutType { kNoCut = 0, kCutN, kCutH, kDynamicCutN, kDynamicCutH, kDynamicCutAll };
enum TensorCutInfo { kNotSupport = 0, kSplitCutSupported, kAnyCutSupported = 3 };
const int64_t kDefaultFissionFactor = 1;
const int64_t kDefaultRankSize = 1;
const std::string kDefaultGroup = "hccl_world_group";
const std::string kDefaultReduction = "sum";
const char *const kDefaultDeviceType = "DEFAULT_DEVICE_TYPE";
const char *const kDefaultExecUnit = "DEFAULT_DEVICE_TYPE";
// deploy info
const char *const kAttrNeedReturnResult = "_need_return_result";
const char *const kAttrDeviceType = "_device_type";
const char *const kDieDeviceTypeValue = "MultiMode";
const char *const kAttrDeviceId = "_device_id";
const char *const kAttrGraphName = "_graph_name";
const char *const kAttrGraphInputs = "_graph_inputs";
using GraphInputs = vector<GeTensorPtr>;
using DeviceId = int64_t;
using GraphInputNodes = vector<NodePtr>;
} // namespace
class MdsUtils {
public:
// Parse the configuration file and determine whether to enable MDS based on the value of device_type.
static bool IsMDSNeeded();
static int64_t GetNLocation(Format fmt);
static int64_t GetHLocation(Format fmt);
static int64_t GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type);
static bool IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type);
static Status SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor,
const std::string &group_name = "");
/// @param [in] index 切分的轴
/// @param [in] deploy_number 切分的份数
static Status DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type,
int64_t deploy_number = kDeployNumber);
// Sets the information, notifies the number of threads to be started during the
// loading phase, the device on which each thread should run, and constructs different input data on each device.
static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node);
static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const std::multimap<DeviceId, GraphInputs> &deploys,
const std::string &device_type = kDieDeviceTypeValue);
// Get cut policy for whole graph
static CutType TryGetGraphCutType(const ComputeGraphPtr &compute_graph);
static GraphInputNodes GetInputNodes() {
return input_nodes_;
}
static void AddInputNode(const NodePtr &input_node) {
input_nodes_.push_back(input_node);
}
static Status DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
static Status DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
static Status DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, NodePtr &input_node);
private:
static GraphInputNodes input_nodes_;
static NodePtr AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const string &type, const string &node_name,
size_t input_num, size_t output_num);
static NodePtr AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type,
const GeTensorDesc &tensor = GeTensorDesc());
static Status ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src,
const InDataAnchorPtr &dst, NodePtr &reduce_node);
static Status ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *node,
NodePtr &slice_node);
static bool NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node);
static NodePtr AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph);
};
} // namespace ge
#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_

+ 41
- 0
ge/graph/passes/mds_kernels/variable_mds_kernel.cc View File

@@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "variable_mds_kernel.h"
#include "mds_kernel_factory.h"
namespace ge {
Status VariableDeploySchedulerKernel::CutN(const ge::NodePtr& node_ptr) {
GE_CHECK_NOTNULL(node_ptr);
if (MdsUtils::IsDistributedDeploySupported(node_ptr->GetOpDesc()->MutableOutputDesc(0), kCutN)) {
return MdsUtils::DistributedDeploy(node_ptr->GetOpDesc()->MutableOutputDesc(0), kCutN);
}
return SUCCESS;
}
Status VariableDeploySchedulerKernel::CutH(const ge::NodePtr& node_ptr) {
GE_CHECK_NOTNULL(node_ptr);
if (MdsUtils::IsDistributedDeploySupported(node_ptr->GetOpDesc()->MutableOutputDesc(0), kCutH)) {
return MdsUtils::DistributedDeploy(node_ptr->GetOpDesc()->MutableOutputDesc(0), kCutH);
}
return SUCCESS;
}
REGISTER_MDS_KERNEL(VARIABLE, VariableDeploySchedulerKernel);
}

+ 28
- 0
ge/graph/passes/mds_kernels/variable_mds_kernel.h View File

@@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_VARIABLE_MDS_KERNEL_H_
#define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_VARIABLE_MDS_KERNEL_H_
#include "base_mds_kernel.h"
namespace ge {
class VariableDeploySchedulerKernel : public DeploySchedulerKernel {
public:
Status CutN(const ge::NodePtr& node_ptr) override;
Status CutH(const ge::NodePtr& node_ptr) override;
};
} // namespace ge
#endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_VARIABLE_MDS_KERNEL_H_

+ 177
- 0
ge/graph/passes/mds_pass.cc View File

@@ -0,0 +1,177 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "./mds_pass.h"
namespace ge {
Status ModelDeploySchedulerPass::Run(ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph);
compute_graph_ = graph;
if (!MdsUtils::IsMDSNeeded()) {
return SUCCESS;
}
GELOGI("[MDS][%s] start to deploy.", GetGraphName());
MDS_REQUIRE_SUCCESS(SMDPProcess(), "[MDS][SMDPProcess] failed, graph_name:[%s]", GetGraphName());
MDS_REQUIRE_SUCCESS(CutProcess(), "[MDS][CutProcess] failed, graph_name:[%s]", GetGraphName());
MDS_REQUIRE_SUCCESS(SMDPProcess(false), "[MDS][SMDPProcess] failed, graph_name:[%s]", GetGraphName());
MDS_REQUIRE_SUCCESS(SwapProcess(), "[MDS][SwapProcess] failed, graph_name:[%s]", GetGraphName());
MDS_REQUIRE_SUCCESS(PiplineProcess(), "[MDS][PiplineProcess] failed, graph_name:[%s]", GetGraphName());
MDS_REQUIRE_SUCCESS(SetDeployInfo(), "[MDS][SetDeployInfo] failed, graph_name:[%s]", GetGraphName());
GELOGI("[MDS][%s] deploy successfully.", graph->GetName().c_str());
return SUCCESS;
}
Status ModelDeploySchedulerPass::CutProcess() {
GE_CHECK_NOTNULL(compute_graph_);
if (!compute_graph_->GetAllSubgraphs().empty() || compute_graph_->GetParentGraph() != nullptr) {
GELOGW("[MDS][CutProcess] graph with subgraphs is not supported now. graph_name:[%s]", GetGraphName());
return SUCCESS;
}
auto type = MdsUtils::TryGetGraphCutType(compute_graph_);
switch (type) {
case kCutN:
MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_), "[MDS][CutNProcessImply] failed, graph_name:[%s]",
GetGraphName());
break;
case kCutH:
MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_), "[MDS][CutHProcessImply] failed, graph_name:[%s]",
GetGraphName());
break;
case kDynamicCutN:
MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]",
GetGraphName());
break;
case kDynamicCutH:
MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]",
GetGraphName());
break;
case kDynamicCutAll:
MDS_REQUIRE_SUCCESS(DynamicCutAll(compute_graph_), "[MDS][DynamicCutAll] failed, graph_name:[%s]",
GetGraphName());
break;
default:
GELOGI("[MDS][CutProcess] could not cut, just return. graph_name:[%s]", GetGraphName());
return SUCCESS;
}
}
Status ModelDeploySchedulerPass::CutNProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic) {
GE_CHECK_NOTNULL(compute_graph);
// step 0: Cut
for (const auto &node : compute_graph->GetDirectNode()) {
auto op_kernel = mds_cut_pass::GetKernelByType(node);
if (op_kernel == nullptr) {
op_kernel = DeploySchedulerKernel::Instance();
}
if (is_dynamic) {
MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutN(node), "[MDS][DYNAMIC_CUTN] failed, node:[%s]",
node->GetName().c_str());
} else {
MDS_REQUIRE_SUCCESS(op_kernel->CutN(node), "[MDS][CUTN] failed, node:[%s]", node->GetName().c_str());
}
bool is_grad_compute_node = false;
if (ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) &&
is_grad_compute_node) {
grad_compute_nodes_.push_back(node);
}
}
// TODO:for single output multi reference insertion allgather, allreduce nodes, do breadth fusion optimization
MDS_REQUIRE_SUCCESS(HcomNodeFusionProcess(), "[MDS][CUTN][HcomNodeFusionProcess] failed, compute graph:[%s]",
compute_graph->GetName().c_str());
return SUCCESS;
}
Status ModelDeploySchedulerPass::CutHProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic) {
GE_CHECK_NOTNULL(compute_graph);
for (NodePtr &node : compute_graph->GetDirectNode()) {
auto op_kernel = mds_cut_pass::GetKernelByType(node);
if (op_kernel == nullptr) {
op_kernel = DeploySchedulerKernel::Instance();
}
if (is_dynamic) {
MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutH(node), "[MDS][DYNAMIC_CUTH] failed, node:[%s]",
node->GetName().c_str());
} else {
MDS_REQUIRE_SUCCESS(op_kernel->CutH(node), "[MDS][CUTH] failed, node:[%s]", node->GetName().c_str());
}
}
return SUCCESS;
}
Status ModelDeploySchedulerPass::DynamicCutAll(const ComputeGraphPtr &compute_graph) {
std::vector<NodePtr> input_nodes;
std::vector<NodePtr> output_nodes;
auto compute_graph0 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes);
auto compute_graph1 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes);
MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph0, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]",
compute_graph0->GetName().c_str());
MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph1, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]",
compute_graph1->GetName().c_str());
// TODO:Create a case node, put the two graphs under the two branches of case
return SUCCESS;
}
Status ModelDeploySchedulerPass::SMDPProcess(bool before_cut) {
if (before_cut) {
MDS_REQUIRE_SUCCESS(SMDPModelState(), "[SMDPProcess][SMDPModelState] failed, graph_name:[%s]", GetGraphName());
MDS_REQUIRE_SUCCESS(SMDPWeight(), "[SMDPProcess][SMDPWeight] failed, graph_name:[%s]", GetGraphName());
} else {
MDS_REQUIRE_SUCCESS(SMDPGradient(), "[SMDPProcess][SMDPGradient] failed, graph_name:[%s]", GetGraphName());
}
return SUCCESS;
}
Status ModelDeploySchedulerPass::SetDeployInfo() {
vector<GeAttrValue::NAMED_ATTRS> deployInfo;
REQUIRE(!ge::AttrUtils::GetListNamedAttrs(compute_graph_, ATTR_NAME_DEPLOY_INFO, deployInfo),
"%s already has deployed before!", GetGraphName());
std::multimap<DeviceId, GraphInputs> deploys;
for (int64_t j = 0; j < kDeployNumber; j++) {
int64_t device_id = j;
GraphInputs graph_inputs;
// For now, only one input_node in input_nodes
for (const auto &input_node : MdsUtils::GetInputNodes()) {
GE_CHECK_NOTNULL(input_node);
GeTensorPtr graph_input = MakeShared<GeTensor>(input_node->GetOpDesc()->GetOutputDesc(0));
vector<uint8_t> data{static_cast<uint8_t>(device_id)};
graph_input->SetData(data);
graph_inputs.push_back(graph_input);
}
deploys.emplace(j, graph_inputs);
}
return MdsUtils::SetDeployInfo(compute_graph_, deploys);
}
Status ModelDeploySchedulerPass::SwapProcess() {
return SUCCESS;
}
Status ModelDeploySchedulerPass::PiplineProcess() {
return SUCCESS;
}
Status ModelDeploySchedulerPass::HcomNodeFusionProcess() {
return SUCCESS;
}
Status ModelDeploySchedulerPass::SMDPModelState() {
return SUCCESS;
}
Status ModelDeploySchedulerPass::SMDPWeight() {
return SUCCESS;
}
Status ModelDeploySchedulerPass::SMDPGradient() {
return SUCCESS;
}
} // namespace ge

+ 71
- 0
ge/graph/passes/mds_pass.h View File

@@ -0,0 +1,71 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_
#define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_
#include "graph/types.h"
#include "ge/ge_api.h"
#include "graph/debug/ge_attr_define.h"
#include "inc/graph_pass.h"
#include "./mds_kernels/base_mds_kernel.h"
#include "ge/ge_api_types.h"
#include "./mds_kernels/mds_utils.h"
namespace ge {
class ModelDeploySchedulerPass : public GraphPass {
public:
Status Run(ge::ComputeGraphPtr graph) override;
private:
// Part0:Process Func
// cut and dynamic cut
Status CutProcess();
Status CutNProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic = false);
Status CutHProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic = false);
Status DynamicCutAll(const ComputeGraphPtr &compute_graph);
// smdp
Status SMDPProcess(bool before_cut = true);
Status SMDPModelState();
Status SMDPGradient();
Status SMDPWeight();
// swap
Status SwapProcess();
// pipline
Status PiplineProcess();
// set delpoyinfo
Status SetDeployInfo();
// Part1: Utils Func
// std::vector<bool> GetNodeInputsSupportCut(NodePtr node, uint64_t cut_index);
// std::vector<bool> GetNodeOutputsSupportCut(NodePtr node, uint64_t cut_index);
Status HcomNodeFusionProcess();
Status GetAllModelStateVar();
Status GetAllWeightVar();
std::vector<NodePtr> GetAllGradComputeNodes() {
return grad_compute_nodes_;
}
const char *GetGraphName() const {
return compute_graph_->GetName().c_str();
}
// members
std::vector<NodePtr> model_state_vars_;
std::vector<NodePtr> model_weight_vars_;
std::vector<NodePtr> grad_compute_nodes_;
ComputeGraphPtr compute_graph_ = nullptr;
};
} // namespace ge
#endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_

+ 1
- 0
inc/external/ge/ge_api_types.h View File

@@ -28,6 +28,7 @@
namespace ge {
// Option key: graph run mode
const char *const OPTION_GRAPH_RUN_MODE = "ge.graphRunMode";
const char *const OPTION_DEVICE_TYPE = "ge.deviceType";

// Option key: ome init
const char *const OPTION_EXEC_SESSION_ID = "ge.exec.sessionId";


Loading…
Cancel
Save