Browse Source

modified: ge/graph/optimize/mem_rw_conflict_optimize.cc

modified:   ge/host_kernels/strided_slice_kernel.cc
pull/383/head
zhaoxinxin 4 years ago
parent
commit
b9fe1ad1fe
2 changed files with 6 additions and 4 deletions
  1. +4
    -1
      ge/graph/optimize/mem_rw_conflict_optimize.cc
  2. +2
    -3
      ge/host_kernels/strided_slice_kernel.cc

+ 4
- 1
ge/graph/optimize/mem_rw_conflict_optimize.cc View File

@@ -625,7 +625,10 @@ Status InsertIdentityAsNeeded(const NodePtr &node) {
}
Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) {
for (const auto &node : compute_graph->GetDirectNode()) {
if (node->GetType() == HCOMALLREDUCE) {
bool is_input_continuous = false;
GE_CHECK_NOTNULL(node->GetOpDesc());
(void) ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous);
if (is_input_continuous) {
std::set<OutDataAnchorPtr> pre_out_anchor_set;
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();


+ 2
- 3
ge/host_kernels/strided_slice_kernel.cc View File

@@ -298,7 +298,7 @@ void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num,
end_mask *= end_mask * (kMaskBitLeftUnit << (x_dims_num - orig_end_vec.size() - 1));
attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK) = end_mask;
}
for (auto i = 0; i < x_dims_num; ++i) {
for (size_t i = 0; i < x_dims_num; ++i) {
bool ellipsis_mask_flag = attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) & (kMaskBitLeftUnit << i);
if (ellipsis_mask_flag) {
auto ellipsis_dim = i;
@@ -306,7 +306,7 @@ void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num,
orig_end_vec[i] = x_dims.at(i);
orig_stride_vec[i] = 1;
if (orig_begin_vec.size() < x_dims_num) {
for (auto j = 1; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) {
for (size_t j = 1; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) {
orig_begin_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 0);
orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim +j));
orig_stride_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 1);
@@ -321,7 +321,6 @@ Status StridedSliceKernel::MaskCal(const size_t i, int64_t &begin_i, int64_t &en
auto i_temp = static_cast<uint32_t>(i);
bool begin_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) & (kMaskBitLeftUnit << i_temp));
bool end_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK) & (kMaskBitLeftUnit << i_temp));
bool ellipsis_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) & (kMaskBitLeftUnit << i_temp));
bool shrink_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK) & (kMaskBitLeftUnit << i_temp));
if (shrink_mask_flag) {
begin_i = (begin_i < 0 ? (dim_i + begin_i) : begin_i);


Loading…
Cancel
Save