GitOrigin-RevId: 4738136e4a
tags/v1.7.2.m1
@@ -183,7 +183,7 @@ namespace pooling { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& pool = static_cast<const Pooling&>(def); | auto&& pool = static_cast<const Pooling&>(def); | ||||
OperatorNodeConfig config{pool.make_name()}; | OperatorNodeConfig config{pool.make_name()}; | ||||
return opr::Pooling::make(inputs[0], pool.param(), config); | |||||
return opr::Pooling::make(inputs[0], pool.param(), pool.policy(), config); | |||||
} | } | ||||
OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback(); | OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback(); | ||||
} // namespace pooling | } // namespace pooling | ||||
@@ -63,7 +63,7 @@ def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, Executio | |||||
def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | ||||
def Pooling: MgbHashableOp<"Pooling", [PoolingParam]>; | |||||
def Pooling: MgbHashableOp<"Pooling", [PoolingParam, ExecutionPolicyParamBase<"policy">]>; | |||||
def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]>; | def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]>; | ||||
@@ -31,7 +31,21 @@ using namespace opr; | |||||
namespace { | namespace { | ||||
template <class MegDNNConv = megdnn::Convolution> | template <class MegDNNConv = megdnn::Convolution> | ||||
struct MakeConvCaller2 { | |||||
struct MakeOprWithPolicyCaller1 { | |||||
template <typename Opr> | |||||
static VarNode* make( | |||||
const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | |||||
const megdnn::param::ExecutionPolicy& execution_policy, | |||||
const OperatorNodeConfig& config) { | |||||
if (inputs.size() == 1) { | |||||
return Opr::make(inputs[0], param, execution_policy, config).node(); | |||||
} | |||||
return nullptr; | |||||
} | |||||
}; | |||||
template <class MegDNNConv = megdnn::Convolution> | |||||
struct MakeOprWithPolicyCaller2 { | |||||
template <typename Opr> | template <typename Opr> | ||||
static VarNode* make( | static VarNode* make( | ||||
const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | ||||
@@ -46,7 +60,7 @@ struct MakeConvCaller2 { | |||||
}; | }; | ||||
template <class MegDNNConv = megdnn::Convolution> | template <class MegDNNConv = megdnn::Convolution> | ||||
struct MakeConvCaller3 { | |||||
struct MakeOprWithPolicyCaller3 { | |||||
template <typename Opr> | template <typename Opr> | ||||
static VarNode* make( | static VarNode* make( | ||||
const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | ||||
@@ -63,7 +77,7 @@ struct MakeConvCaller3 { | |||||
}; | }; | ||||
template <class MegDNNConv = megdnn::Convolution> | template <class MegDNNConv = megdnn::Convolution> | ||||
struct MakeConvCaller4 { | |||||
struct MakeOprWithPolicyCaller4 { | |||||
template <typename Opr> | template <typename Opr> | ||||
static VarNode* make( | static VarNode* make( | ||||
const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | ||||
@@ -80,7 +94,7 @@ struct MakeConvCaller4 { | |||||
}; | }; | ||||
template <class MegDNNConv = megdnn::Convolution> | template <class MegDNNConv = megdnn::Convolution> | ||||
struct MakeConvCaller5 { | |||||
struct MakeOprWithPolicyCaller5 { | |||||
template <typename Opr> | template <typename Opr> | ||||
static VarNode* make( | static VarNode* make( | ||||
const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | ||||
@@ -97,7 +111,7 @@ struct MakeConvCaller5 { | |||||
}; | }; | ||||
template <class MegDNNConv = megdnn::Convolution> | template <class MegDNNConv = megdnn::Convolution> | ||||
struct MakeConvCallerEmpty { | |||||
struct MakeOprWithPolicyCallerEmpty { | |||||
template <typename Opr> | template <typename Opr> | ||||
static VarNode* make( | static VarNode* make( | ||||
const cg::VarNodeArray&, const typename MegDNNConv::Param&, | const cg::VarNodeArray&, const typename MegDNNConv::Param&, | ||||
@@ -108,10 +122,10 @@ struct MakeConvCallerEmpty { | |||||
template < | template < | ||||
class Opr, class Maker0, class MegDNNConv, | class Opr, class Maker0, class MegDNNConv, | ||||
class Maker1 = MakeConvCallerEmpty<MegDNNConv>, | |||||
class Maker2 = MakeConvCallerEmpty<MegDNNConv>, | |||||
typename ConvParam = megdnn::param::Convolution> | |||||
struct ConvMakerImpl { | |||||
class Maker1 = MakeOprWithPolicyCallerEmpty<MegDNNConv>, | |||||
class Maker2 = MakeOprWithPolicyCallerEmpty<MegDNNConv>, | |||||
typename ConvParam = typename MegDNNConv::Param> | |||||
struct OprWithPolicyMakerImpl { | |||||
static VarNode* make( | static VarNode* make( | ||||
const cg::VarNodeArray& inputs, const ConvParam& param, | const cg::VarNodeArray& inputs, const ConvParam& param, | ||||
const megdnn::param::ExecutionPolicy& execution_policy, | const megdnn::param::ExecutionPolicy& execution_policy, | ||||
@@ -130,33 +144,43 @@ struct ConvMakerImpl { | |||||
}; | }; | ||||
template <typename Opr> | template <typename Opr> | ||||
struct ConvMaker; | |||||
struct OprWithPolicyMaker; | |||||
template <> | |||||
struct OprWithPolicyMaker<opr::Pooling> | |||||
: public OprWithPolicyMakerImpl< | |||||
opr::Pooling, MakeOprWithPolicyCaller1<megdnn::Pooling>, | |||||
megdnn::Pooling> {}; | |||||
template <> | template <> | ||||
struct ConvMaker<opr::Convolution> | |||||
: public ConvMakerImpl< | |||||
opr::Convolution, MakeConvCaller2<megdnn::Convolution>, | |||||
struct OprWithPolicyMaker<opr::Convolution> | |||||
: public OprWithPolicyMakerImpl< | |||||
opr::Convolution, MakeOprWithPolicyCaller2<megdnn::Convolution>, | |||||
megdnn::Convolution> {}; | megdnn::Convolution> {}; | ||||
template <> | template <> | ||||
struct ConvMaker<opr::ConvolutionBackwardData> | |||||
: public ConvMakerImpl< | |||||
opr::ConvolutionBackwardData, MakeConvCaller2<megdnn::Convolution>, | |||||
megdnn::Convolution, MakeConvCaller3<megdnn::Convolution>> {}; | |||||
struct OprWithPolicyMaker<opr::ConvolutionBackwardData> | |||||
: public OprWithPolicyMakerImpl< | |||||
opr::ConvolutionBackwardData, | |||||
MakeOprWithPolicyCaller2<megdnn::Convolution>, megdnn::Convolution, | |||||
MakeOprWithPolicyCaller3<megdnn::Convolution>> {}; | |||||
template <> | template <> | ||||
struct ConvMaker<opr::ConvBiasForward> | |||||
: public ConvMakerImpl< | |||||
opr::ConvBiasForward, MakeConvCaller2<megdnn::ConvBiasForward>, | |||||
megdnn::ConvBiasForward, MakeConvCaller3<megdnn::ConvBiasForward>, | |||||
MakeConvCaller4<megdnn::ConvBiasForward>, megdnn::param::ConvBias> {}; | |||||
struct OprWithPolicyMaker<opr::ConvBiasForward> | |||||
: public OprWithPolicyMakerImpl< | |||||
opr::ConvBiasForward, | |||||
MakeOprWithPolicyCaller2<megdnn::ConvBiasForward>, | |||||
megdnn::ConvBiasForward, | |||||
MakeOprWithPolicyCaller3<megdnn::ConvBiasForward>, | |||||
MakeOprWithPolicyCaller4<megdnn::ConvBiasForward>, | |||||
megdnn::param::ConvBias> {}; | |||||
template <> | template <> | ||||
struct ConvMaker<opr::BatchConvBiasForward> | |||||
: public ConvMakerImpl< | |||||
struct OprWithPolicyMaker<opr::BatchConvBiasForward> | |||||
: public OprWithPolicyMakerImpl< | |||||
opr::BatchConvBiasForward, | opr::BatchConvBiasForward, | ||||
MakeConvCaller2<megdnn::BatchConvBiasForward>, | |||||
MakeOprWithPolicyCaller2<megdnn::BatchConvBiasForward>, | |||||
megdnn::BatchConvBiasForward, | megdnn::BatchConvBiasForward, | ||||
MakeConvCaller3<megdnn::BatchConvBiasForward>, | |||||
MakeConvCaller4<megdnn::BatchConvBiasForward>, | |||||
MakeOprWithPolicyCaller3<megdnn::BatchConvBiasForward>, | |||||
MakeOprWithPolicyCaller4<megdnn::BatchConvBiasForward>, | |||||
megdnn::param::BatchConvBias> {}; | megdnn::param::BatchConvBias> {}; | ||||
#include "../../opr/impl/internal/invoke.h" | #include "../../opr/impl/internal/invoke.h" | ||||
@@ -254,7 +278,7 @@ struct OprFormatModifier; | |||||
auto&& opr = opr_->cast_final_safe<_Opr>(); \ | auto&& opr = opr_->cast_final_safe<_Opr>(); \ | ||||
auto param = opr.param(); \ | auto param = opr.param(); \ | ||||
param.format = opr_format; \ | param.format = opr_format; \ | ||||
return ConvMaker<_Opr>::make( \ | |||||
return OprWithPolicyMaker<_Opr>::make( \ | |||||
i, param, opr.execution_policy(), opr.config()); \ | i, param, opr.execution_policy(), opr.config()); \ | ||||
MIDOUT_E \ | MIDOUT_E \ | ||||
} \ | } \ | ||||
@@ -263,6 +287,7 @@ INST(Convolution); | |||||
INST(ConvBiasForward); | INST(ConvBiasForward); | ||||
INST(ConvolutionBackwardData); | INST(ConvolutionBackwardData); | ||||
INST(BatchConvBiasForward); | INST(BatchConvBiasForward); | ||||
INST(Pooling); | |||||
#undef INST | #undef INST | ||||
template <> | template <> | ||||
@@ -303,7 +328,6 @@ struct OprFormatModifier<WarpPerspective> { | |||||
MIDOUT_E \ | MIDOUT_E \ | ||||
} \ | } \ | ||||
}; | }; | ||||
INST(PoolingForward, 1); | |||||
INST(Resize, 2); | INST(Resize, 2); | ||||
#undef INST | #undef INST | ||||
@@ -1492,7 +1492,8 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
} | } | ||||
auto new_param = pooling_opr.param(); | auto new_param = pooling_opr.param(); | ||||
new_param.format = megdnn::param::Pooling::Format::NHWCD4; | new_param.format = megdnn::param::Pooling::Format::NHWCD4; | ||||
auto new_pooling_opr = opr::PoolingForward::make(inp, new_param, opr->config()); | |||||
auto new_pooling_opr = opr::PoolingForward::make( | |||||
inp, new_param, pooling_opr.execution_policy(), opr->config()); | |||||
return new_pooling_opr.node()->owner_opr(); | return new_pooling_opr.node()->owner_opr(); | ||||
}; | }; | ||||
@@ -525,8 +525,8 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass:: | |||||
} | } | ||||
auto new_param = pooling.param(); | auto new_param = pooling.param(); | ||||
new_param.format = Format::NCHW32; | new_param.format = Format::NCHW32; | ||||
auto new_pooling = | |||||
opr::PoolingForward::make(new_inp_var, new_param, opr->config()); | |||||
auto new_pooling = opr::PoolingForward::make( | |||||
new_inp_var, new_param, pooling.execution_policy(), opr->config()); | |||||
return new_pooling.node()->owner_opr(); | return new_pooling.node()->owner_opr(); | ||||
} | } | ||||
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | ||||
@@ -795,8 +795,8 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() { | |||||
if (varshape_changed.count(new_inp[0])) { | if (varshape_changed.count(new_inp[0])) { | ||||
auto new_param = pooling.param(); | auto new_param = pooling.param(); | ||||
new_param.format = Format::CHWN4; | new_param.format = Format::CHWN4; | ||||
auto new_pooling = | |||||
opr::PoolingForward::make(new_inp[0], new_param, opr->config()); | |||||
auto new_pooling = opr::PoolingForward::make( | |||||
new_inp[0], new_param, pooling.execution_policy(), opr->config()); | |||||
varshape_changed.insert(new_pooling.node()); | varshape_changed.insert(new_pooling.node()); | ||||
return new_pooling.node()->owner_opr(); | return new_pooling.node()->owner_opr(); | ||||
} | } | ||||
@@ -1174,8 +1174,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { | |||||
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); | mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); | ||||
auto new_param = pooling.param(); | auto new_param = pooling.param(); | ||||
new_param.format = Format::NCHW4; | new_param.format = Format::NCHW4; | ||||
auto new_pooling = | |||||
opr::PoolingForward::make(new_inp[0], new_param, opr->config()); | |||||
auto new_pooling = opr::PoolingForward::make( | |||||
new_inp[0], new_param, pooling.execution_policy(), opr->config()); | |||||
mgb_assert( | mgb_assert( | ||||
new_pooling.shape().ndim == 5, | new_pooling.shape().ndim == 5, | ||||
"out var of Pooling opr after transform must be 5 (got: " | "out var of Pooling opr after transform must be 5 (got: " | ||||
@@ -1646,8 +1646,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
if (inp->shape().ndim == 5) { | if (inp->shape().ndim == 5) { | ||||
auto new_param = pooling_opr.param(); | auto new_param = pooling_opr.param(); | ||||
new_param.format = pooling_format; | new_param.format = pooling_format; | ||||
auto new_pooling_opr = | |||||
opr::PoolingForward::make(inp, new_param, opr->config()); | |||||
auto new_pooling_opr = opr::PoolingForward::make( | |||||
inp, new_param, pooling_opr.execution_policy(), opr->config()); | |||||
mgb_assert( | mgb_assert( | ||||
new_pooling_opr.shape().ndim == 5, | new_pooling_opr.shape().ndim == 5, | ||||
"The pooling dst dim is not trans to nchwxx"); | "The pooling dst dim is not trans to nchwxx"); | ||||
@@ -3003,7 +3003,8 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() { | |||||
auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC; | auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC; | ||||
auto param = pooling.param(); | auto param = pooling.param(); | ||||
param.format = target_format; | param.format = target_format; | ||||
auto new_pool = opr::PoolingForward::make(inps[0], param, pooling.config()); | |||||
auto new_pool = opr::PoolingForward::make( | |||||
inps[0], param, pooling.execution_policy(), pooling.config()); | |||||
auto ret = new_pool.node()->owner_opr(); | auto ret = new_pool.node()->owner_opr(); | ||||
format_map.insert(std::make_pair(ret, target_format)); | format_map.insert(std::make_pair(ret, target_format)); | ||||
return ret; | return ret; | ||||
@@ -3055,7 +3056,8 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() { | |||||
auto param = pooling.param(); | auto param = pooling.param(); | ||||
param.format = out_format; | param.format = out_format; | ||||
auto new_pool = opr::PoolingForward::make(inps[0], param, pooling.config()); | |||||
auto new_pool = opr::PoolingForward::make( | |||||
inps[0], param, pooling.execution_policy(), pooling.config()); | |||||
auto ret = new_pool.node()->owner_opr(); | auto ret = new_pool.node()->owner_opr(); | ||||
format_map.insert(std::make_pair(ret, out_format)); | format_map.insert(std::make_pair(ret, out_format)); | ||||
return ret; | return ret; | ||||
@@ -281,7 +281,7 @@ TEST(TestLayoutTransform, Resnet18_QS4) { | |||||
auto new_out_var = new_output[0]; | auto new_out_var = new_output[0]; | ||||
/// check global layout transform pass | /// check global layout transform pass | ||||
auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var); | auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var); | ||||
ASSERT_EQ(nr_dimshuffle, 3u); | |||||
ASSERT_EQ(nr_dimshuffle, 5u); | |||||
/// check pass fuse conv bias with z | /// check pass fuse conv bias with z | ||||
auto nr_elemwise_mult_type = find_opr_num<opr::ElemwiseMultiType>(new_out_var); | auto nr_elemwise_mult_type = find_opr_num<opr::ElemwiseMultiType>(new_out_var); | ||||
ASSERT_EQ(nr_elemwise_mult_type, 4u); | ASSERT_EQ(nr_elemwise_mult_type, 4u); | ||||
@@ -822,7 +822,7 @@ TEST(TestLayoutTransform, Resnet18_F16) { | |||||
auto new_out_var = new_output[0]; | auto new_out_var = new_output[0]; | ||||
/// check global layout transform pass | /// check global layout transform pass | ||||
auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var); | auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var); | ||||
ASSERT_EQ(nr_dimshuffle, 4u); | |||||
ASSERT_EQ(nr_dimshuffle, 2u); | |||||
/// check pass fuse conv bias with z | /// check pass fuse conv bias with z | ||||
auto nr_elemwise = find_opr_num<opr::Elemwise>(new_out_var); | auto nr_elemwise = find_opr_num<opr::Elemwise>(new_out_var); | ||||
ASSERT_EQ(nr_elemwise, 4u); | ASSERT_EQ(nr_elemwise, 4u); | ||||
@@ -80,14 +80,26 @@ struct OprLoadDumpImpl<opr::BatchedMatrixMul, 2> | |||||
opr::BatchedMatrixMul, MakeMatrixMulCaller<megdnn::BatchedMatrixMul>, | opr::BatchedMatrixMul, MakeMatrixMulCaller<megdnn::BatchedMatrixMul>, | ||||
megdnn::BatchedMatrixMul> {}; | megdnn::BatchedMatrixMul> {}; | ||||
template <typename Opr> | |||||
cg::OperatorNodeBase* opr_shallow_copy_matmul( | |||||
const serialization::OprShallowCopyContext& ctx, | |||||
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
MGB_MARK_USED_VAR(ctx); | |||||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||||
return OprLoadDumpImpl<Opr, 2>::make( | |||||
inputs, opr.param(), opr.execution_policy_transient(), config) | |||||
->owner_opr(); | |||||
} | |||||
} // namespace serialization | } // namespace serialization | ||||
namespace opr { | namespace opr { | ||||
using MatrixMulV2 = MatrixMul; | using MatrixMulV2 = MatrixMul; | ||||
using BatchedMatrixMulV2 = BatchedMatrixMul; | using BatchedMatrixMulV2 = BatchedMatrixMul; | ||||
MGB_SEREG_OPR(MatrixMulV2, 2); | |||||
MGB_SEREG_OPR(BatchedMatrixMulV2, 2); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(MatrixMulV2, 2, opr_shallow_copy_matmul); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(BatchedMatrixMulV2, 2, opr_shallow_copy_matmul); | |||||
MGB_SEREG_OPR(Dot, 2); | MGB_SEREG_OPR(Dot, 2); | ||||
MGB_SEREG_OPR(MatrixInverse, 1); | MGB_SEREG_OPR(MatrixInverse, 1); | ||||
MGB_SEREG_OPR(SVD, 1); | MGB_SEREG_OPR(SVD, 1); | ||||
@@ -36,9 +36,10 @@ struct MakePoolingCaller1 { | |||||
template <typename Opr> | template <typename Opr> | ||||
static VarNode* make( | static VarNode* make( | ||||
const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param, | const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param, | ||||
const megdnn::param::ExecutionPolicy& execution_policy, | |||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
if (inputs.size() == 1) { | if (inputs.size() == 1) { | ||||
return Opr::make(inputs[0], param, config).node(); | |||||
return Opr::make(inputs[0], param, execution_policy, config).node(); | |||||
} | } | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -78,9 +79,13 @@ struct MakePoolingBackwardCaller3 { | |||||
template <typename Opr> | template <typename Opr> | ||||
static VarNode* make( | static VarNode* make( | ||||
const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param, | const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param, | ||||
const megdnn::param::ExecutionPolicy& execution_policy, | |||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
if (inputs.size() == 3) { | if (inputs.size() == 3) { | ||||
return Opr::make(inputs[0], inputs[1], inputs[2], param, config).node(); | |||||
return Opr::make( | |||||
inputs[0], inputs[1], inputs[2], param, execution_policy, | |||||
config) | |||||
.node(); | |||||
} | } | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -223,8 +228,10 @@ struct PoolingLoadDumpImpl { | |||||
static VarNode* make( | static VarNode* make( | ||||
const cg::VarNodeArray& inputs, const PoolingParam& param, | const cg::VarNodeArray& inputs, const PoolingParam& param, | ||||
const megdnn::param::ExecutionPolicy& execution_policy, | |||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
VarNode* ret = Maker0::template make<Opr>(inputs, param, config); | |||||
VarNode* ret = | |||||
Maker0::template make<Opr>(inputs, param, execution_policy, config); | |||||
mgb_assert(ret); | mgb_assert(ret); | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -233,6 +240,29 @@ struct PoolingLoadDumpImpl { | |||||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | OprLoadContext& ctx, const cg::VarNodeArray& inputs, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
auto param = ctx.read_param<PoolingParam>(); | auto param = ctx.read_param<PoolingParam>(); | ||||
return make(inputs, param, {}, config)->owner_opr(); | |||||
} | |||||
}; | |||||
template <class Opr, class Maker0, typename GeneralOprParam = megdnn::param::ROIAlign> | |||||
struct GeneralOprLoadDumpImpl { | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||||
ctx.write_param<GeneralOprParam>(opr.param()); | |||||
} | |||||
static VarNode* make( | |||||
const cg::VarNodeArray& inputs, const GeneralOprParam& param, | |||||
const OperatorNodeConfig& config) { | |||||
VarNode* ret = Maker0::template make<Opr>(inputs, param, config); | |||||
mgb_assert(ret); | |||||
return ret; | |||||
} | |||||
static cg::OperatorNodeBase* load( | |||||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
auto param = ctx.read_param<GeneralOprParam>(); | |||||
return make(inputs, param, config)->owner_opr(); | return make(inputs, param, config)->owner_opr(); | ||||
} | } | ||||
}; | }; | ||||
@@ -264,26 +294,26 @@ struct OprMaker<opr::LSQBackward, 5> { | |||||
}; | }; | ||||
template <> | template <> | ||||
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | ||||
: public PoolingLoadDumpImpl< | |||||
: public GeneralOprLoadDumpImpl< | |||||
opr::AdaptivePoolingBackward, | opr::AdaptivePoolingBackward, | ||||
MakeAdaptivePoolingBackwardCaller3<megdnn::AdaptivePoolingBackward>, | MakeAdaptivePoolingBackwardCaller3<megdnn::AdaptivePoolingBackward>, | ||||
megdnn::param::AdaptivePooling> {}; | megdnn::param::AdaptivePooling> {}; | ||||
template <> | template <> | ||||
struct OprLoadDumpImpl<opr::AdaptivePooling, 0> | struct OprLoadDumpImpl<opr::AdaptivePooling, 0> | ||||
: public PoolingLoadDumpImpl< | |||||
: public GeneralOprLoadDumpImpl< | |||||
opr::AdaptivePooling, MakeROIAlignCaller1<megdnn::AdaptivePooling>, | opr::AdaptivePooling, MakeROIAlignCaller1<megdnn::AdaptivePooling>, | ||||
megdnn::param::AdaptivePooling> {}; | megdnn::param::AdaptivePooling> {}; | ||||
template <> | template <> | ||||
struct OprLoadDumpImpl<opr::ROIAlign, 0> | struct OprLoadDumpImpl<opr::ROIAlign, 0> | ||||
: public PoolingLoadDumpImpl< | |||||
: public GeneralOprLoadDumpImpl< | |||||
opr::ROIAlign, MakeROIAlignCaller1<megdnn::ROIAlign>, | opr::ROIAlign, MakeROIAlignCaller1<megdnn::ROIAlign>, | ||||
megdnn::param::ROIAlign> {}; | megdnn::param::ROIAlign> {}; | ||||
template <> | template <> | ||||
struct OprLoadDumpImpl<opr::ROIAlignBackward, 0> | struct OprLoadDumpImpl<opr::ROIAlignBackward, 0> | ||||
: public PoolingLoadDumpImpl< | |||||
: public GeneralOprLoadDumpImpl< | |||||
opr::ROIAlignBackward, MakeROIAlignCaller4<megdnn::ROIAlignBackward>, | opr::ROIAlignBackward, MakeROIAlignCaller4<megdnn::ROIAlignBackward>, | ||||
megdnn::param::ROIAlign> {}; | megdnn::param::ROIAlign> {}; | ||||
@@ -500,15 +530,29 @@ struct OprLoadDumpImpl<opr::DeformableConvBackwardFilter, 0> | |||||
opr::DeformableConvBackwardFilter, | opr::DeformableConvBackwardFilter, | ||||
MakeConvCaller5<megdnn::DeformableConvBackwardFilter>, | MakeConvCaller5<megdnn::DeformableConvBackwardFilter>, | ||||
megdnn::Convolution> {}; | megdnn::Convolution> {}; | ||||
template <typename Opr> | |||||
cg::OperatorNodeBase* opr_shallow_copy_conv( | |||||
const serialization::OprShallowCopyContext& ctx, | |||||
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
MGB_MARK_USED_VAR(ctx); | |||||
auto&& opr = opr_.cast_final_safe<Opr>(); | |||||
return OprLoadDumpImpl<Opr, 0>::make( | |||||
inputs, opr.param(), opr.execution_policy_transient(), config) | |||||
->owner_opr(); | |||||
} | |||||
} // namespace serialization | } // namespace serialization | ||||
namespace opr { | namespace opr { | ||||
using ConvolutionV2 = Convolution; | using ConvolutionV2 = Convolution; | ||||
using ConvolutionBackwardDataV2 = ConvolutionBackwardData; | using ConvolutionBackwardDataV2 = ConvolutionBackwardData; | ||||
using ConvolutionBackwardFilterV2 = ConvolutionBackwardFilter; | using ConvolutionBackwardFilterV2 = ConvolutionBackwardFilter; | ||||
MGB_SEREG_OPR(ConvolutionV2, 0); | |||||
MGB_SEREG_OPR(ConvolutionBackwardDataV2, 0); | |||||
MGB_SEREG_OPR(ConvolutionBackwardFilterV2, 0); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvolutionV2, 0, opr_shallow_copy_conv); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvolutionBackwardDataV2, 0, opr_shallow_copy_conv); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY( | |||||
ConvolutionBackwardFilterV2, 0, opr_shallow_copy_conv); | |||||
MGB_SEREG_OPR(Images2Neibs, 1); | MGB_SEREG_OPR(Images2Neibs, 1); | ||||
MGB_SEREG_OPR(Images2NeibsBackward, 2); | MGB_SEREG_OPR(Images2NeibsBackward, 2); | ||||
@@ -534,8 +578,8 @@ MGB_SEREG_OPR(LRN, 1); | |||||
MGB_SEREG_OPR(LRNBackward, 3); | MGB_SEREG_OPR(LRNBackward, 3); | ||||
using PoolingV1 = Pooling; | using PoolingV1 = Pooling; | ||||
using PoolingBackwardV1 = PoolingBackward; | using PoolingBackwardV1 = PoolingBackward; | ||||
MGB_SEREG_OPR(PoolingV1, 1); | |||||
MGB_SEREG_OPR(PoolingBackwardV1, 3); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(PoolingV1, 0, opr_shallow_copy_conv); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(PoolingBackwardV1, 0, opr_shallow_copy_conv); | |||||
using AdaptivePoolingV1 = AdaptivePooling; | using AdaptivePoolingV1 = AdaptivePooling; | ||||
using AdaptivePoolingBackwardV1 = AdaptivePoolingBackward; | using AdaptivePoolingBackwardV1 = AdaptivePoolingBackward; | ||||
MGB_SEREG_OPR(AdaptivePoolingV1, 2); | MGB_SEREG_OPR(AdaptivePoolingV1, 2); | ||||
@@ -548,12 +592,13 @@ using MaskConvolutionV2 = MaskConvolution; | |||||
MGB_SEREG_OPR(MaskConvolutionV2, 3); | MGB_SEREG_OPR(MaskConvolutionV2, 3); | ||||
MGB_SEREG_OPR(MaskPropagate, 1); | MGB_SEREG_OPR(MaskPropagate, 1); | ||||
MGB_SEREG_OPR(Convolution3D, 0); | |||||
MGB_SEREG_OPR(Convolution3DBackwardData, 0); | |||||
MGB_SEREG_OPR(Convolution3DBackwardFilter, 0); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(Convolution3D, 0, opr_shallow_copy_conv); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(Convolution3DBackwardData, 0, opr_shallow_copy_conv); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY( | |||||
Convolution3DBackwardFilter, 0, opr_shallow_copy_conv); | |||||
using ConvBiasForwardV4 = ConvBiasForward; | using ConvBiasForwardV4 = ConvBiasForward; | ||||
MGB_SEREG_OPR(ConvBiasForwardV4, 0); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvBiasForwardV4, 0, opr_shallow_copy_conv); | |||||
using BatchNormV1 = BatchNorm; | using BatchNormV1 = BatchNorm; | ||||
using BatchNormBackwardV1 = BatchNormBackward; | using BatchNormBackwardV1 = BatchNormBackward; | ||||
@@ -563,9 +608,10 @@ MGB_SEREG_OPR(BatchNormBackwardV1, 6); | |||||
using LocalShareForwardV1 = LocalShareForward; | using LocalShareForwardV1 = LocalShareForward; | ||||
using LocalShareBackwardDataV1 = LocalShareBackwardData; | using LocalShareBackwardDataV1 = LocalShareBackwardData; | ||||
using LocalShareBackwardFilterV1 = LocalShareBackwardFilter; | using LocalShareBackwardFilterV1 = LocalShareBackwardFilter; | ||||
MGB_SEREG_OPR(LocalShareForwardV1, 0); | |||||
MGB_SEREG_OPR(LocalShareBackwardDataV1, 0); | |||||
MGB_SEREG_OPR(LocalShareBackwardFilterV1, 0); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(LocalShareForwardV1, 0, opr_shallow_copy_conv); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(LocalShareBackwardDataV1, 0, opr_shallow_copy_conv); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY( | |||||
LocalShareBackwardFilterV1, 0, opr_shallow_copy_conv); | |||||
using ROIAlignV1 = ROIAlign; | using ROIAlignV1 = ROIAlign; | ||||
using ROIAlignBackwardV1 = ROIAlignBackward; | using ROIAlignBackwardV1 = ROIAlignBackward; | ||||
@@ -574,9 +620,11 @@ MGB_SEREG_OPR(ROIAlignBackwardV1, 4); | |||||
using DeformableConvForwardV1 = DeformableConvForward; | using DeformableConvForwardV1 = DeformableConvForward; | ||||
using DeformableConvBackwardDataV1 = DeformableConvBackwardData; | using DeformableConvBackwardDataV1 = DeformableConvBackwardData; | ||||
using DeformableConvBackwardFilterV1 = DeformableConvBackwardFilter; | using DeformableConvBackwardFilterV1 = DeformableConvBackwardFilter; | ||||
MGB_SEREG_OPR(DeformableConvForwardV1, 0); | |||||
MGB_SEREG_OPR(DeformableConvBackwardDataV1, 0); | |||||
MGB_SEREG_OPR(DeformableConvBackwardFilterV1, 0); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(DeformableConvForwardV1, 0, opr_shallow_copy_conv); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY( | |||||
DeformableConvBackwardDataV1, 0, opr_shallow_copy_conv); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY( | |||||
DeformableConvBackwardFilterV1, 0, opr_shallow_copy_conv); | |||||
MGB_SEREG_OPR(CorrelationForward, 2); | MGB_SEREG_OPR(CorrelationForward, 2); | ||||
MGB_SEREG_OPR(CorrelationBackwardData1, 3); | MGB_SEREG_OPR(CorrelationBackwardData1, 3); | ||||
@@ -586,7 +634,7 @@ MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3); | |||||
MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); | MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); | ||||
using BatchConvBiasForwardV1 = BatchConvBiasForward; | using BatchConvBiasForwardV1 = BatchConvBiasForward; | ||||
MGB_SEREG_OPR(BatchConvBiasForwardV1, 0); | |||||
MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(BatchConvBiasForwardV1, 0, opr_shallow_copy_conv); | |||||
MGB_SEREG_OPR(FakeQuant, 3); | MGB_SEREG_OPR(FakeQuant, 3); | ||||
MGB_SEREG_OPR(FakeQuantBackward, 4); | MGB_SEREG_OPR(FakeQuantBackward, 4); | ||||
MGB_SEREG_OPR(TQT, 2); | MGB_SEREG_OPR(TQT, 2); | ||||
@@ -32,8 +32,8 @@ PoolingForward::PoolingForward( | |||||
} | } | ||||
SymbolVar PoolingForward::make( | SymbolVar PoolingForward::make( | ||||
SymbolVar i0, const Param& param, const OperatorNodeConfig& config, | |||||
const ExecutionPolicy& policy) { | |||||
SymbolVar i0, const Param& param, const ExecutionPolicy& policy, | |||||
const OperatorNodeConfig& config) { | |||||
intl::MegDNNOprInitInputsModifier<PoolingForward>::apply(param, {&i0}); | intl::MegDNNOprInitInputsModifier<PoolingForward>::apply(param, {&i0}); | ||||
return i0.insert_single_output_opr<PoolingForward>( | return i0.insert_single_output_opr<PoolingForward>( | ||||
i0.node(), param, policy, config); | i0.node(), param, policy, config); | ||||
@@ -75,12 +75,13 @@ PoolingBackward::PoolingBackward( | |||||
0, true) { | 0, true) { | ||||
init_megdnn_opr(*this, param); | init_megdnn_opr(*this, param); | ||||
add_input({i0, i1, i2}); | add_input({i0, i1, i2}); | ||||
m_policy = policy; | |||||
intl::MegDNNOprInitPostCtor<PoolingBackward>::apply(*this); | intl::MegDNNOprInitPostCtor<PoolingBackward>::apply(*this); | ||||
} | } | ||||
SymbolVar PoolingBackward::make( | SymbolVar PoolingBackward::make( | ||||
SymbolVar i0, SymbolVar i1, SymbolVar i2, const Param& param, | SymbolVar i0, SymbolVar i1, SymbolVar i2, const Param& param, | ||||
const OperatorNodeConfig& config, const ExecutionPolicy& policy) { | |||||
const ExecutionPolicy& policy, const OperatorNodeConfig& config) { | |||||
intl::MegDNNOprInitInputsModifier<PoolingBackward>::apply(param, {&i0, &i1, &i2}); | intl::MegDNNOprInitInputsModifier<PoolingBackward>::apply(param, {&i0, &i1, &i2}); | ||||
return i0.insert_single_output_opr<PoolingBackward>( | return i0.insert_single_output_opr<PoolingBackward>( | ||||
i0.node(), i1.node(), i2.node(), param, policy, config); | i0.node(), i1.node(), i2.node(), param, policy, config); | ||||
@@ -26,8 +26,8 @@ MGE_WIN_DECLSPEC_FUC PoolingForward( | |||||
VarNode* src, const Param& param, const ExecutionPolicy& policy, | VarNode* src, const Param& param, const ExecutionPolicy& policy, | ||||
const OperatorNodeConfig& config); | const OperatorNodeConfig& config); | ||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | MGE_WIN_DECLSPEC_FUC static SymbolVar make( | ||||
SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}, | |||||
const ExecutionPolicy& policy = {}); | |||||
SymbolVar src, const Param& param, const ExecutionPolicy& policy = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
@@ -47,7 +47,7 @@ MGE_WIN_DECLSPEC_FUC PoolingBackward( | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | MGE_WIN_DECLSPEC_FUC static SymbolVar make( | ||||
SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param, | SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param, | ||||
const OperatorNodeConfig& config = {}, const ExecutionPolicy& policy = {}); | |||||
const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); | |||||
MGE_WIN_DECLSPEC_FUC size_t get_workspace_size_bytes( | MGE_WIN_DECLSPEC_FUC size_t get_workspace_size_bytes( | ||||
const TensorShapeArray& input_shapes, | const TensorShapeArray& input_shapes, | ||||
@@ -15,7 +15,9 @@ | |||||
#include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
#include "megbrain/opr/blas.h" | #include "megbrain/opr/blas.h" | ||||
#include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
#include "megbrain/opr/dnn/pooling.h" | |||||
#include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
#include "megbrain/serialization/opr_shallow_copy.h" | |||||
#include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
#include "megbrain/test/autocheck.h" | #include "megbrain/test/autocheck.h" | ||||
#include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
@@ -32,39 +34,24 @@ using namespace mgb; | |||||
namespace { | namespace { | ||||
#if MGB_CUDA | |||||
#if MGB_ENABLE_FASTRUN | |||||
template <typename MgbOpr, int arith> | template <typename MgbOpr, int arith> | ||||
struct GraphMaker; | struct GraphMaker; | ||||
template <typename MgbOpr> | |||||
struct GraphMaker<MgbOpr, 2> { | |||||
SymbolVar operator()( | |||||
const std::array<cg::SymbolVar, 2>& inputs, typename MgbOpr::Param& param, | |||||
typename MgbOpr::ExecutionPolicy& policy) { | |||||
return MgbOpr::make(inputs[0], inputs[1], param, policy); | |||||
} | |||||
}; | |||||
template <> | template <> | ||||
struct GraphMaker<opr::ConvolutionBackwardData, 2> { | |||||
struct GraphMaker<opr::Pooling, 1> { | |||||
SymbolVar operator()( | SymbolVar operator()( | ||||
const std::array<cg::SymbolVar, 2>& inputs, | |||||
opr::ConvolutionBackwardData::Param& param, | |||||
opr::ConvolutionBackwardData::ExecutionPolicy& policy) { | |||||
return opr::ConvolutionBackwardData::make_deconv( | |||||
inputs[0], inputs[1], param, policy); | |||||
const std::array<cg::SymbolVar, 1>& inputs, opr::Pooling::Param& param, | |||||
opr::Pooling::ExecutionPolicy& policy) { | |||||
return opr::Pooling::make(inputs[0], param, policy); | |||||
} | } | ||||
}; | }; | ||||
template <> | |||||
struct GraphMaker<opr::Convolution3DBackwardData, 2> { | |||||
template <typename MgbOpr> | |||||
struct GraphMaker<MgbOpr, 2> { | |||||
SymbolVar operator()( | SymbolVar operator()( | ||||
const std::array<cg::SymbolVar, 2>& inputs, | |||||
opr::Convolution3DBackwardData::Param& param, | |||||
opr::Convolution3DBackwardData::ExecutionPolicy& policy) { | |||||
return opr::Convolution3DBackwardData::make_deconv( | |||||
inputs[0], inputs[1], param, policy); | |||||
const std::array<cg::SymbolVar, 2>& inputs, typename MgbOpr::Param& param, | |||||
typename MgbOpr::ExecutionPolicy& policy) { | |||||
return MgbOpr::make(inputs[0], inputs[1], param, policy); | |||||
} | } | ||||
}; | }; | ||||
@@ -99,6 +86,37 @@ struct GraphMaker<MgbOpr, 5> { | |||||
}; | }; | ||||
template <typename MgbOpr, int arith, typename dtype = dtype::Float32> | template <typename MgbOpr, int arith, typename dtype = dtype::Float32> | ||||
void test_execution_policy_shallow_copy( | |||||
std::array<TensorShape, arith> shapes, typename MgbOpr::Param param = {}) { | |||||
using Policy = typename MgbOpr::ExecutionPolicy; | |||||
Policy policy; | |||||
policy.strategy = Policy::Strategy::PROFILE; | |||||
auto cn = CompNode::load("cpu0"); | |||||
auto graph0 = ComputingGraph::make(), graph1 = ComputingGraph::make(); | |||||
std::array<cg::SymbolVar, arith> inputs0; | |||||
VarNodeArray inputs1; | |||||
for (size_t i = 0; i < arith; ++i) { | |||||
HostTensorND hi{cn, shapes[i], dtype()}; | |||||
inputs0[i] = opr::ImmutableTensor::make(*graph0, hi); | |||||
inputs1.push_back(opr::ImmutableTensor::make(*graph1, hi).node()); | |||||
} | |||||
GraphMaker<MgbOpr, arith> graph_maker; | |||||
auto opr0 = graph_maker(inputs0, param, policy).node()->owner_opr(); | |||||
auto opr1 = serialization::copy_opr_shallow(*opr0, inputs1, OperatorNodeConfig{}); | |||||
auto m0 = &(opr0->template cast_final<MgbOpr>()); | |||||
auto m1 = &(opr1->template cast_final<MgbOpr>()); | |||||
ASSERT_EQ(policy.strategy, m0->execution_policy().strategy); | |||||
ASSERT_EQ(policy.strategy, m1->execution_policy().strategy); | |||||
} | |||||
#if MGB_CUDA | |||||
#if MGB_ENABLE_FASTRUN | |||||
template <typename MgbOpr, int arith, typename dtype = dtype::Float32> | |||||
void test_fastrun_opr( | void test_fastrun_opr( | ||||
std::array<TensorShape, arith> inps0, std::array<TensorShape, arith> inps1, | std::array<TensorShape, arith> inps0, std::array<TensorShape, arith> inps1, | ||||
size_t expect_nr_cache_set_inp0 = 0, size_t expect_nr_cache_set_inp1 = 0, | size_t expect_nr_cache_set_inp0 = 0, size_t expect_nr_cache_set_inp1 = 0, | ||||
@@ -162,16 +180,24 @@ void test_fastrun_opr( | |||||
size_t nr_set_total = expect_nr_cache_set_inp1 + nr_set_inp0; | size_t nr_set_total = expect_nr_cache_set_inp1 + nr_set_inp0; | ||||
ASSERT_EQ(cache_set_history.size(), nr_set_total); | ASSERT_EQ(cache_set_history.size(), nr_set_total); | ||||
} | } | ||||
#endif // MGB_ENABLE_FASTRUN | |||||
#endif // MGB_CUDA | |||||
} // anonymous namespace | |||||
#if MGB_CUDA | |||||
#if MGB_ENABLE_FASTRUN | |||||
TEST(TestOprDNN, FastrunIgnoreBatchSizeConvolution) { | TEST(TestOprDNN, FastrunIgnoreBatchSizeConvolution) { | ||||
REQUIRE_GPU(1); | REQUIRE_GPU(1); | ||||
test_fastrun_opr<opr::Convolution, 2>( | test_fastrun_opr<opr::Convolution, 2>( | ||||
{TensorShape{12, 3, 36, 36}, TensorShape{4, 3, 3, 3}}, | {TensorShape{12, 3, 36, 36}, TensorShape{4, 3, 3, 3}}, | ||||
{TensorShape{1, 3, 36, 36}, TensorShape{4, 3, 3, 3}}); | {TensorShape{1, 3, 36, 36}, TensorShape{4, 3, 3, 3}}); | ||||
test_fastrun_opr<opr::ConvolutionBackwardData, 2>( | |||||
{TensorShape{12, 4, 23, 29}, TensorShape{4, 5, 3, 2}}, | |||||
{TensorShape{2, 4, 23, 29}, TensorShape{4, 5, 3, 2}}); | |||||
test_fastrun_opr<opr::ConvolutionBackwardData, 3>( | |||||
{TensorShape{4, 5, 3, 2}, TensorShape{12, 4, 23, 29}, | |||||
TensorShape{12, 5, 25, 30}}, | |||||
{TensorShape{4, 5, 3, 2}, TensorShape{2, 4, 23, 29}, | |||||
TensorShape{2, 5, 25, 30}}); | |||||
test_fastrun_opr<opr::ConvolutionBackwardFilter, 3>( | test_fastrun_opr<opr::ConvolutionBackwardFilter, 3>( | ||||
{TensorShape{12, 4, 23, 29}, TensorShape{12, 5, 21, 28}, | {TensorShape{12, 4, 23, 29}, TensorShape{12, 5, 21, 28}, | ||||
@@ -195,9 +221,11 @@ TEST(TestOprDNN, FastrunIgnoreBatchSizeConvolution3D) { | |||||
{TensorShape{8, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}}, | {TensorShape{8, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}}, | ||||
{TensorShape{3, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}}); | {TensorShape{3, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}}); | ||||
test_fastrun_opr<opr::Convolution3DBackwardData, 2>( | |||||
{TensorShape{14, 5, 12, 12, 16}, TensorShape{5, 5, 3, 3, 3}}, | |||||
{TensorShape{4, 5, 12, 12, 16}, TensorShape{5, 5, 3, 3, 3}}); | |||||
test_fastrun_opr<opr::Convolution3DBackwardData, 3>( | |||||
{TensorShape{5, 5, 3, 3, 3}, TensorShape{14, 5, 12, 12, 16}, | |||||
TensorShape{14, 5, 14, 14, 18}}, | |||||
{TensorShape{5, 5, 3, 3, 3}, TensorShape{4, 5, 12, 12, 16}, | |||||
TensorShape{4, 5, 14, 14, 18}}); | |||||
test_fastrun_opr<opr::Convolution3DBackwardFilter, 3>( | test_fastrun_opr<opr::Convolution3DBackwardFilter, 3>( | ||||
{TensorShape{64, 16, 18, 18, 18}, TensorShape{64, 16, 18, 18, 18}, | {TensorShape{64, 16, 18, 18, 18}, TensorShape{64, 16, 18, 18, 18}, | ||||
@@ -295,6 +323,87 @@ TEST(TestOprDNN, FastrunIgnoreBatchSizeBatchedMatrixMul) { | |||||
#endif // MGB_ENABLE_FASTRUN | #endif // MGB_ENABLE_FASTRUN | ||||
#endif // MGB_CUDA | #endif // MGB_CUDA | ||||
} // anonymous namespace | |||||
TEST(TestOprDNN, ExecutionPolicyShallowCopyConvolution) { | |||||
test_execution_policy_shallow_copy<opr::Convolution, 2>( | |||||
{TensorShape{12, 3, 36, 36}, TensorShape{4, 3, 3, 3}}); | |||||
test_execution_policy_shallow_copy<opr::ConvolutionBackwardData, 3>( | |||||
{TensorShape{4, 5, 3, 2}, TensorShape{12, 4, 23, 29}, | |||||
TensorShape{12, 5, 25, 30}}); | |||||
test_execution_policy_shallow_copy<opr::ConvolutionBackwardFilter, 3>( | |||||
{TensorShape{12, 4, 23, 29}, TensorShape{12, 5, 21, 28}, | |||||
TensorShape{5, 4, 3, 2}}); | |||||
} | |||||
TEST(TestOprDNN, ExecutionPolicyShallowCopyConvBias) { | |||||
test_execution_policy_shallow_copy<opr::ConvBias, 3>( | |||||
{TensorShape{20, 16, 50, 50}, TensorShape{24, 16, 3, 3}, | |||||
TensorShape{1, 24, 1, 1}}); | |||||
} | |||||
TEST(TestOprDNN, ExecutionPolicyShallowCopyConvolution3D) { | |||||
test_execution_policy_shallow_copy<opr::Convolution3D, 2>( | |||||
{TensorShape{8, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}}); | |||||
test_execution_policy_shallow_copy<opr::Convolution3DBackwardData, 3>( | |||||
{TensorShape{5, 5, 3, 3, 3}, TensorShape{14, 5, 12, 12, 16}, | |||||
TensorShape{14, 5, 14, 14, 18}}); | |||||
test_execution_policy_shallow_copy<opr::Convolution3DBackwardFilter, 3>( | |||||
{TensorShape{64, 16, 18, 18, 18}, TensorShape{64, 16, 18, 18, 18}, | |||||
TensorShape{16, 16, 1, 1, 1}}); | |||||
} | |||||
TEST(TestOprDNN, ExecutionPolicyShallowCopyLocalShare) { | |||||
opr::LocalShare::Param local_share_param; | |||||
local_share_param.mode = opr::LocalShare::Param::Mode::CROSS_CORRELATION; | |||||
local_share_param.pad_h = local_share_param.pad_w = 1; | |||||
local_share_param.stride_h = local_share_param.stride_w = 1; | |||||
local_share_param.spatial_groups_h = local_share_param.spatial_groups_w = 2; | |||||
test_execution_policy_shallow_copy<opr::LocalShareForward, 2>( | |||||
{TensorShape{32, 2, 23, 23}, TensorShape{2, 2, 2, 2, 2, 7}}, | |||||
local_share_param); | |||||
test_execution_policy_shallow_copy<opr::LocalShareBackwardData, 3>( | |||||
{TensorShape{3, 3, 128, 1, 1, 128}, TensorShape{32, 128, 24, 24}, | |||||
TensorShape{32, 128, 24, 24}}); | |||||
test_execution_policy_shallow_copy<opr::LocalShareBackwardFilter, 3>( | |||||
{TensorShape{12, 3, 36, 36}, TensorShape{12, 4, 35, 35}, | |||||
TensorShape{3, 3, 3, 3, 3, 4}}); | |||||
} | |||||
TEST(TestOprDNN, ExecutionPolicyShallowCopyDeformableConv) { | |||||
test_execution_policy_shallow_copy<opr::DeformableConvForward, 4>( | |||||
{TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3}, | |||||
TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18}}); | |||||
test_execution_policy_shallow_copy<opr::DeformableConvBackwardData, 5>( | |||||
{TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3}, | |||||
TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18}, | |||||
TensorShape{12, 6, 18, 18}}); | |||||
test_execution_policy_shallow_copy<opr::DeformableConvBackwardFilter, 5>( | |||||
{TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3}, | |||||
TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18}, | |||||
TensorShape{12, 6, 18, 18}}); | |||||
} | |||||
TEST(TestOprDNN, ExecutionPolicyShallowCopyMatrixMul) { | |||||
test_execution_policy_shallow_copy<opr::MatrixMul, 2>( | |||||
{TensorShape{10, 12}, TensorShape{12, 12}}); | |||||
test_execution_policy_shallow_copy<opr::BatchedMatrixMul, 2>( | |||||
{TensorShape{12, 6, 8}, TensorShape{12, 8, 4}}); | |||||
} | |||||
TEST(TestOprDNN, ExecutionPolicyShallowCopyPooling) { | |||||
test_execution_policy_shallow_copy<opr::Pooling, 1>({TensorShape{1, 20, 24, 24}}); | |||||
test_execution_policy_shallow_copy<opr::PoolingBackward, 3>( | |||||
{TensorShape{1, 20, 24, 24}, TensorShape{1, 20, 12, 12}, | |||||
TensorShape{1, 20, 12, 12}}); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -135,7 +135,7 @@ TEST(TestOprDNN, PoolingExePolicy) { | |||||
Policy policy; | Policy policy; | ||||
policy.strategy = strategy; | policy.strategy = strategy; | ||||
auto pooling = opr::PoolingForward::make(input, param, {}, policy); | |||||
auto pooling = opr::PoolingForward::make(input, param, policy); | |||||
auto loss0 = opr::reduce_sum_sqr(pooling, pooling.make_scalar(1)); | auto loss0 = opr::reduce_sum_sqr(pooling, pooling.make_scalar(1)); | ||||
auto grad = cg::grad(loss0, input, true, false); | auto grad = cg::grad(loss0, input, true, false); | ||||
@@ -187,7 +187,7 @@ TEST(TestOprDNN, PoolingForwardFastrun) { | |||||
Policy policy; | Policy policy; | ||||
policy.strategy = strategy; | policy.strategy = strategy; | ||||
auto pooling = opr::PoolingForward::make(input, param, {}, policy); | |||||
auto pooling = opr::PoolingForward::make(input, param, policy); | |||||
auto func = graph->compile({make_callback_copy(pooling, host_y)}); | auto func = graph->compile({make_callback_copy(pooling, host_y)}); | ||||
func->execute().wait(); | func->execute().wait(); | ||||
@@ -253,4 +253,11 @@ struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; | |||||
__caller_OprRegShallowCopy##_cls##_ins; \ | __caller_OprRegShallowCopy##_cls##_ins; \ | ||||
} | } | ||||
/*! | |||||
* \brief register opr serialization and shallow copy methods | |||||
*/ | |||||
#define MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(_cls, _arity, _copy) \ | |||||
MGB_SEREG_OPR(_cls, _arity) \ | |||||
MGB_REG_OPR_SHALLOW_COPY(_cls, ::mgb::serialization::_copy<_cls>) | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |