Browse Source

Feature: reset shape of dynamic single op

pull/565/head
l00444296 4 years ago
parent
commit
201680cb59
2 changed files with 5 additions and 2 deletions
  1. +4
    -1
      ge/graph/passes/dynamic_single_op_reset_shape_pass.cc
  2. +1
    -1
      ge/graph/passes/dynamic_single_op_reset_shape_pass.h

+ 4
- 1
ge/graph/passes/dynamic_single_op_reset_shape_pass.cc View File

@@ -23,6 +23,9 @@
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"


namespace ge { namespace ge {
namespace {
const int64_t kDynamicShapeDim = -2;
}
Status DynamicSingleOpResetShapePass::Run(ComputeGraphPtr graph) { Status DynamicSingleOpResetShapePass::Run(ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph); GE_CHECK_NOTNULL(graph);
for (const auto &node : graph->GetDirectNode()) { for (const auto &node : graph->GetDirectNode()) {
@@ -38,7 +41,7 @@ Status DynamicSingleOpResetShapePass::Run(ComputeGraphPtr graph) {
} }


auto op_desc = node->GetOpDesc(); auto op_desc = node->GetOpDesc();
std::vector<int64_t> dynamic_shape_dims = {-2};
std::vector<int64_t> dynamic_shape_dims = {kDynamicShapeDim};
GeShape dynamic_shape(dynamic_shape_dims); GeShape dynamic_shape(dynamic_shape_dims);
for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) { for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) {
auto input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(i)); auto input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(i));


+ 1
- 1
ge/graph/passes/dynamic_single_op_reset_shape_pass.h View File

@@ -22,7 +22,7 @@
namespace ge { namespace ge {
class DynamicSingleOpResetShapePass : public GraphPass { class DynamicSingleOpResetShapePass : public GraphPass {
public: public:
Status Run(ComputeGraphPtr graph);
Status Run(ComputeGraphPtr graph) override;
}; };
} // namespace ge } // namespace ge
#endif // GE_GRAPH_PASSES_DYNAMIC_SINGLE_OP_RESET_SHAPE_PASS_H_ #endif // GE_GRAPH_PASSES_DYNAMIC_SINGLE_OP_RESET_SHAPE_PASS_H_

Loading…
Cancel
Save