@@ -1869,6 +1869,13 @@ table LayerNorm { | |||||
normalized_size:ulong = 1; | normalized_size:ulong = 1; | ||||
} | } | ||||
table GroupNorm { | |||||
affine:bool = true; | |||||
eps:float = 1e-5; | |||||
group:uint = 1; | |||||
format:ConvolutionFormat = NCHW; | |||||
} | |||||
table Dropout { | table Dropout { | ||||
drop_prob:float = 0; | drop_prob:float = 0; | ||||
seed:ulong = 0; | seed:ulong = 0; | ||||
@@ -140,6 +140,7 @@ union OperatorParam { | |||||
param.LSTM = 89, | param.LSTM = 89, | ||||
param.Softmax = 90, | param.Softmax = 90, | ||||
param.Diag = 91, | param.Diag = 91, | ||||
param.GroupNorm = 92, | |||||
} | } | ||||
table Operator { | table Operator { | ||||
@@ -2430,6 +2430,76 @@ protected: | |||||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw, | const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw, | ||||
size_t workspace_in_bytes); | size_t workspace_in_bytes); | ||||
}; | }; | ||||
class GroupNormBase : public OperatorBase { | |||||
DEF_OPR_IMPL_CTOR(GroupNormBase, OperatorBase); | |||||
DEF_OPR_PARAM(GroupNorm); | |||||
protected: | |||||
void deduce_layout_fwd( | |||||
const TensorLayout& data, const TensorLayout& weight, | |||||
const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean, | |||||
TensorLayout& rstd); | |||||
void check_layout_fwd( | |||||
const TensorLayout& data, const TensorLayout& weight, | |||||
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, | |||||
const TensorLayout& rstd); | |||||
}; | |||||
class GroupNormForward : public GroupNormBase { | |||||
DEF_OPR_IMPL(GroupNormForward, GroupNormBase, 3, 3); | |||||
public: | |||||
virtual void exec( | |||||
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||||
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||||
_megdnn_workspace workspace) = 0; | |||||
MGE_WIN_DECLSPEC_FUC void deduce_layout( | |||||
const TensorLayout& data, const TensorLayout& weight, | |||||
const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean, | |||||
TensorLayout& rstd); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& data, const TensorLayout& weight, | |||||
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, | |||||
const TensorLayout& rstd) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& data, const TensorLayout& weight, | |||||
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, | |||||
const TensorLayout& rstd, size_t workspace_in_bytes); | |||||
}; | |||||
using GroupNorm = GroupNormForward; | |||||
class GroupNormBackward : public GroupNormBase { | |||||
DEF_OPR_IMPL(GroupNormBackward, GroupNormBase, 5, 3); | |||||
public: | |||||
virtual void exec( | |||||
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||||
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||||
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout( | |||||
const TensorLayout& diff, const TensorLayout& data, | |||||
const TensorLayout& weight, const TensorLayout& mean, | |||||
const TensorLayout& rstd, TensorLayout& ddata, TensorLayout& dweight, | |||||
TensorLayout& dbias); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& diff, const TensorLayout& data, | |||||
const TensorLayout& weight, const TensorLayout& mean, | |||||
const TensorLayout& rstd, const TensorLayout& ddata, | |||||
const TensorLayout& dweight, const TensorLayout& dbias) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& diff, const TensorLayout& data, | |||||
const TensorLayout& weight, const TensorLayout& mean, | |||||
const TensorLayout& rstd, const TensorLayout& ddata, | |||||
const TensorLayout& dweight, const TensorLayout& dbias, | |||||
size_t workspace_in_bytes); | |||||
}; | |||||
} // namespace megdnn | } // namespace megdnn | ||||
#include "megdnn/internal/opr_header_epilogue.h" | #include "megdnn/internal/opr_header_epilogue.h" | ||||
@@ -1247,6 +1247,13 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | |||||
.add_fields('uint64', 'normalized_size', '1') | .add_fields('uint64', 'normalized_size', '1') | ||||
) | ) | ||||
(pdef('GroupNorm') | |||||
.add_fields('bool', 'affine', 'true') | |||||
.add_fields('float32', 'eps', '1e-5f') | |||||
.add_fields('uint32', 'group', '1') | |||||
.add_enum_alias('Format', 'Convolution') | |||||
) | |||||
(pdef('Dropout') | (pdef('Dropout') | ||||
.add_fields('float32', 'drop_prob', '0') | .add_fields('float32', 'drop_prob', '0') | ||||
.add_fields('uint64', 'seed', '0') | .add_fields('uint64', 'seed', '0') | ||||
@@ -0,0 +1,121 @@ | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
using Param = GroupNormBase::Param; | |||||
void GroupNormBase::deduce_layout_fwd( | |||||
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||||
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||||
MEGDNN_MARK_USED_VAR(weight); | |||||
MEGDNN_MARK_USED_VAR(bias); | |||||
size_t N = data.shape[0]; | |||||
size_t group = param().group; | |||||
TensorLayout unnormalized_layout({N, group}, dtype::Float32()); | |||||
dst = data; | |||||
mean = unnormalized_layout; | |||||
rstd = unnormalized_layout; | |||||
} | |||||
void GroupNormBase::check_layout_fwd( | |||||
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||||
const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) { | |||||
megdnn_assert_contiguous(data); | |||||
megdnn_assert_contiguous(weight); | |||||
megdnn_assert_contiguous(bias); | |||||
megdnn_assert_contiguous(dst); | |||||
megdnn_assert_contiguous(mean); | |||||
megdnn_assert_contiguous(rstd); | |||||
auto errmsg = [&]() { | |||||
return megdnn_layout_msg(data) + ", " + megdnn_layout_msg(weight) + ", " + | |||||
megdnn_layout_msg(bias) + ", " + megdnn_layout_msg(dst) + ", " + | |||||
megdnn_layout_msg(mean) + ", " + megdnn_layout_msg(rstd); | |||||
}; | |||||
MEGDNN_MARK_USED_VAR(errmsg); | |||||
megdnn_assert(data.eq_layout(dst), "%s", errmsg().c_str()); | |||||
megdnn_assert(weight.eq_layout(bias), "%s", errmsg().c_str()); | |||||
megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str()); | |||||
auto p = param(); | |||||
size_t C = data.shape[1]; | |||||
size_t group = p.group; | |||||
megdnn_assert( | |||||
group > 0, "Expected num groups to be greater than 0, got %zu", group); | |||||
megdnn_assert( | |||||
C % group == 0, | |||||
"Expected number of channels in input to be divisible by num_groups, but " | |||||
"got Channel of shape %zu and num_groups= %zu", | |||||
C, group); | |||||
} | |||||
void GroupNormForward::deduce_layout( | |||||
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||||
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||||
deduce_layout_fwd(data, weight, bias, dst, mean, rstd); | |||||
} | |||||
void GroupNormForward::check_exec( | |||||
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||||
const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd, | |||||
size_t workspace_in_bytes) { | |||||
check_layout_fwd(data, weight, bias, dst, mean, rstd); | |||||
auto required_workspace_in_bytes = | |||||
get_workspace_in_bytes(data, weight, bias, dst, mean, rstd); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
} | |||||
void GroupNormBackward::deduce_layout( | |||||
const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight, | |||||
const TensorLayout& mean, const TensorLayout& rstd, TensorLayout& ddata, | |||||
TensorLayout& dweight, TensorLayout& dbias) { | |||||
MEGDNN_MARK_USED_VAR(diff); | |||||
MEGDNN_MARK_USED_VAR(mean); | |||||
MEGDNN_MARK_USED_VAR(rstd); | |||||
ddata = data; | |||||
dweight = weight; | |||||
dbias = weight; | |||||
} | |||||
void GroupNormBackward::check_exec( | |||||
const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight, | |||||
const TensorLayout& mean, const TensorLayout& rstd, const TensorLayout& ddata, | |||||
const TensorLayout& dweight, const TensorLayout& dbias, | |||||
size_t workspace_in_bytes) { | |||||
auto p = param(); | |||||
auto required_workspace_in_bytes = get_workspace_in_bytes( | |||||
diff, data, weight, mean, rstd, ddata, dweight, dbias); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
megdnn_assert_contiguous(diff); | |||||
megdnn_assert_contiguous(data); | |||||
megdnn_assert_contiguous(mean); | |||||
megdnn_assert_contiguous(rstd); | |||||
megdnn_assert_contiguous(ddata); | |||||
if (p.affine) { | |||||
megdnn_assert_contiguous(weight); | |||||
megdnn_assert_contiguous(dweight); | |||||
megdnn_assert_contiguous(dbias); | |||||
} | |||||
auto errmsg = [&]() { | |||||
return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(data) + ", " + | |||||
megdnn_layout_msg(weight) + ", " + megdnn_layout_msg(mean) + ", " + | |||||
megdnn_layout_msg(rstd) + ", " + megdnn_layout_msg(ddata) + ", " + | |||||
megdnn_layout_msg(dweight) + ", " + megdnn_layout_msg(dbias); | |||||
}; | |||||
MEGDNN_MARK_USED_VAR(errmsg); | |||||
megdnn_assert(data.eq_layout(ddata), "%s", errmsg().c_str()); | |||||
megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str()); | |||||
if (p.affine) { | |||||
megdnn_assert(weight.eq_layout(dweight), "%s", errmsg().c_str()); | |||||
megdnn_assert(weight.eq_layout(dbias), "%s", errmsg().c_str()); | |||||
} | |||||
} | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -216,7 +216,9 @@ private: | |||||
cb(NormForward) \ | cb(NormForward) \ | ||||
cb(RegionRestrictedConvolutionForward) \ | cb(RegionRestrictedConvolutionForward) \ | ||||
cb(RegionRestrictedConvolutionBackwardData) \ | cb(RegionRestrictedConvolutionBackwardData) \ | ||||
cb(RegionRestrictedConvolutionBackwardFilter) | |||||
cb(RegionRestrictedConvolutionBackwardFilter) \ | |||||
cb(GroupNormForward) \ | |||||
cb(GroupNormBackward) | |||||
// clang-format on | // clang-format on | ||||
/*! | /*! | ||||
@@ -142,6 +142,8 @@ DEF(SoftmaxBackward, 3, true, false); | |||||
DEF(RegionRestrictedConvolutionForward, 5, true, true); | DEF(RegionRestrictedConvolutionForward, 5, true, true); | ||||
DEF(RegionRestrictedConvolutionBackwardData, 5, true, false); | DEF(RegionRestrictedConvolutionBackwardData, 5, true, false); | ||||
DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false); | DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false); | ||||
DEF(GroupNormForward, 6, true, true); | |||||
DEF(GroupNormBackward, 8, true, true); | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -0,0 +1,529 @@ | |||||
#include <stdio.h> | |||||
#include <thrust/pair.h> | |||||
#include <thrust/tuple.h> | |||||
#include <cfloat> | |||||
#include "megdnn/arch.h" | |||||
#include "megdnn/basic_types.h" | |||||
#include "megdnn/dtype.h" | |||||
#include "src/cuda/cuda_shfl_compat.cuh" | |||||
#include "src/cuda/group_norm/group_norm_cuda.cuh" | |||||
#include "src/cuda/utils.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace group_norm { | |||||
// warp size may be used as array length, or used in host function, | |||||
// so we define WARP_SIZE rather than using warpSize | |||||
#define WARP_SIZE 32 | |||||
template <size_t kStart, size_t kEnd, bool kStop> | |||||
struct Compare { | |||||
template <typename T> | |||||
__host__ __device__ inline static bool Run(const T* d1, const T* d2) { | |||||
return d1[kStart] == d2[kStart] && | |||||
Compare<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d1, d2); | |||||
} | |||||
}; | |||||
template <size_t kStart, size_t kEnd> | |||||
struct Compare<kStart, kEnd, true> { | |||||
template <typename T> | |||||
__host__ __device__ inline constexpr static bool Run(const T* d1, const T* d2) { | |||||
return true; | |||||
} | |||||
}; | |||||
template <size_t N> | |||||
using UnrollCompare = Compare<0, N, N == 0>; | |||||
template <typename T, size_t kStart, size_t kEnd, bool kStop> | |||||
struct UnrollVarArgsAssignImpl { | |||||
template <typename... Args> | |||||
__host__ __device__ inline static void Run(T* d, T val, Args... args) { | |||||
static_assert(sizeof...(args) + 1 == kEnd - kStart, "Wrong argument"); | |||||
d[kStart] = val; | |||||
UnrollVarArgsAssignImpl<T, kStart + 1, kEnd, kStart + 1 == kEnd>::Run( | |||||
d, args...); | |||||
} | |||||
}; | |||||
template <typename T, size_t kStart, size_t kEnd> | |||||
struct UnrollVarArgsAssignImpl<T, kStart, kEnd, true> { | |||||
__host__ __device__ inline static void Run(T* d) {} | |||||
}; | |||||
template <typename T> | |||||
struct UnrollVarArgsAssign { | |||||
template <typename... Args> | |||||
__host__ __device__ inline static void Run(T* d, Args... args) { | |||||
UnrollVarArgsAssignImpl<T, 0, sizeof...(Args), sizeof...(Args) == 0>::Run( | |||||
d, args...); | |||||
} | |||||
}; | |||||
template <typename T, size_t N> | |||||
class Array { | |||||
public: | |||||
static constexpr size_t kSize = N; | |||||
__host__ __device__ inline Array() {} | |||||
template <typename... Args> | |||||
__host__ __device__ inline explicit Array(const T& val, Args... args) { | |||||
static_assert(N == sizeof...(Args) + 1, "Invalid argument"); | |||||
UnrollVarArgsAssign<T>::Run(data_, val, args...); | |||||
} | |||||
__host__ __device__ inline T& operator[](size_t i) { return *(data_ + i); } | |||||
__host__ __device__ inline const T& operator[](size_t i) const { | |||||
return *(data_ + i); | |||||
} | |||||
private: | |||||
template <typename U> | |||||
__host__ __device__ static inline U* advance(U* ptr, size_t i) { | |||||
return ptr + i; | |||||
} | |||||
T data_[N]; | |||||
}; | |||||
// ================================ group_norm forward =========================== | |||||
// implementation of groupnorm_forward from | |||||
// https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/gpu/group_norm_kernel.cu#L115 | |||||
template <typename T> | |||||
__forceinline__ __device__ T | |||||
CudaShuffleDownSync(T val, int delta, int width = warpSize) { | |||||
return __shfl_down(val, static_cast<unsigned>(delta), width); | |||||
} | |||||
template <> | |||||
__forceinline__ __device__ dt_float16 | |||||
CudaShuffleDownSync(dt_float16 val, int delta, int width) { | |||||
return dt_float16(__shfl_down(val, static_cast<unsigned>(delta), width)); | |||||
} | |||||
template <> | |||||
__forceinline__ __device__ dt_bfloat16 | |||||
CudaShuffleDownSync(dt_bfloat16 val, int delta, int width) { | |||||
return dt_bfloat16(__shfl_down(val, static_cast<unsigned>(delta), width)); | |||||
} | |||||
template <typename T, int VecSize> | |||||
struct alignas(sizeof(T) * VecSize) VectorType { | |||||
T val[VecSize]; | |||||
}; | |||||
template <typename T> | |||||
struct AddFunctor { | |||||
inline T initial() { return static_cast<T>(0.0f); } | |||||
__device__ __forceinline__ T operator()(const T a, const T b) const { | |||||
return b + a; | |||||
} | |||||
}; | |||||
template <typename T, typename ReduceOp> | |||||
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) { | |||||
for (int stride = WARP_SIZE / 2; stride > 0; stride >>= 1) { | |||||
T temp = CudaShuffleDownSync<T>(val, stride); | |||||
val = reducer(val, temp); | |||||
} | |||||
return val; | |||||
} | |||||
template <typename T, typename ReduceOp> | |||||
__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) { | |||||
__syncthreads(); | |||||
__shared__ T shared[64]; | |||||
int block_dim_x = blockDim.x; | |||||
if (blockDim.x > WARP_SIZE) { | |||||
block_dim_x = blockDim.x / WARP_SIZE; | |||||
int lane = threadIdx.x % WARP_SIZE; | |||||
int tid = threadIdx.y * blockDim.x + threadIdx.x; | |||||
int wid = tid / WARP_SIZE; | |||||
int bid = threadIdx.y; | |||||
val = WarpReduce<T, ReduceOp>(val, reducer); | |||||
if (lane == 0) { | |||||
shared[wid] = val; | |||||
} | |||||
__syncthreads(); | |||||
val = shared[bid * block_dim_x + lane]; | |||||
} | |||||
for (int stride = 1; stride < block_dim_x; stride <<= 1) { | |||||
T temp = CudaShuffleDownSync(val, stride); | |||||
val = reducer(val, temp); | |||||
} | |||||
if (threadIdx.x == 0) { | |||||
shared[threadIdx.y] = val; | |||||
} | |||||
__syncthreads(); | |||||
return shared[threadIdx.y]; | |||||
} | |||||
template <typename T> | |||||
__device__ __forceinline__ void ReduceMeanAndVar( | |||||
T* mean, T* var, T x_mean, T x_var, int size) { | |||||
const int nc = blockIdx.x; | |||||
x_mean = BlockXReduce<T, AddFunctor<T>>(x_mean, AddFunctor<T>()); | |||||
x_var = BlockXReduce<T, AddFunctor<T>>(x_var, AddFunctor<T>()); | |||||
__syncthreads(); | |||||
if (threadIdx.x == 0) { | |||||
mean[nc] = static_cast<T>(x_mean / size); | |||||
var[nc] = static_cast<T>(x_var / size); | |||||
} | |||||
} | |||||
template <typename T, typename T_ACC, int VecSize, int Num> | |||||
__device__ __forceinline__ void ThreadReduce( | |||||
Array<const T*, Num> arrs, int size, const int offset, T_ACC* out_mean, | |||||
T_ACC* out_var) { | |||||
const T* x = arrs[0]; | |||||
const T* y; | |||||
if (Num == 2) { | |||||
y = arrs[1]; | |||||
} | |||||
using VecT = VectorType<T, VecSize>; | |||||
int tid = threadIdx.x; | |||||
if (offset > 0) { | |||||
x -= offset; | |||||
if (Num == 2) { | |||||
y -= offset; | |||||
} | |||||
size += offset; | |||||
if (tid >= offset) { | |||||
if (Num == 1) { | |||||
*out_mean += x[tid]; | |||||
*out_var += x[tid] * x[tid]; | |||||
} else if (Num == 2) { | |||||
*out_mean += y[tid]; | |||||
*out_var += y[tid] * x[tid]; | |||||
} | |||||
} | |||||
size -= blockDim.x; | |||||
x += blockDim.x; | |||||
if (Num == 2) { | |||||
y += blockDim.x; | |||||
} | |||||
} | |||||
int remain = size % (VecSize * blockDim.x); | |||||
T ins_x[VecSize]; | |||||
T ins_y[VecSize]; | |||||
VecT* ins_vec_x = reinterpret_cast<VecT*>(&ins_x); | |||||
VecT* ins_vec_y = reinterpret_cast<VecT*>(&ins_y); | |||||
// vector part | |||||
for (; VecSize * tid < (size - remain); tid += blockDim.x) { | |||||
*ins_vec_x = reinterpret_cast<const VecT*>(x)[tid]; | |||||
if (Num == 2) { | |||||
*ins_vec_y = reinterpret_cast<const VecT*>(y)[tid]; | |||||
} | |||||
#pragma unroll | |||||
for (int i = 0; i < VecSize; ++i) { | |||||
if (Num == 1) { | |||||
*out_mean += ins_x[i]; | |||||
*out_var += ins_x[i] * ins_x[i]; | |||||
} else if (Num == 2) { | |||||
*out_mean += ins_y[i]; | |||||
*out_var += ins_y[i] * ins_x[i]; | |||||
} | |||||
} | |||||
} | |||||
// scalar part | |||||
tid = size - remain + threadIdx.x; | |||||
for (; tid < size; tid += blockDim.x) { | |||||
if (Num == 1) { | |||||
*out_mean += x[tid]; | |||||
*out_var += x[tid] * x[tid]; | |||||
} else if (Num == 2) { | |||||
*out_mean += y[tid]; | |||||
*out_var += y[tid] * x[tid]; | |||||
} | |||||
} | |||||
} | |||||
template <typename T, typename T_ACC> | |||||
__global__ void ScalarGetMeanAndVar(const T* x, T_ACC* mean, T_ACC* var, int size) { | |||||
int i = blockIdx.x; | |||||
T_ACC x_mean = static_cast<T_ACC>(0); | |||||
T_ACC x_var = static_cast<T_ACC>(0); | |||||
for (int j = threadIdx.x; j < size; j += blockDim.x) { | |||||
T val; | |||||
val = x[i * size + j]; | |||||
x_mean += val; | |||||
x_var += val * val; | |||||
} | |||||
ReduceMeanAndVar<T_ACC>(mean, var, x_mean, x_var, size); | |||||
} | |||||
template <typename T, typename T_ACC, int VecSize> | |||||
__global__ void VectorizedGetMeanAndVar(const T* x, T_ACC* mean, T_ACC* var, int size) { | |||||
int i = blockIdx.x; | |||||
T_ACC x_mean = static_cast<T_ACC>(0); | |||||
T_ACC x_var = static_cast<T_ACC>(0); | |||||
x += i * size; | |||||
const int input_offset = ((uint64_t)x) % 16 / sizeof(T); | |||||
Array<const T*, 1> ins; | |||||
ins[0] = x; | |||||
ThreadReduce<T, T_ACC, VecSize, 1>(ins, size, input_offset, &x_mean, &x_var); | |||||
ReduceMeanAndVar<T_ACC>(mean, var, x_mean, x_var, size); | |||||
} | |||||
template <typename T, typename T_ACC> | |||||
__global__ void GroupNormForward( | |||||
const T* x, const T_ACC* mean, const T_ACC* var, const T* scale, const T* bias, | |||||
int N, int C, int W, int imsize, int groups, int group_size, float epsilon, | |||||
T* y, T_ACC* real_var) { | |||||
int gid = blockIdx.y; | |||||
int cid = blockIdx.x; | |||||
int bid = blockIdx.z; | |||||
int ccid = gid * group_size + cid; | |||||
if (ccid >= C) | |||||
return; | |||||
auto ng = bid * groups + gid; | |||||
T_ACC x_mean = mean[ng]; | |||||
T_ACC x_var = var[ng]; | |||||
x_var = x_var - x_mean * x_mean; | |||||
T_ACC var_inv = rsqrt(x_var + epsilon); | |||||
if (cid == 0 && threadIdx.x == 0) { | |||||
real_var[ng] = x_var; | |||||
} | |||||
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { | |||||
T val; | |||||
int index = (bid * C + ccid) * imsize + imid; | |||||
val = x[index]; | |||||
val = (val - x_mean) * var_inv; | |||||
if (scale != nullptr) { | |||||
val *= scale[ccid]; | |||||
} | |||||
if (bias != nullptr) { | |||||
val += bias[ccid]; | |||||
} | |||||
y[index] = val; | |||||
} | |||||
} | |||||
template <typename T, typename T_ACC> | |||||
void forward( | |||||
T* src, T* weight, T* bias, T* dst, T_ACC* mean, T_ACC* rstd, T_ACC* temp_rstd, | |||||
T_ACC eps, int group, int N, int C, int W, int imsize, cudaStream_t stream) { | |||||
auto group_size = C / group; | |||||
int block_size = std::min(1024, imsize); | |||||
dim3 grid(group_size, group, N); | |||||
dim3 threads(block_size, 1, 1); | |||||
int size = group_size * imsize; | |||||
constexpr int vec_size = sizeof(float4) / sizeof(T); | |||||
int max_block_size = std::min(size / vec_size, 1024); | |||||
int block_size_temp = 1; | |||||
while (block_size_temp < max_block_size) { | |||||
block_size_temp *= 2; | |||||
} | |||||
block_size_temp = std::max(block_size_temp, WARP_SIZE); | |||||
dim3 grids(N * group); | |||||
dim3 blocks(block_size_temp); | |||||
if (size < vec_size * block_size_temp) { | |||||
ScalarGetMeanAndVar<T, T_ACC> | |||||
<<<grids, blocks, 0, stream>>>(src, mean, temp_rstd, size); | |||||
after_kernel_launch(); | |||||
} else { | |||||
VectorizedGetMeanAndVar<T, T_ACC, vec_size> | |||||
<<<grids, blocks, 0, stream>>>(src, mean, temp_rstd, size); | |||||
after_kernel_launch(); | |||||
} | |||||
GroupNormForward<T, T_ACC><<<grid, threads, 0, stream>>>( | |||||
src, mean, temp_rstd, weight, bias, N, C, W, imsize, group, group_size, eps, | |||||
dst, rstd); | |||||
after_kernel_launch(); | |||||
} | |||||
// ================================ group_norm backward =========================== | |||||
// implementation of groupnorm_backward from | |||||
// https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu#L253 | |||||
template <typename T, typename T_ACC> | |||||
__global__ void GetDsDbCUDAKernel(int imsize, const T* x, const T* dy, T* ds, T* db) { | |||||
const int nc = blockIdx.x; | |||||
T ds_sum = static_cast<T>(0); | |||||
T db_sum = static_cast<T>(0); | |||||
for (int i = threadIdx.x; i < imsize; i += blockDim.x) { | |||||
const int index = nc * imsize + i; | |||||
ds_sum += dy[index] * x[index]; | |||||
db_sum += dy[index]; | |||||
} | |||||
ReduceMeanAndVar<T>(db, ds, db_sum, ds_sum, 1); | |||||
} | |||||
template <typename T, typename T_ACC> | |||||
__global__ void GetBiasGradientCUDAKernel( | |||||
int N, int C, int group, T_ACC epsilon, const T_ACC* mean, const T_ACC* var, | |||||
const T* ds, const T* db, T* d_scale, T* d_bias) { | |||||
const int c = blockIdx.x * blockDim.x + threadIdx.x; | |||||
if (c < C) { | |||||
const int G = group; | |||||
const int D = C / G; | |||||
T sum1 = static_cast<T>(0); | |||||
T sum2 = static_cast<T>(0); | |||||
for (int n = 0; n < N; ++n) { | |||||
const int nc = n * C + c; | |||||
const int ng = n * G + c / D; | |||||
sum1 += (d_scale == nullptr) | |||||
? T(0) | |||||
: ((ds[nc] - db[nc] * static_cast<T>(mean[ng])) * | |||||
static_cast<T>(rsqrt((float)(var[ng] + epsilon)))); | |||||
sum2 += (d_bias == nullptr) ? T(0) : db[nc]; | |||||
} | |||||
if (d_scale != nullptr) { | |||||
d_scale[c] = sum1; | |||||
} | |||||
if (d_bias != nullptr) { | |||||
d_bias[c] = sum2; | |||||
} | |||||
} | |||||
} | |||||
template <typename T> | |||||
__inline__ MEGDNN_DEVICE T warp_reduce_sum(T val) { | |||||
#pragma unroll | |||||
for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { | |||||
val += __shfl_down(val, offset, warpSize); | |||||
} | |||||
return val; | |||||
} | |||||
template <typename T> | |||||
__inline__ MEGDNN_DEVICE T BlockReduceSum(T val, T* shared) { | |||||
const int lid = threadIdx.x % warpSize; | |||||
const int wid = threadIdx.x / warpSize; | |||||
val = warp_reduce_sum(val); | |||||
__syncthreads(); | |||||
if (lid == 0) { | |||||
shared[wid] = val; | |||||
} | |||||
__syncthreads(); | |||||
val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : T(0); | |||||
if (wid == 0) { | |||||
val = warp_reduce_sum(val); | |||||
} | |||||
return val; | |||||
} | |||||
template <typename T, typename T_ACC, int BlockDim> | |||||
__global__ void GetBackwardParamsCUDAKernel( | |||||
int imsize, int groups, int group_size, T_ACC epsilon, const T_ACC* mean, | |||||
const T_ACC* var, const T* scale, const T* ds, const T* db, T* p1, T* p2, | |||||
T* p3) { | |||||
const int n = blockIdx.x; | |||||
const int g = blockIdx.y; | |||||
const int ng = n * groups + g; | |||||
T sum1 = static_cast<T>(0); | |||||
T sum2 = static_cast<T>(0); | |||||
T var_inv = static_cast<T>(rsqrt(var[ng] + epsilon)); | |||||
for (int64_t i = threadIdx.x; i < group_size; i += blockDim.x) { | |||||
const int64_t index = ng * group_size + i; | |||||
const int64_t c = g * group_size + i; | |||||
const T scale_v = scale == nullptr ? T(1) : static_cast<T>(scale[c]); | |||||
sum1 += ds[index] * scale_v; | |||||
sum2 += db[index] * scale_v; | |||||
const T scale_c = scale == nullptr ? T(0) : static_cast<T>(scale[c]); | |||||
p1[index] = scale_c * var_inv; | |||||
} | |||||
__shared__ T ds_shared[WARP_SIZE]; | |||||
__shared__ T db_shared[WARP_SIZE]; | |||||
sum1 = BlockReduceSum<T>(sum1, ds_shared); | |||||
sum2 = BlockReduceSum<T>(sum2, db_shared); | |||||
if (threadIdx.x == 0) { | |||||
const T s = T(1) / static_cast<T>(group_size * imsize); | |||||
const T x = (sum2 * static_cast<T>(mean[ng]) - sum1) * static_cast<T>(var_inv) * | |||||
static_cast<T>(var_inv) * static_cast<T>(var_inv) * s; | |||||
p2[ng] = x; | |||||
p3[ng] = -x * static_cast<T>(mean[ng]) - sum2 * static_cast<T>(var_inv) * s; | |||||
} | |||||
} | |||||
template <typename T, typename T_ACC> | |||||
__global__ void GetXGradientCUDAKernel( | |||||
int imsize, int C, int group_size, int groups, T* p1, T* p2, T* p3, const T* x, | |||||
const T* dy, T* dx) { | |||||
int cid = blockIdx.x; | |||||
int gid = blockIdx.y; | |||||
int bid = blockIdx.z; | |||||
int ccid = bid * C + gid * group_size + cid; | |||||
int ng = bid * groups + gid; | |||||
int nc = gid * group_size + cid; | |||||
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { | |||||
int index = (bid * C + nc) * imsize + imid; | |||||
dx[index] = p1[ccid] * dy[index] + p2[ng] * x[index] + p3[ng]; | |||||
} | |||||
} | |||||
template <typename T, typename T_ACC> | |||||
void backward( | |||||
const T* dY_data, const T* X_data, const T_ACC* mean_data, | |||||
const T_ACC* rstd_data, const T* weight_data, T* dX_data, T* dweight_data, | |||||
T* dbias_data, T_ACC eps, int group, int N, int C, int imsize, T* ds, T* db, | |||||
T* p1, T* p2, T* p3, cudaStream_t stream) { | |||||
auto group_size = C / group; | |||||
int block_size = std::min(1024, imsize); | |||||
const int block_dims = 1024; | |||||
dim3 grid(group_size, group, N); | |||||
dim3 threads(block_size, 1, 1); | |||||
const int max_num_threads = 1024; | |||||
int max_block_size = std::min(imsize, max_num_threads); | |||||
int block_size_temp = 1; | |||||
while (block_size_temp < max_block_size) { | |||||
block_size_temp *= 2; | |||||
} | |||||
block_size_temp = std::max(block_size_temp, WARP_SIZE); | |||||
dim3 blocks(block_size_temp); | |||||
GetDsDbCUDAKernel<T, T_ACC> | |||||
<<<N * C, blocks, 0, stream>>>(imsize, X_data, dY_data, ds, db); | |||||
after_kernel_launch(); | |||||
bool flag = weight_data != nullptr ? true : false; | |||||
if (flag) { | |||||
const int block = 256; | |||||
GetBiasGradientCUDAKernel<T, T_ACC> | |||||
<<<(C + block - 1) / block, block, 0, stream>>>( | |||||
N, C, group, eps, mean_data, rstd_data, ds, db, dweight_data, | |||||
dbias_data); | |||||
after_kernel_launch(); | |||||
} | |||||
GetBackwardParamsCUDAKernel<T, T_ACC, block_dims> | |||||
<<<dim3(N, group), block_dims, 0, stream>>>( | |||||
imsize, group, group_size, eps, mean_data, rstd_data, weight_data, | |||||
ds, db, p1, p2, p3); | |||||
after_kernel_launch(); | |||||
GetXGradientCUDAKernel<T, T_ACC><<<grid, threads, 0, stream>>>( | |||||
imsize, C, group_size, group, p1, p2, p3, X_data, dY_data, dX_data); | |||||
after_kernel_launch(); | |||||
} | |||||
#define INST(T, T_ACC) \ | |||||
template void forward<T, T_ACC>( \ | |||||
T*, T*, T*, T*, T_ACC*, T_ACC*, T_ACC*, T_ACC, int, int, int, int, int, \ | |||||
cudaStream_t); \ | |||||
template void backward<T, T_ACC>( \ | |||||
const T*, const T*, const T_ACC*, const T_ACC*, const T*, T*, T*, T*, \ | |||||
T_ACC, int, int, int, int, T*, T*, T*, T*, T*, cudaStream_t); | |||||
INST(dt_float32, dt_float32) | |||||
INST(dt_float16, dt_float32) | |||||
INST(dt_bfloat16, dt_float32) | |||||
#undef INST | |||||
} // namespace group_norm | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,24 @@ | |||||
#pragma once | |||||
#include <cuda_runtime_api.h> | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace group_norm { | |||||
template <typename T, typename T_ACC> | |||||
void forward( | |||||
T* X, T* gamma, T* beta, T* Y, T_ACC* mean, T_ACC* rstd, T_ACC* tesmp_rstd, | |||||
T_ACC eps, int group, int N, int C, int W, int imsize, cudaStream_t stream); | |||||
template <typename T, typename T_ACC> | |||||
void backward( | |||||
const T* dY_data, const T* X_data, const T_ACC* mean_data, | |||||
const T_ACC* rstd_data, const T* gamma_data, T* dX_data, T* dgamma_data, | |||||
T* dbeta_data, T_ACC eps, int group, int N, int C, int imsize, T* ds, T* db, | |||||
T* p1, T* p2, T* p3, cudaStream_t stream); | |||||
} // namespace group_norm | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,143 @@ | |||||
#include "src/cuda/group_norm/opr_impl.h" | |||||
#include "src/cuda/group_norm/group_norm_cuda.cuh" | |||||
#include "src/cuda/utils.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
size_t GroupNormForwardImpl::get_workspace_in_bytes( | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout& rstd) { | |||||
size_t N = rstd.shape[0]; | |||||
size_t G = rstd.shape[1]; | |||||
return get_workspace_bundle(N, G, rstd.dtype.size()).total_size_in_bytes(); | |||||
} | |||||
WorkspaceBundle GroupNormForwardImpl::get_workspace_bundle( | |||||
size_t N, size_t G, size_t dtype_size, void* raw_ptr) { | |||||
return {raw_ptr, {N * G * dtype_size}, handle()->alignment_requirement()}; | |||||
} | |||||
void GroupNormForwardImpl::exec( | |||||
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||||
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||||
_megdnn_workspace workspace) { | |||||
check_exec( | |||||
data.layout, weight.layout, bias.layout, dst.layout, mean.layout, | |||||
rstd.layout, workspace.size); | |||||
auto p = param(); | |||||
using Format = param::GroupNorm::Format; | |||||
float eps = p.eps; | |||||
int group = p.group; | |||||
bool affine = p.affine; | |||||
auto layout = data.layout; | |||||
auto format = p.format; | |||||
size_t N, C, H, W, imsize; | |||||
if (data.layout.ndim == 4 && format == Format::NCHW) { | |||||
N = layout.shape[0]; | |||||
C = layout.shape[1]; | |||||
H = layout.shape[2]; | |||||
W = layout.shape[3]; | |||||
imsize = H * W; | |||||
} else { | |||||
megdnn_throw(ssprintf("Unspport groupnorm input")); | |||||
} | |||||
auto stream = cuda_stream(handle()); | |||||
using namespace ::megdnn::cuda::group_norm; | |||||
auto wbundle = | |||||
get_workspace_bundle(N, group, rstd.layout.dtype.size(), workspace.raw_ptr); | |||||
#define cb(DType) \ | |||||
if (data.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
using T_ACC = float; \ | |||||
T_ACC* temp_rstd = wbundle.get_workspace(0).ptr<T_ACC>(); \ | |||||
forward<T, T_ACC>( \ | |||||
data.ptr<T>(), affine ? weight.ptr<T>() : nullptr, \ | |||||
affine ? bias.ptr<T>() : nullptr, dst.ptr<T>(), mean.ptr<T_ACC>(), \ | |||||
rstd.ptr<T_ACC>(), temp_rstd, static_cast<T_ACC>(eps), \ | |||||
static_cast<int>(group), static_cast<int>(N), static_cast<int>(C), \ | |||||
static_cast<int>(W), static_cast<int>(imsize), stream); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
size_t GroupNormBackwardImpl::get_workspace_in_bytes( | |||||
const TensorLayout&, const TensorLayout& data, const TensorLayout&, | |||||
const TensorLayout& mean, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&) { | |||||
size_t N = data.shape[0]; | |||||
size_t C = data.shape[1]; | |||||
size_t G = mean.shape[1]; | |||||
return get_workspace_bundle(N, C, G, data.dtype.size()).total_size_in_bytes(); | |||||
} | |||||
WorkspaceBundle GroupNormBackwardImpl::get_workspace_bundle( | |||||
size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr) { | |||||
return {raw_ptr, | |||||
{N * C * dtype_size, N * C * dtype_size, N * C * dtype_size, | |||||
N * G * dtype_size, N * G * dtype_size}, | |||||
handle()->alignment_requirement()}; | |||||
} | |||||
void GroupNormBackwardImpl::exec( | |||||
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||||
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||||
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||||
_megdnn_workspace workspace) { | |||||
check_exec( | |||||
diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, | |||||
ddata.layout, dweight.layout, dbias.layout, workspace.size); | |||||
auto p = param(); | |||||
using Format = param::GroupNorm::Format; | |||||
bool affine = p.affine; | |||||
float eps = p.eps; | |||||
int group = p.group; | |||||
auto layout = data.layout; | |||||
auto format = p.format; | |||||
size_t N, C, H, W, imsize; | |||||
if (layout.ndim == 4 && format == Format::NCHW) { | |||||
N = layout.shape[0]; | |||||
C = layout.shape[1]; | |||||
H = layout.shape[2]; | |||||
W = layout.shape[3]; | |||||
imsize = H * W; | |||||
} else { | |||||
megdnn_throw(ssprintf("Unspport groupnorm input")); | |||||
} | |||||
auto stream = cuda_stream(handle()); | |||||
using namespace ::megdnn::cuda::group_norm; | |||||
auto wbundle = get_workspace_bundle( | |||||
N, C, group, data.layout.dtype.size(), workspace.raw_ptr); | |||||
#define cb(DType) \ | |||||
if (data.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
using T_ACC = float; \ | |||||
T* ds = wbundle.get_workspace(0).ptr<T>(); \ | |||||
T* db = wbundle.get_workspace(1).ptr<T>(); \ | |||||
T* p1 = wbundle.get_workspace(2).ptr<T>(); \ | |||||
T* p2 = wbundle.get_workspace(3).ptr<T>(); \ | |||||
T* p3 = wbundle.get_workspace(4).ptr<T>(); \ | |||||
backward<T, T_ACC>( \ | |||||
diff.ptr<T>(), data.ptr<T>(), mean.ptr<T_ACC>(), rstd.ptr<T_ACC>(), \ | |||||
affine ? weight.ptr<T>() : nullptr, ddata.ptr<T>(), \ | |||||
affine ? dweight.ptr<T>() : nullptr, \ | |||||
affine ? dbias.ptr<T>() : nullptr, static_cast<T_ACC>(eps), \ | |||||
static_cast<int>(group), static_cast<int>(N), static_cast<int>(C), \ | |||||
static_cast<int>(imsize), ds, db, p1, p2, p3, stream); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,47 @@ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/cuda/cudnn_wrapper.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
class GroupNormForwardImpl final : public GroupNormForward { | |||||
public: | |||||
using GroupNormForward::GroupNormForward; | |||||
void exec( | |||||
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||||
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout& rstd) override; | |||||
private: | |||||
WorkspaceBundle get_workspace_bundle( | |||||
size_t N, size_t G, size_t dtype_size, void* raw_ptr = nullptr); | |||||
}; | |||||
class GroupNormBackwardImpl final : public GroupNormBackward { | |||||
public: | |||||
using GroupNormBackward::GroupNormBackward; | |||||
void exec( | |||||
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||||
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||||
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout&, const TensorLayout& data, const TensorLayout&, | |||||
const TensorLayout& mean, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&) override; | |||||
private: | |||||
WorkspaceBundle get_workspace_bundle( | |||||
size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr = nullptr); | |||||
}; | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -32,6 +32,7 @@ | |||||
#include "src/cuda/flip/opr_impl.h" | #include "src/cuda/flip/opr_impl.h" | ||||
#include "src/cuda/gaussian_blur/opr_impl.h" | #include "src/cuda/gaussian_blur/opr_impl.h" | ||||
#include "src/cuda/group_local/opr_impl.h" | #include "src/cuda/group_local/opr_impl.h" | ||||
#include "src/cuda/group_norm/opr_impl.h" | |||||
#include "src/cuda/images2neibs/opr_impl.h" | #include "src/cuda/images2neibs/opr_impl.h" | ||||
#include "src/cuda/indexing_multi_axis_vec/opr_impl.h" | #include "src/cuda/indexing_multi_axis_vec/opr_impl.h" | ||||
#include "src/cuda/indexing_one_hot/opr_impl.h" | #include "src/cuda/indexing_one_hot/opr_impl.h" | ||||
@@ -163,6 +164,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalBackwardData); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalBackwardData); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalBackwardFilter); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalBackwardFilter); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupNormForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupNormBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Flip); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(Flip); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Rotate); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(Rotate); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ROICopy); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(ROICopy); | ||||
@@ -0,0 +1,206 @@ | |||||
#include "src/naive/group_norm/opr_impl.h" | |||||
#include <algorithm> | |||||
#include "src/common/utils.h" | |||||
#include "src/naive/handle.h" | |||||
using namespace megdnn; | |||||
using namespace naive; | |||||
namespace { | |||||
using Param = megdnn::GroupNorm::Param; | |||||
template <typename T, typename T_ACC = float> | |||||
void forward( | |||||
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||||
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||||
const Param& param) { | |||||
float eps = param.eps; | |||||
bool affine = param.affine; | |||||
size_t N = data.layout.shape[0]; | |||||
size_t C = data.layout.shape[1]; | |||||
size_t HxW = data.layout.shape[2] * data.layout.shape[3]; | |||||
const int64_t G = param.group; | |||||
size_t D = C / G; | |||||
size_t inner_size = D * HxW; | |||||
for (size_t i = 0; i < N * G; i++) { | |||||
T_ACC slice_sum = static_cast<T>(0.0f); | |||||
for (size_t j = 0; j < inner_size; j++) { | |||||
auto value = data.ptr<T>()[i * inner_size + j]; | |||||
slice_sum += value; | |||||
} | |||||
T_ACC slice_mean = static_cast<T>(slice_sum / inner_size); | |||||
T_ACC slice_var = static_cast<T>(0.0f); | |||||
for (size_t j = 0; j < inner_size; j++) { | |||||
slice_var += (data.ptr<T>()[i * inner_size + j] - slice_mean) * | |||||
(data.ptr<T>()[i * inner_size + j] - slice_mean); | |||||
} | |||||
slice_var = slice_var / inner_size; | |||||
T_ACC slice_std = static_cast<T>(1.0f) / static_cast<T>(sqrt(slice_var + eps)); | |||||
if (affine) { | |||||
const int64_t g = i % G; | |||||
for (size_t j = 0; j < D; j++) { | |||||
const int64_t c = g * D + j; | |||||
T_ACC s = slice_std * weight.ptr<T>()[c]; | |||||
T_ACC b = -s * slice_mean + bias.ptr<T>()[c]; | |||||
for (size_t k = 0; k < HxW; k++) { | |||||
dst.ptr<T>()[(i * D + j) * HxW + k] = | |||||
s * data.ptr<T>()[(i * D + j) * HxW + k] + b; | |||||
} | |||||
} | |||||
} else { | |||||
for (size_t j = 0; j < inner_size; j++) { | |||||
dst.ptr<T>()[i * inner_size + j] = | |||||
(data.ptr<T>()[i * inner_size + j] - slice_mean) / slice_std; | |||||
} | |||||
} | |||||
mean.ptr<T_ACC>()[i] = static_cast<T_ACC>(slice_mean); | |||||
rstd.ptr<T_ACC>()[i] = static_cast<T_ACC>(slice_var); | |||||
} | |||||
} | |||||
template <typename T, typename T_ACC = float> | |||||
void backward( | |||||
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||||
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||||
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, const Param& param, | |||||
WorkspaceBundle wbundle) { | |||||
bool affine = param.affine; | |||||
size_t N = data.layout.shape[0]; | |||||
size_t C = data.layout.shape[1]; | |||||
size_t G = param.group; | |||||
float eps = param.eps; | |||||
size_t HxW = data.layout.shape[2] * data.layout.shape[3]; | |||||
T* ds = wbundle.get_workspace(0).ptr<T>(); | |||||
T* db = wbundle.get_workspace(1).ptr<T>(); | |||||
T* slice_std = wbundle.get_workspace(2).ptr<T>(); | |||||
for (size_t i = 0; i < N * G; i++) { | |||||
slice_std[i] = | |||||
static_cast<T>(1.0f) / static_cast<T>(sqrt(rstd.ptr<T_ACC>()[i] + eps)); | |||||
} | |||||
for (size_t i = 0; i < N * C; i++) { | |||||
T ds_data = static_cast<T>(0.0f); | |||||
T db_data = static_cast<T>(0.0f); | |||||
for (size_t j = 0; j < HxW; j++) { | |||||
db_data += diff.ptr<T>()[i * HxW + j]; | |||||
ds_data += data.ptr<T>()[i * HxW + j] * diff.ptr<T>()[i * HxW + j]; | |||||
} | |||||
ds[i] = ds_data; | |||||
db[i] = db_data; | |||||
} | |||||
size_t D = C / G; | |||||
const T s = T(1) / static_cast<T>(D * HxW); | |||||
for (size_t i = 0; i < N * G; i++) { | |||||
const int64_t g = i % G; | |||||
T ds_v = static_cast<T>(0.0f); | |||||
T db_v = static_cast<T>(0.0f); | |||||
for (size_t j = 0; j < D; j += 1) { | |||||
auto weight_v = affine ? weight.ptr<T>()[g * D + j] : static_cast<T>(1.0f); | |||||
ds_v += ds[i * D + j] * weight_v; | |||||
db_v += db[i * D + j] * weight_v; | |||||
} | |||||
auto c2 = (db_v * mean.ptr<T_ACC>()[i] - ds_v) * slice_std[i] * slice_std[i] * | |||||
slice_std[i] * s; | |||||
auto c3 = -c2 * mean.ptr<T_ACC>()[i] - db_v * slice_std[i] * s; | |||||
for (size_t j = 0; j < D; j++) { | |||||
const int64_t c = g * D + j; | |||||
auto weight_v = affine ? weight.ptr<T>()[c] : static_cast<T>(1.0f); | |||||
auto c1 = slice_std[i] * weight_v; | |||||
for (size_t k = 0; k < HxW; k++) { | |||||
ddata.ptr<T>()[(i * D + j) * HxW + k] = | |||||
c1 * diff.ptr<T>()[(i * D + j) * HxW + k] + | |||||
c2 * data.ptr<T>()[(i * D + j) * HxW + k] + c3; | |||||
} | |||||
} | |||||
} | |||||
if (affine) { | |||||
for (size_t i = 0; i < C; ++i) { | |||||
dweight.ptr<T>()[i] = 0; | |||||
dbias.ptr<T>()[i] = 0; | |||||
} | |||||
for (size_t i = 0; i < N * G; i++) { | |||||
auto g = i % G; | |||||
for (size_t j = 0; j < D; j++) { | |||||
auto c = g * D + j; | |||||
dweight.ptr<T>()[c] += | |||||
(ds[i * D + j] - db[i * D + j] * mean.ptr<T_ACC>()[i]) * | |||||
slice_std[i]; | |||||
} | |||||
} | |||||
for (size_t i = 0; i < N; i++) { | |||||
for (size_t j = 0; j < C; j++) { | |||||
dbias.ptr<T>()[j] += db[i * C + j]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace naive { | |||||
size_t GroupNormBackwardImpl::get_workspace_in_bytes( | |||||
const TensorLayout&, const TensorLayout& data, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout& rstd, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&) { | |||||
size_t N = data.shape[0]; | |||||
size_t C = data.shape[1]; | |||||
size_t G = rstd.shape[1]; | |||||
return get_workspace_bundle(N, C, G, data.dtype.size()).total_size_in_bytes(); | |||||
} | |||||
WorkspaceBundle GroupNormBackwardImpl::get_workspace_bundle( | |||||
size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr) { | |||||
return {raw_ptr, | |||||
{N * C * dtype_size, N * C * dtype_size, N * G * dtype_size}, | |||||
handle()->alignment_requirement()}; | |||||
} | |||||
void GroupNormForwardImpl::exec( | |||||
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||||
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||||
_megdnn_workspace workspace) { | |||||
check_exec( | |||||
data.layout, weight.layout, bias.layout, dst.layout, mean.layout, | |||||
rstd.layout, workspace.size); | |||||
#define cb(DType) \ | |||||
if (data.layout.dtype == DType()) { \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(forward<typename DTypeTrait<DType>::ctype>( \ | |||||
data, weight, bias, dst, mean, rstd, param())); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
void GroupNormBackwardImpl::exec( | |||||
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||||
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||||
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||||
_megdnn_workspace workspace) { | |||||
check_exec( | |||||
diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, | |||||
ddata.layout, dweight.layout, dbias.layout, workspace.size); | |||||
#define cb(DType) \ | |||||
if (data.layout.dtype == DType()) { \ | |||||
auto wbundle = get_workspace_bundle( \ | |||||
data.layout.shape[0], data.layout.shape[1], rstd.layout.shape[1], \ | |||||
sizeof(DTypeTrait<DType>::ctype), workspace.raw_ptr); \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(backward<typename DTypeTrait<DType>::ctype>( \ | |||||
diff, data, weight, mean, rstd, ddata, dweight, dbias, param(), \ | |||||
wbundle)); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,44 @@ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
class GroupNormForwardImpl final : public GroupNormForward { | |||||
public: | |||||
using GroupNormForward::GroupNormForward; | |||||
void exec( | |||||
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||||
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
class GroupNormBackwardImpl final : public GroupNormBackward { | |||||
public: | |||||
using GroupNormBackward::GroupNormBackward; | |||||
void exec( | |||||
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||||
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||||
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout&, const TensorLayout& data, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout& rstd, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&) override; | |||||
private: | |||||
WorkspaceBundle get_workspace_bundle( | |||||
size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr = nullptr); | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -34,6 +34,7 @@ | |||||
#include "src/naive/flip/opr_impl.h" | #include "src/naive/flip/opr_impl.h" | ||||
#include "src/naive/gaussian_blur/opr_impl.h" | #include "src/naive/gaussian_blur/opr_impl.h" | ||||
#include "src/naive/group_local/opr_impl.h" | #include "src/naive/group_local/opr_impl.h" | ||||
#include "src/naive/group_norm/opr_impl.h" | |||||
#include "src/naive/images2neibs/opr_impl.h" | #include "src/naive/images2neibs/opr_impl.h" | ||||
#include "src/naive/indexing_multi_axis_vec/opr_impl.h" | #include "src/naive/indexing_multi_axis_vec/opr_impl.h" | ||||
#include "src/naive/indexing_one_hot/opr_impl.h" | #include "src/naive/indexing_one_hot/opr_impl.h" | ||||
@@ -0,0 +1,44 @@ | |||||
#include "test/cuda/fixture.h" | |||||
#include "test/common/checker.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
TEST_F(CUDA, GROUPNORM_FORWARD) { | |||||
using Param = GroupNormForward::Param; | |||||
Param param; | |||||
param.affine = true; | |||||
param.eps = 1e-6; | |||||
Checker<GroupNormForward> checker(handle_cuda()); | |||||
checker.set_epsilon(1e-2); | |||||
auto run = [&](DType d) { | |||||
for (size_t group : {1, 3}) | |||||
for (size_t C : {6, 9}) { | |||||
param.group = group; | |||||
checker.set_param(param) | |||||
.set_dtype(0, d) | |||||
.set_dtype(1, d) | |||||
.set_dtype(2, d) | |||||
.set_dtype(3, d) | |||||
.set_dtype(4, dtype::Float32()) | |||||
.set_dtype(5, dtype::Float32()) | |||||
.execs({{2, C, 2, 1}, | |||||
{C}, | |||||
{C}, | |||||
{2, C, 2, 1}, | |||||
{2, group}, | |||||
{2, group}}); | |||||
} | |||||
}; | |||||
run(dtype::Float32()); | |||||
run(dtype::Float16()); | |||||
run(dtype::BFloat16()); | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,70 @@ | |||||
#include "megdnn/dtype.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/naive/fixture.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
TEST_F(NAIVE, GROUPNORM_FORWARD) { | |||||
Checker<GroupNorm> checker(handle(), true); | |||||
GroupNorm::Param param; | |||||
param.affine = true; | |||||
param.group = 3; | |||||
checker.set_param(param).exect( | |||||
Testcase{ | |||||
TensorValue( | |||||
{2, 3, 2, 1}, dtype::Float32(), | |||||
{3.3179, 0.109, -0.5855, 0.2566, -1.2897, 1.2683, -2.0587, | |||||
0.0711, -0.1169, 0.2509, -0.2393, 0.0876}), // input | |||||
TensorValue({3}, dtype::Float32(), {1., 1., 1.}), // hx | |||||
TensorValue({3}, dtype::Float32(), {0., 0., 0.}), // cx | |||||
{}, | |||||
{}, | |||||
{}}, | |||||
Testcase{ | |||||
{}, | |||||
{}, | |||||
{}, | |||||
TensorValue( | |||||
{2, 3, 2, 1}, dtype::Float32(), | |||||
{1., -1., -1., 1., -1., 1., -1., 1., -0.9999, 0.9999, | |||||
-0.9998, 0.9998}), // output | |||||
TensorValue( | |||||
{2, 3}, dtype::Float32(), | |||||
{1.7135, -0.1645, -0.0107, -0.9938, 0.067, | |||||
-0.0758}), // mean | |||||
TensorValue( | |||||
{2, 3}, dtype::Float32(), | |||||
{2.5742, 0.1772, 1.6358, 1.1340, 0.0338, 0.0267}), // var | |||||
}); | |||||
checker.set_param(param).exect( | |||||
Testcase{ | |||||
TensorValue( | |||||
{1, 3, 1, 2}, dtype::Float32(), | |||||
{-2.4348, -1.7948, 0.5223, 0.0932, -0.2955, | |||||
-0.0492}), // input | |||||
TensorValue({3}, dtype::Float32(), {1., 1., 1.}), // hx | |||||
TensorValue({3}, dtype::Float32(), {0., 0., 0.}), // cx | |||||
{}, | |||||
{}, | |||||
{}}, | |||||
Testcase{ | |||||
{}, | |||||
{}, | |||||
{}, | |||||
TensorValue( | |||||
{1, 3, 1, 2}, dtype::Float32(), | |||||
{-0.9999, 0.9999, 0.9999, -0.9999, -0.9997, | |||||
0.9997}), // output | |||||
TensorValue( | |||||
{1, 3}, dtype::Float32(), | |||||
{-2.1148, 0.3077, -0.1724}), // mean | |||||
TensorValue( | |||||
{1, 3}, dtype::Float32(), {0.1023, 0.0460, 0.0151}), // var | |||||
}); | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn |
@@ -60,6 +60,7 @@ __all__ = [ | |||||
"dropout", | "dropout", | ||||
"embedding", | "embedding", | ||||
"gelu", | "gelu", | ||||
"group_norm", | |||||
"hsigmoid", | "hsigmoid", | ||||
"hswish", | "hswish", | ||||
"indexing_one_hot", | "indexing_one_hot", | ||||
@@ -1202,6 +1203,33 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: | |||||
return output | return output | ||||
def group_norm( | |||||
inp: Tensor, | |||||
num_groups: int, | |||||
affine: bool, | |||||
weight: Optional[Tensor] = None, | |||||
bias: Optional[Tensor] = None, | |||||
eps: float = 1e-5, | |||||
): | |||||
r"""Applies Group Normalization over a mini-batch of inputs as described in | |||||
the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__ | |||||
Args: | |||||
inp: input tensor. | |||||
num_groups: number of groups to separate the channels into | |||||
affine: whether to use weight and bias | |||||
weight: must not be None when the affine is true | |||||
bias: must not be None when the affine is true | |||||
eps: a value added to the denominator for numerical stability. Default: 1e-5 | |||||
""" | |||||
op = builtin.GroupNorm(affine=affine, eps=eps, group=num_groups,) | |||||
if affine: | |||||
assert weight is not None and bias is not None | |||||
return apply(op, inp, weight, bias)[0] | |||||
else: | |||||
return apply(op, inp)[0] | |||||
def layer_norm( | def layer_norm( | ||||
inp: Tensor, | inp: Tensor, | ||||
normalized_shape: tuple, | normalized_shape: tuple, | ||||
@@ -34,21 +34,9 @@ class GroupNorm(Module): | |||||
zeros_(self.bias) | zeros_(self.bias) | ||||
def forward(self, x): | def forward(self, x): | ||||
N, C, H, W = x.shape | |||||
format = x.format | |||||
assert C == self.num_channels | |||||
x = x.reshape(N, self.num_groups, -1) | |||||
mean = x.mean(axis=2, keepdims=True) | |||||
var = (x * x).mean(axis=2, keepdims=True) - mean * mean | |||||
x = (x - mean) / F.sqrt(var + self.eps) | |||||
x = x.reshape(N, C, H, W) | |||||
if self.affine: | |||||
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) | |||||
# FIXME(czh): remove this after making it a builtin op. | |||||
if format == "nhwc": | |||||
x = mge.amp.convert_tensor_format(x, inplace=False) | |||||
x = F.nn.group_norm( | |||||
x, self.num_groups, self.affine, self.weight, self.bias, self.eps | |||||
) | |||||
return x | return x | ||||
def _module_info_string(self) -> str: | def _module_info_string(self) -> str: | ||||
@@ -8,12 +8,14 @@ import pytest | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import Parameter, Tensor, tensor | from megengine import Parameter, Tensor, tensor | ||||
from megengine.device import get_device_count | |||||
from megengine.module import ( | from megengine.module import ( | ||||
BatchNorm1d, | BatchNorm1d, | ||||
BatchNorm2d, | BatchNorm2d, | ||||
Conv1d, | Conv1d, | ||||
Conv2d, | Conv2d, | ||||
Dropout, | Dropout, | ||||
GroupNorm, | |||||
Linear, | Linear, | ||||
MaxPool2d, | MaxPool2d, | ||||
Module, | Module, | ||||
@@ -698,3 +700,67 @@ def test_module_compatible(): | |||||
assert ( | assert ( | ||||
old_attributes == current_attributes | old_attributes == current_attributes | ||||
), "Add or delete attributes in Module class may break compatibility of pickle serialization" | ), "Add or delete attributes in Module class may break compatibility of pickle serialization" | ||||
def test_grou_norm(): | |||||
class OriginGroupNormFunc(Module): | |||||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs): | |||||
super().__init__(**kwargs) | |||||
assert num_channels % num_groups == 0 | |||||
self.num_groups = num_groups | |||||
self.num_channels = num_channels | |||||
self.eps = eps | |||||
self.affine = affine | |||||
if self.affine: | |||||
self.weight = Parameter(np.ones(num_channels, dtype=np.float32)) | |||||
self.bias = Parameter(np.zeros(num_channels, dtype=np.float32)) | |||||
else: | |||||
self.weight = None | |||||
self.bias = None | |||||
def forward(self, x): | |||||
N, C, H, W = x.shape | |||||
x = x.reshape(N, self.num_groups, -1) | |||||
mean = x.mean(axis=2, keepdims=True) | |||||
var = (x * x).mean(axis=2, keepdims=True) - mean * mean | |||||
x = (x - mean) / F.sqrt(var + self.eps) | |||||
x = x.reshape(N, C, H, W) | |||||
if self.affine: | |||||
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape( | |||||
1, -1, 1, 1 | |||||
) | |||||
return x | |||||
inp = np.random.randn(2, 256, 10, 16).astype("float32") | |||||
mge_inp = Tensor(inp) | |||||
mge_m = GroupNorm(32, 256) | |||||
ori_inp = Tensor(inp) | |||||
ori_m = OriginGroupNormFunc(32, 256) | |||||
targets = np.array(2) | |||||
mge_gm = mge.autodiff.GradManager().attach(mge_m.parameters()) | |||||
ori_gm = mge.autodiff.GradManager().attach(ori_m.parameters()) | |||||
for i in range(2): | |||||
with mge_gm: | |||||
mge_output = mge_m(mge_inp) | |||||
loss = F.loss.square_loss( | |||||
mge_output.sum(), mge.tensor(targets, dtype=np.float32) | |||||
) | |||||
mge_gm.backward(loss) | |||||
with ori_gm: | |||||
ori_output = ori_m(ori_inp) | |||||
loss = F.loss.square_loss( | |||||
ori_output.sum(), mge.tensor(targets, dtype=np.float32) | |||||
) | |||||
ori_gm.backward(loss) | |||||
np.testing.assert_allclose(mge_output.numpy(), ori_output.numpy(), atol=1e-05) | |||||
np.testing.assert_allclose( | |||||
mge_m.weight.grad.numpy(), ori_m.weight.grad.numpy(), rtol=1e-03 | |||||
) | |||||
np.testing.assert_allclose( | |||||
mge_m.bias.grad.numpy(), ori_m.bias.grad.numpy(), rtol=1e-03 | |||||
) |
@@ -0,0 +1,97 @@ | |||||
#include "megbrain/opr/dnn/group_norm.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#include "../blob_manager_impl.h" | |||||
#include "../dnn_op_helper.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb::imperative { | |||||
namespace group_norm { | |||||
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const GroupNorm&>(def); | |||||
size_t nr_inp = inputs.size(); | |||||
auto p = op.param(); | |||||
mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
if (nr_inp == 3) { | |||||
return opr::GroupNorm::make( | |||||
inputs[0], inputs[1], inputs[2], op.param(), config)[0] | |||||
.node() | |||||
->owner_opr(); | |||||
} else { | |||||
return opr::GroupNorm::make(inputs[0], op.param(), config)[0] | |||||
.node() | |||||
->owner_opr(); | |||||
} | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
auto&& group_norm = def.cast_final_safe<GroupNorm>(); | |||||
size_t nr_inp = inputs.size(); | |||||
auto affine = group_norm.affine; | |||||
mgb_assert( | |||||
(nr_inp == 3 && affine) || (nr_inp == 1 && !affine), | |||||
"num of inputs of pooling should be 1 or 3 but you give %zu", | |||||
inputs.size()); | |||||
auto&& inp = inputs[0]; | |||||
auto& inp_cn = inp.comp_node; | |||||
if (inp.layout.ndim == 0) { | |||||
return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}, | |||||
{TensorLayout{dtype::Float32()}, inp_cn, {}}, | |||||
{TensorLayout{dtype::Float32()}, inp_cn, {}}}, | |||||
false}; | |||||
} | |||||
DnnOprHelper<megdnn::GroupNorm> dnn_opr(group_norm.param()); | |||||
auto&& [oup_layout, mean_layout, rstd_layout] = | |||||
dnn_opr.deduce_layouts<3>(inp.layout, TensorLayout{}, TensorLayout{}); | |||||
return {{{oup_layout, inp_cn, {}}, | |||||
{mean_layout, inp_cn, {}}, | |||||
{rstd_layout, inp_cn, {}}}, | |||||
true}; | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto&& op_def = def.cast_final_safe<GroupNorm>(); | |||||
size_t nr_inp = inputs.size(); | |||||
auto p = op_def.param(); | |||||
mgb_assert( | |||||
(nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine), | |||||
"num of inputs of groupnorm should be 1 or 3 but you give %zu", | |||||
inputs.size()); | |||||
auto cn = inputs[0]->comp_node(); | |||||
DnnOprCaller<megdnn::GroupNorm> caller(cn, op_def.param()); | |||||
auto&& [oup_layout, mean_layout, rstd_layout] = caller.deduce_layouts<3>( | |||||
inputs[0]->layout(), TensorLayout{}, TensorLayout{}); | |||||
auto out = Tensor::make(oup_layout, cn); | |||||
auto mean = Tensor::make(mean_layout, cn); | |||||
auto rstd = Tensor::make(rstd_layout, cn); | |||||
if (p.affine) { | |||||
caller.exec_with_ws(inputs[0], inputs[1], inputs[2], out, mean, rstd); | |||||
} else { | |||||
megdnn::TensorND empty_dnn; | |||||
caller.exec_with_ws(inputs[0], empty_dnn, empty_dnn, out, mean, rstd); | |||||
} | |||||
return {out, mean, rstd}; | |||||
} | |||||
OP_TRAIT_REG(GroupNorm, GroupNorm) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||||
.fallback(); | |||||
} // namespace group_norm | |||||
} // namespace mgb::imperative |
@@ -1,7 +1,7 @@ | |||||
905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py | |||||
da03ffe2a15411f902cd88920d3d47ec ../../src/core/include/megbrain/ir/ops.td | |||||
5756619f37e4dc130e1b049d7706d4eb generated/opdef.h.inl | |||||
98d1291eed73970ee087f898b6241358 generated/opdef.cpp.inl | |||||
b1a9c7569392942294c2168d40939eb5 generated/opdef.py.inl | |||||
3d88d5358d15a39219957f5257e32f5b generated/opdef.cpy.inl | |||||
e38b68be4e2aaf3de2f22e3dddbeaac4 ../../dnn/scripts/opr_param_defs.py | |||||
cf864561de125ab559c0035158656682 ../../src/core/include/megbrain/ir/ops.td | |||||
9248d42a9b3e770693306992156f6015 generated/opdef.h.inl | |||||
5c7e7ac49d1338d70ac84ba309e6732b generated/opdef.cpp.inl | |||||
30b669eec36876a65717e0c68dd76c83 generated/opdef.py.inl | |||||
d10455217f5f01e3d2668e5689068920 generated/opdef.cpy.inl | |||||
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h | 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h |
@@ -3775,6 +3775,110 @@ OP_TRAIT_REG(GroupLocal, GroupLocal) | |||||
.props(GroupLocal_props_impl) | .props(GroupLocal_props_impl) | ||||
.make_name(GroupLocal_make_name_impl); | .make_name(GroupLocal_make_name_impl); | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupNorm); | |||||
namespace { | |||||
size_t GroupNorm_hash_impl(const OpDef& def_) { | |||||
auto&& op_ = def_.cast_final_safe<GroupNorm>(); | |||||
static_cast<void>(op_); | |||||
size_t val = mgb::hash(op_.dyn_typeinfo()); | |||||
val = mgb::hash_pair_combine(val, mgb::hash(op_.affine)); | |||||
val = mgb::hash_pair_combine(val, mgb::hash(op_.eps)); | |||||
val = mgb::hash_pair_combine(val, mgb::hash(op_.group)); | |||||
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format)); | |||||
return val; | |||||
} | |||||
bool GroupNorm_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { | |||||
auto &&a_ = lhs_.cast_final_safe<GroupNorm>(), | |||||
&&b_ = rhs_.cast_final_safe<GroupNorm>(); | |||||
static_cast<void>(a_); | |||||
static_cast<void>(b_); | |||||
if (a_.affine != b_.affine) return false; | |||||
if (a_.eps != b_.eps) return false; | |||||
if (a_.group != b_.group) return false; | |||||
if (a_.format != b_.format) return false; | |||||
return true; | |||||
} | |||||
std::vector<std::pair<const char*, std::string>> GroupNorm_props_impl(const OpDef& def_) { | |||||
auto&& op_ = def_.cast_final_safe<GroupNorm>(); | |||||
static_cast<void>(op_); | |||||
std::vector<std::pair<const char*, std::string>> props_; | |||||
props_.emplace_back("affine", std::to_string(op_.affine)); | |||||
props_.emplace_back("eps", std::to_string(op_.eps)); | |||||
props_.emplace_back("group", std::to_string(op_.group)); | |||||
switch (op_.format){ | |||||
case GroupNorm::Format::NCHW: | |||||
props_.emplace_back("format", "NCHW"); | |||||
break; | |||||
case GroupNorm::Format::NHWC: | |||||
props_.emplace_back("format", "NHWC"); | |||||
break; | |||||
case GroupNorm::Format::NHWCD4: | |||||
props_.emplace_back("format", "NHWCD4"); | |||||
break; | |||||
case GroupNorm::Format::NCHW4: | |||||
props_.emplace_back("format", "NCHW4"); | |||||
break; | |||||
case GroupNorm::Format::NCHW8: | |||||
props_.emplace_back("format", "NCHW8"); | |||||
break; | |||||
case GroupNorm::Format::NCHW32: | |||||
props_.emplace_back("format", "NCHW32"); | |||||
break; | |||||
case GroupNorm::Format::NCHW88: | |||||
props_.emplace_back("format", "NCHW88"); | |||||
break; | |||||
case GroupNorm::Format::NCHW44: | |||||
props_.emplace_back("format", "NCHW44"); | |||||
break; | |||||
case GroupNorm::Format::NCHW44_DOT: | |||||
props_.emplace_back("format", "NCHW44_DOT"); | |||||
break; | |||||
case GroupNorm::Format::NCHW4_NCHW32: | |||||
props_.emplace_back("format", "NCHW4_NCHW32"); | |||||
break; | |||||
case GroupNorm::Format::NCHW32_NCHW4: | |||||
props_.emplace_back("format", "NCHW32_NCHW4"); | |||||
break; | |||||
case GroupNorm::Format::NCHW4_NCHW: | |||||
props_.emplace_back("format", "NCHW4_NCHW"); | |||||
break; | |||||
case GroupNorm::Format::NHWC_NCHW: | |||||
props_.emplace_back("format", "NHWC_NCHW"); | |||||
break; | |||||
case GroupNorm::Format::NHWC_NCHW4_IC_SMALL: | |||||
props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL"); | |||||
break; | |||||
case GroupNorm::Format::NCHW_NCHW4_IC_SMALL: | |||||
props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL"); | |||||
break; | |||||
case GroupNorm::Format::CHWN4: | |||||
props_.emplace_back("format", "CHWN4"); | |||||
break; | |||||
case GroupNorm::Format::NCHW64: | |||||
props_.emplace_back("format", "NCHW64"); | |||||
break; | |||||
case GroupNorm::Format::NCHW4_NHWC: | |||||
props_.emplace_back("format", "NCHW4_NHWC"); | |||||
break; | |||||
default: | |||||
props_.emplace_back("format", "INVALID"); | |||||
break; | |||||
} | |||||
return props_; | |||||
} | |||||
std::string GroupNorm_make_name_impl(const OpDef& def_) { | |||||
auto&& op_ = def_.cast_final_safe<GroupNorm>(); | |||||
static_cast<void>(op_); | |||||
return "GroupNorm"; | |||||
} | |||||
} // anonymous namespace | |||||
OP_TRAIT_REG(GroupNorm, GroupNorm) | |||||
.hash(GroupNorm_hash_impl) | |||||
.is_same_st(GroupNorm_is_same_st_impl) | |||||
.props(GroupNorm_props_impl) | |||||
.make_name(GroupNorm_make_name_impl); | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Identity); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(Identity); | ||||
namespace { | namespace { | ||||
@@ -10075,6 +10075,158 @@ void _init_py_GroupLocal(py::module m) { | |||||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(GroupLocal::typeinfo(), &py_type).second); | mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(GroupLocal::typeinfo(), &py_type).second); | ||||
} | } | ||||
void _init_py_GroupNorm_Format(PyTypeObject& py_type) { | |||||
auto& e_type = EnumWrapper<GroupNorm::Format>::type; | |||||
Py_INCREF(e_type); | |||||
mgb_assert(PyDict_SetItemString( | |||||
py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||||
} | |||||
PyOpDefBegin(GroupNorm) // { | |||||
static PyGetSetDef py_getsetters[]; | |||||
static PyMethodDef tp_methods[]; | |||||
static PyObject* getstate(PyObject* self, PyObject*) { | |||||
auto& opdef = reinterpret_cast<PyOp(GroupNorm)*>(self)->inst(); | |||||
static_cast<void>(opdef); | |||||
std::unordered_map<std::string, py::object> state { | |||||
{"affine", serialization<decltype(opdef.affine)>::dump(opdef.affine)}, | |||||
{"eps", serialization<decltype(opdef.eps)>::dump(opdef.eps)}, | |||||
{"group", serialization<decltype(opdef.group)>::dump(opdef.group)}, | |||||
{"format", serialization<decltype(opdef.format)>::dump(opdef.format)} | |||||
}; | |||||
return py::cast(state).release().ptr(); | |||||
} | |||||
static PyObject* setstate(PyObject* self, PyObject* args) { | |||||
PyObject* dict = PyTuple_GetItem(args, 0); | |||||
if (!dict) return NULL; | |||||
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict); | |||||
auto& opdef = reinterpret_cast<PyOp(GroupNorm)*>(self)->inst(); | |||||
static_cast<void>(opdef); | |||||
{ | |||||
auto&& iter = state.find("affine"); | |||||
if (iter != state.end()) { | |||||
opdef.affine = serialization<decltype(opdef.affine)>::load(iter->second); | |||||
} | |||||
} | |||||
{ | |||||
auto&& iter = state.find("eps"); | |||||
if (iter != state.end()) { | |||||
opdef.eps = serialization<decltype(opdef.eps)>::load(iter->second); | |||||
} | |||||
} | |||||
{ | |||||
auto&& iter = state.find("group"); | |||||
if (iter != state.end()) { | |||||
opdef.group = serialization<decltype(opdef.group)>::load(iter->second); | |||||
} | |||||
} | |||||
{ | |||||
auto&& iter = state.find("format"); | |||||
if (iter != state.end()) { | |||||
opdef.format = serialization<decltype(opdef.format)>::load(iter->second); | |||||
} | |||||
} | |||||
Py_RETURN_NONE; | |||||
} | |||||
static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||||
// }; | |||||
PyOpDefEnd(GroupNorm) | |||||
int PyOp(GroupNorm)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { | |||||
static const char* kwlist[] = {"affine", "eps", "group", "format", "scope", NULL}; | |||||
PyObject *affine = NULL, *eps = NULL, *group = NULL, *format = NULL, *scope = NULL; | |||||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOO", const_cast<char**>(kwlist), &affine, &eps, &group, &format, &scope)) | |||||
return -1; | |||||
if (affine) { | |||||
try { | |||||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||||
py::detail::loader_life_support guard{}; | |||||
reinterpret_cast<PyOp(GroupNorm)*>(self)->inst().affine = | |||||
py::cast<decltype(GroupNorm::affine)>(py::handle(affine)); | |||||
} CATCH_ALL(-1) | |||||
} | |||||
if (eps) { | |||||
try { | |||||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||||
py::detail::loader_life_support guard{}; | |||||
reinterpret_cast<PyOp(GroupNorm)*>(self)->inst().eps = | |||||
py::cast<decltype(GroupNorm::eps)>(py::handle(eps)); | |||||
} CATCH_ALL(-1) | |||||
} | |||||
if (group) { | |||||
try { | |||||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||||
py::detail::loader_life_support guard{}; | |||||
reinterpret_cast<PyOp(GroupNorm)*>(self)->inst().group = | |||||
py::cast<decltype(GroupNorm::group)>(py::handle(group)); | |||||
} CATCH_ALL(-1) | |||||
} | |||||
if (format) { | |||||
try { | |||||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||||
py::detail::loader_life_support guard{}; | |||||
reinterpret_cast<PyOp(GroupNorm)*>(self)->inst().format = | |||||
py::cast<decltype(GroupNorm::format)>(py::handle(format)); | |||||
} CATCH_ALL(-1) | |||||
} | |||||
if (scope) { | |||||
try { | |||||
reinterpret_cast<PyOp(OpDef)*>(self)->op | |||||
->set_scope(py::cast<std::string>(py::handle(scope))); | |||||
} CATCH_ALL(-1) | |||||
} | |||||
return 0; | |||||
} | |||||
PyGetSetDef PyOp(GroupNorm)::py_getsetters[] = { | |||||
{const_cast<char*>("affine"), py_get_generic(GroupNorm, affine), py_set_generic(GroupNorm, affine), const_cast<char*>("affine"), NULL}, | |||||
{const_cast<char*>("eps"), py_get_generic(GroupNorm, eps), py_set_generic(GroupNorm, eps), const_cast<char*>("eps"), NULL}, | |||||
{const_cast<char*>("group"), py_get_generic(GroupNorm, group), py_set_generic(GroupNorm, group), const_cast<char*>("group"), NULL}, | |||||
{const_cast<char*>("format"), py_get_generic(GroupNorm, format), py_set_generic(GroupNorm, format), const_cast<char*>("format"), NULL}, | |||||
{NULL} /* Sentinel */ | |||||
}; | |||||
PyMethodDef PyOp(GroupNorm)::tp_methods[] = { | |||||
{const_cast<char*>("__getstate__"), PyOp(GroupNorm)::getstate, METH_NOARGS, "GroupNorm getstate"}, | |||||
{const_cast<char*>("__setstate__"), PyOp(GroupNorm)::setstate, METH_VARARGS, "GroupNorm setstate"}, | |||||
{NULL} /* Sentinel */ | |||||
}; | |||||
void _init_py_GroupNorm(py::module m) { | |||||
using py_op = PyOp(GroupNorm); | |||||
auto& py_type = PyOpType(GroupNorm); | |||||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
py_type.tp_name = "megengine.core._imperative_rt.ops.GroupNorm"; | |||||
py_type.tp_basicsize = sizeof(PyOp(GroupNorm)); | |||||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
py_type.tp_doc = "GroupNorm"; | |||||
py_type.tp_base = &PyOpType(OpDef); | |||||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||||
py_type.tp_new = py_new_generic<py_op>; | |||||
py_type.tp_init = py_op::py_init; | |||||
py_type.tp_methods = py_op::tp_methods; | |||||
py_type.tp_getset = py_op::py_getsetters; | |||||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||||
_init_py_GroupNorm_Format(py_type); | |||||
PyType_Modified(&py_type); | |||||
m.add_object("GroupNorm", reinterpret_cast<PyObject*>(&py_type)); | |||||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(GroupNorm::typeinfo(), &py_type).second); | |||||
} | |||||
PyOpDefBegin(Identity) // { | PyOpDefBegin(Identity) // { | ||||
static PyGetSetDef py_getsetters[]; | static PyGetSetDef py_getsetters[]; | ||||
static PyMethodDef tp_methods[]; | static PyMethodDef tp_methods[]; | ||||
@@ -19237,6 +19389,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { | |||||
_init_py_GaussianRNG(m); \ | _init_py_GaussianRNG(m); \ | ||||
_init_py_GetVarShape(m); \ | _init_py_GetVarShape(m); \ | ||||
_init_py_GroupLocal(m); \ | _init_py_GroupLocal(m); \ | ||||
_init_py_GroupNorm(m); \ | |||||
_init_py_Identity(m); \ | _init_py_Identity(m); \ | ||||
_init_py_Images2Neibs(m); \ | _init_py_Images2Neibs(m); \ | ||||
_init_py_IncrMeshIndexing(m); \ | _init_py_IncrMeshIndexing(m); \ | ||||
@@ -988,6 +988,23 @@ public: | |||||
} | } | ||||
}; | }; | ||||
class GroupNorm : public OpDefImplBase<GroupNorm> { | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
public: | |||||
using Format = ::megdnn::param::GroupNorm::Format; | |||||
bool affine = true; | |||||
float eps = 1e-5f; | |||||
uint32_t group = 1; | |||||
Format format = ::megdnn::param::GroupNorm::Format::NCHW; | |||||
GroupNorm() = default; | |||||
GroupNorm(bool affine_, float eps_, uint32_t group_, Format format_, std::string scope_ = {}): affine(affine_), eps(eps_), group(group_), format(format_) { set_scope(scope_); } | |||||
GroupNorm(::megdnn::param::GroupNorm packed_param_0): affine(packed_param_0.affine), eps(packed_param_0.eps), group(packed_param_0.group), format(packed_param_0.format) {} | |||||
::megdnn::param::GroupNorm param() const { | |||||
return {affine, eps, group, format}; | |||||
} | |||||
}; | |||||
class Identity : public OpDefImplBase<Identity> { | class Identity : public OpDefImplBase<Identity> { | ||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | MGB_DYN_TYPE_OBJ_FINAL_DECL; | ||||
@@ -1193,6 +1193,17 @@ GroupLocalInst | |||||
.def_readwrite("format", &GroupLocal::format) | .def_readwrite("format", &GroupLocal::format) | ||||
.def_readwrite("compute_mode", &GroupLocal::compute_mode); | .def_readwrite("compute_mode", &GroupLocal::compute_mode); | ||||
py::class_<GroupNorm, std::shared_ptr<GroupNorm>, OpDef> GroupNormInst(m, "GroupNorm"); | |||||
GroupNormInst.attr("Format") = AdaptivePoolingInst.attr("Format"); | |||||
GroupNormInst | |||||
.def(py::init<bool, float, uint32_t, ::megdnn::param::GroupNorm::Format, std::string>(), py::arg("affine") = true, py::arg("eps") = 1e-5f, py::arg("group") = 1, py::arg("format") = ::megdnn::param::GroupNorm::Format::NCHW, py::arg("scope") = {}) | |||||
.def_readwrite("affine", &GroupNorm::affine) | |||||
.def_readwrite("eps", &GroupNorm::eps) | |||||
.def_readwrite("group", &GroupNorm::group) | |||||
.def_readwrite("format", &GroupNorm::format); | |||||
py::class_<Identity, std::shared_ptr<Identity>, OpDef> IdentityInst(m, "Identity"); | py::class_<Identity, std::shared_ptr<Identity>, OpDef> IdentityInst(m, "Identity"); | ||||
IdentityInst | IdentityInst | ||||
@@ -490,6 +490,8 @@ def LRN: MgbHashableOp<"LRN", [LRNParam]>; | |||||
def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; | def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; | ||||
def GroupNorm: MgbHashableOp<"GroupNorm", [GroupNormParam]>; | |||||
def LAMBUpdate: MgbHashableOp<"LAMBUpdate", [LAMBUpdateParam]>; | def LAMBUpdate: MgbHashableOp<"LAMBUpdate", [LAMBUpdateParam]>; | ||||
def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; | def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; | ||||
@@ -1,8 +1,10 @@ | |||||
#include "megbrain/opr/basic_arith.h" | |||||
#include "megbrain/opr/dnn/adaptive_pooling.h" | #include "megbrain/opr/dnn/adaptive_pooling.h" | ||||
#include "megbrain/opr/dnn/batch_norm.h" | #include "megbrain/opr/dnn/batch_norm.h" | ||||
#include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
#include "megbrain/opr/dnn/correlation.h" | #include "megbrain/opr/dnn/correlation.h" | ||||
#include "megbrain/opr/dnn/fake_quant.h" | #include "megbrain/opr/dnn/fake_quant.h" | ||||
#include "megbrain/opr/dnn/group_norm.h" | |||||
#include "megbrain/opr/dnn/images2neibs.h" | #include "megbrain/opr/dnn/images2neibs.h" | ||||
#include "megbrain/opr/dnn/layer_norm.h" | #include "megbrain/opr/dnn/layer_norm.h" | ||||
#include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
@@ -15,6 +17,9 @@ | |||||
#include "megbrain/opr/dnn/sliding_window_transpose.h" | #include "megbrain/opr/dnn/sliding_window_transpose.h" | ||||
#include "megbrain/opr/dnn/softmax.h" | #include "megbrain/opr/dnn/softmax.h" | ||||
#include "megbrain/opr/dnn/tqt.h" | #include "megbrain/opr/dnn/tqt.h" | ||||
#include "megbrain/opr/io.h" | |||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/serialization/oss_opr_load_dump.h" | |||||
#include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
#include "megdnn/oprs/nn.h" | #include "megdnn/oprs/nn.h" | ||||
@@ -524,6 +529,213 @@ struct OprMaker<opr::LayerNormBackward, 0> { | |||||
} | } | ||||
}; | }; | ||||
template <> | |||||
struct OprMaker<opr::GroupNorm, 0> { | |||||
using Param = opr::GroupNorm::Param; | |||||
static cg::OperatorNodeBase* make( | |||||
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||||
const OperatorNodeConfig& config) { | |||||
MGB_MARK_USED_VAR(graph); | |||||
if (i.size() == 3) { | |||||
return opr::GroupNorm::make(i[0], i[1], i[2], param, config)[0] | |||||
.node() | |||||
->owner_opr(); | |||||
} else { | |||||
mgb_assert(i.size() == 1); | |||||
return opr::GroupNorm::make(i[0], param, config)[0].node()->owner_opr(); | |||||
} | |||||
} | |||||
}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::GroupNorm, 0> { | |||||
using Opr = opr::GroupNorm; | |||||
using Param = opr::GroupNorm::Param; | |||||
using ElemwiseParam = opr::Elemwise::Param; | |||||
using ReduceParam = opr::Reduce::Param; | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | |||||
ctx.write_param<Param>(opr.cast_final_safe<Opr>().param()); | |||||
} | |||||
static cg::OperatorNodeBase* replace_opr( | |||||
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { | |||||
auto graph = inputs[0]->owner_graph(); | |||||
auto comp_node = inputs[0]->comp_node(); | |||||
// std::unique_ptr<StaticInferManager> m_static_infer_manager; | |||||
auto opr_param = opr->cast_final_safe<opr::GroupNorm>().param(); | |||||
float eps = opr_param.eps; | |||||
auto half = DTypeScalar(static_cast<megdnn::dt_float32>(0.5)); | |||||
auto param_eps = DTypeScalar(static_cast<megdnn::dt_float32>(eps)); | |||||
auto half_node = opr::ImmutableTensor::make(*graph, half, {comp_node}); | |||||
auto eps_node = opr::ImmutableTensor::make(*graph, param_eps, {comp_node}); | |||||
auto origin_shape = opr::GetVarShape::make(inputs[0]).node(); | |||||
TensorShape input_shape = | |||||
inputs[0]->owner_graph()->static_infer_manager().infer_shape(inputs[0]); | |||||
size_t N = input_shape[0]; | |||||
size_t inner_size = input_shape[1] * input_shape[2] * input_shape[3]; | |||||
int group = opr_param.group; | |||||
int size = inner_size / group; | |||||
HostTensorND hv = HostTensorND(inputs[0]->comp_node(), {3}, dtype::Int32()); | |||||
auto* ptr = hv.ptr<dt_int32>(); | |||||
ptr[0] = N; | |||||
ptr[1] = group; | |||||
ptr[2] = size; | |||||
auto target_shape = opr::ImmutableTensor::make(*graph, hv, {comp_node}); | |||||
auto inp = opr::Reshape::make(inputs[0], target_shape); | |||||
auto mean = opr::Reduce::make(inp, {ReduceParam::Mode::MEAN, 2}); | |||||
auto elemwise1 = opr::Elemwise::make({inp, inp}, {ElemwiseParam::Mode::MUL}); | |||||
auto temp_var = opr::Reduce::make(elemwise1, {ReduceParam::Mode::MEAN, 2}); | |||||
auto elemwise2 = opr::Elemwise::make({mean, mean}, {ElemwiseParam::Mode::MUL}); | |||||
auto var = | |||||
opr::Elemwise::make({temp_var, elemwise2}, {ElemwiseParam::Mode::SUB}); | |||||
auto add_var = opr::Elemwise::make({var, eps_node}, {ElemwiseParam::Mode::ADD}); | |||||
auto sqrt = | |||||
opr::Elemwise::make({add_var, half_node}, {ElemwiseParam::Mode::POW}); | |||||
auto div = opr::Elemwise::make({inp, mean}, {ElemwiseParam::Mode::SUB}); | |||||
auto temp_inp = | |||||
opr::Elemwise::make({div, sqrt}, {ElemwiseParam::Mode::TRUE_DIV}); | |||||
auto res = opr::Reshape::make(temp_inp, origin_shape); | |||||
if (inputs.size() == 3) { | |||||
auto mul_temp = | |||||
opr::Elemwise::make({res, inputs[1]}, {ElemwiseParam::Mode::MUL}); | |||||
auto res = opr::Elemwise::make( | |||||
{mul_temp, inputs[2]}, {ElemwiseParam::Mode::ADD}); | |||||
return res.node()->owner_opr(); | |||||
} else { | |||||
return res.node()->owner_opr(); | |||||
} | |||||
} | |||||
static cg::OperatorNodeBase* load( | |||||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
// auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||||
return OprMaker<opr::GroupNorm, 0>::make( | |||||
ctx.read_param<Param>(), inputs, ctx.graph(), config); | |||||
} | |||||
}; | |||||
// OprMaker in MGB_SEREG_OPR only support unique output opr | |||||
template <> | |||||
struct OprMaker<opr::GroupNormBackward, 0> { | |||||
using Param = opr::GroupNormBackward::Param; | |||||
static cg::OperatorNodeBase* make( | |||||
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||||
const OperatorNodeConfig& config) { | |||||
MGB_MARK_USED_VAR(graph); | |||||
if (i.size() == 5) { | |||||
return opr::GroupNormBackward::make( | |||||
i[0], i[1], i[2], i[3], i[4], param, config)[0] | |||||
.node() | |||||
->owner_opr(); | |||||
} else { | |||||
mgb_assert(i.size() == 4); | |||||
return opr::GroupNormBackward::make( | |||||
i[0], i[1], i[2], i[3], param, config)[0] | |||||
.node() | |||||
->owner_opr(); | |||||
} | |||||
} | |||||
}; | |||||
template <> | |||||
struct OprLoadDumpImplV2<opr::GroupNormBackward, 0> { | |||||
using Opr = opr::GroupNormBackward; | |||||
using Param = opr::GroupNormBackward::Param; | |||||
using ElemwiseParam = opr::Elemwise::Param; | |||||
using ReduceParam = opr::Reduce::Param; | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | |||||
ctx.write_param<Param>(opr.cast_final_safe<Opr>().param()); | |||||
} | |||||
static cg::OperatorNodeBase* replace_opr( | |||||
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { | |||||
auto rstd = inputs[4]; | |||||
auto graph = inputs[1]->owner_graph(); | |||||
auto comp_node = inputs[1]->comp_node(); | |||||
auto opr_param = opr->cast_final_safe<opr::GroupNormBackward>().param(); | |||||
float eps = opr_param.eps; | |||||
auto half = DTypeScalar(static_cast<megdnn::dt_float32>(0.5)); | |||||
auto param_eps = DTypeScalar(static_cast<megdnn::dt_float32>(eps)); | |||||
auto half_node = opr::ImmutableTensor::make(*graph, half, {comp_node}); | |||||
auto eps_node = opr::ImmutableTensor::make(*graph, param_eps, {comp_node}); | |||||
auto const_node = | |||||
opr::ImmutableTensor::make(*graph, DTypeScalar(1), {comp_node}); | |||||
TensorShape input_shape = | |||||
inputs[1]->owner_graph()->static_infer_manager().infer_shape(inputs[0]); | |||||
auto origin_shape = opr::GetVarShape::make(inputs[1]).node(); | |||||
size_t N = input_shape[0]; | |||||
size_t C = input_shape[1]; | |||||
size_t inner_size = input_shape[1] * input_shape[2] * input_shape[3]; | |||||
int group = opr_param.group; | |||||
int size = inner_size / group; | |||||
HostTensorND hv = HostTensorND(inputs[1]->comp_node(), {3}, dtype::Int32()); | |||||
auto* ptr = hv.ptr<dt_int32>(); | |||||
ptr[0] = N; | |||||
ptr[1] = group; | |||||
ptr[2] = size; | |||||
auto target_shape = opr::ImmutableTensor::make(*graph, hv, {comp_node}); | |||||
auto inp = opr::Reshape::make(inputs[1], target_shape); | |||||
auto temp_rstd = | |||||
opr::Elemwise::make({rstd, eps_node}, {ElemwiseParam::Mode::ADD}); | |||||
auto sqrt = | |||||
opr::Elemwise::make({temp_rstd, half_node}, {ElemwiseParam::Mode::POW}); | |||||
auto slice_std = opr::Elemwise::make( | |||||
{const_node, sqrt}, {ElemwiseParam::Mode::TRUE_DIV}); | |||||
auto sub_mean = | |||||
opr::Elemwise::make({inp, inputs[3]}, {ElemwiseParam::Mode::SUB}); | |||||
auto x_hat = | |||||
opr::Elemwise::make({sub_mean, slice_std}, {ElemwiseParam::Mode::MUL}); | |||||
x_hat = opr::Reshape::make(x_hat, origin_shape); | |||||
auto size_node = | |||||
opr::ImmutableTensor::make(*graph, DTypeScalar(size), {comp_node}); | |||||
auto temp1 = opr::Elemwise::make( | |||||
{slice_std, size_node}, {ElemwiseParam::Mode::TRUE_DIV}); | |||||
auto dx_hat = | |||||
opr::Elemwise::make({inputs[0], inputs[2]}, {ElemwiseParam::Mode::MUL}); | |||||
HostTensorND tshape = HostTensorND(inputs[1]->comp_node(), {5}, dtype::Int32()); | |||||
auto* ptr2 = tshape.ptr<dt_int32>(); | |||||
ptr2[0] = N; | |||||
ptr2[1] = group; | |||||
ptr2[2] = C / group; | |||||
ptr2[3] = input_shape[2]; | |||||
ptr2[4] = input_shape[3]; | |||||
target_shape = opr::ImmutableTensor::make(*graph, tshape, {comp_node}); | |||||
x_hat = opr::Reshape::make(x_hat, target_shape); | |||||
dx_hat = opr::Reshape::make(dx_hat, target_shape); | |||||
auto temp2 = | |||||
opr::Elemwise::make({size_node, dx_hat}, {ElemwiseParam::Mode::MUL}); | |||||
ptr2[2] = 1; | |||||
ptr2[3] = 1; | |||||
ptr2[4] = 1; | |||||
target_shape = opr::ImmutableTensor::make(*graph, tshape, {comp_node}); | |||||
auto temp3 = opr::Reduce::make(dx_hat, {ReduceParam::Mode::SUM}, target_shape); | |||||
auto sum_dx_hat = | |||||
opr::Reduce::make(temp2, {ReduceParam::Mode::SUM}, target_shape); | |||||
auto temp4 = | |||||
opr::Elemwise::make({x_hat, sum_dx_hat}, {ElemwiseParam::Mode::MUL}); | |||||
auto temp5 = opr::Elemwise::make({temp2, temp3}, {ElemwiseParam::Mode::SUB}); | |||||
auto temp6 = opr::Elemwise::make({temp5, temp4}, {ElemwiseParam::Mode::SUB}); | |||||
auto dx_temp = opr::Elemwise::make({temp1, temp6}, {ElemwiseParam::Mode::MUL}); | |||||
auto dx = opr::Reshape::make(dx_temp, origin_shape); | |||||
return dx.node()->owner_opr(); | |||||
} | |||||
static cg::OperatorNodeBase* load( | |||||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
return OprMaker<opr::GroupNormBackward, 0>::make( | |||||
ctx.read_param<Param>(), inputs, ctx.graph(), config); | |||||
} | |||||
}; | |||||
template <class MegDNNConv = megdnn::LocalShare> | template <class MegDNNConv = megdnn::LocalShare> | ||||
struct MakeLocalShareCaller2 { | struct MakeLocalShareCaller2 { | ||||
template <typename Opr> | template <typename Opr> | ||||
@@ -747,6 +959,8 @@ MGB_SEREG_OPR(LSQ, 4); | |||||
MGB_SEREG_OPR(LSQBackward, 5); | MGB_SEREG_OPR(LSQBackward, 5); | ||||
MGB_SEREG_OPR(LayerNorm, 0); | MGB_SEREG_OPR(LayerNorm, 0); | ||||
MGB_SEREG_OPR(LayerNormBackward, 0); | MGB_SEREG_OPR(LayerNormBackward, 0); | ||||
MGB_SEREG_OPR(GroupNorm, 0); | |||||
MGB_SEREG_OPR(GroupNormBackward, 0); | |||||
MGB_SEREG_OPR(RNNCellForward, 6); | MGB_SEREG_OPR(RNNCellForward, 6); | ||||
MGB_SEREG_OPR(LSTMCellForward, 7); | MGB_SEREG_OPR(LSTMCellForward, 7); | ||||
MGB_SEREG_OPR(RNNForward, 3); | MGB_SEREG_OPR(RNNForward, 3); | ||||
@@ -755,6 +969,14 @@ MGB_SEREG_OPR(LSTMForward, 4); | |||||
MGB_SEREG_OPR(LSTMBackward, 9); | MGB_SEREG_OPR(LSTMBackward, 9); | ||||
MGB_SEREG_OPR(Softmax, 1); | MGB_SEREG_OPR(Softmax, 1); | ||||
MGB_SEREG_OPR(SoftmaxBackward, 2); | MGB_SEREG_OPR(SoftmaxBackward, 2); | ||||
MGB_SEREG_OPR_V2( | |||||
GroupNorm, 0, | |||||
(mgb::serialization::OprLoadDumpImplV2<opr::GroupNorm, 0>::replace_opr), | |||||
VERSION_2, CURRENT_VERSION); | |||||
MGB_SEREG_OPR_V2( | |||||
GroupNormBackward, 0, | |||||
(mgb::serialization::OprLoadDumpImplV2<opr::GroupNormBackward, 0>::replace_opr), | |||||
VERSION_2, CURRENT_VERSION); | |||||
} // namespace opr | } // namespace opr | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -0,0 +1,239 @@ | |||||
#include "megbrain/opr/dnn/group_norm.h" | |||||
#include "megbrain/graph/grad_impl.h" | |||||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||||
#include "megbrain/opr/utility.h" | |||||
#include "../internal/megdnn_opr_wrapper.inl" | |||||
using namespace mgb; | |||||
using namespace opr; | |||||
/* ==================== GroupNormForward ==================== */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupNormForward); | |||||
GroupNormForward::GroupNormForward( | |||||
VarNode* data, VarNode* weight, VarNode* bias, const Param& param, | |||||
const OperatorNodeConfig& config) | |||||
: Super{data->owner_graph(), config, "group_norm", {data, weight, bias}} { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({data, weight, bias}); | |||||
output(0)->dtype(data->dtype()); | |||||
output(1)->dtype(dtype::Float32()); | |||||
output(2)->dtype(dtype::Float32()); | |||||
} | |||||
GroupNormForward::GroupNormForward( | |||||
VarNode* data, const Param& param, const OperatorNodeConfig& config) | |||||
: Super{data->owner_graph(), config, "group_norm", {data}} { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({data}); | |||||
output(0)->dtype(data->dtype()); | |||||
output(1)->dtype(dtype::Float32()); | |||||
output(2)->dtype(dtype::Float32()); | |||||
} | |||||
SymbolVarArray GroupNormForward::make( | |||||
SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param, | |||||
const OperatorNodeConfig& config) { | |||||
auto outs = data.node() | |||||
->owner_graph() | |||||
->insert_opr(std::make_unique<GroupNormForward>( | |||||
data.node(), weight.node(), bias.node(), param, config)) | |||||
->output(); | |||||
SymbolVarArray ret; | |||||
for (auto&& out : outs) { | |||||
ret.emplace_back(out); | |||||
} | |||||
return ret; | |||||
} | |||||
SymbolVarArray GroupNormForward::make( | |||||
SymbolVar data, const Param& param, const OperatorNodeConfig& config) { | |||||
auto outs = data.node() | |||||
->owner_graph() | |||||
->insert_opr(std::make_unique<GroupNormForward>( | |||||
data.node(), param, config)) | |||||
->output(); | |||||
SymbolVarArray ret; | |||||
for (auto&& out : outs) { | |||||
ret.emplace_back(out); | |||||
} | |||||
return ret; | |||||
} | |||||
void GroupNormForward::get_output_var_shape( | |||||
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { | |||||
size_t group = param().group; | |||||
out_shape[0] = inp_shape[0]; | |||||
size_t N = inp_shape[0].shape[0]; | |||||
TensorShape unnormalized_shape{N, group}; | |||||
out_shape[1] = unnormalized_shape; | |||||
out_shape[2] = unnormalized_shape; | |||||
} | |||||
size_t GroupNormForward::get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const { | |||||
return intl::MegDNNOprMethInvoker<megdnn::GroupNormForward>::get_workspace_in_bytes( | |||||
megdnn_opr(), this, input_shapes, output_shapes); | |||||
} | |||||
void GroupNormForward::scn_do_execute() { | |||||
if (param().affine) { | |||||
megdnn_opr()->exec( | |||||
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||||
input(2)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||||
output(1)->dev_tensor().as_megdnn(), | |||||
output(2)->dev_tensor().as_megdnn(), | |||||
intl::get_megdnn_workspace_from_var(output().back())); | |||||
} else { | |||||
megdnn_opr()->exec( | |||||
input(0)->dev_tensor().as_megdnn(), {}, {}, | |||||
output(0)->dev_tensor().as_megdnn(), | |||||
output(1)->dev_tensor().as_megdnn(), | |||||
output(2)->dev_tensor().as_megdnn(), | |||||
intl::get_megdnn_workspace_from_var(output().back())); | |||||
} | |||||
} | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(GroupNormForward) { | |||||
auto p = opr.param(); | |||||
SymbolVarArray grad; | |||||
VarNodeArray ret; | |||||
if (p.affine) { | |||||
mgb_assert(wrt_idx < 3, "wrt_idx %zu is out of range", wrt_idx); | |||||
grad = GroupNormBackward::make( | |||||
out_grad[0], opr.input(0), opr.input(1), opr.output(1), opr.output(2), | |||||
opr.param()); | |||||
} else { | |||||
mgb_assert(wrt_idx < 1, "wrt_idx %zu is out of range", wrt_idx); | |||||
grad = GroupNormBackward::make( | |||||
out_grad[0], opr.input(0), opr.output(1), opr.output(2), opr.param()); | |||||
} | |||||
uint32_t nr_ret = p.affine ? 3 : 1; | |||||
for (uint32_t i = 0; i < nr_ret; ++i) { | |||||
ret.push_back(grad[i].node()); | |||||
} | |||||
return ret; | |||||
} | |||||
#endif | |||||
/* ==================== GroupNormBackward ==================== */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupNormBackward); | |||||
GroupNormBackward::GroupNormBackward( | |||||
VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, | |||||
const Param& param, const OperatorNodeConfig& config) | |||||
: Super({diff->owner_graph(), | |||||
config, | |||||
"group_norm_backward", | |||||
{diff, data, weight, mean, rstd}}, | |||||
0, true) { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({diff, data, weight, mean, rstd}); | |||||
} | |||||
GroupNormBackward::GroupNormBackward( | |||||
VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, const Param& param, | |||||
const OperatorNodeConfig& config) | |||||
: Super({diff->owner_graph(), | |||||
config, | |||||
"group_norm_backward", | |||||
{diff, data, mean, rstd}}, | |||||
0, true) { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({diff, data, mean, rstd}); | |||||
auto mark_empty_var = [&](VarNode* var) { | |||||
var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | |||||
.add_flag(VarNode::Flag::VOLATILE_CONTENT); | |||||
}; | |||||
mark_empty_var(output(1)); | |||||
mark_empty_var(output(2)); | |||||
} | |||||
SymbolVarArray GroupNormBackward::make( | |||||
SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, | |||||
SymbolVar rstd, const Param& param, const OperatorNodeConfig& config) { | |||||
auto outs = diff.node() | |||||
->owner_graph() | |||||
->insert_opr(std::make_unique<GroupNormBackward>( | |||||
diff.node(), data.node(), weight.node(), mean.node(), | |||||
rstd.node(), param, config)) | |||||
->output(); | |||||
SymbolVarArray ret; | |||||
for (auto&& out : outs) { | |||||
ret.emplace_back(out); | |||||
} | |||||
return ret; | |||||
} | |||||
SymbolVarArray GroupNormBackward::make( | |||||
SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, | |||||
const Param& param, const OperatorNodeConfig& config) { | |||||
auto outs = diff.node() | |||||
->owner_graph() | |||||
->insert_opr(std::make_unique<GroupNormBackward>( | |||||
diff.node(), data.node(), mean.node(), rstd.node(), | |||||
param, config)) | |||||
->output(); | |||||
SymbolVarArray ret; | |||||
for (auto&& out : outs) { | |||||
ret.emplace_back(out); | |||||
} | |||||
return ret; | |||||
} | |||||
void GroupNormBackward::init_output_static_infer_desc() { | |||||
using namespace cg::static_infer; | |||||
auto&& mgr = owner_graph()->static_infer_manager(); | |||||
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1))); | |||||
if (param().affine) { | |||||
mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); | |||||
mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(2))); | |||||
} else { | |||||
TensorShape empty; | |||||
empty.ndim = 0; | |||||
mgr.register_shape_infer(output(1), ShapeInferDesc::make_const(empty)); | |||||
mgr.register_shape_infer(output(2), ShapeInferDesc::make_const(empty)); | |||||
} | |||||
this->init_output_static_infer_desc_workspace( | |||||
intl::AutoAddWorkspaceNeedLimitGetter<megdnn::GroupNormBackward>::val); | |||||
} | |||||
void GroupNormBackward::init_output_dtype() { | |||||
output(0)->dtype(input(1)->dtype()); | |||||
output(1)->dtype(input(2)->dtype()); | |||||
output(2)->dtype(input(2)->dtype()); | |||||
} | |||||
size_t GroupNormBackward::get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const { | |||||
return intl::MegDNNOprMethInvoker<megdnn::GroupNormBackward>:: | |||||
get_workspace_in_bytes(megdnn_opr(), this, input_shapes, output_shapes); | |||||
} | |||||
void GroupNormBackward::scn_do_execute() { | |||||
if (param().affine) { | |||||
megdnn_opr()->exec( | |||||
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||||
input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), | |||||
input(4)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||||
output(1)->dev_tensor().as_megdnn(), | |||||
output(2)->dev_tensor().as_megdnn(), | |||||
intl::get_megdnn_workspace_from_var(output(3))); | |||||
} else { | |||||
megdnn_opr()->exec( | |||||
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||||
{}, input(2)->dev_tensor().as_megdnn(), | |||||
input(3)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||||
{}, {}, intl::get_megdnn_workspace_from_var(output(3))); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,67 @@ | |||||
#pragma once | |||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#include "megdnn/oprs.h" | |||||
namespace mgb { | |||||
namespace opr { | |||||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
GroupNormForward, intl::MegDNNOprWrapperFwd<megdnn::GroupNormForward>) // { | |||||
public: | |||||
MGE_WIN_DECLSPEC_FUC GroupNormForward( | |||||
VarNode* data, VarNode* weight, VarNode* bias, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC GroupNormForward( | |||||
VarNode* data, const Param& param, const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||||
SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||||
SymbolVar data, const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
private: | |||||
void get_output_var_shape( | |||||
const TensorShapeArray& inp_shape, | |||||
TensorShapeArray& out_shape) const override; | |||||
size_t get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const override; | |||||
void scn_do_execute() override; | |||||
}; | |||||
using GroupNorm = GroupNormForward; | |||||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
GroupNormBackward, intl::MegDNNOprWrapperBwd<megdnn::GroupNormBackward>) // { | |||||
public: | |||||
MGE_WIN_DECLSPEC_FUC GroupNormBackward( | |||||
VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, | |||||
const Param& param, const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC GroupNormBackward( | |||||
VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, | |||||
const Param& param, const OperatorNodeConfig& config); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||||
SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, | |||||
SymbolVar rstd, const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||||
SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, | |||||
const Param& param = {}, const OperatorNodeConfig& config = {}); | |||||
private: | |||||
void init_output_static_infer_desc() override; | |||||
void init_output_dtype() override; | |||||
size_t get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const override; | |||||
void scn_do_execute() override; | |||||
}; | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,90 @@ | |||||
#include "megbrain/opr/dnn/group_norm.h" | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/test/autocheck.h" | |||||
#include "megbrain/test/helper.h" | |||||
#include "megbrain/test/megdnn_helper.h" | |||||
#include "megdnn/oprs.h" | |||||
#include <cmath> | |||||
#include <iomanip> | |||||
#include <random> | |||||
#include <sstream> | |||||
using namespace mgb; | |||||
namespace { | |||||
using Param = opr::GroupNormForward::Param; | |||||
void run_forward(bool is_affine) { | |||||
using Checker = AutoOprChecker<3, 3>; | |||||
Param param; | |||||
param.eps = 1e-5; | |||||
param.affine = is_affine; | |||||
param.group = 3; | |||||
auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||||
auto out = opr::GroupNormForward::make(inputs[0], inputs[1], inputs[2], param); | |||||
return {out[0], out[1], out[2]}; | |||||
}; | |||||
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||||
auto opr = | |||||
MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu())) | |||||
->create_operator<megdnn::GroupNormForward>(); | |||||
auto inp_shape = inp[0]->shape(); | |||||
auto n_slices = inp_shape[0]; | |||||
opr->param() = param; | |||||
dest[0].dtype(dtype::Float32()) | |||||
.comp_node(inp[0]->comp_node()) | |||||
.resize(inp_shape); | |||||
dest[1].dtype(dtype::Float32()) | |||||
.comp_node(inp[0]->comp_node()) | |||||
.resize({n_slices, param.group}); | |||||
dest[2].dtype(dtype::Float32()) | |||||
.comp_node(inp[0]->comp_node()) | |||||
.resize({n_slices, param.group}); | |||||
std::vector<dt_byte> workspace(opr->get_workspace_in_bytes( | |||||
inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), dest[0].layout(), | |||||
dest[1].layout(), dest[2].layout())); | |||||
opr->exec( | |||||
inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), | |||||
dest[0].as_megdnn(), dest[1].as_megdnn(), dest[2].as_megdnn(), | |||||
{workspace.data(), workspace.size()}); | |||||
}; | |||||
auto gen = [&](HostTensorND& src) { | |||||
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> src_gen(0.f); | |||||
src = *src_gen(src.shape(), src.comp_node()); | |||||
}; | |||||
Checker::RunOptions option; | |||||
option.numdiff_max_err = 1e-4; | |||||
Checker checker{make_graph, fwd}; | |||||
checker.set_input_generator(0, gen); | |||||
checker.set_input_generator(1, gen); | |||||
checker.set_input_generator(2, gen); | |||||
checker.set_input_allow_grad(0, false); | |||||
checker.set_input_allow_grad(1, false); | |||||
checker.set_input_allow_grad(2, false); | |||||
checker.set_output_allow_grad(0, false); | |||||
checker.set_output_allow_grad(1, false); | |||||
checker.set_output_allow_grad(2, false); | |||||
checker.run({TensorShape{2, 6, 2, 1}, TensorShape{6}, TensorShape{6}}, option) | |||||
.run({TensorShape{2, 6, 2, 1}, TensorShape{6}, TensorShape{6}}, option) | |||||
.run({TensorShape{2, 6, 2, 1}, TensorShape{6}, TensorShape{6}}, option); | |||||
} | |||||
TEST(TestOprDNN, GroupNormForward) { | |||||
REQUIRE_GPU(1); | |||||
run_forward(true); | |||||
} | |||||
} // anonymous namespace | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -123,6 +123,7 @@ union OperatorParam { | |||||
param.LSTM = 89, | param.LSTM = 89, | ||||
param.Softmax = 90, | param.Softmax = 90, | ||||
param.Diag = 91, | param.Diag = 91, | ||||
param.GroupNorm = 92, | |||||
} | } | ||||
table Operator { | table Operator { | ||||
@@ -140,6 +140,7 @@ union OperatorParam { | |||||
param.LSTM = 89, | param.LSTM = 89, | ||||
param.Softmax = 90, | param.Softmax = 90, | ||||
param.Diag = 91, | param.Diag = 91, | ||||
param.GroupNorm = 92, | |||||
} | } | ||||
table Operator { | table Operator { | ||||