|
|
@@ -65,10 +65,30 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( |
|
|
|
return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true}; |
|
|
|
} |
|
|
|
|
|
|
|
SmallVector<TensorPtr> apply_on_physical_tensor( |
|
|
|
const OpDef& def, |
|
|
|
const SmallVector<TensorPtr>& inputs) { |
|
|
|
auto&& op_def = def.cast_final_safe<Elemwise>(); |
|
|
|
auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); |
|
|
|
mgb_assert(inputs.size() == trait.arity, |
|
|
|
"%s expects %u inputs; got %zu actually", trait.name, |
|
|
|
trait.arity, inputs.size()); |
|
|
|
|
|
|
|
DeviceTensorND out; |
|
|
|
SmallVector<DeviceTensorND> dt_inputs(inputs.size()); |
|
|
|
for (unsigned i = 0; i < inputs.size(); ++i){ |
|
|
|
dt_inputs[i] = inputs[i]->dev_tensor(); |
|
|
|
} |
|
|
|
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0]->comp_node()); |
|
|
|
opr::Elemwise::perform(op_def.mode, out, dt_inputs, dnn_opr); |
|
|
|
return {Tensor::make(out)}; |
|
|
|
} |
|
|
|
|
|
|
|
OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) |
|
|
|
.make_from_op_node(make_from_op_node) |
|
|
|
.apply_on_var_node(apply_on_var_node) |
|
|
|
.infer_output_attrs_fallible(infer_output_attrs_fallible) |
|
|
|
.apply_on_physical_tensor(apply_on_physical_tensor) |
|
|
|
.fallback(); |
|
|
|
} // anonymous namespace |
|
|
|
|
|
|
|