|
|
@@ -15,6 +15,7 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "graph/passes/dynamic_single_op_reset_shape_pass.h" |
|
|
|
#include "common/ge_inner_error_codes.h" |
|
|
|
#include "graph/utils/node_utils.h" |
|
|
|
#include "graph/utils/graph_utils.h" |
|
|
|
#include "graph/utils/tensor_utils.h" |
|
|
@@ -25,12 +26,25 @@ |
|
|
|
namespace ge { |
|
|
|
namespace { |
|
|
|
const int64_t kDynamicShapeDim = -2; |
|
|
|
} |
|
|
|
const char *const kAICPUKernelLibName = "aicpu_tf_kernel"; |
|
|
|
} // namespace |
|
|
|
Status DynamicSingleOpResetShapePass::Run(ComputeGraphPtr graph) { |
|
|
|
GE_CHECK_NOTNULL(graph); |
|
|
|
|
|
|
|
std::shared_ptr<GELib> instance = ge::GELib::GetInstance(); |
|
|
|
if (instance == nullptr || !instance->InitFlag()) { |
|
|
|
GELOGE(ge::GE_CLI_GE_NOT_INITIALIZED, "Run CompileNodesPass failed."); |
|
|
|
return ge::GE_CLI_GE_NOT_INITIALIZED; |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &node : graph->GetDirectNode()) { |
|
|
|
GE_CHECK_NOTNULL(node->GetOpDesc()); |
|
|
|
if (node->GetType() == DATA || node->GetType() == NETOUTPUT) { |
|
|
|
// pass input node |
|
|
|
if (node->GetType() == DATA || node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// pass output node |
|
|
|
if (node->GetType() == NETOUTPUT) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
@@ -40,6 +54,17 @@ Status DynamicSingleOpResetShapePass::Run(ComputeGraphPtr graph) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
// pass node aicpu node. |
|
|
|
string kernel_lib_name; |
|
|
|
if (GetSupportedKernel(node, instance, kernel_lib_name) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(GRAPH_FAILED, "Get kernel lib failed of node[%s].", node->GetName().c_str()); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
if (kernel_lib_name != kAICPUKernelLibName) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
// reset aicpu shape to unknown shape |
|
|
|
auto op_desc = node->GetOpDesc(); |
|
|
|
std::vector<int64_t> dynamic_shape_dims = {kDynamicShapeDim}; |
|
|
|
GeShape dynamic_shape(dynamic_shape_dims); |
|
|
@@ -54,4 +79,70 @@ Status DynamicSingleOpResetShapePass::Run(ComputeGraphPtr graph) { |
|
|
|
GELOGD("Reset dynamic aicpu nodes shape of graph [%s] success!", graph->GetName().c_str()); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
graphStatus DynamicSingleOpResetShapePass::GetSupportedKernel(const NodePtr &node, |
|
|
|
const std::shared_ptr<GELib> instance, |
|
|
|
string &kernel_lib_name) { |
|
|
|
auto op_desc = node->GetOpDesc(); |
|
|
|
if (op_desc == nullptr) { |
|
|
|
GELOGE(ge::GE_GRAPH_PARAM_NULLPTR, "Get op %s opdesc failed", node->GetName().c_str()); |
|
|
|
return ge::GE_GRAPH_PARAM_NULLPTR; |
|
|
|
} |
|
|
|
// reset op kernel lib, find supported kernel |
|
|
|
kernel_lib_name = op_desc->GetOpKernelLibName(); |
|
|
|
if (kernel_lib_name.empty()) { |
|
|
|
(void)instance->DNNEngineManagerObj().GetDNNEngineName(node); |
|
|
|
kernel_lib_name = op_desc->GetOpKernelLibName(); |
|
|
|
if (kernel_lib_name.empty()) { |
|
|
|
GELOGE(GRAPH_FAILED, "Get node:%s, type:%s kernel lib failed.", node->GetName().c_str(), |
|
|
|
op_desc->GetType().c_str()); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
OpsKernelInfoStorePtr kernel_info = instance->OpsKernelManagerObj().GetOpsKernelInfoStore(kernel_lib_name); |
|
|
|
if (kernel_info == nullptr) { |
|
|
|
GELOGE(ge::GE_GRAPH_PARAM_NULLPTR, "Get op %s ops kernel info store failed", node->GetName().c_str()); |
|
|
|
return ge::GE_GRAPH_PARAM_NULLPTR; |
|
|
|
} |
|
|
|
// begin accuracy supported check |
|
|
|
if (!CheckAccuracySupport(kernel_info, instance, op_desc)) { |
|
|
|
// if check accuracy support failed , try to go to other engine. |
|
|
|
GELOGD("Check Accuracy Supported return not support, node name is %s. Try to go to other engine.", |
|
|
|
op_desc->GetName().c_str()); |
|
|
|
string kernel_name_origin = kernel_lib_name; |
|
|
|
OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj(); |
|
|
|
auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores(); |
|
|
|
for (auto it = kernel_map.begin(); it != kernel_map.end(); ++it) { |
|
|
|
string tmp_kernel_name = it->first; |
|
|
|
if (tmp_kernel_name == kernel_name_origin) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
OpsKernelInfoStorePtr tmp_kernel_info = it->second; |
|
|
|
if (CheckAccuracySupport(tmp_kernel_info, instance, op_desc)) { |
|
|
|
kernel_lib_name = tmp_kernel_name; |
|
|
|
GELOGD("Find kernel lib %s support node:%s, type:%s , get kernel lib success.", tmp_kernel_name.c_str(), |
|
|
|
node->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
return GRAPH_SUCCESS; |
|
|
|
} |
|
|
|
} |
|
|
|
GELOGE(GRAPH_FAILED, "Cannot find kernel lib support node:%s, type:%s , get kernel lib failed.", |
|
|
|
node->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
return GRAPH_SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
bool DynamicSingleOpResetShapePass::CheckAccuracySupport(const OpsKernelInfoStorePtr &kernel_info, |
|
|
|
const std::shared_ptr<GELib> instance, OpDescPtr &op_desc) { |
|
|
|
auto ge_desc = MakeShared<ge::OpDescPtr>(op_desc); |
|
|
|
if (ge_desc == nullptr) { |
|
|
|
GELOGE(GE_GRAPH_MEMORY_ALLOC_FAILED, "Fail to malloc op desc."); |
|
|
|
return false; |
|
|
|
} |
|
|
|
string reason; |
|
|
|
if (!(kernel_info->CheckAccuracySupported(*ge_desc, reason, true))) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
} // namespace ge |