GitOrigin-RevId: 9d80f2009d
release-1.7
@@ -682,7 +682,8 @@ public: | |||
* http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html | |||
* | |||
* \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$, | |||
* where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1), iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$. | |||
* where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1), | |||
* iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$. | |||
*/ | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
@@ -724,7 +725,8 @@ protected: | |||
}; | |||
class SlidingWindowTransposeForward : public SlidingWindowTransposeBase { | |||
DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, 1); | |||
DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, | |||
1); | |||
public: | |||
/** | |||
@@ -744,7 +746,8 @@ protected: | |||
using SlidingWindowTranspose = SlidingWindowTransposeForward; | |||
class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase { | |||
DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, 1); | |||
DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, | |||
1); | |||
public: | |||
/** | |||
@@ -975,7 +978,7 @@ protected: | |||
}; | |||
class BNForward : public BNBase { | |||
DEF_OPR_IMPL(BNForward, BNBase, 6, 5); | |||
DEF_OPR_IMPL(BNForward, BNBase, 6, 6); | |||
public: | |||
/** | |||
@@ -986,10 +989,11 @@ public: | |||
* \param[out] dst (n, c, h, w) | |||
* \param[out] mean (see m_param.ParamDim) Global mean. | |||
* \param[out] variance (see m_param.ParamDim) Global variance. | |||
* \Param[out] batch_mean (see m_param.ParamDim) | |||
* \param[out] batch_mean (see m_param.ParamDim) | |||
* Optionally cached intermediate mean from forward pass | |||
* \Param[out] batch_inv_variance (see m_param.ParamDim) | |||
* \param[out] batch_inv_variance (see m_param.ParamDim) | |||
* Optionally cached intermediate variance from forward pass | |||
* \param[out] reserve (see cudnnBatchNormalizationForwardTrainingEx) | |||
* src and dst must have the same shape. | |||
* src and dst must be contiguous. | |||
*/ | |||
@@ -998,17 +1002,20 @@ public: | |||
_megdnn_tensor_inout variance, | |||
_megdnn_tensor_out batch_mean, | |||
_megdnn_tensor_out batch_inv_variance, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, TensorLayout& bn_scale, | |||
TensorLayout& bn_bias, TensorLayout& mean, | |||
_megdnn_tensor_out reserve, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, const TensorLayout& bn_scale, | |||
const TensorLayout& bn_bias, TensorLayout& mean, | |||
TensorLayout& variance, TensorLayout& batch_mean, | |||
TensorLayout& batch_inv_variance, TensorLayout& dst); | |||
TensorLayout& batch_inv_variance, TensorLayout& reserve, | |||
TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& bn_scale, | |||
const TensorLayout& bn_bias, const TensorLayout& mean, | |||
const TensorLayout& variance, const TensorLayout& batch_mean, | |||
const TensorLayout& batch_inv_variance, | |||
const TensorLayout& batch_inv_variance, const TensorLayout& reserve, | |||
const TensorLayout& dst) = 0; | |||
virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& bn_scale, | |||
@@ -1016,12 +1023,13 @@ protected: | |||
const TensorLayout& variance, | |||
const TensorLayout& batch_mean, | |||
const TensorLayout& batch_inv_variance, | |||
const TensorLayout& dst, size_t workspace_in_bytes); | |||
const TensorLayout& dst, size_t workspace_in_bytes, | |||
size_t reserve_in_bytes = 0); | |||
}; | |||
using BN = BNForward; | |||
class BNBackward : public BNBase { | |||
DEF_OPR_IMPL(BNBackward, BNBase, 5, 3); | |||
DEF_OPR_IMPL(BNBackward, BNBase, 6, 3); | |||
public: | |||
/** | |||
@@ -1035,19 +1043,23 @@ public: | |||
Calculated in the forwardpropagation. | |||
* \param[in] saved_batch_variance of the input batch. | |||
Calculated in the forwardpropagation. | |||
* \param[in] reserve (see cudnnBatchNormalizationBackwardEx) | |||
*/ | |||
virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | |||
_megdnn_tensor_in saved_batch_mean, | |||
_megdnn_tensor_in saved_batch_variance, | |||
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, | |||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, | |||
_megdnn_tensor_out d_bn_scale, | |||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& x, const TensorLayout& dy, | |||
const TensorLayout& saved_batch_mean, | |||
const TensorLayout& saved_batch_variance, | |||
const TensorLayout& bn_scale, const TensorLayout& d_bn_scale, | |||
const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0; | |||
const TensorLayout& bn_scale, const TensorLayout& reserve, | |||
const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias, | |||
const TensorLayout& dx) = 0; | |||
virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& x, const TensorLayout& dy, | |||
@@ -1056,7 +1068,7 @@ protected: | |||
const TensorLayout& bn_scale, | |||
const TensorLayout& d_bn_scale, | |||
const TensorLayout& d_bn_bias, const TensorLayout& dx, | |||
size_t workspace_in_bytes); | |||
size_t workspace_in_bytes, size_t reserve_in_bytes = 0); | |||
}; | |||
class LRNBase : public OperatorBase { | |||
@@ -253,7 +253,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||
add_enum_alias('Format', 'Convolution') | |||
) | |||
(pdef('AdaptivePooling', version=0,is_legacy=True). | |||
(pdef('AdaptivePooling', version=0, is_legacy=True). | |||
add_enum_alias('Mode', 'PoolingV0'). | |||
add_enum_alias('Format', 'ConvolutionV0') | |||
) | |||
@@ -276,6 +276,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||
Doc('DIM_11HW = 0', 'Dim of params (Sigma, Mu) is 1 x 1 x H x W'), | |||
Doc('DIM_1CHW = 1', 'Dim of params (Sigma, Mu) is 1 x C x H x W'), | |||
Doc('DIM_1C11 = 2', 'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'), | |||
Doc('DIM_111C = 3', 'Dim of params (Sigma, Mu) is 1 x 1 x 1 x C'), | |||
name_field='param_dim' | |||
). | |||
add_enum( | |||
@@ -4,9 +4,9 @@ | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
@@ -14,28 +14,32 @@ | |||
namespace megdnn { | |||
void BNForward::deduce_layout(const TensorLayout& src, TensorLayout&, | |||
TensorLayout&, TensorLayout&, TensorLayout&, | |||
TensorLayout&, TensorLayout&, TensorLayout& dst) { | |||
void BNForward::deduce_layout(const TensorLayout& src, const TensorLayout&, | |||
const TensorLayout&, TensorLayout&, TensorLayout&, | |||
TensorLayout&, TensorLayout&, | |||
TensorLayout& reserve, TensorLayout& dst) { | |||
reserve = {{get_reserve_in_bytes(src)}, dtype::Byte()}; | |||
dst = src; | |||
} | |||
void BNForward::check_exec(const TensorLayout& src, const TensorLayout& bn_scale, | |||
const TensorLayout& bn_bias, const TensorLayout& mean, | |||
const TensorLayout& variance, | |||
const TensorLayout& batch_mean, | |||
const TensorLayout& batch_inv_variance, | |||
const TensorLayout& dst, size_t workspace_in_bytes) { | |||
void BNForward::check_exec( | |||
const TensorLayout& src, const TensorLayout& bn_scale, | |||
const TensorLayout& bn_bias, const TensorLayout& mean, | |||
const TensorLayout& variance, const TensorLayout& batch_mean, | |||
const TensorLayout& batch_inv_variance, const TensorLayout& dst, | |||
size_t workspace_in_bytes, size_t reserve_in_bytes) { | |||
megdnn_assert_contiguous(src); | |||
megdnn_assert_eq_layout(src, dst); | |||
megdnn_assert_eq_layout(bn_scale, bn_bias); | |||
megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); | |||
megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); | |||
auto required_workspace_in_bytes = | |||
get_workspace_in_bytes(src, bn_scale, bn_bias, mean, variance, | |||
batch_mean, batch_inv_variance, dst); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes( | |||
src, bn_scale, bn_bias, mean, variance, batch_mean, | |||
batch_inv_variance, {}, dst); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
auto required_reserve_in_bytes = get_reserve_in_bytes(src); | |||
megdnn_assert(reserve_in_bytes >= required_reserve_in_bytes); | |||
} | |||
void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, | |||
@@ -44,7 +48,8 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, | |||
const TensorLayout& bn_scale, | |||
const TensorLayout& d_bn_scale, | |||
const TensorLayout& d_bn_bias, | |||
const TensorLayout& dx, size_t workspace_in_bytes) { | |||
const TensorLayout& dx, size_t workspace_in_bytes, | |||
size_t reserve_in_bytes) { | |||
megdnn_assert_contiguous(x); | |||
megdnn_assert_eq_layout(x, dy); | |||
megdnn_assert_eq_layout(x, dx); | |||
@@ -54,11 +59,14 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, | |||
megdnn_assert_eq_layout(saved_batch_mean, bn_scale); | |||
megdnn_assert(x.dtype.category() == DTypeCategory::FLOAT); | |||
megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); | |||
auto required_workspace_in_bytes = | |||
get_workspace_in_bytes(x, dy, saved_batch_mean, saved_batch_variance, | |||
bn_scale, d_bn_scale, d_bn_bias, dx); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes( | |||
x, dy, saved_batch_mean, saved_batch_variance, bn_scale, {}, | |||
d_bn_scale, d_bn_bias, dx); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
megdnn_assert(param().fwd_mode == Param::FwdMode::TRAINING, "BNBackward only support TRAINING mode"); | |||
auto required_reserve_in_bytes = get_reserve_in_bytes(x); | |||
megdnn_assert(reserve_in_bytes >= required_reserve_in_bytes); | |||
megdnn_assert(param().fwd_mode == Param::FwdMode::TRAINING, | |||
"BNBackward only support TRAINING mode"); | |||
} | |||
} // namespace megdnn | |||
@@ -55,8 +55,8 @@ DEF(GroupLocalBackwardData, 3, true, false); | |||
DEF(GroupLocalBackwardFilter, 3, true, false); | |||
DEF(LRNForward, 2, true, true); | |||
DEF(LRNBackward, 4, true, false); | |||
DEF(BNForward, 8, true, true); | |||
DEF(BNBackward, 8, true, false); | |||
DEF(BNForward, 9, true, true); | |||
DEF(BNBackward, 9, true, false); | |||
DEF(ROIPoolingForward, 4, true, false); | |||
DEF(ROIPoolingBackward, 5, true, false); | |||
DEF(CorrelationForward, 3, true, true); | |||
@@ -4,9 +4,9 @@ | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./opr_impl.h" | |||
@@ -17,9 +17,11 @@ namespace cuda { | |||
namespace batch_normalization { | |||
void BNTensorDescHolder::setup(const TensorLayout& x, | |||
const ParamDim& param_dim) { | |||
BNTensorDescHolder::BNTensorDescHolder(const TensorLayout& x, | |||
const ParamDim& param_dim, | |||
const FwdMode& fwd_mode) { | |||
TensorShape xy_shape(x); | |||
Format xy_format = Format::NCHW; | |||
switch (param_dim) { | |||
case ParamDim::DIM_11HW: | |||
@@ -34,50 +36,116 @@ void BNTensorDescHolder::setup(const TensorLayout& x, | |||
case ParamDim::DIM_1C11: | |||
bn_mode = CUDNN_BATCHNORM_SPATIAL; | |||
break; | |||
case ParamDim::DIM_111C: | |||
bn_mode = CUDNN_BATCHNORM_SPATIAL; | |||
xy_format = Format::NHWC; | |||
#if CUDNN_VERSION >= 7410 | |||
if (fwd_mode == FwdMode::TRAINING) { | |||
bn_mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; | |||
} | |||
#endif // CUDNN_VERSION >= 7400 | |||
break; | |||
default: | |||
megdnn_throw("Unknown param dim type of batch normalization."); | |||
} | |||
xy_desc.set(TensorLayout(xy_shape, x.dtype)); | |||
xy_desc.set(TensorLayout(xy_shape, x.dtype), xy_format); | |||
param_desc.set(xy_desc.desc, bn_mode); | |||
} | |||
size_t get_reserve_size(const cudnnHandle_t& handle, | |||
const BNTensorDescHolder& tensor_desc) { | |||
#if CUDNN_VERSION >= 7410 | |||
size_t reserve_size; | |||
cudnn_check(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( | |||
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, | |||
nullptr, // activationDesc | |||
tensor_desc.xy_desc.desc, // xDesc | |||
&reserve_size)); | |||
return reserve_size; | |||
#else | |||
return 0; | |||
#endif // CUDNN_VERSION >= 7410 | |||
} | |||
} // namespace batch_normalization | |||
using batch_normalization::BNTensorDescHolder; | |||
size_t BNForwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, const TensorLayout&) { | |||
#if CUDNN_VERSION >= 7410 | |||
auto handle = cudnn_handle(this->handle()); | |||
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode); | |||
size_t workspace_size; | |||
cudnn_check(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( | |||
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, | |||
tensor_desc.xy_desc.desc, // xDesc | |||
tensor_desc.xy_desc.desc, // yDesc | |||
tensor_desc.xy_desc.desc, // zDesc | |||
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||
nullptr, // activationDesc | |||
&workspace_size)); | |||
return workspace_size; | |||
#else | |||
return 0; | |||
#endif // CUDNN_VERSION >= 7410 | |||
} | |||
size_t BNForwardImpl::get_reserve_in_bytes(const TensorLayout& src) { | |||
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode); | |||
return batch_normalization::get_reserve_size(cudnn_handle(this->handle()), | |||
tensor_desc); | |||
} | |||
void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | |||
_megdnn_tensor_out variance, | |||
_megdnn_tensor_out batch_mean, | |||
_megdnn_tensor_out batch_inv_variance, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
_megdnn_tensor_out reserve, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, | |||
variance.layout, batch_mean.layout, batch_inv_variance.layout, | |||
dst.layout, workspace.size); | |||
dst.layout, workspace.size, reserve.layout.access_bytes()); | |||
auto handle = cudnn_handle(this->handle()); | |||
m_tensor_desc.setup(src.layout, m_param.param_dim); | |||
BNTensorDescHolder tensor_desc(src.layout, m_param.param_dim, | |||
m_param.fwd_mode); | |||
float alpha = 1.0f, beta = 0.0f; | |||
switch (m_param.fwd_mode) { | |||
case param::BN::FwdMode::TRAINING: | |||
#if CUDNN_VERSION >= 7410 | |||
cudnn_check(cudnnBatchNormalizationForwardTrainingEx( | |||
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, | |||
&alpha, &beta, // one & zero | |||
tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x | |||
nullptr, nullptr, // zDesc & z | |||
tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y | |||
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||
bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, | |||
mean.raw_ptr, variance.raw_ptr, m_param.epsilon, | |||
batch_mean.raw_ptr, batch_inv_variance.raw_ptr, nullptr, | |||
workspace.raw_ptr, workspace.size, reserve.raw_ptr, | |||
reserve.layout.access_bytes())); | |||
#else | |||
cudnn_check(cudnnBatchNormalizationForwardTraining( | |||
handle, m_tensor_desc.bn_mode, | |||
&alpha, &beta, | |||
m_tensor_desc.xy_desc.desc, // xDesc | |||
src.raw_ptr, // x | |||
m_tensor_desc.xy_desc.desc, // yDesc | |||
dst.raw_ptr, // y | |||
m_tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||
handle, tensor_desc.bn_mode, &alpha, &beta, | |||
tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x | |||
tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y | |||
tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||
bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, | |||
mean.raw_ptr, variance.raw_ptr, m_param.epsilon, | |||
batch_mean.raw_ptr, batch_inv_variance.raw_ptr)); | |||
#endif // CUDNN_VERSION >= 7410 | |||
break; | |||
case param::BN::FwdMode::INFERENCE: | |||
cudnn_check(cudnnBatchNormalizationForwardInference( | |||
handle, m_tensor_desc.bn_mode, | |||
&alpha, &beta, | |||
m_tensor_desc.xy_desc.desc, src.raw_ptr, | |||
m_tensor_desc.xy_desc.desc, dst.raw_ptr, | |||
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, | |||
handle, tensor_desc.bn_mode, &alpha, &beta, | |||
tensor_desc.xy_desc.desc, src.raw_ptr, | |||
tensor_desc.xy_desc.desc, dst.raw_ptr, | |||
tensor_desc.param_desc.desc, bn_scale.raw_ptr, | |||
bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr, | |||
m_param.epsilon)); | |||
break; | |||
@@ -86,30 +154,79 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||
} | |||
} | |||
size_t BNBackwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& x, const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, const TensorLayout&) { | |||
#if CUDNN_VERSION >= 7410 | |||
auto handle = cudnn_handle(this->handle()); | |||
BNTensorDescHolder tensor_desc(x, m_param.param_dim, m_param.fwd_mode); | |||
size_t workspace_size; | |||
cudnn_check(cudnnGetBatchNormalizationBackwardExWorkspaceSize( | |||
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, | |||
tensor_desc.xy_desc.desc, // xDesc | |||
tensor_desc.xy_desc.desc, // yDesc | |||
tensor_desc.xy_desc.desc, // dyDesc | |||
nullptr, // dzDesc | |||
tensor_desc.xy_desc.desc, // dxDesc | |||
tensor_desc.param_desc.desc, // dBnScaleBiasDesc | |||
nullptr, // activationDesc | |||
&workspace_size)); | |||
return workspace_size; | |||
#else | |||
return 0; | |||
#endif // CUDNN_VERSION >= 7410 | |||
} | |||
size_t BNBackwardImpl::get_reserve_in_bytes(const TensorLayout& src) { | |||
BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode); | |||
return batch_normalization::get_reserve_size(cudnn_handle(this->handle()), | |||
tensor_desc); | |||
} | |||
void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | |||
_megdnn_tensor_in saved_batch_mean, | |||
_megdnn_tensor_in saved_batch_inv_variance, | |||
_megdnn_tensor_in bn_scale, | |||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, | |||
_megdnn_tensor_out d_bn_scale, | |||
_megdnn_tensor_out d_bn_bias, | |||
_megdnn_tensor_out dx, _megdnn_workspace workspace) { | |||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||
_megdnn_workspace workspace) { | |||
check_exec(x.layout, dy.layout, saved_batch_mean.layout, | |||
saved_batch_inv_variance.layout, bn_scale.layout, | |||
d_bn_scale.layout, d_bn_bias.layout, dx.layout, | |||
workspace.size); | |||
d_bn_scale.layout, d_bn_bias.layout, dx.layout, workspace.size, | |||
reserve.layout.access_bytes()); | |||
auto handle = cudnn_handle(this->handle()); | |||
m_tensor_desc.setup(x.layout, m_param.param_dim); | |||
BNTensorDescHolder tensor_desc(x.layout, m_param.param_dim, | |||
m_param.fwd_mode); | |||
float alpha = 1.0, beta = 0.0; | |||
#if CUDNN_VERSION >= 7410 | |||
cudnn_check(cudnnBatchNormalizationBackwardEx( | |||
handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, &beta, | |||
&alpha, &beta, tensor_desc.xy_desc.desc, | |||
x.raw_ptr, // xDesc & x | |||
nullptr, nullptr, // yDesc & y | |||
tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy | |||
nullptr, nullptr, // dzDesc & dz | |||
tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx | |||
tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale | |||
nullptr, // bnBias | |||
d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias | |||
m_param.epsilon, saved_batch_mean.raw_ptr, | |||
saved_batch_inv_variance.raw_ptr, nullptr, workspace.raw_ptr, | |||
workspace.size, reserve.raw_ptr, reserve.layout.access_bytes())); | |||
#else | |||
cudnn_check(cudnnBatchNormalizationBackward( | |||
handle, m_tensor_desc.bn_mode, | |||
&alpha, &beta, &alpha, &beta, | |||
m_tensor_desc.xy_desc.desc, x.raw_ptr, | |||
m_tensor_desc.xy_desc.desc, dy.raw_ptr, | |||
m_tensor_desc.xy_desc.desc, dx.raw_ptr, | |||
m_tensor_desc.param_desc.desc, bn_scale.raw_ptr, | |||
d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, m_param.epsilon, | |||
saved_batch_mean.raw_ptr, saved_batch_inv_variance.raw_ptr)); | |||
handle, tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta, | |||
tensor_desc.xy_desc.desc, x.raw_ptr, // xDesc & x | |||
tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy | |||
tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx | |||
tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale | |||
d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias | |||
m_param.epsilon, saved_batch_mean.raw_ptr, | |||
saved_batch_inv_variance.raw_ptr)); | |||
#endif | |||
} | |||
} // namespace cuda | |||
@@ -4,9 +4,9 @@ | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
@@ -20,14 +20,20 @@ namespace batch_normalization { | |||
struct BNTensorDescHolder { | |||
using ParamDim = param::BN::ParamDim; | |||
using FwdMode = param::BN::FwdMode; | |||
using Format = param::Convolution::Format; | |||
TensorDesc xy_desc; | |||
BNParamDesc param_desc; | |||
cudnnBatchNormMode_t bn_mode; | |||
void setup(const TensorLayout& x, const ParamDim& param_dim); | |||
BNTensorDescHolder(const TensorLayout& x, const ParamDim& param_dim, | |||
const FwdMode& fwd_mode); | |||
}; | |||
size_t get_reserve_size(const cudnnHandle_t& handle, | |||
const BNTensorDescHolder& tensor_desc); | |||
} // namespace batch_normalization | |||
class BNForwardImpl final : public BNForward { | |||
@@ -36,19 +42,15 @@ public: | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | |||
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, | |||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
private: | |||
batch_normalization::BNTensorDescHolder m_tensor_desc; | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override; | |||
size_t get_reserve_in_bytes(const TensorLayout& src) override; | |||
}; | |||
class BNBackwardImpl final : public BNBackward { | |||
@@ -57,20 +59,16 @@ public: | |||
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | |||
_megdnn_tensor_in saved_batch_mean, | |||
_megdnn_tensor_in saved_batch_inv_variance, | |||
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, | |||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||
_megdnn_workspace workspace) override; | |||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, | |||
_megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, | |||
_megdnn_tensor_out dx, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
size_t get_workspace_in_bytes(const TensorLayout& x, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
private: | |||
batch_normalization::BNTensorDescHolder m_tensor_desc; | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override; | |||
size_t get_reserve_in_bytes(const TensorLayout& src) override; | |||
}; | |||
} // namespace cuda | |||
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/naive/batch_normalization/opr_impl.h" | |||
@@ -219,13 +220,14 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||
_megdnn_tensor_inout variance, | |||
_megdnn_tensor_out batch_mean, | |||
_megdnn_tensor_out batch_inv_variance, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
_megdnn_tensor_out, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, | |||
variance.layout, batch_mean.layout, batch_inv_variance.layout, | |||
dst.layout, workspace.size); | |||
DNN_INC_FLOAT16(if (src.layout.dtype == dtype::Float16() && | |||
bn_scale.layout.dtype == dtype::Float32()) { | |||
bn_scale.layout.dtype == dtype::Float32()) { | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(({ | |||
using T0 = typename DTypeTrait<dtype::Float16>::ctype; | |||
using T1 = typename DTypeTrait<dtype::Float32>::ctype; | |||
@@ -263,7 +265,7 @@ WorkspaceBundle BNBackwardImpl::get_workspace_bundle(size_t x_size, | |||
size_t BNBackwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& x, const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout& bn_scale, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&) { | |||
const TensorLayout&, const TensorLayout&, const TensorLayout&) { | |||
auto x_size = x.total_nr_elems(), param_size = bn_scale.total_nr_elems(); | |||
return get_workspace_bundle(x_size, param_size).total_size_in_bytes(); | |||
} | |||
@@ -271,7 +273,7 @@ size_t BNBackwardImpl::get_workspace_in_bytes( | |||
void BNBackwardImpl::exec(_megdnn_tensor_in x_in, _megdnn_tensor_in dy_in, | |||
_megdnn_tensor_in saved_batch_mean, | |||
_megdnn_tensor_in saved_batch_inv_variance, | |||
_megdnn_tensor_in bn_scale, | |||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in, | |||
_megdnn_tensor_out d_bn_scale, | |||
_megdnn_tensor_out d_bn_bias, | |||
_megdnn_tensor_out dx_out, | |||
@@ -286,7 +288,7 @@ void BNBackwardImpl::exec(_megdnn_tensor_in x_in, _megdnn_tensor_in dy_in, | |||
workspace.raw_ptr); | |||
DNN_INC_FLOAT16(if (x_in.layout.dtype == dtype::Float16() && | |||
bn_scale.layout.dtype == dtype::Float32()) { | |||
bn_scale.layout.dtype == dtype::Float32()) { | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(({ | |||
using T0 = typename DTypeTrait<dtype::Float16>::ctype; | |||
using T1 = typename DTypeTrait<dtype::Float32>::ctype; | |||
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
@@ -21,16 +22,17 @@ public: | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | |||
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, | |||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; } | |||
}; | |||
class BNBackwardImpl final : public BNBackward { | |||
@@ -39,15 +41,17 @@ public: | |||
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | |||
_megdnn_tensor_in saved_batch_mean, | |||
_megdnn_tensor_in saved_batch_inv_variance, | |||
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, | |||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||
_megdnn_workspace workspace) override; | |||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, | |||
_megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, | |||
_megdnn_tensor_out dx, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout& x, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout& bn_scale, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) override; | |||
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; } | |||
private: | |||
WorkspaceBundle get_workspace_bundle(size_t x_size, size_t param_size, | |||
@@ -49,7 +49,8 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||
_megdnn_tensor_out variance, | |||
_megdnn_tensor_out batch_mean, | |||
_megdnn_tensor_out batch_inv_variance, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
_megdnn_tensor_out, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, | |||
variance.layout, batch_mean.layout, batch_inv_variance.layout, | |||
dst.layout, workspace.size); | |||
@@ -88,7 +89,7 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||
void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | |||
_megdnn_tensor_in saved_batch_mean, | |||
_megdnn_tensor_in saved_batch_inv_variance, | |||
_megdnn_tensor_in bn_scale, | |||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in, | |||
_megdnn_tensor_out d_bn_scale, | |||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||
_megdnn_workspace workspace) { | |||
@@ -37,16 +37,17 @@ public: | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, | |||
_megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, | |||
_megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, | |||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
_megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; } | |||
private: | |||
batch_normalization::BNTensorDescHolder m_tensor_desc; | |||
@@ -58,17 +59,18 @@ public: | |||
void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, | |||
_megdnn_tensor_in saved_batch_mean, | |||
_megdnn_tensor_in saved_batch_inv_variance, | |||
_megdnn_tensor_in bn_scale, _megdnn_tensor_out d_bn_scale, | |||
_megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, | |||
_megdnn_workspace workspace) override; | |||
_megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, | |||
_megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, | |||
_megdnn_tensor_out dx, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
size_t get_reserve_in_bytes(const TensorLayout&) override { return 0; } | |||
private: | |||
batch_normalization::BNTensorDescHolder m_tensor_desc; | |||