Browse Source

perf(ops): enable memory forward for reduce in special cases

GitOrigin-RevId: dd6e1664c5
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
e1c7b22ff0
1 changed files with 40 additions and 0 deletions
  1. +40
    -0
      imperative/src/impl/ops/reduce.cpp

+ 40
- 0
imperative/src/impl/ops/reduce.cpp View File

@@ -11,6 +11,7 @@

#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/imperative/proxy_graph_detail.h"

#include "../op_trait.h"
#include "../dnn_op_helper.h"
@@ -35,9 +36,48 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
return Reduce::make(node->param());
}

bool memory_forward_success(
const OpDef& def,
SmallVector<TensorPtr> inputs) {
auto&& reduce = static_cast<const Reduce&>(def);
if (reduce.mode != Reduce::Mode::SUM_SQR && inputs.size() == 2) {
auto shape_tensor = inputs[1]->get_value();
TensorShape shape;
cg::copy_tensor_value_to_shape(shape, shape_tensor.proxy_to_default_cpu());
if (shape.eq_shape(inputs[0]->shape())) {
return true;
}
}
return false;
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
if (memory_forward_success(def, inputs_tensors)) {
auto& src_desc = inputs_mems[0];
return {{{src_desc.layout, 0, src_desc.cn, StorageIdentifier::make(&src_desc)}}, {}};
}
return proxy_graph_detail::infer_output_mem_desc(def, inputs_tensors, inputs_mems);
}


void execute(const OpDef& def,
SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
if (memory_forward_success(def, inputs)) {
return;
}
return proxy_graph_detail::execute(def, inputs, outputs, workspace);
}

OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback();
} // namespace reduce
} // namespace


Loading…
Cancel
Save