#include "megbrain/graph/symbol_var.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/physical_tensor.h" #include "megbrain/imperative/proxy_graph_detail.h" #include "megbrain/opr/basic_arith.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/io.h" #include "megbrain/opr/tensor_manip.h" #include "megdnn/dtype.h" #include "../blob_manager_impl.h" #include "../dnn_op_helper.h" #include "../op_trait.h" namespace mgb { namespace imperative { namespace { namespace padding { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); return opr::Padding::make(inputs[0], op.param()); } SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { auto comp_node = inputs[0]->comp_node(); auto&& op_def = def.cast_final_safe(); DnnOprCaller dnn_op(comp_node, op_def.param()); auto dst = [&] { if (validated) { return output_descs[0].layout; } else { return dnn_op.deduce_layout(inputs[0]->layout()); } }(); auto out = Tensor::make(dst, comp_node); dnn_op.exec(inputs[0], out); return {out}; } std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto&& op_def = def.cast_final_safe(); auto&& inp = inputs[0]; if (inp.layout.ndim == 0) { return {{{TensorLayout{inp.layout.dtype}, inp.comp_node, {}}}, false}; } DnnOprHelper dnn_op(op_def.param()); auto oup_layout = dnn_op.deduce_layout(inp.layout); return {{{oup_layout, inp.comp_node}}, true}; } OP_TRAIT_REG(Padding, Padding, opr::Padding) .apply_on_var_node(apply_on_var_node) .apply_on_physical_tensor(apply_on_physical_tensor) .infer_output_attrs_fallible(infer_output_attrs_fallible) .fallback(); } // namespace padding } // namespace } // namespace imperative } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}