GitOrigin-RevId: e66dabee01
release-1.1
@@ -760,7 +760,7 @@ VarNode* CollectiveComm::grad(VarNode* out_grad) const { | |||||
return ModeTrait::from_mode(m_param.mode).grad(out_grad, this); | return ModeTrait::from_mode(m_param.mode).grad(out_grad, this); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(CollectiveComm) { | MGB_IMPL_OPR_GRAD(CollectiveComm) { | ||||
mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad"); | mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad"); | ||||
return opr.grad(out_grad[0]); | return opr.grad(out_grad[0]); | ||||
@@ -119,7 +119,7 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const { | |||||
return prop; | return prop; | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(RemoteSend) { | MGB_IMPL_OPR_GRAD(RemoteSend) { | ||||
mgb_assert(opr.is_grad()); | mgb_assert(opr.is_grad()); | ||||
return RemoteRecv::make(opr.key() + ":grad", | return RemoteRecv::make(opr.key() + ":grad", | ||||
@@ -552,7 +552,7 @@ void Elemwise::call_megdnn_opr_exec( | |||||
opr->exec(inp, out); | opr->exec(inp, out); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Elemwise) { | MGB_IMPL_OPR_GRAD(Elemwise) { | ||||
SymbolVar i[5]; | SymbolVar i[5]; | ||||
SymbolVar i0(opr.input(0)), i1, i2, out(opr.output(0)), | SymbolVar i0(opr.input(0)), i1, i2, out(opr.output(0)), | ||||
@@ -822,7 +822,7 @@ TypeCvt::NodeProp* TypeCvt::do_make_node_prop() const { | |||||
return ret; | return ret; | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(TypeCvt) { | MGB_IMPL_OPR_GRAD(TypeCvt) { | ||||
MGB_MARK_USED_VAR(wrt_idx); | MGB_MARK_USED_VAR(wrt_idx); | ||||
auto itype = opr.input(0)->dtype(), otype = opr.output(0)->dtype(); | auto itype = opr.input(0)->dtype(), otype = opr.output(0)->dtype(); | ||||
@@ -973,7 +973,7 @@ void AddUpdate::record_execute_deps(ExecDependencyArray& deps) { | |||||
record_megdnn_opr(deps); | record_megdnn_opr(deps); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(AddUpdate) { | MGB_IMPL_OPR_GRAD(AddUpdate) { | ||||
// actually valid, just not implemented | // actually valid, just not implemented | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -1712,7 +1712,7 @@ void Reduce::create_megdnn_opr() { | |||||
create_operator<megdnn::Reduce>()); | create_operator<megdnn::Reduce>()); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Reduce) { | MGB_IMPL_OPR_GRAD(Reduce) { | ||||
for (size_t i = 1; i < opr.output().size(); ++ i) | for (size_t i = 1; i < opr.output().size(); ++ i) | ||||
mgb_assert(!out_grad[i]); | mgb_assert(!out_grad[i]); | ||||
@@ -1798,7 +1798,7 @@ void PowC::init_output_static_infer_desc() { | |||||
{SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); | {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(PowC) { | MGB_IMPL_OPR_GRAD(PowC) { | ||||
auto exp = opr.param().exp; | auto exp = opr.param().exp; | ||||
return (exp * SymbolVar{out_grad[0]} * | return (exp * SymbolVar{out_grad[0]} * | ||||
@@ -106,7 +106,7 @@ void MatrixMul::scn_do_execute() { | |||||
MGB_FINALLY({ tparam = this->param(); }); | MGB_FINALLY({ tparam = this->param(); }); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(MatrixMul) { | MGB_IMPL_OPR_GRAD(MatrixMul) { | ||||
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | ||||
"only float data type supported for grad"); | "only float data type supported for grad"); | ||||
@@ -226,7 +226,7 @@ void BatchedMatrixMul::scn_do_execute() { | |||||
MGB_FINALLY({ tparam = this->param(); }); | MGB_FINALLY({ tparam = this->param(); }); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { | MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { | ||||
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | ||||
"only float data type supported for grad"); | "only float data type supported for grad"); | ||||
@@ -331,7 +331,7 @@ void Dot::add_input_layout_constraint() { | |||||
input(1)->add_layout_constraint(check); | input(1)->add_layout_constraint(check); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Dot) { | MGB_IMPL_OPR_GRAD(Dot) { | ||||
auto other_input = opr.input(wrt_idx == 0 ? 1 : 0); | auto other_input = opr.input(wrt_idx == 0 ? 1 : 0); | ||||
auto ishp0 = opr::GetVarShape::make(opr.input(0)), | auto ishp0 = opr::GetVarShape::make(opr.input(0)), | ||||
@@ -357,7 +357,7 @@ void Dot::record_execute_deps(ExecDependencyArray &deps) { | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse); | ||||
MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv") | MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv") | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(MatrixInverse) { | MGB_IMPL_OPR_GRAD(MatrixInverse) { | ||||
SymbolVar a = opr.output(0); | SymbolVar a = opr.output(0); | ||||
// TODO: use unified MatrixMul interface when we have it | // TODO: use unified MatrixMul interface when we have it | ||||
@@ -395,7 +395,7 @@ SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) : | |||||
} | } | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
namespace { | namespace { | ||||
/*! | /*! | ||||
@@ -489,7 +489,7 @@ OP(*, {}, {}) | |||||
} // anonymous namespace | } // anonymous namespace | ||||
#endif | #endif | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(SVD) { | MGB_IMPL_OPR_GRAD(SVD) { | ||||
/** | /** | ||||
* The formula is copied from | * The formula is copied from | ||||
@@ -818,7 +818,7 @@ SymbolVar CondExecMark::mark_if_need(SymbolVar maybe_ppv, SymbolVar input, | |||||
return input; | return input; | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(CondExecMark) { | MGB_IMPL_OPR_GRAD(CondExecMark) { | ||||
if (wrt_idx == opr.input().size() - 1 || !out_grad.at(wrt_idx)) { | if (wrt_idx == opr.input().size() - 1 || !out_grad.at(wrt_idx)) { | ||||
return nullptr; | return nullptr; | ||||
@@ -1227,7 +1227,7 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const { | |||||
return ret; | return ret; | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(CondExecMerge) { | MGB_IMPL_OPR_GRAD(CondExecMerge) { | ||||
using Mode = CondExecMerge::Param::Mode; | using Mode = CondExecMerge::Param::Mode; | ||||
if (opr.param().mode == Mode::SUM_COND_OUT && | if (opr.param().mode == Mode::SUM_COND_OUT && | ||||
@@ -91,7 +91,7 @@ void AdaptivePoolingForward::record_execute_deps(ExecDependencyArray& deps) { | |||||
record_megdnn_opr(deps); | record_megdnn_opr(deps); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(AdaptivePoolingForward) { | MGB_IMPL_OPR_GRAD(AdaptivePoolingForward) { | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
// wrt src | // wrt src | ||||
@@ -240,7 +240,7 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() { | |||||
} | } | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(BatchNormForward) { | MGB_IMPL_OPR_GRAD(BatchNormForward) { | ||||
mgb_assert(opr.param().fwd_mode == BatchNorm::Param::FwdMode::TRAINING, | mgb_assert(opr.param().fwd_mode == BatchNorm::Param::FwdMode::TRAINING, | ||||
"batch norm could only take grad in training mode"); | "batch norm could only take grad in training mode"); | ||||
@@ -1012,7 +1012,7 @@ void ConvolutionForward::init_output_dtype() { | |||||
output(0)->dtype(output_dtype); | output(0)->dtype(output_dtype); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(ConvolutionForward) { | MGB_IMPL_OPR_GRAD(ConvolutionForward) { | ||||
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | ||||
"only float data type supported for grad"); | "only float data type supported for grad"); | ||||
@@ -1175,7 +1175,7 @@ void ConvolutionBackwardData::scn_do_execute() { | |||||
intl::get_megdnn_workspace_from_var(output(1))); | intl::get_megdnn_workspace_from_var(output(1))); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) { | MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) { | ||||
mgb_assert(!out_grad[1]); | mgb_assert(!out_grad[1]); | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
@@ -1229,7 +1229,7 @@ size_t ConvolutionBackwardFilter::get_workspace_size_bytes( | |||||
megdnn_opr(), this); | megdnn_opr(), this); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) { | MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) { | ||||
mgb_assert(!out_grad[1]); | mgb_assert(!out_grad[1]); | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
@@ -1285,7 +1285,7 @@ void Convolution3DForward::init_output_dtype() { | |||||
} | } | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Convolution3DForward) { | MGB_IMPL_OPR_GRAD(Convolution3DForward) { | ||||
mgb_assert(opr.param().data_type == | mgb_assert(opr.param().data_type == | ||||
Convolution3DForward::Param::DataType::FLOAT, | Convolution3DForward::Param::DataType::FLOAT, | ||||
@@ -1380,7 +1380,7 @@ void Convolution3DBackwardData::scn_do_execute() { | |||||
intl::get_megdnn_workspace_from_var(output(1))); | intl::get_megdnn_workspace_from_var(output(1))); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) { | MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) { | ||||
mgb_assert(!out_grad[1]); | mgb_assert(!out_grad[1]); | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
@@ -1781,7 +1781,7 @@ size_t LocalShareForward::get_workspace_size_bytes( | |||||
megdnn_opr(), this); | megdnn_opr(), this); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(LocalShareForward) { | MGB_IMPL_OPR_GRAD(LocalShareForward) { | ||||
mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | ||||
"only float data type supported for grad"); | "only float data type supported for grad"); | ||||
@@ -1862,7 +1862,7 @@ void LocalShareBackwardData::scn_do_execute() { | |||||
intl::get_megdnn_workspace_from_var(output(1))); | intl::get_megdnn_workspace_from_var(output(1))); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(LocalShareBackwardData) { | MGB_IMPL_OPR_GRAD(LocalShareBackwardData) { | ||||
mgb_assert(!out_grad[1]); | mgb_assert(!out_grad[1]); | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
@@ -1919,7 +1919,7 @@ size_t LocalShareBackwardFilter::get_workspace_size_bytes( | |||||
megdnn_opr(), this); | megdnn_opr(), this); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) { | MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) { | ||||
mgb_assert(!out_grad[1]); | mgb_assert(!out_grad[1]); | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
@@ -1998,7 +1998,7 @@ size_t DeformableConvForward::get_workspace_size_bytes( | |||||
megdnn_opr(), this); | megdnn_opr(), this); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(DeformableConvForward) { | MGB_IMPL_OPR_GRAD(DeformableConvForward) { | ||||
mgb_assert(opr.input(0)->dtype() == dtype::Float32(), | mgb_assert(opr.input(0)->dtype() == dtype::Float32(), | ||||
"only float data type supported for grad"); | "only float data type supported for grad"); | ||||
@@ -20,7 +20,7 @@ using namespace opr; | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsForward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsForward); | ||||
MEGDNN_OPR_INIT1(Images2NeibsForward, "images2neibs") | MEGDNN_OPR_INIT1(Images2NeibsForward, "images2neibs") | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Images2NeibsForward) { | MGB_IMPL_OPR_GRAD(Images2NeibsForward) { | ||||
mgb_assert(wrt_idx == 0 && out_grad.size() == 2 && !out_grad[1]); | mgb_assert(wrt_idx == 0 && out_grad.size() == 2 && !out_grad[1]); | ||||
return Images2NeibsBackward::make( | return Images2NeibsBackward::make( | ||||
@@ -21,7 +21,7 @@ using namespace opr; | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalForward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalForward); | ||||
MEGDNN_OPR_INIT2(LocalForward, "local") | MEGDNN_OPR_INIT2(LocalForward, "local") | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(LocalForward) { | MGB_IMPL_OPR_GRAD(LocalForward) { | ||||
return intl::conv_grad<LocalBackwardData, LocalBackwardFilter>( | return intl::conv_grad<LocalBackwardData, LocalBackwardFilter>( | ||||
opr, wrt_idx, out_grad); | opr, wrt_idx, out_grad); | ||||
@@ -38,7 +38,7 @@ MEGDNN_OPR_INIT3(LocalBackwardFilter, "local_bwd_filter", 2, false); | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalForward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalForward); | ||||
MEGDNN_OPR_INIT2(GroupLocalForward, "glocal") | MEGDNN_OPR_INIT2(GroupLocalForward, "glocal") | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(GroupLocalForward) { | MGB_IMPL_OPR_GRAD(GroupLocalForward) { | ||||
return intl::conv_grad<GroupLocalBackwardData, GroupLocalBackwardFilter>( | return intl::conv_grad<GroupLocalBackwardData, GroupLocalBackwardFilter>( | ||||
opr, wrt_idx, out_grad); | opr, wrt_idx, out_grad); | ||||
@@ -20,7 +20,7 @@ using namespace opr; | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNForward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNForward); | ||||
MEGDNN_OPR_INIT1(LRNForward, "lrn") | MEGDNN_OPR_INIT1(LRNForward, "lrn") | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(LRNForward) { | MGB_IMPL_OPR_GRAD(LRNForward) { | ||||
mgb_assert(wrt_idx == 0); | mgb_assert(wrt_idx == 0); | ||||
SymbolVar grad = LRNBackward::make( | SymbolVar grad = LRNBackward::make( | ||||
@@ -19,7 +19,7 @@ using namespace opr; | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward); | ||||
MEGDNN_OPR_INIT1(PoolingForward, "pooling") | MEGDNN_OPR_INIT1(PoolingForward, "pooling") | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(PoolingForward) { | MGB_IMPL_OPR_GRAD(PoolingForward) { | ||||
mgb_assert(wrt_idx == 0); | mgb_assert(wrt_idx == 0); | ||||
SymbolVar grad = PoolingBackward::make( | SymbolVar grad = PoolingBackward::make( | ||||
@@ -40,7 +40,7 @@ SymbolVar ROIAlignForward::make(SymbolVar src, SymbolVar rois, | |||||
src.node(), rois.node(), param, config); | src.node(), rois.node(), param, config); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(ROIAlignForward) { | MGB_IMPL_OPR_GRAD(ROIAlignForward) { | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
// wrt src | // wrt src | ||||
@@ -84,7 +84,7 @@ size_t ROIPoolingForward::get_workspace_size_bytes( | |||||
input_shapes, output_shapes); | input_shapes, output_shapes); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(ROIPoolingForward) { | MGB_IMPL_OPR_GRAD(ROIPoolingForward) { | ||||
if (wrt_idx == 2) { | if (wrt_idx == 2) { | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -148,7 +148,7 @@ SymbolVar DeformablePSROIPoolingForward::make( | |||||
return all[0]; | return all[0]; | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) { | MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) { | ||||
mgb_assert(wrt_idx <= 2); // wrt_idx = 0 or 1 or 2 | mgb_assert(wrt_idx <= 2); // wrt_idx = 0 or 1 or 2 | ||||
@@ -126,7 +126,7 @@ void WarpPerspectiveForward::record_execute_deps(ExecDependencyArray& deps) { | |||||
record_megdnn_opr(deps); | record_megdnn_opr(deps); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) { | MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) { | ||||
if (opr.input().size() == 4) { | if (opr.input().size() == 4) { | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
@@ -351,7 +351,7 @@ void ResizeForward::record_execute_deps(ExecDependencyArray& deps) { | |||||
record_megdnn_opr(deps); | record_megdnn_opr(deps); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(ResizeForward) { | MGB_IMPL_OPR_GRAD(ResizeForward) { | ||||
mgb_assert(opr.input().size() == 2); | mgb_assert(opr.input().size() == 2); | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
@@ -443,7 +443,7 @@ void RemapForward::init_output_dtype() { | |||||
output(0)->dtype(input(0)->dtype()); | output(0)->dtype(input(0)->dtype()); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(RemapForward) { | MGB_IMPL_OPR_GRAD(RemapForward) { | ||||
mgb_assert(opr.input().size() == 2); | mgb_assert(opr.input().size() == 2); | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
@@ -83,7 +83,7 @@ void IndexingOneHot::init_output_dtype() { | |||||
output(0)->dtype(input(0)->dtype()); | output(0)->dtype(input(0)->dtype()); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(IndexingOneHot) { | MGB_IMPL_OPR_GRAD(IndexingOneHot) { | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
return IndexingSetOneHot::make( | return IndexingSetOneHot::make( | ||||
@@ -135,7 +135,7 @@ void IndexingSetOneHot::scn_do_execute() { | |||||
intl::get_megdnn_workspace_from_var(output(1))); | intl::get_megdnn_workspace_from_var(output(1))); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(IndexingSetOneHot) { | MGB_IMPL_OPR_GRAD(IndexingSetOneHot) { | ||||
SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)}; | SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)}; | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
@@ -169,7 +169,7 @@ void IndexingRemap::init_output_dtype() { | |||||
output(0)->dtype(input(0)->dtype()); | output(0)->dtype(input(0)->dtype()); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(IndexingRemap) { | MGB_IMPL_OPR_GRAD(IndexingRemap) { | ||||
if (wrt_idx == 1) | if (wrt_idx == 1) | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -466,7 +466,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( | |||||
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( | MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( | ||||
IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false); | IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false); | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { | MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { | ||||
if (wrt_idx) | if (wrt_idx) | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -477,7 +477,7 @@ MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { | |||||
} | } | ||||
#endif | #endif | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) { | MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) { | ||||
if (wrt_idx >= 2) | if (wrt_idx >= 2) | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -490,7 +490,7 @@ MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) { | |||||
} | } | ||||
#endif | #endif | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) { | MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) { | ||||
if (wrt_idx >= 2) | if (wrt_idx >= 2) | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -510,7 +510,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET( | |||||
BatchedMeshIndexing, "batched_mesh_indexing", false, | BatchedMeshIndexing, "batched_mesh_indexing", false, | ||||
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);); | output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);); | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(MeshIndexing) { | MGB_IMPL_OPR_GRAD(MeshIndexing) { | ||||
if (wrt_idx != 0) { | if (wrt_idx != 0) { | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -522,7 +522,7 @@ MGB_IMPL_OPR_GRAD(MeshIndexing) { | |||||
} | } | ||||
#endif | #endif | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) { | MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) { | ||||
if (wrt_idx != 0) { | if (wrt_idx != 0) { | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -539,7 +539,7 @@ MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) { | |||||
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing", | MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing", | ||||
false); | false); | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(IncrMeshIndexing) { | MGB_IMPL_OPR_GRAD(IncrMeshIndexing) { | ||||
if (wrt_idx > 2) { | if (wrt_idx > 2) { | ||||
return opr::InvalidGrad::make(opr, wrt_idx); | return opr::InvalidGrad::make(opr, wrt_idx); | ||||
@@ -553,7 +553,7 @@ MGB_IMPL_OPR_GRAD(IncrMeshIndexing) { | |||||
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedIncrMeshIndexing, | MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedIncrMeshIndexing, | ||||
"batched_incr_mesh_indexing", false); | "batched_incr_mesh_indexing", false); | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) { | MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) { | ||||
if (wrt_idx > 2) { | if (wrt_idx > 2) { | ||||
return opr::InvalidGrad::make(opr, wrt_idx); | return opr::InvalidGrad::make(opr, wrt_idx); | ||||
@@ -568,7 +568,7 @@ MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) { | |||||
/* ======================== SetMeshIndexing =========================== */ | /* ======================== SetMeshIndexing =========================== */ | ||||
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false); | MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false); | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(SetMeshIndexing) { | MGB_IMPL_OPR_GRAD(SetMeshIndexing) { | ||||
if (wrt_idx >= 2) { | if (wrt_idx >= 2) { | ||||
return opr::InvalidGrad::make(opr, wrt_idx); | return opr::InvalidGrad::make(opr, wrt_idx); | ||||
@@ -587,7 +587,7 @@ MGB_IMPL_OPR_GRAD(SetMeshIndexing) { | |||||
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedSetMeshIndexing, | MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedSetMeshIndexing, | ||||
"batched_set_mesh_indexing", false); | "batched_set_mesh_indexing", false); | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) { | MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) { | ||||
if (wrt_idx > 2) { | if (wrt_idx > 2) { | ||||
return opr::InvalidGrad::make(opr, wrt_idx); | return opr::InvalidGrad::make(opr, wrt_idx); | ||||
@@ -766,7 +766,7 @@ Copy::NodeProp* Copy::do_make_node_prop() const { | |||||
return rst; | return rst; | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Copy) { | MGB_IMPL_OPR_GRAD(Copy) { | ||||
mgb_assert(wrt_idx == 0); | mgb_assert(wrt_idx == 0); | ||||
return Copy::make(out_grad[0], | return Copy::make(out_grad[0], | ||||
@@ -268,7 +268,7 @@ VarNode* Loop::grad(Loop &opr, size_t wrt_idx, const VarNodeArray &out_grad) { | |||||
return gopr->get_grad_var(wrt_idx); | return gopr->get_grad_var(wrt_idx); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Loop) { | MGB_IMPL_OPR_GRAD(Loop) { | ||||
return Loop::grad(const_cast<Loop&>(opr), wrt_idx, out_grad); | return Loop::grad(const_cast<Loop&>(opr), wrt_idx, out_grad); | ||||
} | } | ||||
@@ -48,7 +48,7 @@ namespace intl { | |||||
/* ================= Argmxx ================= */ | /* ================= Argmxx ================= */ | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Argmax) { | MGB_IMPL_OPR_GRAD(Argmax) { | ||||
MGB_MARK_USED_VAR(out_grad); | MGB_MARK_USED_VAR(out_grad); | ||||
MGB_MARK_USED_VAR(opr); | MGB_MARK_USED_VAR(opr); | ||||
@@ -60,7 +60,7 @@ MGB_IMPL_OPR_GRAD(Argmax) { | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmax); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmax); | ||||
MEGDNN_OPR_INIT1(Argmax, "argmax") | MEGDNN_OPR_INIT1(Argmax, "argmax") | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Argmin) { | MGB_IMPL_OPR_GRAD(Argmin) { | ||||
MGB_MARK_USED_VAR(out_grad); | MGB_MARK_USED_VAR(out_grad); | ||||
MGB_MARK_USED_VAR(opr); | MGB_MARK_USED_VAR(opr); | ||||
@@ -87,7 +87,7 @@ std::array<SymbolVar, 2> ArgsortForward::make( | |||||
return {node->output(0), node->output(1)}; | return {node->output(0), node->output(1)}; | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(ArgsortForward) { | MGB_IMPL_OPR_GRAD(ArgsortForward) { | ||||
mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); | mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); | ||||
if (!out_grad[0]) | if (!out_grad[0]) | ||||
@@ -112,7 +112,7 @@ Cumsum::Cumsum(VarNode* opr, const Param& param, | |||||
add_input({opr}, AddInputSortType::CUR_ADDED); | add_input({opr}, AddInputSortType::CUR_ADDED); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Cumsum) { | MGB_IMPL_OPR_GRAD(Cumsum) { | ||||
mgb_assert(out_grad[0] && !out_grad[1]); | mgb_assert(out_grad[0] && !out_grad[1]); | ||||
auto param = opr.param(); | auto param = opr.param(); | ||||
@@ -263,7 +263,7 @@ CondTake::CondTake(VarNode *data, VarNode *mask, | |||||
} | } | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(CondTake) { | MGB_IMPL_OPR_GRAD(CondTake) { | ||||
mgb_assert(out_grad.size() == 3 && !out_grad[2]); | mgb_assert(out_grad.size() == 3 && !out_grad[2]); | ||||
if (wrt_idx == 0 && out_grad[0]) { | if (wrt_idx == 0 && out_grad[0]) { | ||||
@@ -413,7 +413,7 @@ void TopK::record_execute_deps(ExecDependencyArray& deps) { | |||||
record_megdnn_opr(deps); | record_megdnn_opr(deps); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(TopK) { | MGB_IMPL_OPR_GRAD(TopK) { | ||||
if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) { | if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) { | ||||
mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]); | mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]); | ||||
@@ -316,7 +316,7 @@ VarNodeArray AllGather::grad(const VarNodeArray &out_grad) { | |||||
OperatorNodeConfig().comp_node_arr(sp_cn))); | OperatorNodeConfig().comp_node_arr(sp_cn))); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(AllGather) { | MGB_IMPL_OPR_GRAD(AllGather) { | ||||
return const_cast<AllGather&>(opr).grad(out_grad); | return const_cast<AllGather&>(opr).grad(out_grad); | ||||
} | } | ||||
@@ -123,7 +123,7 @@ namespace opr { | |||||
namespace intl { | namespace intl { | ||||
template class RNGOpr<::megdnn::GaussianRNG>; | template class RNGOpr<::megdnn::GaussianRNG>; | ||||
template class RNGOpr<::megdnn::UniformRNG>; | template class RNGOpr<::megdnn::UniformRNG>; | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
IMPL(GaussianRNG); | IMPL(GaussianRNG); | ||||
IMPL(UniformRNG); | IMPL(UniformRNG); | ||||
#endif | #endif | ||||
@@ -46,7 +46,7 @@ void Alloc::outshape_by_symvar_do_get_output_shape( | |||||
void Alloc::scn_do_execute() { | void Alloc::scn_do_execute() { | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Alloc) { | MGB_IMPL_OPR_GRAD(Alloc) { | ||||
MGB_MARK_USED_VAR(wrt_idx); | MGB_MARK_USED_VAR(wrt_idx); | ||||
MGB_MARK_USED_VAR(out_grad); | MGB_MARK_USED_VAR(out_grad); | ||||
@@ -125,7 +125,7 @@ void Linspace::record_execute_deps(ExecDependencyArray& deps) { | |||||
std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); | std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Linspace) { | MGB_IMPL_OPR_GRAD(Linspace) { | ||||
if (wrt_idx == 2) | if (wrt_idx == 2) | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -199,7 +199,7 @@ void Eye::record_execute_deps(ExecDependencyArray& deps) { | |||||
std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); | std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Eye) { | MGB_IMPL_OPR_GRAD(Eye) { | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
} | } | ||||
@@ -165,7 +165,7 @@ void GetVarShape::init_output_static_infer_desc() { | |||||
mgr.register_value_infer(output(0), | mgr.register_value_infer(output(0), | ||||
{SourceType::DEP, deps, infer_value}); | {SourceType::DEP, deps, infer_value}); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(GetVarShape) { | MGB_IMPL_OPR_GRAD(GetVarShape) { | ||||
MGB_MARK_USED_VAR(wrt_idx); | MGB_MARK_USED_VAR(wrt_idx); | ||||
MGB_MARK_USED_VAR(out_grad); | MGB_MARK_USED_VAR(out_grad); | ||||
@@ -372,7 +372,7 @@ SymbolVar Reshape::make(SymbolVar inp, SymbolVar tshp, | |||||
inp.node(), tshp.node(), unspec_axis, config); | inp.node(), tshp.node(), unspec_axis, config); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Reshape) { | MGB_IMPL_OPR_GRAD(Reshape) { | ||||
if (wrt_idx) | if (wrt_idx) | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -441,7 +441,7 @@ SymbolVar Broadcast::make(SymbolVar inp, SymbolVar tshp, | |||||
inp.node(), tshp.node(), config); | inp.node(), tshp.node(), config); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Broadcast) { | MGB_IMPL_OPR_GRAD(Broadcast) { | ||||
if (wrt_idx) | if (wrt_idx) | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -586,7 +586,7 @@ VarNode* Dimshuffle::grad( | |||||
return Dimshuffle::make(out_grad.at(0), back, m_pattern.size()).node(); | return Dimshuffle::make(out_grad.at(0), back, m_pattern.size()).node(); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Dimshuffle) { | MGB_IMPL_OPR_GRAD(Dimshuffle) { | ||||
return opr.grad(wrt_idx, out_grad); | return opr.grad(wrt_idx, out_grad); | ||||
} | } | ||||
@@ -649,7 +649,7 @@ TensorLayout AxisAddRemove::axis_manip_get_output_layout( | |||||
return layout; | return layout; | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(AxisAddRemove) { | MGB_IMPL_OPR_GRAD(AxisAddRemove) { | ||||
MGB_MARK_USED_VAR(wrt_idx); | MGB_MARK_USED_VAR(wrt_idx); | ||||
return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node(); | return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node(); | ||||
@@ -662,7 +662,7 @@ MGB_IMPL_OPR_GRAD(AxisAddRemove) { | |||||
MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true); | MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true); | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Subtensor) { | MGB_IMPL_OPR_GRAD(Subtensor) { | ||||
if (wrt_idx) | if (wrt_idx) | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -806,7 +806,7 @@ void SetSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) { | |||||
sub.copy_from_fixlayout(val); | sub.copy_from_fixlayout(val); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(SetSubtensor) { | MGB_IMPL_OPR_GRAD(SetSubtensor) { | ||||
if (wrt_idx >= 2) | if (wrt_idx >= 2) | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -838,7 +838,7 @@ void IncrSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) { | |||||
opr->exec(sub.as_megdnn(), val.as_megdnn()); | opr->exec(sub.as_megdnn(), val.as_megdnn()); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(IncrSubtensor) { | MGB_IMPL_OPR_GRAD(IncrSubtensor) { | ||||
if (wrt_idx >= 2) | if (wrt_idx >= 2) | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -1112,7 +1112,7 @@ void Split::do_execute(ExecEnv &env) { | |||||
} | } | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Split) { | MGB_IMPL_OPR_GRAD(Split) { | ||||
if (wrt_idx) | if (wrt_idx) | ||||
return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
@@ -1265,7 +1265,7 @@ SymbolVar Concat::make(const VarNodeArrayView& inp, int axis, | |||||
axis, config); | axis, config); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Concat) { | MGB_IMPL_OPR_GRAD(Concat) { | ||||
auto axis = opr.axis(); | auto axis = opr.axis(); | ||||
mgb_assert(out_grad.size() == 1); | mgb_assert(out_grad.size() == 1); | ||||
@@ -1549,7 +1549,7 @@ void ParamPackSplit::scn_do_execute() { | |||||
mgb_assert(inp_size == m_offsets.back(), "input shape should match offsets"); | mgb_assert(inp_size == m_offsets.back(), "input shape should match offsets"); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(ParamPackSplit) { | MGB_IMPL_OPR_GRAD(ParamPackSplit) { | ||||
mgb_assert(out_grad.size() == opr.output().size()); | mgb_assert(out_grad.size() == opr.output().size()); | ||||
SmallVector<SymbolVar> grad; | SmallVector<SymbolVar> grad; | ||||
@@ -255,7 +255,7 @@ void MarkDynamicVar::scn_do_execute() { | |||||
o->dev_tensor().copy_from_fixlayout(i->dev_tensor()); | o->dev_tensor().copy_from_fixlayout(i->dev_tensor()); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(MarkDynamicVar) { | MGB_IMPL_OPR_GRAD(MarkDynamicVar) { | ||||
return MarkDynamicVar::make(out_grad.at(0)).node(); | return MarkDynamicVar::make(out_grad.at(0)).node(); | ||||
} | } | ||||
@@ -383,7 +383,7 @@ CallbackInjector::mixin_get_static_infer_desc(OperatorNodeBase &opr) { | |||||
} | } | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(CallbackInjector) { | MGB_IMPL_OPR_GRAD(CallbackInjector) { | ||||
MGB_MARK_USED_VAR(wrt_idx); | MGB_MARK_USED_VAR(wrt_idx); | ||||
return out_grad.at(0); | return out_grad.at(0); | ||||
@@ -408,7 +408,7 @@ SymbolVar MarkNoBroadcastElemwise::make( | |||||
input.node(), config); | input.node(), config); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(MarkNoBroadcastElemwise) { | MGB_IMPL_OPR_GRAD(MarkNoBroadcastElemwise) { | ||||
return out_grad.at(0); | return out_grad.at(0); | ||||
} | } | ||||
@@ -435,7 +435,7 @@ SymbolVar Identity::make( | |||||
return input.insert_single_output_opr<Identity>(input.node(), config); | return input.insert_single_output_opr<Identity>(input.node(), config); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(Identity) { | MGB_IMPL_OPR_GRAD(Identity) { | ||||
return out_grad.at(0); | return out_grad.at(0); | ||||
} | } | ||||
@@ -538,7 +538,7 @@ SymbolVar SetGrad::make(SymbolVar input, const GradGetter& grad_getter, | |||||
input.node(), grad_getter, config); | input.node(), grad_getter, config); | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(SetGrad) { | MGB_IMPL_OPR_GRAD(SetGrad) { | ||||
MGB_MARK_USED_VAR(wrt_idx); | MGB_MARK_USED_VAR(wrt_idx); | ||||
MGB_MARK_USED_VAR(out_grad); | MGB_MARK_USED_VAR(out_grad); | ||||
@@ -700,7 +700,7 @@ VirtualLoss::NodeProp* VirtualLoss::do_make_node_prop() const { | |||||
return ret; | return ret; | ||||
} | } | ||||
#ifdef MGB_ENABLE_GRAD | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(VirtualLoss) { | MGB_IMPL_OPR_GRAD(VirtualLoss) { | ||||
mgb_assert(out_grad.size() == 1); | mgb_assert(out_grad.size() == 1); | ||||
auto mid = opr.input().size() / 2; | auto mid = opr.input().size() / 2; | ||||