Browse Source

perf(imperative): add apply_on_physical_tensor for Elemwise

GitOrigin-RevId: 27087d90e4
release-1.2
Megvii Engine Team 4 years ago
parent
commit
9de1ea6a48
1 changed files with 20 additions and 0 deletions
  1. +20
    -0
      imperative/src/impl/ops/elemwise.cpp

+ 20
- 0
imperative/src/impl/ops/elemwise.cpp View File

@@ -65,10 +65,30 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true};
}

SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto&& op_def = def.cast_final_safe<Elemwise>();
auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
mgb_assert(inputs.size() == trait.arity,
"%s expects %u inputs; got %zu actually", trait.name,
trait.arity, inputs.size());

DeviceTensorND out;
SmallVector<DeviceTensorND> dt_inputs(inputs.size());
for (unsigned i = 0; i < inputs.size(); ++i){
dt_inputs[i] = inputs[i]->dev_tensor();
}
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0]->comp_node());
opr::Elemwise::perform(op_def.mode, out, dt_inputs, dnn_opr);
return {Tensor::make(out)};
}

OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // anonymous namespace



Loading…
Cancel
Save