Browse Source

Pre Merge pull request !2034 from zhaoxinxin/master

pull/2034/MERGE
zhaoxinxin Gitee 3 years ago
parent
commit
807e93487e
1 changed files with 10 additions and 5 deletions
  1. +10
    -5
      ge/graph/passes/infershape_pass.cc

+ 10
- 5
ge/graph/passes/infershape_pass.cc View File

@@ -138,7 +138,9 @@ graphStatus InferShapePass::InferShapeAndType(NodePtr &node) {
if (!is_unknown_graph) {
auto inference_context = ShapeRefiner::CreateInferenceContext(node);
GE_CHECK_NOTNULL(inference_context);
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
std::vector<AscendString> marks;
inference_context->GetMarks(marks);
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), marks.size());
op.SetInferenceContext(inference_context);
}
@@ -151,10 +153,12 @@ graphStatus InferShapePass::InferShapeAndType(NodePtr &node) {
if (!is_unknown_graph) {
auto ctx_after_infer = op.GetInferenceContext();
if (ctx_after_infer != nullptr) {
GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
std::vector<AscendString> marks;
ctx_after_infer->GetMarks(marks);
GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), marks.size());
if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !marks.empty()) {
GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(),
ctx_after_infer->GetMarks().size());
marks.size());
ShapeRefiner::PushToContextMap(node, ctx_after_infer);
}
}
@@ -254,7 +258,8 @@ graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) {
auto ret = op_desc->CallInferFunc(op);
if (ret == GRAPH_PARAM_INVALID) {
// Op ir no infer func, try to get infer func from operator factory
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType());
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType().c_str());
if (node_op.IsEmpty()) {
GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
return ret;


Loading…
Cancel
Save