From 9de1ea6a4881894a4ebf57fc96fb2cd06d40c326 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 14 Dec 2020 13:56:44 +0800 Subject: [PATCH] perf(imperative): add apply_on_physical_tensor for Elemwise GitOrigin-RevId: 27087d90e431d0fbdb0439827f8bf6088781f6a5 --- imperative/src/impl/ops/elemwise.cpp | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index 40c6aeab..4811150f 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -65,10 +65,30 @@ std::tuple, bool> infer_output_attrs_fallible( return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true}; } +SmallVector apply_on_physical_tensor( + const OpDef& def, + const SmallVector& inputs) { + auto&& op_def = def.cast_final_safe(); + auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); + mgb_assert(inputs.size() == trait.arity, + "%s expects %u inputs; got %zu actually", trait.name, + trait.arity, inputs.size()); + + DeviceTensorND out; + SmallVector dt_inputs(inputs.size()); + for (unsigned i = 0; i < inputs.size(); ++i){ + dt_inputs[i] = inputs[i]->dev_tensor(); + } + auto&& dnn_opr = opr::intl::create_megdnn_opr(inputs[0]->comp_node()); + opr::Elemwise::perform(op_def.mode, out, dt_inputs, dnn_opr); + return {Tensor::make(out)}; +} + OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); } // anonymous namespace