|
@@ -138,7 +138,9 @@ graphStatus InferShapePass::InferShapeAndType(NodePtr &node) { |
|
|
if (!is_unknown_graph) {
|
|
|
if (!is_unknown_graph) {
|
|
|
auto inference_context = ShapeRefiner::CreateInferenceContext(node);
|
|
|
auto inference_context = ShapeRefiner::CreateInferenceContext(node);
|
|
|
GE_CHECK_NOTNULL(inference_context);
|
|
|
GE_CHECK_NOTNULL(inference_context);
|
|
|
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
|
|
|
|
|
|
|
|
|
std::vector<AscendString> marks;
|
|
|
|
|
|
inference_context->GetMarks(marks);
|
|
|
|
|
|
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), marks.size());
|
|
|
op.SetInferenceContext(inference_context);
|
|
|
op.SetInferenceContext(inference_context);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
@@ -151,10 +153,12 @@ graphStatus InferShapePass::InferShapeAndType(NodePtr &node) { |
|
|
if (!is_unknown_graph) {
|
|
|
if (!is_unknown_graph) {
|
|
|
auto ctx_after_infer = op.GetInferenceContext();
|
|
|
auto ctx_after_infer = op.GetInferenceContext();
|
|
|
if (ctx_after_infer != nullptr) {
|
|
|
if (ctx_after_infer != nullptr) {
|
|
|
GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
|
|
|
|
|
|
if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
|
|
|
|
|
|
|
|
|
std::vector<AscendString> marks;
|
|
|
|
|
|
ctx_after_infer->GetMarks(marks);
|
|
|
|
|
|
GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), marks.size());
|
|
|
|
|
|
if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !marks.empty()) {
|
|
|
GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(),
|
|
|
GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(),
|
|
|
ctx_after_infer->GetMarks().size());
|
|
|
|
|
|
|
|
|
marks.size());
|
|
|
ShapeRefiner::PushToContextMap(node, ctx_after_infer);
|
|
|
ShapeRefiner::PushToContextMap(node, ctx_after_infer);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -254,7 +258,8 @@ graphStatus InferShapePass::CallInferShapeFunc(NodePtr &node, Operator &op) { |
|
|
auto ret = op_desc->CallInferFunc(op);
|
|
|
auto ret = op_desc->CallInferFunc(op);
|
|
|
if (ret == GRAPH_PARAM_INVALID) {
|
|
|
if (ret == GRAPH_PARAM_INVALID) {
|
|
|
// Op ir no infer func, try to get infer func from operator factory
|
|
|
// Op ir no infer func, try to get infer func from operator factory
|
|
|
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType().c_str());
|
|
|
if (node_op.IsEmpty()) {
|
|
|
if (node_op.IsEmpty()) {
|
|
|
GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
|
|
|
GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
|
|
|
return ret;
|
|
|
return ret;
|
|
|