Browse Source

perf(mgb): use midout in megbrain to reduce binary size

GitOrigin-RevId: ddc8af79af
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
80c4705317
32 changed files with 361 additions and 25 deletions
  1. +1
    -0
      src/CMakeLists.txt
  2. +22
    -1
      src/gopt/impl/basic_arith/chain.cpp
  3. +12
    -0
      src/gopt/impl/basic_arith/inplace.cpp
  4. +18
    -0
      src/gopt/impl/basic_arith/trans.cpp
  5. +31
    -0
      src/gopt/impl/inference.cpp
  6. +28
    -0
      src/gopt/impl/misc.cpp
  7. +25
    -1
      src/gopt/impl/tensor_reformat.cpp
  8. +12
    -0
      src/gopt/impl/weights_preprocess.cpp
  9. +2
    -0
      src/opr-mm/impl/collective_comm.cpp
  10. +2
    -0
      src/opr-mm/impl/io_remote.cpp
  11. +10
    -1
      src/opr/impl/basic_arith.cpp
  12. +13
    -0
      src/opr/impl/blas.cpp
  13. +4
    -0
      src/opr/impl/cond.cpp
  14. +2
    -0
      src/opr/impl/dnn/batch_norm.cpp
  15. +51
    -12
      src/opr/impl/dnn/convolution.cpp
  16. +2
    -0
      src/opr/impl/dnn/images2neibs.cpp
  17. +6
    -0
      src/opr/impl/dnn/local.cpp
  18. +2
    -0
      src/opr/impl/dnn/lrn.cpp
  19. +2
    -0
      src/opr/impl/dnn/pooling.cpp
  20. +2
    -0
      src/opr/impl/dnn/roi_align.cpp
  21. +4
    -0
      src/opr/impl/dnn/roi_pooling.cpp
  22. +4
    -0
      src/opr/impl/imgproc.cpp
  23. +26
    -0
      src/opr/impl/indexing.cpp
  24. +2
    -0
      src/opr/impl/io.cpp
  25. +2
    -0
      src/opr/impl/loop/forward.cpp
  26. +12
    -1
      src/opr/impl/misc.cpp
  27. +2
    -0
      src/opr/impl/muxing.cpp
  28. +9
    -7
      src/opr/impl/rand.cpp
  29. +6
    -1
      src/opr/impl/tensor_gen.cpp
  30. +22
    -1
      src/opr/impl/tensor_manip.cpp
  31. +12
    -0
      src/opr/impl/utility.cpp
  32. +13
    -0
      src/plugin/impl/opr_footprint.cpp

+ 1
- 0
src/CMakeLists.txt View File

@@ -43,6 +43,7 @@ add_library(megbrain OBJECT EXCLUDE_FROM_ALL ${SOURCES})
target_link_libraries(megbrain PUBLIC mgb_opr_param_defs) target_link_libraries(megbrain PUBLIC mgb_opr_param_defs)
target_include_directories(megbrain target_include_directories(megbrain
PUBLIC $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}> PUBLIC $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>
PRIVATE ${PROJECT_SOURCE_DIR}/third_party/midout/src
) )
foreach (INCPATH IN LISTS MGB_INC) foreach (INCPATH IN LISTS MGB_INC)
target_include_directories(megbrain target_include_directories(megbrain


+ 22
- 1
src/gopt/impl/basic_arith/chain.cpp View File

@@ -15,6 +15,20 @@


#include <deque> #include <deque>


//! TODO: here has to be know some megdnn::opr when there is produced midout.h
//! fix it if there is another graceful way.
#include "megdnn/oprs.h"

#include "megbrain/utils/hash_ct.h"
#include "midout.h"

MIDOUT_DECL(megbrain_chain)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_chain, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();

using namespace mgb; using namespace mgb;
using namespace gopt; using namespace gopt;
using namespace opr; using namespace opr;
@@ -132,6 +146,7 @@ const char* ExpandFusedArithPass::name() const {
} }


void ExpandFusedArithPass::apply(OptState &opt) const { void ExpandFusedArithPass::apply(OptState &opt) const {
MIDOUT_B("ExpandFusedArithPass::apply")
auto rewriter = opt.graph().make_rewriter(); auto rewriter = opt.graph().make_rewriter();
auto on_opr = [&](OperatorNodeBase *opr) { auto on_opr = [&](OperatorNodeBase *opr) {
using Mode = Elemwise::Mode; using Mode = Elemwise::Mode;
@@ -172,6 +187,7 @@ void ExpandFusedArithPass::apply(OptState &opt) const {


opt.graph().iter(on_opr); opt.graph().iter(on_opr);
rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


/* ================ NormalizeArithChainPass ================ */ /* ================ NormalizeArithChainPass ================ */
@@ -529,7 +545,9 @@ const char* NormalizeArithChainPass::name() const {
} }


void NormalizeArithChainPass::apply(OptState &opt) const { void NormalizeArithChainPass::apply(OptState &opt) const {
MIDOUT_B("NormalizeArithChainPass::apply")
Impl{opt}; Impl{opt};
MIDOUT_E
} }


/* ================ ReorderArithChainPass ================ */ /* ================ ReorderArithChainPass ================ */
@@ -737,7 +755,9 @@ const char* ReorderArithChainPass::name() const {
} }


void ReorderArithChainPass::apply(OptState &opt) const { void ReorderArithChainPass::apply(OptState &opt) const {
MIDOUT_B("ReorderArithChainPass::apply")
Impl{*this, opt}; Impl{*this, opt};
MIDOUT_E
} }


/* ================ ArithFusePass ================ */ /* ================ ArithFusePass ================ */
@@ -944,8 +964,9 @@ const char* ArithFusePass::name() const {
} }


void ArithFusePass::apply(OptState &opt) const { void ArithFusePass::apply(OptState &opt) const {
MIDOUT_B("ArithFusePass::apply")
Impl{opt}; Impl{opt};
MIDOUT_E
} }


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}


+ 12
- 0
src/gopt/impl/basic_arith/inplace.cpp View File

@@ -19,6 +19,16 @@


#include <cmath> #include <cmath>


#include "megbrain/utils/hash_ct.h"
#include "midout.h"

MIDOUT_DECL(megbrain_inplace)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_inplace, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();

using namespace mgb; using namespace mgb;
using namespace opr; using namespace opr;
using namespace gopt; using namespace gopt;
@@ -150,8 +160,10 @@ bool gopt::has_inplace_basic_arith_opt(const cg::OperatorNodeBase& opr) {


const inplace_optimize::OptimizerRegistry& const inplace_optimize::OptimizerRegistry&
inplace_optimize::optimizer_registry() { inplace_optimize::optimizer_registry() {
MIDOUT_B("inplace_optimize::optimizer_registry")
static OptimizerRegistry ret = make_optimizer_registry(); static OptimizerRegistry ret = make_optimizer_registry();
return ret; return ret;
MIDOUT_E
} }


inplace_optimize::OptimizerRegistry inplace_optimize::OptimizerRegistry


+ 18
- 0
src/gopt/impl/basic_arith/trans.cpp View File

@@ -13,6 +13,20 @@
#include "megbrain/gopt/basic_arith.h" #include "megbrain/gopt/basic_arith.h"
#include "megbrain/serialization/serializer.h" #include "megbrain/serialization/serializer.h"


//! TODO: here has to be know some megdnn::opr when there is produced midout.h
//! fix it if there is another graceful way.
#include "megdnn/oprs.h"

#include "megbrain/utils/hash_ct.h"
#include "midout.h"

MIDOUT_DECL(megbrain_trans)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_trans, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();

using namespace mgb; using namespace mgb;
using namespace gopt; using namespace gopt;


@@ -284,7 +298,9 @@ const char* ArithMulDistributePass::name() const {
} }


void ArithMulDistributePass::apply(OptState &opt) const { void ArithMulDistributePass::apply(OptState &opt) const {
MIDOUT_B("ArithMulDistributePass::apply")
Impl{*this, opt}; Impl{*this, opt};
MIDOUT_E
} }


/* ================ FinalArithTransformPass ================ */ /* ================ FinalArithTransformPass ================ */
@@ -488,7 +504,9 @@ const char* FinalArithTransformPass::name() const {
} }


void FinalArithTransformPass::apply(OptState &opt) const { void FinalArithTransformPass::apply(OptState &opt) const {
MIDOUT_B("FinalArithTransformPass::apply")
Impl{*this, opt}; Impl{*this, opt};
MIDOUT_E
} }


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}


+ 31
- 0
src/gopt/impl/inference.cpp View File

@@ -27,6 +27,7 @@
#include "megbrain/opr/imgproc.h" #include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h" #include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_gen.h" #include "megbrain/opr/tensor_gen.h"
#include "megbrain/utils/hash_ct.h"


#include "megdnn/tensor_format.h" #include "megdnn/tensor_format.h"


@@ -36,6 +37,16 @@


#include "megbrain/gopt/misc.h" #include "megbrain/gopt/misc.h"


#include "megbrain/utils/hash_ct.h"
#include "midout.h"

MIDOUT_DECL(megbrain_inference)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_inference, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();

using namespace mgb; using namespace mgb;
using namespace gopt; using namespace gopt;


@@ -430,7 +441,9 @@ ParamRedistributePass::Impl::Impl(OptState &state):
} }


void ParamRedistributePass::apply(OptState &state) const { void ParamRedistributePass::apply(OptState &state) const {
MIDOUT_B("ParamRedistributePass::apply")
Impl{state}; Impl{state};
MIDOUT_E
} }


/* ================ ParamFusePass ================ */ /* ================ ParamFusePass ================ */
@@ -512,6 +525,7 @@ const char* ParamFusePass::name() const {
} }


void ParamFusePass::apply(OptState &state) const { void ParamFusePass::apply(OptState &state) const {
MIDOUT_B("ParamFusePass::apply")
auto rewriter = state.graph().make_rewriter(); auto rewriter = state.graph().make_rewriter();
auto cg = state.graph().comp_graph(); auto cg = state.graph().comp_graph();


@@ -613,6 +627,7 @@ void ParamFusePass::apply(OptState &state) const {


state.graph().iter(replace_opr); state.graph().iter(replace_opr);
rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


/* ================ One2OneOprReplacePass ================ */ /* ================ One2OneOprReplacePass ================ */
@@ -621,6 +636,7 @@ const char* ConvertF32ToF16Pass::name() const {
} }


void ConvertF32ToF16Pass::apply(OptState& state) const { void ConvertF32ToF16Pass::apply(OptState& state) const {
MIDOUT_B("ConvertF32ToF16Pass::apply")
state.set_var_replace_check_flag(m_var_replace_check_flag); state.set_var_replace_check_flag(m_var_replace_check_flag);
auto rewriter = state.graph().make_rewriter(); auto rewriter = state.graph().make_rewriter();
VarNodeArray new_inp_cache; VarNodeArray new_inp_cache;
@@ -674,6 +690,7 @@ void ConvertF32ToF16Pass::apply(OptState& state) const {
auto opr = endpoints[0].node()->owner_opr(); auto opr = endpoints[0].node()->owner_opr();
state.call_with_opr(opr, replace_output, OprPropertyFlag::NONE); state.call_with_opr(opr, replace_output, OprPropertyFlag::NONE);
rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
@@ -940,6 +957,7 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(
/* ================ ConvertFormatPass ================ */ /* ================ ConvertFormatPass ================ */


void ConvertFormatPass::apply(OptState& state) const { void ConvertFormatPass::apply(OptState& state) const {
MIDOUT_B("ConvertFormatPass::apply")
state.set_var_replace_check_flag(m_var_replace_check_flag); state.set_var_replace_check_flag(m_var_replace_check_flag);
auto rewriter = state.graph().make_rewriter(); auto rewriter = state.graph().make_rewriter();
VarNodeArray new_inp_cache; VarNodeArray new_inp_cache;
@@ -994,9 +1012,11 @@ void ConvertFormatPass::apply(OptState& state) const {
}; };
state.graph().iter(on_opr); state.graph().iter(on_opr);
rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
MIDOUT_B("ConvertFormatPass::make")
auto filter_mode = auto filter_mode =
[](const megdnn::param::Convolution::Sparse conv_mode, [](const megdnn::param::Convolution::Sparse conv_mode,
const VarNode* filter) -> megdnn::param::RelayoutFormat::Mode { const VarNode* filter) -> megdnn::param::RelayoutFormat::Mode {
@@ -1551,6 +1571,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
replace_func[opr::GroupLocalForward::typeinfo()] = replace_func[opr::GroupLocalForward::typeinfo()] =
relayout_first_inp_to_chw; relayout_first_inp_to_chw;
return ret; return ret;
MIDOUT_E
} }


/* ================ ConvertBatchNormPass ================ */ /* ================ ConvertBatchNormPass ================ */
@@ -1559,6 +1580,7 @@ const char* ConvertBatchNormToElemwisePass::name() const {
} }


void ConvertBatchNormToElemwisePass::apply(OptState& state) const { void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
MIDOUT_B("ConvertBatchNormToElemwisePass::apply")
auto rewriter = state.graph().make_rewriter(); auto rewriter = state.graph().make_rewriter();
auto on_opr = [&](OperatorNodeBase* opr) { auto on_opr = [&](OperatorNodeBase* opr) {
if (auto bn = try_cast_as_op<opr::BatchNorm>(opr)) { if (auto bn = try_cast_as_op<opr::BatchNorm>(opr)) {
@@ -1586,6 +1608,7 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
state.graph().iter(on_opr); state.graph().iter(on_opr);


rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


/* ================ FuseConvBiasNonlinPass ================ */ /* ================ FuseConvBiasNonlinPass ================ */
@@ -1594,6 +1617,7 @@ const char* FuseConvBiasNonlinPass::name() const {
} }


void FuseConvBiasNonlinPass::apply(OptState& state) const { void FuseConvBiasNonlinPass::apply(OptState& state) const {
MIDOUT_B("FuseConvBiasNonlinPass::apply")
std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps; std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
state.graph().iter([&m_deps](OperatorNodeBase* opr) { state.graph().iter([&m_deps](OperatorNodeBase* opr) {
for (auto& inp : opr->input()) { for (auto& inp : opr->input()) {
@@ -1843,6 +1867,7 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
state.graph().iter(on_opr); state.graph().iter(on_opr);


rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


/* ================ FuseConvBiasZPass ================ */ /* ================ FuseConvBiasZPass ================ */
@@ -1851,6 +1876,7 @@ const char* FuseConvBiasZPass::name() const {
} }


void FuseConvBiasZPass::apply(OptState& state) const { void FuseConvBiasZPass::apply(OptState& state) const {
MIDOUT_B("FuseConvBiasZPass::apply")
UniqReaderCheck uniq_reader_check{state.graph()}; UniqReaderCheck uniq_reader_check{state.graph()};


auto rewriter = state.graph().make_rewriter(); auto rewriter = state.graph().make_rewriter();
@@ -1977,6 +2003,7 @@ void FuseConvBiasZPass::apply(OptState& state) const {
state.graph().iter(on_opr); state.graph().iter(on_opr);


rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


/* ================ FuseDeconvCvtPass ================ */ /* ================ FuseDeconvCvtPass ================ */
@@ -1986,6 +2013,7 @@ const char* FuseDeconvCvtPass::name() const {




void FuseDeconvCvtPass::apply(OptState& state) const { void FuseDeconvCvtPass::apply(OptState& state) const {
MIDOUT_B("FuseDeconvCvtPass::apply")
std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps; std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps;
state.graph().iter([&m_deps](OperatorNodeBase* opr) { state.graph().iter([&m_deps](OperatorNodeBase* opr) {
for (auto& inp : opr->input()) { for (auto& inp : opr->input()) {
@@ -2036,6 +2064,7 @@ void FuseDeconvCvtPass::apply(OptState& state) const {
state.graph().iter(on_opr); state.graph().iter(on_opr);


rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


/* ================ ParamMergePass ================ */ /* ================ ParamMergePass ================ */
@@ -2044,10 +2073,12 @@ const char* ParamMergePass::name() const {
} }


void ParamMergePass::apply(OptState& opt_state) const { void ParamMergePass::apply(OptState& opt_state) const {
MIDOUT_B("ParamMergePass::apply")
param_merge<opr::SharedDeviceTensor, opr::MultipleDeviceTensorHolder>( param_merge<opr::SharedDeviceTensor, opr::MultipleDeviceTensorHolder>(
opt_state); opt_state);
param_merge<opr::SharedDeviceTensorWithFormat, param_merge<opr::SharedDeviceTensorWithFormat,
opr::MultipleDeviceTensorWithFormatHolder>(opt_state); opr::MultipleDeviceTensorWithFormatHolder>(opt_state);
MIDOUT_E
} }


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 28
- 0
src/gopt/impl/misc.cpp View File

@@ -19,6 +19,16 @@
#include "megbrain/serialization/opr_shallow_copy.h" #include "megbrain/serialization/opr_shallow_copy.h"
#include "../../core/impl/graph/cg_impl.h" #include "../../core/impl/graph/cg_impl.h"


#include "megbrain/utils/hash_ct.h"
#include "midout.h"

MIDOUT_DECL(megbrain_misc)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_misc, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();

using namespace mgb; using namespace mgb;
using namespace gopt; using namespace gopt;


@@ -29,6 +39,7 @@ const char* RemoveNonComputingOprPass::name() const {
} }


void RemoveNonComputingOprPass::apply(OptState& opt) const { void RemoveNonComputingOprPass::apply(OptState& opt) const {
MIDOUT_B("RemoveNonComputingOprPass::apply")
auto rewriter = opt.graph().make_rewriter(); auto rewriter = opt.graph().make_rewriter();
auto on_opr = [&](OperatorNodeBase* opr) { auto on_opr = [&](OperatorNodeBase* opr) {
auto type = opr->dyn_typeinfo(); auto type = opr->dyn_typeinfo();
@@ -75,6 +86,7 @@ void RemoveNonComputingOprPass::apply(OptState& opt) const {


opt.graph().iter(on_opr); opt.graph().iter(on_opr);
rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


/* ================ ExpandVirtualGradPass ================ */ /* ================ ExpandVirtualGradPass ================ */
@@ -84,6 +96,7 @@ const char* ExpandVirtualGradPass::name() const {
} }


void ExpandVirtualGradPass::apply(OptState& opt) const { void ExpandVirtualGradPass::apply(OptState& opt) const {
MIDOUT_B("ExpandVirtualGradPass::apply")
#if MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
opt.set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); opt.set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
auto rewriter = opt.graph().make_rewriter(); auto rewriter = opt.graph().make_rewriter();
@@ -111,6 +124,7 @@ void ExpandVirtualGradPass::apply(OptState& opt) const {
#else #else
MGB_MARK_USED_VAR(opt); MGB_MARK_USED_VAR(opt);
#endif #endif
MIDOUT_E
} }


/* ================= DelayBroadcastPass ================ */ /* ================= DelayBroadcastPass ================ */
@@ -144,6 +158,7 @@ void DelayBroadcastPass::apply(OptState& opt) const {
// remove them from the chain, and add them back right after the endpoint. // remove them from the chain, and add them back right after the endpoint.


// TypeCvt's order may change, so disable the check. // TypeCvt's order may change, so disable the check.
MIDOUT_B("DelayBroadcastPass::apply")
opt.set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); opt.set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);


auto unique_reader_chk = UniqReaderCheck{opt.graph()}; auto unique_reader_chk = UniqReaderCheck{opt.graph()};
@@ -325,6 +340,7 @@ void DelayBroadcastPass::apply(OptState& opt) const {


opt.graph().iter(on_opr); opt.graph().iter(on_opr);
rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


/* ======================= RecompTypeCvtPass ====================== */ /* ======================= RecompTypeCvtPass ====================== */
@@ -334,6 +350,7 @@ const char* RecompTypeCvtPass::name() const {
} }


void RecompTypeCvtPass::apply(OptState& opt) const { void RecompTypeCvtPass::apply(OptState& opt) const {
MIDOUT_B("RecompTypeCvtPass::apply")
auto rewriter = opt.graph().make_rewriter(); auto rewriter = opt.graph().make_rewriter();


auto allowed_typecvt = [](OperatorNodeBase* opr) -> OperatorNodeBase* { auto allowed_typecvt = [](OperatorNodeBase* opr) -> OperatorNodeBase* {
@@ -399,6 +416,7 @@ void RecompTypeCvtPass::apply(OptState& opt) const {
}; };
opt.graph().iter(on_opr); opt.graph().iter(on_opr);
rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


/* ======================= CombineAstypeAndReducePass ====================== */ /* ======================= CombineAstypeAndReducePass ====================== */
@@ -408,6 +426,7 @@ const char* CombineAstypeAndReducePass::name() const {
} }


void CombineAstypeAndReducePass::apply(OptState& opt) const { void CombineAstypeAndReducePass::apply(OptState& opt) const {
MIDOUT_B("CombineAstypeAndReducePass::apply")
auto rewriter = opt.graph().make_rewriter(); auto rewriter = opt.graph().make_rewriter();


using DataType = opr::Reduce::Param::DataType; using DataType = opr::Reduce::Param::DataType;
@@ -453,6 +472,7 @@ void CombineAstypeAndReducePass::apply(OptState& opt) const {


opt.graph().iter(on_opr); opt.graph().iter(on_opr);
rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


/* ================ CondExecConstPredicateFolding ================ */ /* ================ CondExecConstPredicateFolding ================ */
@@ -462,6 +482,7 @@ const char* CondExecConstPredicateFolding::name() const {


void CondExecConstPredicateFolding::apply(OptState& opt) const { void CondExecConstPredicateFolding::apply(OptState& opt) const {
#if MGB_ENABLE_COND_EXEC #if MGB_ENABLE_COND_EXEC
MIDOUT_B("CondExecConstPredicateFolding::apply")
if (!cg::ExecutionMask::have_alive_instance()) { if (!cg::ExecutionMask::have_alive_instance()) {
return; return;
} }
@@ -605,6 +626,7 @@ void CondExecConstPredicateFolding::apply(OptState& opt) const {
} }


rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E


#endif // MGB_ENABLE_COND_EXEC #endif // MGB_ENABLE_COND_EXEC
} }
@@ -632,6 +654,7 @@ bool RemoveRedundantTypeCvtPass::should_remove(DType A, DType B) {
} }


void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { void RemoveRedundantTypeCvtPass::apply(OptState& opt) const {
MIDOUT_B("RemoveRedundantTypeCvtPass::apply")
auto rewriter = opt.graph().make_rewriter(); auto rewriter = opt.graph().make_rewriter();


auto on_opr = [&](OperatorNodeBase* opr) { auto on_opr = [&](OperatorNodeBase* opr) {
@@ -656,6 +679,7 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const {


opt.graph().iter(on_opr); opt.graph().iter(on_opr);
rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


#if MGB_ENABLE_OPR_MM #if MGB_ENABLE_OPR_MM
@@ -668,6 +692,7 @@ const char* PackAllReduceScanPass::name() const {
} }


void PackAllReduceScanPass::apply(OptState& opt) const { void PackAllReduceScanPass::apply(OptState& opt) const {
MIDOUT_B("PackAllReduceScanPass::apply")
auto comp_graph = opt.graph().comp_graph(); auto comp_graph = opt.graph().comp_graph();
if (comp_graph->options().allreduce_pack_max_size == 0) return; if (comp_graph->options().allreduce_pack_max_size == 0) return;
auto cb_scan = [this] (OperatorNodeBase* opr) { auto cb_scan = [this] (OperatorNodeBase* opr) {
@@ -682,6 +707,7 @@ void PackAllReduceScanPass::apply(OptState& opt) const {
} }
}; };
opt.graph().iter(cb_scan); opt.graph().iter(cb_scan);
MIDOUT_E
} }


bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) {
@@ -856,6 +882,7 @@ void PackAllReduceReplacePass::insert_packed_oprs(
} }


void PackAllReduceReplacePass::apply(OptState& opt) const { void PackAllReduceReplacePass::apply(OptState& opt) const {
MIDOUT_B("PackAllReduceReplacePass::apply")
// get graph options // get graph options
auto comp_graph = opt.graph().comp_graph(); auto comp_graph = opt.graph().comp_graph();
size_t max_size = comp_graph->options().allreduce_pack_max_size * 1024 * 1024; size_t max_size = comp_graph->options().allreduce_pack_max_size * 1024 * 1024;
@@ -917,6 +944,7 @@ void PackAllReduceReplacePass::apply(OptState& opt) const {
}; };
opt.graph().iter(cb_replace); opt.graph().iter(cb_replace);
rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


#else #else


+ 25
- 1
src/gopt/impl/tensor_reformat.cpp View File

@@ -36,6 +36,16 @@
#endif #endif


#include "megbrain/gopt/misc.h" #include "megbrain/gopt/misc.h"
#include "megbrain/utils/hash_ct.h"

#include "midout.h"

MIDOUT_DECL(megbrain_tensor_reformat)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_tensor_reformat, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();


using namespace mgb; using namespace mgb;
using namespace gopt; using namespace gopt;
@@ -755,8 +765,10 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
} }


void TensorReformatPass::apply(OptState& opt) const { void TensorReformatPass::apply(OptState& opt) const {
MIDOUT_B("TensorReformatPass::apply")
insert_pass(opt); insert_pass(opt);
translate_pass(opt); translate_pass(opt);
MIDOUT_E
} }


/* ================ EnableTensorCorePass =============== */ /* ================ EnableTensorCorePass =============== */
@@ -773,6 +785,7 @@ VarNode* EnableTensorCorePass::on_graph_endpoint_var(VarNode* new_var,


std::unique_ptr<EnableTensorCorePass> std::unique_ptr<EnableTensorCorePass>
EnableTensorCorePass::make_tensorcore_converter() { EnableTensorCorePass::make_tensorcore_converter() {
MIDOUT_B("EnableTensorCorePass::make")
// replace rule for conv bias opr // replace rule for conv bias opr
auto replace_conv_bias_opr = [](OperatorNodeBase* opr, auto replace_conv_bias_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
@@ -1111,6 +1124,7 @@ EnableTensorCorePass::make_tensorcore_converter() {
replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4; replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4;
replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4; replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4;
return ret; return ret;
MIDOUT_E
} }


/* ================ EnableCHWN4Pass =============== */ /* ================ EnableCHWN4Pass =============== */
@@ -1125,6 +1139,7 @@ VarNode* EnableCHWN4Pass::on_graph_endpoint_var(VarNode* new_var,
} }


std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() { std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
MIDOUT_B("EnableCHWN4Pass::make")
auto ret = std::make_unique<EnableCHWN4Pass>(); auto ret = std::make_unique<EnableCHWN4Pass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
auto&& replace_func = ret->m_opr_replace_func; auto&& replace_func = ret->m_opr_replace_func;
@@ -1381,6 +1396,7 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4; replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4;
replace_func[opr::BatchConvBias::typeinfo()] = replace_inps_to_nchw4; replace_func[opr::BatchConvBias::typeinfo()] = replace_inps_to_nchw4;
return ret; return ret;
MIDOUT_E
} }


/* ================ EnableNCHW4Pass ================ */ /* ================ EnableNCHW4Pass ================ */
@@ -1395,6 +1411,7 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var,
} }


std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
MIDOUT_B("EnableNCHW4Pass::make")
auto ret = std::make_unique<EnableNCHW4Pass>(); auto ret = std::make_unique<EnableNCHW4Pass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
using RelayoutMode = RelayoutPlaceholder::LayoutType; using RelayoutMode = RelayoutPlaceholder::LayoutType;
@@ -1772,6 +1789,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw;
return ret; return ret;
MIDOUT_E
} }


/* ================ EnableNchwxxPass =============== */ /* ================ EnableNchwxxPass =============== */
@@ -2140,6 +2158,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){


std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
size_t pack_c_size) { size_t pack_c_size) {
MIDOUT_B("EnableNchwxxPass::make")
auto ret = std::make_unique<EnableNchwxxPass>(pack_c_size); auto ret = std::make_unique<EnableNchwxxPass>(pack_c_size);
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
std::string convter_pass_name = "conv_format_nchw88"; std::string convter_pass_name = "conv_format_nchw88";
@@ -2149,6 +2168,7 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
ret->fill_opr_convert_fun(pack_c_size); ret->fill_opr_convert_fun(pack_c_size);
ret->set_name(convter_pass_name); ret->set_name(convter_pass_name);
return ret; return ret;
MIDOUT_E
} }


/* ================ EnableNchw44DotPass =============== */ /* ================ EnableNchw44DotPass =============== */
@@ -2164,6 +2184,7 @@ VarNode* EnableNchw44DotPass::on_graph_endpoint_var(VarNode* new_var,


std::unique_ptr<EnableNchw44DotPass> std::unique_ptr<EnableNchw44DotPass>
EnableNchw44DotPass::make_nchw44_dot_converter() { EnableNchw44DotPass::make_nchw44_dot_converter() {
MIDOUT_B("EnableNchw44DotPass::make")
auto ret = std::make_unique<EnableNchw44DotPass>(); auto ret = std::make_unique<EnableNchw44DotPass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
//! First is whether the conv can trans to nchwxx, second is the filter //! First is whether the conv can trans to nchwxx, second is the filter
@@ -2384,6 +2405,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
return ret; return ret;
MIDOUT_E
} }


/* ==================== ShuffleShuffleRemovePass ================= */ /* ==================== ShuffleShuffleRemovePass ================= */
@@ -2961,9 +2983,11 @@ const char* ShuffleShuffleRemovePass::name() const {
} }


void ShuffleShuffleRemovePass::apply(OptState& opt) const { void ShuffleShuffleRemovePass::apply(OptState& opt) const {
MIDOUT_B("ShuffleShuffleRemovePass::apply")
opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_SHAPE | opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_SHAPE |
VarReplaceCheckFlag::CHECK_DTYPE); VarReplaceCheckFlag::CHECK_DTYPE);
Impl{opt}; Impl{opt};
MIDOUT_E
} }


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 12
- 0
src/gopt/impl/weights_preprocess.cpp View File

@@ -14,6 +14,16 @@
#include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"


#include "megbrain/utils/hash_ct.h"
#include "midout.h"

MIDOUT_DECL(megbrain_weight_preprocess)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_weight_preprocess, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();

using namespace mgb; using namespace mgb;
using namespace gopt; using namespace gopt;
using namespace cg; using namespace cg;
@@ -23,6 +33,7 @@ const char* WinogradTransformReplacePass::name() const {
} }


void WinogradTransformReplacePass::apply(OptState& opt) const { void WinogradTransformReplacePass::apply(OptState& opt) const {
MIDOUT_B("WinogradTransformReplacePass::apply")
auto rewriter = opt.graph().make_rewriter(); auto rewriter = opt.graph().make_rewriter();
ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM}; ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM};
opt.graph().iter([&cvprop](OperatorNodeBase *opr) { opt.graph().iter([&cvprop](OperatorNodeBase *opr) {
@@ -174,6 +185,7 @@ void WinogradTransformReplacePass::apply(OptState& opt) const {


opt.graph().iter(on_opr); opt.graph().iter(on_opr);
rewriter.apply_inplace(); rewriter.apply_inplace();
MIDOUT_E
} }


/** /**


+ 2
- 0
src/opr-mm/impl/collective_comm.cpp View File

@@ -855,10 +855,12 @@ VarNode* CollectiveComm::grad(VarNode* out_grad) const {
return ModeTrait::from_mode(m_param.mode).grad(out_grad, this); return ModeTrait::from_mode(m_param.mode).grad(out_grad, this);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CollectiveComm) { MGB_IMPL_OPR_GRAD(CollectiveComm) {
mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad"); mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad");
return opr.grad(out_grad[0]); return opr.grad(out_grad[0]);
} }
#endif


/* ===================== shallow copy ===================== */ /* ===================== shallow copy ===================== */




+ 2
- 0
src/opr-mm/impl/io_remote.cpp View File

@@ -109,6 +109,7 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const {
return prop; return prop;
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(RemoteSend) { MGB_IMPL_OPR_GRAD(RemoteSend) {
mgb_assert(opr.is_grad()); mgb_assert(opr.is_grad());
return RemoteRecv::make(opr.key() + ":grad", return RemoteRecv::make(opr.key() + ":grad",
@@ -118,6 +119,7 @@ MGB_IMPL_OPR_GRAD(RemoteSend) {
opr.input(0)->shape(), opr.input(0)->dtype()) opr.input(0)->shape(), opr.input(0)->dtype())
.node(); .node();
} }
#endif


/* ===================== RemoteRecv ===================== */ /* ===================== RemoteRecv ===================== */




+ 10
- 1
src/opr/impl/basic_arith.cpp View File

@@ -552,6 +552,7 @@ void Elemwise::call_megdnn_opr_exec(
opr->exec(inp, out); opr->exec(inp, out);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Elemwise) { MGB_IMPL_OPR_GRAD(Elemwise) {
SymbolVar i[5]; SymbolVar i[5];
SymbolVar i0(opr.input(0)), i1, i2, out(opr.output(0)), SymbolVar i0(opr.input(0)), i1, i2, out(opr.output(0)),
@@ -730,6 +731,7 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
result = -result; result = -result;
return result.node(); return result.node();
} }
#endif


VarNode* Elemwise::sum_grad_list(VarNode *wrt, VarNodeArray &grads) { VarNode* Elemwise::sum_grad_list(VarNode *wrt, VarNodeArray &grads) {
mgb_assert(!grads.empty()); mgb_assert(!grads.empty());
@@ -814,6 +816,7 @@ TypeCvt::NodeProp* TypeCvt::do_make_node_prop() const {
return ret; return ret;
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(TypeCvt) { MGB_IMPL_OPR_GRAD(TypeCvt) {
MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(wrt_idx);
auto itype = opr.input(0)->dtype(), otype = opr.output(0)->dtype(); auto itype = opr.input(0)->dtype(), otype = opr.output(0)->dtype();
@@ -826,6 +829,7 @@ MGB_IMPL_OPR_GRAD(TypeCvt) {
} }
return TypeCvt::make(out_grad[0], opr.input(0)->dtype()).node(); return TypeCvt::make(out_grad[0], opr.input(0)->dtype()).node();
} }
#endif


void TypeCvt::mem_plan_fwd_in2out_writable() { void TypeCvt::mem_plan_fwd_in2out_writable() {
if (input(0)->dtype().size() == output(0)->dtype().size() && if (input(0)->dtype().size() == output(0)->dtype().size() &&
@@ -963,10 +967,12 @@ void AddUpdate::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps); record_megdnn_opr(deps);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(AddUpdate) { MGB_IMPL_OPR_GRAD(AddUpdate) {
// actually valid, just not implemented // actually valid, just not implemented
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
} }
#endif


/* =========================== Reduce =========================== */ /* =========================== Reduce =========================== */


@@ -1698,6 +1704,7 @@ void Reduce::create_megdnn_opr() {
create_operator<megdnn::Reduce>()); create_operator<megdnn::Reduce>());
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Reduce) { MGB_IMPL_OPR_GRAD(Reduce) {
for (size_t i = 1; i < opr.output().size(); ++ i) for (size_t i = 1; i < opr.output().size(); ++ i)
mgb_assert(!out_grad[i]); mgb_assert(!out_grad[i]);
@@ -1733,7 +1740,7 @@ MGB_IMPL_OPR_GRAD(Reduce) {
grad = TypeCvt::make(grad, iv.dtype()); grad = TypeCvt::make(grad, iv.dtype());
return grad.node(); return grad.node();
} }
#endif


void Reduce::record_execute_deps(ExecDependencyArray& deps) { void Reduce::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps); record_megdnn_opr(deps);
@@ -1783,11 +1790,13 @@ void PowC::init_output_static_infer_desc() {
{SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value});
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(PowC) { MGB_IMPL_OPR_GRAD(PowC) {
auto exp = opr.param().exp; auto exp = opr.param().exp;
return (exp * SymbolVar{out_grad[0]} * return (exp * SymbolVar{out_grad[0]} *
PowC::make(opr.input(0), exp - 1, opr.config())) PowC::make(opr.input(0), exp - 1, opr.config()))
.node(); .node();
} }
#endif


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 13
- 0
src/opr/impl/blas.cpp View File

@@ -106,6 +106,7 @@ void MatrixMul::scn_do_execute() {
MGB_FINALLY({ tparam = this->param(); }); MGB_FINALLY({ tparam = this->param(); });
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MatrixMul) { MGB_IMPL_OPR_GRAD(MatrixMul) {
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad"); "only float data type supported for grad");
@@ -128,6 +129,7 @@ MGB_IMPL_OPR_GRAD(MatrixMul) {
} }
return grad.node(); return grad.node();
} }
#endif


/* ================= BatchedMatrixMul ================= */ /* ================= BatchedMatrixMul ================= */


@@ -224,6 +226,7 @@ void BatchedMatrixMul::scn_do_execute() {
MGB_FINALLY({ tparam = this->param(); }); MGB_FINALLY({ tparam = this->param(); });
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { MGB_IMPL_OPR_GRAD(BatchedMatrixMul) {
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad"); "only float data type supported for grad");
@@ -251,6 +254,7 @@ MGB_IMPL_OPR_GRAD(BatchedMatrixMul) {
} }
return grad.node(); return grad.node();
} }
#endif


/* ================= Dot ================= */ /* ================= Dot ================= */


@@ -327,6 +331,7 @@ void Dot::add_input_layout_constraint() {
input(1)->add_layout_constraint(check); input(1)->add_layout_constraint(check);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Dot) { MGB_IMPL_OPR_GRAD(Dot) {
auto other_input = opr.input(wrt_idx == 0 ? 1 : 0); auto other_input = opr.input(wrt_idx == 0 ? 1 : 0);
auto ishp0 = opr::GetVarShape::make(opr.input(0)), auto ishp0 = opr::GetVarShape::make(opr.input(0)),
@@ -336,6 +341,7 @@ MGB_IMPL_OPR_GRAD(Dot) {
Broadcast::make(mul(out_grad[0], other_input), max_ishp), Broadcast::make(mul(out_grad[0], other_input), max_ishp),
wrt_idx ? ishp1 : ishp0).node(); wrt_idx ? ishp1 : ishp0).node();
} }
#endif


SymbolVar Dot::make(SymbolVar opr0, SymbolVar opr1, SymbolVar Dot::make(SymbolVar opr0, SymbolVar opr1,
const OperatorNodeConfig &config) { const OperatorNodeConfig &config) {
@@ -350,6 +356,8 @@ void Dot::record_execute_deps(ExecDependencyArray &deps) {


MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse); MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse);
MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv") MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv")

#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MatrixInverse) { MGB_IMPL_OPR_GRAD(MatrixInverse) {
SymbolVar a = opr.output(0); SymbolVar a = opr.output(0);
// TODO: use unified MatrixMul interface when we have it // TODO: use unified MatrixMul interface when we have it
@@ -364,6 +372,7 @@ MGB_IMPL_OPR_GRAD(MatrixInverse) {
a_bnn); a_bnn);
return da.reshape(a.symshape()).node(); return da.reshape(a.symshape()).node();
} }
#endif


/* ================= SVD ================= */ /* ================= SVD ================= */


@@ -386,6 +395,7 @@ SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) :
} }
} }


#ifdef MGB_ENABLE_GRAD
namespace { namespace {


/*! /*!
@@ -477,7 +487,9 @@ OP(*, {}, {})
#undef OP #undef OP


} // anonymous namespace } // anonymous namespace
#endif


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SVD) { MGB_IMPL_OPR_GRAD(SVD) {
/** /**
* The formula is copied from * The formula is copied from
@@ -555,6 +567,7 @@ MGB_IMPL_OPR_GRAD(SVD) {
I_n - matmul(v, v, param01))); I_n - matmul(v, v, param01)));
return ret.reshape(a.symshape()).node(); return ret.reshape(a.symshape()).node();
} }
#endif


SymbolVarArray SVD::make(const SymbolVar& src, const Param& param, SymbolVarArray SVD::make(const SymbolVar& src, const Param& param,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {


+ 4
- 0
src/opr/impl/cond.cpp View File

@@ -818,6 +818,7 @@ SymbolVar CondExecMark::mark_if_need(SymbolVar maybe_ppv, SymbolVar input,
return input; return input;
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CondExecMark) { MGB_IMPL_OPR_GRAD(CondExecMark) {
if (wrt_idx == opr.input().size() - 1 || !out_grad.at(wrt_idx)) { if (wrt_idx == opr.input().size() - 1 || !out_grad.at(wrt_idx)) {
return nullptr; return nullptr;
@@ -841,6 +842,7 @@ MGB_IMPL_OPR_GRAD(CondExecMark) {
{1, grad_mode}, OperatorNodeConfig{}) {1, grad_mode}, OperatorNodeConfig{})
->output(0); ->output(0);
} }
#endif


/* ============================= CondExecMerge ============================= */ /* ============================= CondExecMerge ============================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondExecMerge); MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondExecMerge);
@@ -1225,6 +1227,7 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const {
return ret; return ret;
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CondExecMerge) { MGB_IMPL_OPR_GRAD(CondExecMerge) {
using Mode = CondExecMerge::Param::Mode; using Mode = CondExecMerge::Param::Mode;
if (opr.param().mode == Mode::SUM_COND_OUT && if (opr.param().mode == Mode::SUM_COND_OUT &&
@@ -1259,6 +1262,7 @@ MGB_IMPL_OPR_GRAD(CondExecMerge) {
OperatorNodeConfig{og->comp_node()}) OperatorNodeConfig{og->comp_node()})
->output(0); ->output(0);
} }
#endif


void CondExecMerge::modify_grad_sum_list(VarNode* wrt, VarNodeArray& grads) { void CondExecMerge::modify_grad_sum_list(VarNode* wrt, VarNodeArray& grads) {
if (!ExecutionMask::have_alive_instance()) { if (!ExecutionMask::have_alive_instance()) {


+ 2
- 0
src/opr/impl/dnn/batch_norm.cpp View File

@@ -230,6 +230,7 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() {
} }
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchNormForward) { MGB_IMPL_OPR_GRAD(BatchNormForward) {
mgb_assert(wrt_idx < 5); mgb_assert(wrt_idx < 5);
if (wrt_idx < 3) { if (wrt_idx < 3) {
@@ -242,6 +243,7 @@ MGB_IMPL_OPR_GRAD(BatchNormForward) {
return nullptr; return nullptr;
} }
} }
#endif


MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNormBackward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNormBackward);




+ 51
- 12
src/opr/impl/dnn/convolution.cpp View File

@@ -18,6 +18,19 @@


#include "megdnn/oprs/utils.h" #include "megdnn/oprs/utils.h"


//! TODO: here has to be know some megdnn::opr when there is produced midout.h
//! fix it if there is another graceful way.
#include "megdnn/oprs.h"

#include "midout.h"

MIDOUT_DECL(megbrain_opr_convolution)
#define MIDOUT_B(...) \
MIDOUT_BEGIN(megbrain_opr_convolution, __VA_ARGS__) {
#define MIDOUT_E \
} \
MIDOUT_END();

#include "../internal/megdnn_opr_wrapper.inl" #include "../internal/megdnn_opr_wrapper.inl"


#include <array> #include <array>
@@ -230,6 +243,7 @@ class TimedProfiler {
static constexpr int arity_in = OprArityTrait<Opr>::arity_in; static constexpr int arity_in = OprArityTrait<Opr>::arity_in;
static constexpr int arity_out = OprArityTrait<Opr>::arity_out; static constexpr int arity_out = OprArityTrait<Opr>::arity_out;
static constexpr int arity = OprArityTrait<Opr>::arity; static constexpr int arity = OprArityTrait<Opr>::arity;

using ConvTensorShapes = std::array<TensorShape, arity>; using ConvTensorShapes = std::array<TensorShape, arity>;


public: public:
@@ -295,6 +309,7 @@ double TimedProfiler<Opr>::init_timeout_setting() {
template <typename Opr> template <typename Opr>
typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl( typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(
const TParam& raw_param) { const TParam& raw_param) {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("TimedProfiler::prof_impl")))
auto&& param = raw_param.as_single_pod<Param>(); auto&& param = raw_param.as_single_pod<Param>();
CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc); CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc);
auto megdnn_opr = intl::create_megdnn_opr<Opr>(cn); auto megdnn_opr = intl::create_megdnn_opr<Opr>(cn);
@@ -401,14 +416,17 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl(


mgb_assert(ev_start->finished()); mgb_assert(ev_start->finished());
return TResult::from_pod(Result{ev_start->elapsed_time_until(*ev_end)}); return TResult::from_pod(Result{ev_start->elapsed_time_until(*ev_end)});
MIDOUT_E
}; };


template <typename Opr> template <typename Opr>
void TimedProfiler<Opr>::prof_init_device(const TParam& raw_param) { void TimedProfiler<Opr>::prof_init_device(const TParam& raw_param) {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("TimedProfiler::prof_init_device")))
auto&& param = raw_param.as_single_pod<Param>(); auto&& param = raw_param.as_single_pod<Param>();
CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc); CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc);
// wait for cuda init, so its time does not get accounted in timeout // wait for cuda init, so its time does not get accounted in timeout
cn.sync(); cn.sync();
MIDOUT_E
} }


/* =================== AlgoChooser =================== */ /* =================== AlgoChooser =================== */
@@ -426,6 +444,7 @@ class AlgoChooser {
static constexpr int arity_in = OprArityTrait<Opr>::arity_in; static constexpr int arity_in = OprArityTrait<Opr>::arity_in;
static constexpr int arity_out = OprArityTrait<Opr>::arity_out; static constexpr int arity_out = OprArityTrait<Opr>::arity_out;
static constexpr int arity = OprArityTrait<Opr>::arity; static constexpr int arity = OprArityTrait<Opr>::arity;

using ImplAlgo = typename Opr::Algorithm*; using ImplAlgo = typename Opr::Algorithm*;
using MGBOpr = typename MegDNNOpr2MGBOpr<Opr>::MGBOpr; using MGBOpr = typename MegDNNOpr2MGBOpr<Opr>::MGBOpr;
using ConvTensorLayouts = std::array<TensorLayout, arity>; using ConvTensorLayouts = std::array<TensorLayout, arity>;
@@ -473,8 +492,8 @@ class AlgoChooser {
//! put first //! put first
std::vector<ImplAlgo> get_all_candidates() const { std::vector<ImplAlgo> get_all_candidates() const {
auto heu = choose_by_heuristic(); auto heu = choose_by_heuristic();
auto&& ret = OprArityTrait<Opr>::get_all_algorithms(
m_megdnn_opr, m_layouts);
auto&& ret = OprArityTrait<Opr>::get_all_algorithms(m_megdnn_opr,
m_layouts);
bool found = false; bool found = false;
for (size_t i = 0; i < ret.size(); ++i) { for (size_t i = 0; i < ret.size(); ++i) {
if (ret[i] == heu) { if (ret[i] == heu) {
@@ -491,7 +510,7 @@ class AlgoChooser {


//! get candidate algos with workspace limit. //! get candidate algos with workspace limit.
std::vector<ImplAlgo> get_all_candidates_with_workspace_limit() const { std::vector<ImplAlgo> get_all_candidates_with_workspace_limit() const {
auto && all_algos = get_all_candidates();
auto&& all_algos = get_all_candidates();
auto opr = m_mgb_opr; auto opr = m_mgb_opr;
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
opr->owner_graph(), opr->comp_node(), opr->owner_graph(), opr->comp_node(),
@@ -633,16 +652,16 @@ AlgoChooserProfileCache::Result AlgoChooser<Opr>::get_profile_result(
algo->name(), str_on_inp_shape.c_str()); algo->name(), str_on_inp_shape.c_str());
timer.reset(); timer.reset();
MGB_TRY { cur_rst = ctx.profile_single_algo(algo, cur_timeout); } MGB_TRY { cur_rst = ctx.profile_single_algo(algo, cur_timeout); }
MGB_CATCH(std::exception & exc,
{
mgb_log_warn("caught exception during %s: %s",
msg.c_str(), exc.what());
continue;
})
MGB_CATCH(std::exception & exc, {
mgb_log_warn("caught exception during %s: %s", msg.c_str(),
exc.what());
continue;
})
MGB_CATCH(..., { MGB_CATCH(..., {
mgb_log_warn("caught exception during %s", msg.c_str()); mgb_log_warn("caught exception during %s", msg.c_str());
continue; continue;
}) if (!cur_rst.valid()) {
})
if (!cur_rst.valid()) {
mgb_log_warn("timeout when %s; timeout setting: %.3fsec", mgb_log_warn("timeout when %s; timeout setting: %.3fsec",
msg.c_str(), cur_timeout); msg.c_str(), cur_timeout);
continue; continue;
@@ -680,6 +699,7 @@ void AlgoChooser<megdnn::ConvBias>::get_origin_param_and_layouts(
template <typename Opr> template <typename Opr>
typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::choose_by_profile( typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::choose_by_profile(
ExeContext& ctx, bool require_reproducible, bool enable_update) { ExeContext& ctx, bool require_reproducible, bool enable_update) {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile")))
auto opr = ctx.mgb_opr(); auto opr = ctx.mgb_opr();
if (opr->owner_graph()->options().no_profiling_on_shape_change) { if (opr->owner_graph()->options().no_profiling_on_shape_change) {
auto algo = ctx.megdnn_opr()->execution_policy().algorithm; auto algo = ctx.megdnn_opr()->execution_policy().algorithm;
@@ -720,6 +740,7 @@ typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::choose_by_profile(
opr->owner_graph(), opr->comp_node(), opr->owner_graph(), opr->comp_node(),
opr->execution_policy().workspace_limit)); opr->execution_policy().workspace_limit));
mgb_trap(); mgb_trap();
MIDOUT_E
} }


template <> template <>
@@ -748,7 +769,7 @@ void AlgoChooser<megdnn::ConvBias>::ExeContext::
if (m_layouts[1].dtype.enumv() == DTypeEnum::QuantizedS8 && if (m_layouts[1].dtype.enumv() == DTypeEnum::QuantizedS8 &&
param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW44) { param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW44) {
if (winograd_preprocess_opr->param().format == if (winograd_preprocess_opr->param().format ==
megdnn::param::MatrixMul::Format::MK4){
megdnn::param::MatrixMul::Format::MK4) {
winograd_preprocess_opr->param().compute_mode = winograd_preprocess_opr->param().compute_mode =
ConvBias::Param::ComputeMode::FLOAT32; ConvBias::Param::ComputeMode::FLOAT32;
param.opr_param.compute_mode = param.opr_param.compute_mode =
@@ -941,6 +962,7 @@ void ConvolutionForward::init_output_dtype() {
output(0)->dtype(output_dtype); output(0)->dtype(output_dtype);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ConvolutionForward) { MGB_IMPL_OPR_GRAD(ConvolutionForward) {
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad"); "only float data type supported for grad");
@@ -960,6 +982,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionForward) {
return grad.node(); return grad.node();
} }
} }
#endif


size_t ConvolutionForward::get_workspace_size_bytes( size_t ConvolutionForward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes, const TensorShapeArray& input_shapes,
@@ -1086,6 +1109,7 @@ void ConvolutionBackwardData::scn_do_execute() {
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(1)));
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) { MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) {
mgb_assert(!out_grad[1]); mgb_assert(!out_grad[1]);
if (wrt_idx == 0) { if (wrt_idx == 0) {
@@ -1101,6 +1125,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) {
} }
return nullptr; return nullptr;
} }
#endif


/* ==================== ConvolutionBackwardFilter ==================== */ /* ==================== ConvolutionBackwardFilter ==================== */
IMPL_CONV(ConvolutionBackwardFilter, "conv_bwd_filter"); IMPL_CONV(ConvolutionBackwardFilter, "conv_bwd_filter");
@@ -1138,6 +1163,7 @@ size_t ConvolutionBackwardFilter::get_workspace_size_bytes(
megdnn_opr(), this); megdnn_opr(), this);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) { MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) {
mgb_assert(!out_grad[1]); mgb_assert(!out_grad[1]);
if (wrt_idx == 0) { if (wrt_idx == 0) {
@@ -1153,6 +1179,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) {
} }
return nullptr; return nullptr;
} }
#endif
/* ==================== Convolution3DForward ==================== */ /* ==================== Convolution3DForward ==================== */


IMPL_CONV(Convolution3DForward, "conv3d_fwd"); IMPL_CONV(Convolution3DForward, "conv3d_fwd");
@@ -1192,6 +1219,7 @@ void Convolution3DForward::init_output_dtype() {
} }
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Convolution3DForward) { MGB_IMPL_OPR_GRAD(Convolution3DForward) {
mgb_assert(opr.param().data_type == mgb_assert(opr.param().data_type ==
Convolution3DForward::Param::DataType::FLOAT, Convolution3DForward::Param::DataType::FLOAT,
@@ -1212,6 +1240,7 @@ MGB_IMPL_OPR_GRAD(Convolution3DForward) {
return grad.node(); return grad.node();
} }
} }
#endif


size_t Convolution3DForward::get_workspace_size_bytes( size_t Convolution3DForward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes, const TensorShapeArray& input_shapes,
@@ -1285,6 +1314,7 @@ void Convolution3DBackwardData::scn_do_execute() {
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(1)));
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) { MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) {
mgb_assert(!out_grad[1]); mgb_assert(!out_grad[1]);
if (wrt_idx == 0) { if (wrt_idx == 0) {
@@ -1300,6 +1330,7 @@ MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) {
} }
return nullptr; return nullptr;
} }
#endif


/* ==================== Convolution3DBackwardFilter ==================== */ /* ==================== Convolution3DBackwardFilter ==================== */
IMPL_CONV(Convolution3DBackwardFilter, "conv3d_bwd_filter"); IMPL_CONV(Convolution3DBackwardFilter, "conv3d_bwd_filter");
@@ -1658,6 +1689,7 @@ size_t LocalShareForward::get_workspace_size_bytes(
megdnn_opr(), this); megdnn_opr(), this);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LocalShareForward) { MGB_IMPL_OPR_GRAD(LocalShareForward) {
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad"); "only float data type supported for grad");
@@ -1677,6 +1709,7 @@ MGB_IMPL_OPR_GRAD(LocalShareForward) {
return grad.node(); return grad.node();
} }
} }
#endif


/* ===================== LocalShareBackwardData ==================== */ /* ===================== LocalShareBackwardData ==================== */


@@ -1737,6 +1770,7 @@ void LocalShareBackwardData::scn_do_execute() {
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(1)));
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LocalShareBackwardData) { MGB_IMPL_OPR_GRAD(LocalShareBackwardData) {
mgb_assert(!out_grad[1]); mgb_assert(!out_grad[1]);
if (wrt_idx == 0) { if (wrt_idx == 0) {
@@ -1752,6 +1786,7 @@ MGB_IMPL_OPR_GRAD(LocalShareBackwardData) {
} }
return nullptr; return nullptr;
} }
#endif


/* ==================== LocalShareBackwardFilter ==================== */ /* ==================== LocalShareBackwardFilter ==================== */


@@ -1792,6 +1827,7 @@ size_t LocalShareBackwardFilter::get_workspace_size_bytes(
megdnn_opr(), this); megdnn_opr(), this);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) { MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) {
mgb_assert(!out_grad[1]); mgb_assert(!out_grad[1]);
if (wrt_idx == 0) { if (wrt_idx == 0) {
@@ -1805,6 +1841,7 @@ MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) {
} }
return nullptr; return nullptr;
} }
#endif


/* ===================== DeformableConvForward ==================== */ /* ===================== DeformableConvForward ==================== */


@@ -1869,6 +1906,7 @@ size_t DeformableConvForward::get_workspace_size_bytes(
megdnn_opr(), this); megdnn_opr(), this);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(DeformableConvForward) { MGB_IMPL_OPR_GRAD(DeformableConvForward) {
mgb_assert(opr.input(0)->dtype() == dtype::Float32(), mgb_assert(opr.input(0)->dtype() == dtype::Float32(),
"only float data type supported for grad"); "only float data type supported for grad");
@@ -1888,6 +1926,7 @@ MGB_IMPL_OPR_GRAD(DeformableConvForward) {
SymbolVarArray grads = {grad_arr[0], filter_grad, grad_arr[1], grad_arr[2]}; SymbolVarArray grads = {grad_arr[0], filter_grad, grad_arr[1], grad_arr[2]};
return grads[wrt_idx].node(); return grads[wrt_idx].node();
} }
#endif


/* ==================== DeformableConvBackwardData ==================== */ /* ==================== DeformableConvBackwardData ==================== */


@@ -2265,4 +2304,4 @@ void BatchConvBiasForward::init_output_format() {
#undef IMPL_CONV #undef IMPL_CONV
#undef MGB_FOREACH_FASTRUN_OPR #undef MGB_FOREACH_FASTRUN_OPR


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 2
- 0
src/opr/impl/dnn/images2neibs.cpp View File

@@ -20,11 +20,13 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsForward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsForward);
MEGDNN_OPR_INIT1(Images2NeibsForward, "images2neibs") MEGDNN_OPR_INIT1(Images2NeibsForward, "images2neibs")


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Images2NeibsForward) { MGB_IMPL_OPR_GRAD(Images2NeibsForward) {
mgb_assert(wrt_idx == 0 && out_grad.size() == 2 && !out_grad[1]); mgb_assert(wrt_idx == 0 && out_grad.size() == 2 && !out_grad[1]);
return Images2NeibsBackward::make( return Images2NeibsBackward::make(
out_grad[0], opr.input(0), opr.param()).node(); out_grad[0], opr.input(0), opr.param()).node();
} }
#endif


MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsBackward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsBackward);
MEGDNN_OPR_INIT2(Images2NeibsBackward, "images2neibs_grad", 1, false); MEGDNN_OPR_INIT2(Images2NeibsBackward, "images2neibs_grad", 1, false);


+ 6
- 0
src/opr/impl/dnn/local.cpp View File

@@ -20,10 +20,13 @@ using namespace opr;


MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalForward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalForward);
MEGDNN_OPR_INIT2(LocalForward, "local") MEGDNN_OPR_INIT2(LocalForward, "local")

#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LocalForward) { MGB_IMPL_OPR_GRAD(LocalForward) {
return intl::conv_grad<LocalBackwardData, LocalBackwardFilter>( return intl::conv_grad<LocalBackwardData, LocalBackwardFilter>(
opr, wrt_idx, out_grad); opr, wrt_idx, out_grad);
} }
#endif


MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalBackwardData); MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalBackwardData);
MEGDNN_OPR_INIT3(LocalBackwardData, "local_bwd_data", 2, false); MEGDNN_OPR_INIT3(LocalBackwardData, "local_bwd_data", 2, false);
@@ -34,10 +37,13 @@ MEGDNN_OPR_INIT3(LocalBackwardFilter, "local_bwd_filter", 2, false);


MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalForward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalForward);
MEGDNN_OPR_INIT2(GroupLocalForward, "glocal") MEGDNN_OPR_INIT2(GroupLocalForward, "glocal")

#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(GroupLocalForward) { MGB_IMPL_OPR_GRAD(GroupLocalForward) {
return intl::conv_grad<GroupLocalBackwardData, GroupLocalBackwardFilter>( return intl::conv_grad<GroupLocalBackwardData, GroupLocalBackwardFilter>(
opr, wrt_idx, out_grad); opr, wrt_idx, out_grad);
} }
#endif


MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalBackwardData); MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalBackwardData);
MEGDNN_OPR_INIT3(GroupLocalBackwardData, "glocal_bwd_data", 2, false); MEGDNN_OPR_INIT3(GroupLocalBackwardData, "glocal_bwd_data", 2, false);


+ 2
- 0
src/opr/impl/dnn/lrn.cpp View File

@@ -20,12 +20,14 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNForward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNForward);
MEGDNN_OPR_INIT1(LRNForward, "lrn") MEGDNN_OPR_INIT1(LRNForward, "lrn")


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LRNForward) { MGB_IMPL_OPR_GRAD(LRNForward) {
mgb_assert(wrt_idx == 0); mgb_assert(wrt_idx == 0);
SymbolVar grad = LRNBackward::make( SymbolVar grad = LRNBackward::make(
opr.input(0), opr.output(0), out_grad[0], opr.param()); opr.input(0), opr.output(0), out_grad[0], opr.param());
return grad.node(); return grad.node();
} }
#endif


MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNBackward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNBackward);
MEGDNN_OPR_INIT3(LRNBackward, "lrn_bwd", 0, true); MEGDNN_OPR_INIT3(LRNBackward, "lrn_bwd", 0, true);


+ 2
- 0
src/opr/impl/dnn/pooling.cpp View File

@@ -19,12 +19,14 @@ using namespace opr;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward);
MEGDNN_OPR_INIT1(PoolingForward, "pooling") MEGDNN_OPR_INIT1(PoolingForward, "pooling")


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(PoolingForward) { MGB_IMPL_OPR_GRAD(PoolingForward) {
mgb_assert(wrt_idx == 0); mgb_assert(wrt_idx == 0);
SymbolVar grad = PoolingBackward::make( SymbolVar grad = PoolingBackward::make(
opr.input(0), opr.output(0), out_grad[0], opr.param()); opr.input(0), opr.output(0), out_grad[0], opr.param());
return grad.node(); return grad.node();
} }
#endif


MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingBackward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingBackward);
MEGDNN_OPR_INIT3(PoolingBackward, "pooling_bwd", 0, true); MEGDNN_OPR_INIT3(PoolingBackward, "pooling_bwd", 0, true);


+ 2
- 0
src/opr/impl/dnn/roi_align.cpp View File

@@ -40,6 +40,7 @@ SymbolVar ROIAlignForward::make(SymbolVar src, SymbolVar rois,
src.node(), rois.node(), param, config); src.node(), rois.node(), param, config);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ROIAlignForward) { MGB_IMPL_OPR_GRAD(ROIAlignForward) {
if (out_grad[1]) { if (out_grad[1]) {
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
@@ -55,6 +56,7 @@ MGB_IMPL_OPR_GRAD(ROIAlignForward) {
return nullptr; return nullptr;
} }
} }
#endif


/* ==================== ROIAlignBackward ==================== */ /* ==================== ROIAlignBackward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ROIAlignBackward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(ROIAlignBackward);


+ 4
- 0
src/opr/impl/dnn/roi_pooling.cpp View File

@@ -84,6 +84,7 @@ size_t ROIPoolingForward::get_workspace_size_bytes(
input_shapes, output_shapes); input_shapes, output_shapes);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ROIPoolingForward) { MGB_IMPL_OPR_GRAD(ROIPoolingForward) {
if (out_grad[1] || wrt_idx == 2) { if (out_grad[1] || wrt_idx == 2) {
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
@@ -98,6 +99,7 @@ MGB_IMPL_OPR_GRAD(ROIPoolingForward) {
return nullptr; return nullptr;
} }
} }
#endif


void ROIPoolingForward::scn_do_execute() { void ROIPoolingForward::scn_do_execute() {
return intl::MegDNNOprMethInvoker<megdnn::ROIPoolingForward>:: return intl::MegDNNOprMethInvoker<megdnn::ROIPoolingForward>::
@@ -146,6 +148,7 @@ SymbolVar DeformablePSROIPoolingForward::make(
return all[0]; return all[0];
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) { MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) {
mgb_assert(wrt_idx <= 2); // wrt_idx = 0 or 1 or 2 mgb_assert(wrt_idx <= 2); // wrt_idx = 0 or 1 or 2


@@ -168,6 +171,7 @@ MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) {
} }
return nullptr; return nullptr;
} }
#endif


/* ==================== DeformablePSROIPoolingBackward ==================== */ /* ==================== DeformablePSROIPoolingBackward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(DeformablePSROIPoolingBackward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(DeformablePSROIPoolingBackward);


+ 4
- 0
src/opr/impl/imgproc.cpp View File

@@ -127,6 +127,7 @@ void WarpPerspectiveForward::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps); record_megdnn_opr(deps);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) { MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) {
mgb_assert(opr.input().size() == 3, mgb_assert(opr.input().size() == 3,
"backward with mat_idx is currently unsupported"); "backward with mat_idx is currently unsupported");
@@ -145,6 +146,7 @@ MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) {
} else } else
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
} }
#endif


/* ====================== WarpPerspectiveBackwardData ====================== */ /* ====================== WarpPerspectiveBackwardData ====================== */


@@ -234,6 +236,7 @@ void ResizeForward::record_execute_deps(ExecDependencyArray &deps) {
record_megdnn_opr(deps); record_megdnn_opr(deps);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ResizeForward) { MGB_IMPL_OPR_GRAD(ResizeForward) {
mgb_assert(opr.input().size() == 2); mgb_assert(opr.input().size() == 2);
if (wrt_idx == 0) { if (wrt_idx == 0) {
@@ -243,6 +246,7 @@ MGB_IMPL_OPR_GRAD(ResizeForward) {
} else } else
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
} }
#endif


/* ====================== ResizeBackward ====================== */ /* ====================== ResizeBackward ====================== */




+ 26
- 0
src/opr/impl/indexing.cpp View File

@@ -83,6 +83,7 @@ void IndexingOneHot::init_output_dtype() {
output(0)->dtype(input(0)->dtype()); output(0)->dtype(input(0)->dtype());
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingOneHot) { MGB_IMPL_OPR_GRAD(IndexingOneHot) {
if (wrt_idx == 0) { if (wrt_idx == 0) {
return IndexingSetOneHot::make( return IndexingSetOneHot::make(
@@ -91,6 +92,7 @@ MGB_IMPL_OPR_GRAD(IndexingOneHot) {
} }
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
} }
#endif


/* ==================== IndexingSetOneHot ==================== */ /* ==================== IndexingSetOneHot ==================== */


@@ -133,6 +135,7 @@ void IndexingSetOneHot::scn_do_execute() {
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(1)));
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingSetOneHot) { MGB_IMPL_OPR_GRAD(IndexingSetOneHot) {
SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)}; SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)};
if (wrt_idx == 0) { if (wrt_idx == 0) {
@@ -144,6 +147,7 @@ MGB_IMPL_OPR_GRAD(IndexingSetOneHot) {
} }
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
} }
#endif


size_t IndexingSetOneHot::get_workspace_size_bytes( size_t IndexingSetOneHot::get_workspace_size_bytes(
const TensorShapeArray &input_shapes, const TensorShapeArray &input_shapes,
@@ -165,6 +169,7 @@ void IndexingRemap::init_output_dtype() {
output(0)->dtype(input(0)->dtype()); output(0)->dtype(input(0)->dtype());
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingRemap) { MGB_IMPL_OPR_GRAD(IndexingRemap) {
if (wrt_idx == 1) if (wrt_idx == 1)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
@@ -172,6 +177,7 @@ MGB_IMPL_OPR_GRAD(IndexingRemap) {
return IndexingRemapBackward::make( return IndexingRemapBackward::make(
out_grad[0], opr.input(1), opr.input(0), opr.param()).node(); out_grad[0], opr.input(1), opr.input(0), opr.param()).node();
} }
#endif


MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingRemapBackward); MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingRemapBackward);
MEGDNN_OPR_INIT3(IndexingRemapBackward, "indexing_remap_bwd", 2, false); MEGDNN_OPR_INIT3(IndexingRemapBackward, "indexing_remap_bwd", 2, false);
@@ -460,6 +466,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false); IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false);


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) {
if (wrt_idx) if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
@@ -468,7 +475,9 @@ MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) {
SymbolVar{opr.input(0)}.fill_retain_dtype(0), SymbolVar{opr.input(0)}.fill_retain_dtype(0),
out_grad.at(0), opr.index_desc()).node(); out_grad.at(0), opr.index_desc()).node();
} }
#endif


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) { MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) {
if (wrt_idx >= 2) if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
@@ -479,7 +488,9 @@ MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) {
} }
return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node(); return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node();
} }
#endif


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) { MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) {
if (wrt_idx >= 2) if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
@@ -488,6 +499,7 @@ MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) {
} }
return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node(); return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node();
} }
#endif


/* ============================= Mesh Indexing ============================ */ /* ============================= Mesh Indexing ============================ */


@@ -498,6 +510,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET(
BatchedMeshIndexing, "batched_mesh_indexing", false, BatchedMeshIndexing, "batched_mesh_indexing", false,
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);); output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MeshIndexing) { MGB_IMPL_OPR_GRAD(MeshIndexing) {
if (wrt_idx != 0) { if (wrt_idx != 0) {
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
@@ -507,6 +520,9 @@ MGB_IMPL_OPR_GRAD(MeshIndexing) {
opr.index_desc()) opr.index_desc())
.node(); .node();
} }
#endif

#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) { MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) {
if (wrt_idx != 0) { if (wrt_idx != 0) {
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
@@ -516,11 +532,14 @@ MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) {
opr.index_desc()) opr.index_desc())
.node(); .node();
} }
#endif


/* ========================= IncrMeshIndexing ========================= */ /* ========================= IncrMeshIndexing ========================= */


MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing", MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing",
false); false);

#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IncrMeshIndexing) { MGB_IMPL_OPR_GRAD(IncrMeshIndexing) {
if (wrt_idx > 2) { if (wrt_idx > 2) {
return opr::InvalidGrad::make(opr, wrt_idx); return opr::InvalidGrad::make(opr, wrt_idx);
@@ -530,9 +549,11 @@ MGB_IMPL_OPR_GRAD(IncrMeshIndexing) {
} }
return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node(); return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
} }
#endif


MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedIncrMeshIndexing, MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedIncrMeshIndexing,
"batched_incr_mesh_indexing", false); "batched_incr_mesh_indexing", false);
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) { MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) {
if (wrt_idx > 2) { if (wrt_idx > 2) {
return opr::InvalidGrad::make(opr, wrt_idx); return opr::InvalidGrad::make(opr, wrt_idx);
@@ -542,10 +563,12 @@ MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) {
} }
return BatchedMeshIndexing::make(out_grad.at(0), opr.index_desc()).node(); return BatchedMeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
} }
#endif


/* ======================== SetMeshIndexing =========================== */ /* ======================== SetMeshIndexing =========================== */
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false); MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false);


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SetMeshIndexing) { MGB_IMPL_OPR_GRAD(SetMeshIndexing) {
if (wrt_idx >= 2) { if (wrt_idx >= 2) {
return opr::InvalidGrad::make(opr, wrt_idx); return opr::InvalidGrad::make(opr, wrt_idx);
@@ -560,9 +583,11 @@ MGB_IMPL_OPR_GRAD(SetMeshIndexing) {
return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node(); return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node();
} }
} }
#endif


MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedSetMeshIndexing, MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedSetMeshIndexing,
"batched_set_mesh_indexing", false); "batched_set_mesh_indexing", false);
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) { MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) {
if (wrt_idx > 2) { if (wrt_idx > 2) {
return opr::InvalidGrad::make(opr, wrt_idx); return opr::InvalidGrad::make(opr, wrt_idx);
@@ -578,5 +603,6 @@ MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) {
.node(); .node();
} }
} }
#endif


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 2
- 0
src/opr/impl/io.cpp View File

@@ -764,11 +764,13 @@ Copy::NodeProp* Copy::do_make_node_prop() const {
return rst; return rst;
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Copy) { MGB_IMPL_OPR_GRAD(Copy) {
mgb_assert(wrt_idx == 0); mgb_assert(wrt_idx == 0);
return Copy::make(out_grad[0], return Copy::make(out_grad[0],
OperatorNodeConfig{}.follow_comp_node(opr.input(0))).node(); OperatorNodeConfig{}.follow_comp_node(opr.input(0))).node();
} }
#endif


void Copy::add_input_layout_constraint() { void Copy::add_input_layout_constraint() {
if (input(0)->comp_node() != output(0)->comp_node()) { if (input(0)->comp_node() != output(0)->comp_node()) {


+ 2
- 0
src/opr/impl/loop/forward.cpp View File

@@ -268,9 +268,11 @@ VarNode* Loop::grad(Loop &opr, size_t wrt_idx, const VarNodeArray &out_grad) {
return gopr->get_grad_var(wrt_idx); return gopr->get_grad_var(wrt_idx);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Loop) { MGB_IMPL_OPR_GRAD(Loop) {
return Loop::grad(const_cast<Loop&>(opr), wrt_idx, out_grad); return Loop::grad(const_cast<Loop&>(opr), wrt_idx, out_grad);
} }
#endif


cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const { cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const {
auto prop = LoopImpl::do_make_node_prop(); auto prop = LoopImpl::do_make_node_prop();


+ 12
- 1
src/opr/impl/misc.cpp View File

@@ -48,23 +48,26 @@ namespace intl {


/* ================= Argmxx ================= */ /* ================= Argmxx ================= */


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Argmax) { MGB_IMPL_OPR_GRAD(Argmax) {
MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(out_grad);
MGB_MARK_USED_VAR(opr); MGB_MARK_USED_VAR(opr);
mgb_assert(!wrt_idx); mgb_assert(!wrt_idx);
return nullptr; return nullptr;
} }
#endif


MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmax); MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmax);
MEGDNN_OPR_INIT1(Argmax, "argmax") MEGDNN_OPR_INIT1(Argmax, "argmax")


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Argmin) { MGB_IMPL_OPR_GRAD(Argmin) {
MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(out_grad);
MGB_MARK_USED_VAR(opr); MGB_MARK_USED_VAR(opr);
mgb_assert(!wrt_idx); mgb_assert(!wrt_idx);
return nullptr; return nullptr;
} }
#endif


MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmin); MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmin);
MEGDNN_OPR_INIT1(Argmin, "argmin") MEGDNN_OPR_INIT1(Argmin, "argmin")
@@ -84,12 +87,14 @@ std::array<SymbolVar, 2> ArgsortForward::make(
return {node->output(0), node->output(1)}; return {node->output(0), node->output(1)};
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ArgsortForward) { MGB_IMPL_OPR_GRAD(ArgsortForward) {
mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]);
if (!out_grad[0]) if (!out_grad[0])
return nullptr; return nullptr;
return ArgsortBackward::make(out_grad[0], opr.output(1)).node(); return ArgsortBackward::make(out_grad[0], opr.output(1)).node();
} }
#endif


/* ================= ArgsortBackward ================= */ /* ================= ArgsortBackward ================= */


@@ -107,12 +112,14 @@ Cumsum::Cumsum(VarNode* opr, const Param& param,
add_input({opr}, AddInputSortType::CUR_ADDED); add_input({opr}, AddInputSortType::CUR_ADDED);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Cumsum) { MGB_IMPL_OPR_GRAD(Cumsum) {
mgb_assert(out_grad[0] && !out_grad[1]); mgb_assert(out_grad[0] && !out_grad[1]);
auto param = opr.param(); auto param = opr.param();
param.reverse = !param.reverse; param.reverse = !param.reverse;
return Cumsum::make(out_grad[0], param).node(); return Cumsum::make(out_grad[0], param).node();
} }
#endif


SymbolVar Cumsum::make(SymbolVar opr, const Param& param, SymbolVar Cumsum::make(SymbolVar opr, const Param& param,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
@@ -170,6 +177,7 @@ CondTake::CondTake(VarNode *data, VarNode *mask,
} }
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CondTake) { MGB_IMPL_OPR_GRAD(CondTake) {
mgb_assert(out_grad.size() == 3 && !out_grad[2]); mgb_assert(out_grad.size() == 3 && !out_grad[2]);
if (wrt_idx == 0 && out_grad[0]) { if (wrt_idx == 0 && out_grad[0]) {
@@ -181,6 +189,7 @@ MGB_IMPL_OPR_GRAD(CondTake) {
} }
return nullptr; return nullptr;
} }
#endif


std::array<SymbolVar, 2> CondTake::make( std::array<SymbolVar, 2> CondTake::make(
SymbolVar data, SymbolVar mask, SymbolVar data, SymbolVar mask,
@@ -318,6 +327,7 @@ void TopK::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps); record_megdnn_opr(deps);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(TopK) { MGB_IMPL_OPR_GRAD(TopK) {
if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) { if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) {
mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]); mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]);
@@ -334,5 +344,6 @@ MGB_IMPL_OPR_GRAD(TopK) {
return ArgsortBackward::make(out_grad[0], opr.output(1), opr.input(0)) return ArgsortBackward::make(out_grad[0], opr.output(1), opr.input(0))
.node(); .node();
} }
#endif


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 2
- 0
src/opr/impl/muxing.cpp View File

@@ -316,9 +316,11 @@ VarNodeArray AllGather::grad(const VarNodeArray &out_grad) {
OperatorNodeConfig().comp_node_arr(sp_cn))); OperatorNodeConfig().comp_node_arr(sp_cn)));
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(AllGather) { MGB_IMPL_OPR_GRAD(AllGather) {
return const_cast<AllGather&>(opr).grad(out_grad); return const_cast<AllGather&>(opr).grad(out_grad);
} }
#endif


void AllGather::on_output_comp_node_stream_changed() { void AllGather::on_output_comp_node_stream_changed() {
} }


+ 9
- 7
src/opr/impl/rand.cpp View File

@@ -112,19 +112,21 @@ UniqPtrWithCN<megdnn::RNGBase> RNGOpr<MegDNNOpr>::create_megdnn_opr() {
return opr; return opr;
} }


#define IMPL(_cls) \
template class RNGOpr<::megdnn::_cls>; \
MGB_IMPL_OPR_GRAD(_cls) { \
MGB_MARK_USED_VAR(out_grad); \
return InvalidGrad::make(opr, wrt_idx); \
} \

#define IMPL(_cls) \
MGB_IMPL_OPR_GRAD(_cls) { \
MGB_MARK_USED_VAR(out_grad); \
return InvalidGrad::make(opr, wrt_idx); \
}


namespace mgb { namespace mgb {
namespace opr { namespace opr {
namespace intl { namespace intl {
template class RNGOpr<::megdnn::GaussianRNG>;
template class RNGOpr<::megdnn::UniformRNG>;
#ifdef MGB_ENABLE_GRAD
IMPL(GaussianRNG); IMPL(GaussianRNG);
IMPL(UniformRNG); IMPL(UniformRNG);
#endif
} }
} }
} }


+ 6
- 1
src/opr/impl/tensor_gen.cpp View File

@@ -46,11 +46,13 @@ void Alloc::outshape_by_symvar_do_get_output_shape(
void Alloc::scn_do_execute() { void Alloc::scn_do_execute() {
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Alloc) { MGB_IMPL_OPR_GRAD(Alloc) {
MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(wrt_idx);
MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(out_grad);
return InvalidGrad::make(opr, 0); return InvalidGrad::make(opr, 0);
} }
#endif


/* ======================= Linspace ======================= */ /* ======================= Linspace ======================= */


@@ -123,6 +125,7 @@ void Linspace::record_execute_deps(ExecDependencyArray& deps) {
std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr)));
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Linspace) { MGB_IMPL_OPR_GRAD(Linspace) {
if (wrt_idx == 2) if (wrt_idx == 2)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
@@ -134,6 +137,7 @@ MGB_IMPL_OPR_GRAD(Linspace) {
return opr::Dot::make(og, return opr::Dot::make(og,
opr::Linspace::make(i0, i1, opr.input(2), opr.param())).node(); opr::Linspace::make(i0, i1, opr.input(2), opr.param())).node();
} }
#endif


/* ======================= Eye ======================= */ /* ======================= Eye ======================= */


@@ -195,9 +199,10 @@ void Eye::record_execute_deps(ExecDependencyArray& deps) {
std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr)));
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Eye) { MGB_IMPL_OPR_GRAD(Eye) {
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
} }
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}



+ 22
- 1
src/opr/impl/tensor_manip.cpp View File

@@ -165,12 +165,13 @@ void GetVarShape::init_output_static_infer_desc() {
mgr.register_value_infer(output(0), mgr.register_value_infer(output(0),
{SourceType::DEP, deps, infer_value}); {SourceType::DEP, deps, infer_value});
} }
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(GetVarShape) { MGB_IMPL_OPR_GRAD(GetVarShape) {
MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(wrt_idx);
MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(out_grad);
return nullptr; return nullptr;
} }
#endif


SymbolVar GetVarShape::make(const VarNodeArrayView& inp, Param param, SymbolVar GetVarShape::make(const VarNodeArrayView& inp, Param param,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
@@ -362,11 +363,13 @@ SymbolVar Reshape::make(SymbolVar inp, SymbolVar tshp,
inp.node(), tshp.node(), unspec_axis, config); inp.node(), tshp.node(), unspec_axis, config);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Reshape) { MGB_IMPL_OPR_GRAD(Reshape) {
if (wrt_idx) if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node(); return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node();
} }
#endif


Maybe<TensorLayout> Reshape::reshapebrdcast_get_dest_layout( Maybe<TensorLayout> Reshape::reshapebrdcast_get_dest_layout(
const TensorLayout &src, const TensorShape &tshape) const { const TensorLayout &src, const TensorShape &tshape) const {
@@ -429,12 +432,14 @@ SymbolVar Broadcast::make(SymbolVar inp, SymbolVar tshp,
inp.node(), tshp.node(), config); inp.node(), tshp.node(), config);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Broadcast) { MGB_IMPL_OPR_GRAD(Broadcast) {
if (wrt_idx) if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
return Reduce::make(out_grad.at(0), Reduce::Mode::SUM, return Reduce::make(out_grad.at(0), Reduce::Mode::SUM,
GetVarShape::make(opr.input(0))).node(); GetVarShape::make(opr.input(0))).node();
} }
#endif


Maybe<TensorLayout> Broadcast::reshapebrdcast_get_dest_layout( Maybe<TensorLayout> Broadcast::reshapebrdcast_get_dest_layout(
const TensorLayout &src, const TensorShape &tshape) const { const TensorLayout &src, const TensorShape &tshape) const {
@@ -562,9 +567,11 @@ VarNode* Dimshuffle::grad(
return Dimshuffle::make(out_grad.at(0), back, m_pattern.size()).node(); return Dimshuffle::make(out_grad.at(0), back, m_pattern.size()).node();
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Dimshuffle) { MGB_IMPL_OPR_GRAD(Dimshuffle) {
return opr.grad(wrt_idx, out_grad); return opr.grad(wrt_idx, out_grad);
} }
#endif


// f}}} // f}}}


@@ -631,10 +638,12 @@ AxisAddRemove::NodeProp* AxisAddRemove::do_make_node_prop() const {
return ret; return ret;
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(AxisAddRemove) { MGB_IMPL_OPR_GRAD(AxisAddRemove) {
MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(wrt_idx);
return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node(); return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node();
} }
#endif


// f}}} // f}}}


@@ -642,6 +651,7 @@ MGB_IMPL_OPR_GRAD(AxisAddRemove) {


MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true); MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true);


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Subtensor) { MGB_IMPL_OPR_GRAD(Subtensor) {
if (wrt_idx) if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
@@ -650,6 +660,7 @@ MGB_IMPL_OPR_GRAD(Subtensor) {
SymbolVar{opr.input(0)}.fill_retain_dtype(0), SymbolVar{opr.input(0)}.fill_retain_dtype(0),
out_grad.at(0), opr.index_desc()).node(); out_grad.at(0), opr.index_desc()).node();
} }
#endif


void Subtensor::init_output_static_infer_desc() { void Subtensor::init_output_static_infer_desc() {
using namespace cg::static_infer; using namespace cg::static_infer;
@@ -783,6 +794,7 @@ void SetSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) {
sub.copy_from_fixlayout(val); sub.copy_from_fixlayout(val);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SetSubtensor) { MGB_IMPL_OPR_GRAD(SetSubtensor) {
if (wrt_idx >= 2) if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
@@ -793,6 +805,7 @@ MGB_IMPL_OPR_GRAD(SetSubtensor) {
} }
return Subtensor::make(out_grad.at(0), opr.index_desc()).node(); return Subtensor::make(out_grad.at(0), opr.index_desc()).node();
} }
#endif


// f}}} // f}}}


@@ -813,6 +826,7 @@ void IncrSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) {
opr->exec(sub.as_megdnn(), val.as_megdnn()); opr->exec(sub.as_megdnn(), val.as_megdnn());
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IncrSubtensor) { MGB_IMPL_OPR_GRAD(IncrSubtensor) {
if (wrt_idx >= 2) if (wrt_idx >= 2)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
@@ -821,6 +835,7 @@ MGB_IMPL_OPR_GRAD(IncrSubtensor) {
} }
return Subtensor::make(out_grad.at(0), opr.index_desc()).node(); return Subtensor::make(out_grad.at(0), opr.index_desc()).node();
} }
#endif


// f}}} // f}}}


@@ -1085,6 +1100,7 @@ void Split::do_execute(ExecEnv &env) {
} }
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Split) { MGB_IMPL_OPR_GRAD(Split) {
if (wrt_idx) if (wrt_idx)
return InvalidGrad::make(opr, wrt_idx); return InvalidGrad::make(opr, wrt_idx);
@@ -1100,6 +1116,7 @@ MGB_IMPL_OPR_GRAD(Split) {
return Concat::make(grad, opr.options().axis, return Concat::make(grad, opr.options().axis,
OperatorNodeConfig{}.follow_comp_node(opr.input(0))).node(); OperatorNodeConfig{}.follow_comp_node(opr.input(0))).node();
} }
#endif


void Split::mem_plan_fwd_in2out_readonly() { void Split::mem_plan_fwd_in2out_readonly() {
m_readonly_fwd_called = true; m_readonly_fwd_called = true;
@@ -1236,6 +1253,7 @@ SymbolVar Concat::make(const VarNodeArrayView& inp, int axis,
axis, config); axis, config);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Concat) { MGB_IMPL_OPR_GRAD(Concat) {
auto axis = opr.axis(); auto axis = opr.axis();
mgb_assert(out_grad.size() == 1); mgb_assert(out_grad.size() == 1);
@@ -1250,6 +1268,7 @@ MGB_IMPL_OPR_GRAD(Concat) {
OperatorNodeConfig().comp_node_arr(comp_node)); OperatorNodeConfig().comp_node_arr(comp_node));
return cg::to_var_node_array(ret); return cg::to_var_node_array(ret);
} }
#endif


void Concat::scn_do_execute() { void Concat::scn_do_execute() {
auto&& out = output(0)->dev_tensor(); auto&& out = output(0)->dev_tensor();
@@ -1507,6 +1526,7 @@ void ParamPackSplit::init_output_static_infer_desc() {
} }
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ParamPackSplit) { MGB_IMPL_OPR_GRAD(ParamPackSplit) {
mgb_assert(out_grad.size() == opr.output().size()); mgb_assert(out_grad.size() == opr.output().size());
SmallVector<SymbolVar> grad; SmallVector<SymbolVar> grad;
@@ -1531,6 +1551,7 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) {
OperatorNodeConfig{}.follow_comp_node(opr.input(0))) OperatorNodeConfig{}.follow_comp_node(opr.input(0)))
.node(); .node();
} }
#endif
// f}}} // f}}}


/* f{{{ ======================= RelayoutFormat ======================= */ /* f{{{ ======================= RelayoutFormat ======================= */


+ 12
- 0
src/opr/impl/utility.cpp View File

@@ -255,9 +255,11 @@ void MarkDynamicVar::scn_do_execute() {
o->dev_tensor().copy_from_fixlayout(i->dev_tensor()); o->dev_tensor().copy_from_fixlayout(i->dev_tensor());
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MarkDynamicVar) { MGB_IMPL_OPR_GRAD(MarkDynamicVar) {
return MarkDynamicVar::make(out_grad.at(0)).node(); return MarkDynamicVar::make(out_grad.at(0)).node();
} }
#endif


MarkDynamicVar::MarkDynamicVar(VarNode *node, const OperatorNodeConfig &config): MarkDynamicVar::MarkDynamicVar(VarNode *node, const OperatorNodeConfig &config):
Super{node->owner_graph(), config, "mark_dyn", {node}} Super{node->owner_graph(), config, "mark_dyn", {node}}
@@ -381,10 +383,12 @@ CallbackInjector::mixin_get_static_infer_desc(OperatorNodeBase &opr) {
} }
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CallbackInjector) { MGB_IMPL_OPR_GRAD(CallbackInjector) {
MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(wrt_idx);
return out_grad.at(0); return out_grad.at(0);
} }
#endif


/* ===================== MarkNoBroadcastElemwise ===================== */ /* ===================== MarkNoBroadcastElemwise ===================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MarkNoBroadcastElemwise); MGB_DYN_TYPE_OBJ_FINAL_IMPL(MarkNoBroadcastElemwise);
@@ -404,9 +408,11 @@ SymbolVar MarkNoBroadcastElemwise::make(
input.node(), config); input.node(), config);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MarkNoBroadcastElemwise) { MGB_IMPL_OPR_GRAD(MarkNoBroadcastElemwise) {
return out_grad.at(0); return out_grad.at(0);
} }
#endif


/* ===================== Identity ===================== */ /* ===================== Identity ===================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Identity); MGB_DYN_TYPE_OBJ_FINAL_IMPL(Identity);
@@ -429,9 +435,11 @@ SymbolVar Identity::make(
return input.insert_single_output_opr<Identity>(input.node(), config); return input.insert_single_output_opr<Identity>(input.node(), config);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Identity) { MGB_IMPL_OPR_GRAD(Identity) {
return out_grad.at(0); return out_grad.at(0);
} }
#endif


/* ===================== AssertEqual ===================== */ /* ===================== AssertEqual ===================== */


@@ -530,6 +538,7 @@ SymbolVar SetGrad::make(SymbolVar input, const GradGetter& grad_getter,
input.node(), grad_getter, config); input.node(), grad_getter, config);
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SetGrad) { MGB_IMPL_OPR_GRAD(SetGrad) {
MGB_MARK_USED_VAR(wrt_idx); MGB_MARK_USED_VAR(wrt_idx);
MGB_MARK_USED_VAR(out_grad); MGB_MARK_USED_VAR(out_grad);
@@ -538,6 +547,7 @@ MGB_IMPL_OPR_GRAD(SetGrad) {
"var returned by grad_getter belongs to a different comp graph"); "var returned by grad_getter belongs to a different comp graph");
return grad.node(); return grad.node();
} }
#endif


/* ===================== InvalidGrad ===================== */ /* ===================== InvalidGrad ===================== */


@@ -690,6 +700,7 @@ VirtualLoss::NodeProp* VirtualLoss::do_make_node_prop() const {
return ret; return ret;
} }


#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(VirtualLoss) { MGB_IMPL_OPR_GRAD(VirtualLoss) {
mgb_assert(out_grad.size() == 1); mgb_assert(out_grad.size() == 1);
auto mid = opr.input().size() / 2; auto mid = opr.input().size() / 2;
@@ -698,6 +709,7 @@ MGB_IMPL_OPR_GRAD(VirtualLoss) {
} }
return nullptr; return nullptr;
} }
#endif


#else #else
VarNode* InvalidGrad::make(const OperatorNodeBase&, size_t) { VarNode* InvalidGrad::make(const OperatorNodeBase&, size_t) {


+ 13
- 0
src/plugin/impl/opr_footprint.cpp View File

@@ -24,6 +24,16 @@
#include "megdnn/opr_param_json.h" #include "megdnn/opr_param_json.h"
#endif #endif


#include "megbrain/utils/hash_ct.h"
#include "midout.h"

MIDOUT_DECL(megbrain_opr_footprint)
#define MIDOUT_B(...) \
MIDOUT_BEGIN(megbrain_opr_footprint, __VA_ARGS__) {
#define MIDOUT_E \
} \
MIDOUT_END();

using namespace mgb; using namespace mgb;


namespace { namespace {
@@ -581,9 +591,12 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::Subtensor>(


template <class OprType> template <class OprType>
void OprFootprint::add_single_comp_footprint() { void OprFootprint::add_single_comp_footprint() {
MIDOUT_B(OprType,
midout_iv(MGB_HASH_STR("OprFootprint::add_single_comp_footprint")))
auto&& record = m_type2comp_footprint.emplace(OprType::typeinfo(), auto&& record = m_type2comp_footprint.emplace(OprType::typeinfo(),
opr_footprint_func<OprType>); opr_footprint_func<OprType>);
mgb_assert(record.second, "duplicate opr typeinfo"); mgb_assert(record.second, "duplicate opr typeinfo");
MIDOUT_E
} }


#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON


Loading…
Cancel
Save