GitOrigin-RevId: 9d80f2009d
release-1.7
@@ -682,7 +682,8 @@ public: | |||||
* http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html | * http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html | ||||
* | * | ||||
* \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$, | * \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$, | ||||
* where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1), iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$. | |||||
* where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1), | |||||
* iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$. | |||||
*/ | */ | ||||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | ||||
_megdnn_workspace workspace) = 0; | _megdnn_workspace workspace) = 0; | ||||
@@ -724,7 +725,8 @@ protected: | |||||
}; | }; | ||||
class SlidingWindowTransposeForward : public SlidingWindowTransposeBase { | class SlidingWindowTransposeForward : public SlidingWindowTransposeBase { | ||||
DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, 1); | |||||
DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, | |||||
1); | |||||
public: | public: | ||||
/** | /** | ||||
@@ -744,7 +746,8 @@ protected: | |||||
using SlidingWindowTranspose = SlidingWindowTransposeForward; | using SlidingWindowTranspose = SlidingWindowTransposeForward; | ||||
class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase { | class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase { | ||||
DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, 1); | |||||
DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, | |||||
1); | |||||
public: | public: | ||||
/** | /** | ||||
@@ -975,7 +978,7 @@ protected: | |||||
}; | }; | ||||
class BNForward : public BNBase { | class BNForward : public BNBase { | ||||
DEF_OPR_IMPL(BNForward, BNBase, 6, 5); | |||||
DEF_OPR_IMPL(BNForward, BNBase, 6, 6); | |||||
public: | public: | ||||
/** | /** | ||||
@@ -986,10 +989,11 @@ public: | |||||
* \param[out] dst (n, c, h, w) | * \param[out] dst (n, c, h, w) | ||||
* \param[out] mean (see m_param.ParamDim) Global mean. | * \param[out] mean (see m_param.ParamDim) Global mean. | ||||
* \param[out] variance (see m_param.ParamDim) Global variance. | * \param[out] variance (see m_param.ParamDim) Global variance. | ||||
* \Param[out] batch_mean (see m_param.ParamDim) | |||||
* \param[out] batch_mean (see m_param.ParamDim) | |||||
* Optionally cached intermediate mean from forward pass | * Optionally cached intermediate mean from forward pass | ||||
* \Param[out] batch_inv_variance (see m_param.ParamDim) | |||||
* \param[out] batch_inv_variance (see m_param.ParamDim) | |||||
* Optionally cached intermediate variance from forward pass | * Optionally cached intermediate variance from forward pass | ||||
* \param[out] reserve (see cudnnBatchNormalizationForwardTrainingEx) | |||||
* src and dst must have the same shape. | * src and dst must have the same shape. | ||||
* src and dst must be contiguous. | * src and dst must be contiguous. | ||||
*/ | */ | ||||
@@ -998,17 +1002,20 @@ public: | |||||
_megdnn_tensor_inout variance, | _megdnn_tensor_inout variance, | ||||
_megdnn_tensor_out batch_mean, | _megdnn_tensor_out batch_mean, | ||||
_megdnn_tensor_out batch_inv_variance, | _megdnn_tensor_out batch_inv_variance, | ||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, TensorLayout& bn_scale, | |||||
TensorLayout& bn_bias, TensorLayout& mean, | |||||
_megdnn_tensor_out reserve, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& src, const TensorLayout& bn_scale, | |||||
const TensorLayout& bn_bias, TensorLayout& mean, | |||||
TensorLayout& variance, TensorLayout& batch_mean, | TensorLayout& variance, TensorLayout& batch_mean, | ||||
TensorLayout& batch_inv_variance, TensorLayout& dst); | |||||
TensorLayout& batch_inv_variance, TensorLayout& reserve, | |||||
TensorLayout& dst); | |||||
virtual size_t get_workspace_in_bytes( | virtual size_t get_workspace_in_bytes( | ||||
const TensorLayout& src, const TensorLayout& bn_scale, | const TensorLayout& src, const TensorLayout& bn_scale, | ||||
const TensorLayout& bn_bias, const TensorLayout& mean, | const TensorLayout& bn_bias, const TensorLayout& mean, | ||||
const TensorLayout& variance, const TensorLayout& batch_mean, | const TensorLayout& variance, const TensorLayout& batch_mean, | ||||
const TensorLayout& batch_inv_variance, | |||||
const TensorLayout& batch_inv_variance, const TensorLayout& reserve, | |||||
const TensorLayout& dst) = 0; | const TensorLayout& dst) = 0; | ||||
virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& src, const TensorLayout& bn_scale, | void check_exec(const TensorLayout& src, const TensorLayout& bn_scale, | ||||
@@ -1016,12 +1023,13 @@ protected: | |||||
const TensorLayout& variance, | const TensorLayout& variance, | ||||
const TensorLayout& batch_mean, | const TensorLayout& batch_mean, | ||||
const TensorLayout& batch_inv_variance, | const TensorLayout& batch_inv_variance, | ||||
const TensorLayout& dst, size_t workspace_in_bytes); | |||||
const TensorLayout& dst, size_t workspace_in_bytes, | |||||
size_t reserve_in_bytes = 0); | |||||
}; | }; | ||||
using BN = BNForward; | using BN = BNForward; | ||||
class BNBackward : public BNBase { | class BNBackward : public BNBase { | ||||
DEF_OPR_IMPL(BNBackward, BNBase, 5, 3); | |||||
DEF_OPR_IMPL(BNBackward, BNBase, 6, 3); | |||||
public: | public: | ||||
/** | /** | ||||
@@ -1035,19 +1043,23 @@ public: | |||||
Calculated in the forwardpropagation. | Calculated in the forwardpropagation. | ||||
* \param[in] saved_batch_variance of the input batch. | * \param[in] saved_batch_variance of the input batch. | ||||
Calculated in the forwardpropagation. | Calculated in the forwardpropagation. | ||||
* \param[in] reserve (see cudnnBatchNormalizationBackwardEx) | |||||
*/ | */ | ||||
virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | ||||
_megdnn_tensor_in saved_batch_mean, | _megdnn_tensor_in saved_batch_mean, | ||||
_megdnn_tensor_in saved_batch_variance, | _megdnn_tensor_in saved_batch_variance, | ||||
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, | |||||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, | |||||
_megdnn_tensor_out d_bn_scale, | |||||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | ||||
_megdnn_workspace workspace) = 0; | _megdnn_workspace workspace) = 0; | ||||
virtual size_t get_workspace_in_bytes( | virtual size_t get_workspace_in_bytes( | ||||
const TensorLayout& x, const TensorLayout& dy, | const TensorLayout& x, const TensorLayout& dy, | ||||
const TensorLayout& saved_batch_mean, | const TensorLayout& saved_batch_mean, | ||||
const TensorLayout& saved_batch_variance, | const TensorLayout& saved_batch_variance, | ||||
const TensorLayout& bn_scale, const TensorLayout& d_bn_scale, | |||||
const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0; | |||||
const TensorLayout& bn_scale, const TensorLayout& reserve, | |||||
const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias, | |||||
const TensorLayout& dx) = 0; | |||||
virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0; | |||||
protected: | protected: | ||||
void check_exec(const TensorLayout& x, const TensorLayout& dy, | void check_exec(const TensorLayout& x, const TensorLayout& dy, | ||||
@@ -1056,7 +1068,7 @@ protected: | |||||
const TensorLayout& bn_scale, | const TensorLayout& bn_scale, | ||||
const TensorLayout& d_bn_scale, | const TensorLayout& d_bn_scale, | ||||
const TensorLayout& d_bn_bias, const TensorLayout& dx, | const TensorLayout& d_bn_bias, const TensorLayout& dx, | ||||
size_t workspace_in_bytes); | |||||
size_t workspace_in_bytes, size_t reserve_in_bytes = 0); | |||||
}; | }; | ||||
class LRNBase : public OperatorBase { | class LRNBase : public OperatorBase { | ||||
@@ -253,7 +253,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
add_enum_alias('Format', 'Convolution') | add_enum_alias('Format', 'Convolution') | ||||
) | ) | ||||
(pdef('AdaptivePooling', version=0,is_legacy=True). | |||||
(pdef('AdaptivePooling', version=0, is_legacy=True). | |||||
add_enum_alias('Mode', 'PoolingV0'). | add_enum_alias('Mode', 'PoolingV0'). | ||||
add_enum_alias('Format', 'ConvolutionV0') | add_enum_alias('Format', 'ConvolutionV0') | ||||
) | ) | ||||
@@ -276,6 +276,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
Doc('DIM_11HW = 0', 'Dim of params (Sigma, Mu) is 1 x 1 x H x W'), | Doc('DIM_11HW = 0', 'Dim of params (Sigma, Mu) is 1 x 1 x H x W'), | ||||
Doc('DIM_1CHW = 1', 'Dim of params (Sigma, Mu) is 1 x C x H x W'), | Doc('DIM_1CHW = 1', 'Dim of params (Sigma, Mu) is 1 x C x H x W'), | ||||
Doc('DIM_1C11 = 2', 'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'), | Doc('DIM_1C11 = 2', 'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'), | ||||
Doc('DIM_111C = 3', 'Dim of params (Sigma, Mu) is 1 x 1 x 1 x C'), | |||||
name_field='param_dim' | name_field='param_dim' | ||||
). | ). | ||||
add_enum( | add_enum( | ||||
@@ -4,9 +4,9 @@ | |||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | */ | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
@@ -14,28 +14,32 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
void BNForward::deduce_layout(const TensorLayout& src, TensorLayout&, | |||||
TensorLayout&, TensorLayout&, TensorLayout&, | |||||
TensorLayout&, TensorLayout&, TensorLayout& dst) { | |||||
void BNForward::deduce_layout(const TensorLayout& src, const TensorLayout&, | |||||
const TensorLayout&, TensorLayout&, TensorLayout&, | |||||
TensorLayout&, TensorLayout&, | |||||
TensorLayout& reserve, TensorLayout& dst) { | |||||
reserve = {{get_reserve_in_bytes(src)}, dtype::Byte()}; | |||||
dst = src; | dst = src; | ||||
} | } | ||||
void BNForward::check_exec(const TensorLayout& src, const TensorLayout& bn_scale, | |||||
const TensorLayout& bn_bias, const TensorLayout& mean, | |||||
const TensorLayout& variance, | |||||
const TensorLayout& batch_mean, | |||||
const TensorLayout& batch_inv_variance, | |||||
const TensorLayout& dst, size_t workspace_in_bytes) { | |||||
void BNForward::check_exec( | |||||
const TensorLayout& src, const TensorLayout& bn_scale, | |||||
const TensorLayout& bn_bias, const TensorLayout& mean, | |||||
const TensorLayout& variance, const TensorLayout& batch_mean, | |||||
const TensorLayout& batch_inv_variance, const TensorLayout& dst, | |||||
size_t workspace_in_bytes, size_t reserve_in_bytes) { | |||||
megdnn_assert_contiguous(src); | megdnn_assert_contiguous(src); | ||||
megdnn_assert_eq_layout(src, dst); | megdnn_assert_eq_layout(src, dst); | ||||
megdnn_assert_eq_layout(bn_scale, bn_bias); | megdnn_assert_eq_layout(bn_scale, bn_bias); | ||||
megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); | megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); | ||||
megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); | megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); | ||||
auto required_workspace_in_bytes = | |||||
get_workspace_in_bytes(src, bn_scale, bn_bias, mean, variance, | |||||
batch_mean, batch_inv_variance, dst); | |||||
auto required_workspace_in_bytes = get_workspace_in_bytes( | |||||
src, bn_scale, bn_bias, mean, variance, batch_mean, | |||||
batch_inv_variance, {}, dst); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | ||||
auto required_reserve_in_bytes = get_reserve_in_bytes(src); | |||||
megdnn_assert(reserve_in_bytes >= required_reserve_in_bytes); | |||||
} | } | ||||
void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, | void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, | ||||
@@ -44,7 +48,8 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, | |||||
const TensorLayout& bn_scale, | const TensorLayout& bn_scale, | ||||
const TensorLayout& d_bn_scale, | const TensorLayout& d_bn_scale, | ||||
const TensorLayout& d_bn_bias, | const TensorLayout& d_bn_bias, | ||||
const TensorLayout& dx, size_t workspace_in_bytes) { | |||||
const TensorLayout& dx, size_t workspace_in_bytes, | |||||
size_t reserve_in_bytes) { | |||||
megdnn_assert_contiguous(x); | megdnn_assert_contiguous(x); | ||||
megdnn_assert_eq_layout(x, dy); | megdnn_assert_eq_layout(x, dy); | ||||
megdnn_assert_eq_layout(x, dx); | megdnn_assert_eq_layout(x, dx); | ||||
@@ -54,11 +59,14 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, | |||||
megdnn_assert_eq_layout(saved_batch_mean, bn_scale); | megdnn_assert_eq_layout(saved_batch_mean, bn_scale); | ||||
megdnn_assert(x.dtype.category() == DTypeCategory::FLOAT); | megdnn_assert(x.dtype.category() == DTypeCategory::FLOAT); | ||||
megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); | megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); | ||||
auto required_workspace_in_bytes = | |||||
get_workspace_in_bytes(x, dy, saved_batch_mean, saved_batch_variance, | |||||
bn_scale, d_bn_scale, d_bn_bias, dx); | |||||
auto required_workspace_in_bytes = get_workspace_in_bytes( | |||||
x, dy, saved_batch_mean, saved_batch_variance, bn_scale, {}, | |||||
d_bn_scale, d_bn_bias, dx); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | ||||
megdnn_assert(param().fwd_mode == Param::FwdMode::TRAINING, "BNBackward only support TRAINING mode"); | |||||
auto required_reserve_in_bytes = get_reserve_in_bytes(x); | |||||
megdnn_assert(reserve_in_bytes >= required_reserve_in_bytes); | |||||
megdnn_assert(param().fwd_mode == Param::FwdMode::TRAINING, | |||||
"BNBackward only support TRAINING mode"); | |||||
} | } | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -55,8 +55,8 @@ DEF(GroupLocalBackwardData, 3, true, false); | |||||
DEF(GroupLocalBackwardFilter, 3, true, false); | DEF(GroupLocalBackwardFilter, 3, true, false); | ||||
DEF(LRNForward, 2, true, true); | DEF(LRNForward, 2, true, true); | ||||
DEF(LRNBackward, 4, true, false); | DEF(LRNBackward, 4, true, false); | ||||
DEF(BNForward, 8, true, true); | |||||
DEF(BNBackward, 8, true, false); | |||||
DEF(BNForward, 9, true, true); | |||||
DEF(BNBackward, 9, true, false); | |||||
DEF(ROIPoolingForward, 4, true, false); | DEF(ROIPoolingForward, 4, true, false); | ||||
DEF(ROIPoolingBackward, 5, true, false); | DEF(ROIPoolingBackward, 5, true, false); | ||||
DEF(CorrelationForward, 3, true, true); | DEF(CorrelationForward, 3, true, true); | ||||
@@ -4,9 +4,9 @@ | |||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | */ | ||||
#include "./opr_impl.h" | #include "./opr_impl.h" | ||||
@@ -17,9 +17,11 @@ namespace cuda { | |||||
namespace batch_normalization { | namespace batch_normalization { | ||||
void BNTensorDescHolder::setup(const TensorLayout& x, | |||||
const ParamDim& param_dim) { | |||||
BNTensorDescHolder::BNTensorDescHolder(const TensorLayout& x, | |||||
const ParamDim& param_dim, | |||||
const FwdMode& fwd_mode) { | |||||
TensorShape xy_shape(x); | TensorShape xy_shape(x); | ||||
Format xy_format = Format::NCHW; | |||||
switch (param_dim) { | switch (param_dim) { | ||||
case ParamDim::DIM_11HW: | case ParamDim::DIM_11HW: | ||||
@@ -34,50 +36,116 @@ void BNTensorDescHolder::setup(const TensorLayout& x, | |||||
case ParamDim::DIM_1C11: | case ParamDim::DIM_1C11: | ||||
bn_mode = CUDNN_BATCHNORM_SPATIAL; | bn_mode = CUDNN_BATCHNORM_SPATIAL; | ||||
break; | break; | ||||
case ParamDim::DIM_111C: | |||||
bn_mode = CUDNN_BATCHNORM_SPATIAL; | |||||
xy_format = Format::NHWC; | |||||
#if CUDNN_VERSION >= 7410 | |||||
if (fwd_mode == FwdMode::TRAINING) { | |||||
bn_mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; | |||||
} | |||||
#endif // CUDNN_VERSION >= 7400 | |||||
break; | |||||
default: | default: | ||||
megdnn_throw("Unknown param dim type of batch normalization."); | megdnn_throw("Unknown param dim type of batch normalization."); | ||||
} | } | ||||
xy_desc.set(TensorLayout(xy_shape, x.dtype)); | |||||
xy_desc.set(TensorLayout(xy_shape, x.dtype), xy_format); | |||||
param_desc.set(xy_desc.desc, bn_mode); | param_desc.set(xy_desc.desc, bn_mode); | ||||
} | } | ||||
size_t get_reserve_size(const cudnnHandle_t& handle, | |||||
const BNTensorDescHolder& tensor_desc) { | |||||
#if CUDNN_VERSION >= 7410 | |||||
size_t reserve_size; | |||||
cudnn_check(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( | |||||
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, | |||||
nullptr, // activationDesc | |||||
tensor_desc.xy_desc.desc, // xDesc | |||||
&reserve_size)); | |||||
return reserve_size; | |||||
#else | |||||
return 0; | |||||
#endif // CUDNN_VERSION >= 7410 | |||||
} | |||||
} // namespace batch_normalization | } // namespace batch_normalization | ||||
using batch_normalization::BNTensorDescHolder; | |||||
size_t BNForwardImpl::get_workspace_in_bytes( | |||||
const TensorLayout& src, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout&) { | |||||
#if CUDNN_VERSION >= 7410 | |||||
auto handle = cudnn_handle(this->handle()); | |||||
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode); | |||||
size_t workspace_size; | |||||
cudnn_check(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( | |||||
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, | |||||
tensor_desc.xy_desc.desc, // xDesc | |||||
tensor_desc.xy_desc.desc, // yDesc | |||||
tensor_desc.xy_desc.desc, // zDesc | |||||
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||||
nullptr, // activationDesc | |||||
&workspace_size)); | |||||
return workspace_size; | |||||
#else | |||||
return 0; | |||||
#endif // CUDNN_VERSION >= 7410 | |||||
} | |||||
size_t BNForwardImpl::get_reserve_in_bytes(const TensorLayout& src) { | |||||
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode); | |||||
return batch_normalization::get_reserve_size(cudnn_handle(this->handle()), | |||||
tensor_desc); | |||||
} | |||||
void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | ||||
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | ||||
_megdnn_tensor_out variance, | _megdnn_tensor_out variance, | ||||
_megdnn_tensor_out batch_mean, | _megdnn_tensor_out batch_mean, | ||||
_megdnn_tensor_out batch_inv_variance, | _megdnn_tensor_out batch_inv_variance, | ||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
_megdnn_tensor_out reserve, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, | check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, | ||||
variance.layout, batch_mean.layout, batch_inv_variance.layout, | variance.layout, batch_mean.layout, batch_inv_variance.layout, | ||||
dst.layout, workspace.size); | |||||
dst.layout, workspace.size, reserve.layout.access_bytes()); | |||||
auto handle = cudnn_handle(this->handle()); | auto handle = cudnn_handle(this->handle()); | ||||
m_tensor_desc.setup(src.layout, m_param.param_dim); | |||||
BNTensorDescHolder tensor_desc(src.layout, m_param.param_dim, | |||||
m_param.fwd_mode); | |||||
float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
switch (m_param.fwd_mode) { | switch (m_param.fwd_mode) { | ||||
case param::BN::FwdMode::TRAINING: | case param::BN::FwdMode::TRAINING: | ||||
#if CUDNN_VERSION >= 7410 | |||||
cudnn_check(cudnnBatchNormalizationForwardTrainingEx( | |||||
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, | |||||
&alpha, &beta, // one & zero | |||||
tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x | |||||
nullptr, nullptr, // zDesc & z | |||||
tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y | |||||
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||||
bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, | |||||
mean.raw_ptr, variance.raw_ptr, m_param.epsilon, | |||||
batch_mean.raw_ptr, batch_inv_variance.raw_ptr, nullptr, | |||||
workspace.raw_ptr, workspace.size, reserve.raw_ptr, | |||||
reserve.layout.access_bytes())); | |||||
#else | |||||
cudnn_check(cudnnBatchNormalizationForwardTraining( | cudnn_check(cudnnBatchNormalizationForwardTraining( | ||||
handle, m_tensor_desc.bn_mode, | |||||
&alpha, &beta, | |||||
m_tensor_desc.xy_desc.desc, // xDesc | |||||
src.raw_ptr, // x | |||||
m_tensor_desc.xy_desc.desc, // yDesc | |||||
dst.raw_ptr, // y | |||||
m_tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||||
handle, tensor_desc.bn_mode, &alpha, &beta, | |||||
tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x | |||||
tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y | |||||
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||||
bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, | bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, | ||||
mean.raw_ptr, variance.raw_ptr, m_param.epsilon, | mean.raw_ptr, variance.raw_ptr, m_param.epsilon, | ||||
batch_mean.raw_ptr, batch_inv_variance.raw_ptr)); | batch_mean.raw_ptr, batch_inv_variance.raw_ptr)); | ||||
#endif // CUDNN_VERSION >= 7410 | |||||
break; | break; | ||||
case param::BN::FwdMode::INFERENCE: | case param::BN::FwdMode::INFERENCE: | ||||
cudnn_check(cudnnBatchNormalizationForwardInference( | cudnn_check(cudnnBatchNormalizationForwardInference( | ||||
handle, m_tensor_desc.bn_mode, | |||||
&alpha, &beta, | |||||
m_tensor_desc.xy_desc.desc, src.raw_ptr, | |||||
m_tensor_desc.xy_desc.desc, dst.raw_ptr, | |||||
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, | |||||
handle, tensor_desc.bn_mode, &alpha, &beta, | |||||
tensor_desc.xy_desc.desc, src.raw_ptr, | |||||
tensor_desc.xy_desc.desc, dst.raw_ptr, | |||||
tensor_desc.param_desc.desc, bn_scale.raw_ptr, | |||||
bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr, | bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr, | ||||
m_param.epsilon)); | m_param.epsilon)); | ||||
break; | break; | ||||
@@ -86,30 +154,79 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||||
} | } | ||||
} | } | ||||
size_t BNBackwardImpl::get_workspace_in_bytes( | |||||
const TensorLayout& x, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout&) { | |||||
#if CUDNN_VERSION >= 7410 | |||||
auto handle = cudnn_handle(this->handle()); | |||||
BNTensorDescHolder tensor_desc(x, m_param.param_dim, m_param.fwd_mode); | |||||
size_t workspace_size; | |||||
cudnn_check(cudnnGetBatchNormalizationBackwardExWorkspaceSize( | |||||
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, | |||||
tensor_desc.xy_desc.desc, // xDesc | |||||
tensor_desc.xy_desc.desc, // yDesc | |||||
tensor_desc.xy_desc.desc, // dyDesc | |||||
nullptr, // dzDesc | |||||
tensor_desc.xy_desc.desc, // dxDesc | |||||
tensor_desc.param_desc.desc, // dBnScaleBiasDesc | |||||
nullptr, // activationDesc | |||||
&workspace_size)); | |||||
return workspace_size; | |||||
#else | |||||
return 0; | |||||
#endif // CUDNN_VERSION >= 7410 | |||||
} | |||||
size_t BNBackwardImpl::get_reserve_in_bytes(const TensorLayout& src) { | |||||
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode); | |||||
return batch_normalization::get_reserve_size(cudnn_handle(this->handle()), | |||||
tensor_desc); | |||||
} | |||||
void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | ||||
_megdnn_tensor_in saved_batch_mean, | _megdnn_tensor_in saved_batch_mean, | ||||
_megdnn_tensor_in saved_batch_inv_variance, | _megdnn_tensor_in saved_batch_inv_variance, | ||||
_megdnn_tensor_in bn_scale, | |||||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, | |||||
_megdnn_tensor_out d_bn_scale, | _megdnn_tensor_out d_bn_scale, | ||||
_megdnn_tensor_out d_bn_bias, | |||||
_megdnn_tensor_out dx, _megdnn_workspace workspace) { | |||||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(x.layout, dy.layout, saved_batch_mean.layout, | check_exec(x.layout, dy.layout, saved_batch_mean.layout, | ||||
saved_batch_inv_variance.layout, bn_scale.layout, | saved_batch_inv_variance.layout, bn_scale.layout, | ||||
d_bn_scale.layout, d_bn_bias.layout, dx.layout, | |||||
workspace.size); | |||||
d_bn_scale.layout, d_bn_bias.layout, dx.layout, workspace.size, | |||||
reserve.layout.access_bytes()); | |||||
auto handle = cudnn_handle(this->handle()); | auto handle = cudnn_handle(this->handle()); | ||||
m_tensor_desc.setup(x.layout, m_param.param_dim); | |||||
BNTensorDescHolder tensor_desc(x.layout, m_param.param_dim, | |||||
m_param.fwd_mode); | |||||
float alpha = 1.0, beta = 0.0; | float alpha = 1.0, beta = 0.0; | ||||
#if CUDNN_VERSION >= 7410 | |||||
cudnn_check(cudnnBatchNormalizationBackwardEx( | |||||
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, &beta, | |||||
&alpha, &beta, tensor_desc.xy_desc.desc, | |||||
x.raw_ptr, // xDesc & x | |||||
nullptr, nullptr, // yDesc & y | |||||
tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy | |||||
nullptr, nullptr, // dzDesc & dz | |||||
tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx | |||||
tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale | |||||
nullptr, // bnBias | |||||
d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias | |||||
m_param.epsilon, saved_batch_mean.raw_ptr, | |||||
saved_batch_inv_variance.raw_ptr, nullptr, workspace.raw_ptr, | |||||
workspace.size, reserve.raw_ptr, reserve.layout.access_bytes())); | |||||
#else | |||||
cudnn_check(cudnnBatchNormalizationBackward( | cudnn_check(cudnnBatchNormalizationBackward( | ||||
handle, m_tensor_desc.bn_mode, | |||||
&alpha, &beta, &alpha, &beta, | |||||
m_tensor_desc.xy_desc.desc, x.raw_ptr, | |||||
m_tensor_desc.xy_desc.desc, dy.raw_ptr, | |||||
m_tensor_desc.xy_desc.desc, dx.raw_ptr, | |||||
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, | |||||
d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, m_param.epsilon, | |||||
saved_batch_mean.raw_ptr, saved_batch_inv_variance.raw_ptr)); | |||||
handle, tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta, | |||||
tensor_desc.xy_desc.desc, x.raw_ptr, // xDesc & x | |||||
tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy | |||||
tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx | |||||
tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale | |||||
d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias | |||||
m_param.epsilon, saved_batch_mean.raw_ptr, | |||||
saved_batch_inv_variance.raw_ptr)); | |||||
#endif | |||||
} | } | ||||
} // namespace cuda | } // namespace cuda | ||||
@@ -4,9 +4,9 @@ | |||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
@@ -20,14 +20,20 @@ namespace batch_normalization { | |||||
struct BNTensorDescHolder { | struct BNTensorDescHolder { | ||||
using ParamDim = param::BN::ParamDim; | using ParamDim = param::BN::ParamDim; | ||||
using FwdMode = param::BN::FwdMode; | |||||
using Format = param::Convolution::Format; | |||||
TensorDesc xy_desc; | TensorDesc xy_desc; | ||||
BNParamDesc param_desc; | BNParamDesc param_desc; | ||||
cudnnBatchNormMode_t bn_mode; | cudnnBatchNormMode_t bn_mode; | ||||
void setup(const TensorLayout& x, const ParamDim& param_dim); | |||||
BNTensorDescHolder(const TensorLayout& x, const ParamDim& param_dim, | |||||
const FwdMode& fwd_mode); | |||||
}; | }; | ||||
size_t get_reserve_size(const cudnnHandle_t& handle, | |||||
const BNTensorDescHolder& tensor_desc); | |||||
} // namespace batch_normalization | } // namespace batch_normalization | ||||
class BNForwardImpl final : public BNForward { | class BNForwardImpl final : public BNForward { | ||||
@@ -36,19 +42,15 @@ public: | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | ||||
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | ||||
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, | _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, | ||||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
private: | |||||
batch_normalization::BNTensorDescHolder m_tensor_desc; | |||||
const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override; | |||||
size_t get_reserve_in_bytes(const TensorLayout& src) override; | |||||
}; | }; | ||||
class BNBackwardImpl final : public BNBackward { | class BNBackwardImpl final : public BNBackward { | ||||
@@ -57,20 +59,16 @@ public: | |||||
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | ||||
_megdnn_tensor_in saved_batch_mean, | _megdnn_tensor_in saved_batch_mean, | ||||
_megdnn_tensor_in saved_batch_inv_variance, | _megdnn_tensor_in saved_batch_inv_variance, | ||||
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, | |||||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||||
_megdnn_workspace workspace) override; | |||||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, | |||||
_megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, | |||||
_megdnn_tensor_out dx, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
size_t get_workspace_in_bytes(const TensorLayout& x, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
private: | |||||
batch_normalization::BNTensorDescHolder m_tensor_desc; | |||||
const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override; | |||||
size_t get_reserve_in_bytes(const TensorLayout& src) override; | |||||
}; | }; | ||||
} // namespace cuda | } // namespace cuda | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/naive/batch_normalization/opr_impl.h" | #include "src/naive/batch_normalization/opr_impl.h" | ||||
@@ -219,13 +220,14 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||||
_megdnn_tensor_inout variance, | _megdnn_tensor_inout variance, | ||||
_megdnn_tensor_out batch_mean, | _megdnn_tensor_out batch_mean, | ||||
_megdnn_tensor_out batch_inv_variance, | _megdnn_tensor_out batch_inv_variance, | ||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
_megdnn_tensor_out, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, | check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, | ||||
variance.layout, batch_mean.layout, batch_inv_variance.layout, | variance.layout, batch_mean.layout, batch_inv_variance.layout, | ||||
dst.layout, workspace.size); | dst.layout, workspace.size); | ||||
DNN_INC_FLOAT16(if (src.layout.dtype == dtype::Float16() && | DNN_INC_FLOAT16(if (src.layout.dtype == dtype::Float16() && | ||||
bn_scale.layout.dtype == dtype::Float32()) { | |||||
bn_scale.layout.dtype == dtype::Float32()) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(({ | MEGDNN_DISPATCH_CPU_KERN_OPR(({ | ||||
using T0 = typename DTypeTrait<dtype::Float16>::ctype; | using T0 = typename DTypeTrait<dtype::Float16>::ctype; | ||||
using T1 = typename DTypeTrait<dtype::Float32>::ctype; | using T1 = typename DTypeTrait<dtype::Float32>::ctype; | ||||
@@ -263,7 +265,7 @@ WorkspaceBundle BNBackwardImpl::get_workspace_bundle(size_t x_size, | |||||
size_t BNBackwardImpl::get_workspace_in_bytes( | size_t BNBackwardImpl::get_workspace_in_bytes( | ||||
const TensorLayout& x, const TensorLayout&, const TensorLayout&, | const TensorLayout& x, const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, const TensorLayout& bn_scale, const TensorLayout&, | const TensorLayout&, const TensorLayout& bn_scale, const TensorLayout&, | ||||
const TensorLayout&, const TensorLayout&) { | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout&) { | |||||
auto x_size = x.total_nr_elems(), param_size = bn_scale.total_nr_elems(); | auto x_size = x.total_nr_elems(), param_size = bn_scale.total_nr_elems(); | ||||
return get_workspace_bundle(x_size, param_size).total_size_in_bytes(); | return get_workspace_bundle(x_size, param_size).total_size_in_bytes(); | ||||
} | } | ||||
@@ -271,7 +273,7 @@ size_t BNBackwardImpl::get_workspace_in_bytes( | |||||
void BNBackwardImpl::exec(_megdnn_tensor_in x_in, _megdnn_tensor_in dy_in, | void BNBackwardImpl::exec(_megdnn_tensor_in x_in, _megdnn_tensor_in dy_in, | ||||
_megdnn_tensor_in saved_batch_mean, | _megdnn_tensor_in saved_batch_mean, | ||||
_megdnn_tensor_in saved_batch_inv_variance, | _megdnn_tensor_in saved_batch_inv_variance, | ||||
_megdnn_tensor_in bn_scale, | |||||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in, | |||||
_megdnn_tensor_out d_bn_scale, | _megdnn_tensor_out d_bn_scale, | ||||
_megdnn_tensor_out d_bn_bias, | _megdnn_tensor_out d_bn_bias, | ||||
_megdnn_tensor_out dx_out, | _megdnn_tensor_out dx_out, | ||||
@@ -286,7 +288,7 @@ void BNBackwardImpl::exec(_megdnn_tensor_in x_in, _megdnn_tensor_in dy_in, | |||||
workspace.raw_ptr); | workspace.raw_ptr); | ||||
DNN_INC_FLOAT16(if (x_in.layout.dtype == dtype::Float16() && | DNN_INC_FLOAT16(if (x_in.layout.dtype == dtype::Float16() && | ||||
bn_scale.layout.dtype == dtype::Float32()) { | |||||
bn_scale.layout.dtype == dtype::Float32()) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(({ | MEGDNN_DISPATCH_CPU_KERN_OPR(({ | ||||
using T0 = typename DTypeTrait<dtype::Float16>::ctype; | using T0 = typename DTypeTrait<dtype::Float16>::ctype; | ||||
using T1 = typename DTypeTrait<dtype::Float32>::ctype; | using T1 = typename DTypeTrait<dtype::Float32>::ctype; | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
@@ -21,16 +22,17 @@ public: | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | ||||
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | ||||
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, | _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, | ||||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override { | const TensorLayout&) override { | ||||
return 0; | return 0; | ||||
} | } | ||||
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; } | |||||
}; | }; | ||||
class BNBackwardImpl final : public BNBackward { | class BNBackwardImpl final : public BNBackward { | ||||
@@ -39,15 +41,17 @@ public: | |||||
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | ||||
_megdnn_tensor_in saved_batch_mean, | _megdnn_tensor_in saved_batch_mean, | ||||
_megdnn_tensor_in saved_batch_inv_variance, | _megdnn_tensor_in saved_batch_inv_variance, | ||||
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, | |||||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||||
_megdnn_workspace workspace) override; | |||||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, | |||||
_megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, | |||||
_megdnn_tensor_out dx, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout& x, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout& x, const TensorLayout&, | ||||
const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout& bn_scale, | const TensorLayout& bn_scale, | ||||
const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, | |||||
const TensorLayout&) override; | const TensorLayout&) override; | ||||
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; } | |||||
private: | private: | ||||
WorkspaceBundle get_workspace_bundle(size_t x_size, size_t param_size, | WorkspaceBundle get_workspace_bundle(size_t x_size, size_t param_size, | ||||
@@ -49,7 +49,8 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||||
_megdnn_tensor_out variance, | _megdnn_tensor_out variance, | ||||
_megdnn_tensor_out batch_mean, | _megdnn_tensor_out batch_mean, | ||||
_megdnn_tensor_out batch_inv_variance, | _megdnn_tensor_out batch_inv_variance, | ||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
_megdnn_tensor_out, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, | check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, | ||||
variance.layout, batch_mean.layout, batch_inv_variance.layout, | variance.layout, batch_mean.layout, batch_inv_variance.layout, | ||||
dst.layout, workspace.size); | dst.layout, workspace.size); | ||||
@@ -88,7 +89,7 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||||
void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | ||||
_megdnn_tensor_in saved_batch_mean, | _megdnn_tensor_in saved_batch_mean, | ||||
_megdnn_tensor_in saved_batch_inv_variance, | _megdnn_tensor_in saved_batch_inv_variance, | ||||
_megdnn_tensor_in bn_scale, | |||||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in, | |||||
_megdnn_tensor_out d_bn_scale, | _megdnn_tensor_out d_bn_scale, | ||||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | ||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
@@ -37,16 +37,17 @@ public: | |||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | ||||
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | ||||
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, | _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, | ||||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) override; | |||||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override { | const TensorLayout&) override { | ||||
return 0; | return 0; | ||||
} | } | ||||
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; } | |||||
private: | private: | ||||
batch_normalization::BNTensorDescHolder m_tensor_desc; | batch_normalization::BNTensorDescHolder m_tensor_desc; | ||||
@@ -58,17 +59,18 @@ public: | |||||
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | ||||
_megdnn_tensor_in saved_batch_mean, | _megdnn_tensor_in saved_batch_mean, | ||||
_megdnn_tensor_in saved_batch_inv_variance, | _megdnn_tensor_in saved_batch_inv_variance, | ||||
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, | |||||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||||
_megdnn_workspace workspace) override; | |||||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, | |||||
_megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, | |||||
_megdnn_tensor_out dx, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, const TensorLayout&, | const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override { | const TensorLayout&) override { | ||||
return 0; | return 0; | ||||
} | } | ||||
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; } | |||||
private: | private: | ||||
batch_normalization::BNTensorDescHolder m_tensor_desc; | batch_normalization::BNTensorDescHolder m_tensor_desc; | ||||