Browse Source

fix dump

tags/v1.2.0
wjm 4 years ago
parent
commit
46156bf04f
4 changed files with 9 additions and 3 deletions
  1. +1
    -0
      ge/common/helper/model_helper.cc
  2. +3
    -3
      ge/hybrid/model/hybrid_model_builder.cc
  3. +3
    -0
      ge/model/ge_root_model.h
  4. +2
    -0
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 1
- 0
ge/common/helper/model_helper.cc View File

@@ -599,6 +599,7 @@ Status ModelHelper::GenerateGeRootModel(OmFileLoadHelper &om_load_helper) {
is_first_model = false;
root_model_->SetRootGraph(GraphUtils::GetComputeGraph(cur_model->GetGraph()));
root_model_->SetModelId(cur_model->GetModelId());
root_model_->SetModelName(cur_model->GetName());
model_ = cur_model;
continue;
}


+ 3
- 3
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -130,7 +130,7 @@ HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model)

Status HybridModelBuilder::Build() {
GE_CHK_STATUS_RET(ValidateParams(), "Failed to validate GeRootModel");
hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName();
hybrid_model_.model_name_ = ge_root_model_->GetName();
GELOGI("[%s] Start to build hybrid model.", GetGraphName());
GE_CHK_STATUS_RET(InitRuntimeParams(), "[%s] Failed to InitRuntimeParams", GetGraphName());
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[%s] Failed to RecoverGraphUnknownFlag", GetGraphName());
@@ -154,7 +154,7 @@ Status HybridModelBuilder::Build() {

Status HybridModelBuilder::BuildForSingleOp() {
GE_CHK_STATUS_RET(ValidateParams(), "Failed to validate GeRootModel");
hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName();
hybrid_model_.model_name_ = ge_root_model_->GetName();
GELOGI("[%s] Start to build hybrid model.", GetGraphName());
auto ret = ge_root_model_->GetSubgraphInstanceNameToModel();
const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()];
@@ -272,7 +272,7 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
// not care result, if no this attr, stand for the op does not need force infershape
(void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape);
(void) AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape);
GELOGD("node [%s] is need do infershape , flag is %d",
op_desc->GetName().c_str(),
node_item.is_need_force_infershape);


+ 3
- 0
ge/model/ge_root_model.h View File

@@ -35,12 +35,15 @@ class GeRootModel {
const ComputeGraphPtr &GetRootGraph() const { return root_graph_; };
void SetModelId(uint32_t model_id) { model_id_ = model_id; }
uint32_t GetModelId() const { return model_id_; }
void SetModelName(const std::string &model_name) { model_name_ = model_name; }
const std::string &GetModelName() const { return model_name_; }
Status CheckIsUnknownShape(bool &is_dynamic_shape);
void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; }
private:
ComputeGraphPtr root_graph_ = nullptr;
std::map<std::string, GeModelPtr> subgraph_instance_name_to_model_;
uint32_t model_id_ = 0;
std::string model_name_;
};
} // namespace ge
using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>;


+ 2
- 0
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -150,9 +150,11 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) {

ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
ge_root_model->SetModelName("test_name");
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);

ASSERT_EQ(hybrid_model_builder.Build(), INTERNAL_ERROR);
ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR);
}



Loading…
Cancel
Save