@@ -1869,6 +1869,13 @@ table LayerNorm { | |||
normalized_size:ulong = 1; | |||
} | |||
table GroupNorm { | |||
affine:bool = true; | |||
eps:float = 1e-5; | |||
group:uint = 1; | |||
format:ConvolutionFormat = NCHW; | |||
} | |||
table Dropout { | |||
drop_prob:float = 0; | |||
seed:ulong = 0; | |||
@@ -140,6 +140,7 @@ union OperatorParam { | |||
param.LSTM = 89, | |||
param.Softmax = 90, | |||
param.Diag = 91, | |||
param.GroupNorm = 92, | |||
} | |||
table Operator { | |||
@@ -2430,6 +2430,76 @@ protected: | |||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw, | |||
size_t workspace_in_bytes); | |||
}; | |||
class GroupNormBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(GroupNormBase, OperatorBase); | |||
DEF_OPR_PARAM(GroupNorm); | |||
protected: | |||
void deduce_layout_fwd( | |||
const TensorLayout& data, const TensorLayout& weight, | |||
const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean, | |||
TensorLayout& rstd); | |||
void check_layout_fwd( | |||
const TensorLayout& data, const TensorLayout& weight, | |||
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, | |||
const TensorLayout& rstd); | |||
}; | |||
class GroupNormForward : public GroupNormBase { | |||
DEF_OPR_IMPL(GroupNormForward, GroupNormBase, 3, 3); | |||
public: | |||
virtual void exec( | |||
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||
_megdnn_workspace workspace) = 0; | |||
MGE_WIN_DECLSPEC_FUC void deduce_layout( | |||
const TensorLayout& data, const TensorLayout& weight, | |||
const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean, | |||
TensorLayout& rstd); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& data, const TensorLayout& weight, | |||
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, | |||
const TensorLayout& rstd) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& data, const TensorLayout& weight, | |||
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, | |||
const TensorLayout& rstd, size_t workspace_in_bytes); | |||
}; | |||
using GroupNorm = GroupNormForward; | |||
class GroupNormBackward : public GroupNormBase { | |||
DEF_OPR_IMPL(GroupNormBackward, GroupNormBase, 5, 3); | |||
public: | |||
virtual void exec( | |||
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout( | |||
const TensorLayout& diff, const TensorLayout& data, | |||
const TensorLayout& weight, const TensorLayout& mean, | |||
const TensorLayout& rstd, TensorLayout& ddata, TensorLayout& dweight, | |||
TensorLayout& dbias); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& diff, const TensorLayout& data, | |||
const TensorLayout& weight, const TensorLayout& mean, | |||
const TensorLayout& rstd, const TensorLayout& ddata, | |||
const TensorLayout& dweight, const TensorLayout& dbias) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& diff, const TensorLayout& data, | |||
const TensorLayout& weight, const TensorLayout& mean, | |||
const TensorLayout& rstd, const TensorLayout& ddata, | |||
const TensorLayout& dweight, const TensorLayout& dbias, | |||
size_t workspace_in_bytes); | |||
}; | |||
} // namespace megdnn | |||
#include "megdnn/internal/opr_header_epilogue.h" | |||
@@ -1247,6 +1247,13 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | |||
.add_fields('uint64', 'normalized_size', '1') | |||
) | |||
(pdef('GroupNorm') | |||
.add_fields('bool', 'affine', 'true') | |||
.add_fields('float32', 'eps', '1e-5f') | |||
.add_fields('uint32', 'group', '1') | |||
.add_enum_alias('Format', 'Convolution') | |||
) | |||
(pdef('Dropout') | |||
.add_fields('float32', 'drop_prob', '0') | |||
.add_fields('uint64', 'seed', '0') | |||
@@ -0,0 +1,121 @@ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
using Param = GroupNormBase::Param; | |||
void GroupNormBase::deduce_layout_fwd( | |||
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||
MEGDNN_MARK_USED_VAR(weight); | |||
MEGDNN_MARK_USED_VAR(bias); | |||
size_t N = data.shape[0]; | |||
size_t group = param().group; | |||
TensorLayout unnormalized_layout({N, group}, dtype::Float32()); | |||
dst = data; | |||
mean = unnormalized_layout; | |||
rstd = unnormalized_layout; | |||
} | |||
void GroupNormBase::check_layout_fwd( | |||
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) { | |||
megdnn_assert_contiguous(data); | |||
megdnn_assert_contiguous(weight); | |||
megdnn_assert_contiguous(bias); | |||
megdnn_assert_contiguous(dst); | |||
megdnn_assert_contiguous(mean); | |||
megdnn_assert_contiguous(rstd); | |||
auto errmsg = [&]() { | |||
return megdnn_layout_msg(data) + ", " + megdnn_layout_msg(weight) + ", " + | |||
megdnn_layout_msg(bias) + ", " + megdnn_layout_msg(dst) + ", " + | |||
megdnn_layout_msg(mean) + ", " + megdnn_layout_msg(rstd); | |||
}; | |||
MEGDNN_MARK_USED_VAR(errmsg); | |||
megdnn_assert(data.eq_layout(dst), "%s", errmsg().c_str()); | |||
megdnn_assert(weight.eq_layout(bias), "%s", errmsg().c_str()); | |||
megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str()); | |||
auto p = param(); | |||
size_t C = data.shape[1]; | |||
size_t group = p.group; | |||
megdnn_assert( | |||
group > 0, "Expected num groups to be greater than 0, got %zu", group); | |||
megdnn_assert( | |||
C % group == 0, | |||
"Expected number of channels in input to be divisible by num_groups, but " | |||
"got Channel of shape %zu and num_groups= %zu", | |||
C, group); | |||
} | |||
void GroupNormForward::deduce_layout( | |||
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||
deduce_layout_fwd(data, weight, bias, dst, mean, rstd); | |||
} | |||
void GroupNormForward::check_exec( | |||
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd, | |||
size_t workspace_in_bytes) { | |||
check_layout_fwd(data, weight, bias, dst, mean, rstd); | |||
auto required_workspace_in_bytes = | |||
get_workspace_in_bytes(data, weight, bias, dst, mean, rstd); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
void GroupNormBackward::deduce_layout( | |||
const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight, | |||
const TensorLayout& mean, const TensorLayout& rstd, TensorLayout& ddata, | |||
TensorLayout& dweight, TensorLayout& dbias) { | |||
MEGDNN_MARK_USED_VAR(diff); | |||
MEGDNN_MARK_USED_VAR(mean); | |||
MEGDNN_MARK_USED_VAR(rstd); | |||
ddata = data; | |||
dweight = weight; | |||
dbias = weight; | |||
} | |||
void GroupNormBackward::check_exec( | |||
const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight, | |||
const TensorLayout& mean, const TensorLayout& rstd, const TensorLayout& ddata, | |||
const TensorLayout& dweight, const TensorLayout& dbias, | |||
size_t workspace_in_bytes) { | |||
auto p = param(); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes( | |||
diff, data, weight, mean, rstd, ddata, dweight, dbias); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
megdnn_assert_contiguous(diff); | |||
megdnn_assert_contiguous(data); | |||
megdnn_assert_contiguous(mean); | |||
megdnn_assert_contiguous(rstd); | |||
megdnn_assert_contiguous(ddata); | |||
if (p.affine) { | |||
megdnn_assert_contiguous(weight); | |||
megdnn_assert_contiguous(dweight); | |||
megdnn_assert_contiguous(dbias); | |||
} | |||
auto errmsg = [&]() { | |||
return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(data) + ", " + | |||
megdnn_layout_msg(weight) + ", " + megdnn_layout_msg(mean) + ", " + | |||
megdnn_layout_msg(rstd) + ", " + megdnn_layout_msg(ddata) + ", " + | |||
megdnn_layout_msg(dweight) + ", " + megdnn_layout_msg(dbias); | |||
}; | |||
MEGDNN_MARK_USED_VAR(errmsg); | |||
megdnn_assert(data.eq_layout(ddata), "%s", errmsg().c_str()); | |||
megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str()); | |||
if (p.affine) { | |||
megdnn_assert(weight.eq_layout(dweight), "%s", errmsg().c_str()); | |||
megdnn_assert(weight.eq_layout(dbias), "%s", errmsg().c_str()); | |||
} | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -216,7 +216,9 @@ private: | |||
cb(NormForward) \ | |||
cb(RegionRestrictedConvolutionForward) \ | |||
cb(RegionRestrictedConvolutionBackwardData) \ | |||
cb(RegionRestrictedConvolutionBackwardFilter) | |||
cb(RegionRestrictedConvolutionBackwardFilter) \ | |||
cb(GroupNormForward) \ | |||
cb(GroupNormBackward) | |||
// clang-format on | |||
/*! | |||
@@ -142,6 +142,8 @@ DEF(SoftmaxBackward, 3, true, false); | |||
DEF(RegionRestrictedConvolutionForward, 5, true, true); | |||
DEF(RegionRestrictedConvolutionBackwardData, 5, true, false); | |||
DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false); | |||
DEF(GroupNormForward, 6, true, true); | |||
DEF(GroupNormBackward, 8, true, true); | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,529 @@ | |||
#include <stdio.h> | |||
#include <thrust/pair.h> | |||
#include <thrust/tuple.h> | |||
#include <cfloat> | |||
#include "megdnn/arch.h" | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/dtype.h" | |||
#include "src/cuda/cuda_shfl_compat.cuh" | |||
#include "src/cuda/group_norm/group_norm_cuda.cuh" | |||
#include "src/cuda/utils.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace group_norm { | |||
// warp size may be used as array length, or used in host function, | |||
// so we define WARP_SIZE rather than using warpSize | |||
#define WARP_SIZE 32 | |||
template <size_t kStart, size_t kEnd, bool kStop> | |||
struct Compare { | |||
template <typename T> | |||
__host__ __device__ inline static bool Run(const T* d1, const T* d2) { | |||
return d1[kStart] == d2[kStart] && | |||
Compare<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d1, d2); | |||
} | |||
}; | |||
template <size_t kStart, size_t kEnd> | |||
struct Compare<kStart, kEnd, true> { | |||
template <typename T> | |||
__host__ __device__ inline constexpr static bool Run(const T* d1, const T* d2) { | |||
return true; | |||
} | |||
}; | |||
template <size_t N> | |||
using UnrollCompare = Compare<0, N, N == 0>; | |||
template <typename T, size_t kStart, size_t kEnd, bool kStop> | |||
struct UnrollVarArgsAssignImpl { | |||
template <typename... Args> | |||
__host__ __device__ inline static void Run(T* d, T val, Args... args) { | |||
static_assert(sizeof...(args) + 1 == kEnd - kStart, "Wrong argument"); | |||
d[kStart] = val; | |||
UnrollVarArgsAssignImpl<T, kStart + 1, kEnd, kStart + 1 == kEnd>::Run( | |||
d, args...); | |||
} | |||
}; | |||
template <typename T, size_t kStart, size_t kEnd> | |||
struct UnrollVarArgsAssignImpl<T, kStart, kEnd, true> { | |||
__host__ __device__ inline static void Run(T* d) {} | |||
}; | |||
template <typename T> | |||
struct UnrollVarArgsAssign { | |||
template <typename... Args> | |||
__host__ __device__ inline static void Run(T* d, Args... args) { | |||
UnrollVarArgsAssignImpl<T, 0, sizeof...(Args), sizeof...(Args) == 0>::Run( | |||
d, args...); | |||
} | |||
}; | |||
template <typename T, size_t N> | |||
class Array { | |||
public: | |||
static constexpr size_t kSize = N; | |||
__host__ __device__ inline Array() {} | |||
template <typename... Args> | |||
__host__ __device__ inline explicit Array(const T& val, Args... args) { | |||
static_assert(N == sizeof...(Args) + 1, "Invalid argument"); | |||
UnrollVarArgsAssign<T>::Run(data_, val, args...); | |||
} | |||
__host__ __device__ inline T& operator[](size_t i) { return *(data_ + i); } | |||
__host__ __device__ inline const T& operator[](size_t i) const { | |||
return *(data_ + i); | |||
} | |||
private: | |||
template <typename U> | |||
__host__ __device__ static inline U* advance(U* ptr, size_t i) { | |||
return ptr + i; | |||
} | |||
T data_[N]; | |||
}; | |||
// ================================ group_norm forward =========================== | |||
// implementation of groupnorm_forward from | |||
// https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/gpu/group_norm_kernel.cu#L115 | |||
template <typename T> | |||
__forceinline__ __device__ T | |||
CudaShuffleDownSync(T val, int delta, int width = warpSize) { | |||
return __shfl_down(val, static_cast<unsigned>(delta), width); | |||
} | |||
template <> | |||
__forceinline__ __device__ dt_float16 | |||
CudaShuffleDownSync(dt_float16 val, int delta, int width) { | |||
return dt_float16(__shfl_down(val, static_cast<unsigned>(delta), width)); | |||
} | |||
template <> | |||
__forceinline__ __device__ dt_bfloat16 | |||
CudaShuffleDownSync(dt_bfloat16 val, int delta, int width) { | |||
return dt_bfloat16(__shfl_down(val, static_cast<unsigned>(delta), width)); | |||
} | |||
template <typename T, int VecSize> | |||
struct alignas(sizeof(T) * VecSize) VectorType { | |||
T val[VecSize]; | |||
}; | |||
template <typename T> | |||
struct AddFunctor { | |||
inline T initial() { return static_cast<T>(0.0f); } | |||
__device__ __forceinline__ T operator()(const T a, const T b) const { | |||
return b + a; | |||
} | |||
}; | |||
template <typename T, typename ReduceOp> | |||
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) { | |||
for (int stride = WARP_SIZE / 2; stride > 0; stride >>= 1) { | |||
T temp = CudaShuffleDownSync<T>(val, stride); | |||
val = reducer(val, temp); | |||
} | |||
return val; | |||
} | |||
template <typename T, typename ReduceOp> | |||
__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) { | |||
__syncthreads(); | |||
__shared__ T shared[64]; | |||
int block_dim_x = blockDim.x; | |||
if (blockDim.x > WARP_SIZE) { | |||
block_dim_x = blockDim.x / WARP_SIZE; | |||
int lane = threadIdx.x % WARP_SIZE; | |||
int tid = threadIdx.y * blockDim.x + threadIdx.x; | |||
int wid = tid / WARP_SIZE; | |||
int bid = threadIdx.y; | |||
val = WarpReduce<T, ReduceOp>(val, reducer); | |||
if (lane == 0) { | |||
shared[wid] = val; | |||
} | |||
__syncthreads(); | |||
val = shared[bid * block_dim_x + lane]; | |||
} | |||
for (int stride = 1; stride < block_dim_x; stride <<= 1) { | |||
T temp = CudaShuffleDownSync(val, stride); | |||
val = reducer(val, temp); | |||
} | |||
if (threadIdx.x == 0) { | |||
shared[threadIdx.y] = val; | |||
} | |||
__syncthreads(); | |||
return shared[threadIdx.y]; | |||
} | |||
template <typename T> | |||
__device__ __forceinline__ void ReduceMeanAndVar( | |||
T* mean, T* var, T x_mean, T x_var, int size) { | |||
const int nc = blockIdx.x; | |||
x_mean = BlockXReduce<T, AddFunctor<T>>(x_mean, AddFunctor<T>()); | |||
x_var = BlockXReduce<T, AddFunctor<T>>(x_var, AddFunctor<T>()); | |||
__syncthreads(); | |||
if (threadIdx.x == 0) { | |||
mean[nc] = static_cast<T>(x_mean / size); | |||
var[nc] = static_cast<T>(x_var / size); | |||
} | |||
} | |||
template <typename T, typename T_ACC, int VecSize, int Num> | |||
__device__ __forceinline__ void ThreadReduce( | |||
Array<const T*, Num> arrs, int size, const int offset, T_ACC* out_mean, | |||
T_ACC* out_var) { | |||
const T* x = arrs[0]; | |||
const T* y; | |||
if (Num == 2) { | |||
y = arrs[1]; | |||
} | |||
using VecT = VectorType<T, VecSize>; | |||
int tid = threadIdx.x; | |||
if (offset > 0) { | |||
x -= offset; | |||
if (Num == 2) { | |||
y -= offset; | |||
} | |||
size += offset; | |||
if (tid >= offset) { | |||
if (Num == 1) { | |||
*out_mean += x[tid]; | |||
*out_var += x[tid] * x[tid]; | |||
} else if (Num == 2) { | |||
*out_mean += y[tid]; | |||
*out_var += y[tid] * x[tid]; | |||
} | |||
} | |||
size -= blockDim.x; | |||
x += blockDim.x; | |||
if (Num == 2) { | |||
y += blockDim.x; | |||
} | |||
} | |||
int remain = size % (VecSize * blockDim.x); | |||
T ins_x[VecSize]; | |||
T ins_y[VecSize]; | |||
VecT* ins_vec_x = reinterpret_cast<VecT*>(&ins_x); | |||
VecT* ins_vec_y = reinterpret_cast<VecT*>(&ins_y); | |||
// vector part | |||
for (; VecSize * tid < (size - remain); tid += blockDim.x) { | |||
*ins_vec_x = reinterpret_cast<const VecT*>(x)[tid]; | |||
if (Num == 2) { | |||
*ins_vec_y = reinterpret_cast<const VecT*>(y)[tid]; | |||
} | |||
#pragma unroll | |||
for (int i = 0; i < VecSize; ++i) { | |||
if (Num == 1) { | |||
*out_mean += ins_x[i]; | |||
*out_var += ins_x[i] * ins_x[i]; | |||
} else if (Num == 2) { | |||
*out_mean += ins_y[i]; | |||
*out_var += ins_y[i] * ins_x[i]; | |||
} | |||
} | |||
} | |||
// scalar part | |||
tid = size - remain + threadIdx.x; | |||
for (; tid < size; tid += blockDim.x) { | |||
if (Num == 1) { | |||
*out_mean += x[tid]; | |||
*out_var += x[tid] * x[tid]; | |||
} else if (Num == 2) { | |||
*out_mean += y[tid]; | |||
*out_var += y[tid] * x[tid]; | |||
} | |||
} | |||
} | |||
template <typename T, typename T_ACC> | |||
__global__ void ScalarGetMeanAndVar(const T* x, T_ACC* mean, T_ACC* var, int size) { | |||
int i = blockIdx.x; | |||
T_ACC x_mean = static_cast<T_ACC>(0); | |||
T_ACC x_var = static_cast<T_ACC>(0); | |||
for (int j = threadIdx.x; j < size; j += blockDim.x) { | |||
T val; | |||
val = x[i * size + j]; | |||
x_mean += val; | |||
x_var += val * val; | |||
} | |||
ReduceMeanAndVar<T_ACC>(mean, var, x_mean, x_var, size); | |||
} | |||
template <typename T, typename T_ACC, int VecSize> | |||
__global__ void VectorizedGetMeanAndVar(const T* x, T_ACC* mean, T_ACC* var, int size) { | |||
int i = blockIdx.x; | |||
T_ACC x_mean = static_cast<T_ACC>(0); | |||
T_ACC x_var = static_cast<T_ACC>(0); | |||
x += i * size; | |||
const int input_offset = ((uint64_t)x) % 16 / sizeof(T); | |||
Array<const T*, 1> ins; | |||
ins[0] = x; | |||
ThreadReduce<T, T_ACC, VecSize, 1>(ins, size, input_offset, &x_mean, &x_var); | |||
ReduceMeanAndVar<T_ACC>(mean, var, x_mean, x_var, size); | |||
} | |||
template <typename T, typename T_ACC> | |||
__global__ void GroupNormForward( | |||
const T* x, const T_ACC* mean, const T_ACC* var, const T* scale, const T* bias, | |||
int N, int C, int W, int imsize, int groups, int group_size, float epsilon, | |||
T* y, T_ACC* real_var) { | |||
int gid = blockIdx.y; | |||
int cid = blockIdx.x; | |||
int bid = blockIdx.z; | |||
int ccid = gid * group_size + cid; | |||
if (ccid >= C) | |||
return; | |||
auto ng = bid * groups + gid; | |||
T_ACC x_mean = mean[ng]; | |||
T_ACC x_var = var[ng]; | |||
x_var = x_var - x_mean * x_mean; | |||
T_ACC var_inv = rsqrt(x_var + epsilon); | |||
if (cid == 0 && threadIdx.x == 0) { | |||
real_var[ng] = x_var; | |||
} | |||
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { | |||
T val; | |||
int index = (bid * C + ccid) * imsize + imid; | |||
val = x[index]; | |||
val = (val - x_mean) * var_inv; | |||
if (scale != nullptr) { | |||
val *= scale[ccid]; | |||
} | |||
if (bias != nullptr) { | |||
val += bias[ccid]; | |||
} | |||
y[index] = val; | |||
} | |||
} | |||
template <typename T, typename T_ACC> | |||
void forward( | |||
T* src, T* weight, T* bias, T* dst, T_ACC* mean, T_ACC* rstd, T_ACC* temp_rstd, | |||
T_ACC eps, int group, int N, int C, int W, int imsize, cudaStream_t stream) { | |||
auto group_size = C / group; | |||
int block_size = std::min(1024, imsize); | |||
dim3 grid(group_size, group, N); | |||
dim3 threads(block_size, 1, 1); | |||
int size = group_size * imsize; | |||
constexpr int vec_size = sizeof(float4) / sizeof(T); | |||
int max_block_size = std::min(size / vec_size, 1024); | |||
int block_size_temp = 1; | |||
while (block_size_temp < max_block_size) { | |||
block_size_temp *= 2; | |||
} | |||
block_size_temp = std::max(block_size_temp, WARP_SIZE); | |||
dim3 grids(N * group); | |||
dim3 blocks(block_size_temp); | |||
if (size < vec_size * block_size_temp) { | |||
ScalarGetMeanAndVar<T, T_ACC> | |||
<<<grids, blocks, 0, stream>>>(src, mean, temp_rstd, size); | |||
after_kernel_launch(); | |||
} else { | |||
VectorizedGetMeanAndVar<T, T_ACC, vec_size> | |||
<<<grids, blocks, 0, stream>>>(src, mean, temp_rstd, size); | |||
after_kernel_launch(); | |||
} | |||
GroupNormForward<T, T_ACC><<<grid, threads, 0, stream>>>( | |||
src, mean, temp_rstd, weight, bias, N, C, W, imsize, group, group_size, eps, | |||
dst, rstd); | |||
after_kernel_launch(); | |||
} | |||
// ================================ group_norm backward =========================== | |||
// implementation of groupnorm_backward from | |||
// https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu#L253 | |||
template <typename T, typename T_ACC> | |||
__global__ void GetDsDbCUDAKernel(int imsize, const T* x, const T* dy, T* ds, T* db) { | |||
const int nc = blockIdx.x; | |||
T ds_sum = static_cast<T>(0); | |||
T db_sum = static_cast<T>(0); | |||
for (int i = threadIdx.x; i < imsize; i += blockDim.x) { | |||
const int index = nc * imsize + i; | |||
ds_sum += dy[index] * x[index]; | |||
db_sum += dy[index]; | |||
} | |||
ReduceMeanAndVar<T>(db, ds, db_sum, ds_sum, 1); | |||
} | |||
template <typename T, typename T_ACC> | |||
__global__ void GetBiasGradientCUDAKernel( | |||
int N, int C, int group, T_ACC epsilon, const T_ACC* mean, const T_ACC* var, | |||
const T* ds, const T* db, T* d_scale, T* d_bias) { | |||
const int c = blockIdx.x * blockDim.x + threadIdx.x; | |||
if (c < C) { | |||
const int G = group; | |||
const int D = C / G; | |||
T sum1 = static_cast<T>(0); | |||
T sum2 = static_cast<T>(0); | |||
for (int n = 0; n < N; ++n) { | |||
const int nc = n * C + c; | |||
const int ng = n * G + c / D; | |||
sum1 += (d_scale == nullptr) | |||
? T(0) | |||
: ((ds[nc] - db[nc] * static_cast<T>(mean[ng])) * | |||
static_cast<T>(rsqrt((float)(var[ng] + epsilon)))); | |||
sum2 += (d_bias == nullptr) ? T(0) : db[nc]; | |||
} | |||
if (d_scale != nullptr) { | |||
d_scale[c] = sum1; | |||
} | |||
if (d_bias != nullptr) { | |||
d_bias[c] = sum2; | |||
} | |||
} | |||
} | |||
template <typename T> | |||
__inline__ MEGDNN_DEVICE T warp_reduce_sum(T val) { | |||
#pragma unroll | |||
for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { | |||
val += __shfl_down(val, offset, warpSize); | |||
} | |||
return val; | |||
} | |||
template <typename T> | |||
__inline__ MEGDNN_DEVICE T BlockReduceSum(T val, T* shared) { | |||
const int lid = threadIdx.x % warpSize; | |||
const int wid = threadIdx.x / warpSize; | |||
val = warp_reduce_sum(val); | |||
__syncthreads(); | |||
if (lid == 0) { | |||
shared[wid] = val; | |||
} | |||
__syncthreads(); | |||
val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : T(0); | |||
if (wid == 0) { | |||
val = warp_reduce_sum(val); | |||
} | |||
return val; | |||
} | |||
template <typename T, typename T_ACC, int BlockDim> | |||
__global__ void GetBackwardParamsCUDAKernel( | |||
int imsize, int groups, int group_size, T_ACC epsilon, const T_ACC* mean, | |||
const T_ACC* var, const T* scale, const T* ds, const T* db, T* p1, T* p2, | |||
T* p3) { | |||
const int n = blockIdx.x; | |||
const int g = blockIdx.y; | |||
const int ng = n * groups + g; | |||
T sum1 = static_cast<T>(0); | |||
T sum2 = static_cast<T>(0); | |||
T var_inv = static_cast<T>(rsqrt(var[ng] + epsilon)); | |||
for (int64_t i = threadIdx.x; i < group_size; i += blockDim.x) { | |||
const int64_t index = ng * group_size + i; | |||
const int64_t c = g * group_size + i; | |||
const T scale_v = scale == nullptr ? T(1) : static_cast<T>(scale[c]); | |||
sum1 += ds[index] * scale_v; | |||
sum2 += db[index] * scale_v; | |||
const T scale_c = scale == nullptr ? T(0) : static_cast<T>(scale[c]); | |||
p1[index] = scale_c * var_inv; | |||
} | |||
__shared__ T ds_shared[WARP_SIZE]; | |||
__shared__ T db_shared[WARP_SIZE]; | |||
sum1 = BlockReduceSum<T>(sum1, ds_shared); | |||
sum2 = BlockReduceSum<T>(sum2, db_shared); | |||
if (threadIdx.x == 0) { | |||
const T s = T(1) / static_cast<T>(group_size * imsize); | |||
const T x = (sum2 * static_cast<T>(mean[ng]) - sum1) * static_cast<T>(var_inv) * | |||
static_cast<T>(var_inv) * static_cast<T>(var_inv) * s; | |||
p2[ng] = x; | |||
p3[ng] = -x * static_cast<T>(mean[ng]) - sum2 * static_cast<T>(var_inv) * s; | |||
} | |||
} | |||
template <typename T, typename T_ACC> | |||
__global__ void GetXGradientCUDAKernel( | |||
int imsize, int C, int group_size, int groups, T* p1, T* p2, T* p3, const T* x, | |||
const T* dy, T* dx) { | |||
int cid = blockIdx.x; | |||
int gid = blockIdx.y; | |||
int bid = blockIdx.z; | |||
int ccid = bid * C + gid * group_size + cid; | |||
int ng = bid * groups + gid; | |||
int nc = gid * group_size + cid; | |||
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { | |||
int index = (bid * C + nc) * imsize + imid; | |||
dx[index] = p1[ccid] * dy[index] + p2[ng] * x[index] + p3[ng]; | |||
} | |||
} | |||
template <typename T, typename T_ACC> | |||
void backward( | |||
const T* dY_data, const T* X_data, const T_ACC* mean_data, | |||
const T_ACC* rstd_data, const T* weight_data, T* dX_data, T* dweight_data, | |||
T* dbias_data, T_ACC eps, int group, int N, int C, int imsize, T* ds, T* db, | |||
T* p1, T* p2, T* p3, cudaStream_t stream) { | |||
auto group_size = C / group; | |||
int block_size = std::min(1024, imsize); | |||
const int block_dims = 1024; | |||
dim3 grid(group_size, group, N); | |||
dim3 threads(block_size, 1, 1); | |||
const int max_num_threads = 1024; | |||
int max_block_size = std::min(imsize, max_num_threads); | |||
int block_size_temp = 1; | |||
while (block_size_temp < max_block_size) { | |||
block_size_temp *= 2; | |||
} | |||
block_size_temp = std::max(block_size_temp, WARP_SIZE); | |||
dim3 blocks(block_size_temp); | |||
GetDsDbCUDAKernel<T, T_ACC> | |||
<<<N * C, blocks, 0, stream>>>(imsize, X_data, dY_data, ds, db); | |||
after_kernel_launch(); | |||
bool flag = weight_data != nullptr ? true : false; | |||
if (flag) { | |||
const int block = 256; | |||
GetBiasGradientCUDAKernel<T, T_ACC> | |||
<<<(C + block - 1) / block, block, 0, stream>>>( | |||
N, C, group, eps, mean_data, rstd_data, ds, db, dweight_data, | |||
dbias_data); | |||
after_kernel_launch(); | |||
} | |||
GetBackwardParamsCUDAKernel<T, T_ACC, block_dims> | |||
<<<dim3(N, group), block_dims, 0, stream>>>( | |||
imsize, group, group_size, eps, mean_data, rstd_data, weight_data, | |||
ds, db, p1, p2, p3); | |||
after_kernel_launch(); | |||
GetXGradientCUDAKernel<T, T_ACC><<<grid, threads, 0, stream>>>( | |||
imsize, C, group_size, group, p1, p2, p3, X_data, dY_data, dX_data); | |||
after_kernel_launch(); | |||
} | |||
#define INST(T, T_ACC) \ | |||
template void forward<T, T_ACC>( \ | |||
T*, T*, T*, T*, T_ACC*, T_ACC*, T_ACC*, T_ACC, int, int, int, int, int, \ | |||
cudaStream_t); \ | |||
template void backward<T, T_ACC>( \ | |||
const T*, const T*, const T_ACC*, const T_ACC*, const T*, T*, T*, T*, \ | |||
T_ACC, int, int, int, int, T*, T*, T*, T*, T*, cudaStream_t); | |||
INST(dt_float32, dt_float32) | |||
INST(dt_float16, dt_float32) | |||
INST(dt_bfloat16, dt_float32) | |||
#undef INST | |||
} // namespace group_norm | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,24 @@ | |||
#pragma once | |||
#include <cuda_runtime_api.h> | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace group_norm { | |||
template <typename T, typename T_ACC> | |||
void forward( | |||
T* X, T* gamma, T* beta, T* Y, T_ACC* mean, T_ACC* rstd, T_ACC* tesmp_rstd, | |||
T_ACC eps, int group, int N, int C, int W, int imsize, cudaStream_t stream); | |||
template <typename T, typename T_ACC> | |||
void backward( | |||
const T* dY_data, const T* X_data, const T_ACC* mean_data, | |||
const T_ACC* rstd_data, const T* gamma_data, T* dX_data, T* dgamma_data, | |||
T* dbeta_data, T_ACC eps, int group, int N, int C, int imsize, T* ds, T* db, | |||
T* p1, T* p2, T* p3, cudaStream_t stream); | |||
} // namespace group_norm | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,143 @@ | |||
#include "src/cuda/group_norm/opr_impl.h" | |||
#include "src/cuda/group_norm/group_norm_cuda.cuh" | |||
#include "src/cuda/utils.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
size_t GroupNormForwardImpl::get_workspace_in_bytes( | |||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, const TensorLayout& rstd) { | |||
size_t N = rstd.shape[0]; | |||
size_t G = rstd.shape[1]; | |||
return get_workspace_bundle(N, G, rstd.dtype.size()).total_size_in_bytes(); | |||
} | |||
WorkspaceBundle GroupNormForwardImpl::get_workspace_bundle( | |||
size_t N, size_t G, size_t dtype_size, void* raw_ptr) { | |||
return {raw_ptr, {N * G * dtype_size}, handle()->alignment_requirement()}; | |||
} | |||
void GroupNormForwardImpl::exec( | |||
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||
_megdnn_workspace workspace) { | |||
check_exec( | |||
data.layout, weight.layout, bias.layout, dst.layout, mean.layout, | |||
rstd.layout, workspace.size); | |||
auto p = param(); | |||
using Format = param::GroupNorm::Format; | |||
float eps = p.eps; | |||
int group = p.group; | |||
bool affine = p.affine; | |||
auto layout = data.layout; | |||
auto format = p.format; | |||
size_t N, C, H, W, imsize; | |||
if (data.layout.ndim == 4 && format == Format::NCHW) { | |||
N = layout.shape[0]; | |||
C = layout.shape[1]; | |||
H = layout.shape[2]; | |||
W = layout.shape[3]; | |||
imsize = H * W; | |||
} else { | |||
megdnn_throw(ssprintf("Unspport groupnorm input")); | |||
} | |||
auto stream = cuda_stream(handle()); | |||
using namespace ::megdnn::cuda::group_norm; | |||
auto wbundle = | |||
get_workspace_bundle(N, group, rstd.layout.dtype.size(), workspace.raw_ptr); | |||
#define cb(DType) \ | |||
if (data.layout.dtype == DType()) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
using T_ACC = float; \ | |||
T_ACC* temp_rstd = wbundle.get_workspace(0).ptr<T_ACC>(); \ | |||
forward<T, T_ACC>( \ | |||
data.ptr<T>(), affine ? weight.ptr<T>() : nullptr, \ | |||
affine ? bias.ptr<T>() : nullptr, dst.ptr<T>(), mean.ptr<T_ACC>(), \ | |||
rstd.ptr<T_ACC>(), temp_rstd, static_cast<T_ACC>(eps), \ | |||
static_cast<int>(group), static_cast<int>(N), static_cast<int>(C), \ | |||
static_cast<int>(W), static_cast<int>(imsize), stream); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
#undef cb | |||
megdnn_throw("bad dtype"); | |||
} | |||
size_t GroupNormBackwardImpl::get_workspace_in_bytes( | |||
const TensorLayout&, const TensorLayout& data, const TensorLayout&, | |||
const TensorLayout& mean, const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&) { | |||
size_t N = data.shape[0]; | |||
size_t C = data.shape[1]; | |||
size_t G = mean.shape[1]; | |||
return get_workspace_bundle(N, C, G, data.dtype.size()).total_size_in_bytes(); | |||
} | |||
WorkspaceBundle GroupNormBackwardImpl::get_workspace_bundle( | |||
size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr) { | |||
return {raw_ptr, | |||
{N * C * dtype_size, N * C * dtype_size, N * C * dtype_size, | |||
N * G * dtype_size, N * G * dtype_size}, | |||
handle()->alignment_requirement()}; | |||
} | |||
void GroupNormBackwardImpl::exec( | |||
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||
_megdnn_workspace workspace) { | |||
check_exec( | |||
diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, | |||
ddata.layout, dweight.layout, dbias.layout, workspace.size); | |||
auto p = param(); | |||
using Format = param::GroupNorm::Format; | |||
bool affine = p.affine; | |||
float eps = p.eps; | |||
int group = p.group; | |||
auto layout = data.layout; | |||
auto format = p.format; | |||
size_t N, C, H, W, imsize; | |||
if (layout.ndim == 4 && format == Format::NCHW) { | |||
N = layout.shape[0]; | |||
C = layout.shape[1]; | |||
H = layout.shape[2]; | |||
W = layout.shape[3]; | |||
imsize = H * W; | |||
} else { | |||
megdnn_throw(ssprintf("Unspport groupnorm input")); | |||
} | |||
auto stream = cuda_stream(handle()); | |||
using namespace ::megdnn::cuda::group_norm; | |||
auto wbundle = get_workspace_bundle( | |||
N, C, group, data.layout.dtype.size(), workspace.raw_ptr); | |||
#define cb(DType) \ | |||
if (data.layout.dtype == DType()) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
using T_ACC = float; \ | |||
T* ds = wbundle.get_workspace(0).ptr<T>(); \ | |||
T* db = wbundle.get_workspace(1).ptr<T>(); \ | |||
T* p1 = wbundle.get_workspace(2).ptr<T>(); \ | |||
T* p2 = wbundle.get_workspace(3).ptr<T>(); \ | |||
T* p3 = wbundle.get_workspace(4).ptr<T>(); \ | |||
backward<T, T_ACC>( \ | |||
diff.ptr<T>(), data.ptr<T>(), mean.ptr<T_ACC>(), rstd.ptr<T_ACC>(), \ | |||
affine ? weight.ptr<T>() : nullptr, ddata.ptr<T>(), \ | |||
affine ? dweight.ptr<T>() : nullptr, \ | |||
affine ? dbias.ptr<T>() : nullptr, static_cast<T_ACC>(eps), \ | |||
static_cast<int>(group), static_cast<int>(N), static_cast<int>(C), \ | |||
static_cast<int>(imsize), ds, db, p1, p2, p3, stream); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
#undef cb | |||
megdnn_throw("bad dtype"); | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,47 @@ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
#include "src/cuda/cudnn_wrapper.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class GroupNormForwardImpl final : public GroupNormForward { | |||
public: | |||
using GroupNormForward::GroupNormForward; | |||
void exec( | |||
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, | |||
const TensorLayout& rstd) override; | |||
private: | |||
WorkspaceBundle get_workspace_bundle( | |||
size_t N, size_t G, size_t dtype_size, void* raw_ptr = nullptr); | |||
}; | |||
class GroupNormBackwardImpl final : public GroupNormBackward { | |||
public: | |||
using GroupNormBackward::GroupNormBackward; | |||
void exec( | |||
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout&, const TensorLayout& data, const TensorLayout&, | |||
const TensorLayout& mean, const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&) override; | |||
private: | |||
WorkspaceBundle get_workspace_bundle( | |||
size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr = nullptr); | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -32,6 +32,7 @@ | |||
#include "src/cuda/flip/opr_impl.h" | |||
#include "src/cuda/gaussian_blur/opr_impl.h" | |||
#include "src/cuda/group_local/opr_impl.h" | |||
#include "src/cuda/group_norm/opr_impl.h" | |||
#include "src/cuda/images2neibs/opr_impl.h" | |||
#include "src/cuda/indexing_multi_axis_vec/opr_impl.h" | |||
#include "src/cuda/indexing_one_hot/opr_impl.h" | |||
@@ -163,6 +164,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalBackwardData); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalBackwardFilter); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupNormForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupNormBackward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Flip); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Rotate); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ROICopy); | |||
@@ -0,0 +1,206 @@ | |||
#include "src/naive/group_norm/opr_impl.h" | |||
#include <algorithm> | |||
#include "src/common/utils.h" | |||
#include "src/naive/handle.h" | |||
using namespace megdnn; | |||
using namespace naive; | |||
namespace { | |||
using Param = megdnn::GroupNorm::Param; | |||
template <typename T, typename T_ACC = float> | |||
void forward( | |||
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||
const Param& param) { | |||
float eps = param.eps; | |||
bool affine = param.affine; | |||
size_t N = data.layout.shape[0]; | |||
size_t C = data.layout.shape[1]; | |||
size_t HxW = data.layout.shape[2] * data.layout.shape[3]; | |||
const int64_t G = param.group; | |||
size_t D = C / G; | |||
size_t inner_size = D * HxW; | |||
for (size_t i = 0; i < N * G; i++) { | |||
T_ACC slice_sum = static_cast<T>(0.0f); | |||
for (size_t j = 0; j < inner_size; j++) { | |||
auto value = data.ptr<T>()[i * inner_size + j]; | |||
slice_sum += value; | |||
} | |||
T_ACC slice_mean = static_cast<T>(slice_sum / inner_size); | |||
T_ACC slice_var = static_cast<T>(0.0f); | |||
for (size_t j = 0; j < inner_size; j++) { | |||
slice_var += (data.ptr<T>()[i * inner_size + j] - slice_mean) * | |||
(data.ptr<T>()[i * inner_size + j] - slice_mean); | |||
} | |||
slice_var = slice_var / inner_size; | |||
T_ACC slice_std = static_cast<T>(1.0f) / static_cast<T>(sqrt(slice_var + eps)); | |||
if (affine) { | |||
const int64_t g = i % G; | |||
for (size_t j = 0; j < D; j++) { | |||
const int64_t c = g * D + j; | |||
T_ACC s = slice_std * weight.ptr<T>()[c]; | |||
T_ACC b = -s * slice_mean + bias.ptr<T>()[c]; | |||
for (size_t k = 0; k < HxW; k++) { | |||
dst.ptr<T>()[(i * D + j) * HxW + k] = | |||
s * data.ptr<T>()[(i * D + j) * HxW + k] + b; | |||
} | |||
} | |||
} else { | |||
for (size_t j = 0; j < inner_size; j++) { | |||
dst.ptr<T>()[i * inner_size + j] = | |||
(data.ptr<T>()[i * inner_size + j] - slice_mean) / slice_std; | |||
} | |||
} | |||
mean.ptr<T_ACC>()[i] = static_cast<T_ACC>(slice_mean); | |||
rstd.ptr<T_ACC>()[i] = static_cast<T_ACC>(slice_var); | |||
} | |||
} | |||
template <typename T, typename T_ACC = float> | |||
void backward( | |||
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, const Param& param, | |||
WorkspaceBundle wbundle) { | |||
bool affine = param.affine; | |||
size_t N = data.layout.shape[0]; | |||
size_t C = data.layout.shape[1]; | |||
size_t G = param.group; | |||
float eps = param.eps; | |||
size_t HxW = data.layout.shape[2] * data.layout.shape[3]; | |||
T* ds = wbundle.get_workspace(0).ptr<T>(); | |||
T* db = wbundle.get_workspace(1).ptr<T>(); | |||
T* slice_std = wbundle.get_workspace(2).ptr<T>(); | |||
for (size_t i = 0; i < N * G; i++) { | |||
slice_std[i] = | |||
static_cast<T>(1.0f) / static_cast<T>(sqrt(rstd.ptr<T_ACC>()[i] + eps)); | |||
} | |||
for (size_t i = 0; i < N * C; i++) { | |||
T ds_data = static_cast<T>(0.0f); | |||
T db_data = static_cast<T>(0.0f); | |||
for (size_t j = 0; j < HxW; j++) { | |||
db_data += diff.ptr<T>()[i * HxW + j]; | |||
ds_data += data.ptr<T>()[i * HxW + j] * diff.ptr<T>()[i * HxW + j]; | |||
} | |||
ds[i] = ds_data; | |||
db[i] = db_data; | |||
} | |||
size_t D = C / G; | |||
const T s = T(1) / static_cast<T>(D * HxW); | |||
for (size_t i = 0; i < N * G; i++) { | |||
const int64_t g = i % G; | |||
T ds_v = static_cast<T>(0.0f); | |||
T db_v = static_cast<T>(0.0f); | |||
for (size_t j = 0; j < D; j += 1) { | |||
auto weight_v = affine ? weight.ptr<T>()[g * D + j] : static_cast<T>(1.0f); | |||
ds_v += ds[i * D + j] * weight_v; | |||
db_v += db[i * D + j] * weight_v; | |||
} | |||
auto c2 = (db_v * mean.ptr<T_ACC>()[i] - ds_v) * slice_std[i] * slice_std[i] * | |||
slice_std[i] * s; | |||
auto c3 = -c2 * mean.ptr<T_ACC>()[i] - db_v * slice_std[i] * s; | |||
for (size_t j = 0; j < D; j++) { | |||
const int64_t c = g * D + j; | |||
auto weight_v = affine ? weight.ptr<T>()[c] : static_cast<T>(1.0f); | |||
auto c1 = slice_std[i] * weight_v; | |||
for (size_t k = 0; k < HxW; k++) { | |||
ddata.ptr<T>()[(i * D + j) * HxW + k] = | |||
c1 * diff.ptr<T>()[(i * D + j) * HxW + k] + | |||
c2 * data.ptr<T>()[(i * D + j) * HxW + k] + c3; | |||
} | |||
} | |||
} | |||
if (affine) { | |||
for (size_t i = 0; i < C; ++i) { | |||
dweight.ptr<T>()[i] = 0; | |||
dbias.ptr<T>()[i] = 0; | |||
} | |||
for (size_t i = 0; i < N * G; i++) { | |||
auto g = i % G; | |||
for (size_t j = 0; j < D; j++) { | |||
auto c = g * D + j; | |||
dweight.ptr<T>()[c] += | |||
(ds[i * D + j] - db[i * D + j] * mean.ptr<T_ACC>()[i]) * | |||
slice_std[i]; | |||
} | |||
} | |||
for (size_t i = 0; i < N; i++) { | |||
for (size_t j = 0; j < C; j++) { | |||
dbias.ptr<T>()[j] += db[i * C + j]; | |||
} | |||
} | |||
} | |||
} | |||
} // namespace | |||
namespace megdnn { | |||
namespace naive { | |||
size_t GroupNormBackwardImpl::get_workspace_in_bytes( | |||
const TensorLayout&, const TensorLayout& data, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout& rstd, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&) { | |||
size_t N = data.shape[0]; | |||
size_t C = data.shape[1]; | |||
size_t G = rstd.shape[1]; | |||
return get_workspace_bundle(N, C, G, data.dtype.size()).total_size_in_bytes(); | |||
} | |||
WorkspaceBundle GroupNormBackwardImpl::get_workspace_bundle( | |||
size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr) { | |||
return {raw_ptr, | |||
{N * C * dtype_size, N * C * dtype_size, N * G * dtype_size}, | |||
handle()->alignment_requirement()}; | |||
} | |||
void GroupNormForwardImpl::exec( | |||
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||
_megdnn_workspace workspace) { | |||
check_exec( | |||
data.layout, weight.layout, bias.layout, dst.layout, mean.layout, | |||
rstd.layout, workspace.size); | |||
#define cb(DType) \ | |||
if (data.layout.dtype == DType()) { \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(forward<typename DTypeTrait<DType>::ctype>( \ | |||
data, weight, bias, dst, mean, rstd, param())); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
#undef cb | |||
megdnn_throw("bad dtype"); | |||
} | |||
void GroupNormBackwardImpl::exec( | |||
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||
_megdnn_workspace workspace) { | |||
check_exec( | |||
diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, | |||
ddata.layout, dweight.layout, dbias.layout, workspace.size); | |||
#define cb(DType) \ | |||
if (data.layout.dtype == DType()) { \ | |||
auto wbundle = get_workspace_bundle( \ | |||
data.layout.shape[0], data.layout.shape[1], rstd.layout.shape[1], \ | |||
sizeof(DTypeTrait<DType>::ctype), workspace.raw_ptr); \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(backward<typename DTypeTrait<DType>::ctype>( \ | |||
diff, data, weight, mean, rstd, ddata, dweight, dbias, param(), \ | |||
wbundle)); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
#undef cb | |||
megdnn_throw("bad dtype"); | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,44 @@ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class GroupNormForwardImpl final : public GroupNormForward { | |||
public: | |||
using GroupNormForward::GroupNormForward; | |||
void exec( | |||
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, | |||
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&, const TensorLayout&) override { | |||
return 0; | |||
} | |||
}; | |||
class GroupNormBackwardImpl final : public GroupNormBackward { | |||
public: | |||
using GroupNormBackward::GroupNormBackward; | |||
void exec( | |||
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, | |||
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, | |||
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout&, const TensorLayout& data, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout& rstd, const TensorLayout&, | |||
const TensorLayout&, const TensorLayout&) override; | |||
private: | |||
WorkspaceBundle get_workspace_bundle( | |||
size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr = nullptr); | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -34,6 +34,7 @@ | |||
#include "src/naive/flip/opr_impl.h" | |||
#include "src/naive/gaussian_blur/opr_impl.h" | |||
#include "src/naive/group_local/opr_impl.h" | |||
#include "src/naive/group_norm/opr_impl.h" | |||
#include "src/naive/images2neibs/opr_impl.h" | |||
#include "src/naive/indexing_multi_axis_vec/opr_impl.h" | |||
#include "src/naive/indexing_one_hot/opr_impl.h" | |||
@@ -0,0 +1,44 @@ | |||
#include "test/cuda/fixture.h" | |||
#include "test/common/checker.h" | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(CUDA, GROUPNORM_FORWARD) { | |||
using Param = GroupNormForward::Param; | |||
Param param; | |||
param.affine = true; | |||
param.eps = 1e-6; | |||
Checker<GroupNormForward> checker(handle_cuda()); | |||
checker.set_epsilon(1e-2); | |||
auto run = [&](DType d) { | |||
for (size_t group : {1, 3}) | |||
for (size_t C : {6, 9}) { | |||
param.group = group; | |||
checker.set_param(param) | |||
.set_dtype(0, d) | |||
.set_dtype(1, d) | |||
.set_dtype(2, d) | |||
.set_dtype(3, d) | |||
.set_dtype(4, dtype::Float32()) | |||
.set_dtype(5, dtype::Float32()) | |||
.execs({{2, C, 2, 1}, | |||
{C}, | |||
{C}, | |||
{2, C, 2, 1}, | |||
{2, group}, | |||
{2, group}}); | |||
} | |||
}; | |||
run(dtype::Float32()); | |||
run(dtype::Float16()); | |||
run(dtype::BFloat16()); | |||
} | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,70 @@ | |||
#include "megdnn/dtype.h" | |||
#include "megdnn/oprs.h" | |||
#include "test/common/checker.h" | |||
#include "test/naive/fixture.h" | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(NAIVE, GROUPNORM_FORWARD) { | |||
Checker<GroupNorm> checker(handle(), true); | |||
GroupNorm::Param param; | |||
param.affine = true; | |||
param.group = 3; | |||
checker.set_param(param).exect( | |||
Testcase{ | |||
TensorValue( | |||
{2, 3, 2, 1}, dtype::Float32(), | |||
{3.3179, 0.109, -0.5855, 0.2566, -1.2897, 1.2683, -2.0587, | |||
0.0711, -0.1169, 0.2509, -0.2393, 0.0876}), // input | |||
TensorValue({3}, dtype::Float32(), {1., 1., 1.}), // hx | |||
TensorValue({3}, dtype::Float32(), {0., 0., 0.}), // cx | |||
{}, | |||
{}, | |||
{}}, | |||
Testcase{ | |||
{}, | |||
{}, | |||
{}, | |||
TensorValue( | |||
{2, 3, 2, 1}, dtype::Float32(), | |||
{1., -1., -1., 1., -1., 1., -1., 1., -0.9999, 0.9999, | |||
-0.9998, 0.9998}), // output | |||
TensorValue( | |||
{2, 3}, dtype::Float32(), | |||
{1.7135, -0.1645, -0.0107, -0.9938, 0.067, | |||
-0.0758}), // mean | |||
TensorValue( | |||
{2, 3}, dtype::Float32(), | |||
{2.5742, 0.1772, 1.6358, 1.1340, 0.0338, 0.0267}), // var | |||
}); | |||
checker.set_param(param).exect( | |||
Testcase{ | |||
TensorValue( | |||
{1, 3, 1, 2}, dtype::Float32(), | |||
{-2.4348, -1.7948, 0.5223, 0.0932, -0.2955, | |||
-0.0492}), // input | |||
TensorValue({3}, dtype::Float32(), {1., 1., 1.}), // hx | |||
TensorValue({3}, dtype::Float32(), {0., 0., 0.}), // cx | |||
{}, | |||
{}, | |||
{}}, | |||
Testcase{ | |||
{}, | |||
{}, | |||
{}, | |||
TensorValue( | |||
{1, 3, 1, 2}, dtype::Float32(), | |||
{-0.9999, 0.9999, 0.9999, -0.9999, -0.9997, | |||
0.9997}), // output | |||
TensorValue( | |||
{1, 3}, dtype::Float32(), | |||
{-2.1148, 0.3077, -0.1724}), // mean | |||
TensorValue( | |||
{1, 3}, dtype::Float32(), {0.1023, 0.0460, 0.0151}), // var | |||
}); | |||
} | |||
} // namespace test | |||
} // namespace megdnn |
@@ -60,6 +60,7 @@ __all__ = [ | |||
"dropout", | |||
"embedding", | |||
"gelu", | |||
"group_norm", | |||
"hsigmoid", | |||
"hswish", | |||
"indexing_one_hot", | |||
@@ -1202,6 +1203,33 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: | |||
return output | |||
def group_norm( | |||
inp: Tensor, | |||
num_groups: int, | |||
affine: bool, | |||
weight: Optional[Tensor] = None, | |||
bias: Optional[Tensor] = None, | |||
eps: float = 1e-5, | |||
): | |||
r"""Applies Group Normalization over a mini-batch of inputs as described in | |||
the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__ | |||
Args: | |||
inp: input tensor. | |||
num_groups: number of groups to separate the channels into | |||
affine: whether to use weight and bias | |||
weight: must not be None when the affine is true | |||
bias: must not be None when the affine is true | |||
eps: a value added to the denominator for numerical stability. Default: 1e-5 | |||
""" | |||
op = builtin.GroupNorm(affine=affine, eps=eps, group=num_groups,) | |||
if affine: | |||
assert weight is not None and bias is not None | |||
return apply(op, inp, weight, bias)[0] | |||
else: | |||
return apply(op, inp)[0] | |||
def layer_norm( | |||
inp: Tensor, | |||
normalized_shape: tuple, | |||
@@ -34,21 +34,9 @@ class GroupNorm(Module): | |||
zeros_(self.bias) | |||
def forward(self, x): | |||
N, C, H, W = x.shape | |||
format = x.format | |||
assert C == self.num_channels | |||
x = x.reshape(N, self.num_groups, -1) | |||
mean = x.mean(axis=2, keepdims=True) | |||
var = (x * x).mean(axis=2, keepdims=True) - mean * mean | |||
x = (x - mean) / F.sqrt(var + self.eps) | |||
x = x.reshape(N, C, H, W) | |||
if self.affine: | |||
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) | |||
# FIXME(czh): remove this after making it a builtin op. | |||
if format == "nhwc": | |||
x = mge.amp.convert_tensor_format(x, inplace=False) | |||
x = F.nn.group_norm( | |||
x, self.num_groups, self.affine, self.weight, self.bias, self.eps | |||
) | |||
return x | |||
def _module_info_string(self) -> str: | |||
@@ -8,12 +8,14 @@ import pytest | |||
import megengine as mge | |||
import megengine.functional as F | |||
from megengine import Parameter, Tensor, tensor | |||
from megengine.device import get_device_count | |||
from megengine.module import ( | |||
BatchNorm1d, | |||
BatchNorm2d, | |||
Conv1d, | |||
Conv2d, | |||
Dropout, | |||
GroupNorm, | |||
Linear, | |||
MaxPool2d, | |||
Module, | |||
@@ -698,3 +700,67 @@ def test_module_compatible(): | |||
assert ( | |||
old_attributes == current_attributes | |||
), "Add or delete attributes in Module class may break compatibility of pickle serialization" | |||
def test_grou_norm(): | |||
class OriginGroupNormFunc(Module): | |||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs): | |||
super().__init__(**kwargs) | |||
assert num_channels % num_groups == 0 | |||
self.num_groups = num_groups | |||
self.num_channels = num_channels | |||
self.eps = eps | |||
self.affine = affine | |||
if self.affine: | |||
self.weight = Parameter(np.ones(num_channels, dtype=np.float32)) | |||
self.bias = Parameter(np.zeros(num_channels, dtype=np.float32)) | |||
else: | |||
self.weight = None | |||
self.bias = None | |||
def forward(self, x): | |||
N, C, H, W = x.shape | |||
x = x.reshape(N, self.num_groups, -1) | |||
mean = x.mean(axis=2, keepdims=True) | |||
var = (x * x).mean(axis=2, keepdims=True) - mean * mean | |||
x = (x - mean) / F.sqrt(var + self.eps) | |||
x = x.reshape(N, C, H, W) | |||
if self.affine: | |||
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape( | |||
1, -1, 1, 1 | |||
) | |||
return x | |||
inp = np.random.randn(2, 256, 10, 16).astype("float32") | |||
mge_inp = Tensor(inp) | |||
mge_m = GroupNorm(32, 256) | |||
ori_inp = Tensor(inp) | |||
ori_m = OriginGroupNormFunc(32, 256) | |||
targets = np.array(2) | |||
mge_gm = mge.autodiff.GradManager().attach(mge_m.parameters()) | |||
ori_gm = mge.autodiff.GradManager().attach(ori_m.parameters()) | |||
for i in range(2): | |||
with mge_gm: | |||
mge_output = mge_m(mge_inp) | |||
loss = F.loss.square_loss( | |||
mge_output.sum(), mge.tensor(targets, dtype=np.float32) | |||
) | |||
mge_gm.backward(loss) | |||
with ori_gm: | |||
ori_output = ori_m(ori_inp) | |||
loss = F.loss.square_loss( | |||
ori_output.sum(), mge.tensor(targets, dtype=np.float32) | |||
) | |||
ori_gm.backward(loss) | |||
np.testing.assert_allclose(mge_output.numpy(), ori_output.numpy(), atol=1e-05) | |||
np.testing.assert_allclose( | |||
mge_m.weight.grad.numpy(), ori_m.weight.grad.numpy(), rtol=1e-03 | |||
) | |||
np.testing.assert_allclose( | |||
mge_m.bias.grad.numpy(), ori_m.bias.grad.numpy(), rtol=1e-03 | |||
) |
@@ -0,0 +1,97 @@ | |||
#include "megbrain/opr/dnn/group_norm.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#include "../blob_manager_impl.h" | |||
#include "../dnn_op_helper.h" | |||
#include "../op_trait.h" | |||
namespace mgb::imperative { | |||
namespace group_norm { | |||
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const GroupNorm&>(def); | |||
size_t nr_inp = inputs.size(); | |||
auto p = op.param(); | |||
mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); | |||
OperatorNodeConfig config{op.make_name()}; | |||
if (nr_inp == 3) { | |||
return opr::GroupNorm::make( | |||
inputs[0], inputs[1], inputs[2], op.param(), config)[0] | |||
.node() | |||
->owner_opr(); | |||
} else { | |||
return opr::GroupNorm::make(inputs[0], op.param(), config)[0] | |||
.node() | |||
->owner_opr(); | |||
} | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto&& group_norm = def.cast_final_safe<GroupNorm>(); | |||
size_t nr_inp = inputs.size(); | |||
auto affine = group_norm.affine; | |||
mgb_assert( | |||
(nr_inp == 3 && affine) || (nr_inp == 1 && !affine), | |||
"num of inputs of pooling should be 1 or 3 but you give %zu", | |||
inputs.size()); | |||
auto&& inp = inputs[0]; | |||
auto& inp_cn = inp.comp_node; | |||
if (inp.layout.ndim == 0) { | |||
return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}, | |||
{TensorLayout{dtype::Float32()}, inp_cn, {}}, | |||
{TensorLayout{dtype::Float32()}, inp_cn, {}}}, | |||
false}; | |||
} | |||
DnnOprHelper<megdnn::GroupNorm> dnn_opr(group_norm.param()); | |||
auto&& [oup_layout, mean_layout, rstd_layout] = | |||
dnn_opr.deduce_layouts<3>(inp.layout, TensorLayout{}, TensorLayout{}); | |||
return {{{oup_layout, inp_cn, {}}, | |||
{mean_layout, inp_cn, {}}, | |||
{rstd_layout, inp_cn, {}}}, | |||
true}; | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto&& op_def = def.cast_final_safe<GroupNorm>(); | |||
size_t nr_inp = inputs.size(); | |||
auto p = op_def.param(); | |||
mgb_assert( | |||
(nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine), | |||
"num of inputs of groupnorm should be 1 or 3 but you give %zu", | |||
inputs.size()); | |||
auto cn = inputs[0]->comp_node(); | |||
DnnOprCaller<megdnn::GroupNorm> caller(cn, op_def.param()); | |||
auto&& [oup_layout, mean_layout, rstd_layout] = caller.deduce_layouts<3>( | |||
inputs[0]->layout(), TensorLayout{}, TensorLayout{}); | |||
auto out = Tensor::make(oup_layout, cn); | |||
auto mean = Tensor::make(mean_layout, cn); | |||
auto rstd = Tensor::make(rstd_layout, cn); | |||
if (p.affine) { | |||
caller.exec_with_ws(inputs[0], inputs[1], inputs[2], out, mean, rstd); | |||
} else { | |||
megdnn::TensorND empty_dnn; | |||
caller.exec_with_ws(inputs[0], empty_dnn, empty_dnn, out, mean, rstd); | |||
} | |||
return {out, mean, rstd}; | |||
} | |||
OP_TRAIT_REG(GroupNorm, GroupNorm) | |||
.apply_on_var_node(apply_on_var_node) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.fallback(); | |||
} // namespace group_norm | |||
} // namespace mgb::imperative |
@@ -1,7 +1,7 @@ | |||
905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py | |||
da03ffe2a15411f902cd88920d3d47ec ../../src/core/include/megbrain/ir/ops.td | |||
5756619f37e4dc130e1b049d7706d4eb generated/opdef.h.inl | |||
98d1291eed73970ee087f898b6241358 generated/opdef.cpp.inl | |||
b1a9c7569392942294c2168d40939eb5 generated/opdef.py.inl | |||
3d88d5358d15a39219957f5257e32f5b generated/opdef.cpy.inl | |||
e38b68be4e2aaf3de2f22e3dddbeaac4 ../../dnn/scripts/opr_param_defs.py | |||
cf864561de125ab559c0035158656682 ../../src/core/include/megbrain/ir/ops.td | |||
9248d42a9b3e770693306992156f6015 generated/opdef.h.inl | |||
5c7e7ac49d1338d70ac84ba309e6732b generated/opdef.cpp.inl | |||
30b669eec36876a65717e0c68dd76c83 generated/opdef.py.inl | |||
d10455217f5f01e3d2668e5689068920 generated/opdef.cpy.inl | |||
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h |
@@ -3775,6 +3775,110 @@ OP_TRAIT_REG(GroupLocal, GroupLocal) | |||
.props(GroupLocal_props_impl) | |||
.make_name(GroupLocal_make_name_impl); | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupNorm); | |||
namespace { | |||
size_t GroupNorm_hash_impl(const OpDef& def_) { | |||
auto&& op_ = def_.cast_final_safe<GroupNorm>(); | |||
static_cast<void>(op_); | |||
size_t val = mgb::hash(op_.dyn_typeinfo()); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.affine)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.eps)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.group)); | |||
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format)); | |||
return val; | |||
} | |||
bool GroupNorm_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { | |||
auto &&a_ = lhs_.cast_final_safe<GroupNorm>(), | |||
&&b_ = rhs_.cast_final_safe<GroupNorm>(); | |||
static_cast<void>(a_); | |||
static_cast<void>(b_); | |||
if (a_.affine != b_.affine) return false; | |||
if (a_.eps != b_.eps) return false; | |||
if (a_.group != b_.group) return false; | |||
if (a_.format != b_.format) return false; | |||
return true; | |||
} | |||
std::vector<std::pair<const char*, std::string>> GroupNorm_props_impl(const OpDef& def_) { | |||
auto&& op_ = def_.cast_final_safe<GroupNorm>(); | |||
static_cast<void>(op_); | |||
std::vector<std::pair<const char*, std::string>> props_; | |||
props_.emplace_back("affine", std::to_string(op_.affine)); | |||
props_.emplace_back("eps", std::to_string(op_.eps)); | |||
props_.emplace_back("group", std::to_string(op_.group)); | |||
switch (op_.format){ | |||
case GroupNorm::Format::NCHW: | |||
props_.emplace_back("format", "NCHW"); | |||
break; | |||
case GroupNorm::Format::NHWC: | |||
props_.emplace_back("format", "NHWC"); | |||
break; | |||
case GroupNorm::Format::NHWCD4: | |||
props_.emplace_back("format", "NHWCD4"); | |||
break; | |||
case GroupNorm::Format::NCHW4: | |||
props_.emplace_back("format", "NCHW4"); | |||
break; | |||
case GroupNorm::Format::NCHW8: | |||
props_.emplace_back("format", "NCHW8"); | |||
break; | |||
case GroupNorm::Format::NCHW32: | |||
props_.emplace_back("format", "NCHW32"); | |||
break; | |||
case GroupNorm::Format::NCHW88: | |||
props_.emplace_back("format", "NCHW88"); | |||
break; | |||
case GroupNorm::Format::NCHW44: | |||
props_.emplace_back("format", "NCHW44"); | |||
break; | |||
case GroupNorm::Format::NCHW44_DOT: | |||
props_.emplace_back("format", "NCHW44_DOT"); | |||
break; | |||
case GroupNorm::Format::NCHW4_NCHW32: | |||
props_.emplace_back("format", "NCHW4_NCHW32"); | |||
break; | |||
case GroupNorm::Format::NCHW32_NCHW4: | |||
props_.emplace_back("format", "NCHW32_NCHW4"); | |||
break; | |||
case GroupNorm::Format::NCHW4_NCHW: | |||
props_.emplace_back("format", "NCHW4_NCHW"); | |||
break; | |||
case GroupNorm::Format::NHWC_NCHW: | |||
props_.emplace_back("format", "NHWC_NCHW"); | |||
break; | |||
case GroupNorm::Format::NHWC_NCHW4_IC_SMALL: | |||
props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL"); | |||
break; | |||
case GroupNorm::Format::NCHW_NCHW4_IC_SMALL: | |||
props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL"); | |||
break; | |||
case GroupNorm::Format::CHWN4: | |||
props_.emplace_back("format", "CHWN4"); | |||
break; | |||
case GroupNorm::Format::NCHW64: | |||
props_.emplace_back("format", "NCHW64"); | |||
break; | |||
case GroupNorm::Format::NCHW4_NHWC: | |||
props_.emplace_back("format", "NCHW4_NHWC"); | |||
break; | |||
default: | |||
props_.emplace_back("format", "INVALID"); | |||
break; | |||
} | |||
return props_; | |||
} | |||
std::string GroupNorm_make_name_impl(const OpDef& def_) { | |||
auto&& op_ = def_.cast_final_safe<GroupNorm>(); | |||
static_cast<void>(op_); | |||
return "GroupNorm"; | |||
} | |||
} // anonymous namespace | |||
OP_TRAIT_REG(GroupNorm, GroupNorm) | |||
.hash(GroupNorm_hash_impl) | |||
.is_same_st(GroupNorm_is_same_st_impl) | |||
.props(GroupNorm_props_impl) | |||
.make_name(GroupNorm_make_name_impl); | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Identity); | |||
namespace { | |||
@@ -10075,6 +10075,158 @@ void _init_py_GroupLocal(py::module m) { | |||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(GroupLocal::typeinfo(), &py_type).second); | |||
} | |||
void _init_py_GroupNorm_Format(PyTypeObject& py_type) { | |||
auto& e_type = EnumWrapper<GroupNorm::Format>::type; | |||
Py_INCREF(e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0); | |||
} | |||
PyOpDefBegin(GroupNorm) // { | |||
static PyGetSetDef py_getsetters[]; | |||
static PyMethodDef tp_methods[]; | |||
static PyObject* getstate(PyObject* self, PyObject*) { | |||
auto& opdef = reinterpret_cast<PyOp(GroupNorm)*>(self)->inst(); | |||
static_cast<void>(opdef); | |||
std::unordered_map<std::string, py::object> state { | |||
{"affine", serialization<decltype(opdef.affine)>::dump(opdef.affine)}, | |||
{"eps", serialization<decltype(opdef.eps)>::dump(opdef.eps)}, | |||
{"group", serialization<decltype(opdef.group)>::dump(opdef.group)}, | |||
{"format", serialization<decltype(opdef.format)>::dump(opdef.format)} | |||
}; | |||
return py::cast(state).release().ptr(); | |||
} | |||
static PyObject* setstate(PyObject* self, PyObject* args) { | |||
PyObject* dict = PyTuple_GetItem(args, 0); | |||
if (!dict) return NULL; | |||
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict); | |||
auto& opdef = reinterpret_cast<PyOp(GroupNorm)*>(self)->inst(); | |||
static_cast<void>(opdef); | |||
{ | |||
auto&& iter = state.find("affine"); | |||
if (iter != state.end()) { | |||
opdef.affine = serialization<decltype(opdef.affine)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("eps"); | |||
if (iter != state.end()) { | |||
opdef.eps = serialization<decltype(opdef.eps)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("group"); | |||
if (iter != state.end()) { | |||
opdef.group = serialization<decltype(opdef.group)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("format"); | |||
if (iter != state.end()) { | |||
opdef.format = serialization<decltype(opdef.format)>::load(iter->second); | |||
} | |||
} | |||
Py_RETURN_NONE; | |||
} | |||
static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||
// }; | |||
PyOpDefEnd(GroupNorm) | |||
int PyOp(GroupNorm)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { | |||
static const char* kwlist[] = {"affine", "eps", "group", "format", "scope", NULL}; | |||
PyObject *affine = NULL, *eps = NULL, *group = NULL, *format = NULL, *scope = NULL; | |||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOO", const_cast<char**>(kwlist), &affine, &eps, &group, &format, &scope)) | |||
return -1; | |||
if (affine) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(GroupNorm)*>(self)->inst().affine = | |||
py::cast<decltype(GroupNorm::affine)>(py::handle(affine)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (eps) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(GroupNorm)*>(self)->inst().eps = | |||
py::cast<decltype(GroupNorm::eps)>(py::handle(eps)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (group) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(GroupNorm)*>(self)->inst().group = | |||
py::cast<decltype(GroupNorm::group)>(py::handle(group)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (format) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(GroupNorm)*>(self)->inst().format = | |||
py::cast<decltype(GroupNorm::format)>(py::handle(format)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (scope) { | |||
try { | |||
reinterpret_cast<PyOp(OpDef)*>(self)->op | |||
->set_scope(py::cast<std::string>(py::handle(scope))); | |||
} CATCH_ALL(-1) | |||
} | |||
return 0; | |||
} | |||
PyGetSetDef PyOp(GroupNorm)::py_getsetters[] = { | |||
{const_cast<char*>("affine"), py_get_generic(GroupNorm, affine), py_set_generic(GroupNorm, affine), const_cast<char*>("affine"), NULL}, | |||
{const_cast<char*>("eps"), py_get_generic(GroupNorm, eps), py_set_generic(GroupNorm, eps), const_cast<char*>("eps"), NULL}, | |||
{const_cast<char*>("group"), py_get_generic(GroupNorm, group), py_set_generic(GroupNorm, group), const_cast<char*>("group"), NULL}, | |||
{const_cast<char*>("format"), py_get_generic(GroupNorm, format), py_set_generic(GroupNorm, format), const_cast<char*>("format"), NULL}, | |||
{NULL} /* Sentinel */ | |||
}; | |||
PyMethodDef PyOp(GroupNorm)::tp_methods[] = { | |||
{const_cast<char*>("__getstate__"), PyOp(GroupNorm)::getstate, METH_NOARGS, "GroupNorm getstate"}, | |||
{const_cast<char*>("__setstate__"), PyOp(GroupNorm)::setstate, METH_VARARGS, "GroupNorm setstate"}, | |||
{NULL} /* Sentinel */ | |||
}; | |||
void _init_py_GroupNorm(py::module m) { | |||
using py_op = PyOp(GroupNorm); | |||
auto& py_type = PyOpType(GroupNorm); | |||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
py_type.tp_name = "megengine.core._imperative_rt.ops.GroupNorm"; | |||
py_type.tp_basicsize = sizeof(PyOp(GroupNorm)); | |||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
py_type.tp_doc = "GroupNorm"; | |||
py_type.tp_base = &PyOpType(OpDef); | |||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
py_type.tp_new = py_new_generic<py_op>; | |||
py_type.tp_init = py_op::py_init; | |||
py_type.tp_methods = py_op::tp_methods; | |||
py_type.tp_getset = py_op::py_getsetters; | |||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||
_init_py_GroupNorm_Format(py_type); | |||
PyType_Modified(&py_type); | |||
m.add_object("GroupNorm", reinterpret_cast<PyObject*>(&py_type)); | |||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(GroupNorm::typeinfo(), &py_type).second); | |||
} | |||
PyOpDefBegin(Identity) // { | |||
static PyGetSetDef py_getsetters[]; | |||
static PyMethodDef tp_methods[]; | |||
@@ -19237,6 +19389,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { | |||
_init_py_GaussianRNG(m); \ | |||
_init_py_GetVarShape(m); \ | |||
_init_py_GroupLocal(m); \ | |||
_init_py_GroupNorm(m); \ | |||
_init_py_Identity(m); \ | |||
_init_py_Images2Neibs(m); \ | |||
_init_py_IncrMeshIndexing(m); \ | |||
@@ -988,6 +988,23 @@ public: | |||
} | |||
}; | |||
class GroupNorm : public OpDefImplBase<GroupNorm> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
using Format = ::megdnn::param::GroupNorm::Format; | |||
bool affine = true; | |||
float eps = 1e-5f; | |||
uint32_t group = 1; | |||
Format format = ::megdnn::param::GroupNorm::Format::NCHW; | |||
GroupNorm() = default; | |||
GroupNorm(bool affine_, float eps_, uint32_t group_, Format format_, std::string scope_ = {}): affine(affine_), eps(eps_), group(group_), format(format_) { set_scope(scope_); } | |||
GroupNorm(::megdnn::param::GroupNorm packed_param_0): affine(packed_param_0.affine), eps(packed_param_0.eps), group(packed_param_0.group), format(packed_param_0.format) {} | |||
::megdnn::param::GroupNorm param() const { | |||
return {affine, eps, group, format}; | |||
} | |||
}; | |||
class Identity : public OpDefImplBase<Identity> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
@@ -1193,6 +1193,17 @@ GroupLocalInst | |||
.def_readwrite("format", &GroupLocal::format) | |||
.def_readwrite("compute_mode", &GroupLocal::compute_mode); | |||
py::class_<GroupNorm, std::shared_ptr<GroupNorm>, OpDef> GroupNormInst(m, "GroupNorm"); | |||
GroupNormInst.attr("Format") = AdaptivePoolingInst.attr("Format"); | |||
GroupNormInst | |||
.def(py::init<bool, float, uint32_t, ::megdnn::param::GroupNorm::Format, std::string>(), py::arg("affine") = true, py::arg("eps") = 1e-5f, py::arg("group") = 1, py::arg("format") = ::megdnn::param::GroupNorm::Format::NCHW, py::arg("scope") = {}) | |||
.def_readwrite("affine", &GroupNorm::affine) | |||
.def_readwrite("eps", &GroupNorm::eps) | |||
.def_readwrite("group", &GroupNorm::group) | |||
.def_readwrite("format", &GroupNorm::format); | |||
py::class_<Identity, std::shared_ptr<Identity>, OpDef> IdentityInst(m, "Identity"); | |||
IdentityInst | |||
@@ -490,6 +490,8 @@ def LRN: MgbHashableOp<"LRN", [LRNParam]>; | |||
def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; | |||
def GroupNorm: MgbHashableOp<"GroupNorm", [GroupNormParam]>; | |||
def LAMBUpdate: MgbHashableOp<"LAMBUpdate", [LAMBUpdateParam]>; | |||
def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; | |||
@@ -1,8 +1,10 @@ | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||
#include "megbrain/opr/dnn/batch_norm.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/dnn/correlation.h" | |||
#include "megbrain/opr/dnn/fake_quant.h" | |||
#include "megbrain/opr/dnn/group_norm.h" | |||
#include "megbrain/opr/dnn/images2neibs.h" | |||
#include "megbrain/opr/dnn/layer_norm.h" | |||
#include "megbrain/opr/dnn/local.h" | |||
@@ -15,6 +17,9 @@ | |||
#include "megbrain/opr/dnn/sliding_window_transpose.h" | |||
#include "megbrain/opr/dnn/softmax.h" | |||
#include "megbrain/opr/dnn/tqt.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/serialization/oss_opr_load_dump.h" | |||
#include "megbrain/serialization/sereg.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megdnn/oprs/nn.h" | |||
@@ -524,6 +529,213 @@ struct OprMaker<opr::LayerNormBackward, 0> { | |||
} | |||
}; | |||
template <> | |||
struct OprMaker<opr::GroupNorm, 0> { | |||
using Param = opr::GroupNorm::Param; | |||
static cg::OperatorNodeBase* make( | |||
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||
const OperatorNodeConfig& config) { | |||
MGB_MARK_USED_VAR(graph); | |||
if (i.size() == 3) { | |||
return opr::GroupNorm::make(i[0], i[1], i[2], param, config)[0] | |||
.node() | |||
->owner_opr(); | |||
} else { | |||
mgb_assert(i.size() == 1); | |||
return opr::GroupNorm::make(i[0], param, config)[0].node()->owner_opr(); | |||
} | |||
} | |||
}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::GroupNorm, 0> { | |||
using Opr = opr::GroupNorm; | |||
using Param = opr::GroupNorm::Param; | |||
using ElemwiseParam = opr::Elemwise::Param; | |||
using ReduceParam = opr::Reduce::Param; | |||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | |||
ctx.write_param<Param>(opr.cast_final_safe<Opr>().param()); | |||
} | |||
static cg::OperatorNodeBase* replace_opr( | |||
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { | |||
auto graph = inputs[0]->owner_graph(); | |||
auto comp_node = inputs[0]->comp_node(); | |||
// std::unique_ptr<StaticInferManager> m_static_infer_manager; | |||
auto opr_param = opr->cast_final_safe<opr::GroupNorm>().param(); | |||
float eps = opr_param.eps; | |||
auto half = DTypeScalar(static_cast<megdnn::dt_float32>(0.5)); | |||
auto param_eps = DTypeScalar(static_cast<megdnn::dt_float32>(eps)); | |||
auto half_node = opr::ImmutableTensor::make(*graph, half, {comp_node}); | |||
auto eps_node = opr::ImmutableTensor::make(*graph, param_eps, {comp_node}); | |||
auto origin_shape = opr::GetVarShape::make(inputs[0]).node(); | |||
TensorShape input_shape = | |||
inputs[0]->owner_graph()->static_infer_manager().infer_shape(inputs[0]); | |||
size_t N = input_shape[0]; | |||
size_t inner_size = input_shape[1] * input_shape[2] * input_shape[3]; | |||
int group = opr_param.group; | |||
int size = inner_size / group; | |||
HostTensorND hv = HostTensorND(inputs[0]->comp_node(), {3}, dtype::Int32()); | |||
auto* ptr = hv.ptr<dt_int32>(); | |||
ptr[0] = N; | |||
ptr[1] = group; | |||
ptr[2] = size; | |||
auto target_shape = opr::ImmutableTensor::make(*graph, hv, {comp_node}); | |||
auto inp = opr::Reshape::make(inputs[0], target_shape); | |||
auto mean = opr::Reduce::make(inp, {ReduceParam::Mode::MEAN, 2}); | |||
auto elemwise1 = opr::Elemwise::make({inp, inp}, {ElemwiseParam::Mode::MUL}); | |||
auto temp_var = opr::Reduce::make(elemwise1, {ReduceParam::Mode::MEAN, 2}); | |||
auto elemwise2 = opr::Elemwise::make({mean, mean}, {ElemwiseParam::Mode::MUL}); | |||
auto var = | |||
opr::Elemwise::make({temp_var, elemwise2}, {ElemwiseParam::Mode::SUB}); | |||
auto add_var = opr::Elemwise::make({var, eps_node}, {ElemwiseParam::Mode::ADD}); | |||
auto sqrt = | |||
opr::Elemwise::make({add_var, half_node}, {ElemwiseParam::Mode::POW}); | |||
auto div = opr::Elemwise::make({inp, mean}, {ElemwiseParam::Mode::SUB}); | |||
auto temp_inp = | |||
opr::Elemwise::make({div, sqrt}, {ElemwiseParam::Mode::TRUE_DIV}); | |||
auto res = opr::Reshape::make(temp_inp, origin_shape); | |||
if (inputs.size() == 3) { | |||
auto mul_temp = | |||
opr::Elemwise::make({res, inputs[1]}, {ElemwiseParam::Mode::MUL}); | |||
auto res = opr::Elemwise::make( | |||
{mul_temp, inputs[2]}, {ElemwiseParam::Mode::ADD}); | |||
return res.node()->owner_opr(); | |||
} else { | |||
return res.node()->owner_opr(); | |||
} | |||
} | |||
static cg::OperatorNodeBase* load( | |||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
const OperatorNodeConfig& config) { | |||
// auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); | |||
return OprMaker<opr::GroupNorm, 0>::make( | |||
ctx.read_param<Param>(), inputs, ctx.graph(), config); | |||
} | |||
}; | |||
// OprMaker in MGB_SEREG_OPR only support unique output opr | |||
template <> | |||
struct OprMaker<opr::GroupNormBackward, 0> { | |||
using Param = opr::GroupNormBackward::Param; | |||
static cg::OperatorNodeBase* make( | |||
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||
const OperatorNodeConfig& config) { | |||
MGB_MARK_USED_VAR(graph); | |||
if (i.size() == 5) { | |||
return opr::GroupNormBackward::make( | |||
i[0], i[1], i[2], i[3], i[4], param, config)[0] | |||
.node() | |||
->owner_opr(); | |||
} else { | |||
mgb_assert(i.size() == 4); | |||
return opr::GroupNormBackward::make( | |||
i[0], i[1], i[2], i[3], param, config)[0] | |||
.node() | |||
->owner_opr(); | |||
} | |||
} | |||
}; | |||
template <> | |||
struct OprLoadDumpImplV2<opr::GroupNormBackward, 0> { | |||
using Opr = opr::GroupNormBackward; | |||
using Param = opr::GroupNormBackward::Param; | |||
using ElemwiseParam = opr::Elemwise::Param; | |||
using ReduceParam = opr::Reduce::Param; | |||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { | |||
ctx.write_param<Param>(opr.cast_final_safe<Opr>().param()); | |||
} | |||
static cg::OperatorNodeBase* replace_opr( | |||
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { | |||
auto rstd = inputs[4]; | |||
auto graph = inputs[1]->owner_graph(); | |||
auto comp_node = inputs[1]->comp_node(); | |||
auto opr_param = opr->cast_final_safe<opr::GroupNormBackward>().param(); | |||
float eps = opr_param.eps; | |||
auto half = DTypeScalar(static_cast<megdnn::dt_float32>(0.5)); | |||
auto param_eps = DTypeScalar(static_cast<megdnn::dt_float32>(eps)); | |||
auto half_node = opr::ImmutableTensor::make(*graph, half, {comp_node}); | |||
auto eps_node = opr::ImmutableTensor::make(*graph, param_eps, {comp_node}); | |||
auto const_node = | |||
opr::ImmutableTensor::make(*graph, DTypeScalar(1), {comp_node}); | |||
TensorShape input_shape = | |||
inputs[1]->owner_graph()->static_infer_manager().infer_shape(inputs[0]); | |||
auto origin_shape = opr::GetVarShape::make(inputs[1]).node(); | |||
size_t N = input_shape[0]; | |||
size_t C = input_shape[1]; | |||
size_t inner_size = input_shape[1] * input_shape[2] * input_shape[3]; | |||
int group = opr_param.group; | |||
int size = inner_size / group; | |||
HostTensorND hv = HostTensorND(inputs[1]->comp_node(), {3}, dtype::Int32()); | |||
auto* ptr = hv.ptr<dt_int32>(); | |||
ptr[0] = N; | |||
ptr[1] = group; | |||
ptr[2] = size; | |||
auto target_shape = opr::ImmutableTensor::make(*graph, hv, {comp_node}); | |||
auto inp = opr::Reshape::make(inputs[1], target_shape); | |||
auto temp_rstd = | |||
opr::Elemwise::make({rstd, eps_node}, {ElemwiseParam::Mode::ADD}); | |||
auto sqrt = | |||
opr::Elemwise::make({temp_rstd, half_node}, {ElemwiseParam::Mode::POW}); | |||
auto slice_std = opr::Elemwise::make( | |||
{const_node, sqrt}, {ElemwiseParam::Mode::TRUE_DIV}); | |||
auto sub_mean = | |||
opr::Elemwise::make({inp, inputs[3]}, {ElemwiseParam::Mode::SUB}); | |||
auto x_hat = | |||
opr::Elemwise::make({sub_mean, slice_std}, {ElemwiseParam::Mode::MUL}); | |||
x_hat = opr::Reshape::make(x_hat, origin_shape); | |||
auto size_node = | |||
opr::ImmutableTensor::make(*graph, DTypeScalar(size), {comp_node}); | |||
auto temp1 = opr::Elemwise::make( | |||
{slice_std, size_node}, {ElemwiseParam::Mode::TRUE_DIV}); | |||
auto dx_hat = | |||
opr::Elemwise::make({inputs[0], inputs[2]}, {ElemwiseParam::Mode::MUL}); | |||
HostTensorND tshape = HostTensorND(inputs[1]->comp_node(), {5}, dtype::Int32()); | |||
auto* ptr2 = tshape.ptr<dt_int32>(); | |||
ptr2[0] = N; | |||
ptr2[1] = group; | |||
ptr2[2] = C / group; | |||
ptr2[3] = input_shape[2]; | |||
ptr2[4] = input_shape[3]; | |||
target_shape = opr::ImmutableTensor::make(*graph, tshape, {comp_node}); | |||
x_hat = opr::Reshape::make(x_hat, target_shape); | |||
dx_hat = opr::Reshape::make(dx_hat, target_shape); | |||
auto temp2 = | |||
opr::Elemwise::make({size_node, dx_hat}, {ElemwiseParam::Mode::MUL}); | |||
ptr2[2] = 1; | |||
ptr2[3] = 1; | |||
ptr2[4] = 1; | |||
target_shape = opr::ImmutableTensor::make(*graph, tshape, {comp_node}); | |||
auto temp3 = opr::Reduce::make(dx_hat, {ReduceParam::Mode::SUM}, target_shape); | |||
auto sum_dx_hat = | |||
opr::Reduce::make(temp2, {ReduceParam::Mode::SUM}, target_shape); | |||
auto temp4 = | |||
opr::Elemwise::make({x_hat, sum_dx_hat}, {ElemwiseParam::Mode::MUL}); | |||
auto temp5 = opr::Elemwise::make({temp2, temp3}, {ElemwiseParam::Mode::SUB}); | |||
auto temp6 = opr::Elemwise::make({temp5, temp4}, {ElemwiseParam::Mode::SUB}); | |||
auto dx_temp = opr::Elemwise::make({temp1, temp6}, {ElemwiseParam::Mode::MUL}); | |||
auto dx = opr::Reshape::make(dx_temp, origin_shape); | |||
return dx.node()->owner_opr(); | |||
} | |||
static cg::OperatorNodeBase* load( | |||
OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||
const OperatorNodeConfig& config) { | |||
return OprMaker<opr::GroupNormBackward, 0>::make( | |||
ctx.read_param<Param>(), inputs, ctx.graph(), config); | |||
} | |||
}; | |||
template <class MegDNNConv = megdnn::LocalShare> | |||
struct MakeLocalShareCaller2 { | |||
template <typename Opr> | |||
@@ -747,6 +959,8 @@ MGB_SEREG_OPR(LSQ, 4); | |||
MGB_SEREG_OPR(LSQBackward, 5); | |||
MGB_SEREG_OPR(LayerNorm, 0); | |||
MGB_SEREG_OPR(LayerNormBackward, 0); | |||
MGB_SEREG_OPR(GroupNorm, 0); | |||
MGB_SEREG_OPR(GroupNormBackward, 0); | |||
MGB_SEREG_OPR(RNNCellForward, 6); | |||
MGB_SEREG_OPR(LSTMCellForward, 7); | |||
MGB_SEREG_OPR(RNNForward, 3); | |||
@@ -755,6 +969,14 @@ MGB_SEREG_OPR(LSTMForward, 4); | |||
MGB_SEREG_OPR(LSTMBackward, 9); | |||
MGB_SEREG_OPR(Softmax, 1); | |||
MGB_SEREG_OPR(SoftmaxBackward, 2); | |||
MGB_SEREG_OPR_V2( | |||
GroupNorm, 0, | |||
(mgb::serialization::OprLoadDumpImplV2<opr::GroupNorm, 0>::replace_opr), | |||
VERSION_2, CURRENT_VERSION); | |||
MGB_SEREG_OPR_V2( | |||
GroupNormBackward, 0, | |||
(mgb::serialization::OprLoadDumpImplV2<opr::GroupNormBackward, 0>::replace_opr), | |||
VERSION_2, CURRENT_VERSION); | |||
} // namespace opr | |||
} // namespace mgb | |||
@@ -0,0 +1,239 @@ | |||
#include "megbrain/opr/dnn/group_norm.h" | |||
#include "megbrain/graph/grad_impl.h" | |||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "../internal/megdnn_opr_wrapper.inl" | |||
using namespace mgb; | |||
using namespace opr; | |||
/* ==================== GroupNormForward ==================== */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupNormForward); | |||
GroupNormForward::GroupNormForward( | |||
VarNode* data, VarNode* weight, VarNode* bias, const Param& param, | |||
const OperatorNodeConfig& config) | |||
: Super{data->owner_graph(), config, "group_norm", {data, weight, bias}} { | |||
init_megdnn_opr(*this, param); | |||
add_input({data, weight, bias}); | |||
output(0)->dtype(data->dtype()); | |||
output(1)->dtype(dtype::Float32()); | |||
output(2)->dtype(dtype::Float32()); | |||
} | |||
GroupNormForward::GroupNormForward( | |||
VarNode* data, const Param& param, const OperatorNodeConfig& config) | |||
: Super{data->owner_graph(), config, "group_norm", {data}} { | |||
init_megdnn_opr(*this, param); | |||
add_input({data}); | |||
output(0)->dtype(data->dtype()); | |||
output(1)->dtype(dtype::Float32()); | |||
output(2)->dtype(dtype::Float32()); | |||
} | |||
SymbolVarArray GroupNormForward::make( | |||
SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param, | |||
const OperatorNodeConfig& config) { | |||
auto outs = data.node() | |||
->owner_graph() | |||
->insert_opr(std::make_unique<GroupNormForward>( | |||
data.node(), weight.node(), bias.node(), param, config)) | |||
->output(); | |||
SymbolVarArray ret; | |||
for (auto&& out : outs) { | |||
ret.emplace_back(out); | |||
} | |||
return ret; | |||
} | |||
SymbolVarArray GroupNormForward::make( | |||
SymbolVar data, const Param& param, const OperatorNodeConfig& config) { | |||
auto outs = data.node() | |||
->owner_graph() | |||
->insert_opr(std::make_unique<GroupNormForward>( | |||
data.node(), param, config)) | |||
->output(); | |||
SymbolVarArray ret; | |||
for (auto&& out : outs) { | |||
ret.emplace_back(out); | |||
} | |||
return ret; | |||
} | |||
void GroupNormForward::get_output_var_shape( | |||
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { | |||
size_t group = param().group; | |||
out_shape[0] = inp_shape[0]; | |||
size_t N = inp_shape[0].shape[0]; | |||
TensorShape unnormalized_shape{N, group}; | |||
out_shape[1] = unnormalized_shape; | |||
out_shape[2] = unnormalized_shape; | |||
} | |||
size_t GroupNormForward::get_workspace_size_bytes( | |||
const TensorShapeArray& input_shapes, | |||
const TensorShapeArray& output_shapes) const { | |||
return intl::MegDNNOprMethInvoker<megdnn::GroupNormForward>::get_workspace_in_bytes( | |||
megdnn_opr(), this, input_shapes, output_shapes); | |||
} | |||
void GroupNormForward::scn_do_execute() { | |||
if (param().affine) { | |||
megdnn_opr()->exec( | |||
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||
input(2)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||
output(1)->dev_tensor().as_megdnn(), | |||
output(2)->dev_tensor().as_megdnn(), | |||
intl::get_megdnn_workspace_from_var(output().back())); | |||
} else { | |||
megdnn_opr()->exec( | |||
input(0)->dev_tensor().as_megdnn(), {}, {}, | |||
output(0)->dev_tensor().as_megdnn(), | |||
output(1)->dev_tensor().as_megdnn(), | |||
output(2)->dev_tensor().as_megdnn(), | |||
intl::get_megdnn_workspace_from_var(output().back())); | |||
} | |||
} | |||
#if MGB_ENABLE_GRAD | |||
MGB_IMPL_OPR_GRAD(GroupNormForward) { | |||
auto p = opr.param(); | |||
SymbolVarArray grad; | |||
VarNodeArray ret; | |||
if (p.affine) { | |||
mgb_assert(wrt_idx < 3, "wrt_idx %zu is out of range", wrt_idx); | |||
grad = GroupNormBackward::make( | |||
out_grad[0], opr.input(0), opr.input(1), opr.output(1), opr.output(2), | |||
opr.param()); | |||
} else { | |||
mgb_assert(wrt_idx < 1, "wrt_idx %zu is out of range", wrt_idx); | |||
grad = GroupNormBackward::make( | |||
out_grad[0], opr.input(0), opr.output(1), opr.output(2), opr.param()); | |||
} | |||
uint32_t nr_ret = p.affine ? 3 : 1; | |||
for (uint32_t i = 0; i < nr_ret; ++i) { | |||
ret.push_back(grad[i].node()); | |||
} | |||
return ret; | |||
} | |||
#endif | |||
/* ==================== GroupNormBackward ==================== */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupNormBackward); | |||
GroupNormBackward::GroupNormBackward( | |||
VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, | |||
const Param& param, const OperatorNodeConfig& config) | |||
: Super({diff->owner_graph(), | |||
config, | |||
"group_norm_backward", | |||
{diff, data, weight, mean, rstd}}, | |||
0, true) { | |||
init_megdnn_opr(*this, param); | |||
add_input({diff, data, weight, mean, rstd}); | |||
} | |||
GroupNormBackward::GroupNormBackward( | |||
VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, const Param& param, | |||
const OperatorNodeConfig& config) | |||
: Super({diff->owner_graph(), | |||
config, | |||
"group_norm_backward", | |||
{diff, data, mean, rstd}}, | |||
0, true) { | |||
init_megdnn_opr(*this, param); | |||
add_input({diff, data, mean, rstd}); | |||
auto mark_empty_var = [&](VarNode* var) { | |||
var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) | |||
.add_flag(VarNode::Flag::VOLATILE_CONTENT); | |||
}; | |||
mark_empty_var(output(1)); | |||
mark_empty_var(output(2)); | |||
} | |||
SymbolVarArray GroupNormBackward::make( | |||
SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, | |||
SymbolVar rstd, const Param& param, const OperatorNodeConfig& config) { | |||
auto outs = diff.node() | |||
->owner_graph() | |||
->insert_opr(std::make_unique<GroupNormBackward>( | |||
diff.node(), data.node(), weight.node(), mean.node(), | |||
rstd.node(), param, config)) | |||
->output(); | |||
SymbolVarArray ret; | |||
for (auto&& out : outs) { | |||
ret.emplace_back(out); | |||
} | |||
return ret; | |||
} | |||
SymbolVarArray GroupNormBackward::make( | |||
SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, | |||
const Param& param, const OperatorNodeConfig& config) { | |||
auto outs = diff.node() | |||
->owner_graph() | |||
->insert_opr(std::make_unique<GroupNormBackward>( | |||
diff.node(), data.node(), mean.node(), rstd.node(), | |||
param, config)) | |||
->output(); | |||
SymbolVarArray ret; | |||
for (auto&& out : outs) { | |||
ret.emplace_back(out); | |||
} | |||
return ret; | |||
} | |||
void GroupNormBackward::init_output_static_infer_desc() { | |||
using namespace cg::static_infer; | |||
auto&& mgr = owner_graph()->static_infer_manager(); | |||
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1))); | |||
if (param().affine) { | |||
mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); | |||
mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(2))); | |||
} else { | |||
TensorShape empty; | |||
empty.ndim = 0; | |||
mgr.register_shape_infer(output(1), ShapeInferDesc::make_const(empty)); | |||
mgr.register_shape_infer(output(2), ShapeInferDesc::make_const(empty)); | |||
} | |||
this->init_output_static_infer_desc_workspace( | |||
intl::AutoAddWorkspaceNeedLimitGetter<megdnn::GroupNormBackward>::val); | |||
} | |||
void GroupNormBackward::init_output_dtype() { | |||
output(0)->dtype(input(1)->dtype()); | |||
output(1)->dtype(input(2)->dtype()); | |||
output(2)->dtype(input(2)->dtype()); | |||
} | |||
size_t GroupNormBackward::get_workspace_size_bytes( | |||
const TensorShapeArray& input_shapes, | |||
const TensorShapeArray& output_shapes) const { | |||
return intl::MegDNNOprMethInvoker<megdnn::GroupNormBackward>:: | |||
get_workspace_in_bytes(megdnn_opr(), this, input_shapes, output_shapes); | |||
} | |||
void GroupNormBackward::scn_do_execute() { | |||
if (param().affine) { | |||
megdnn_opr()->exec( | |||
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||
input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), | |||
input(4)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||
output(1)->dev_tensor().as_megdnn(), | |||
output(2)->dev_tensor().as_megdnn(), | |||
intl::get_megdnn_workspace_from_var(output(3))); | |||
} else { | |||
megdnn_opr()->exec( | |||
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||
{}, input(2)->dev_tensor().as_megdnn(), | |||
input(3)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||
{}, {}, intl::get_megdnn_workspace_from_var(output(3))); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,67 @@ | |||
#pragma once | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#include "megdnn/oprs.h" | |||
namespace mgb { | |||
namespace opr { | |||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
GroupNormForward, intl::MegDNNOprWrapperFwd<megdnn::GroupNormForward>) // { | |||
public: | |||
MGE_WIN_DECLSPEC_FUC GroupNormForward( | |||
VarNode* data, VarNode* weight, VarNode* bias, const Param& param, | |||
const OperatorNodeConfig& config); | |||
MGE_WIN_DECLSPEC_FUC GroupNormForward( | |||
VarNode* data, const Param& param, const OperatorNodeConfig& config); | |||
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||
SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param = {}, | |||
const OperatorNodeConfig& config = {}); | |||
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||
SymbolVar data, const Param& param = {}, | |||
const OperatorNodeConfig& config = {}); | |||
private: | |||
void get_output_var_shape( | |||
const TensorShapeArray& inp_shape, | |||
TensorShapeArray& out_shape) const override; | |||
size_t get_workspace_size_bytes( | |||
const TensorShapeArray& input_shapes, | |||
const TensorShapeArray& output_shapes) const override; | |||
void scn_do_execute() override; | |||
}; | |||
using GroupNorm = GroupNormForward; | |||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
GroupNormBackward, intl::MegDNNOprWrapperBwd<megdnn::GroupNormBackward>) // { | |||
public: | |||
MGE_WIN_DECLSPEC_FUC GroupNormBackward( | |||
VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, | |||
const Param& param, const OperatorNodeConfig& config); | |||
MGE_WIN_DECLSPEC_FUC GroupNormBackward( | |||
VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, | |||
const Param& param, const OperatorNodeConfig& config); | |||
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||
SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, | |||
SymbolVar rstd, const Param& param = {}, | |||
const OperatorNodeConfig& config = {}); | |||
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( | |||
SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, | |||
const Param& param = {}, const OperatorNodeConfig& config = {}); | |||
private: | |||
void init_output_static_infer_desc() override; | |||
void init_output_dtype() override; | |||
size_t get_workspace_size_bytes( | |||
const TensorShapeArray& input_shapes, | |||
const TensorShapeArray& output_shapes) const override; | |||
void scn_do_execute() override; | |||
}; | |||
} // namespace opr | |||
} // namespace mgb | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,90 @@ | |||
#include "megbrain/opr/dnn/group_norm.h" | |||
#include "megbrain/comp_node_env.h" | |||
#include "megbrain/test/autocheck.h" | |||
#include "megbrain/test/helper.h" | |||
#include "megbrain/test/megdnn_helper.h" | |||
#include "megdnn/oprs.h" | |||
#include <cmath> | |||
#include <iomanip> | |||
#include <random> | |||
#include <sstream> | |||
using namespace mgb; | |||
namespace { | |||
using Param = opr::GroupNormForward::Param; | |||
void run_forward(bool is_affine) { | |||
using Checker = AutoOprChecker<3, 3>; | |||
Param param; | |||
param.eps = 1e-5; | |||
param.affine = is_affine; | |||
param.group = 3; | |||
auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||
auto out = opr::GroupNormForward::make(inputs[0], inputs[1], inputs[2], param); | |||
return {out[0], out[1], out[2]}; | |||
}; | |||
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||
auto opr = | |||
MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu())) | |||
->create_operator<megdnn::GroupNormForward>(); | |||
auto inp_shape = inp[0]->shape(); | |||
auto n_slices = inp_shape[0]; | |||
opr->param() = param; | |||
dest[0].dtype(dtype::Float32()) | |||
.comp_node(inp[0]->comp_node()) | |||
.resize(inp_shape); | |||
dest[1].dtype(dtype::Float32()) | |||
.comp_node(inp[0]->comp_node()) | |||
.resize({n_slices, param.group}); | |||
dest[2].dtype(dtype::Float32()) | |||
.comp_node(inp[0]->comp_node()) | |||
.resize({n_slices, param.group}); | |||
std::vector<dt_byte> workspace(opr->get_workspace_in_bytes( | |||
inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), dest[0].layout(), | |||
dest[1].layout(), dest[2].layout())); | |||
opr->exec( | |||
inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), | |||
dest[0].as_megdnn(), dest[1].as_megdnn(), dest[2].as_megdnn(), | |||
{workspace.data(), workspace.size()}); | |||
}; | |||
auto gen = [&](HostTensorND& src) { | |||
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> src_gen(0.f); | |||
src = *src_gen(src.shape(), src.comp_node()); | |||
}; | |||
Checker::RunOptions option; | |||
option.numdiff_max_err = 1e-4; | |||
Checker checker{make_graph, fwd}; | |||
checker.set_input_generator(0, gen); | |||
checker.set_input_generator(1, gen); | |||
checker.set_input_generator(2, gen); | |||
checker.set_input_allow_grad(0, false); | |||
checker.set_input_allow_grad(1, false); | |||
checker.set_input_allow_grad(2, false); | |||
checker.set_output_allow_grad(0, false); | |||
checker.set_output_allow_grad(1, false); | |||
checker.set_output_allow_grad(2, false); | |||
checker.run({TensorShape{2, 6, 2, 1}, TensorShape{6}, TensorShape{6}}, option) | |||
.run({TensorShape{2, 6, 2, 1}, TensorShape{6}, TensorShape{6}}, option) | |||
.run({TensorShape{2, 6, 2, 1}, TensorShape{6}, TensorShape{6}}, option); | |||
} | |||
TEST(TestOprDNN, GroupNormForward) { | |||
REQUIRE_GPU(1); | |||
run_forward(true); | |||
} | |||
} // anonymous namespace | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -123,6 +123,7 @@ union OperatorParam { | |||
param.LSTM = 89, | |||
param.Softmax = 90, | |||
param.Diag = 91, | |||
param.GroupNorm = 92, | |||
} | |||
table Operator { | |||
@@ -140,6 +140,7 @@ union OperatorParam { | |||
param.LSTM = 89, | |||
param.Softmax = 90, | |||
param.Diag = 91, | |||
param.GroupNorm = 92, | |||
} | |||
table Operator { | |||