/** * \file imperative/src/impl/ops/elemwise.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/imperative/ops/autogen.h" #include "megbrain/opr/basic_arith.h" #include "../op_trait.h" namespace mgb { namespace imperative { namespace { std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { auto* node = &node_->cast_final_safe(); return Elemwise::make(node->param().mode); } cg::OperatorNodeBase* apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& elemwise_opr = def.cast_final_safe(); return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr(); } std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto&& op_def = def.cast_final_safe(); auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); mgb_assert(inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually", trait.name, trait.arity, inputs.size()); TensorShapeArray inp_shapes; DType out_dt; CompNode out_cn; for (size_t i = 0; i < inputs.size(); ++ i) { auto &&t = inputs[i]; if (!i) { out_cn = t.comp_node; out_dt = t.layout.dtype; } else { mgb_assert(t.comp_node == out_cn); mgb_assert(t.layout.dtype == out_dt); } if (t.layout.ndim > 0) { inp_shapes.push_back(t.layout); } else { TensorLayout out_layout; out_layout.ndim = 0; out_layout.dtype = out_dt; return {{{out_layout, out_cn}}, true}; } } auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true}; } OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) .fallback(); } // anonymous namespace } // namespace imperative } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}