@@ -2049,6 +2049,198 @@ protected: | |||||
const TensorLayout& dinp, size_t workspace_in_bytes); | const TensorLayout& dinp, size_t workspace_in_bytes); | ||||
}; | }; | ||||
class RNNCellForward : public OperatorBase { | |||||
DEF_OPR_PARAM(RNNCell); | |||||
DEF_OPR_IMPL(RNNCellForward, OperatorBase, 6, 1); | |||||
public: | |||||
virtual void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||||
static void deduce_layout( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
TensorLayout& dst); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& dst) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& dst, size_t workspace_in_bytes); | |||||
}; | |||||
using RNNCell = RNNCellForward; | |||||
class LSTMCellForward : public OperatorBase { | |||||
// DEF_OPR_PARAM(LSTMCell); | |||||
DEF_OPR_PARAM(Empty); | |||||
DEF_OPR_IMPL(LSTMCellForward, OperatorBase, 7, 3); | |||||
public: | |||||
virtual void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||||
_megdnn_tensor_out gates, _megdnn_workspace workspace) = 0; | |||||
void deduce_layout( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new, | |||||
TensorLayout& gates); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, const TensorLayout& h_new, | |||||
const TensorLayout& c_new, const TensorLayout& gates) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, const TensorLayout& h_new, | |||||
const TensorLayout& c_new, const TensorLayout& gates, | |||||
size_t workspace_in_bytes); | |||||
}; | |||||
using LSTMCell = LSTMCellForward; | |||||
class RNNForward : public OperatorBase { | |||||
DEF_OPR_PARAM(RNN); | |||||
DEF_OPR_IMPL(RNNForward, OperatorBase, 3, 3); | |||||
public: | |||||
virtual void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout( | |||||
const TensorLayout& input, const TensorLayout& hx, | |||||
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy, | |||||
TensorLayout& reserve_space); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& hx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& reserve_space) = 0; | |||||
virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& input, const TensorLayout& hx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& reserve_space, | |||||
size_t workspace_in_bytes); | |||||
}; | |||||
using RNN = RNNForward; | |||||
class RNNBackward : public OperatorBase { | |||||
DEF_OPR_PARAM(RNN); | |||||
DEF_OPR_IMPL(RNNBackward, OperatorBase, 7, 3); | |||||
public: | |||||
virtual void exec( | |||||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space, | |||||
_megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||||
TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||||
const TensorLayout& dx, const TensorLayout& dhx, | |||||
const TensorLayout& dw) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||||
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw, | |||||
size_t workspace_in_bytes); | |||||
}; | |||||
class LSTMForward : public OperatorBase { | |||||
DEF_OPR_PARAM(LSTM); | |||||
DEF_OPR_IMPL(LSTMForward, OperatorBase, 4, 4); | |||||
public: | |||||
virtual void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out hy, _megdnn_tensor_out cy, | |||||
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace) = 0; | |||||
void deduce_layout( | |||||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||||
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy, | |||||
TensorLayout& cy, TensorLayout& reserve_space); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& cy, | |||||
const TensorLayout& reserve_space) = 0; | |||||
virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& cy, | |||||
const TensorLayout& reserve_space, size_t workspace_in_bytes); | |||||
}; | |||||
using LSTM = LSTMForward; | |||||
class LSTMBackward : public OperatorBase { | |||||
DEF_OPR_PARAM(LSTM); | |||||
DEF_OPR_IMPL(LSTMBackward, OperatorBase, 9, 4); | |||||
public: | |||||
virtual void exec( | |||||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||||
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, | |||||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, | |||||
_megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw, | |||||
_megdnn_workspace workspace) = 0; | |||||
void deduce_layout( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||||
const TensorLayout& reserve_space, TensorLayout& dx, TensorLayout& dhx, | |||||
TensorLayout& dcx, TensorLayout& dw); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||||
const TensorLayout& dhx, const TensorLayout& dcx, | |||||
const TensorLayout& dw) = 0; | |||||
protected: | |||||
void check_exec( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw, | |||||
size_t workspace_in_bytes); | |||||
}; | |||||
} // namespace megdnn | } // namespace megdnn | ||||
#include "megdnn/internal/opr_header_epilogue.h" | #include "megdnn/internal/opr_header_epilogue.h" | ||||
@@ -36,13 +36,18 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
add_enum(Doc('Format', 'convolution data/filter/output format; see ' | add_enum(Doc('Format', 'convolution data/filter/output format; see ' | ||||
':class:`RelayoutFormat` for more details'), | ':class:`RelayoutFormat` for more details'), | ||||
'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', | 'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', | ||||
'NCHW44 = 7','NCHW44_DOT = 8', | |||||
'NCHW44 = 7', 'NCHW44_DOT = 8', | |||||
Doc('NCHW_WINOGRAD = 9', 'NCHW layout with weights tranformed by winograd'), | Doc('NCHW_WINOGRAD = 9', 'NCHW layout with weights tranformed by winograd'), | ||||
Doc('NCHW88_WINOGRAD = 10', 'NCHW88 layout with weights tranformed by winograd'), | |||||
Doc('NCHW44_WINOGRAD = 11', 'NCHW44 layout with weights tranformed by winograd'), | |||||
Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||||
Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||||
Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||||
Doc('NCHW88_WINOGRAD = 10', | |||||
'NCHW88 layout with weights tranformed by winograd'), | |||||
Doc('NCHW44_WINOGRAD = 11', | |||||
'NCHW44 layout with weights tranformed by winograd'), | |||||
Doc('NCHW4_NCHW32 = 12', | |||||
'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||||
Doc('NCHW32_NCHW4 = 13', | |||||
'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||||
Doc('NCHW4_NCHW = 14', | |||||
'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||||
Doc('NHWC_NCHW = 15', 'NHWC_NCHW means input tensors are nhwc layout, ' | Doc('NHWC_NCHW = 15', 'NHWC_NCHW means input tensors are nhwc layout, ' | ||||
'output tensor is nchw layout'), | 'output tensor is nchw layout'), | ||||
Doc('NHWC_NCHW4_IC_SMALL = 16', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' | Doc('NHWC_NCHW4_IC_SMALL = 16', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' | ||||
@@ -96,10 +101,13 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
add_enum(Doc('Format', 'convolution data/filter/output format; see ' | add_enum(Doc('Format', 'convolution data/filter/output format; see ' | ||||
':class:`RelayoutFormat` for more details'), | ':class:`RelayoutFormat` for more details'), | ||||
'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', | 'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', | ||||
'NCHW44 = 7','NCHW44_DOT = 8', | |||||
Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||||
Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||||
Doc('NCHW4_NCHW = 11', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||||
'NCHW44 = 7', 'NCHW44_DOT = 8', | |||||
Doc('NCHW4_NCHW32 = 9', | |||||
'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||||
Doc('NCHW32_NCHW4 = 10', | |||||
'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||||
Doc('NCHW4_NCHW = 11', | |||||
'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||||
Doc('NHWC_NCHW = 12', 'NHWC_NCHW means input tensors are nhwc layout, ' | Doc('NHWC_NCHW = 12', 'NHWC_NCHW means input tensors are nhwc layout, ' | ||||
'output tensor is nchw layout'), | 'output tensor is nchw layout'), | ||||
Doc('NHWC_NCHW4_IC_SMALL = 13', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' | Doc('NHWC_NCHW4_IC_SMALL = 13', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' | ||||
@@ -107,11 +115,11 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
Doc('NCHW_NCHW4_IC_SMALL = 14', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' | Doc('NCHW_NCHW4_IC_SMALL = 14', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' | ||||
'output tensor is nchw4 layout, padding c=4'), | 'output tensor is nchw4 layout, padding c=4'), | ||||
Doc('CHWN4 = 15', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | Doc('CHWN4 = 15', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | ||||
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'), | |||||
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'), | |||||
Doc('NCHW64 = 16', 'NCHW64 is designed for convolution implementation to utilizing TensorCore ' | Doc('NCHW64 = 16', 'NCHW64 is designed for convolution implementation to utilizing TensorCore ' | ||||
'instructions for 4-bit integers on Nvidia platforms'), | |||||
'instructions for 4-bit integers on Nvidia platforms'), | |||||
Doc('NCHW4_NHWC = 17', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout')). | Doc('NCHW4_NHWC = 17', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout')). | ||||
add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode') | |||||
add_enum_alias('ComputeMode', 'ConvolutionV1', name_field='compute_mode') | |||||
) | ) | ||||
@@ -133,7 +141,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
add_enum_alias('ConvMode', 'ConvolutionV0', 'Mode'). | add_enum_alias('ConvMode', 'ConvolutionV0', 'Mode'). | ||||
add_enum('PoolMode', 'AVERAGE = 0', 'MAX = 1'). | add_enum('PoolMode', 'AVERAGE = 0', 'MAX = 1'). | ||||
add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'SIGMOID = 2'). | add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'SIGMOID = 2'). | ||||
add_fields('uint32', 'pool_shape_h', 1, 'pool_shape_w', 1, 'pool_stride_h', 1, 'pool_stride_w', 1, \ | |||||
add_fields('uint32', 'pool_shape_h', 1, 'pool_shape_w', 1, 'pool_stride_h', 1, 'pool_stride_w', 1, | |||||
'pool_pad_h', 0, 'pool_pad_w', 0, 'conv_stride_h', 1, 'conv_stride_w', 1, 'conv_pad_h', 0, 'conv_pad_w', 0)) | 'pool_pad_h', 0, 'pool_pad_w', 0, 'conv_stride_h', 1, 'conv_stride_w', 1, 'conv_pad_h', 0, 'conv_pad_w', 0)) | ||||
(pdef('ConvBias', 'legacy conv_bias', version=0, is_legacy=True). | (pdef('ConvBias', 'legacy conv_bias', version=0, is_legacy=True). | ||||
@@ -216,8 +224,8 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
(pdef('SeparableConv'). | (pdef('SeparableConv'). | ||||
add_enum_alias('Mode', 'ConvolutionV0'). | add_enum_alias('Mode', 'ConvolutionV0'). | ||||
add_enum('BorderMode', 'BORDER_REPLICATE = 0', 'BORDER_REFLECT = 1', | add_enum('BorderMode', 'BORDER_REPLICATE = 0', 'BORDER_REFLECT = 1', | ||||
'BORDER_REFLECT_101 = 2','BORDER_WRAP = 3', | |||||
'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5','BORDER_ISOLATED = 6'). | |||||
'BORDER_REFLECT_101 = 2', 'BORDER_WRAP = 3', | |||||
'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5', 'BORDER_ISOLATED = 6'). | |||||
add_fields('bool', 'is_symm_kernel', 'true'). | add_fields('bool', 'is_symm_kernel', 'true'). | ||||
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1, | add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1, | ||||
'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1)) | 'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1)) | ||||
@@ -247,7 +255,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
) | ) | ||||
(pdef('Pooling', version=1). | (pdef('Pooling', version=1). | ||||
add_enum_alias('Mode','PoolingV0'). | |||||
add_enum_alias('Mode', 'PoolingV0'). | |||||
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 2, 'stride_w', 2, | add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 2, 'stride_w', 2, | ||||
'window_h', 2, 'window_w', 2). | 'window_h', 2, 'window_w', 2). | ||||
add_enum_alias('Format', 'Convolution') | add_enum_alias('Format', 'Convolution') | ||||
@@ -302,7 +310,8 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
). | ). | ||||
add_fields('float32', 'scale', '1.f')) | add_fields('float32', 'scale', '1.f')) | ||||
INTERP_MODES = ['NEAREST = 0', 'LINEAR = 1', 'AREA = 2', 'CUBIC = 3', 'LANCZOS4 = 4'] | |||||
INTERP_MODES = ['NEAREST = 0', 'LINEAR = 1', | |||||
'AREA = 2', 'CUBIC = 3', 'LANCZOS4 = 4'] | |||||
BORDER_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | BORDER_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | ||||
Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'), | Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'), | ||||
Doc('REFLECT_101 = 2', 'gfedcb|abcdefgh|gfedcba'), | Doc('REFLECT_101 = 2', 'gfedcb|abcdefgh|gfedcba'), | ||||
@@ -323,8 +332,8 @@ BORDER_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | |||||
add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f')) | add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f')) | ||||
(pdef('WarpPerspective', version=2). | (pdef('WarpPerspective', version=2). | ||||
add_enum_alias('InterpolationMode','WarpPerspectiveV1',name_field="imode"). | |||||
add_enum_alias('BorderMode','WarpPerspectiveV1',name_field="bmode"). | |||||
add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field="imode"). | |||||
add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field="bmode"). | |||||
add_enum_alias('Format', 'Convolution'). | add_enum_alias('Format', 'Convolution'). | ||||
add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f')) | add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f')) | ||||
@@ -399,7 +408,7 @@ pdef('Elemwise').add_enum( | |||||
Doc('RMULH = 43', 'binary: rounded higher l bits of x * y, where l is the bit ' | Doc('RMULH = 43', 'binary: rounded higher l bits of x * y, where l is the bit ' | ||||
'length of x.'), | 'length of x.'), | ||||
Doc('ATAN2 = 44','binary: atan2(y,x)'), | |||||
Doc('ATAN2 = 44', 'binary: atan2(y,x)'), | |||||
Doc('ERF = 45', 'unary: erf(x)'), | Doc('ERF = 45', 'unary: erf(x)'), | ||||
Doc('ERFINV = 46', 'unary: inverse function of erf(x)'), | Doc('ERFINV = 46', 'unary: inverse function of erf(x)'), | ||||
Doc('ERFC = 47', 'unary: erfc(x)'), | Doc('ERFC = 47', 'unary: erfc(x)'), | ||||
@@ -634,7 +643,7 @@ Currently, ```DEFAULT``` mode means: | |||||
Doc('axis', | Doc('axis', | ||||
'axis along which reduction is performed; if INT_MAX is given, ' | 'axis along which reduction is performed; if INT_MAX is given, ' | ||||
'reduce to given target shape (only used in megbrain)'), | 'reduce to given target shape (only used in megbrain)'), | ||||
(1<<31)-1). | |||||
(1 << 31)-1). | |||||
add_enum('DataType', | add_enum('DataType', | ||||
Doc('DEFAULT = 0', | Doc('DEFAULT = 0', | ||||
''' | ''' | ||||
@@ -689,7 +698,7 @@ Currently, ```DEFAULT``` mode means: | |||||
add_fields('int32', | add_fields('int32', | ||||
Doc('axis', | Doc('axis', | ||||
'axis along which cumsum is performed, default with INT_MAX'), | 'axis along which cumsum is performed, default with INT_MAX'), | ||||
(1<<31)-1). | |||||
(1 << 31)-1). | |||||
add_fields('bool', | add_fields('bool', | ||||
Doc('exclusive', | Doc('exclusive', | ||||
'whether the current element is taken into account'), | 'whether the current element is taken into account'), | ||||
@@ -761,7 +770,8 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||||
(pdef('UniformRNG', version=1). | (pdef('UniformRNG', version=1). | ||||
add_fields('uint64', 'seed', 0). | add_fields('uint64', 'seed', 0). | ||||
add_fields( | add_fields( | ||||
'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'), | |||||
'dtype', Doc( | |||||
'dtype', 'The dtype of output Tensor. Only support Float32.'), | |||||
'DTypeEnum::Float32')) | 'DTypeEnum::Float32')) | ||||
(pdef('GaussianRNG', version=0, is_legacy=True). | (pdef('GaussianRNG', version=0, is_legacy=True). | ||||
@@ -772,7 +782,8 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||||
add_fields('uint64', 'seed', 0). | add_fields('uint64', 'seed', 0). | ||||
add_fields('float32', 'mean', 0, 'std', 1). | add_fields('float32', 'mean', 0, 'std', 1). | ||||
add_fields( | add_fields( | ||||
'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'), | |||||
'dtype', Doc( | |||||
'dtype', 'The dtype of output Tensor. Only support Float32.'), | |||||
'DTypeEnum::Float32')) | 'DTypeEnum::Float32')) | ||||
(pdef('GammaRNG'). | (pdef('GammaRNG'). | ||||
@@ -819,7 +830,7 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||||
('YUV2GRAY_NV12', 'BT601_YUV2GRAY_NV12'), | ('YUV2GRAY_NV12', 'BT601_YUV2GRAY_NV12'), | ||||
('YUV2GRAY_YV12', 'BT601_YUV2GRAY_YV12'), | ('YUV2GRAY_YV12', 'BT601_YUV2GRAY_YV12'), | ||||
('YUV2GRAY_YU12', 'BT601_YUV2GRAY_YU12')], | ('YUV2GRAY_YU12', 'BT601_YUV2GRAY_YU12')], | ||||
name_field = 'mode')) | |||||
name_field='mode')) | |||||
(pdef('WarpAffine', version=0, is_legacy=True) | (pdef('WarpAffine', version=0, is_legacy=True) | ||||
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode') | .add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode') | ||||
@@ -842,7 +853,7 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||||
(pdef('GaussianBlur') | (pdef('GaussianBlur') | ||||
.add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_mode') | .add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_mode') | ||||
.add_fields('uint32', 'kernel_height', 0, 'kernel_width', 0) | .add_fields('uint32', 'kernel_height', 0, 'kernel_width', 0) | ||||
.add_fields('float32','sigma_x', '0.f', 'sigma_y', '0.f')) | |||||
.add_fields('float32', 'sigma_x', '0.f', 'sigma_y', '0.f')) | |||||
(pdef('Resize', version=0, is_legacy=True) | (pdef('Resize', version=0, is_legacy=True) | ||||
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')) | .add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')) | ||||
@@ -855,7 +866,7 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||||
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode') | .add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode') | ||||
.add_enum_alias('Format', 'Convolution', default=1)) | .add_enum_alias('Format', 'Convolution', default=1)) | ||||
(pdef('Remap', version=0,is_legacy=True) | |||||
(pdef('Remap', version=0, is_legacy=True) | |||||
.add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode') | .add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode') | ||||
.add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_type') | .add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_type') | ||||
.add_enum_alias('Format', 'ConvolutionV0', default=1) | .add_enum_alias('Format', 'ConvolutionV0', default=1) | ||||
@@ -909,8 +920,8 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||||
(pdef('SeparableConv3D'). | (pdef('SeparableConv3D'). | ||||
add_enum_alias('Mode', 'Convolution3D'). | add_enum_alias('Mode', 'Convolution3D'). | ||||
add_enum('BorderMode', 'BORDER_REPLICATE = 0', 'BORDER_REFLECT = 1', | add_enum('BorderMode', 'BORDER_REPLICATE = 0', 'BORDER_REFLECT = 1', | ||||
'BORDER_REFLECT_101 = 2','BORDER_WRAP = 3', | |||||
'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5','BORDER_ISOLATED = 6'). | |||||
'BORDER_REFLECT_101 = 2', 'BORDER_WRAP = 3', | |||||
'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5', 'BORDER_ISOLATED = 6'). | |||||
add_fields('bool', 'is_symm_kernel', 'true'). | add_fields('bool', 'is_symm_kernel', 'true'). | ||||
add_fields('uint32', 'pad_d', 0, 'pad_h', 0, 'pad_w', 0, | add_fields('uint32', 'pad_d', 0, 'pad_h', 0, 'pad_w', 0, | ||||
'stride_d', 0, 'stride_h', 1, 'stride_w', 1, | 'stride_d', 0, 'stride_h', 1, 'stride_w', 1, | ||||
@@ -1023,10 +1034,10 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||||
'NCHW_NCHW4 = 24', | 'NCHW_NCHW4 = 24', | ||||
'NCHW4_NCHW = 25', | 'NCHW4_NCHW = 25', | ||||
'NCHW_NCHW4_WEIGHT = 26', | 'NCHW_NCHW4_WEIGHT = 26', | ||||
'NCHW_NCHW64 = 27', | |||||
'NCHW64_NCHW = 28', | |||||
'NCHW_NHWC = 29', | |||||
'NHWC_NCHW = 30', | |||||
'NCHW_NCHW64 = 27', | |||||
'NCHW64_NCHW = 28', | |||||
'NCHW_NHWC = 29', | |||||
'NHWC_NCHW = 30', | |||||
) | ) | ||||
) | ) | ||||
@@ -1048,7 +1059,7 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||||
add_fields('bool', 'is_symm_kernel', 'true'). | add_fields('bool', 'is_symm_kernel', 'true'). | ||||
add_fields('uint32', 'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1)) | add_fields('uint32', 'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1)) | ||||
(pdef('LocalShare', 'Local share convolution',version=0, is_legacy=True). | |||||
(pdef('LocalShare', 'Local share convolution', version=0, is_legacy=True). | |||||
add_enum_alias('Mode', 'ConvolutionV0'). | add_enum_alias('Mode', 'ConvolutionV0'). | ||||
add_fields( | add_fields( | ||||
'uint32', | 'uint32', | ||||
@@ -1089,7 +1100,7 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||||
) | ) | ||||
(pdef('ROIAlign',version=0,is_legacy=True). | |||||
(pdef('ROIAlign', version=0, is_legacy=True). | |||||
add_enum('Mode', 'MAX = 0', 'AVERAGE = 1', name_field='mode'). | add_enum('Mode', 'MAX = 0', 'AVERAGE = 1', name_field='mode'). | ||||
add_enum_alias('Format', 'ConvolutionV0'). | add_enum_alias('Format', 'ConvolutionV0'). | ||||
add_fields('float32', 'spatial_scale', '1.0'). | add_fields('float32', 'spatial_scale', '1.0'). | ||||
@@ -1133,7 +1144,7 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||||
Doc('part_size', 'size of each deformable part'), 1, | Doc('part_size', 'size of each deformable part'), 1, | ||||
Doc('sample_per_part', 'sample count of each bbox'), 1)) | Doc('sample_per_part', 'sample count of each bbox'), 1)) | ||||
(pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)',version=0,is_legacy=True). | |||||
(pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)', version=0, is_legacy=True). | |||||
add_enum_alias('NonlineMode', 'ConvBiasV0'). | add_enum_alias('NonlineMode', 'ConvBiasV0'). | ||||
add_enum_alias('Mode', 'ConvolutionV0'). | add_enum_alias('Mode', 'ConvolutionV0'). | ||||
add_fields( | add_fields( | ||||
@@ -1152,7 +1163,7 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||||
add_enum_alias('ComputeMode', 'ConvolutionV1', name_field="compute_mode") | add_enum_alias('ComputeMode', 'ConvolutionV1', name_field="compute_mode") | ||||
) | ) | ||||
(pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)',version=1). | |||||
(pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)', version=1). | |||||
add_enum_alias('NonlineMode', 'ConvBiasV0'). | add_enum_alias('NonlineMode', 'ConvBiasV0'). | ||||
add_enum_alias('Mode', 'ConvolutionV0'). | add_enum_alias('Mode', 'ConvolutionV0'). | ||||
add_fields( | add_fields( | ||||
@@ -1172,8 +1183,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||||
) | ) | ||||
(pdef('FakeQuant'). | (pdef('FakeQuant'). | ||||
add_fields('int32','qmin','-2147483648'). | |||||
add_fields('int32','qmax','2147483647') | |||||
add_fields('int32', 'qmin', '-2147483648'). | |||||
add_fields('int32', 'qmax', '2147483647') | |||||
) | ) | ||||
(pdef('TQT'). | (pdef('TQT'). | ||||
add_fields('int32', 'qmin', '-2147483648'). | add_fields('int32', 'qmin', '-2147483648'). | ||||
@@ -1192,13 +1203,13 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | |||||
Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'), | Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'), | ||||
Doc('CONSTANT = 2', 'iiiiii|abcdefgh|iiiiiii')] | Doc('CONSTANT = 2', 'iiiiii|abcdefgh|iiiiiii')] | ||||
(pdef('Padding'). | (pdef('Padding'). | ||||
add_fields('uint32', Doc('front_offset_dim0','offset in dim 0'), 0). | |||||
add_fields('uint32', Doc('front_offset_dim1','offset in dim 1'), 0). | |||||
add_fields('uint32', Doc('front_offset_dim2','offset in dim 2'), 0). | |||||
add_fields('uint32', Doc('front_offset_dim3','offset in dim 3'), 0). | |||||
add_fields('uint32', Doc('front_offset_dim4','offset in dim 4'), 0). | |||||
add_fields('uint32', Doc('front_offset_dim5','offset in dim 5'), 0). | |||||
add_fields('uint32', Doc('front_offset_dim6','offset in dim 6'), 0). | |||||
add_fields('uint32', Doc('front_offset_dim0', 'offset in dim 0'), 0). | |||||
add_fields('uint32', Doc('front_offset_dim1', 'offset in dim 1'), 0). | |||||
add_fields('uint32', Doc('front_offset_dim2', 'offset in dim 2'), 0). | |||||
add_fields('uint32', Doc('front_offset_dim3', 'offset in dim 3'), 0). | |||||
add_fields('uint32', Doc('front_offset_dim4', 'offset in dim 4'), 0). | |||||
add_fields('uint32', Doc('front_offset_dim5', 'offset in dim 5'), 0). | |||||
add_fields('uint32', Doc('front_offset_dim6', 'offset in dim 6'), 0). | |||||
add_fields('uint32', Doc('back_offset_dim0', 'back offset in dim0'), 0). | add_fields('uint32', Doc('back_offset_dim0', 'back offset in dim0'), 0). | ||||
add_fields('uint32', Doc('back_offset_dim1', 'back offset in dim1'), 0). | add_fields('uint32', Doc('back_offset_dim1', 'back offset in dim1'), 0). | ||||
add_fields('uint32', Doc('back_offset_dim2', 'back offset in dim2'), 0). | add_fields('uint32', Doc('back_offset_dim2', 'back offset in dim2'), 0). | ||||
@@ -1206,7 +1217,7 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | |||||
add_fields('uint32', Doc('back_offset_dim4', 'back offset in dim4'), 0). | add_fields('uint32', Doc('back_offset_dim4', 'back offset in dim4'), 0). | ||||
add_fields('uint32', Doc('back_offset_dim5', 'back offset in dim5'), 0). | add_fields('uint32', Doc('back_offset_dim5', 'back offset in dim5'), 0). | ||||
add_fields('uint32', Doc('back_offset_dim6', 'back offset in dim6'), 0). | add_fields('uint32', Doc('back_offset_dim6', 'back offset in dim6'), 0). | ||||
add_fields('float32', Doc('padding_val','param of padding opr'), 0). | |||||
add_fields('float32', Doc('padding_val', 'param of padding opr'), 0). | |||||
add_enum('PaddingMode', *PADDING_MODES, | add_enum('PaddingMode', *PADDING_MODES, | ||||
name_field='padding_mode', default=2, | name_field='padding_mode', default=2, | ||||
member_alias=[(i, 'PADDING_{}'.format(i)) for i in PADDING_MODES] | member_alias=[(i, 'PADDING_{}'.format(i)) for i in PADDING_MODES] | ||||
@@ -1223,4 +1234,29 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | |||||
(pdef('Dropout') | (pdef('Dropout') | ||||
.add_fields('float32', 'drop_prob', '0') | .add_fields('float32', 'drop_prob', '0') | ||||
.add_fields('uint64', 'seed', '0') | .add_fields('uint64', 'seed', '0') | ||||
) | |||||
(pdef('RNNCell'). | |||||
add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'TANH = 2') | |||||
) | |||||
(pdef('RNN'). | |||||
add_fields('uint32', 'num_layers', '1'). | |||||
add_fields('bool', 'bidirectional', 'false'). | |||||
add_fields('bool', 'bias', 'true'). | |||||
add_fields('uint32', 'hidden_size', '128'). | |||||
add_fields('uint32', 'proj_size', '0'). | |||||
add_fields('float32', 'dropout', '0.f'). | |||||
add_enum_alias('NonlineMode', 'RNNCell'). | |||||
add_enum_alias('FwdMode', 'BN', name_field='fwd_mode') | |||||
) | |||||
(pdef('LSTM'). | |||||
add_fields('uint32', 'num_layers', '1'). | |||||
add_fields('bool', 'bidirectional', 'false'). | |||||
add_fields('bool', 'bias', 'true'). | |||||
add_fields('uint32', 'hidden_size', '128'). | |||||
add_fields('uint32', 'proj_size', '0'). | |||||
add_fields('float32', 'dropout', '0.f'). | |||||
add_enum_alias('FwdMode', 'BN', name_field='fwd_mode') | |||||
) | ) |
@@ -213,7 +213,13 @@ private: | |||||
cb(LayerNormForward) \ | cb(LayerNormForward) \ | ||||
cb(LayerNormBackward) \ | cb(LayerNormBackward) \ | ||||
cb(DropoutForward) \ | cb(DropoutForward) \ | ||||
cb(DropoutBackward) | |||||
cb(DropoutBackward) \ | |||||
cb(RNNCell) \ | |||||
cb(LSTMCell) \ | |||||
cb(RNN) \ | |||||
cb(RNNBackward) \ | |||||
cb(LSTM) \ | |||||
cb(LSTMBackward) | |||||
// clang-format on | // clang-format on | ||||
/*! | /*! | ||||
@@ -0,0 +1,98 @@ | |||||
/** | |||||
* \file dnn/src/common/lstm.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/utils.h" | |||||
// #include "src/cuda/lstm/utils.h" | |||||
namespace megdnn { | |||||
/*size_t get_reserve_size(Handle* handle, megdnn::LSTMForward::Param& param, const | |||||
TensorLayout& input) { #if CUDNN_MAJOR >= 6 auto holder = | |||||
megdnn::cuda::lstm::get_RNNDescHolder_v6(handle, param, input); return | |||||
holder.reserveSpace_size; # else return 0; #endif | |||||
}*/ | |||||
void LSTM::deduce_layout( | |||||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||||
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy, | |||||
TensorLayout& cy, TensorLayout& reserve_space) { | |||||
// input: [seq_len, batch_size, input_size] | |||||
// hx: [D * num_layers, batch_size, hidden_size] | |||||
size_t seq_len = input.shape[0]; | |||||
size_t batch_size = input.shape[1]; | |||||
size_t D = param().bidirectional ? 2 : 1; | |||||
size_t hidden_size = hx.shape[2]; | |||||
output = TensorLayout( | |||||
TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype); | |||||
hy = TensorLayout(hx); | |||||
cy = TensorLayout(cx); | |||||
// reserve_space = {{get_reserve_size(this->handle(), param(), input)}, | |||||
// dtype::Byte()}; | |||||
reserve_space = {{get_reserve_size_in_bytes(input)}, dtype::Byte()}; | |||||
} | |||||
void LSTM::check_exec( | |||||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& cy, | |||||
const TensorLayout& reserve_space, size_t workspace_in_bytes) { | |||||
auto errmsg = [&]() { | |||||
std::string msg; | |||||
msg.append("input="); | |||||
msg.append(input.to_string()); | |||||
msg.append(", hx="); | |||||
msg.append(hx.to_string()); | |||||
msg.append(", cx="); | |||||
msg.append(cx.to_string()); | |||||
msg.append(", hidden_size="); | |||||
msg.append(std::to_string(param().hidden_size)); | |||||
msg.append(", num_layers="); | |||||
msg.append(std::to_string(param().num_layers)); | |||||
msg.append(", bidirectional="); | |||||
msg.append(std::to_string(param().bidirectional)); | |||||
return msg; | |||||
}; | |||||
size_t D = param().bidirectional ? 2 : 1; | |||||
size_t num_layers = param().num_layers; | |||||
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str()); | |||||
ASSERT_BRIEF(hx.ndim == 3) | |||||
ASSERT_BRIEF(hx.shape[0] == D * num_layers) | |||||
ASSERT_BRIEF(hx.shape[1] == input.shape[1]) // batch_size | |||||
ASSERT_BRIEF(hx.shape[2] == param().hidden_size) | |||||
ASSERT_BRIEF(cx.ndim == 3) | |||||
ASSERT_BRIEF(cx.shape[0] == D * num_layers) | |||||
ASSERT_BRIEF(cx.shape[1] == input.shape[1]) // batch_size | |||||
ASSERT_BRIEF(cx.shape[2] == param().hidden_size) | |||||
#undef ASSERT_BRIEF | |||||
} | |||||
void LSTMBackward::deduce_layout( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||||
const TensorLayout& reserve_space, TensorLayout& dx, TensorLayout& dhx, | |||||
TensorLayout& dcx, TensorLayout& dw) { | |||||
dx = x; | |||||
dhx = hx; | |||||
dcx = cx; | |||||
dw = flatten_weights; | |||||
} | |||||
void LSTMBackward::check_exec( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw, | |||||
size_t workspace_in_bytes) {} | |||||
} // namespace megdnn |
@@ -0,0 +1,136 @@ | |||||
/** | |||||
* \file dnn/src/common/lstm_cell.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/common/lstm_cell.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
void LSTMCell::deduce_layout( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new, | |||||
TensorLayout& gates) { | |||||
// size_t batch_size = hx.shape[0]; | |||||
// size_t hidden_size = hx.shape[1]; | |||||
h_new = TensorLayout(hx, hx.dtype); | |||||
c_new = TensorLayout(cx, cx.dtype); | |||||
auto opr = handle()->create_operator<RNNCellForward>(); | |||||
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY; | |||||
opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates); | |||||
} | |||||
void LSTMCell::check_exec( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, | |||||
const TensorLayout& gates, size_t workspace_in_bytes) { | |||||
TensorLayout h_new_expected, c_new_expected, gates_expected; | |||||
deduce_layout( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new_expected, | |||||
c_new_expected, gates_expected); | |||||
megdnn_assert_eq_layout(h_new_expected, h_new); | |||||
megdnn_assert_eq_layout(c_new_expected, c_new); | |||||
megdnn_assert_eq_layout(gates_expected, gates); | |||||
auto required_workspace_in_bytes = get_workspace_in_bytes( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
} | |||||
} // namespace megdnn | |||||
namespace megdnn { | |||||
namespace lstm_cell { | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, | |||||
const TensorLayout& gates, Handle* handle) { | |||||
TensorLayout tmp_layout; | |||||
auto opr = handle->create_operator<RNNCellForward>(); | |||||
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY; | |||||
opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, tmp_layout); | |||||
size_t rnn_cell_need = opr->get_workspace_in_bytes( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates); | |||||
size_t lstm_cell_need = tmp_layout.span().dist_byte(); | |||||
return rnn_cell_need > lstm_cell_need ? rnn_cell_need : lstm_cell_need; | |||||
} | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||||
_megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle) { | |||||
auto opr = handle->create_operator<RNNCellForward>(); | |||||
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY; | |||||
/*TensorLayout tmp_layout; | |||||
opr->deduce_layout(input.layout, weight_ih.layout, | |||||
hx.layout, weight_hh.layout, | |||||
bias.layout, tmp_layout); | |||||
auto workspace_ptr = workspace.raw_ptr; | |||||
// TensorND tmp{static_cast<void*>(workspace.raw_ptr), tmp_layout}; | |||||
TensorND tmp{workspace_ptr, tmp_layout}; | |||||
auto new_workspace = Workspace{workspace_ptr + tmp.layout.span().dist_byte(), | |||||
workspace.size - | |||||
tmp.layout.span().dist_byte()};*/ | |||||
// opr->exec(input, weight_ih, hx, weight_hh, bias, tmp, new_workspace); | |||||
opr->exec(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates, workspace); | |||||
// activation | |||||
// size_t batch_size = tmp.layout.shape[0]; | |||||
size_t batch_size = hx.layout.shape[0]; | |||||
size_t hidden_size = hx.layout.shape[1]; | |||||
// sigmoid: i f o | |||||
// TensorLayout gates_ifo_layout{TensorShape({batch_size, hidden_size * 3}), | |||||
// tmp.layout.dtype}; | |||||
TensorND tmp{static_cast<void*>(workspace.raw_ptr), gates.layout}; | |||||
TensorLayout gates_ifo_layout{ | |||||
TensorShape({batch_size, hidden_size * 3}), gates.layout.dtype}; | |||||
TensorND gates_ifo_origin{gates.raw_ptr(), gates_ifo_layout}; | |||||
TensorND gates_ifo{tmp.raw_ptr(), gates_ifo_layout}; | |||||
auto sigmoid = handle->create_operator<ElemwiseForward>(); | |||||
sigmoid->param().mode = Elemwise::Param::Mode::SIGMOID; | |||||
sigmoid->exec({gates_ifo_origin}, gates_ifo); | |||||
// tanh: g | |||||
TensorLayout g_layout{TensorShape({batch_size, hidden_size}), gates.layout.dtype}; | |||||
TensorND g_origin{ | |||||
static_cast<char*>(gates.raw_ptr()) + gates_ifo_layout.span().dist_byte(), | |||||
g_layout}; | |||||
TensorND g{ | |||||
static_cast<char*>(tmp.raw_ptr()) + gates_ifo_layout.span().dist_byte(), | |||||
g_layout}; | |||||
auto tanh = handle->create_operator<ElemwiseForward>(); | |||||
tanh->param().mode = Elemwise::Param::Mode::TANH; | |||||
tanh->exec({g_origin}, g); | |||||
// extract i f o | |||||
TensorND i{static_cast<char*>(tmp.raw_ptr()), g_layout}; | |||||
TensorND f{ | |||||
static_cast<char*>(tmp.raw_ptr()) + g_layout.span().dist_byte(), g_layout}; | |||||
TensorND o{ | |||||
static_cast<char*>(tmp.raw_ptr()) + g_layout.span().dist_byte() * 2, | |||||
g_layout}; | |||||
// calculate new cell state | |||||
auto elewise_mul_add = handle->create_operator<ElemwiseForward>(); | |||||
elewise_mul_add->param().mode = Elemwise::Param::Mode::FUSE_MUL_ADD4; | |||||
elewise_mul_add->exec({f, cx, i, g}, c_new); | |||||
// calculate new hidden state | |||||
tanh->exec({c_new}, h_new); | |||||
auto elewise_mul = handle->create_operator<ElemwiseForward>(); | |||||
elewise_mul->param().mode = Elemwise::Param::Mode::MUL; | |||||
elewise_mul->exec({o, h_new}, h_new); | |||||
} | |||||
} // namespace lstm_cell | |||||
} // namespace megdnn |
@@ -0,0 +1,32 @@ | |||||
/** | |||||
* \file dnn/src/common/lstm_cell.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "megdnn/oprs/base.h" | |||||
namespace megdnn { | |||||
namespace lstm_cell { | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, | |||||
const TensorLayout& gates, Handle* handle); | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||||
_megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle); | |||||
} // namespace lstm_cell | |||||
} // namespace megdnn |
@@ -139,6 +139,8 @@ DEF(LayerNormForward, 6, true, true); | |||||
DEF(LayerNormBackward, 8, true, true); | DEF(LayerNormBackward, 8, true, true); | ||||
DEF(DropoutForward, 3, true, true); | DEF(DropoutForward, 3, true, true); | ||||
DEF(DropoutBackward, 3, true, true); | DEF(DropoutBackward, 3, true, true); | ||||
DEF(RNNCellForward, 6, true, true); | |||||
DEF(RNNForward, 6, true, true); | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -0,0 +1,82 @@ | |||||
/** | |||||
* \file dnn/src/common/rnn.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/common/rnn.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
void RNN::deduce_layout( | |||||
const TensorLayout& input, const TensorLayout& hx, | |||||
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy, | |||||
TensorLayout& reserve_space) { | |||||
// input: [seq_len, batch_size, input_size] | |||||
// hx: [D * num_layers, batch_size, hidden_size] | |||||
size_t seq_len = input.shape[0]; | |||||
size_t batch_size = input.shape[1]; | |||||
size_t D = param().bidirectional ? 2 : 1; | |||||
size_t hidden_size = hx.shape[2]; | |||||
output = TensorLayout( | |||||
TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype); | |||||
hy = TensorLayout(hx); | |||||
// reserve_space = {{get_reserve_size(this->handle(), param(), input)}, | |||||
// dtype::Byte()}; | |||||
reserve_space = {{get_reserve_size_in_bytes(input)}, dtype::Byte()}; | |||||
} | |||||
void RNN::check_exec( | |||||
const TensorLayout& input, const TensorLayout& hx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& reserve_space, | |||||
size_t workspace_in_bytes) { | |||||
auto errmsg = [&]() { | |||||
std::string msg; | |||||
msg.append("input="); | |||||
msg.append(input.to_string()); | |||||
msg.append(", hx="); | |||||
msg.append(hx.to_string()); | |||||
msg.append(", hidden_size="); | |||||
msg.append(std::to_string(param().hidden_size)); | |||||
msg.append(", num_layers="); | |||||
msg.append(std::to_string(param().num_layers)); | |||||
msg.append(", bidirectional="); | |||||
msg.append(std::to_string(param().bidirectional)); | |||||
return msg; | |||||
}; | |||||
size_t D = param().bidirectional ? 2 : 1; | |||||
size_t num_layers = param().num_layers; | |||||
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str()); | |||||
ASSERT_BRIEF(hx.ndim == 3) | |||||
ASSERT_BRIEF(hx.shape[0] == D * num_layers) | |||||
ASSERT_BRIEF(hx.shape[1] == input.shape[1]) // batch_size | |||||
ASSERT_BRIEF(hx.shape[2] == param().hidden_size) | |||||
#undef ASSERT_BRIEF | |||||
} | |||||
void RNNBackward::deduce_layout( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||||
TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw) { | |||||
dx = x; | |||||
dhx = hx; | |||||
dw = flatten_weights; | |||||
} | |||||
void RNNBackward::check_exec( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||||
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw, | |||||
size_t workspace_in_bytes) {} | |||||
} // namespace megdnn |
@@ -0,0 +1,33 @@ | |||||
/** | |||||
* \file dnn/src/common/rnn.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/opr_param_defs.h" | |||||
#include "megdnn/oprs/base.h" | |||||
#include "megdnn/oprs/general.h" | |||||
namespace megdnn { | |||||
namespace rnn { | |||||
using Param = param::RNN; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayoutArray& weight_ih, | |||||
const TensorLayoutArray& states, const TensorLayoutArray& weight_hh, | |||||
const TensorLayoutArray& bias, const TensorLayout& output, | |||||
const TensorLayoutArray& states_new, const Param& param, Handle* handle); | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_in const TensorNDArray& weight_ih, | |||||
_megdnn_in const TensorNDArray& states, | |||||
_megdnn_in const TensorNDArray& weight_hh, _megdnn_in const TensorNDArray& bias, | |||||
_megdnn_tensor_out output, _megdnn_out const TensorNDArray& states_new, | |||||
_megdnn_workspace workspace, const Param& param, Handle* handle); | |||||
} // namespace rnn | |||||
} // namespace megdnn |
@@ -0,0 +1,108 @@ | |||||
/** | |||||
* \file dnn/src/common/rnn_cell.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/common/rnn_cell.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
void RNNCell::deduce_layout( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, TensorLayout& dst) { | |||||
// megdnn_assert(hx.ndim == 2); | |||||
size_t batch_size = hx.shape[0]; | |||||
// size_t hidden_size = weight_hh.shape[1]; | |||||
size_t gate_hidden_size = weight_ih.shape[0]; | |||||
// size_t input_size = weight_ih.shape[1]; | |||||
// megdnn_assert(input.shape[1] == input_size); | |||||
// megdnn_assert(hx.shape[1] == hidden_size); | |||||
// megdnn_assert_eq_dtype(input, hx); | |||||
dst = TensorLayout(TensorShape({batch_size, gate_hidden_size}), input.dtype); | |||||
} | |||||
void RNNCell::check_exec( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& dst, size_t workspace_in_bytes) { | |||||
TensorLayout dst_expected; | |||||
megdnn_assert_eq_dtype(input, dst); | |||||
megdnn_assert_eq_dtype(hx, dst); | |||||
deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst_expected); | |||||
megdnn_assert_eq_layout(dst_expected, dst); | |||||
auto required_workspace_in_bytes = get_workspace_in_bytes( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
} | |||||
} // namespace megdnn | |||||
namespace megdnn { | |||||
namespace rnn_cell { | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& dst, Handle* handle) { | |||||
auto opr = handle->create_operator<MatrixMulForward>(); | |||||
opr->param().transposeB = true; | |||||
return dst.span().dist_byte() + opr->get_workspace_in_bytes(hx, weight_hh, dst); | |||||
} | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace, | |||||
param::RNNCell::NonlineMode nonline_mode, Handle* handle) { | |||||
TensorND tmp{static_cast<void*>(workspace.raw_ptr), dst.layout}; | |||||
_megdnn_workspace new_workspace = { | |||||
workspace.raw_ptr + dst.layout.span().dist_byte(), | |||||
workspace.size - dst.layout.span().dist_byte()}; | |||||
auto opr = handle->create_operator<MatrixMulForward>(); | |||||
opr->param().transposeB = true; | |||||
opr->exec(input, weight_ih, tmp, new_workspace); | |||||
opr->exec(hx, weight_hh, dst, new_workspace); | |||||
// if (this->param().bias) add_bias(dst, tmp, bias, dst); | |||||
// if (this->param().bias) { | |||||
auto add_opr = handle->create_operator<ElemwiseForward>(); | |||||
add_opr->param().mode = Elemwise::Param::Mode::ADD; | |||||
add_opr->exec({dst, tmp}, dst); | |||||
add_opr->exec({dst, bias_ih}, dst); | |||||
add_opr->exec({dst, bias_hh}, dst); | |||||
// } | |||||
// activation | |||||
using NonlineMode = param::RNNCell::NonlineMode; | |||||
switch (nonline_mode) { | |||||
#define cb(_mode) \ | |||||
case NonlineMode::_mode: { \ | |||||
auto nonlinear = handle->create_operator<ElemwiseForward>(); \ | |||||
nonlinear->param().mode = Elemwise::Param::Mode::_mode; \ | |||||
nonlinear->exec({dst}, dst); \ | |||||
break; \ | |||||
} | |||||
cb(RELU); | |||||
cb(TANH); | |||||
#undef cb | |||||
case NonlineMode::IDENTITY: | |||||
break; | |||||
default: | |||||
megdnn_assert(false); | |||||
} | |||||
} | |||||
} // namespace rnn_cell | |||||
} // namespace megdnn |
@@ -0,0 +1,31 @@ | |||||
/** | |||||
* \file dnn/src/common/rnn_cell.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs/base.h" | |||||
#include "megdnn/oprs/general.h" | |||||
namespace megdnn { | |||||
namespace rnn_cell { | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& dst, Handle* handle); | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace, | |||||
param::RNNCell::NonlineMode nonline_mode, Handle* handle); | |||||
} // namespace rnn_cell | |||||
} // namespace megdnn |
@@ -160,6 +160,29 @@ void TensorDesc::set( | |||||
} | } | ||||
} | } | ||||
void TensorDesc::set_nd(const TensorLayout& layout, int pad) { | |||||
int nbDims = layout.ndim < pad ? pad : layout.ndim; | |||||
int dimA[nbDims], strideA[nbDims]; | |||||
for (size_t i = 0; i < layout.ndim; ++i) { | |||||
dimA[i] = layout.shape[i]; | |||||
// strideA[i] = layout.stride[i]; | |||||
} | |||||
for (size_t i = layout.ndim; i < nbDims; ++i) { | |||||
dimA[i] = 1; // unused | |||||
// strideA[i] = 1; | |||||
} | |||||
// stride | |||||
for (size_t i = 0; i < nbDims; ++i) { | |||||
strideA[i] = 1; | |||||
for (size_t j = i + 1; j < nbDims; ++j) { | |||||
strideA[i] *= dimA[j]; | |||||
} | |||||
} | |||||
cudnn_check(cudnnSetTensorNdDescriptor( | |||||
desc, to_cudnn_dtype(layout.dtype), nbDims, dimA, strideA)); | |||||
} | |||||
std::string TensorDesc::to_string() { | std::string TensorDesc::to_string() { | ||||
cudnnDataType_t data_type; | cudnnDataType_t data_type; | ||||
int n; | int n; | ||||
@@ -433,6 +456,97 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) { | |||||
desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT)); | desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT)); | ||||
} | } | ||||
DropoutDesc::DropoutDesc() { | |||||
cudnn_check(cudnnCreateDropoutDescriptor(&desc)); | |||||
} | |||||
DropoutDesc::~DropoutDesc() { | |||||
cudnn_check(cudnnDestroyDropoutDescriptor(desc)); | |||||
} | |||||
void DropoutDesc::set(float dropout, Handle* handle, TensorND& state) { | |||||
cudnn_check(cudnnSetDropoutDescriptor( | |||||
desc, cudnn_handle(handle), dropout, state.raw_ptr(), | |||||
state.layout.span().dist_byte(), 0 // seed | |||||
)); | |||||
} | |||||
void DropoutDesc::set_no_dropout(Handle* handle) { | |||||
cudnn_check( | |||||
cudnnSetDropoutDescriptor(desc, cudnn_handle(handle), 0, nullptr, 0, 0)); | |||||
} | |||||
RNNDesc::RNNDesc() { | |||||
cudnn_check(cudnnCreateRNNDescriptor(&desc)); | |||||
} | |||||
RNNDesc::~RNNDesc() { | |||||
cudnn_check(cudnnDestroyRNNDescriptor(desc)); | |||||
} | |||||
void RNNDesc::set( | |||||
size_t input_size, size_t hidden_size, size_t proj_size, size_t num_layers, | |||||
bool bidirectional, bool bias, const megdnn::DType dtype, cudnnRNNMode_t mode, | |||||
DropoutDesc& dropout_desc, Handle* handle) { | |||||
cudnnRNNMode_t rnn_mode = mode; | |||||
cudnnRNNBiasMode_t bias_mode = bias ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; | |||||
cudnnDirectionMode_t dir_mode = | |||||
bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; | |||||
cudnnDataType_t math_prec; | |||||
// math precision | |||||
if (dtype.enumv() == DTypeEnum::Float16) | |||||
math_prec = CUDNN_DATA_HALF; | |||||
else | |||||
math_prec = CUDNN_DATA_FLOAT; | |||||
#if false // CUDNN_MAJOR >= 8 | |||||
cudnn_check(cudnnSetRNNDescriptor_v8( | |||||
desc, CUDNN_RNN_ALGO_STANDARD, mode, bias_mode, dir_mode, | |||||
CUDNN_LINEAR_INPUT, to_cudnn_dtype(dtype), math_prec, CUDNN_DEFAULT_MATH, | |||||
input_size, hidden_size, proj_size, num_layers, dropout_desc.desc, | |||||
CUDNN_RNN_PADDED_IO_DISABLED)); | |||||
#else | |||||
cudnn_check(cudnnSetRNNDescriptor_v6( | |||||
cudnn_handle(handle), desc, hidden_size, num_layers, dropout_desc.desc, | |||||
CUDNN_LINEAR_INPUT, dir_mode, mode, CUDNN_RNN_ALGO_STANDARD, math_prec)); | |||||
#endif | |||||
} | |||||
RNNDataDesc::RNNDataDesc() { | |||||
cudnn_check(cudnnCreateRNNDataDescriptor(&desc)); | |||||
} | |||||
RNNDataDesc::~RNNDataDesc() { | |||||
cudnn_check(cudnnDestroyRNNDataDescriptor(desc)); | |||||
} | |||||
void RNNDataDesc::set( | |||||
int batchSize, int vectorSize, int maxSeqLength, const int* devSeqLengths, | |||||
DType dtype) { | |||||
// for now, all tensor are padded in python | |||||
// int seqLengthArray[batchSize]; | |||||
// for (int i = 0; i < batchSize; ++i) seqLengthArray[i] = maxSeqLength; | |||||
cudnn_check(cudnnSetRNNDataDescriptor( | |||||
desc, to_cudnn_dtype(dtype), CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED, | |||||
maxSeqLength, batchSize, vectorSize, devSeqLengths, nullptr)); | |||||
} | |||||
RNNWeightFilterDesc::RNNWeightFilterDesc() { | |||||
cudnn_check(cudnnCreateFilterDescriptor(&desc)); | |||||
} | |||||
RNNWeightFilterDesc::~RNNWeightFilterDesc() { | |||||
cudnn_check(cudnnDestroyFilterDescriptor(desc)); | |||||
} | |||||
void RNNWeightFilterDesc::set(const TensorLayout& flatten_weights) { | |||||
int weight_elem_num = flatten_weights.total_nr_elems(); | |||||
int dimW[] = {weight_elem_num, 1, 1}; | |||||
cudnn_check(cudnnSetFilterNdDescriptor( | |||||
desc, to_cudnn_dtype(flatten_weights.dtype), CUDNN_TENSOR_NCHW, 3, dimW)); | |||||
} | |||||
////////////////////////// CudnnAlgoPack ////////////////////////// | ////////////////////////// CudnnAlgoPack ////////////////////////// | ||||
#define V1(v) #v | #define V1(v) #v | ||||
@@ -30,6 +30,7 @@ public: | |||||
void set( | void set( | ||||
const TensorLayout& layout, | const TensorLayout& layout, | ||||
const param::Convolution::Format = param::Convolution::Format::NCHW); | const param::Convolution::Format = param::Convolution::Format::NCHW); | ||||
void set_nd(const TensorLayout& layout, int pad = 3); // at least 3 dimensions | |||||
std::string to_string(); | std::string to_string(); | ||||
~TensorDesc(); | ~TensorDesc(); | ||||
cudnnTensorDescriptor_t desc; | cudnnTensorDescriptor_t desc; | ||||
@@ -121,6 +122,44 @@ public: | |||||
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr> conv3d_fwd_algos(); | static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr> conv3d_fwd_algos(); | ||||
}; | }; | ||||
class DropoutDesc { | |||||
public: | |||||
DropoutDesc(); | |||||
void set(float dropout, Handle* handle, TensorND& state); | |||||
void set_no_dropout(Handle* handle); | |||||
~DropoutDesc(); | |||||
cudnnDropoutDescriptor_t desc; | |||||
}; | |||||
class RNNDesc { | |||||
public: | |||||
RNNDesc(); | |||||
void set( | |||||
size_t input_size, size_t hidden_size, size_t proj_size, size_t num_layers, | |||||
bool bidirectional, bool bias, const megdnn::DType dtype, | |||||
cudnnRNNMode_t mode, DropoutDesc& dropout_desc, Handle* handle); | |||||
~RNNDesc(); | |||||
cudnnRNNDescriptor_t desc; | |||||
}; | |||||
class RNNDataDesc { | |||||
public: | |||||
RNNDataDesc(); | |||||
void set( | |||||
int batchSize, int vectorSize, int maxSeqLength, const int* devSeqLengths, | |||||
DType dtype); | |||||
~RNNDataDesc(); | |||||
cudnnRNNDataDescriptor_t desc; | |||||
}; | |||||
class RNNWeightFilterDesc { | |||||
public: | |||||
RNNWeightFilterDesc(); | |||||
void set(const TensorLayout& flatten_weights); | |||||
~RNNWeightFilterDesc(); | |||||
cudnnFilterDescriptor_t desc; | |||||
}; | |||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -52,6 +52,8 @@ | |||||
#include "src/cuda/local_share/opr_impl.h" | #include "src/cuda/local_share/opr_impl.h" | ||||
#include "src/cuda/lrn/opr_impl.h" | #include "src/cuda/lrn/opr_impl.h" | ||||
#include "src/cuda/lsq/opr_impl.h" | #include "src/cuda/lsq/opr_impl.h" | ||||
#include "src/cuda/lstm/opr_impl.h" | |||||
#include "src/cuda/lstm_cell/opr_impl.h" | |||||
#include "src/cuda/mask_conv/opr_impl.h" | #include "src/cuda/mask_conv/opr_impl.h" | ||||
#include "src/cuda/matrix_inverse/opr_impl.h" | #include "src/cuda/matrix_inverse/opr_impl.h" | ||||
#include "src/cuda/matrix_mul/opr_impl.h" | #include "src/cuda/matrix_mul/opr_impl.h" | ||||
@@ -68,6 +70,8 @@ | |||||
#include "src/cuda/repeat/opr_impl.h" | #include "src/cuda/repeat/opr_impl.h" | ||||
#include "src/cuda/resize/opr_impl.h" | #include "src/cuda/resize/opr_impl.h" | ||||
#include "src/cuda/rng/opr_impl.h" | #include "src/cuda/rng/opr_impl.h" | ||||
#include "src/cuda/rnn/opr_impl.h" | |||||
#include "src/cuda/rnn_cell/opr_impl.h" | |||||
#include "src/cuda/roi_align/opr_impl.h" | #include "src/cuda/roi_align/opr_impl.h" | ||||
#include "src/cuda/roi_copy/opr_impl.h" | #include "src/cuda/roi_copy/opr_impl.h" | ||||
#include "src/cuda/roi_pooling/opr_impl.h" | #include "src/cuda/roi_pooling/opr_impl.h" | ||||
@@ -90,7 +94,146 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace cuda { | namespace cuda { | ||||
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_SPECIALIZE_CREATE_OPERATOR) | |||||
// MEGDNN_FOREACH_OPR_CLASS(MEGDNN_SPECIALIZE_CREATE_OPERATOR) | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardFilter); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvPoolingForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBiasForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Images2NeibsForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Images2NeibsBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SlidingWindowTransposeForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SlidingWindowTransposeBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseMultiType); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdateForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LocalForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LocalBackwardData); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LocalBackwardFilter); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LRNForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LRNBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ROIPoolingForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ROIPoolingBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspectiveForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspectiveBackwardData); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspectiveBackwardMat); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DotForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixInverse); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMulForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SVDForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(CondTake); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(CumsumForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TransposeForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConcatForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SplitForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TileForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TileBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RepeatForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RepeatBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgsortBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingRemapForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingRemapBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingOneHotForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetOneHotForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingMultiAxisVec); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetMultiAxisVec); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingIncrMultiAxisVec); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MeshIndexing); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(IncrMeshIndexing); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SetMeshIndexing); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMeshIndexing); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedIncrMeshIndexing); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedSetMeshIndexing); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(UniformRNG); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianRNG); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GammaRNG); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BetaRNG); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoissonRNG); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PermutationRNG); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ShuffleRNGForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ShuffleRNGBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SeparableConvForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SeparableFilterForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalBackwardData); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalBackwardFilter); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Flip); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Rotate); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ROICopy); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(CvtColor); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpAffine); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianBlur); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Resize); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ResizeBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaxTensorDiff); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskConvForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(MaskPropagate); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Convolution3DForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Convolution3DBackwardData); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Convolution3DBackwardFilter); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DeformableConvForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DeformableConvBackwardFilter); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DeformableConvBackwardData); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DeformablePSROIPoolingForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DeformablePSROIPoolingBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutFormat); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TopK); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PowC); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LocalShareForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LocalShareBackwardData); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LocalShareBackwardFilter); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ROIAlignForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ROIAlignBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(CorrelationForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(CorrelationBackwardData1); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(CorrelationBackwardData2); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchConvBiasForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Remap); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RemapBackwardData); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RemapBackwardMat); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePoolingBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(DctChannelSelectForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(FakeQuantForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(FakeQuantBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TQTForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(TQTBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(CheckNonFinite); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSQForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSQBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Fill); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(PaddingBackward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormBackward); | |||||
template <typename Opr> | |||||
std::unique_ptr<Opr> HandleImpl::create_operator() { | |||||
megdnn_throw("unsupported cuda opr"); | |||||
return nullptr; | |||||
} | |||||
#define MEGDNN_INST_CREATE_OPERATOR(opr) \ | |||||
template std::unique_ptr<megdnn::opr> HandleImpl::create_operator(); | |||||
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) | |||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -0,0 +1,112 @@ | |||||
/** | |||||
* \file dnn/src/cuda/lstm/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/cuda/lstm/opr_impl.h" | |||||
#include "src/cuda/lstm/utils.h" | |||||
#include "src/cuda/utils.h" | |||||
#include <cudnn.h> | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
void LSTMImpl::exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out reserve_space, | |||||
_megdnn_workspace workspace) { | |||||
Handle* handle = this->handle(); | |||||
rnn::RNNForwardDescHolder_v6 desc_holder = | |||||
lstm::get_RNNDescHolder_v6(this->handle(), param(), input.layout); | |||||
auto x_desc_arr = rnn::get_descs(desc_holder.x_descs); | |||||
auto y_desc_arr = rnn::get_descs(desc_holder.y_descs); | |||||
RNNWeightFilterDesc w_desc; | |||||
w_desc.set(flatten_weights.layout); | |||||
if (param().fwd_mode == param::LSTM::FwdMode::TRAINING) { | |||||
cudnn_check(cudnnRNNForwardTraining( | |||||
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len, | |||||
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc, | |||||
hx.raw_ptr(), desc_holder.cx_desc.desc, cx.raw_ptr(), w_desc.desc, | |||||
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(), | |||||
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, | |||||
cy.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, | |||||
reserve_space.raw_ptr(), desc_holder.reserveSpace_size)); | |||||
} else { | |||||
cudnn_check(cudnnRNNForwardInference( | |||||
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len, | |||||
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc, | |||||
hx.raw_ptr(), desc_holder.cx_desc.desc, nullptr, w_desc.desc, | |||||
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(), | |||||
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, | |||||
cy.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size)); | |||||
} | |||||
} | |||||
size_t LSTMImpl::get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& cy, | |||||
const TensorLayout& reserve_space) { | |||||
rnn::RNNForwardDescHolder_v6 desc_holder = | |||||
lstm::get_RNNDescHolder_v6(this->handle(), param(), input); | |||||
return desc_holder.workspace_size; | |||||
} | |||||
size_t LSTMImpl::get_reserve_size_in_bytes(const TensorLayout& input) { | |||||
rnn::RNNForwardDescHolder_v6 desc_holder = | |||||
lstm::get_RNNDescHolder_v6(this->handle(), param(), input); | |||||
return desc_holder.reserveSpace_size; | |||||
} | |||||
void LSTMBackwardImpl::exec( | |||||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||||
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, | |||||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx, | |||||
_megdnn_tensor_out dcx, _megdnn_tensor_out dw, _megdnn_workspace workspace) { | |||||
Handle* handle = this->handle(); | |||||
size_t seq_len = x.layout.shape[0]; | |||||
auto desc_holder = lstm::get_RNNDescHolder_v6(handle, param(), x.layout); | |||||
auto x_desc_arr_ptr = rnn::get_descs(desc_holder.x_descs).data(); | |||||
auto y_desc_arr_ptr = rnn::get_descs(desc_holder.y_descs).data(); | |||||
RNNWeightFilterDesc w_desc; | |||||
w_desc.set(flatten_weights.layout); | |||||
cudnn_check(cudnnRNNBackwardData( | |||||
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, y_desc_arr_ptr, | |||||
y.raw_ptr(), y_desc_arr_ptr, dy.raw_ptr(), desc_holder.hy_desc.desc, | |||||
dhy.raw_ptr(), desc_holder.cy_desc.desc, dcy.raw_ptr(), w_desc.desc, | |||||
flatten_weights.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), | |||||
desc_holder.cx_desc.desc, cx.raw_ptr(), x_desc_arr_ptr, dx.raw_ptr(), | |||||
desc_holder.hx_desc.desc, dhx.raw_ptr(), desc_holder.cx_desc.desc, | |||||
dcx.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, | |||||
reserve_space.raw_ptr(), desc_holder.reserveSpace_size)); | |||||
cudnn_check(cudnnRNNBackwardWeights( | |||||
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, x_desc_arr_ptr, | |||||
x.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), y_desc_arr_ptr, | |||||
y.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, w_desc.desc, | |||||
dw.raw_ptr(), reserve_space.raw_ptr(), desc_holder.reserveSpace_size)); | |||||
} | |||||
size_t LSTMBackwardImpl::get_workspace_in_bytes( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw) { | |||||
auto desc_holder = lstm::get_RNNDescHolder_v6(this->handle(), param(), x); | |||||
return desc_holder.workspace_size; | |||||
} | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,56 @@ | |||||
/** | |||||
* \file dnn/src/cuda/lstm/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
class LSTMImpl : public LSTM { | |||||
public: | |||||
using LSTM::LSTM; | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out hy, _megdnn_tensor_out cy, | |||||
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace); | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& cy, | |||||
const TensorLayout& reserve_space); | |||||
size_t get_reserve_size_in_bytes(const TensorLayout& input); | |||||
}; | |||||
class LSTMBackwardImpl : public LSTMBackward { | |||||
public: | |||||
using LSTMBackward::LSTMBackward; | |||||
virtual void exec( | |||||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||||
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, | |||||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, | |||||
_megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw, | |||||
_megdnn_workspace workspace); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw); | |||||
}; | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,39 @@ | |||||
/** | |||||
* \file dnn/src/cuda/lstm/utils.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/cuda/lstm/utils.h" | |||||
#include "src/cuda/utils.h" | |||||
#include <cudnn.h> | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace lstm { | |||||
RNNForwardDescHolder_v6 get_RNNDescHolder_v6( | |||||
Handle* handle, megdnn::LSTMForward::Param& _param, const TensorLayout& input) { | |||||
size_t seq_len = input.shape[0]; | |||||
size_t batch_size = input.shape[1]; | |||||
size_t input_size = input.shape[2]; | |||||
cudnnRNNMode_t mode = CUDNN_LSTM; | |||||
using FwdMode = param::LSTM::FwdMode; | |||||
RNNForwardDescHolder_v6 desc_holder( | |||||
handle, seq_len, batch_size, _param.hidden_size, input_size, | |||||
_param.proj_size, _param.num_layers, _param.bidirectional, _param.bias, | |||||
input.dtype, mode); | |||||
return desc_holder; | |||||
} | |||||
} // namespace lstm | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,23 @@ | |||||
/** | |||||
* \file dnn/src/cuda/lstm/utils.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/cuda/cudnn_wrapper.h" | |||||
#include "src/cuda/rnn/utils.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace lstm { | |||||
using megdnn::cuda::rnn::RNNForwardDescHolder_v6; | |||||
RNNForwardDescHolder_v6 get_RNNDescHolder_v6( | |||||
Handle* handle, megdnn::LSTMForward::Param& _param, const TensorLayout& input); | |||||
} // namespace lstm | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,42 @@ | |||||
/** | |||||
* \file dnn/src/cuda/lstm_cell/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/cuda/lstm_cell/opr_impl.h" | |||||
#include "megdnn/dtype.h" | |||||
#include "megdnn/oprs/base.h" | |||||
#include "src/common/lstm_cell.h" | |||||
#include "src/common/opr_delegate.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
size_t LSTMCellImpl::get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, | |||||
const TensorLayout& gates) { | |||||
return megdnn::lstm_cell::get_workspace_in_bytes( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates, | |||||
handle()); | |||||
} | |||||
void LSTMCellImpl::exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||||
_megdnn_tensor_out gates, _megdnn_workspace workspace) { | |||||
megdnn::lstm_cell::exec( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates, | |||||
workspace, handle()); | |||||
} | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,36 @@ | |||||
/** | |||||
* \file dnn/src/cuda/lstm_cell/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/cuda/rnn_cell/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
class LSTMCellImpl : public LSTMCell { | |||||
public: | |||||
using LSTMCell::LSTMCell; | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||||
_megdnn_tensor_out gates, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, const TensorLayout& h_new, | |||||
const TensorLayout& c_new, const TensorLayout& gates) override; | |||||
}; | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,170 @@ | |||||
/** | |||||
* \file dnn/src/cuda/rnn/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/cuda/rnn/opr_impl.h" | |||||
#include "src/common/rnn.h" | |||||
#include "src/cuda/utils.h" | |||||
//#include <cstring> | |||||
#include <cudnn.h> | |||||
#include <cstdlib> | |||||
#include <iostream> | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
using namespace std; | |||||
void RNNImpl::exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, | |||||
_megdnn_workspace workspace) { | |||||
Handle* handle = this->handle(); | |||||
#if false // CUDNN_MAJOR >= 8 | |||||
rnn::RNNForwardDescHolder desc_holder = this->get_desc_holder(input.layout); | |||||
void* workspace_ptr = workspace.raw_ptr; | |||||
void* reserveSpace_ptr = static_cast<uint8_t*>(workspace_ptr) + desc_holder.workspace_size; | |||||
cudnn_check(cudnnRNNForward( | |||||
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.fwdMode, desc_holder.devSeqLengths, | |||||
desc_holder.x_desc.desc, input.raw_ptr(), desc_holder.y_desc.desc, output.raw_ptr(), | |||||
desc_holder.h_desc.desc, hx.raw_ptr(), hy.raw_ptr(), | |||||
desc_holder.h_desc.desc, nullptr, nullptr, | |||||
desc_holder.weight_size, flatten_weights.raw_ptr(), desc_holder.workspace_size, workspace_ptr, | |||||
desc_holder.reserveSpace_size, reserveSpace_ptr | |||||
)); | |||||
#else | |||||
rnn::RNNForwardDescHolder_v6 desc_holder = | |||||
rnn::get_RNNDescHolder_v6(this->handle(), param(), input.layout); | |||||
auto x_desc_arr = rnn::get_descs(desc_holder.x_descs); | |||||
auto y_desc_arr = rnn::get_descs(desc_holder.y_descs); | |||||
RNNWeightFilterDesc w_desc; | |||||
w_desc.set(flatten_weights.layout); | |||||
if (param().fwd_mode == param::RNN::FwdMode::TRAINING) { | |||||
cudnn_check(cudnnRNNForwardTraining( | |||||
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len, | |||||
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc, | |||||
hx.raw_ptr(), desc_holder.cx_desc.desc, NULL, w_desc.desc, | |||||
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(), | |||||
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, NULL, | |||||
workspace.raw_ptr, desc_holder.workspace_size, reserve_space.raw_ptr(), | |||||
desc_holder.reserveSpace_size)); | |||||
} else { | |||||
cudnn_check(cudnnRNNForwardInference( | |||||
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len, | |||||
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc, | |||||
hx.raw_ptr(), desc_holder.cx_desc.desc, nullptr, w_desc.desc, | |||||
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(), | |||||
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, | |||||
nullptr, workspace.raw_ptr, desc_holder.workspace_size)); | |||||
} | |||||
#endif | |||||
} | |||||
size_t RNNImpl::get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& hx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& reserve_space) { | |||||
#if false // CUDNN_MAJOR >= 8 | |||||
rnn::RNNForwardDescHolder desc_holder = this->get_desc_holder(input); | |||||
#else | |||||
rnn::RNNForwardDescHolder_v6 desc_holder = | |||||
rnn::get_RNNDescHolder_v6(this->handle(), param(), input); | |||||
#endif | |||||
return desc_holder.workspace_size; | |||||
} | |||||
size_t RNNImpl::get_reserve_size_in_bytes(const TensorLayout& input) { | |||||
rnn::RNNForwardDescHolder_v6 desc_holder = | |||||
rnn::get_RNNDescHolder_v6(this->handle(), param(), input); | |||||
return desc_holder.reserveSpace_size; | |||||
} | |||||
/*rnn::RNNForwardDescHolder RNNImpl::get_desc_holder(const TensorLayout& input) { | |||||
Handle* handle = this->handle(); | |||||
size_t seq_len = input.shape[0]; | |||||
size_t batch_size = input.shape[1]; | |||||
size_t input_size = input.shape[2]; | |||||
auto _param = param(); | |||||
cudnnRNNMode_t mode; | |||||
using NonlineMode = param::RNN::NonlineMode; | |||||
switch (_param.nonlineMode) { | |||||
case NonlineMode::RELU: | |||||
mode = CUDNN_RNN_RELU; | |||||
break; | |||||
case NonlineMode::TANH: | |||||
mode = CUDNN_RNN_TANH; | |||||
break; | |||||
} | |||||
cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_TRAINING; | |||||
using FwdMode = param::RNN::FwdMode; | |||||
switch (_param.fwd_mode) { | |||||
case FwdMode::TRAINING: | |||||
fwdMode = CUDNN_FWD_MODE_TRAINING; | |||||
break; | |||||
case FwdMode::INFERENCE: | |||||
fwdMode = CUDNN_FWD_MODE_INFERENCE; | |||||
break; | |||||
} | |||||
rnn::RNNForwardDescHolder desc_holder( | |||||
handle, seq_len, batch_size, _param.hidden_size, input_size, | |||||
_param.proj_size, _param.num_layers, _param.bidirectional, _param.bias, | |||||
input.dtype, mode, fwdMode); | |||||
return desc_holder; | |||||
}*/ | |||||
void RNNBackwardImpl::exec( | |||||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in dy, _megdnn_tensor_in dhy, _megdnn_tensor_in flatten_weights, | |||||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx, | |||||
_megdnn_tensor_out dw, _megdnn_workspace workspace) { | |||||
Handle* handle = this->handle(); | |||||
size_t seq_len = x.layout.shape[0]; | |||||
auto desc_holder = rnn::get_RNNDescHolder_v6(handle, param(), x.layout); | |||||
auto x_desc_arr_ptr = rnn::get_descs(desc_holder.x_descs).data(); | |||||
auto y_desc_arr_ptr = rnn::get_descs(desc_holder.y_descs).data(); | |||||
RNNWeightFilterDesc w_desc; | |||||
w_desc.set(flatten_weights.layout); | |||||
cudnn_check(cudnnRNNBackwardData( | |||||
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, y_desc_arr_ptr, | |||||
y.raw_ptr(), y_desc_arr_ptr, dy.raw_ptr(), desc_holder.hy_desc.desc, | |||||
dhy.raw_ptr(), desc_holder.cy_desc.desc, NULL, w_desc.desc, | |||||
flatten_weights.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), | |||||
desc_holder.cx_desc.desc, NULL, x_desc_arr_ptr, dx.raw_ptr(), | |||||
desc_holder.hx_desc.desc, dhx.raw_ptr(), desc_holder.cx_desc.desc, NULL, | |||||
workspace.raw_ptr, desc_holder.workspace_size, reserve_space.raw_ptr(), | |||||
desc_holder.reserveSpace_size)); | |||||
cudnn_check(cudnnRNNBackwardWeights( | |||||
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, x_desc_arr_ptr, | |||||
x.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), y_desc_arr_ptr, | |||||
y.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, w_desc.desc, | |||||
dw.raw_ptr(), reserve_space.raw_ptr(), desc_holder.reserveSpace_size)); | |||||
} | |||||
size_t RNNBackwardImpl::get_workspace_in_bytes( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||||
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw) { | |||||
auto desc_holder = rnn::get_RNNDescHolder_v6(this->handle(), param(), x); | |||||
return desc_holder.workspace_size; | |||||
} | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,57 @@ | |||||
/** | |||||
* \file dnn/src/cuda/rnn/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/cuda/cudnn_wrapper.h" | |||||
#include "src/cuda/rnn/utils.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
class RNNImpl : public RNN { | |||||
public: | |||||
using RNN::RNN; | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, | |||||
_megdnn_workspace workspace); | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& hx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& reserve_space); | |||||
size_t get_reserve_size_in_bytes(const TensorLayout& input); | |||||
// private: | |||||
// rnn::RNNForwardDescHolder get_desc_holder(const TensorLayout& input); | |||||
}; | |||||
class RNNBackwardImpl : public RNNBackward { | |||||
public: | |||||
using RNNBackward::RNNBackward; | |||||
virtual void exec( | |||||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space, | |||||
_megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw, | |||||
_megdnn_workspace workspace); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||||
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw); | |||||
}; | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,138 @@ | |||||
/** | |||||
* \file dnn/src/cuda/rnn/utils.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/cuda/rnn/utils.h" | |||||
#include "src/cuda/utils.h" | |||||
#include <cudnn.h> | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace rnn { | |||||
/*RNNForwardDescHolder::RNNForwardDescHolder(Handle* handle, size_t seq_len, size_t | |||||
batch_size, size_t hidden_size, size_t input_size, size_t proj_size, size_t num_layers, | |||||
bool bidirectional, bool bias, DType dtype, cudnnRNNMode_t _mode, cudnnForwardMode_t | |||||
_fwdMode) : mode(_mode), fwdMode(_fwdMode) | |||||
{ | |||||
size_t D = bidirectional ? 2 : 1; | |||||
// TODO: set dropout to 0 in inference mode | |||||
dropout_desc.set_no_dropout(handle); | |||||
// seq len is unified (not packed) | |||||
// cuda_check(cudaMalloc((void**)&devSeqLengths, sizeof(int32_t) * batch_size)); | |||||
devSeqLengths = (int32_t*)malloc(sizeof(int32_t) * batch_size); | |||||
for (size_t i = 0; i < batch_size; ++i) devSeqLengths[i] = seq_len; | |||||
// proj size should be smaller than hidden size according to cudnn api | |||||
// otherwise it is disabled | |||||
proj_size = (proj_size > hidden_size || proj_size == 0) ? hidden_size : | |||||
proj_size; rnn_desc.set( input_size, hidden_size, proj_size, num_layers, bidirectional, | |||||
bias, dtype, mode, dropout_desc, handle | |||||
); | |||||
x_desc.set(batch_size, input_size, seq_len, devSeqLengths, dtype); | |||||
y_desc.set(batch_size, D * proj_size, seq_len, | |||||
devSeqLengths, dtype); | |||||
h_desc.set_nd(TensorLayout(TensorShape{D * num_layers, batch_size, proj_size}, | |||||
dtype)); | |||||
cudnn_check(cudnnGetRNNWeightSpaceSize(cudnn_handle(handle), rnn_desc.desc, | |||||
&weight_size)); | |||||
cudnn_check(cudnnGetRNNTempSpaceSizes( | |||||
cudnn_handle(handle), rnn_desc.desc, fwdMode, x_desc.desc, | |||||
&workspace_size, &reserveSpace_size | |||||
)); | |||||
} | |||||
RNNForwardDescHolder::~RNNForwardDescHolder() { | |||||
// cuda_check(cudaFree(devSeqLengths)); | |||||
free(devSeqLengths); | |||||
}*/ | |||||
RNNForwardDescHolder_v6::RNNForwardDescHolder_v6( | |||||
Handle* handle, size_t seq_len, size_t batch_size, size_t hidden_size, | |||||
size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional, | |||||
bool bias, DType dtype, cudnnRNNMode_t _mode) | |||||
: mode(_mode), seq_len(seq_len) { | |||||
size_t D = bidirectional ? 2 : 1; | |||||
// TODO: set dropout to 0 in inference mode | |||||
dropout_desc.set_no_dropout(handle); | |||||
proj_size = (proj_size > hidden_size || proj_size == 0) ? hidden_size : proj_size; | |||||
rnn_desc.set( | |||||
input_size, hidden_size, proj_size, num_layers, bidirectional, bias, dtype, | |||||
mode, dropout_desc, handle); | |||||
x_descs.resize(seq_len); | |||||
y_descs.resize(seq_len); | |||||
for (size_t i = 0; i < seq_len; ++i) { | |||||
x_descs[i].set_nd(TensorLayout(TensorShape{batch_size, input_size}, dtype), 3); | |||||
y_descs[i].set_nd( | |||||
TensorLayout(TensorShape{batch_size, D * hidden_size}, dtype), 3); | |||||
} | |||||
#define SET_H(_var) \ | |||||
_var.set_nd(TensorLayout( \ | |||||
TensorShape{D * num_layers, batch_size, hidden_size}, dtype)); | |||||
SET_H(hx_desc) | |||||
SET_H(cx_desc) | |||||
SET_H(hy_desc) | |||||
SET_H(cy_desc) | |||||
#undef SET_H | |||||
std::vector<cudnnTensorDescriptor_t> x_desc_arr = get_descs(x_descs); | |||||
cudnn_check(cudnnGetRNNWorkspaceSize( | |||||
cudnn_handle(handle), rnn_desc.desc, seq_len, x_desc_arr.data(), | |||||
&workspace_size)); | |||||
cudnn_check(cudnnGetRNNTrainingReserveSize( | |||||
cudnn_handle(handle), rnn_desc.desc, seq_len, x_desc_arr.data(), | |||||
&reserveSpace_size)); | |||||
} | |||||
RNNForwardDescHolder_v6 get_RNNDescHolder_v6( | |||||
Handle* handle, megdnn::RNNForward::Param& _param, const TensorLayout& input) { | |||||
size_t seq_len = input.shape[0]; | |||||
size_t batch_size = input.shape[1]; | |||||
size_t input_size = input.shape[2]; | |||||
cudnnRNNMode_t mode; | |||||
using NonlineMode = param::RNN::NonlineMode; | |||||
switch (_param.nonlineMode) { | |||||
case NonlineMode::RELU: | |||||
mode = CUDNN_RNN_RELU; | |||||
break; | |||||
case NonlineMode::TANH: | |||||
mode = CUDNN_RNN_TANH; | |||||
break; | |||||
} | |||||
RNNForwardDescHolder_v6 desc_holder( | |||||
handle, seq_len, batch_size, _param.hidden_size, input_size, | |||||
_param.proj_size, _param.num_layers, _param.bidirectional, _param.bias, | |||||
input.dtype, mode); | |||||
return desc_holder; | |||||
} | |||||
std::vector<cudnnTensorDescriptor_t> get_descs(const std::vector<TensorDesc>& descs) { | |||||
std::vector<cudnnTensorDescriptor_t> r; | |||||
r.reserve(descs.size()); | |||||
for (auto& desc : descs) { | |||||
r.emplace_back(desc.desc); | |||||
} | |||||
return r; | |||||
} | |||||
} // namespace rnn | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,56 @@ | |||||
/** | |||||
* \file dnn/src/cuda/rnn/utils.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/cuda/cudnn_wrapper.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace rnn { | |||||
// v8, not for now | |||||
/*struct RNNForwardDescHolder { | |||||
int32_t* devSeqLengths; | |||||
cudnnRNNMode_t mode; | |||||
cudnnForwardMode_t fwdMode; | |||||
RNNDesc rnn_desc; | |||||
DropoutDesc dropout_desc; | |||||
RNNDataDesc x_desc, y_desc; | |||||
TensorDesc h_desc; | |||||
size_t weight_size, workspace_size, reserveSpace_size; | |||||
RNNForwardDescHolder(Handle* handle, size_t seq_len, size_t batch_size, size_t | |||||
hidden_size, size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional, | |||||
bool bias, DType dtype, | |||||
cudnnRNNMode_t _mode, cudnnForwardMode_t _fwdMode); ~RNNForwardDescHolder(); | |||||
};*/ | |||||
struct RNNForwardDescHolder_v6 { | |||||
cudnnRNNMode_t mode; | |||||
RNNDesc rnn_desc; | |||||
int seq_len; | |||||
DropoutDesc dropout_desc; | |||||
std::vector<TensorDesc> x_descs, y_descs; | |||||
TensorDesc hx_desc, cx_desc, hy_desc, cy_desc; | |||||
size_t workspace_size, reserveSpace_size; | |||||
RNNForwardDescHolder_v6( | |||||
Handle* handle, size_t seq_len, size_t batch_size, size_t hidden_size, | |||||
size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional, | |||||
bool bias, DType dtype, cudnnRNNMode_t _mode); | |||||
}; | |||||
RNNForwardDescHolder_v6 get_RNNDescHolder_v6( | |||||
Handle* handle, megdnn::RNNForward::Param& _param, const TensorLayout& input); | |||||
std::vector<cudnnTensorDescriptor_t> get_descs(const std::vector<TensorDesc>& descs); | |||||
} // namespace rnn | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,35 @@ | |||||
/** | |||||
* \file dnn/src/cuda/rnn_cell/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/cuda/rnn_cell/opr_impl.h" | |||||
#include "src/common/rnn_cell.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
size_t RNNCellImpl::get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& dst) { | |||||
return megdnn::rnn_cell::get_workspace_in_bytes( | |||||
input, weight_ih, bias_hh, hx, weight_hh, bias_hh, dst, handle()); | |||||
} | |||||
void RNNCellImpl::exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
megdnn::rnn_cell::exec( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, workspace, | |||||
param().nonlineMode, handle()); | |||||
} | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* \file dnn/src/cuda/rnn_cell/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
class RNNCellImpl : public RNNCell { | |||||
public: | |||||
using RNNCell::RNNCell; | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& dst) override; | |||||
/* | |||||
private: | |||||
void add_bias(_megdnn_tensor_in A, | |||||
_megdnn_tensor_in B, | |||||
_megdnn_tensor_in bias, | |||||
_megdnn_tensor_out C); | |||||
*/ | |||||
}; | |||||
} // namespace cuda | |||||
} // namespace megdnn |
@@ -54,6 +54,8 @@ | |||||
#include "src/naive/local_share/opr_impl.h" | #include "src/naive/local_share/opr_impl.h" | ||||
#include "src/naive/lrn/opr_impl.h" | #include "src/naive/lrn/opr_impl.h" | ||||
#include "src/naive/lsq/opr_impl.h" | #include "src/naive/lsq/opr_impl.h" | ||||
#include "src/naive/lstm/opr_impl.h" | |||||
#include "src/naive/lstm_cell/opr_impl.h" | |||||
#include "src/naive/mask_conv/opr_impl.h" | #include "src/naive/mask_conv/opr_impl.h" | ||||
#include "src/naive/matrix_inverse/opr_impl.h" | #include "src/naive/matrix_inverse/opr_impl.h" | ||||
#include "src/naive/matrix_mul/opr_impl.h" | #include "src/naive/matrix_mul/opr_impl.h" | ||||
@@ -70,6 +72,8 @@ | |||||
#include "src/naive/repeat/opr_impl.h" | #include "src/naive/repeat/opr_impl.h" | ||||
#include "src/naive/resize/opr_impl.h" | #include "src/naive/resize/opr_impl.h" | ||||
#include "src/naive/rng/opr_impl.h" | #include "src/naive/rng/opr_impl.h" | ||||
#include "src/naive/rnn/opr_impl.h" | |||||
#include "src/naive/rnn_cell/opr_impl.h" | |||||
#include "src/naive/roi_align/opr_impl.h" | #include "src/naive/roi_align/opr_impl.h" | ||||
#include "src/naive/roi_copy/opr_impl.h" | #include "src/naive/roi_copy/opr_impl.h" | ||||
#include "src/naive/roi_pooling/opr_impl.h" | #include "src/naive/roi_pooling/opr_impl.h" | ||||
@@ -0,0 +1,146 @@ | |||||
/** | |||||
* \file dnn/src/naive/lstm/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/naive/lstm/opr_impl.h" | |||||
#include "src/naive/rnn/funcs.h" | |||||
#include "src/naive/rnn/rnn.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
using rnn::LSTMCellWeightWrapper; | |||||
void LSTMImpl::exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out reserve_space, | |||||
_megdnn_workspace workspace) { | |||||
auto _param = param(); | |||||
size_t D = _param.bidirectional ? 2 : 1; | |||||
size_t num_layers = _param.num_layers; | |||||
size_t input_size = input.layout.shape[2]; | |||||
std::vector<LSTMCellWeightWrapper> cells; | |||||
size_t used_workspace_size = rnn::get_cells<LSTMCellWeightWrapper>( | |||||
D, num_layers, input_size, _param.hidden_size, _param.bias, cells, | |||||
flatten_weights, workspace); | |||||
Workspace new_workspace( | |||||
workspace.raw_ptr + used_workspace_size, | |||||
workspace.size - used_workspace_size); | |||||
TensorNDArray states = {hx, cx}, states_new = {hy, cy}; | |||||
rnn::exec_internal<LSTMCellWeightWrapper, LSTMCellForward>( | |||||
cells, input, states, states_new, output, reserve_space, num_layers, D, | |||||
this->handle(), new_workspace); | |||||
} | |||||
size_t LSTMImpl::get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& cy, | |||||
const TensorLayout& reserve_space) { | |||||
size_t workspace_size = rnn::get_workspace_in_bytes<LSTMCellForward>( | |||||
input, flatten_weights, param().hidden_size, param().bidirectional ? 2 : 1, | |||||
this->handle()); | |||||
if (!param().bias) { // use fake bias (all 0) | |||||
TensorLayout bias_layout = {{param().hidden_size * 4}, flatten_weights.dtype}; | |||||
workspace_size += bias_layout.span().dist_byte(); | |||||
} | |||||
workspace_size += output.span().dist_byte(); | |||||
return workspace_size; | |||||
} | |||||
size_t LSTMImpl::get_reserve_size_in_bytes(const TensorLayout& input) { | |||||
size_t num_layers = param().num_layers; | |||||
size_t D = param().bidirectional ? 2 : 1; | |||||
size_t seq_len = input.shape[0]; | |||||
size_t batch_size = input.shape[1]; | |||||
TensorLayout state_layout{{batch_size, param().hidden_size}, input.dtype}; | |||||
// 2 for hidden state and cell state | |||||
return 2 * num_layers * D * seq_len * state_layout.span().dist_byte(); | |||||
} | |||||
void LSTMBackwardImpl::exec( | |||||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||||
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, | |||||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx, | |||||
_megdnn_tensor_out dcx, _megdnn_tensor_out dw, _megdnn_workspace workspace) { | |||||
TensorNDArray layer_inputs; | |||||
TensorNDArray layer_outputs; | |||||
std::vector<std::vector<TensorNDArray>> cell_seq_states; | |||||
size_t num_layers = param().num_layers; | |||||
size_t D = param().bidirectional ? 2 : 1; | |||||
size_t input_size = x.layout.shape[2]; | |||||
size_t hidden_size = param().hidden_size; | |||||
size_t used_workspace_size = 0; | |||||
// get cells | |||||
std::vector<LSTMCellWeightWrapper> cells; | |||||
used_workspace_size += rnn::get_cells<LSTMCellWeightWrapper>( | |||||
D, num_layers, input_size, hidden_size, param().bias, cells, | |||||
flatten_weights, workspace); | |||||
// get formatted inputs | |||||
Workspace new_workspace = Workspace( | |||||
workspace.raw_ptr + used_workspace_size, | |||||
workspace.size - used_workspace_size); | |||||
used_workspace_size += rnn::get_inputs_for_exec<LSTMCellWeightWrapper>( | |||||
x, y, reserve_space, num_layers, D, hidden_size, cells, layer_inputs, | |||||
layer_outputs, cell_seq_states, param::RNNCell::NonlineMode::IDENTITY, | |||||
new_workspace); | |||||
// dhy arr, dhx arr | |||||
TensorNDArray dhy_arr = {dhy, dcy}, dhx_arr = {dhx, dcx}; | |||||
// exec | |||||
new_workspace = Workspace( | |||||
workspace.raw_ptr + used_workspace_size, | |||||
workspace.size - used_workspace_size); | |||||
rnn::backward_exec_internal<LSTMCellWeightWrapper>( | |||||
cells, D, num_layers, input_size, param().bias, | |||||
param::RNNCell::NonlineMode::IDENTITY, layer_inputs, layer_outputs, | |||||
cell_seq_states, dy, dhy_arr, dx, dhx_arr, dw, this->handle(), | |||||
new_workspace); | |||||
} | |||||
size_t LSTMBackwardImpl::get_workspace_in_bytes( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw) { | |||||
size_t D = param().bidirectional ? 2 : 1; | |||||
size_t num_layers = param().num_layers; | |||||
size_t hidden_size = param().hidden_size; | |||||
size_t gate_hidden_size = hidden_size * 4; | |||||
size_t max_input_size = std::max(x.shape[2], D * hidden_size); | |||||
size_t workspace_size = LSTMCellWeightWrapper::backward_workspace_size_in_bytes( | |||||
this->handle(), x.shape[1], param().hidden_size, max_input_size, x.dtype); | |||||
if (!param().bias) { // use fake bias (all 0) | |||||
TensorLayout bias_layout = {{gate_hidden_size}, flatten_weights.dtype}; | |||||
workspace_size += bias_layout.span().dist_byte() * | |||||
2; // times 2 because another bias is allocated in | |||||
// backward_exec_internal | |||||
} | |||||
workspace_size += num_layers * y.span().dist_byte(); | |||||
// add back exec workspace size | |||||
workspace_size += y.span().dist_byte() * 2; | |||||
workspace_size += x.span().dist_byte() * 2; | |||||
TensorLayout wih{{gate_hidden_size, max_input_size}, flatten_weights.dtype}; | |||||
TensorLayout whh{{gate_hidden_size, hidden_size}, flatten_weights.dtype}; | |||||
TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype}; | |||||
workspace_size += wih.span().dist_byte(); | |||||
workspace_size += whh.span().dist_byte(); | |||||
workspace_size += bias.span().dist_byte(); | |||||
return workspace_size; | |||||
} | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,56 @@ | |||||
/** | |||||
* \file dnn/src/naive/lstm/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
class LSTMImpl : public LSTM { | |||||
public: | |||||
using LSTM::LSTM; | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out hy, _megdnn_tensor_out cy, | |||||
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace); | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& cy, | |||||
const TensorLayout& reserve_space); | |||||
size_t get_reserve_size_in_bytes(const TensorLayout& input); | |||||
}; | |||||
class LSTMBackwardImpl : public LSTMBackward { | |||||
public: | |||||
using LSTMBackward::LSTMBackward; | |||||
virtual void exec( | |||||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||||
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, | |||||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, | |||||
_megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw, | |||||
_megdnn_workspace workspace); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw); | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,55 @@ | |||||
/** | |||||
* \file dnn/src/naive/lstm/template_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/naive/rnn/funcs.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
namespace rnn { | |||||
template <> | |||||
void cell_opr_exec<LSTMCellForward>( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in bias_hh, const TensorNDArray& states, | |||||
TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle) { | |||||
auto opr = handle->create_operator<LSTMCellForward>(); | |||||
TensorLayout gates, h_new, c_new; | |||||
opr->deduce_layout( | |||||
input.layout, weight_ih.layout, bias_ih.layout, states[0].layout, | |||||
weight_hh.layout, bias_hh.layout, states[1].layout, h_new, c_new, gates); | |||||
TensorND gates_tensor{workspace.raw_ptr, gates}; | |||||
_megdnn_workspace new_workspace = { | |||||
workspace.raw_ptr + gates.span().dist_byte(), | |||||
workspace.size - gates.span().dist_byte()}; | |||||
opr->exec( | |||||
input, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states[1], | |||||
states_new[0], states_new[1], gates_tensor, new_workspace); | |||||
} | |||||
template <> | |||||
size_t cell_opr_get_workspace_in_bytes<LSTMCellForward>( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_ih, | |||||
const TensorLayout& bias_hh, const TensorLayout& hx, Handle* handle) { | |||||
TensorLayout cx = hx; | |||||
TensorLayout h_new, c_new, gates; | |||||
auto cell_opr = handle->create_operator<LSTMCellForward>(); | |||||
cell_opr->deduce_layout( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates); | |||||
return cell_opr->get_workspace_in_bytes( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, | |||||
gates) + | |||||
gates.span().dist_byte(); | |||||
} | |||||
} // namespace rnn | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,38 @@ | |||||
/** | |||||
* \file dnn/src/naive/lstm_cell/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/naive/lstm_cell/opr_impl.h" | |||||
#include "src/common/lstm_cell.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
size_t LSTMCellImpl::get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, | |||||
const TensorLayout& gates) { | |||||
return megdnn::lstm_cell::get_workspace_in_bytes( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates, | |||||
handle()); | |||||
} | |||||
void LSTMCellImpl::exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||||
_megdnn_tensor_out gates, _megdnn_workspace workspace) { | |||||
megdnn::lstm_cell::exec( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates, | |||||
workspace, handle()); | |||||
} | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,36 @@ | |||||
/** | |||||
* \file dnn/src/naive/lstm_cell/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/naive/rnn_cell/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
class LSTMCellImpl : public LSTMCell { | |||||
public: | |||||
using LSTMCell::LSTMCell; | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||||
_megdnn_tensor_out gates, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, const TensorLayout& h_new, | |||||
const TensorLayout& c_new, const TensorLayout& gates) override; | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,75 @@ | |||||
/** | |||||
* \file dnn/src/naive/rnn/funcs.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#ifndef _RNN_H | |||||
#define _RNN_H | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
namespace rnn { | |||||
template <typename CellOpr> | |||||
void cell_opr_exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in bias_hh, const TensorNDArray& states, | |||||
TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle); | |||||
template <typename CellOpr> | |||||
size_t cell_opr_get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_ih, | |||||
const TensorLayout& bias_hh, const TensorLayout& hx, Handle* handle); | |||||
template <typename CellOpr> | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& flatten_weights, | |||||
size_t hidden_size, | |||||
size_t D, // num_directions | |||||
Handle* handle); | |||||
template <class Cell, typename CellOpr> | |||||
void exec_internal( | |||||
std::vector<Cell>& cells, _megdnn_tensor_in input, const TensorNDArray& states, | |||||
TensorNDArray& states_new, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out reserve_space, size_t num_layers, | |||||
size_t D, // D is num_directions | |||||
Handle* handle, _megdnn_workspace workspace); | |||||
template <class Cell> | |||||
size_t get_cells( | |||||
size_t D, size_t num_layers, size_t input_size, size_t hidden_size, bool bias, | |||||
std::vector<Cell>& cells, _megdnn_tensor_in flatten_weights, | |||||
_megdnn_workspace workspace); | |||||
template <class Cell> | |||||
size_t get_inputs_for_exec( | |||||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in reserve_space, | |||||
size_t num_layers, size_t D, size_t hidden_size, const std::vector<Cell>& cells, | |||||
TensorNDArray& layer_inputs, TensorNDArray& layer_outputs, | |||||
std::vector<std::vector<TensorNDArray>>& cell_seq_states, | |||||
param::RNNCell::NonlineMode nonlineMode, _megdnn_workspace workspace); | |||||
template <class Cell> | |||||
void backward_exec_internal( | |||||
std::vector<Cell>& cells, size_t D, size_t num_layers, size_t input_size, | |||||
bool bias, param::RNNCell::NonlineMode nonlineMode, | |||||
const TensorNDArray& layer_inputs, const TensorNDArray& layer_outputs, | |||||
const std::vector<std::vector<TensorNDArray>>& cell_seq_states, | |||||
_megdnn_tensor_in dy, const TensorNDArray& dhy, _megdnn_tensor_out dx, | |||||
TensorNDArray& dstates, _megdnn_tensor_out dw, Handle* handle, | |||||
_megdnn_workspace workspace); | |||||
} // namespace rnn | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
#include "funcs.tpp" | |||||
#endif |
@@ -0,0 +1,449 @@ | |||||
/** | |||||
* \file dnn/src/naive/rnn/funcs.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "funcs.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
namespace rnn { | |||||
template <typename CellOpr> | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& flatten_weights, | |||||
size_t hidden_size, | |||||
size_t D, // num_directions | |||||
Handle* handle) { | |||||
size_t seq_len = input.shape[0]; | |||||
size_t batch_size = input.shape[1]; | |||||
size_t input_size = input.shape[2]; | |||||
size_t gate_hidden_size = flatten_weights.shape[0]; | |||||
// concat workspace | |||||
TensorLayout direction_output_layout{ | |||||
TensorShape{seq_len, batch_size, hidden_size}, input.dtype}; | |||||
TensorLayout output_layout{{seq_len, batch_size, D * hidden_size}, input.dtype}; | |||||
TensorLayoutArray layer_layouts; | |||||
for (size_t i = 0; i < D; ++i) | |||||
layer_layouts.push_back(direction_output_layout); | |||||
auto concat_opr = handle->create_operator<ConcatForward>(); | |||||
concat_opr->param().axis = -1; | |||||
size_t concat_workspace = | |||||
concat_opr->get_workspace_in_bytes(layer_layouts, output_layout); | |||||
// cell workspace | |||||
TensorLayout weight_ih{{gate_hidden_size, input_size}, flatten_weights.dtype}; | |||||
TensorLayout weight_hh{{gate_hidden_size, hidden_size}, flatten_weights.dtype}; | |||||
TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype}; | |||||
TensorLayout hx{{batch_size, hidden_size}, input.dtype}; | |||||
size_t cell_workspace = cell_opr_get_workspace_in_bytes<CellOpr>( | |||||
input, weight_ih, weight_hh, bias, bias, hx, handle); | |||||
return std::max(cell_workspace, concat_workspace); | |||||
} | |||||
template <class Cell, typename CellOpr> | |||||
void exec_internal( | |||||
std::vector<Cell>& cells, _megdnn_tensor_in input, const TensorNDArray& states, | |||||
TensorNDArray& states_new, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out reserve_space, size_t num_layers, | |||||
size_t D, // D is num_directions | |||||
Handle* handle, _megdnn_workspace workspace) { | |||||
size_t seq_len = input.layout.shape[0]; | |||||
size_t batch_size = input.layout.shape[1]; | |||||
size_t input_size = input.layout.shape[2]; | |||||
size_t hidden_size = cells[0].weight_hh.layout.shape[1]; | |||||
TensorLayout cell_output_layout{ | |||||
TensorShape{batch_size, hidden_size}, states[0].layout.dtype}; | |||||
TensorLayout cell_first_input_layout{ | |||||
TensorShape{batch_size, input_size}, input.layout.dtype}; | |||||
TensorLayout cell_input_layout{ | |||||
TensorShape{batch_size, D * hidden_size}, input.layout.dtype}; | |||||
TensorLayout direction_output_layout{ | |||||
TensorShape{seq_len, batch_size, hidden_size}, output.layout.dtype}; | |||||
TensorND tmp_output{workspace.raw_ptr, output.layout}; | |||||
_megdnn_workspace new_workspace{ | |||||
workspace.raw_ptr + tmp_output.layout.span().dist_byte(), | |||||
workspace.size - tmp_output.layout.span().dist_byte()}; | |||||
auto cell_opr = handle->create_operator<CellOpr>(); | |||||
auto copy_opr = handle->create_operator<TypeCvtForward>(); | |||||
// copy states to states_new | |||||
for (size_t i = 0; i < states.size(); ++i) | |||||
copy_opr->exec(states[i], states_new[i]); | |||||
void* reserve_ptr = reserve_space.raw_ptr(); | |||||
// layer 1 | |||||
// TensorNDArray layer_outputs; | |||||
for (size_t d = 0; d < D; ++d) { | |||||
size_t cell_idx = d; | |||||
auto& cell = cells[cell_idx]; | |||||
TensorNDArray cur_states; | |||||
size_t states_offset = cell_idx * cell_output_layout.span().dist_byte(); | |||||
for (size_t i = 0; i < states.size(); ++i) { | |||||
cur_states.push_back(TensorND{ | |||||
static_cast<uint8_t*>(states_new[i].raw_ptr()) + states_offset, | |||||
cell_output_layout}); | |||||
} | |||||
// TensorND direction_output_tensor{output.raw_ptr + d * | |||||
// direction_output_layout.span().dist_byte(), | |||||
// direction_output_layout}; | |||||
for (size_t i = 0; i < seq_len; ++i) { | |||||
size_t step = d == 0 ? i : seq_len - 1 - i; | |||||
TensorND step_input{ | |||||
static_cast<uint8_t*>(input.raw_ptr()) + | |||||
step * cell_first_input_layout.span().dist_byte(), | |||||
cell_first_input_layout}; | |||||
TensorND step_output{ | |||||
static_cast<uint8_t*>(output.raw_ptr()) + | |||||
(step * D + d) * cell_output_layout.span().dist_byte(), | |||||
cell_output_layout}; | |||||
// temporary states of each step (use reserve space) | |||||
TensorNDArray tmp_states; | |||||
for (size_t s = 0; s < cur_states.size(); ++s) { | |||||
tmp_states.push_back(TensorND{reserve_ptr, cur_states[s].layout}); | |||||
size_t size_in_bytes = cur_states[s].layout.span().dist_byte(); | |||||
reserve_ptr = static_cast<uint8_t*>(reserve_ptr) + size_in_bytes; | |||||
} | |||||
cell_opr_exec<CellOpr>( | |||||
step_input, cell.weight_ih, cell.weight_hh, cell.bias_ih, | |||||
cell.bias_hh, cur_states, tmp_states, new_workspace, handle); | |||||
// copy states to cur_states | |||||
for (size_t s = 0; s < tmp_states.size(); ++s) { | |||||
copy_opr->exec(tmp_states[s], cur_states[s]); | |||||
} | |||||
// copy h to step output | |||||
copy_opr->exec(cur_states[0], step_output); | |||||
} | |||||
} | |||||
for (size_t layer = 1; layer < num_layers; ++layer) { | |||||
// TensorNDArray layer_outputs; | |||||
for (size_t d = 0; d < D; ++d) { | |||||
size_t cell_idx = layer * D + d; | |||||
auto& cell = cells[cell_idx]; | |||||
TensorNDArray cur_states; | |||||
size_t states_offset = cell_idx * cell_output_layout.span().dist_byte(); | |||||
for (size_t i = 0; i < states.size(); ++i) { | |||||
cur_states.push_back(TensorND{ | |||||
static_cast<uint8_t*>(states_new[i].raw_ptr()) + states_offset, | |||||
cell_output_layout}); | |||||
} | |||||
// TensorND direction_output_tensor{output.raw_ptr + d * | |||||
// direction_output_layout.span().dist_byte(), | |||||
// direction_output_layout}; | |||||
for (size_t i = 0; i < seq_len; ++i) { | |||||
size_t step = d == 0 ? i : seq_len - 1 - i; | |||||
TensorND step_input{ | |||||
static_cast<uint8_t*>(output.raw_ptr()) + | |||||
step * cell_input_layout.span().dist_byte(), | |||||
cell_input_layout}; | |||||
TensorND step_output{ | |||||
static_cast<uint8_t*>(tmp_output.raw_ptr()) + | |||||
(step * D + d) * cell_output_layout.span().dist_byte(), | |||||
cell_output_layout}; | |||||
// temporary states of each step (use reserve space) | |||||
TensorNDArray tmp_states; | |||||
for (size_t s = 0; s < cur_states.size(); ++s) { | |||||
tmp_states.push_back(TensorND{reserve_ptr, cur_states[s].layout}); | |||||
size_t size_in_bytes = cur_states[s].layout.span().dist_byte(); | |||||
reserve_ptr = static_cast<uint8_t*>(reserve_ptr) + size_in_bytes; | |||||
} | |||||
cell_opr_exec<CellOpr>( | |||||
step_input, cell.weight_ih, cell.weight_hh, cell.bias_ih, | |||||
cell.bias_hh, cur_states, cur_states, new_workspace, handle); | |||||
// copy states to cur_states | |||||
for (size_t s = 0; s < tmp_states.size(); ++s) { | |||||
copy_opr->exec(tmp_states[s], cur_states[s]); | |||||
} | |||||
// copy h to step_output | |||||
copy_opr->exec(cur_states[0], step_output); | |||||
} | |||||
} | |||||
// copy layer output to output | |||||
copy_opr->exec(tmp_output, output); | |||||
} | |||||
// output: [d0, d1, d0, d1 ...] | |||||
} | |||||
template <class Cell> | |||||
size_t get_cells( | |||||
size_t D, size_t num_layers, size_t input_size, size_t hidden_size, bool bias, | |||||
std::vector<Cell>& cells, _megdnn_tensor_in flatten_weights, | |||||
_megdnn_workspace workspace) { | |||||
cells.reserve(D * num_layers); | |||||
void* weight_ptr = flatten_weights.raw_ptr(); | |||||
for (size_t layer = 0; layer < num_layers; ++layer) { | |||||
for (size_t d = 0; d < D; ++d) { | |||||
size_t cell_input_size = D * hidden_size; | |||||
if (layer == 0) | |||||
cell_input_size = input_size; | |||||
Cell cell( | |||||
weight_ptr, hidden_size, cell_input_size, bias, | |||||
flatten_weights.layout.dtype, workspace); | |||||
weight_ptr = | |||||
static_cast<uint8_t*>(weight_ptr) + cell.weight_size_in_bytes(); | |||||
cells.push_back(cell); | |||||
} | |||||
} | |||||
// return used workspace | |||||
return cells[0].workspace_size_in_bytes(); | |||||
} | |||||
template <class Cell> | |||||
size_t get_inputs_for_exec( | |||||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in reserve_space, | |||||
size_t num_layers, size_t D, size_t hidden_size, const std::vector<Cell>& cells, | |||||
TensorNDArray& layer_inputs, TensorNDArray& layer_outputs, | |||||
std::vector<std::vector<TensorNDArray>>& cell_seq_states, | |||||
param::RNNCell::NonlineMode nonlineMode, _megdnn_workspace workspace) { | |||||
// return used workspace size | |||||
layer_inputs.push_back(x); | |||||
size_t seq_len = x.layout.shape[0]; | |||||
size_t batch_size = x.layout.shape[1]; | |||||
size_t num_states = cells[0].num_states(); | |||||
TensorLayout cell_output_layout{{batch_size, hidden_size}, y.layout.dtype}; | |||||
TensorLayout direction_output_layout{ | |||||
{seq_len, batch_size, hidden_size}, y.layout.dtype}; | |||||
void* workspace_ptr = workspace.raw_ptr; | |||||
// extract intermedia states from reserve space | |||||
for (int layer = 0; layer < num_layers; ++layer) { | |||||
TensorND layer_output{workspace_ptr, y.layout}; | |||||
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) + | |||||
layer_output.layout.span().dist_byte(); | |||||
for (int d = 0; d < D; ++d) { | |||||
cell_seq_states.push_back(std::vector<TensorNDArray>()); | |||||
// reverse direction is stored with reversed order of sequence order | |||||
for (int i = 0; i < seq_len; ++i) { | |||||
size_t step = i; | |||||
if (d == 1) | |||||
step = seq_len - i - 1; | |||||
size_t offset = ((layer * D + d) * seq_len + step) * | |||||
cell_output_layout.span().dist_byte() * num_states; | |||||
TensorNDArray cur_states; | |||||
for (int s = 0; s < num_states; ++s) { | |||||
TensorND h{ | |||||
static_cast<uint8_t*>(reserve_space.raw_ptr()) + offset + | |||||
s * cell_output_layout.span().dist_byte(), | |||||
cell_output_layout}; | |||||
cur_states.push_back(h); | |||||
} | |||||
TensorND hy{ | |||||
static_cast<uint8_t*>(reserve_space.raw_ptr()) + offset, | |||||
cell_output_layout}; // the first hidden state is the output | |||||
// states | |||||
cell_seq_states[cell_seq_states.size() - 1].push_back(cur_states); | |||||
// output | |||||
offset = i * D * cell_output_layout.span().dist_byte(); | |||||
memcpy(static_cast<uint8_t*>(layer_output.raw_ptr()) + offset, hy.raw_ptr(), | |||||
hy.layout.span().dist_byte()); | |||||
} | |||||
} | |||||
layer_outputs.push_back(layer_output); | |||||
if (layer != num_layers - 1) | |||||
layer_inputs.push_back(layer_output); | |||||
} | |||||
return static_cast<uint8_t*>(workspace_ptr) - | |||||
static_cast<uint8_t*>((void*)workspace.raw_ptr); | |||||
} | |||||
template <class Cell> | |||||
// using Cell = RNNCellWeightWrapper; | |||||
void backward_exec_internal( | |||||
std::vector<Cell>& cells, size_t D, size_t num_layers, size_t input_size, | |||||
bool bias, param::RNNCell::NonlineMode nonlineMode, | |||||
const TensorNDArray& layer_inputs, const TensorNDArray& layer_outputs, | |||||
const std::vector<std::vector<TensorNDArray>>& cell_seq_states, | |||||
_megdnn_tensor_in dy, const TensorNDArray& dhy, _megdnn_tensor_out dx, | |||||
TensorNDArray& dstates, _megdnn_tensor_out dw, Handle* handle, | |||||
_megdnn_workspace workspace) { | |||||
/* | |||||
layer_inputs: array of input of each layer, element 0: [seq_len, batch_size, | |||||
input_size], element others: [seq_len, batch_size, D * hidden_size] | |||||
layer_outputs: array of outputs of each rnn. To access outputs of the cell at | |||||
(layer, d), use layer_outputs[layer]. The shape is [seq_len, batch_size, | |||||
output_size(D*hidden_size)] (in sequence order) cell_seq_states: arrray of states | |||||
of each cell at each step. To access the states of the cell at (layer, d) at | |||||
sequence step (step), use cell_seq_states[layer*D + d][step] | |||||
*/ | |||||
size_t seq_len = layer_inputs[0].layout.shape[0]; | |||||
size_t batch_size = layer_inputs[0].layout.shape[1]; | |||||
DType dtype = layer_inputs[0].layout.dtype; | |||||
size_t cell_y_size = | |||||
layer_outputs[0].layout.shape[2] / D; // should all be the same | |||||
size_t hidden_size = cell_y_size; | |||||
TensorLayout cell_y_layout = {{batch_size, cell_y_size}, dtype}; | |||||
void* workspace_ptr = workspace.raw_ptr; | |||||
TensorND layer_output_grad{ | |||||
workspace_ptr, {{seq_len, batch_size, D * hidden_size}, dtype}}; | |||||
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) + | |||||
layer_output_grad.layout.span().dist_byte(); | |||||
memcpy(layer_output_grad.raw_ptr(), dy.raw_ptr(), dy.layout.span().dist_byte()); | |||||
TensorNDArray direction_dx_arr; // for layer 1 to layer num_layers-1 | |||||
for (int i = 0; i < D; ++i) { | |||||
TensorLayout direction_dx_layout{{seq_len, batch_size, hidden_size}, dtype}; | |||||
direction_dx_arr.push_back(TensorND(workspace_ptr, direction_dx_layout)); | |||||
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) + | |||||
direction_dx_layout.span().dist_byte(); | |||||
} | |||||
TensorNDArray L0_direction_dx_arr; | |||||
for (int i = 0; i < D; ++i) { | |||||
TensorLayout direction_dx_layout{{seq_len, batch_size, input_size}, dtype}; | |||||
L0_direction_dx_arr.push_back(TensorND(workspace_ptr, direction_dx_layout)); | |||||
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) + | |||||
direction_dx_layout.span().dist_byte(); | |||||
} | |||||
// cell states for each layer and each direction | |||||
std::vector<TensorNDArray> dstates_arr; | |||||
for (int layer = 0; layer < num_layers; ++layer) { | |||||
for (int d = 0; d < D; ++d) { | |||||
TensorNDArray cell_states; | |||||
cell_states.reserve(dstates.size()); | |||||
for (int i = 0; i < dstates.size(); ++i) { | |||||
size_t offset = (layer * D + d) * cell_y_layout.span().dist_byte(); | |||||
TensorND dhx_cell{ | |||||
static_cast<uint8_t*>(dstates[i].raw_ptr()) + offset, | |||||
cell_y_layout}; | |||||
memcpy(dhx_cell.raw_ptr(), static_cast<uint8_t*>(dhy[i].raw_ptr()) + offset, | |||||
cell_y_layout.span().dist_byte()); | |||||
cell_states.emplace_back(dhx_cell); | |||||
} | |||||
dstates_arr.push_back(cell_states); | |||||
} | |||||
} | |||||
// init gradient on weight to zero | |||||
memset(dw.raw_ptr(), 0, dw.layout.span().dist_byte()); | |||||
// use cells to contain gradient | |||||
std::vector<Cell> cell_grads; | |||||
size_t used_workspace_size = static_cast<uint8_t*>(workspace_ptr) - | |||||
static_cast<uint8_t*>((void*)(workspace.raw_ptr)); | |||||
workspace_ptr = | |||||
static_cast<uint8_t*>(workspace_ptr) + | |||||
get_cells( | |||||
D, num_layers, input_size, hidden_size, bias, cell_grads, dw, | |||||
Workspace( | |||||
workspace.raw_ptr + used_workspace_size, | |||||
workspace.size - used_workspace_size)); | |||||
auto add_opr = handle->create_operator<ElemwiseForward>(); | |||||
add_opr->param().mode = Elemwise::Mode::ADD; | |||||
auto copy_opr = handle->create_operator<TypeCvtForward>(); | |||||
// initialize dx to zero | |||||
memset(dx.raw_ptr(), 0, dx.layout.span().dist_byte()); | |||||
// calculate grads | |||||
for (int layer = num_layers - 1; layer >= 0; --layer) { | |||||
for (int d = D - 1; d >= 0; --d) { | |||||
Cell& cell = cells[layer * D + d]; | |||||
Cell& cell_grad = cell_grads[layer * D + d]; | |||||
size_t input_size = layer_inputs[layer].layout.shape[2]; | |||||
const TensorND& x_arr = layer_inputs[layer]; | |||||
const TensorND& y_arr = layer_outputs[layer]; | |||||
TensorLayout x_layout = {{batch_size, input_size}, dtype}; | |||||
// tmp tensors | |||||
void* tmp_workspace_ptr = workspace_ptr; | |||||
TensorND dwi_tmp{tmp_workspace_ptr, cell_grad.weight_ih.layout}; | |||||
tmp_workspace_ptr = static_cast<uint8_t*>(tmp_workspace_ptr) + | |||||
dwi_tmp.layout.span().dist_byte(); | |||||
TensorND dwh_tmp{tmp_workspace_ptr, cell_grad.weight_hh.layout}; | |||||
tmp_workspace_ptr = static_cast<uint8_t*>(tmp_workspace_ptr) + | |||||
dwh_tmp.layout.span().dist_byte(); | |||||
TensorND dbias_tmp{tmp_workspace_ptr, cell_grad.bias_ih.layout}; | |||||
tmp_workspace_ptr = static_cast<uint8_t*>(tmp_workspace_ptr) + | |||||
dbias_tmp.layout.span().dist_byte(); | |||||
size_t used_workspace_size = | |||||
static_cast<uint8_t*>(tmp_workspace_ptr) - | |||||
static_cast<uint8_t*>((void*)(workspace.raw_ptr)); | |||||
for (int i = 0; i < seq_len; ++i) { | |||||
// reverse time step (not seq step). Here step means seq step | |||||
size_t step = i; | |||||
if (d == 0) | |||||
step = seq_len - i - 1; | |||||
TensorND x{ | |||||
static_cast<uint8_t*>(x_arr.raw_ptr()) + | |||||
step * x_layout.span().dist_byte(), | |||||
x_layout}, | |||||
y{static_cast<uint8_t*>(y_arr.raw_ptr()) + | |||||
(step * D + d) * cell_y_layout.span().dist_byte(), | |||||
cell_y_layout}; | |||||
const TensorNDArray& cell_states = cell_seq_states[layer * D + d][step]; | |||||
TensorNDArray& dstates_new = dstates_arr[layer * D + d]; | |||||
// dy should be d_output + d_hidden | |||||
TensorND dy_t{ | |||||
static_cast<uint8_t*>(layer_output_grad.raw_ptr()) + | |||||
(step * D + d) * cell_y_layout.span().dist_byte(), | |||||
cell_y_layout}; | |||||
add_opr->exec({dstates_new[0], dy_t}, dy_t); | |||||
// dx for layer 0 has a different size | |||||
TensorND dx_t; | |||||
if (layer == 0) | |||||
dx_t = {static_cast<uint8_t*>(L0_direction_dx_arr[d].raw_ptr()) + | |||||
step * x_layout.span().dist_byte(), | |||||
x_layout}; | |||||
else | |||||
dx_t = {static_cast<uint8_t*>(direction_dx_arr[d].raw_ptr()) + | |||||
step * x_layout.span().dist_byte(), | |||||
x_layout}; | |||||
TensorNDArray douts = {dy_t}; | |||||
for (int s = 1; s < dstates_new.size(); ++s) | |||||
douts.push_back(dstates_new[s]); | |||||
cell.backward( | |||||
handle, nonlineMode, x, cell_states, y, douts, dx_t, | |||||
dstates_new, dwi_tmp, dwh_tmp, dbias_tmp, | |||||
Workspace( | |||||
workspace.raw_ptr + used_workspace_size, | |||||
workspace.size - used_workspace_size)); | |||||
// add step gradient to overall gradient | |||||
add_opr->exec({dwi_tmp, cell_grad.weight_ih}, cell_grad.weight_ih); | |||||
add_opr->exec({dwh_tmp, cell_grad.weight_hh}, cell_grad.weight_hh); | |||||
add_opr->exec({dbias_tmp, cell_grad.bias_ih}, cell_grad.bias_ih); | |||||
add_opr->exec({dbias_tmp, cell_grad.bias_hh}, cell_grad.bias_hh); | |||||
} | |||||
} | |||||
// add gradient of different directions to layer_output_grad. | |||||
// Layer 0 to dx | |||||
if (layer == 0) { | |||||
for (int i = 0; i < D; ++i) | |||||
add_opr->exec({L0_direction_dx_arr[i], dx}, dx); | |||||
} else { | |||||
if (D == 1) | |||||
copy_opr->exec(direction_dx_arr[0], layer_output_grad); | |||||
else { // D == 2, arrange as [(d0, d1), (d0, d1), ...] | |||||
for (size_t t = 0; t < seq_len; ++t) { | |||||
size_t offset = t * D * cell_y_layout.span().dist_byte(); | |||||
for (size_t d = 0; d < D; ++d) { | |||||
TensorND src{ | |||||
static_cast<uint8_t*>(direction_dx_arr[d].raw_ptr()) + | |||||
offset, | |||||
cell_y_layout}; | |||||
TensorND dst{ | |||||
static_cast<uint8_t*>(layer_output_grad.raw_ptr()) + | |||||
offset + d * cell_y_layout.span().dist_byte(), | |||||
cell_y_layout}; | |||||
copy_opr->exec(src, dst); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} // namespace rnn | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,196 @@ | |||||
/** | |||||
* \file dnn/src/naive/rnn/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/naive/rnn/opr_impl.h" | |||||
#include "megdnn/dtype.h" | |||||
#include "megdnn/oprs/base.h" | |||||
#include "megdnn/oprs/general.h" | |||||
#include "src/common/opr_delegate.h" | |||||
#include "src/common/rnn.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/naive/handle.h" | |||||
#include "src/naive/matrix_mul/opr_impl.h" | |||||
#include "src/naive/rnn/funcs.h" | |||||
#include "src/naive/rnn/rnn.h" | |||||
#include <cstring> | |||||
namespace megdnn { | |||||
namespace naive { | |||||
using rnn::RNNCellWeightWrapper; | |||||
void RNNImpl::exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, | |||||
_megdnn_workspace workspace) { | |||||
auto _param = param(); | |||||
size_t D = _param.bidirectional ? 2 : 1; | |||||
size_t num_layers = _param.num_layers; | |||||
size_t input_size = input.layout.shape[2]; | |||||
std::vector<RNNCellWeightWrapper> cells; | |||||
size_t used_workspace_size = rnn::get_cells<RNNCellWeightWrapper>( | |||||
D, num_layers, input_size, _param.hidden_size, _param.bias, cells, | |||||
flatten_weights, workspace); | |||||
Workspace new_workspace( | |||||
workspace.raw_ptr + used_workspace_size, | |||||
workspace.size - used_workspace_size); | |||||
TensorNDArray states, states_new; | |||||
states.push_back(hx); | |||||
states_new.push_back(hy); | |||||
rnn::exec_internal<RNNCellWeightWrapper, RNNCellForward>( | |||||
cells, input, states, states_new, output, reserve_space, num_layers, D, | |||||
this->handle(), new_workspace); | |||||
} | |||||
size_t RNNImpl::get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& hx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& reserve_space) { | |||||
size_t workspace_size = rnn::get_workspace_in_bytes<RNNCellForward>( | |||||
input, flatten_weights, param().hidden_size, param().bidirectional ? 2 : 1, | |||||
this->handle()); | |||||
if (!param().bias) { // use fake bias (all 0) | |||||
TensorLayout bias_layout = {{param().hidden_size}, flatten_weights.dtype}; | |||||
workspace_size += bias_layout.span().dist_byte(); | |||||
} | |||||
workspace_size += output.span().dist_byte(); | |||||
return workspace_size; | |||||
} | |||||
size_t RNNImpl::get_reserve_size_in_bytes(const TensorLayout& input) { | |||||
size_t num_layers = param().num_layers; | |||||
size_t D = param().bidirectional ? 2 : 1; | |||||
size_t seq_len = input.shape[0]; | |||||
size_t batch_size = input.shape[1]; | |||||
TensorLayout state_layout{{batch_size, param().hidden_size}, input.dtype}; | |||||
return num_layers * D * seq_len * state_layout.span().dist_byte(); | |||||
} | |||||
void RNNBackwardImpl::exec( | |||||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in dy, _megdnn_tensor_in dhy, _megdnn_tensor_in flatten_weights, | |||||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx, | |||||
_megdnn_tensor_out dw, _megdnn_workspace workspace) { | |||||
TensorNDArray layer_inputs; | |||||
// layer_inputs.push_back(x); | |||||
TensorNDArray layer_outputs; | |||||
std::vector<std::vector<TensorNDArray>> cell_seq_states; | |||||
size_t num_layers = param().num_layers; | |||||
size_t D = param().bidirectional ? 2 : 1; | |||||
// size_t seq_len = x.layout.shape[0]; | |||||
// size_t batch_size = x.layout.shape[1]; | |||||
size_t input_size = x.layout.shape[2]; | |||||
size_t hidden_size = param().hidden_size; | |||||
size_t used_workspace_size = 0; | |||||
// get cells | |||||
std::vector<RNNCellWeightWrapper> cells; | |||||
// workspace_ptr = static_cast<uint8_t*>(workspace_ptr) + | |||||
used_workspace_size += rnn::get_cells( | |||||
D, num_layers, input_size, hidden_size, param().bias, cells, | |||||
flatten_weights, workspace); | |||||
// extract intermedia states from reserve space | |||||
/*for (int layer = 0; layer < num_layers; ++layer) { | |||||
TensorND layer_output{workspace_ptr, y.layout}; | |||||
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) + | |||||
layer_output.layout.span().dist_byte(); for (int d = 0; d < D; ++d) { | |||||
cell_seq_states.push_back(std::vector<TensorNDArray>()); | |||||
// reverse direction is stored with reversed order of sequence order | |||||
for (int i = 0; i < seq_len; ++i) { | |||||
size_t step = i; | |||||
if (d == 1) step = seq_len - i - 1; | |||||
size_t offset = ((layer * D + d) * seq_len + step) * | |||||
cell_output_layout.span().dist_byte(); TensorND | |||||
hy{static_cast<uint8_t*>(reserve_space.raw_ptr) + offset, cell_output_layout}; | |||||
// states | |||||
cell_seq_states[cell_seq_states.size() - 1].push_back({hy}); | |||||
// output | |||||
offset = i * D * cell_output_layout.span().dist_byte(); | |||||
memcpy(static_cast<uint8_t*>(layer_output.raw_ptr) + offset, | |||||
hy.raw_ptr, hy.layout.span().dist_byte()); | |||||
} | |||||
} | |||||
cell_seq_outputs.push_back(layer_output); | |||||
if (layer != num_layers - 1) layer_inputs.push_back(layer_output); | |||||
}*/ | |||||
// nonlinear mode | |||||
param::RNNCell::NonlineMode nonlineMode; | |||||
using ModeRNN = param::RNN::NonlineMode; | |||||
using ModeRNNCell = param::RNNCell::NonlineMode; | |||||
switch (param().nonlineMode) { | |||||
case ModeRNN::RELU: | |||||
nonlineMode = ModeRNNCell::RELU; | |||||
break; | |||||
case ModeRNN::TANH: | |||||
nonlineMode = ModeRNNCell::TANH; | |||||
break; | |||||
} | |||||
// get formatted inputs | |||||
Workspace new_workspace = Workspace( | |||||
workspace.raw_ptr + used_workspace_size, | |||||
workspace.size - used_workspace_size); | |||||
used_workspace_size += rnn::get_inputs_for_exec<RNNCellWeightWrapper>( | |||||
x, y, reserve_space, num_layers, D, hidden_size, cells, layer_inputs, | |||||
layer_outputs, cell_seq_states, nonlineMode, new_workspace); | |||||
// dhy arr, dhx arr | |||||
TensorNDArray dhy_arr = {dhy}, dhx_arr = {dhx}; | |||||
// exec | |||||
/*size_t used_workspace_size = static_cast<uint8_t*>(workspace_ptr) - | |||||
static_cast<uint8_t*>((void*)workspace.raw_ptr);*/ | |||||
new_workspace = Workspace( | |||||
workspace.raw_ptr + used_workspace_size, | |||||
workspace.size - used_workspace_size); | |||||
rnn::backward_exec_internal<RNNCellWeightWrapper>( | |||||
cells, D, num_layers, input_size, param().bias, nonlineMode, layer_inputs, | |||||
layer_outputs, cell_seq_states, dy, dhy_arr, dx, dhx_arr, dw, | |||||
this->handle(), new_workspace); | |||||
} | |||||
size_t RNNBackwardImpl::get_workspace_in_bytes( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||||
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw) { | |||||
size_t D = param().bidirectional ? 2 : 1; | |||||
size_t num_layers = param().num_layers; | |||||
size_t hidden_size = param().hidden_size; | |||||
size_t gate_hidden_size = hidden_size; | |||||
size_t max_input_size = std::max(x.shape[2], D * hidden_size); | |||||
size_t workspace_size = RNNCellWeightWrapper::backward_workspace_size_in_bytes( | |||||
this->handle(), x.shape[1], param().hidden_size, max_input_size, x.dtype); | |||||
if (!param().bias) { // use fake bias (all 0) | |||||
TensorLayout bias_layout = {{gate_hidden_size}, flatten_weights.dtype}; | |||||
workspace_size += bias_layout.span().dist_byte() * | |||||
2; // times 2 because another bias is allocated in | |||||
// backward_exec_internal | |||||
} | |||||
workspace_size += num_layers * y.span().dist_byte(); | |||||
// add back exec workspace size | |||||
workspace_size += y.span().dist_byte() * 2; | |||||
workspace_size += x.span().dist_byte() * 2; | |||||
TensorLayout wih{{gate_hidden_size, max_input_size}, flatten_weights.dtype}; | |||||
TensorLayout whh{{gate_hidden_size, hidden_size}, flatten_weights.dtype}; | |||||
TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype}; | |||||
workspace_size += wih.span().dist_byte(); | |||||
workspace_size += whh.span().dist_byte(); | |||||
workspace_size += bias.span().dist_byte(); | |||||
return workspace_size; | |||||
} | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,53 @@ | |||||
/** | |||||
* \file dnn/src/naive/rnn/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
class RNNImpl : public RNN { | |||||
public: | |||||
using RNN::RNN; | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, | |||||
_megdnn_workspace workspace); | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& hx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& reserve_space); | |||||
size_t get_reserve_size_in_bytes(const TensorLayout& input); | |||||
}; | |||||
class RNNBackwardImpl : public RNNBackward { | |||||
public: | |||||
using RNNBackward::RNNBackward; | |||||
virtual void exec( | |||||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space, | |||||
_megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw, | |||||
_megdnn_workspace workspace); | |||||
virtual size_t get_workspace_in_bytes( | |||||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||||
const TensorLayout& dy, const TensorLayout& dhy, | |||||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||||
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw); | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,285 @@ | |||||
/** | |||||
* \file dnn/src/naive/rnn/rnn.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/naive/rnn/rnn.h" | |||||
#include <cmath> | |||||
#include <cstring> | |||||
namespace megdnn { | |||||
namespace naive { | |||||
namespace rnn { | |||||
CellWeightsWrapperBase::CellWeightsWrapperBase( | |||||
void* weight_ptr, size_t hidden_size, size_t input_size, size_t num_chunks, | |||||
bool has_bias, DType dtype, _megdnn_workspace workspace) { | |||||
// weight_ih: [gate_hidden_size, input_size] | |||||
// weight_hh: [gate_hidden_size, hidden_size] | |||||
// bias_ih: [gate_hidden_size] | |||||
// bias_hh: [gate_hidden_size] | |||||
size_t gate_hidden_size = num_chunks * hidden_size; | |||||
TensorLayout weight_ih_layout{{gate_hidden_size, input_size}, dtype}; | |||||
TensorLayout weight_hh_layout{{gate_hidden_size, hidden_size}, dtype}; | |||||
TensorLayout bias_layout{{gate_hidden_size}, dtype}; | |||||
this->_weight_size = 0; | |||||
this->weight_ih = TensorND(weight_ptr, weight_ih_layout); | |||||
this->_weight_size += weight_ih_layout.span().dist_byte(); | |||||
this->weight_hh = TensorND( | |||||
static_cast<uint8_t*>(weight_ptr) + this->_weight_size, weight_hh_layout); | |||||
this->_weight_size += weight_hh_layout.span().dist_byte(); | |||||
if (has_bias) { | |||||
this->bias_ih = TensorND( | |||||
static_cast<uint8_t*>(weight_ptr) + this->_weight_size, bias_layout); | |||||
this->_weight_size += bias_layout.span().dist_byte(); | |||||
this->bias_hh = TensorND( | |||||
static_cast<uint8_t*>(weight_ptr) + this->_weight_size, bias_layout); | |||||
this->_weight_size += bias_layout.span().dist_byte(); | |||||
this->_workspace_size = 0; | |||||
} else { | |||||
this->bias_ih = TensorND(workspace.raw_ptr, bias_layout); | |||||
this->bias_hh = TensorND(workspace.raw_ptr, bias_layout); | |||||
memset(workspace.raw_ptr, 0, bias_layout.span().dist_byte()); | |||||
this->_workspace_size = bias_layout.span().dist_byte(); | |||||
} | |||||
} | |||||
size_t CellWeightsWrapperBase::weight_size_in_bytes() const { | |||||
return this->_weight_size; | |||||
} | |||||
size_t CellWeightsWrapperBase::workspace_size_in_bytes() const { | |||||
return this->_workspace_size; | |||||
} | |||||
size_t CellWeightsWrapperBase::num_states() const { | |||||
return 1; | |||||
} | |||||
void CellWeightsWrapperBase::backward( | |||||
Handle* handle, param::RNNCell::NonlineMode nonlineMode, _megdnn_tensor_in x, | |||||
const TensorNDArray& states, _megdnn_tensor_in y, const TensorNDArray& douts, | |||||
_megdnn_tensor_out dx, TensorNDArray& dstates, _megdnn_tensor_out dwi, | |||||
_megdnn_tensor_out dwh, _megdnn_tensor_out dbias, | |||||
_megdnn_workspace workspace) const { | |||||
auto dy = douts[0]; | |||||
using NonlineMode = param::RNNCell::NonlineMode; | |||||
using Mode = Elemwise::Mode; | |||||
auto elemwise_opr = handle->create_operator<ElemwiseForward>(); | |||||
TensorND tmp = {workspace.raw_ptr, dy.layout}; | |||||
auto new_workspace = Workspace( | |||||
workspace.raw_ptr + tmp.layout.span().dist_byte(), | |||||
workspace.size - tmp.layout.span().dist_byte()); | |||||
switch (nonlineMode) { | |||||
case (NonlineMode::IDENTITY): | |||||
memcpy(tmp.raw_ptr(), dy.raw_ptr(), dy.layout.span().dist_byte()); | |||||
break; | |||||
case (NonlineMode::TANH): | |||||
elemwise_opr->param().mode = Mode::TANH_GRAD; | |||||
elemwise_opr->exec({y, dy}, tmp); | |||||
break; | |||||
case (NonlineMode::RELU): | |||||
elemwise_opr->param().mode = Mode::SWITCH_GT0; | |||||
elemwise_opr->exec({y, dy}, tmp); | |||||
break; | |||||
} | |||||
auto matrixmul_opr = handle->create_operator<MatrixMulForward>(); | |||||
matrixmul_opr->param().transposeA = false; | |||||
matrixmul_opr->param().transposeB = false; | |||||
// dx | |||||
matrixmul_opr->exec(tmp, this->weight_ih, dx, new_workspace); | |||||
// dhx | |||||
matrixmul_opr->exec(tmp, this->weight_hh, dstates[0], new_workspace); | |||||
// dwi | |||||
matrixmul_opr->param().transposeA = true; | |||||
matrixmul_opr->exec(tmp, x, dwi, new_workspace); | |||||
// dwh | |||||
matrixmul_opr->exec(tmp, states[0], dwh, new_workspace); | |||||
// dbias | |||||
auto sum_opr = handle->create_operator<ReduceForward>(); | |||||
sum_opr->param().mode = ReduceForward::Mode::SUM; | |||||
sum_opr->param().axis = 0; | |||||
TensorND dbias_expanded = { | |||||
dbias.raw_ptr(), {{1, dbias.layout.shape[0]}, dbias.layout.dtype}}; | |||||
sum_opr->exec(tmp, dbias_expanded, new_workspace); | |||||
} | |||||
size_t CellWeightsWrapperBase::backward_workspace_size_in_bytes( | |||||
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, | |||||
size_t num_chunks, DType dtype) { | |||||
size_t gate_hidden_size = hidden_size * num_chunks; | |||||
TensorLayout tmp = {{batch_size, gate_hidden_size}, dtype}; | |||||
TensorLayout bias_expanded = {{1, gate_hidden_size}, dtype}; | |||||
TensorLayout wih = {{gate_hidden_size, input_size}, dtype}; | |||||
TensorLayout whh = {{gate_hidden_size, hidden_size}, dtype}; | |||||
TensorLayout x = {{batch_size, input_size}, dtype}; | |||||
TensorLayout hx = {{batch_size, hidden_size}, dtype}; | |||||
size_t workspace_size = 0; | |||||
auto matrixmul_opr = handle->create_operator<MatrixMulForward>(); | |||||
matrixmul_opr->param().transposeA = false; | |||||
matrixmul_opr->param().transposeB = false; | |||||
// dx | |||||
workspace_size = std::max( | |||||
workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, wih, x)); | |||||
// dhx | |||||
workspace_size = std::max( | |||||
workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, whh, hx)); | |||||
// dwi | |||||
matrixmul_opr->param().transposeA = true; | |||||
workspace_size = std::max( | |||||
workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, x, wih)); | |||||
// dwh | |||||
workspace_size = std::max( | |||||
workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, hx, whh)); | |||||
// dbias | |||||
auto sum_opr = handle->create_operator<ReduceForward>(); | |||||
sum_opr->param().mode = ReduceForward::Mode::SUM; | |||||
sum_opr->param().axis = 0; | |||||
workspace_size = std::max( | |||||
workspace_size, sum_opr->get_workspace_in_bytes(tmp, bias_expanded)); | |||||
workspace_size += tmp.span().dist_byte(); | |||||
return workspace_size; | |||||
} | |||||
RNNCellWeightWrapper::RNNCellWeightWrapper( | |||||
void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, | |||||
DType dtype, _megdnn_workspace workspace) | |||||
: CellWeightsWrapperBase( | |||||
weight_ptr, hidden_size, input_size, 1, has_bias, dtype, workspace) {} | |||||
size_t RNNCellWeightWrapper::backward_workspace_size_in_bytes( | |||||
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, | |||||
DType dtype) { | |||||
return CellWeightsWrapperBase::backward_workspace_size_in_bytes( | |||||
handle, batch_size, hidden_size, input_size, 1, dtype); | |||||
} | |||||
LSTMCellWeightWrapper::LSTMCellWeightWrapper( | |||||
void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, | |||||
DType dtype, _megdnn_workspace workspace) | |||||
: CellWeightsWrapperBase( | |||||
weight_ptr, hidden_size, input_size, 4, has_bias, dtype, workspace) {} | |||||
size_t LSTMCellWeightWrapper::num_states() const { | |||||
return 2; | |||||
} | |||||
size_t LSTMCellWeightWrapper::backward_workspace_size_in_bytes( | |||||
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, | |||||
DType dtype) { | |||||
// get gates size | |||||
size_t gate_hidden_size = 4 * hidden_size; | |||||
auto lstm_opr = handle->create_operator<LSTMCellForward>(); | |||||
TensorLayout x = {{batch_size, input_size}, dtype}; | |||||
TensorLayout weight_ih = {{gate_hidden_size, input_size}, dtype}; | |||||
TensorLayout weight_hh = {{gate_hidden_size, hidden_size}, dtype}; | |||||
TensorLayout bias = {{gate_hidden_size}, dtype}; | |||||
TensorLayout h = {{batch_size, hidden_size}, dtype}; | |||||
TensorLayout gates, h_new, c_new; | |||||
lstm_opr->deduce_layout( | |||||
x, weight_ih, bias, h, weight_hh, bias, h, h_new, c_new, gates); | |||||
return CellWeightsWrapperBase::backward_workspace_size_in_bytes( | |||||
handle, batch_size, hidden_size, input_size, 4, dtype) + | |||||
gates.span().dist_byte() * 2 + c_new.span().dist_byte(); | |||||
} | |||||
void LSTMCellWeightWrapper::backward( | |||||
Handle* handle, | |||||
param::RNNCell::NonlineMode nonlineMode, // nonlineMode must be identity | |||||
_megdnn_tensor_in x, const TensorNDArray& states, _megdnn_tensor_in y, | |||||
const TensorNDArray& douts, _megdnn_tensor_out dx, TensorNDArray& dstates, | |||||
_megdnn_tensor_out dwi, _megdnn_tensor_out dwh, _megdnn_tensor_out dbias, | |||||
_megdnn_workspace workspace) const { | |||||
size_t used_workspace_size = 0; | |||||
// get gates | |||||
auto lstm_opr = handle->create_operator<LSTMCellForward>(); | |||||
TensorLayout gates, h_new, c_new; | |||||
lstm_opr->deduce_layout( | |||||
x.layout, weight_ih.layout, bias_ih.layout, states[0].layout, | |||||
weight_hh.layout, bias_hh.layout, states[1].layout, h_new, c_new, gates); | |||||
TensorND gates_tensor{workspace.raw_ptr, gates}; | |||||
used_workspace_size += gates.span().dist_byte(); | |||||
TensorND gates_grad{workspace.raw_ptr + used_workspace_size, gates}; | |||||
used_workspace_size += gates.span().dist_byte(); | |||||
TensorND tanh_cy{workspace.raw_ptr + used_workspace_size, y.layout}; | |||||
used_workspace_size += tanh_cy.layout.span().dist_byte(); | |||||
Workspace new_workspace = Workspace( | |||||
workspace.raw_ptr + used_workspace_size, | |||||
workspace.size - used_workspace_size); | |||||
// temporarily use dstates to store hy, cy | |||||
// only gates and cy needed, other output will be cleared afterwards | |||||
lstm_opr->exec( | |||||
x, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states[1], dstates[0], | |||||
dstates[1], gates_tensor, | |||||
new_workspace); // no information left in the workspace | |||||
// i, f, o, g | |||||
TensorLayout single_gate = {{gates.shape[0], gates.shape[1] / 4}, gates.dtype}; | |||||
TensorND i, f, o, g, i_grad, f_grad, o_grad, | |||||
g_grad; // grad refers to the grad of gates before activation | |||||
i = {gates_tensor.raw_ptr(), single_gate}; | |||||
f = {static_cast<uint8_t*>(gates_tensor.raw_ptr()) + single_gate.span().dist_byte(), | |||||
single_gate}; | |||||
o = {static_cast<uint8_t*>(f.raw_ptr()) + single_gate.span().dist_byte(), | |||||
single_gate}; | |||||
g = {static_cast<uint8_t*>(o.raw_ptr()) + single_gate.span().dist_byte(), | |||||
single_gate}; | |||||
i_grad = {gates_grad.raw_ptr(), single_gate}; | |||||
f_grad = { | |||||
static_cast<uint8_t*>(i_grad.raw_ptr()) + single_gate.span().dist_byte(), | |||||
single_gate}; | |||||
o_grad = { | |||||
static_cast<uint8_t*>(f_grad.raw_ptr()) + single_gate.span().dist_byte(), | |||||
single_gate}; | |||||
g_grad = { | |||||
static_cast<uint8_t*>(o_grad.raw_ptr()) + single_gate.span().dist_byte(), | |||||
single_gate}; | |||||
// activation | |||||
auto elem_opr = handle->create_operator<ElemwiseForward>(); | |||||
elem_opr->param().mode = Elemwise::Mode::SIGMOID; | |||||
elem_opr->exec({i}, i); | |||||
elem_opr->exec({f}, f); | |||||
elem_opr->exec({o}, o); | |||||
elem_opr->param().mode = Elemwise::Mode::TANH; | |||||
elem_opr->exec({g}, g); | |||||
elem_opr->exec({dstates[1]}, tanh_cy); | |||||
auto mul_opr = handle->create_operator<ElemwiseForward>(); | |||||
mul_opr->param().mode = Elemwise::Mode::MUL; | |||||
// use dstates[0] as tmp tensor to store dhy * tanh_cy | |||||
mul_opr->exec({douts[0], tanh_cy}, dstates[0]); | |||||
elem_opr->param().mode = Elemwise::Mode::SIGMOID_GRAD; | |||||
elem_opr->exec({o, dstates[0]}, o_grad); // grad of gate o | |||||
// use dstates[0] as tmp tensor to store dhy * o | |||||
mul_opr->exec({douts[0], o}, dstates[0]); | |||||
elem_opr->param().mode = Elemwise::Mode::TANH_GRAD; | |||||
elem_opr->exec({tanh_cy, dstates[0]}, dstates[1]); // grad of cy from hy | |||||
elem_opr->param().mode = Elemwise::Mode::ADD; | |||||
elem_opr->exec({douts[1], dstates[1]}, dstates[1]); // true grad of cy | |||||
// use dstates[0] as tmp tensor to store dcy * cx | |||||
mul_opr->exec({dstates[1], states[1]}, dstates[0]); | |||||
elem_opr->param().mode = Elemwise::Mode::SIGMOID_GRAD; | |||||
elem_opr->exec({f, dstates[0]}, f_grad); // grad of gate f | |||||
// use dstates[0] as tmp tensor to store dcy * g | |||||
mul_opr->exec({dstates[1], g}, dstates[0]); | |||||
elem_opr->exec({i, dstates[0]}, i_grad); // grad of gate i | |||||
// use dstates[0] as tmp tensor to store dcy * i | |||||
mul_opr->exec({dstates[1], i}, dstates[0]); | |||||
elem_opr->param().mode = Elemwise::Mode::TANH_GRAD; | |||||
elem_opr->exec({g, dstates[0]}, g_grad); // grad of gate g | |||||
// grad if cx | |||||
mul_opr->exec({dstates[1], f}, dstates[1]); | |||||
TensorNDArray base_dstates = {dstates[0]}; | |||||
CellWeightsWrapperBase::backward( | |||||
handle, nonlineMode, x, {states[0]}, gates_tensor, {gates_grad}, dx, | |||||
base_dstates, dwi, dwh, dbias, new_workspace); | |||||
} | |||||
} // namespace rnn | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,73 @@ | |||||
/** | |||||
* \file dnn/src/naive/rnn/rnn.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "megdnn/oprs/general.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
namespace rnn { | |||||
class CellWeightsWrapperBase { | |||||
private: | |||||
size_t _weight_size, _workspace_size; | |||||
public: | |||||
TensorND weight_ih, weight_hh, bias_ih, bias_hh; | |||||
// if no bias, will create dummy bias tensor from workspace | |||||
CellWeightsWrapperBase( | |||||
void* weight_ptr, size_t hidden_size, size_t input_size, size_t num_chunks, | |||||
bool has_bias, DType dtype, _megdnn_workspace workspace); | |||||
size_t weight_size_in_bytes() const; | |||||
size_t workspace_size_in_bytes() const; | |||||
static size_t backward_workspace_size_in_bytes( | |||||
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, | |||||
size_t num_chunks, DType dtype); | |||||
virtual void backward( | |||||
Handle* handle, param::RNNCell::NonlineMode nonlineMode, | |||||
_megdnn_tensor_in x, const TensorNDArray& states, _megdnn_tensor_in y, | |||||
const TensorNDArray& douts, _megdnn_tensor_out dx, TensorNDArray& dstates, | |||||
_megdnn_tensor_out dwi, _megdnn_tensor_out dwh, _megdnn_tensor_out dbias, | |||||
_megdnn_workspace workspace) const; | |||||
virtual size_t num_states() const; | |||||
}; | |||||
class RNNCellWeightWrapper : public CellWeightsWrapperBase { | |||||
public: | |||||
RNNCellWeightWrapper( | |||||
void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, | |||||
DType dtype, _megdnn_workspace workspace); | |||||
static size_t backward_workspace_size_in_bytes( | |||||
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, | |||||
DType dtype); | |||||
}; | |||||
class LSTMCellWeightWrapper : public CellWeightsWrapperBase { | |||||
public: | |||||
LSTMCellWeightWrapper( | |||||
void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, | |||||
DType dtype, _megdnn_workspace workspace); | |||||
static size_t backward_workspace_size_in_bytes( | |||||
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, | |||||
DType dtype); | |||||
size_t num_states() const override; | |||||
void backward( | |||||
Handle* handle, param::RNNCell::NonlineMode nonlineMode, | |||||
_megdnn_tensor_in x, const TensorNDArray& states, _megdnn_tensor_in y, | |||||
const TensorNDArray& douts, _megdnn_tensor_out dx, TensorNDArray& dstates, | |||||
_megdnn_tensor_out dwi, _megdnn_tensor_out dwh, _megdnn_tensor_out dbias, | |||||
_megdnn_workspace workspace) const override; | |||||
}; | |||||
} // namespace rnn | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* \file dnn/src/naive/rnn/template_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/naive/rnn/funcs.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
namespace rnn { | |||||
template <> | |||||
void cell_opr_exec<RNNCellForward>( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in bias_hh, const TensorNDArray& states, | |||||
TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle) { | |||||
auto opr = handle->create_operator<RNNCellForward>(); | |||||
opr->exec( | |||||
input, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states_new[0], | |||||
workspace); | |||||
} | |||||
template <> | |||||
size_t cell_opr_get_workspace_in_bytes<RNNCellForward>( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_ih, | |||||
const TensorLayout& bias_hh, const TensorLayout& hx, Handle* handle) { | |||||
auto cell_opr = handle->create_operator<RNNCellForward>(); | |||||
return cell_opr->get_workspace_in_bytes( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, hx); | |||||
} | |||||
} // namespace rnn | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,34 @@ | |||||
/** | |||||
* \file dnn/src/naive/rnn_cell/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/naive/rnn_cell/opr_impl.h" | |||||
#include "src/common/rnn_cell.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
size_t RNNCellImpl::get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& dst) { | |||||
return megdnn::rnn_cell::get_workspace_in_bytes( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, handle()); | |||||
} | |||||
void RNNCellImpl::exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
megdnn::rnn_cell::exec( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, workspace, | |||||
param().nonlineMode, handle()); | |||||
} | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -0,0 +1,33 @@ | |||||
/** | |||||
* \file dnn/src/naive/rnn_cell/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
class RNNCellImpl : public RNNCell { | |||||
public: | |||||
using RNNCell::RNNCell; | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& dst) override; | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn |
@@ -77,6 +77,15 @@ struct DeduceLayoutProxy<Opr, 6, false> { | |||||
}; | }; | ||||
template <typename Opr> | template <typename Opr> | ||||
struct DeduceLayoutProxy<Opr, 6, true> { | |||||
static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) { | |||||
megdnn_assert(layouts.size() == 6); | |||||
opr->deduce_layout( | |||||
layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5]); | |||||
} | |||||
}; | |||||
template <typename Opr> | |||||
struct DeduceLayoutProxy<Opr, 7, false> { | struct DeduceLayoutProxy<Opr, 7, false> { | ||||
static void deduce_layout(Opr*, TensorLayoutArray&) {} | static void deduce_layout(Opr*, TensorLayoutArray&) {} | ||||
}; | }; | ||||
@@ -0,0 +1,51 @@ | |||||
/** | |||||
* \file dnn/test/common/rnn.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <vector> | |||||
#include "megdnn/basic_types.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
namespace rnn { | |||||
struct TestArg { | |||||
param::RNN param; | |||||
TensorShape input, hx, flatten_weights; | |||||
TestArg(param::RNN param, TensorShape input, TensorShape hx, | |||||
TensorShape flatten_weights) | |||||
: param(param), input(input), hx(hx), flatten_weights(flatten_weights) {} | |||||
}; | |||||
inline std::vector<TestArg> get_args() { | |||||
std::vector<TestArg> args; | |||||
size_t batch_size = 2; | |||||
size_t input_size = 3; | |||||
size_t hidden_size = 2; | |||||
size_t seq_len = 2; | |||||
size_t gate_hidden_size = hidden_size; | |||||
param::RNN param; | |||||
param.num_layers = 1; | |||||
param.bidirectional = false; | |||||
param.bias = false; | |||||
param.hidden_size = hidden_size; | |||||
param.nonlineMode = param::RNN::NonlineMode::RELU; | |||||
args.emplace_back( | |||||
param, TensorShape{seq_len, batch_size, input_size}, | |||||
TensorShape{batch_size, hidden_size}, | |||||
TensorShape{gate_hidden_size, input_size + hidden_size}); | |||||
return args; | |||||
} | |||||
} // namespace rnn | |||||
} // namespace test | |||||
} // namespace megdnn |
@@ -0,0 +1,80 @@ | |||||
/** | |||||
* \file dnn/test/naive/rnn.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "test/common/rnn.h" | |||||
#include "megdnn/dtype.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/naive/fixture.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
/*TEST_F(NAIVE, RNN) { | |||||
std::vector<rnn::TestArg> args = rnn::get_args(); | |||||
Checker<RNN> checker(handle()); | |||||
for (auto&& arg : args) { | |||||
checker.set_param(arg.param) | |||||
.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.set_dtype(3, dtype::Float32()) | |||||
.set_dtype(4, dtype::Float32()) | |||||
.set_dtype(5, dtype::Float32()) | |||||
.execs({arg.input, arg.hx, arg.flatten_weights, {}, {}, {}}); | |||||
} | |||||
}*/ | |||||
TEST_F(NAIVE, RNN_HAND_MADE) { | |||||
Checker<RNN> checker(handle(), false); | |||||
size_t batch_size = 2; | |||||
size_t input_size = 3; | |||||
size_t hidden_size = 2; | |||||
size_t seq_len = 2; | |||||
size_t gate_hidden_size = hidden_size; | |||||
RNN::Param param; | |||||
param.num_layers = 1; | |||||
param.bidirectional = false; | |||||
param.bias = false; | |||||
param.hidden_size = hidden_size; | |||||
param.nonlineMode = param::RNN::NonlineMode::RELU; | |||||
checker.set_param(param).exect( | |||||
Testcase{ | |||||
TensorValue( | |||||
{seq_len, batch_size, input_size}, dtype::Float32(), | |||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input | |||||
TensorValue( | |||||
{batch_size, hidden_size}, dtype::Float32(), | |||||
{2, 1, 3, 5}), // hx | |||||
TensorValue( | |||||
{gate_hidden_size, input_size + hidden_size}, | |||||
dtype::Float32(), | |||||
{3, 6, 1, 3, 2, 7, 9, 3, 5, 1}), // weights | |||||
{}, | |||||
{}, | |||||
{}}, | |||||
Testcase{ | |||||
{}, | |||||
{}, | |||||
{}, | |||||
TensorValue( | |||||
{seq_len, batch_size, hidden_size}, dtype::Float32(), | |||||
{39, 39, 90, 84, 300, 216, 546, 366}), // output | |||||
TensorValue( | |||||
{batch_size, hidden_size}, dtype::Float32(), | |||||
{21, 11, 42, 20}), // hy | |||||
TensorValue( | |||||
{1, 2, 2, 2}, dtype::Float32(), | |||||
{2, 1, 3, 5, 21, 11, 42, 20}) // reserve space | |||||
}); | |||||
} | |||||
} // namespace test | |||||
} // namespace megdnn |
@@ -36,5 +36,6 @@ from .padding import Pad | |||||
from .pixel_shuffle import PixelShuffle | from .pixel_shuffle import PixelShuffle | ||||
from .pooling import AvgPool2d, MaxPool2d | from .pooling import AvgPool2d, MaxPool2d | ||||
from .quant_dequant import DequantStub, QuantStub | from .quant_dequant import DequantStub, QuantStub | ||||
from .rnn import LSTM, RNN, LSTMCell, RNNCell | |||||
from .sequential import Sequential | from .sequential import Sequential | ||||
from .sliding_window import SlidingWindow, SlidingWindowTranspose | from .sliding_window import SlidingWindow, SlidingWindowTranspose |
@@ -0,0 +1,396 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import math | |||||
import numbers | |||||
from abc import abstractmethod | |||||
from typing import Optional, Tuple | |||||
import numpy as np | |||||
from ..core._imperative_rt.core2 import apply | |||||
from ..core.ops import builtin | |||||
from ..device import is_cuda_available | |||||
from ..functional import concat, expand_dims, repeat, stack, zeros | |||||
from ..functional.nn import concat | |||||
from ..tensor import Parameter, Tensor | |||||
from . import init | |||||
from .module import Module | |||||
class RNNCellBase(Module): | |||||
def __init__( | |||||
self, | |||||
input_size: int, | |||||
hidden_size: int, | |||||
bias: bool, | |||||
num_chunks: int, | |||||
device=None, | |||||
dtype=None, | |||||
) -> None: | |||||
# num_chunks indicates the number of gates | |||||
super(RNNCellBase, self).__init__() | |||||
self.input_size = input_size | |||||
self.hidden_size = hidden_size | |||||
self.bias = bias | |||||
# initialize weights | |||||
common_kwargs = {"device": device, "dtype": dtype} | |||||
self.gate_hidden_size = num_chunks * hidden_size | |||||
self.weight_ih = Parameter( | |||||
np.random.uniform(size=(self.gate_hidden_size, input_size)).astype( | |||||
np.float32 | |||||
), | |||||
**common_kwargs, | |||||
) | |||||
self.weight_hh = Parameter( | |||||
np.random.uniform(size=(self.gate_hidden_size, hidden_size)).astype( | |||||
np.float32 | |||||
), | |||||
**common_kwargs, | |||||
) | |||||
if bias: | |||||
self.bias_ih = Parameter( | |||||
np.random.uniform(size=(self.gate_hidden_size)).astype(np.float32), | |||||
**common_kwargs, | |||||
) | |||||
self.bias_hh = Parameter( | |||||
np.random.uniform(size=(self.gate_hidden_size)).astype(np.float32), | |||||
**common_kwargs, | |||||
) | |||||
else: | |||||
self.bias_ih = zeros(shape=(self.gate_hidden_size), **common_kwargs) | |||||
self.bias_hh = zeros(shape=(self.gate_hidden_size), **common_kwargs) | |||||
self.reset_parameters() | |||||
# if bias is False self.bias will remain zero | |||||
def get_op(self): | |||||
return builtin.RNNCell() | |||||
def reset_parameters(self) -> None: | |||||
stdv = 1.0 / math.sqrt(self.hidden_size) | |||||
for weight in self.parameters(): | |||||
init.uniform_(weight, -stdv, stdv) | |||||
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: | |||||
if hx is None: | |||||
hx = zeros( | |||||
shape=(input.shape[0], self.gate_hidden_size), | |||||
dtype=input.dtype, | |||||
device=input.device, | |||||
) | |||||
op = self.get_op() | |||||
return apply( | |||||
op, input, self.weight_ih, self.bias_ih, hx, self.weight_hh, self.bias_hh | |||||
)[0] | |||||
# return linear(input, self.weight_ih, self.bias_ih) + linear(hx, self.weight_hh, self.bias_hh) | |||||
class RNNCell(RNNCellBase): | |||||
def __init__( | |||||
self, | |||||
input_size: int, | |||||
hidden_size: int, | |||||
bias: bool = True, | |||||
nonlinearity: str = "tanh", | |||||
device=None, | |||||
dtype=None, | |||||
) -> None: | |||||
self.nonlinearity = nonlinearity | |||||
super(RNNCell, self).__init__( | |||||
input_size, hidden_size, bias, num_chunks=1, device=device, dtype=dtype | |||||
) | |||||
# self.activate = tanh if nonlinearity == "tanh" else relu | |||||
def get_op(self): | |||||
return builtin.RNNCell(nonlineMode=self.nonlinearity) | |||||
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: | |||||
return super().forward(input, hx) | |||||
class LSTMCell(RNNCellBase): | |||||
def __init__( | |||||
self, | |||||
input_size: int, | |||||
hidden_size: int, | |||||
bias: bool = True, | |||||
device=None, | |||||
dtype=None, | |||||
) -> None: | |||||
super(LSTMCell, self).__init__( | |||||
input_size, hidden_size, bias, num_chunks=4, device=device, dtype=dtype | |||||
) | |||||
def get_op(self): | |||||
return builtin.LSTMCell() | |||||
def forward( | |||||
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None | |||||
) -> Tuple[Tensor, Tensor]: | |||||
# hx: (h, c) | |||||
if hx is None: | |||||
h = zeros( | |||||
shape=(input.shape[0], self.hidden_size), | |||||
dtype=input.dtype, | |||||
device=input.device, | |||||
) | |||||
c = zeros( | |||||
shape=(input.shape[0], self.hidden_size), | |||||
dtype=input.dtype, | |||||
device=input.device, | |||||
) | |||||
else: | |||||
h, c = hx | |||||
op = self.get_op() | |||||
return apply( | |||||
op, input, self.weight_ih, self.bias_ih, h, self.weight_hh, self.bias_hh, c | |||||
)[:2] | |||||
def is_gpu(device: str) -> bool: | |||||
if "xpux" in device and is_cuda_available(): | |||||
return True | |||||
if "gpu" in device: | |||||
return True | |||||
return False | |||||
class RNNBase(Module): | |||||
def __init__( | |||||
self, | |||||
input_size: int, | |||||
hidden_size: int, | |||||
num_layers: int = 1, | |||||
bias: bool = True, | |||||
batch_first: bool = False, | |||||
dropout: float = 0, | |||||
bidirectional: bool = False, | |||||
proj_size: int = 0, | |||||
device=None, | |||||
dtype=None, | |||||
) -> None: | |||||
super(RNNBase, self).__init__() | |||||
# self.mode = mode | |||||
self.input_size = input_size | |||||
self.hidden_size = hidden_size | |||||
self.num_layers = num_layers | |||||
self.bias = bias | |||||
self.batch_first = batch_first | |||||
self.dropout = float(dropout) | |||||
self.bidirectional = bidirectional | |||||
self.num_directions = 2 if self.bidirectional else 1 | |||||
self.proj_size = proj_size | |||||
# check validity of dropout | |||||
if ( | |||||
not isinstance(dropout, numbers.Number) | |||||
or not 0 <= dropout <= 1 | |||||
or isinstance(dropout, bool) | |||||
): | |||||
raise ValueError( | |||||
"Dropout should be a float in [0, 1], which indicates the probability " | |||||
"of an element to be zero" | |||||
) | |||||
if proj_size < 0: | |||||
raise ValueError( | |||||
"proj_size should be a positive integer or zero to disable projections" | |||||
) | |||||
elif proj_size >= hidden_size: | |||||
raise ValueError("proj_size has to be smaller than hidden_size") | |||||
self.cells = [] | |||||
for layer in range(self.num_layers): | |||||
self.cells.append([]) | |||||
for _ in range(self.num_directions): | |||||
self.cells[layer].append(self.create_cell(layer, device, dtype)) | |||||
# parameters have been initialized during the creation of the cells | |||||
# if flatten, then delete cells | |||||
self._flatten_parameters(device, dtype, self.cells) | |||||
def _flatten_parameters(self, device, dtype, cells): | |||||
gate_hidden_size = cells[0][0].gate_hidden_size | |||||
size_dim1 = 0 | |||||
for layer in range(self.num_layers): | |||||
for direction in range(self.num_directions): | |||||
size_dim1 += cells[layer][direction].weight_ih.shape[1] | |||||
size_dim1 += cells[layer][direction].weight_hh.shape[1] | |||||
# if self.bias: | |||||
# size_dim1 += 2 * self.num_directions * self.num_layers | |||||
size_dim1 += 2 * self.num_directions * self.num_layers | |||||
self._flatten_weights = Parameter( | |||||
np.zeros((gate_hidden_size, size_dim1), dtype=np.float32) | |||||
) | |||||
self.reset_parameters() | |||||
# TODO: if no bias, set the bias to zero | |||||
def reset_parameters(self) -> None: | |||||
stdv = 1.0 / math.sqrt(self.hidden_size) | |||||
for weight in self.parameters(): | |||||
init.uniform_(weight, -stdv, stdv) | |||||
@abstractmethod | |||||
def create_cell(self, layer, device, dtype): | |||||
raise NotImplementedError("Cell not implemented !") | |||||
@abstractmethod | |||||
def init_hidden(self): | |||||
raise NotImplementedError("init_hidden not implemented !") | |||||
@abstractmethod | |||||
def get_output_from_hidden(self, hx): | |||||
raise NotImplementedError("get_output_from_hidden not implemented !") | |||||
@abstractmethod | |||||
def apply_op(self, input, hx): | |||||
raise NotImplementedError("apply_op not implemented !") | |||||
def _apply_fn_to_hx(self, hx, fn): | |||||
return fn(hx) | |||||
def _stack_h_n(self, h_n): | |||||
return stack(h_n, axis=0) | |||||
def forward(self, input: Tensor, hx=None): | |||||
if self.batch_first: | |||||
batch_size = input.shape[0] | |||||
input = input.transpose((1, 0, 2)) # [seq_len, batch_size, dim] | |||||
else: | |||||
batch_size = input.shape[1] | |||||
if hx is None: | |||||
hx = self.init_hidden(batch_size, input.device, input.dtype) | |||||
output, h = self.apply_op(input, hx) | |||||
if self.batch_first: | |||||
output = output.transpose((1, 0, 2)) | |||||
return output, h | |||||
if is_gpu(str(input.device)) or True: | |||||
# return output, h_n | |||||
output, h = self.apply_op(input, hx) | |||||
if self.batch_first: | |||||
output = output.transpose((1, 0, 2)) | |||||
return output, h | |||||
order_settings = [(0, input.shape[0]), (input.shape[0] - 1, -1, -1)] | |||||
h_n = [] | |||||
for layer in range(self.num_layers): | |||||
layer_outputs = [] | |||||
for direction in range(self.num_directions): | |||||
direction_outputs = [None for _ in range(input.shape[0])] | |||||
cell = self.cells[layer][direction] | |||||
hidden = self._apply_fn_to_hx( | |||||
hx, lambda x: x[layer * self.num_directions + direction] | |||||
) | |||||
for step in range(*(order_settings[direction])): | |||||
hidden = cell(input[step], hidden) # [batch_size, hidden_size] | |||||
direction_outputs[step] = self.get_output_from_hidden(hidden) | |||||
direction_output = stack( | |||||
direction_outputs, axis=0 | |||||
) # [seq_len, batch_size, hidden_size] | |||||
layer_outputs.append(direction_output) | |||||
h_n.append(hidden) | |||||
layer_output = concat( | |||||
layer_outputs, axis=-1 | |||||
) # [seq_len, batch_size, D*hidden_size] | |||||
input = layer_output | |||||
if self.batch_first: | |||||
layer_output = layer_output.transpose((1, 0, 2)) | |||||
return layer_output, self._stack_h_n(h_n) | |||||
class RNN(RNNBase): | |||||
def __init__(self, *args, **kwargs) -> None: | |||||
self.nonlinearity = kwargs.pop("nonlinearity", "tanh") | |||||
super(RNN, self).__init__(*args, **kwargs) | |||||
def create_cell(self, layer, device, dtype): | |||||
if layer == 0: | |||||
input_size = self.input_size | |||||
else: | |||||
input_size = self.num_directions * self.hidden_size | |||||
return RNNCell( | |||||
input_size, self.hidden_size, self.bias, self.nonlinearity, device, dtype | |||||
) | |||||
def init_hidden(self, batch_size, device, dtype): | |||||
hidden_shape = ( | |||||
self.num_directions * self.num_layers, | |||||
batch_size, | |||||
self.hidden_size, | |||||
) | |||||
return zeros(shape=hidden_shape, dtype=dtype, device=device) | |||||
def get_output_from_hidden(self, hx): | |||||
return hx | |||||
def apply_op(self, input, hx): | |||||
op = builtin.RNN( | |||||
num_layers=self.num_layers, | |||||
bidirectional=self.bidirectional, | |||||
bias=self.bias, | |||||
hidden_size=self.hidden_size, | |||||
proj_size=self.proj_size, | |||||
dropout=self.dropout, | |||||
nonlineMode=self.nonlinearity, | |||||
) | |||||
output, h = apply(op, input, hx, self._flatten_weights)[:2] | |||||
output = output + h.sum() * 0 | |||||
h = h + output.sum() * 0 | |||||
return output, h | |||||
class LSTM(RNNBase): | |||||
def __init__(self, *args, **kwargs) -> None: | |||||
super(LSTM, self).__init__(*args, **kwargs) | |||||
def create_cell(self, layer, device, dtype): | |||||
if layer == 0: | |||||
input_size = self.input_size | |||||
else: | |||||
input_size = self.num_directions * self.hidden_size | |||||
return LSTMCell(input_size, self.hidden_size, self.bias, device, dtype) | |||||
def init_hidden(self, batch_size, device, dtype): | |||||
hidden_shape = ( | |||||
self.num_directions * self.num_layers, | |||||
batch_size, | |||||
self.hidden_size, | |||||
) | |||||
h = zeros(shape=hidden_shape, dtype=dtype, device=device) | |||||
c = zeros(shape=hidden_shape, dtype=dtype, device=device) | |||||
return (h, c) | |||||
def get_output_from_hidden(self, hx): | |||||
return hx[0] | |||||
def apply_op(self, input, hx): | |||||
op = builtin.LSTM( | |||||
num_layers=self.num_layers, | |||||
bidirectional=self.bidirectional, | |||||
bias=self.bias, | |||||
hidden_size=self.hidden_size, | |||||
proj_size=self.proj_size, | |||||
dropout=self.dropout, | |||||
) | |||||
output, h, c = apply(op, input, hx[0], hx[1], self._flatten_weights)[:3] | |||||
placeholders = [output.sum() * 0, h.sum() * 0, c.sum() * 0] | |||||
output = output + placeholders[1] + placeholders[2] | |||||
h = h + placeholders[0] + placeholders[2] | |||||
c = c + placeholders[0] + placeholders[1] | |||||
return output, (h, c) | |||||
def _apply_fn_to_hx(self, hx, fn): | |||||
return (fn(hx[0]), fn(hx[1])) | |||||
def _stack_h_n(self, h_n): | |||||
h = [tup[0] for tup in h_n] | |||||
c = [tup[1] for tup in h_n] | |||||
return (stack(h, axis=0), stack(c, axis=0)) |
@@ -0,0 +1,181 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.functional as F | |||||
from megengine.module import LSTM, RNN, LSTMCell, RNNCell | |||||
def assert_tuple_equal(src, ref): | |||||
assert len(src) == len(ref) | |||||
for i, j in zip(src, ref): | |||||
assert i == j | |||||
@pytest.mark.parametrize( | |||||
"batch_size, input_size, hidden_size, init_hidden", | |||||
[(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False)], | |||||
) | |||||
def test_rnn_cell(batch_size, input_size, hidden_size, init_hidden): | |||||
rnn_cell = RNNCell(input_size, hidden_size) | |||||
x = mge.random.normal(size=(batch_size, input_size)) | |||||
if init_hidden: | |||||
h = F.zeros(shape=(batch_size, hidden_size)) | |||||
else: | |||||
h = None | |||||
h_new = rnn_cell(x, h) | |||||
assert_tuple_equal(h_new.shape, (batch_size, hidden_size)) | |||||
# is batch_size == 0 tolerated ? it will cause error in slice operation xx[:, ...] | |||||
@pytest.mark.parametrize( | |||||
"batch_size, input_size, hidden_size, init_hidden", | |||||
[(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False)], | |||||
) | |||||
def test_lstm_cell(batch_size, input_size, hidden_size, init_hidden): | |||||
rnn_cell = LSTMCell(input_size, hidden_size) | |||||
x = mge.random.normal(size=(batch_size, input_size)) | |||||
if init_hidden: | |||||
h = F.zeros(shape=(batch_size, hidden_size)) | |||||
hx = (h, h) | |||||
else: | |||||
hx = None | |||||
h_new, c_new = rnn_cell(x, hx) | |||||
assert_tuple_equal(h_new.shape, (batch_size, hidden_size)) | |||||
assert_tuple_equal(c_new.shape, (batch_size, hidden_size)) | |||||
@pytest.mark.parametrize( | |||||
"batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first", | |||||
[ | |||||
(3, 6, 10, 20, 2, False, False, True), | |||||
pytest.param( | |||||
3, | |||||
3, | |||||
10, | |||||
10, | |||||
1, | |||||
True, | |||||
True, | |||||
False, | |||||
marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"), | |||||
), | |||||
], | |||||
) | |||||
# (0, 1, 1, 1, 1, False, True, False)]) | |||||
def test_rnn( | |||||
batch_size, | |||||
seq_len, | |||||
input_size, | |||||
hidden_size, | |||||
num_layers, | |||||
bidirectional, | |||||
init_hidden, | |||||
batch_first, | |||||
): | |||||
rnn = RNN( | |||||
input_size, | |||||
hidden_size, | |||||
batch_first=batch_first, | |||||
num_layers=num_layers, | |||||
bidirectional=bidirectional, | |||||
) | |||||
if batch_first: | |||||
x_shape = (batch_size, seq_len, input_size) | |||||
else: | |||||
x_shape = (seq_len, batch_size, input_size) | |||||
x = mge.random.normal(size=x_shape) | |||||
total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size | |||||
if init_hidden: | |||||
h = mge.random.normal(size=(batch_size, total_hidden_size)) | |||||
else: | |||||
h = None | |||||
output, h_n = rnn(x, h) | |||||
num_directions = 2 if bidirectional else 1 | |||||
if batch_first: | |||||
assert_tuple_equal( | |||||
output.shape, (batch_size, seq_len, num_directions * hidden_size) | |||||
) | |||||
else: | |||||
assert_tuple_equal( | |||||
output.shape, (seq_len, batch_size, num_directions * hidden_size) | |||||
) | |||||
assert_tuple_equal( | |||||
h_n.shape, (num_directions * num_layers, batch_size, hidden_size) | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first", | |||||
[ | |||||
(3, 10, 20, 20, 1, False, False, True), | |||||
pytest.param( | |||||
3, | |||||
3, | |||||
10, | |||||
10, | |||||
1, | |||||
True, | |||||
True, | |||||
False, | |||||
marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"), | |||||
), | |||||
], | |||||
) | |||||
# (0, 1, 1, 1, 1, False, True, False)]) | |||||
def test_lstm( | |||||
batch_size, | |||||
seq_len, | |||||
input_size, | |||||
hidden_size, | |||||
num_layers, | |||||
bidirectional, | |||||
init_hidden, | |||||
batch_first, | |||||
): | |||||
rnn = LSTM( | |||||
input_size, | |||||
hidden_size, | |||||
batch_first=batch_first, | |||||
num_layers=num_layers, | |||||
bidirectional=bidirectional, | |||||
) | |||||
if batch_first: | |||||
x_shape = (batch_size, seq_len, input_size) | |||||
else: | |||||
x_shape = (seq_len, batch_size, input_size) | |||||
x = mge.random.normal(size=x_shape) | |||||
total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size | |||||
if init_hidden: | |||||
h = mge.random.normal(size=(batch_size, total_hidden_size)) | |||||
h = (h, h) | |||||
else: | |||||
h = None | |||||
output, h_n = rnn(x, h) | |||||
num_directions = 2 if bidirectional else 1 | |||||
if batch_first: | |||||
assert_tuple_equal( | |||||
output.shape, (batch_size, seq_len, num_directions * hidden_size) | |||||
) | |||||
else: | |||||
assert_tuple_equal( | |||||
output.shape, (seq_len, batch_size, num_directions * hidden_size) | |||||
) | |||||
assert_tuple_equal( | |||||
h_n[0].shape, (num_directions * num_layers, batch_size, hidden_size) | |||||
) | |||||
assert_tuple_equal( | |||||
h_n[1].shape, (num_directions * num_layers, batch_size, hidden_size) | |||||
) | |||||
if __name__ == "__main__": | |||||
test_lstm(5, 10, 10, 20, 1, False, False, True) |
@@ -0,0 +1,68 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/rnn.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/opr/dnn/rnn.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "../op_trait.h" | |||||
namespace mgb::imperative { | |||||
namespace rnn_cell { | |||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const RNNCell&>(def); | |||||
mgb_assert(inputs.size() == 6); | |||||
return opr::RNNCell::make( | |||||
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], | |||||
op.param()); | |||||
} | |||||
OP_TRAIT_REG(RNNCell, RNNCell).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace rnn_cell | |||||
namespace lstm_cell { | |||||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const LSTMCell&>(def); | |||||
mgb_assert(inputs.size() == 7); | |||||
auto* opr = opr::LSTMCell::make( | |||||
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], | |||||
inputs[5], inputs[6], op.param()) | |||||
.node() | |||||
->owner_opr(); | |||||
return {opr->output(0), opr->output(1), opr->output(2)}; | |||||
} | |||||
OP_TRAIT_REG(LSTMCell, LSTMCell).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace lstm_cell | |||||
namespace rnn { | |||||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const RNN&>(def); | |||||
mgb_assert(inputs.size() == 3); | |||||
auto* opr = opr::RNN::make(inputs[0], inputs[1], inputs[2], op.param()) | |||||
.node() | |||||
->owner_opr(); | |||||
return {opr->output(0), opr->output(1), opr->output(2)}; | |||||
} | |||||
OP_TRAIT_REG(RNN, RNN).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace rnn | |||||
namespace lstm { | |||||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const LSTM&>(def); | |||||
mgb_assert(inputs.size() == 4); | |||||
auto* opr = opr::LSTM::make(inputs[0], inputs[1], inputs[2], inputs[3], op.param()) | |||||
.node() | |||||
->owner_opr(); | |||||
return {opr->output(0), opr->output(1), opr->output(2), opr->output(3)}; | |||||
} | |||||
OP_TRAIT_REG(LSTM, LSTM).apply_on_var_node(apply_on_var_node).fallback(); | |||||
} // namespace lstm | |||||
} // namespace mgb::imperative |
@@ -432,6 +432,13 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>; | |||||
def LRN: MgbHashableOp<"LRN", [LRNParam]>; | def LRN: MgbHashableOp<"LRN", [LRNParam]>; | ||||
def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; | def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; | ||||
def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; | |||||
def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>; | |||||
def RNN: MgbHashableOp<"RNN", [RNNParam]>; | |||||
def LSTM: MgbHashableOp<"LSTM", [LSTMParam]>; | |||||
def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> { | def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> { | ||||
let extraArguments = (ins | let extraArguments = (ins | ||||
@@ -21,6 +21,7 @@ | |||||
#include "megbrain/opr/dnn/lrn.h" | #include "megbrain/opr/dnn/lrn.h" | ||||
#include "megbrain/opr/dnn/lsq.h" | #include "megbrain/opr/dnn/lsq.h" | ||||
#include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
#include "megbrain/opr/dnn/rnn.h" | |||||
#include "megbrain/opr/dnn/roi_align.h" | #include "megbrain/opr/dnn/roi_align.h" | ||||
#include "megbrain/opr/dnn/roi_pooling.h" | #include "megbrain/opr/dnn/roi_pooling.h" | ||||
#include "megbrain/opr/dnn/sliding_window_transpose.h" | #include "megbrain/opr/dnn/sliding_window_transpose.h" | ||||
@@ -293,6 +294,36 @@ struct OprMaker<opr::LSQBackward, 5> { | |||||
->owner_opr(); | ->owner_opr(); | ||||
} | } | ||||
}; | }; | ||||
template <> | |||||
struct OprMaker<opr::RNNBackward, 7> { | |||||
using Param = opr::RNNBackward::Param; | |||||
static cg::OperatorNodeBase* make( | |||||
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||||
const OperatorNodeConfig& config) { | |||||
MGB_MARK_USED_VAR(graph); | |||||
return opr::RNNBackward::make( | |||||
i[0], i[1], i[2], i[3], i[4], i[5], i[6], param, config)[0] | |||||
.node() | |||||
->owner_opr(); | |||||
} | |||||
}; | |||||
template <> | |||||
struct OprMaker<opr::LSTMBackward, 9> { | |||||
using Param = opr::LSTMBackward::Param; | |||||
static cg::OperatorNodeBase* make( | |||||
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||||
const OperatorNodeConfig& config) { | |||||
MGB_MARK_USED_VAR(graph); | |||||
return opr::LSTMBackward::make( | |||||
i[0], i[1], i[2], i[3], i[4], i[5], i[6], i[7], i[8], param, | |||||
config)[0] | |||||
.node() | |||||
->owner_opr(); | |||||
} | |||||
}; | |||||
template <> | template <> | ||||
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | ||||
: public GeneralOprLoadDumpImpl< | : public GeneralOprLoadDumpImpl< | ||||
@@ -685,6 +716,10 @@ MGB_SEREG_OPR(LSQ, 4); | |||||
MGB_SEREG_OPR(LSQBackward, 5); | MGB_SEREG_OPR(LSQBackward, 5); | ||||
MGB_SEREG_OPR(LayerNorm, 0); | MGB_SEREG_OPR(LayerNorm, 0); | ||||
MGB_SEREG_OPR(LayerNormBackward, 0); | MGB_SEREG_OPR(LayerNormBackward, 0); | ||||
MGB_SEREG_OPR(RNNForward, 3); | |||||
MGB_SEREG_OPR(RNNBackward, 7); | |||||
MGB_SEREG_OPR(LSTMForward, 4); | |||||
MGB_SEREG_OPR(LSTMBackward, 9); | |||||
} // namespace opr | } // namespace opr | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -0,0 +1,323 @@ | |||||
/** | |||||
* \file src/opr/impl/dnn/rnn.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/opr/dnn/rnn.h" | |||||
#include "../internal/megdnn_opr_wrapper.inl" | |||||
#include "megbrain/graph/grad_impl.h" | |||||
#include "megbrain/opr/basic_arith_wrapper.h" | |||||
#include "megbrain/opr/blas.h" | |||||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/opr/utility.h" | |||||
using namespace mgb; | |||||
using namespace opr; | |||||
/* ================= RNNCell ================= */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNNCellForward); | |||||
RNNCellForward::RNNCellForward( | |||||
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, | |||||
VarNode* weight_hh, VarNode* bias_hh, const Param& param, | |||||
const OperatorNodeConfig& config) | |||||
: Super{input->owner_graph(), | |||||
config, | |||||
"rnn_cell", | |||||
{input, weight_ih, bias_ih, hx, weight_hh, bias_hh}} { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({input, weight_ih, bias_ih, hx, weight_hh, bias_hh}); | |||||
} | |||||
SymbolVar RNNCellForward::make( | |||||
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, | |||||
SymbolVar weight_hh, SymbolVar bias_hh, const Param& param, | |||||
const OperatorNodeConfig& config) { | |||||
return input.insert_single_output_opr<RNNCellForward>( | |||||
input.node(), weight_ih.node(), bias_ih.node(), hx.node(), weight_hh.node(), | |||||
bias_hh.node(), param, config); | |||||
} | |||||
#if MGB_ENABLE_GRAD | |||||
VarNode* rnnCellBackward( | |||||
const SymbolVar& input, const SymbolVar& weight_ih, const SymbolVar& hx, | |||||
const SymbolVar& weight_hh, const SymbolVar& out, | |||||
RNNCell::NonlineMode nonlineMode, size_t wrt_idx, const SymbolVar& og) { | |||||
SymbolVar tmp; | |||||
// activation | |||||
using NonlineMode = RNNCell::NonlineMode; | |||||
using Mode = Elemwise::Mode; | |||||
switch (nonlineMode) { | |||||
case NonlineMode::IDENTITY: | |||||
tmp = og; | |||||
break; | |||||
case NonlineMode::TANH: | |||||
tmp = Elemwise::make({out, og}, Mode::TANH_GRAD); | |||||
break; | |||||
case NonlineMode::RELU: | |||||
tmp = Elemwise::make({out, og}, Mode::SWITCH_GT0); | |||||
break; | |||||
default: | |||||
mgb_throw(GraphError, "Activation method not supported"); | |||||
} | |||||
// now grad is in tmp | |||||
if (wrt_idx == 2 || wrt_idx == 5) | |||||
return tmp.node(); // bias | |||||
SymbolVar result; | |||||
// A * Bt = C, A' = C' * B, B' = C't * A | |||||
if (wrt_idx == 0) { // input | |||||
result = MatrixMul::make( | |||||
tmp, weight_ih, | |||||
{false, false}); // transpose a false, transpose b false | |||||
} else if (wrt_idx == 1) { // weight_ih | |||||
result = MatrixMul::make(tmp, input, {true, false}); | |||||
} else if (wrt_idx == 3) { // hx | |||||
result = MatrixMul::make(tmp, weight_hh, {false, false}); | |||||
} else if (wrt_idx == 4) { // weight_hh | |||||
result = MatrixMul::make(tmp, hx, {true, false}); | |||||
} | |||||
return result.node(); | |||||
} | |||||
MGB_IMPL_OPR_GRAD(RNNCell) { | |||||
SymbolVar input(opr.input(0)), weight_ih(opr.input(1)), hx(opr.input(3)), | |||||
weight_hh(opr.input(4)); | |||||
SymbolVar out(opr.output(0)), og{out_grad.at(0)}; | |||||
return rnnCellBackward( | |||||
input, weight_ih, hx, weight_hh, out, opr.param().nonlineMode, wrt_idx, og); | |||||
} | |||||
#endif | |||||
/* ================= LSTMCell ================= */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTMCell); | |||||
LSTMCellForward::LSTMCellForward( | |||||
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, | |||||
VarNode* weight_hh, VarNode* bias_hh, VarNode* cx, const Param& param, | |||||
const OperatorNodeConfig& config) | |||||
: Super{input->owner_graph(), | |||||
config, | |||||
"lstm_cell", | |||||
{input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx}} { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx}); | |||||
} | |||||
SymbolVar LSTMCellForward::make( | |||||
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, | |||||
SymbolVar weight_hh, SymbolVar bias_hh, SymbolVar cx, const Param& param, | |||||
const OperatorNodeConfig& config) { | |||||
return input.insert_single_output_opr<LSTMCellForward>( | |||||
input.node(), weight_ih.node(), bias_ih.node(), hx.node(), weight_hh.node(), | |||||
bias_hh.node(), cx.node(), param, config); | |||||
} | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(LSTMCell) { | |||||
SymbolVar input(opr.input(0)), weight_ih(opr.input(1)), hx(opr.input(3)), | |||||
weight_hh(opr.input(4)), cx(opr.input(6)); | |||||
SymbolVar h_out(opr.output(0)), c_out(opr.output(1)), gates(opr.output(2)), | |||||
h_og{out_grad.at(0)}, c_og{out_grad.at(1)}, tmp; | |||||
size_t ghs = gates.shape()[1] / 4; // gate_hidden_size | |||||
SymbolVarArray gates_array = Split::make( | |||||
gates, Split::Options::make_partition(gates, 1, {ghs, ghs, ghs, ghs})); | |||||
mgb_assert(gates_array.size() == 4); | |||||
using Mode = Elemwise::Mode; | |||||
const SymbolVar &i(Elemwise::make({gates_array.at(0)}, Mode::SIGMOID)), | |||||
f(Elemwise::make({gates_array.at(1)}, Mode::SIGMOID)), | |||||
o(Elemwise::make({gates_array.at(2)}, Mode::SIGMOID)), | |||||
g(Elemwise::make({gates_array.at(3)}, Mode::TANH)); | |||||
SymbolVar i_grad, f_grad, o_grad, g_grad; | |||||
SymbolVar tanh_c_out = Elemwise::make({c_out}, Mode::TANH); | |||||
o_grad = Elemwise::make({o, h_og * tanh_c_out}, Mode::SIGMOID_GRAD); | |||||
c_og = c_og + Elemwise::make({tanh_c_out, h_og * o}, Mode::TANH_GRAD); | |||||
f_grad = Elemwise::make({f, c_og * cx}, Mode::SIGMOID_GRAD); | |||||
i_grad = Elemwise::make({i, c_og * g}, Mode::SIGMOID_GRAD); | |||||
g_grad = Elemwise::make({g, c_og * i}, Mode::TANH_GRAD); | |||||
SymbolVar rnn_cell_grad = Concat::make({i_grad, f_grad, o_grad, g_grad}, {-1}); | |||||
SymbolVar result; | |||||
if (wrt_idx < 6) { | |||||
using NonlineMode = RNNCell::NonlineMode; | |||||
result = rnnCellBackward( | |||||
input, weight_ih, hx, weight_hh, gates, NonlineMode::IDENTITY, wrt_idx, | |||||
rnn_cell_grad); | |||||
} else { // cx | |||||
result = c_og * f; | |||||
} | |||||
return result.node(); | |||||
} | |||||
#endif | |||||
/* ================= RNN ================= */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNN); | |||||
MEGDNN_OPR_INIT3(RNNForward, "rnn_fwd"); | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(RNN) { | |||||
mgb_assert( | |||||
opr.param().fwd_mode == RNN::Param::FwdMode::TRAINING, | |||||
"RNN could only take grad in training mode"); | |||||
SymbolVarArray grads = RNNBackward::make( | |||||
opr.input(0), opr.output(0), opr.input(1), out_grad.at(0), out_grad.at(1), | |||||
opr.input(2), opr.output(2), opr.param()); | |||||
// return grads.at(wrt_idx).node(); // input, hx, weights | |||||
VarNodeArray ret(opr.input().size(), nullptr); | |||||
for (size_t i = 0; i < ret.size(); ++i) { | |||||
ret[i] = grads[i].node(); | |||||
} | |||||
return ret; | |||||
} | |||||
#endif | |||||
/* ================= RNNBackward ================= */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNNBackward); | |||||
RNNBackward::RNNBackward( | |||||
VarNode* x, VarNode* y, VarNode* hx, VarNode* dy, VarNode* dhy, | |||||
VarNode* flatten_weights, VarNode* reserve_space, const Param& param, | |||||
const OperatorNodeConfig& config) | |||||
: Super({x->owner_graph(), | |||||
config, | |||||
"rnn_bwd", | |||||
{x, y, hx, dy, dhy, flatten_weights, reserve_space}}, | |||||
0, true) { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({x, y, hx, dy, dhy, flatten_weights, reserve_space}); | |||||
} | |||||
SymbolVarArray RNNBackward::make( | |||||
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar dy, SymbolVar dhy, | |||||
SymbolVar flatten_weights, SymbolVar reserve_space, const Param& param, | |||||
const OperatorNodeConfig& config) { | |||||
auto&& out = x.node()->owner_graph() | |||||
->insert_opr(std::make_unique<RNNBackward>( | |||||
x.node(), y.node(), hx.node(), dy.node(), dhy.node(), | |||||
flatten_weights.node(), reserve_space.node(), param, | |||||
config)) | |||||
->output(); | |||||
SymbolVarArray ret(out.size()); | |||||
for (size_t i = 0; i < ret.size(); ++i) { | |||||
ret[i] = out[i]; | |||||
} | |||||
return ret; | |||||
} | |||||
RNNBackward::Super::NodeProp* RNNBackward::do_make_node_prop() const { | |||||
auto ret = Super::do_make_node_prop(); | |||||
ret->add_dep_type_existing_var(input(6), NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
return ret; | |||||
} | |||||
void RNNBackward::init_output_static_infer_desc() { | |||||
using namespace cg::static_infer; | |||||
auto&& mgr = owner_graph()->static_infer_manager(); | |||||
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); | |||||
mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); | |||||
mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(5))); | |||||
this->init_output_static_infer_desc_workspace( | |||||
intl::AutoAddWorkspaceNeedLimitGetter<megdnn::RNNBackward>::val); | |||||
} | |||||
void RNNBackward::init_output_dtype() { | |||||
output(0)->dtype(input(0)->dtype()); | |||||
output(1)->dtype(input(2)->dtype()); | |||||
output(2)->dtype(input(5)->dtype()); | |||||
} | |||||
/* ================= LSTM ================= */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTM); | |||||
LSTMForward::LSTMForward( | |||||
VarNode* input, VarNode* hx, VarNode* cx, VarNode* flatten_weights, | |||||
const Param& param, const OperatorNodeConfig& config) | |||||
: Super{input->owner_graph(), | |||||
config, | |||||
"lstm", | |||||
{input, hx, cx, flatten_weights}} { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({input, hx, cx, flatten_weights}); | |||||
} | |||||
SymbolVar LSTMForward::make( | |||||
SymbolVar input, SymbolVar hx, SymbolVar cx, SymbolVar flatten_weights, | |||||
const Param& param, const OperatorNodeConfig& config) { | |||||
return input.insert_single_output_opr<LSTMForward>( | |||||
input.node(), hx.node(), cx.node(), flatten_weights.node(), param, config); | |||||
} | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(LSTM) { | |||||
SymbolVarArray grads = LSTMBackward::make( | |||||
opr.input(0), opr.output(0), opr.input(1), opr.input(2), out_grad.at(0), | |||||
out_grad.at(1), out_grad.at(2), opr.input(3), opr.output(3), opr.param()); | |||||
SymbolVar res; | |||||
return grads.at(wrt_idx).node(); // input, hx, cx, weights | |||||
} | |||||
#endif | |||||
/* ================= LSTMBackward ================= */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTMBackward); | |||||
LSTMBackward::LSTMBackward( | |||||
VarNode* x, VarNode* y, VarNode* hx, VarNode* cx, VarNode* dy, VarNode* dhy, | |||||
VarNode* dcy, VarNode* flatten_weights, VarNode* reserve_space, | |||||
const Param& param, const OperatorNodeConfig& config) | |||||
: Super({x->owner_graph(), | |||||
config, | |||||
"lstm_bwd", | |||||
{x, y, hx, cx, dy, dhy, dcy, flatten_weights, reserve_space}}, | |||||
1, true) { | |||||
init_megdnn_opr(*this, param); | |||||
add_input({x, y, hx, cx, dy, dhy, dcy, flatten_weights, reserve_space}); | |||||
} | |||||
SymbolVarArray LSTMBackward::make( | |||||
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar cx, SymbolVar dy, | |||||
SymbolVar dhy, SymbolVar dcy, SymbolVar flatten_weights, | |||||
SymbolVar reserve_space, const Param& param, const OperatorNodeConfig& config) { | |||||
auto&& out = x.node()->owner_graph() | |||||
->insert_opr(std::make_unique<LSTMBackward>( | |||||
x.node(), y.node(), hx.node(), cx.node(), dy.node(), | |||||
dhy.node(), dcy.node(), flatten_weights.node(), | |||||
reserve_space.node(), param, config)) | |||||
->output(); | |||||
SymbolVarArray ret(out.size()); | |||||
for (size_t i = 0; i < ret.size(); ++i) { | |||||
ret[i] = out[i]; | |||||
} | |||||
return ret; | |||||
} | |||||
LSTMBackward::Super::NodeProp* LSTMBackward::do_make_node_prop() const { | |||||
auto ret = Super::do_make_node_prop(); | |||||
ret->add_dep_type_existing_var( | |||||
input(8), // reserve space | |||||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
return ret; | |||||
} | |||||
void LSTMBackward::init_output_static_infer_desc() { | |||||
using namespace cg::static_infer; | |||||
auto&& mgr = owner_graph()->static_infer_manager(); | |||||
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); | |||||
mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); | |||||
mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(3))); | |||||
mgr.register_shape_infer(output(3), ShapeInferDesc::make_identity(input(7))); | |||||
this->init_output_static_infer_desc_workspace( | |||||
intl::AutoAddWorkspaceNeedLimitGetter<megdnn::LSTMBackward>::val); | |||||
} | |||||
void LSTMBackward::init_output_dtype() { | |||||
output(0)->dtype(input(0)->dtype()); | |||||
output(1)->dtype(input(2)->dtype()); | |||||
output(2)->dtype(input(3)->dtype()); | |||||
output(3)->dtype(input(7)->dtype()); | |||||
} |
@@ -170,6 +170,16 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU | |||||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0) | #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0) | ||||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | ||||
#define _NR_INPUTS 4 | |||||
#define _NR_OUTPUTS 4 | |||||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0), _o(1), _o(2), _o(3) | |||||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||||
#define _NR_INPUTS 5 | |||||
#define _NR_OUTPUTS 1 | |||||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0) | |||||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||||
#define _NR_INPUTS 5 | #define _NR_INPUTS 5 | ||||
#define _NR_OUTPUTS 2 | #define _NR_OUTPUTS 2 | ||||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1) | #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1) | ||||
@@ -180,11 +190,34 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU | |||||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2) | #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2) | ||||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | ||||
#define _NR_INPUTS 6 | |||||
#define _NR_OUTPUTS 1 | |||||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0) | |||||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||||
#define _NR_INPUTS 6 | |||||
#define _NR_OUTPUTS 2 | |||||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1) | |||||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||||
#define _NR_INPUTS 6 | #define _NR_INPUTS 6 | ||||
#define _NR_OUTPUTS 3 | #define _NR_OUTPUTS 3 | ||||
#define _FOREACH_IO(_i, _o) \ | #define _FOREACH_IO(_i, _o) \ | ||||
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2) | _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2) | ||||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | ||||
#define _NR_INPUTS 7 | |||||
#define _NR_OUTPUTS 3 | |||||
#define _FOREACH_IO(_i, _o) \ | |||||
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _o(0), _o(1), _o(2) | |||||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||||
#define _NR_INPUTS 9 | |||||
#define _NR_OUTPUTS 4 | |||||
#define _FOREACH_IO(_i, _o) \ | |||||
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _i(7), _i(8), _o(0), _o(1), \ | |||||
_o(2), _o(3) | |||||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||||
} // anonymous namespace | } // anonymous namespace | ||||
/* ======================= MegDNNOprWrapperFwd ======================= */ | /* ======================= MegDNNOprWrapperFwd ======================= */ | ||||
@@ -0,0 +1,120 @@ | |||||
/** | |||||
* \file src/opr/include/megbrain/opr/dnn/rnn.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||||
#if MGB_CUDA | |||||
#include "../../../../impl/nvof/denseflownvidia.h" | |||||
#include "megbrain/opr/param_defs.h" | |||||
#endif | |||||
#include "megdnn/oprs.h" | |||||
namespace mgb { | |||||
namespace opr { | |||||
MGB_DEFINE_OPR_CLASS( | |||||
RNNCellForward, intl::MegDNNOprWrapperFwd<megdnn::RNNCellForward>) // { | |||||
public: | |||||
using NonlineMode = Param::NonlineMode; | |||||
RNNCellForward( | |||||
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, | |||||
VarNode* weight_hh, VarNode* bias_hh, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVar make( | |||||
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, | |||||
SymbolVar weight_hh, SymbolVar bias_hh, const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
}; | |||||
using RNNCell = RNNCellForward; | |||||
MGB_DEFINE_OPR_CLASS( | |||||
LSTMCellForward, intl::MegDNNOprWrapperFwd<megdnn::LSTMCellForward>) // { | |||||
public: | |||||
LSTMCellForward( | |||||
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, | |||||
VarNode* weight_hh, VarNode* bias_hh, VarNode* cx, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVar make( | |||||
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, | |||||
SymbolVar weight_hh, SymbolVar bias_hh, SymbolVar cx, | |||||
const Param& param = {}, const OperatorNodeConfig& config = {}); | |||||
}; | |||||
using LSTMCell = LSTMCellForward; | |||||
MGB_DEFINE_OPR_CLASS(RNNForward, intl::MegDNNOprWrapperFwd<megdnn::RNNForward>) // { | |||||
/*private: | |||||
SymbolVarArray weight_ih_arr; // 1d, idx: direction * num_layers + layer | |||||
SymbolVarArray weight_hh_arr; | |||||
SymbolVarArray bias_arr; | |||||
*/ | |||||
public: | |||||
RNNForward( | |||||
VarNode* input, VarNode* hx, VarNode* flatten_weights, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVar make( | |||||
SymbolVar input, SymbolVar hx, SymbolVar flatten_weights, | |||||
const Param& param = {}, const OperatorNodeConfig& config = {}); | |||||
}; | |||||
using RNN = RNNForward; | |||||
MGB_DEFINE_OPR_CLASS( | |||||
RNNBackward, intl::MegDNNOprWrapperBwd<megdnn::RNNBackward>) // { | |||||
public: | |||||
RNNBackward( | |||||
VarNode* x, VarNode* y, VarNode* hx, VarNode* dy, VarNode* dhy, | |||||
VarNode* flatten_weights, VarNode* reserve_space, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVarArray make( | |||||
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar dy, SymbolVar dhy, | |||||
SymbolVar flatten_weights, SymbolVar reserve_space, const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
Super::NodeProp* do_make_node_prop() const override; | |||||
private: | |||||
void init_output_static_infer_desc() override; | |||||
void init_output_dtype() override; | |||||
}; | |||||
MGB_DEFINE_OPR_CLASS( | |||||
LSTMForward, intl::MegDNNOprWrapperFwd<megdnn::LSTMForward>) // { | |||||
public: | |||||
LSTMForward( | |||||
VarNode* input, VarNode* hx, VarNode* cx, VarNode* flatten_weights, | |||||
const Param& param, const OperatorNodeConfig& config); | |||||
static SymbolVar make( | |||||
SymbolVar input, SymbolVar hx, SymbolVar cx, SymbolVar flatten_weights, | |||||
const Param& param = {}, const OperatorNodeConfig& config = {}); | |||||
}; | |||||
using LSTM = LSTMForward; | |||||
MGB_DEFINE_OPR_CLASS( | |||||
LSTMBackward, intl::MegDNNOprWrapperBwd<megdnn::LSTMBackward>) // { | |||||
public: | |||||
LSTMBackward( | |||||
VarNode* x, VarNode* y, VarNode* hx, VarNode* cx, VarNode* dy, VarNode* dhy, | |||||
VarNode* dcy, VarNode* flatten_weights, VarNode* reserve_space, | |||||
const Param& param, const OperatorNodeConfig& config); | |||||
static SymbolVarArray make( | |||||
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar cx, SymbolVar dy, | |||||
SymbolVar dhy, SymbolVar dcy, SymbolVar flatten_weights, | |||||
SymbolVar reserve_space, const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
Super::NodeProp* do_make_node_prop() const override; | |||||
private: | |||||
void init_output_static_infer_desc() override; | |||||
void init_output_dtype() override; | |||||
}; | |||||
} // namespace opr | |||||
} // namespace mgb |
@@ -118,6 +118,9 @@ union OperatorParam { | |||||
param.CheckNonFinite = 84, | param.CheckNonFinite = 84, | ||||
param.LayerNorm = 85, | param.LayerNorm = 85, | ||||
param.Dropout = 86, | param.Dropout = 86, | ||||
param.RNNCell = 87, | |||||
param.RNN = 88, | |||||
param.LSTM = 89, | |||||
} | } | ||||
table Operator { | table Operator { | ||||