GitOrigin-RevId: 8c2b6a2aed
tags/v1.8.0
@@ -158,70 +158,71 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
MGB_DEFINE_OPR_CLASS( | MGB_DEFINE_OPR_CLASS( | ||||
ForceInplaceElemwise, | ForceInplaceElemwise, | ||||
cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{ | |||||
cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) // { | |||||
public: | public: | ||||
struct Param { | |||||
using Mode = megdnn::Elemwise::Param::Mode; | |||||
Mode mode; | |||||
size_t inplace_index; | |||||
}; | |||||
using Mode = Param::Mode; | |||||
ForceInplaceElemwise( | |||||
const VarNodeArray& inputs, Param param, OperatorNodeConfig config = {}) | |||||
: Super(inputs[0]->owner_graph(), config, "device_add_update", inputs), | |||||
m_param{param} { | |||||
for (auto* input : inputs) { | |||||
add_input({input}); | |||||
struct Param { | |||||
using Mode = megdnn::Elemwise::Param::Mode; | |||||
Mode mode; | |||||
size_t inplace_index; | |||||
}; | |||||
using Mode = Param::Mode; | |||||
ForceInplaceElemwise( | |||||
const VarNodeArray& inputs, Param param, OperatorNodeConfig config = {}) | |||||
: Super(inputs[0]->owner_graph(), config, "device_add_update", inputs), | |||||
m_param{param} { | |||||
for (auto* input : inputs) { | |||||
add_input({input}); | |||||
} | |||||
add_output(None) | |||||
->set_fwd_in2out_writable_force(input(param.inplace_index)) | |||||
.add_flag(VarNode::Flag::NO_MEM_RECLAIM); | |||||
} | } | ||||
add_output(None) | |||||
->set_fwd_in2out_writable_force(input(param.inplace_index)) | |||||
.add_flag(VarNode::Flag::NO_MEM_RECLAIM); | |||||
} | |||||
static SymbolVar make(const VarNodeArray& inputs, Param param) { | |||||
return SymbolVar{inputs[0]}.insert_single_output_opr<ForceInplaceElemwise>( | |||||
inputs, param); | |||||
} | |||||
static cg::OperatorNodeBase* shallow_copy( | |||||
const serialization::OprShallowCopyContext& ctx, | |||||
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVar make(const VarNodeArray& inputs, Param param) { | |||||
return SymbolVar{inputs[0]}.insert_single_output_opr<ForceInplaceElemwise>( | |||||
inputs, param); | |||||
} | |||||
static cg::OperatorNodeBase* shallow_copy( | |||||
const serialization::OprShallowCopyContext& ctx, | |||||
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config); | |||||
protected: | protected: | ||||
NodeProp* do_make_node_prop() const override { | |||||
auto ret = Super::do_make_node_prop(); | |||||
ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); | |||||
return ret; | |||||
} | |||||
void create_megdnn_opr() override { | |||||
auto opr = DnnOprCaller<megdnn::Elemwise>::create_operator(comp_node()); | |||||
opr->param().mode = m_param.mode; | |||||
set_megdnn_opr(std::move(opr)); | |||||
} | |||||
void scn_do_execute() override { | |||||
auto to_dnnnd = [&](auto* var) { return var->dev_tensor().as_megdnn(); }; | |||||
megdnn::TensorNDArray inputs_dnnnd; | |||||
for (auto* input : input()) { | |||||
inputs_dnnnd.push_back(to_dnnnd(input)); | |||||
NodeProp* do_make_node_prop() const override { | |||||
auto ret = Super::do_make_node_prop(); | |||||
ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); | |||||
return ret; | |||||
} | } | ||||
mgb_assert( | |||||
input(m_param.inplace_index)->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC), | |||||
"ForceInplaceElemwise cannot be applied in internal tensor"); | |||||
auto* out_dest = output(0); | |||||
auto* opr = static_cast<megdnn::Elemwise*>(megdnn_opr()); | |||||
opr->exec(std::move(inputs_dnnnd), to_dnnnd(out_dest)); | |||||
} | |||||
void init_output_static_infer_desc() override { | |||||
using namespace cg::static_infer; | |||||
void create_megdnn_opr() override { | |||||
auto opr = DnnOprCaller<megdnn::Elemwise>::create_operator(comp_node()); | |||||
opr->param().mode = m_param.mode; | |||||
set_megdnn_opr(std::move(opr)); | |||||
} | |||||
void scn_do_execute() override { | |||||
auto to_dnnnd = [&](auto* var) { return var->dev_tensor().as_megdnn(); }; | |||||
megdnn::TensorNDArray inputs_dnnnd; | |||||
for (auto* input : input()) { | |||||
inputs_dnnnd.push_back(to_dnnnd(input)); | |||||
} | |||||
mgb_assert( | |||||
input(m_param.inplace_index) | |||||
->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC), | |||||
"ForceInplaceElemwise cannot be applied in internal tensor"); | |||||
auto* out_dest = output(0); | |||||
auto* opr = static_cast<megdnn::Elemwise*>(megdnn_opr()); | |||||
opr->exec(std::move(inputs_dnnnd), to_dnnnd(out_dest)); | |||||
} | |||||
void init_output_static_infer_desc() override { | |||||
using namespace cg::static_infer; | |||||
owner_graph()->static_infer_manager().register_shape_infer( | |||||
output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index))); | |||||
} | |||||
owner_graph()->static_infer_manager().register_shape_infer( | |||||
output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index))); | |||||
} | |||||
private: | private: | ||||
Param m_param; | |||||
void record_execute_deps(ExecDependencyArray& deps) override { | |||||
record_megdnn_opr(deps); | |||||
} | |||||
Param m_param; | |||||
void record_execute_deps(ExecDependencyArray& deps) override { | |||||
record_megdnn_opr(deps); | |||||
} | |||||
}; | }; | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ForceInplaceElemwise); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(ForceInplaceElemwise); | ||||
@@ -1013,13 +1013,13 @@ using OprNodeArray = SmallVector<OperatorNodeBase*>; | |||||
* | * | ||||
* Note that opening brace is included | * Note that opening brace is included | ||||
*/ | */ | ||||
#define MGB_DEFINE_OPR_CLASS(_name, _base, ...) \ | |||||
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \ | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
#define MGB_DEFINE_OPR_CLASS(_name, _base, ...) \ | |||||
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \ | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
#define MGB_DEFINE_OPR_CLASS_WITH_EXPORT(_name, _base, ...) \ | |||||
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \ | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT; | |||||
#define MGB_DEFINE_OPR_CLASS_WITH_EXPORT(_name, _base, ...) \ | |||||
MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \ | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT; | |||||
} // namespace cg | } // namespace cg | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -495,18 +495,18 @@ private: | |||||
} // namespace mgb | } // namespace mgb | ||||
#define _MGB_DEFINE_CLS_WITH_SUPER_IMPL(_tpl, _name, _base, ...) \ | |||||
class _name : public _base, ##__VA_ARGS__ { \ | |||||
public: \ | |||||
using Super = _tpl _base; \ | |||||
\ | |||||
#define MGB_DEFINE_CLS_WITH_SUPER_IMPL(_tpl, _name, _base, ...) \ | |||||
class _name : public _base, ##__VA_ARGS__ { \ | |||||
public: \ | |||||
using Super = _tpl _base; \ | |||||
\ | |||||
private: | private: | ||||
/*! | /*! | ||||
* \brief define a class which has Super defined to base | * \brief define a class which has Super defined to base | ||||
*/ | */ | ||||
#define MGB_DEFINE_CLS_WITH_SUPER(_name, _base, ...) \ | #define MGB_DEFINE_CLS_WITH_SUPER(_name, _base, ...) \ | ||||
_MGB_DEFINE_CLS_WITH_SUPER_IMPL(, _name, _base, ##__VA_ARGS__) | |||||
MGB_DEFINE_CLS_WITH_SUPER_IMPL(, _name, _base, ##__VA_ARGS__) | |||||
/*! | /*! | ||||
* \brief define a class which has Super defined to base | * \brief define a class which has Super defined to base | ||||
@@ -514,5 +514,5 @@ private: | |||||
* Used when this class is a template and base class has template | * Used when this class is a template and base class has template | ||||
*/ | */ | ||||
#define MGB_DEFINE_CLS_WITH_SUPER_TPL(_name, _base, ...) \ | #define MGB_DEFINE_CLS_WITH_SUPER_TPL(_name, _base, ...) \ | ||||
_MGB_DEFINE_CLS_WITH_SUPER_IMPL(typename, _name, _base, ##__VA_ARGS__) | |||||
MGB_DEFINE_CLS_WITH_SUPER_IMPL(typename, _name, _base, ##__VA_ARGS__) | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -99,7 +99,7 @@ float GraphPartitionProfiler::duration_in_usec() const { | |||||
* \brief An operator that indicates its input var node is contiguous | * \brief An operator that indicates its input var node is contiguous | ||||
*/ | */ | ||||
// clang-format off | // clang-format off | ||||
MGB_DEFINE_OPR_CLASS(MarkInputContiguous, SingleCNOperatorNodeBase) //{ | |||||
MGB_DEFINE_OPR_CLASS(MarkInputContiguous, SingleCNOperatorNodeBase) // { | |||||
void scn_do_execute() override {}; | void scn_do_execute() override {}; | ||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
void add_input_layout_constraint() override { | void add_input_layout_constraint() override { | ||||
@@ -20,38 +20,38 @@ namespace opr { | |||||
MGB_DEFINE_OPR_CLASS( | MGB_DEFINE_OPR_CLASS( | ||||
PoolingForward, intl::MegDNNOprWrapperFwd<megdnn::PoolingForward>, | PoolingForward, intl::MegDNNOprWrapperFwd<megdnn::PoolingForward>, | ||||
public mixin::AlgoChooserHelper) //{ | |||||
public mixin::AlgoChooserHelper) // { | |||||
public: | public: | ||||
MGE_WIN_DECLSPEC_FUC PoolingForward( | |||||
VarNode* src, const Param& param, const ExecutionPolicy& policy, | |||||
const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
SymbolVar src, const Param& param, const ExecutionPolicy& policy = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
void init_output_static_infer_desc() override; | |||||
size_t get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const override; | |||||
MGE_WIN_DECLSPEC_FUC PoolingForward( | |||||
VarNode* src, const Param& param, const ExecutionPolicy& policy, | |||||
const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
SymbolVar src, const Param& param, const ExecutionPolicy& policy = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
void init_output_static_infer_desc() override; | |||||
size_t get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const override; | |||||
}; | }; | ||||
using Pooling = PoolingForward; | using Pooling = PoolingForward; | ||||
MGB_DEFINE_OPR_CLASS( | MGB_DEFINE_OPR_CLASS( | ||||
PoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::PoolingBackward>, | PoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::PoolingBackward>, | ||||
public mixin::AlgoChooserHelper) //{ | |||||
public mixin::AlgoChooserHelper) // { | |||||
public: | public: | ||||
MGE_WIN_DECLSPEC_FUC PoolingBackward( | |||||
VarNode* src, VarNode* dst, VarNode* diff, const Param& param, | |||||
const ExecutionPolicy& policy, const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC PoolingBackward( | |||||
VarNode* src, VarNode* dst, VarNode* diff, const Param& param, | |||||
const ExecutionPolicy& policy, const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param, | |||||
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param, | |||||
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); | |||||
MGE_WIN_DECLSPEC_FUC size_t get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const override final; | |||||
MGE_WIN_DECLSPEC_FUC size_t get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const override final; | |||||
}; | }; | ||||
} // namespace opr | } // namespace opr | ||||
@@ -86,7 +86,7 @@ MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint_contig(OperatorNodeBase& o | |||||
//! called in constructor to add output vars | //! called in constructor to add output vars | ||||
MGE_WIN_DECLSPEC_FUC void add_output_vars( | MGE_WIN_DECLSPEC_FUC void add_output_vars( | ||||
OperatorNodeBase& opr, size_t nr_output, bool add_workspace); | OperatorNodeBase& opr, size_t nr_output, bool add_workspace); | ||||
} | |||||
} // namespace megdnn_utils | |||||
/*! | /*! | ||||
* \brief mixin for infer workspace size based on input and output shapes | * \brief mixin for infer workspace size based on input and output shapes | ||||
@@ -344,34 +344,34 @@ private: | |||||
} // namespace mgb | } // namespace mgb | ||||
//! define a megdnn opr wrapper class with 1 input for forward | //! define a megdnn opr wrapper class with 1 input for forward | ||||
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(_name) \ | |||||
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \ | |||||
public: \ | |||||
_name(VarNode* p0, const Param& param, const OperatorNodeConfig& config); \ | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
SymbolVar p0, const Param& param = {}, \ | |||||
const OperatorNodeConfig& config = {}); \ | |||||
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(_name) \ | |||||
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \ | |||||
public: \ | |||||
_name(VarNode* p0, const Param& param, const OperatorNodeConfig& config); \ | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
SymbolVar p0, const Param& param = {}, \ | |||||
const OperatorNodeConfig& config = {}); \ | |||||
} | } | ||||
//! define a megdnn opr wrapper class with 2 inputs for forward | //! define a megdnn opr wrapper class with 2 inputs for forward | ||||
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD2(_name) \ | |||||
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \ | |||||
public: \ | |||||
_name(VarNode* p0, VarNode* p1, const Param& param, \ | |||||
const OperatorNodeConfig& config); \ | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
SymbolVar p0, SymbolVar p1, const Param& param = {}, \ | |||||
const OperatorNodeConfig& config = {}); \ | |||||
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD2(_name) \ | |||||
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \ | |||||
public: \ | |||||
_name(VarNode* p0, VarNode* p1, const Param& param, \ | |||||
const OperatorNodeConfig& config); \ | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
SymbolVar p0, SymbolVar p1, const Param& param = {}, \ | |||||
const OperatorNodeConfig& config = {}); \ | |||||
} | } | ||||
//! define a megdnn opr wrapper class with 3 inputs for grad | //! define a megdnn opr wrapper class with 3 inputs for grad | ||||
#define MGB_DEFINE_MEGDNN_OPR_WRAPPER_BWD3(_name, _extra...) \ | #define MGB_DEFINE_MEGDNN_OPR_WRAPPER_BWD3(_name, _extra...) \ | ||||
MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperBwd<megdnn::_name>) \ | MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperBwd<megdnn::_name>) \ | ||||
_extra public : _name(VarNode* p0, VarNode* p1, VarNode* p2, const Param& param, \ | |||||
const OperatorNodeConfig& config); \ | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
SymbolVar p0, SymbolVar p1, SymbolVar p2, const Param& param = {}, \ | |||||
const OperatorNodeConfig& config = {}); \ | |||||
_extra public : _name(VarNode* p0, VarNode* p1, VarNode* p2, \ | |||||
const Param& param, const OperatorNodeConfig& config); \ | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
SymbolVar p0, SymbolVar p1, SymbolVar p2, const Param& param = {}, \ | |||||
const OperatorNodeConfig& config = {}); \ | |||||
} | } | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -40,25 +40,25 @@ protected: | |||||
}; | }; | ||||
/* ================= RNG with shape ================= */ | /* ================= RNG with shape ================= */ | ||||
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ | |||||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \ | |||||
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||||
\ | |||||
public: \ | |||||
RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \ | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
SymbolVar shape, const Param& param = {}, \ | |||||
const OperatorNodeConfig& config = {}); \ | |||||
static SymbolVar make( \ | |||||
ComputingGraph& graph, const TensorShape& shape, \ | |||||
const OperatorNodeConfig& config, const Param& param = {}) { \ | |||||
return make( \ | |||||
var_from_tensor_shape(graph, config, "rng", shape), param, config); \ | |||||
} \ | |||||
void init_output_static_infer_desc() override; \ | |||||
void scn_do_execute() override; \ | |||||
} \ | |||||
; | |||||
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ | |||||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \ | |||||
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||||
\ | |||||
public: \ | |||||
RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \ | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
SymbolVar shape, const Param& param = {}, \ | |||||
const OperatorNodeConfig& config = {}); \ | |||||
static SymbolVar make( \ | |||||
ComputingGraph& graph, const TensorShape& shape, \ | |||||
const OperatorNodeConfig& config, const Param& param = {}) { \ | |||||
return make( \ | |||||
var_from_tensor_shape(graph, config, "rng", shape), param, \ | |||||
config); \ | |||||
} \ | |||||
void init_output_static_infer_desc() override; \ | |||||
void scn_do_execute() override; \ | |||||
}; | |||||
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG) | _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG) | ||||
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG) | _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG) | ||||
@@ -66,20 +66,19 @@ _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(PermutationRNG) | |||||
#undef _DEFINE_RNG_OPR_WITH_SHAPE_CLASS | #undef _DEFINE_RNG_OPR_WITH_SHAPE_CLASS | ||||
/* ================= RNG with input ================= */ | /* ================= RNG with input ================= */ | ||||
#define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \ | |||||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \ | |||||
void add_input_layout_constraint() override; \ | |||||
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||||
\ | |||||
public: \ | |||||
RNG(_INPUTS(VarNode*), const Param& param, const OperatorNodeConfig& config); \ | |||||
MGE_WIN_DECLSPEC_FUC static _OUTPUTS make( \ | |||||
_INPUTS(SymbolVar), const Param& param = {}, \ | |||||
const OperatorNodeConfig& config = {}); \ | |||||
void init_output_static_infer_desc() override; \ | |||||
void scn_do_execute() override; \ | |||||
} \ | |||||
; | |||||
#define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \ | |||||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \ | |||||
void add_input_layout_constraint() override; \ | |||||
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||||
\ | |||||
public: \ | |||||
RNG(_INPUTS(VarNode*), const Param& param, const OperatorNodeConfig& config); \ | |||||
MGE_WIN_DECLSPEC_FUC static _OUTPUTS make( \ | |||||
_INPUTS(SymbolVar), const Param& param = {}, \ | |||||
const OperatorNodeConfig& config = {}); \ | |||||
void init_output_static_infer_desc() override; \ | |||||
void scn_do_execute() override; \ | |||||
}; | |||||
/* ================= 1 input ================= */ | /* ================= 1 input ================= */ | ||||
#define _INPUTS(preifx) preifx i0 | #define _INPUTS(preifx) preifx i0 | ||||
@@ -100,7 +99,7 @@ _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG) | |||||
#undef _INPUTS | #undef _INPUTS | ||||
#undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS | #undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS | ||||
} // intl | |||||
} // namespace intl | |||||
using UniformRNG = intl::UniformRNG; | using UniformRNG = intl::UniformRNG; | ||||
using GaussianRNG = intl::GaussianRNG; | using GaussianRNG = intl::GaussianRNG; | ||||
@@ -111,16 +110,15 @@ using BetaRNG = intl::BetaRNG; | |||||
using ShuffleRNG = intl::ShuffleRNGForward; | using ShuffleRNG = intl::ShuffleRNGForward; | ||||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | ||||
ShuffleRNGBackward, | |||||
intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) //{ | |||||
ShuffleRNGBackward, intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) // { | |||||
public: | public: | ||||
ShuffleRNGBackward( | |||||
VarNode* out_diff, VarNode* indices, VarNode* result_shape, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
ShuffleRNGBackward( | |||||
VarNode* out_diff, VarNode* indices, VarNode* result_shape, | |||||
const Param& param, const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
SymbolVar out_diff, SymbolVar indices, SymbolVar result_shape, | |||||
const Param& param = {}, const OperatorNodeConfig& config = {}); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
SymbolVar out_diff, SymbolVar indices, SymbolVar result_shape, | |||||
const Param& param = {}, const OperatorNodeConfig& config = {}); | |||||
}; | }; | ||||
} // namespace opr | } // namespace opr | ||||
@@ -19,7 +19,8 @@ failed_files = Manager().list() | |||||
def process_file(file, clang_format, write): | def process_file(file, clang_format, write): | ||||
source = open(file, "r").read() | source = open(file, "r").read() | ||||
source = re.sub(r"MGB_DEFINE(?P<r>(.|\n)*?)// +{", "class MGB_DEFINE\g<r>{", source) | |||||
source = re.sub(r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source) | |||||
source, count = re.subn(r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source) | |||||
result = subprocess.check_output( | result = subprocess.check_output( | ||||
[ | [ | ||||
@@ -33,6 +34,8 @@ def process_file(file, clang_format, write): | |||||
) | ) | ||||
result = result.decode("utf-8") | result = result.decode("utf-8") | ||||
if count: | |||||
result = re.sub(r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result) | |||||
result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result) | result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result) | ||||
if write: | if write: | ||||