From 8f48da7ffe03baa037aada4c059309952da053a9 Mon Sep 17 00:00:00 2001 From: "kxz@thumt102-1" <15068701650@163.com> Date: Sat, 9 Oct 2021 22:02:45 +0800 Subject: [PATCH] feat(mgb/opr): add cell level rnn/lstm and sequence level rnn/lstm --- .gitignore | 11 +- dnn/include/megdnn/oprs/nn.h | 192 +++++++++++ dnn/scripts/opr_param_defs.py | 25 ++ dnn/src/common/handle.cpp | 6 +- dnn/src/common/handle_impl.h | 8 +- dnn/src/common/lstm.cpp | 98 ++++++ dnn/src/common/lstm_cell.cpp | 136 ++++++++ dnn/src/common/lstm_cell.h | 32 ++ dnn/src/common/opr_trait.h | 2 + dnn/src/common/relayout_format.cpp | 5 +- dnn/src/common/rnn.cpp | 82 +++++ dnn/src/common/rnn.h | 33 ++ dnn/src/common/rnn_cell.cpp | 108 ++++++ dnn/src/common/rnn_cell.h | 31 ++ dnn/src/cuda/cudnn_wrapper.cpp | 114 +++++++ dnn/src/cuda/cudnn_wrapper.h | 39 +++ dnn/src/cuda/handle_create.cpp | 4 + dnn/src/cuda/lstm/opr_impl.cpp | 112 ++++++ dnn/src/cuda/lstm/opr_impl.h | 56 +++ dnn/src/cuda/lstm/utils.cpp | 39 +++ dnn/src/cuda/lstm/utils.h | 23 ++ dnn/src/cuda/lstm_cell/opr_impl.cpp | 42 +++ dnn/src/cuda/lstm_cell/opr_impl.h | 36 ++ dnn/src/cuda/rnn/opr_impl.cpp | 170 ++++++++++ dnn/src/cuda/rnn/opr_impl.h | 57 ++++ dnn/src/cuda/rnn/utils.cpp | 138 ++++++++ dnn/src/cuda/rnn/utils.h | 56 +++ dnn/src/cuda/rnn_cell/opr_impl.cpp | 35 ++ dnn/src/cuda/rnn_cell/opr_impl.h | 40 +++ dnn/src/naive/handle.cpp | 4 + dnn/src/naive/lstm/opr_impl.cpp | 146 ++++++++ dnn/src/naive/lstm/opr_impl.h | 56 +++ dnn/src/naive/lstm/template_impl.cpp | 55 +++ dnn/src/naive/lstm_cell/opr_impl.cpp | 38 +++ dnn/src/naive/lstm_cell/opr_impl.h | 36 ++ dnn/src/naive/relayout/opr_impl.cpp | 4 +- dnn/src/naive/rnn/funcs.h | 75 +++++ dnn/src/naive/rnn/funcs.tpp | 449 +++++++++++++++++++++++++ dnn/src/naive/rnn/opr_impl.cpp | 196 +++++++++++ dnn/src/naive/rnn/opr_impl.h | 53 +++ dnn/src/naive/rnn/rnn.cpp | 285 ++++++++++++++++ dnn/src/naive/rnn/rnn.h | 73 ++++ dnn/src/naive/rnn/template_impl.cpp | 41 +++ dnn/src/naive/rnn_cell/opr_impl.cpp | 34 ++ dnn/src/naive/rnn_cell/opr_impl.h | 33 ++ dnn/test/common/deduce_layout_proxy.h | 9 + dnn/test/common/elemwise.cpp | 5 +- dnn/test/common/rnn.h | 51 +++ dnn/test/naive/rnn.cpp | 80 +++++ imperative/python/megengine/module/__init__.py | 1 + imperative/python/megengine/module/rnn.py | 396 ++++++++++++++++++++++ imperative/python/test/unit/module/test_rnn.py | 181 ++++++++++ imperative/src/impl/ops/rnn.cpp | 68 ++++ src/core/include/megbrain/ir/ops.td | 8 + src/opr/impl/dnn/dnn.sereg.h | 35 ++ src/opr/impl/dnn/rnn.cpp | 323 ++++++++++++++++++ src/opr/impl/internal/megdnn_opr_wrapper.inl | 33 ++ src/opr/include/megbrain/opr/dnn/rnn.h | 120 +++++++ src/serialization/impl/schema.fbs | 3 + 59 files changed, 4601 insertions(+), 20 deletions(-) create mode 100644 dnn/src/common/lstm.cpp create mode 100644 dnn/src/common/lstm_cell.cpp create mode 100644 dnn/src/common/lstm_cell.h create mode 100644 dnn/src/common/rnn.cpp create mode 100644 dnn/src/common/rnn.h create mode 100644 dnn/src/common/rnn_cell.cpp create mode 100644 dnn/src/common/rnn_cell.h create mode 100644 dnn/src/cuda/lstm/opr_impl.cpp create mode 100644 dnn/src/cuda/lstm/opr_impl.h create mode 100644 dnn/src/cuda/lstm/utils.cpp create mode 100644 dnn/src/cuda/lstm/utils.h create mode 100644 dnn/src/cuda/lstm_cell/opr_impl.cpp create mode 100644 dnn/src/cuda/lstm_cell/opr_impl.h create mode 100644 dnn/src/cuda/rnn/opr_impl.cpp create mode 100644 dnn/src/cuda/rnn/opr_impl.h create mode 100644 dnn/src/cuda/rnn/utils.cpp create mode 100644 dnn/src/cuda/rnn/utils.h create mode 100644 dnn/src/cuda/rnn_cell/opr_impl.cpp create mode 100644 dnn/src/cuda/rnn_cell/opr_impl.h create mode 100644 dnn/src/naive/lstm/opr_impl.cpp create mode 100644 dnn/src/naive/lstm/opr_impl.h create mode 100644 dnn/src/naive/lstm/template_impl.cpp create mode 100644 dnn/src/naive/lstm_cell/opr_impl.cpp create mode 100644 dnn/src/naive/lstm_cell/opr_impl.h create mode 100644 dnn/src/naive/rnn/funcs.h create mode 100644 dnn/src/naive/rnn/funcs.tpp create mode 100644 dnn/src/naive/rnn/opr_impl.cpp create mode 100644 dnn/src/naive/rnn/opr_impl.h create mode 100644 dnn/src/naive/rnn/rnn.cpp create mode 100644 dnn/src/naive/rnn/rnn.h create mode 100644 dnn/src/naive/rnn/template_impl.cpp create mode 100644 dnn/src/naive/rnn_cell/opr_impl.cpp create mode 100644 dnn/src/naive/rnn_cell/opr_impl.h create mode 100644 dnn/test/common/rnn.h create mode 100644 dnn/test/naive/rnn.cpp create mode 100644 imperative/python/megengine/module/rnn.py create mode 100644 imperative/python/test/unit/module/test_rnn.py create mode 100644 imperative/src/impl/ops/rnn.cpp create mode 100644 src/opr/impl/dnn/rnn.cpp create mode 100644 src/opr/include/megbrain/opr/dnn/rnn.h diff --git a/.gitignore b/.gitignore index 262be6cb..86689858 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,3 @@ -# Build -build/ -output/ - # Cache __pycache__/ .ccls-cache/ @@ -11,5 +7,8 @@ __pycache__/ .vs/ .idea/ -# CMake -compile_commands.json \ No newline at end of file +# Make and Build Settings +build/ +output/ +compile_commands.json +imperative/python/megengine/core/*.so diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index e5ea399c..5afc4b33 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -1936,6 +1936,198 @@ protected: const TensorLayout& grad_s, size_t workspace_in_bytes); }; +class RNNCellForward : public OperatorBase { + DEF_OPR_PARAM(RNNCell); + DEF_OPR_IMPL(RNNCellForward, OperatorBase, 6, 1); + +public: + virtual void exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + static void deduce_layout( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& dst, size_t workspace_in_bytes); +}; +using RNNCell = RNNCellForward; + +class LSTMCellForward : public OperatorBase { + // DEF_OPR_PARAM(LSTMCell); + DEF_OPR_PARAM(Empty); + DEF_OPR_IMPL(LSTMCellForward, OperatorBase, 7, 3); + +public: + virtual void exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, + _megdnn_tensor_out gates, _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new, + TensorLayout& gates); + virtual size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, const TensorLayout& h_new, + const TensorLayout& c_new, const TensorLayout& gates) = 0; + +protected: + void check_exec( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, const TensorLayout& h_new, + const TensorLayout& c_new, const TensorLayout& gates, + size_t workspace_in_bytes); +}; +using LSTMCell = LSTMCellForward; + +class RNNForward : public OperatorBase { + DEF_OPR_PARAM(RNN); + DEF_OPR_IMPL(RNNForward, OperatorBase, 3, 3); + +public: + virtual void exec( + _megdnn_tensor_in input, _megdnn_tensor_in hx, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, + _megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& input, const TensorLayout& hx, + const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy, + TensorLayout& reserve_space); + virtual size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& hx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& reserve_space) = 0; + virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0; + +protected: + void check_exec( + const TensorLayout& input, const TensorLayout& hx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& reserve_space, + size_t workspace_in_bytes); +}; +using RNN = RNNForward; + +class RNNBackward : public OperatorBase { + DEF_OPR_PARAM(RNN); + DEF_OPR_IMPL(RNNBackward, OperatorBase, 7, 3); + +public: + virtual void exec( + _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, + _megdnn_tensor_in dy, _megdnn_tensor_in dhy, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space, + _megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& flatten_weights, const TensorLayout& reserve_space, + TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw); + virtual size_t get_workspace_in_bytes( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& flatten_weights, const TensorLayout& reserve_space, + const TensorLayout& dx, const TensorLayout& dhx, + const TensorLayout& dw) = 0; + +protected: + void check_exec( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& flatten_weights, const TensorLayout& reserve_space, + const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw, + size_t workspace_in_bytes); +}; + +class LSTMForward : public OperatorBase { + DEF_OPR_PARAM(LSTM); + DEF_OPR_IMPL(LSTMForward, OperatorBase, 4, 4); + +public: + virtual void exec( + _megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, + _megdnn_tensor_out hy, _megdnn_tensor_out cy, + _megdnn_tensor_out reserve_space, _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, + const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy, + TensorLayout& cy, TensorLayout& reserve_space); + virtual size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& cy, + const TensorLayout& reserve_space) = 0; + virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0; + +protected: + void check_exec( + const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& cy, + const TensorLayout& reserve_space, size_t workspace_in_bytes); +}; +using LSTM = LSTMForward; + +class LSTMBackward : public OperatorBase { + DEF_OPR_PARAM(LSTM); + DEF_OPR_IMPL(LSTMBackward, OperatorBase, 9, 4); + +public: + virtual void exec( + _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, + _megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, + _megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, + _megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, + _megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& dcy, const TensorLayout& flatten_weights, + const TensorLayout& reserve_space, TensorLayout& dx, TensorLayout& dhx, + TensorLayout& dcx, TensorLayout& dw); + virtual size_t get_workspace_in_bytes( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& dcy, const TensorLayout& flatten_weights, + const TensorLayout& reserve_space, const TensorLayout& dx, + const TensorLayout& dhx, const TensorLayout& dcx, + const TensorLayout& dw) = 0; + +protected: + void check_exec( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& dcy, const TensorLayout& flatten_weights, + const TensorLayout& reserve_space, const TensorLayout& dx, + const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw, + size_t workspace_in_bytes); +}; } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 8eba9199..fe10f4b0 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1203,3 +1203,28 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), member_alias=[(i, 'PADDING_{}'.format(i)) for i in PADDING_MODES] ) ) + +(pdef('RNNCell'). + add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'TANH = 2') + ) + +(pdef('RNN'). + add_fields('uint32', 'num_layers', '1'). + add_fields('bool', 'bidirectional', 'false'). + add_fields('bool', 'bias', 'true'). + add_fields('uint32', 'hidden_size', '128'). + add_fields('uint32', 'proj_size', '0'). + add_fields('float32', 'dropout', '0.f'). + add_enum_alias('NonlineMode', 'RNNCell'). + add_enum_alias('FwdMode', 'BN', name_field='fwd_mode') + ) + +(pdef('LSTM'). + add_fields('uint32', 'num_layers', '1'). + add_fields('bool', 'bidirectional', 'false'). + add_fields('bool', 'bias', 'true'). + add_fields('uint32', 'hidden_size', '128'). + add_fields('uint32', 'proj_size', '0'). + add_fields('float32', 'dropout', '0.f'). + add_enum_alias('FwdMode', 'BN', name_field='fwd_mode') + ) diff --git a/dnn/src/common/handle.cpp b/dnn/src/common/handle.cpp index 17f4718c..bc3b68a6 100644 --- a/dnn/src/common/handle.cpp +++ b/dnn/src/common/handle.cpp @@ -92,8 +92,7 @@ std::unique_ptr Handle::make( } MIDOUT_END(); - } - else if (platform == megcorePlatformROCM) { + } else if (platform == megcorePlatformROCM) { #if MEGDNN_WITH_ROCM return make_rocm_handle(computing_handle); #else @@ -111,8 +110,7 @@ std::unique_ptr Handle::make( #else return nullptr; #endif - } - else { + } else { // CUDA megdnn_throw_if( platform != megcorePlatformCUDA, megdnn_error, diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 7c3e01a1..5124f65b 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -209,7 +209,13 @@ private: cb(LSQBackward) \ cb(Fill) \ cb(PaddingForward) \ - cb(PaddingBackward) + cb(PaddingBackward) \ + cb(RNNCell) \ + cb(LSTMCell) \ + cb(RNN) \ + cb(RNNBackward) \ + cb(LSTM) \ + cb(LSTMBackward) // clang-format on /*! diff --git a/dnn/src/common/lstm.cpp b/dnn/src/common/lstm.cpp new file mode 100644 index 00000000..4275a582 --- /dev/null +++ b/dnn/src/common/lstm.cpp @@ -0,0 +1,98 @@ +/** + * \file dnn/src/common/lstm.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "megdnn/oprs.h" +#include "src/common/utils.h" +// #include "src/cuda/lstm/utils.h" + +namespace megdnn { + +/*size_t get_reserve_size(Handle* handle, megdnn::LSTMForward::Param& param, const +TensorLayout& input) { #if CUDNN_MAJOR >= 6 auto holder = +megdnn::cuda::lstm::get_RNNDescHolder_v6(handle, param, input); return +holder.reserveSpace_size; # else return 0; #endif +}*/ + +void LSTM::deduce_layout( + const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, + const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy, + TensorLayout& cy, TensorLayout& reserve_space) { + // input: [seq_len, batch_size, input_size] + // hx: [D * num_layers, batch_size, hidden_size] + size_t seq_len = input.shape[0]; + size_t batch_size = input.shape[1]; + size_t D = param().bidirectional ? 2 : 1; + size_t hidden_size = hx.shape[2]; + output = TensorLayout( + TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype); + hy = TensorLayout(hx); + cy = TensorLayout(cx); + // reserve_space = {{get_reserve_size(this->handle(), param(), input)}, + // dtype::Byte()}; + reserve_space = {{get_reserve_size_in_bytes(input)}, dtype::Byte()}; +} + +void LSTM::check_exec( + const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& cy, + const TensorLayout& reserve_space, size_t workspace_in_bytes) { + auto errmsg = [&]() { + std::string msg; + msg.append("input="); + msg.append(input.to_string()); + msg.append(", hx="); + msg.append(hx.to_string()); + msg.append(", cx="); + msg.append(cx.to_string()); + msg.append(", hidden_size="); + msg.append(std::to_string(param().hidden_size)); + msg.append(", num_layers="); + msg.append(std::to_string(param().num_layers)); + msg.append(", bidirectional="); + msg.append(std::to_string(param().bidirectional)); + return msg; + }; + size_t D = param().bidirectional ? 2 : 1; + size_t num_layers = param().num_layers; +#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str()); + + ASSERT_BRIEF(hx.ndim == 3) + ASSERT_BRIEF(hx.shape[0] == D * num_layers) + ASSERT_BRIEF(hx.shape[1] == input.shape[1]) // batch_size + ASSERT_BRIEF(hx.shape[2] == param().hidden_size) + ASSERT_BRIEF(cx.ndim == 3) + ASSERT_BRIEF(cx.shape[0] == D * num_layers) + ASSERT_BRIEF(cx.shape[1] == input.shape[1]) // batch_size + ASSERT_BRIEF(cx.shape[2] == param().hidden_size) +#undef ASSERT_BRIEF +} + +void LSTMBackward::deduce_layout( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& dcy, const TensorLayout& flatten_weights, + const TensorLayout& reserve_space, TensorLayout& dx, TensorLayout& dhx, + TensorLayout& dcx, TensorLayout& dw) { + dx = x; + dhx = hx; + dcx = cx; + dw = flatten_weights; +} + +void LSTMBackward::check_exec( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& dcy, const TensorLayout& flatten_weights, + const TensorLayout& reserve_space, const TensorLayout& dx, + const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw, + size_t workspace_in_bytes) {} + +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/common/lstm_cell.cpp b/dnn/src/common/lstm_cell.cpp new file mode 100644 index 00000000..35abdf4e --- /dev/null +++ b/dnn/src/common/lstm_cell.cpp @@ -0,0 +1,136 @@ +/** + * \file dnn/src/common/lstm_cell.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/common/lstm_cell.h" +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { + +void LSTMCell::deduce_layout( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new, + TensorLayout& gates) { + // size_t batch_size = hx.shape[0]; + // size_t hidden_size = hx.shape[1]; + h_new = TensorLayout(hx, hx.dtype); + c_new = TensorLayout(cx, cx.dtype); + auto opr = handle()->create_operator(); + opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY; + opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates); +} + +void LSTMCell::check_exec( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, + const TensorLayout& gates, size_t workspace_in_bytes) { + TensorLayout h_new_expected, c_new_expected, gates_expected; + deduce_layout( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new_expected, + c_new_expected, gates_expected); + megdnn_assert_eq_layout(h_new_expected, h_new); + megdnn_assert_eq_layout(c_new_expected, c_new); + megdnn_assert_eq_layout(gates_expected, gates); + + auto required_workspace_in_bytes = get_workspace_in_bytes( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +} // namespace megdnn + +namespace megdnn { +namespace lstm_cell { + +size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, + const TensorLayout& gates, Handle* handle) { + TensorLayout tmp_layout; + auto opr = handle->create_operator(); + opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY; + opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, tmp_layout); + size_t rnn_cell_need = opr->get_workspace_in_bytes( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates); + size_t lstm_cell_need = tmp_layout.span().dist_byte(); + return rnn_cell_need > lstm_cell_need ? rnn_cell_need : lstm_cell_need; +} + +void exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, + _megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle) { + auto opr = handle->create_operator(); + opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY; + /*TensorLayout tmp_layout; + opr->deduce_layout(input.layout, weight_ih.layout, + hx.layout, weight_hh.layout, + bias.layout, tmp_layout); + auto workspace_ptr = workspace.raw_ptr; + // TensorND tmp{static_cast(workspace.raw_ptr), tmp_layout}; + TensorND tmp{workspace_ptr, tmp_layout}; + auto new_workspace = Workspace{workspace_ptr + tmp.layout.span().dist_byte(), + workspace.size - + tmp.layout.span().dist_byte()};*/ + // opr->exec(input, weight_ih, hx, weight_hh, bias, tmp, new_workspace); + opr->exec(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates, workspace); + // activation + // size_t batch_size = tmp.layout.shape[0]; + size_t batch_size = hx.layout.shape[0]; + size_t hidden_size = hx.layout.shape[1]; + // sigmoid: i f o + // TensorLayout gates_ifo_layout{TensorShape({batch_size, hidden_size * 3}), + // tmp.layout.dtype}; + TensorND tmp{static_cast(workspace.raw_ptr), gates.layout}; + TensorLayout gates_ifo_layout{ + TensorShape({batch_size, hidden_size * 3}), gates.layout.dtype}; + TensorND gates_ifo_origin{gates.raw_ptr(), gates_ifo_layout}; + TensorND gates_ifo{tmp.raw_ptr(), gates_ifo_layout}; + auto sigmoid = handle->create_operator(); + sigmoid->param().mode = Elemwise::Param::Mode::SIGMOID; + sigmoid->exec({gates_ifo_origin}, gates_ifo); + // tanh: g + TensorLayout g_layout{TensorShape({batch_size, hidden_size}), gates.layout.dtype}; + TensorND g_origin{ + static_cast(gates.raw_ptr()) + gates_ifo_layout.span().dist_byte(), + g_layout}; + TensorND g{ + static_cast(tmp.raw_ptr()) + gates_ifo_layout.span().dist_byte(), + g_layout}; + auto tanh = handle->create_operator(); + tanh->param().mode = Elemwise::Param::Mode::TANH; + tanh->exec({g_origin}, g); + // extract i f o + TensorND i{static_cast(tmp.raw_ptr()), g_layout}; + TensorND f{ + static_cast(tmp.raw_ptr()) + g_layout.span().dist_byte(), g_layout}; + TensorND o{ + static_cast(tmp.raw_ptr()) + g_layout.span().dist_byte() * 2, + g_layout}; + // calculate new cell state + auto elewise_mul_add = handle->create_operator(); + elewise_mul_add->param().mode = Elemwise::Param::Mode::FUSE_MUL_ADD4; + elewise_mul_add->exec({f, cx, i, g}, c_new); + // calculate new hidden state + tanh->exec({c_new}, h_new); + auto elewise_mul = handle->create_operator(); + elewise_mul->param().mode = Elemwise::Param::Mode::MUL; + elewise_mul->exec({o, h_new}, h_new); +} + +} // namespace lstm_cell +} // namespace megdnn diff --git a/dnn/src/common/lstm_cell.h b/dnn/src/common/lstm_cell.h new file mode 100644 index 00000000..472fb3f2 --- /dev/null +++ b/dnn/src/common/lstm_cell.h @@ -0,0 +1,32 @@ +/** + * \file dnn/src/common/lstm_cell.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "megdnn/oprs/base.h" + +namespace megdnn { +namespace lstm_cell { + +size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, + const TensorLayout& gates, Handle* handle); + +void exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, + _megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle); + +} // namespace lstm_cell +} // namespace megdnn diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 8999b736..32528b8e 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -135,6 +135,8 @@ DEF(CheckNonFinite, 2, true, true); DEF(LSQForward, 5, true, true); DEF(LSQBackward, 7, true, false); DEF(Fill, 1, true, false); +DEF(RNNCellForward, 6, true, true); +DEF(RNNForward, 6, true, true); } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/relayout_format.cpp b/dnn/src/common/relayout_format.cpp index 56a91c39..48760fc0 100644 --- a/dnn/src/common/relayout_format.cpp +++ b/dnn/src/common/relayout_format.cpp @@ -402,9 +402,8 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { } if (dst.type() == TensorFormat::Type::IMAGE2D_PACK4 && - ( - handle()->type() != Handle::HandleType::NAIVE && - handle()->type() != Handle::HandleType::X86)) { + (handle()->type() != Handle::HandleType::NAIVE && + handle()->type() != Handle::HandleType::X86)) { megdnn_throw( "Dump with Image2DPack4TensorFormat is not available on CUDA compnode, " "try export CUDA_VISIBLE_DEVICES=\'\'"); diff --git a/dnn/src/common/rnn.cpp b/dnn/src/common/rnn.cpp new file mode 100644 index 00000000..2cf3b2f0 --- /dev/null +++ b/dnn/src/common/rnn.cpp @@ -0,0 +1,82 @@ +/** + * \file dnn/src/common/rnn.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/common/rnn.h" +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { + +void RNN::deduce_layout( + const TensorLayout& input, const TensorLayout& hx, + const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy, + TensorLayout& reserve_space) { + // input: [seq_len, batch_size, input_size] + // hx: [D * num_layers, batch_size, hidden_size] + size_t seq_len = input.shape[0]; + size_t batch_size = input.shape[1]; + size_t D = param().bidirectional ? 2 : 1; + size_t hidden_size = hx.shape[2]; + output = TensorLayout( + TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype); + hy = TensorLayout(hx); + // reserve_space = {{get_reserve_size(this->handle(), param(), input)}, + // dtype::Byte()}; + reserve_space = {{get_reserve_size_in_bytes(input)}, dtype::Byte()}; +} + +void RNN::check_exec( + const TensorLayout& input, const TensorLayout& hx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& reserve_space, + size_t workspace_in_bytes) { + auto errmsg = [&]() { + std::string msg; + msg.append("input="); + msg.append(input.to_string()); + msg.append(", hx="); + msg.append(hx.to_string()); + msg.append(", hidden_size="); + msg.append(std::to_string(param().hidden_size)); + msg.append(", num_layers="); + msg.append(std::to_string(param().num_layers)); + msg.append(", bidirectional="); + msg.append(std::to_string(param().bidirectional)); + return msg; + }; + size_t D = param().bidirectional ? 2 : 1; + size_t num_layers = param().num_layers; +#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str()); + + ASSERT_BRIEF(hx.ndim == 3) + ASSERT_BRIEF(hx.shape[0] == D * num_layers) + ASSERT_BRIEF(hx.shape[1] == input.shape[1]) // batch_size + ASSERT_BRIEF(hx.shape[2] == param().hidden_size) +#undef ASSERT_BRIEF +} + +void RNNBackward::deduce_layout( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& flatten_weights, const TensorLayout& reserve_space, + TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw) { + dx = x; + dhx = hx; + dw = flatten_weights; +} + +void RNNBackward::check_exec( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& flatten_weights, const TensorLayout& reserve_space, + const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw, + size_t workspace_in_bytes) {} + +} // namespace megdnn diff --git a/dnn/src/common/rnn.h b/dnn/src/common/rnn.h new file mode 100644 index 00000000..e98651b4 --- /dev/null +++ b/dnn/src/common/rnn.h @@ -0,0 +1,33 @@ +/** + * \file dnn/src/common/rnn.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/opr_param_defs.h" +#include "megdnn/oprs/base.h" +#include "megdnn/oprs/general.h" + +namespace megdnn { +namespace rnn { +using Param = param::RNN; + +size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayoutArray& weight_ih, + const TensorLayoutArray& states, const TensorLayoutArray& weight_hh, + const TensorLayoutArray& bias, const TensorLayout& output, + const TensorLayoutArray& states_new, const Param& param, Handle* handle); + +void exec( + _megdnn_tensor_in input, _megdnn_in const TensorNDArray& weight_ih, + _megdnn_in const TensorNDArray& states, + _megdnn_in const TensorNDArray& weight_hh, _megdnn_in const TensorNDArray& bias, + _megdnn_tensor_out output, _megdnn_out const TensorNDArray& states_new, + _megdnn_workspace workspace, const Param& param, Handle* handle); +} // namespace rnn +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/common/rnn_cell.cpp b/dnn/src/common/rnn_cell.cpp new file mode 100644 index 00000000..0cdbc259 --- /dev/null +++ b/dnn/src/common/rnn_cell.cpp @@ -0,0 +1,108 @@ +/** + * \file dnn/src/common/rnn_cell.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/common/rnn_cell.h" +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { + +void RNNCell::deduce_layout( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, TensorLayout& dst) { + // megdnn_assert(hx.ndim == 2); + size_t batch_size = hx.shape[0]; + // size_t hidden_size = weight_hh.shape[1]; + size_t gate_hidden_size = weight_ih.shape[0]; + // size_t input_size = weight_ih.shape[1]; + // megdnn_assert(input.shape[1] == input_size); + // megdnn_assert(hx.shape[1] == hidden_size); + // megdnn_assert_eq_dtype(input, hx); + + dst = TensorLayout(TensorShape({batch_size, gate_hidden_size}), input.dtype); +} + +void RNNCell::check_exec( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& dst, size_t workspace_in_bytes) { + TensorLayout dst_expected; + megdnn_assert_eq_dtype(input, dst); + megdnn_assert_eq_dtype(hx, dst); + deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst_expected); + megdnn_assert_eq_layout(dst_expected, dst); + + auto required_workspace_in_bytes = get_workspace_in_bytes( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +} // namespace megdnn + +namespace megdnn { +namespace rnn_cell { + +size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& dst, Handle* handle) { + auto opr = handle->create_operator(); + opr->param().transposeB = true; + return dst.span().dist_byte() + opr->get_workspace_in_bytes(hx, weight_hh, dst); +} + +void exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_out dst, _megdnn_workspace workspace, + param::RNNCell::NonlineMode nonline_mode, Handle* handle) { + TensorND tmp{static_cast(workspace.raw_ptr), dst.layout}; + _megdnn_workspace new_workspace = { + workspace.raw_ptr + dst.layout.span().dist_byte(), + workspace.size - dst.layout.span().dist_byte()}; + auto opr = handle->create_operator(); + opr->param().transposeB = true; + opr->exec(input, weight_ih, tmp, new_workspace); + opr->exec(hx, weight_hh, dst, new_workspace); + // if (this->param().bias) add_bias(dst, tmp, bias, dst); + // if (this->param().bias) { + auto add_opr = handle->create_operator(); + add_opr->param().mode = Elemwise::Param::Mode::ADD; + add_opr->exec({dst, tmp}, dst); + add_opr->exec({dst, bias_ih}, dst); + add_opr->exec({dst, bias_hh}, dst); + // } + + // activation + using NonlineMode = param::RNNCell::NonlineMode; + + switch (nonline_mode) { +#define cb(_mode) \ + case NonlineMode::_mode: { \ + auto nonlinear = handle->create_operator(); \ + nonlinear->param().mode = Elemwise::Param::Mode::_mode; \ + nonlinear->exec({dst}, dst); \ + break; \ + } + cb(RELU); + cb(TANH); +#undef cb + case NonlineMode::IDENTITY: + break; + default: + megdnn_assert(false); + } +} + +} // namespace rnn_cell +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/common/rnn_cell.h b/dnn/src/common/rnn_cell.h new file mode 100644 index 00000000..3906da1d --- /dev/null +++ b/dnn/src/common/rnn_cell.h @@ -0,0 +1,31 @@ +/** + * \file dnn/src/common/rnn_cell.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs/base.h" +#include "megdnn/oprs/general.h" + +namespace megdnn { +namespace rnn_cell { + +size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& dst, Handle* handle); + +void exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_out dst, _megdnn_workspace workspace, + param::RNNCell::NonlineMode nonline_mode, Handle* handle); + +} // namespace rnn_cell +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/cudnn_wrapper.cpp b/dnn/src/cuda/cudnn_wrapper.cpp index f8080afb..16d73ef6 100644 --- a/dnn/src/cuda/cudnn_wrapper.cpp +++ b/dnn/src/cuda/cudnn_wrapper.cpp @@ -160,6 +160,29 @@ void TensorDesc::set( } } +void TensorDesc::set_nd(const TensorLayout& layout, int pad) { + int nbDims = layout.ndim < pad ? pad : layout.ndim; + int dimA[nbDims], strideA[nbDims]; + + for (size_t i = 0; i < layout.ndim; ++i) { + dimA[i] = layout.shape[i]; + // strideA[i] = layout.stride[i]; + } + for (size_t i = layout.ndim; i < nbDims; ++i) { + dimA[i] = 1; // unused + // strideA[i] = 1; + } + // stride + for (size_t i = 0; i < nbDims; ++i) { + strideA[i] = 1; + for (size_t j = i + 1; j < nbDims; ++j) { + strideA[i] *= dimA[j]; + } + } + cudnn_check(cudnnSetTensorNdDescriptor( + desc, to_cudnn_dtype(layout.dtype), nbDims, dimA, strideA)); +} + std::string TensorDesc::to_string() { cudnnDataType_t data_type; int n; @@ -433,6 +456,97 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) { desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT)); } +DropoutDesc::DropoutDesc() { + cudnn_check(cudnnCreateDropoutDescriptor(&desc)); +} + +DropoutDesc::~DropoutDesc() { + cudnn_check(cudnnDestroyDropoutDescriptor(desc)); +} + +void DropoutDesc::set(float dropout, Handle* handle, TensorND& state) { + cudnn_check(cudnnSetDropoutDescriptor( + desc, cudnn_handle(handle), dropout, state.raw_ptr(), + state.layout.span().dist_byte(), 0 // seed + )); +} + +void DropoutDesc::set_no_dropout(Handle* handle) { + cudnn_check( + cudnnSetDropoutDescriptor(desc, cudnn_handle(handle), 0, nullptr, 0, 0)); +} + +RNNDesc::RNNDesc() { + cudnn_check(cudnnCreateRNNDescriptor(&desc)); +} + +RNNDesc::~RNNDesc() { + cudnn_check(cudnnDestroyRNNDescriptor(desc)); +} + +void RNNDesc::set( + size_t input_size, size_t hidden_size, size_t proj_size, size_t num_layers, + bool bidirectional, bool bias, const megdnn::DType dtype, cudnnRNNMode_t mode, + DropoutDesc& dropout_desc, Handle* handle) { + cudnnRNNMode_t rnn_mode = mode; + cudnnRNNBiasMode_t bias_mode = bias ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; + cudnnDirectionMode_t dir_mode = + bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; + cudnnDataType_t math_prec; + + // math precision + if (dtype.enumv() == DTypeEnum::Float16) + math_prec = CUDNN_DATA_HALF; + else + math_prec = CUDNN_DATA_FLOAT; + +#if false // CUDNN_MAJOR >= 8 + cudnn_check(cudnnSetRNNDescriptor_v8( + desc, CUDNN_RNN_ALGO_STANDARD, mode, bias_mode, dir_mode, + CUDNN_LINEAR_INPUT, to_cudnn_dtype(dtype), math_prec, CUDNN_DEFAULT_MATH, + input_size, hidden_size, proj_size, num_layers, dropout_desc.desc, + CUDNN_RNN_PADDED_IO_DISABLED)); +#else + cudnn_check(cudnnSetRNNDescriptor_v6( + cudnn_handle(handle), desc, hidden_size, num_layers, dropout_desc.desc, + CUDNN_LINEAR_INPUT, dir_mode, mode, CUDNN_RNN_ALGO_STANDARD, math_prec)); +#endif +} + +RNNDataDesc::RNNDataDesc() { + cudnn_check(cudnnCreateRNNDataDescriptor(&desc)); +} + +RNNDataDesc::~RNNDataDesc() { + cudnn_check(cudnnDestroyRNNDataDescriptor(desc)); +} + +void RNNDataDesc::set( + int batchSize, int vectorSize, int maxSeqLength, const int* devSeqLengths, + DType dtype) { + // for now, all tensor are padded in python + // int seqLengthArray[batchSize]; + // for (int i = 0; i < batchSize; ++i) seqLengthArray[i] = maxSeqLength; + cudnn_check(cudnnSetRNNDataDescriptor( + desc, to_cudnn_dtype(dtype), CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED, + maxSeqLength, batchSize, vectorSize, devSeqLengths, nullptr)); +} + +RNNWeightFilterDesc::RNNWeightFilterDesc() { + cudnn_check(cudnnCreateFilterDescriptor(&desc)); +} + +RNNWeightFilterDesc::~RNNWeightFilterDesc() { + cudnn_check(cudnnDestroyFilterDescriptor(desc)); +} + +void RNNWeightFilterDesc::set(const TensorLayout& flatten_weights) { + int weight_elem_num = flatten_weights.total_nr_elems(); + int dimW[] = {weight_elem_num, 1, 1}; + cudnn_check(cudnnSetFilterNdDescriptor( + desc, to_cudnn_dtype(flatten_weights.dtype), CUDNN_TENSOR_NCHW, 3, dimW)); +} + ////////////////////////// CudnnAlgoPack ////////////////////////// #define V1(v) #v diff --git a/dnn/src/cuda/cudnn_wrapper.h b/dnn/src/cuda/cudnn_wrapper.h index ef0ab5ab..5198e9b5 100644 --- a/dnn/src/cuda/cudnn_wrapper.h +++ b/dnn/src/cuda/cudnn_wrapper.h @@ -30,6 +30,7 @@ public: void set( const TensorLayout& layout, const param::Convolution::Format = param::Convolution::Format::NCHW); + void set_nd(const TensorLayout& layout, int pad = 3); // at least 3 dimensions std::string to_string(); ~TensorDesc(); cudnnTensorDescriptor_t desc; @@ -121,6 +122,44 @@ public: static const std::unordered_map conv3d_fwd_algos(); }; +class DropoutDesc { +public: + DropoutDesc(); + void set(float dropout, Handle* handle, TensorND& state); + void set_no_dropout(Handle* handle); + ~DropoutDesc(); + cudnnDropoutDescriptor_t desc; +}; + +class RNNDesc { +public: + RNNDesc(); + void set( + size_t input_size, size_t hidden_size, size_t proj_size, size_t num_layers, + bool bidirectional, bool bias, const megdnn::DType dtype, + cudnnRNNMode_t mode, DropoutDesc& dropout_desc, Handle* handle); + ~RNNDesc(); + cudnnRNNDescriptor_t desc; +}; + +class RNNDataDesc { +public: + RNNDataDesc(); + void set( + int batchSize, int vectorSize, int maxSeqLength, const int* devSeqLengths, + DType dtype); + ~RNNDataDesc(); + cudnnRNNDataDescriptor_t desc; +}; + +class RNNWeightFilterDesc { +public: + RNNWeightFilterDesc(); + void set(const TensorLayout& flatten_weights); + ~RNNWeightFilterDesc(); + cudnnFilterDescriptor_t desc; +}; + } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 03858f5f..59afb22e 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -50,6 +50,8 @@ #include "src/cuda/local_share/opr_impl.h" #include "src/cuda/lrn/opr_impl.h" #include "src/cuda/lsq/opr_impl.h" +#include "src/cuda/lstm/opr_impl.h" +#include "src/cuda/lstm_cell/opr_impl.h" #include "src/cuda/mask_conv/opr_impl.h" #include "src/cuda/matrix_inverse/opr_impl.h" #include "src/cuda/matrix_mul/opr_impl.h" @@ -66,6 +68,8 @@ #include "src/cuda/repeat/opr_impl.h" #include "src/cuda/resize/opr_impl.h" #include "src/cuda/rng/opr_impl.h" +#include "src/cuda/rnn/opr_impl.h" +#include "src/cuda/rnn_cell/opr_impl.h" #include "src/cuda/roi_align/opr_impl.h" #include "src/cuda/roi_copy/opr_impl.h" #include "src/cuda/roi_pooling/opr_impl.h" diff --git a/dnn/src/cuda/lstm/opr_impl.cpp b/dnn/src/cuda/lstm/opr_impl.cpp new file mode 100644 index 00000000..6dfc9ef0 --- /dev/null +++ b/dnn/src/cuda/lstm/opr_impl.cpp @@ -0,0 +1,112 @@ +/** + * \file dnn/src/cuda/lstm/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/cuda/lstm/opr_impl.h" +#include "src/cuda/lstm/utils.h" +#include "src/cuda/utils.h" + +#include + +namespace megdnn { +namespace cuda { + +void LSTMImpl::exec( + _megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, + _megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out reserve_space, + _megdnn_workspace workspace) { + Handle* handle = this->handle(); + + rnn::RNNForwardDescHolder_v6 desc_holder = + lstm::get_RNNDescHolder_v6(this->handle(), param(), input.layout); + auto x_desc_arr = rnn::get_descs(desc_holder.x_descs); + auto y_desc_arr = rnn::get_descs(desc_holder.y_descs); + RNNWeightFilterDesc w_desc; + w_desc.set(flatten_weights.layout); + + if (param().fwd_mode == param::LSTM::FwdMode::TRAINING) { + cudnn_check(cudnnRNNForwardTraining( + cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len, + x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc, + hx.raw_ptr(), desc_holder.cx_desc.desc, cx.raw_ptr(), w_desc.desc, + flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(), + desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, + cy.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, + reserve_space.raw_ptr(), desc_holder.reserveSpace_size)); + } else { + cudnn_check(cudnnRNNForwardInference( + cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len, + x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc, + hx.raw_ptr(), desc_holder.cx_desc.desc, nullptr, w_desc.desc, + flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(), + desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, + cy.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size)); + } +} + +size_t LSTMImpl::get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& cy, + const TensorLayout& reserve_space) { + rnn::RNNForwardDescHolder_v6 desc_holder = + lstm::get_RNNDescHolder_v6(this->handle(), param(), input); + return desc_holder.workspace_size; +} + +size_t LSTMImpl::get_reserve_size_in_bytes(const TensorLayout& input) { + rnn::RNNForwardDescHolder_v6 desc_holder = + lstm::get_RNNDescHolder_v6(this->handle(), param(), input); + return desc_holder.reserveSpace_size; +} + +void LSTMBackwardImpl::exec( + _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, + _megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, + _megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, + _megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx, + _megdnn_tensor_out dcx, _megdnn_tensor_out dw, _megdnn_workspace workspace) { + Handle* handle = this->handle(); + size_t seq_len = x.layout.shape[0]; + auto desc_holder = lstm::get_RNNDescHolder_v6(handle, param(), x.layout); + auto x_desc_arr_ptr = rnn::get_descs(desc_holder.x_descs).data(); + auto y_desc_arr_ptr = rnn::get_descs(desc_holder.y_descs).data(); + RNNWeightFilterDesc w_desc; + w_desc.set(flatten_weights.layout); + + cudnn_check(cudnnRNNBackwardData( + cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, y_desc_arr_ptr, + y.raw_ptr(), y_desc_arr_ptr, dy.raw_ptr(), desc_holder.hy_desc.desc, + dhy.raw_ptr(), desc_holder.cy_desc.desc, dcy.raw_ptr(), w_desc.desc, + flatten_weights.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), + desc_holder.cx_desc.desc, cx.raw_ptr(), x_desc_arr_ptr, dx.raw_ptr(), + desc_holder.hx_desc.desc, dhx.raw_ptr(), desc_holder.cx_desc.desc, + dcx.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, + reserve_space.raw_ptr(), desc_holder.reserveSpace_size)); + + cudnn_check(cudnnRNNBackwardWeights( + cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, x_desc_arr_ptr, + x.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), y_desc_arr_ptr, + y.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, w_desc.desc, + dw.raw_ptr(), reserve_space.raw_ptr(), desc_holder.reserveSpace_size)); +} + +size_t LSTMBackwardImpl::get_workspace_in_bytes( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& dcy, const TensorLayout& flatten_weights, + const TensorLayout& reserve_space, const TensorLayout& dx, + const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw) { + auto desc_holder = lstm::get_RNNDescHolder_v6(this->handle(), param(), x); + return desc_holder.workspace_size; +} + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/lstm/opr_impl.h b/dnn/src/cuda/lstm/opr_impl.h new file mode 100644 index 00000000..a01032f7 --- /dev/null +++ b/dnn/src/cuda/lstm/opr_impl.h @@ -0,0 +1,56 @@ +/** + * \file dnn/src/cuda/lstm/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace cuda { + +class LSTMImpl : public LSTM { +public: + using LSTM::LSTM; + + void exec( + _megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, + _megdnn_tensor_out hy, _megdnn_tensor_out cy, + _megdnn_tensor_out reserve_space, _megdnn_workspace workspace); + + size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& cy, + const TensorLayout& reserve_space); + size_t get_reserve_size_in_bytes(const TensorLayout& input); +}; + +class LSTMBackwardImpl : public LSTMBackward { +public: + using LSTMBackward::LSTMBackward; + + virtual void exec( + _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, + _megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, + _megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, + _megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, + _megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw, + _megdnn_workspace workspace); + + virtual size_t get_workspace_in_bytes( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& dcy, const TensorLayout& flatten_weights, + const TensorLayout& reserve_space, const TensorLayout& dx, + const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw); +}; + +} // namespace cuda +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/lstm/utils.cpp b/dnn/src/cuda/lstm/utils.cpp new file mode 100644 index 00000000..7fe1ab0e --- /dev/null +++ b/dnn/src/cuda/lstm/utils.cpp @@ -0,0 +1,39 @@ +/** + * \file dnn/src/cuda/lstm/utils.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/cuda/lstm/utils.h" +#include "src/cuda/utils.h" + +#include + +namespace megdnn { +namespace cuda { +namespace lstm { + +RNNForwardDescHolder_v6 get_RNNDescHolder_v6( + Handle* handle, megdnn::LSTMForward::Param& _param, const TensorLayout& input) { + size_t seq_len = input.shape[0]; + size_t batch_size = input.shape[1]; + size_t input_size = input.shape[2]; + + cudnnRNNMode_t mode = CUDNN_LSTM; + + using FwdMode = param::LSTM::FwdMode; + + RNNForwardDescHolder_v6 desc_holder( + handle, seq_len, batch_size, _param.hidden_size, input_size, + _param.proj_size, _param.num_layers, _param.bidirectional, _param.bias, + input.dtype, mode); + return desc_holder; +} + +} // namespace lstm +} // namespace cuda +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/lstm/utils.h b/dnn/src/cuda/lstm/utils.h new file mode 100644 index 00000000..2623c6fd --- /dev/null +++ b/dnn/src/cuda/lstm/utils.h @@ -0,0 +1,23 @@ +/** + * \file dnn/src/cuda/lstm/utils.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/rnn/utils.h" + +namespace megdnn { +namespace cuda { +namespace lstm { +using megdnn::cuda::rnn::RNNForwardDescHolder_v6; +RNNForwardDescHolder_v6 get_RNNDescHolder_v6( + Handle* handle, megdnn::LSTMForward::Param& _param, const TensorLayout& input); +} // namespace lstm +} // namespace cuda +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/lstm_cell/opr_impl.cpp b/dnn/src/cuda/lstm_cell/opr_impl.cpp new file mode 100644 index 00000000..b01c5de0 --- /dev/null +++ b/dnn/src/cuda/lstm_cell/opr_impl.cpp @@ -0,0 +1,42 @@ +/** + * \file dnn/src/cuda/lstm_cell/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/cuda/lstm_cell/opr_impl.h" +#include "megdnn/dtype.h" +#include "megdnn/oprs/base.h" +#include "src/common/lstm_cell.h" +#include "src/common/opr_delegate.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace cuda { +size_t LSTMCellImpl::get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, + const TensorLayout& gates) { + return megdnn::lstm_cell::get_workspace_in_bytes( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates, + handle()); +} + +void LSTMCellImpl::exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, + _megdnn_tensor_out gates, _megdnn_workspace workspace) { + megdnn::lstm_cell::exec( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates, + workspace, handle()); +} +} // namespace cuda + +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/lstm_cell/opr_impl.h b/dnn/src/cuda/lstm_cell/opr_impl.h new file mode 100644 index 00000000..f8578162 --- /dev/null +++ b/dnn/src/cuda/lstm_cell/opr_impl.h @@ -0,0 +1,36 @@ +/** + * \file dnn/src/cuda/lstm_cell/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "src/cuda/rnn_cell/opr_impl.h" + +namespace megdnn { +namespace cuda { + +class LSTMCellImpl : public LSTMCell { +public: + using LSTMCell::LSTMCell; + void exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, + _megdnn_tensor_out gates, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, const TensorLayout& h_new, + const TensorLayout& c_new, const TensorLayout& gates) override; +}; + +} // namespace cuda +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/rnn/opr_impl.cpp b/dnn/src/cuda/rnn/opr_impl.cpp new file mode 100644 index 00000000..937345a2 --- /dev/null +++ b/dnn/src/cuda/rnn/opr_impl.cpp @@ -0,0 +1,170 @@ +/** + * \file dnn/src/cuda/rnn/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/cuda/rnn/opr_impl.h" +#include "src/common/rnn.h" +#include "src/cuda/utils.h" + +//#include +#include +#include +#include + +namespace megdnn { +namespace cuda { + +using namespace std; + +void RNNImpl::exec( + _megdnn_tensor_in input, _megdnn_tensor_in hx, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, + _megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, + _megdnn_workspace workspace) { + Handle* handle = this->handle(); + +#if false // CUDNN_MAJOR >= 8 + rnn::RNNForwardDescHolder desc_holder = this->get_desc_holder(input.layout); + + void* workspace_ptr = workspace.raw_ptr; + void* reserveSpace_ptr = static_cast(workspace_ptr) + desc_holder.workspace_size; + + cudnn_check(cudnnRNNForward( + cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.fwdMode, desc_holder.devSeqLengths, + desc_holder.x_desc.desc, input.raw_ptr(), desc_holder.y_desc.desc, output.raw_ptr(), + desc_holder.h_desc.desc, hx.raw_ptr(), hy.raw_ptr(), + desc_holder.h_desc.desc, nullptr, nullptr, + desc_holder.weight_size, flatten_weights.raw_ptr(), desc_holder.workspace_size, workspace_ptr, + desc_holder.reserveSpace_size, reserveSpace_ptr + )); +#else + rnn::RNNForwardDescHolder_v6 desc_holder = + rnn::get_RNNDescHolder_v6(this->handle(), param(), input.layout); + auto x_desc_arr = rnn::get_descs(desc_holder.x_descs); + auto y_desc_arr = rnn::get_descs(desc_holder.y_descs); + RNNWeightFilterDesc w_desc; + w_desc.set(flatten_weights.layout); + + if (param().fwd_mode == param::RNN::FwdMode::TRAINING) { + cudnn_check(cudnnRNNForwardTraining( + cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len, + x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc, + hx.raw_ptr(), desc_holder.cx_desc.desc, NULL, w_desc.desc, + flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(), + desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, NULL, + workspace.raw_ptr, desc_holder.workspace_size, reserve_space.raw_ptr(), + desc_holder.reserveSpace_size)); + } else { + cudnn_check(cudnnRNNForwardInference( + cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len, + x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc, + hx.raw_ptr(), desc_holder.cx_desc.desc, nullptr, w_desc.desc, + flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(), + desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, + nullptr, workspace.raw_ptr, desc_holder.workspace_size)); + } +#endif +} + +size_t RNNImpl::get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& hx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& reserve_space) { +#if false // CUDNN_MAJOR >= 8 + rnn::RNNForwardDescHolder desc_holder = this->get_desc_holder(input); +#else + rnn::RNNForwardDescHolder_v6 desc_holder = + rnn::get_RNNDescHolder_v6(this->handle(), param(), input); +#endif + return desc_holder.workspace_size; +} + +size_t RNNImpl::get_reserve_size_in_bytes(const TensorLayout& input) { + rnn::RNNForwardDescHolder_v6 desc_holder = + rnn::get_RNNDescHolder_v6(this->handle(), param(), input); + return desc_holder.reserveSpace_size; +} + +/*rnn::RNNForwardDescHolder RNNImpl::get_desc_holder(const TensorLayout& input) { + Handle* handle = this->handle(); + size_t seq_len = input.shape[0]; + size_t batch_size = input.shape[1]; + size_t input_size = input.shape[2]; + auto _param = param(); + + cudnnRNNMode_t mode; + using NonlineMode = param::RNN::NonlineMode; + switch (_param.nonlineMode) { + case NonlineMode::RELU: + mode = CUDNN_RNN_RELU; + break; + case NonlineMode::TANH: + mode = CUDNN_RNN_TANH; + break; + } + + cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_TRAINING; + using FwdMode = param::RNN::FwdMode; + switch (_param.fwd_mode) { + case FwdMode::TRAINING: + fwdMode = CUDNN_FWD_MODE_TRAINING; + break; + case FwdMode::INFERENCE: + fwdMode = CUDNN_FWD_MODE_INFERENCE; + break; + } + + rnn::RNNForwardDescHolder desc_holder( + handle, seq_len, batch_size, _param.hidden_size, input_size, + _param.proj_size, _param.num_layers, _param.bidirectional, _param.bias, + input.dtype, mode, fwdMode); + return desc_holder; +}*/ + +void RNNBackwardImpl::exec( + _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, + _megdnn_tensor_in dy, _megdnn_tensor_in dhy, _megdnn_tensor_in flatten_weights, + _megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx, + _megdnn_tensor_out dw, _megdnn_workspace workspace) { + Handle* handle = this->handle(); + size_t seq_len = x.layout.shape[0]; + auto desc_holder = rnn::get_RNNDescHolder_v6(handle, param(), x.layout); + auto x_desc_arr_ptr = rnn::get_descs(desc_holder.x_descs).data(); + auto y_desc_arr_ptr = rnn::get_descs(desc_holder.y_descs).data(); + RNNWeightFilterDesc w_desc; + w_desc.set(flatten_weights.layout); + + cudnn_check(cudnnRNNBackwardData( + cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, y_desc_arr_ptr, + y.raw_ptr(), y_desc_arr_ptr, dy.raw_ptr(), desc_holder.hy_desc.desc, + dhy.raw_ptr(), desc_holder.cy_desc.desc, NULL, w_desc.desc, + flatten_weights.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), + desc_holder.cx_desc.desc, NULL, x_desc_arr_ptr, dx.raw_ptr(), + desc_holder.hx_desc.desc, dhx.raw_ptr(), desc_holder.cx_desc.desc, NULL, + workspace.raw_ptr, desc_holder.workspace_size, reserve_space.raw_ptr(), + desc_holder.reserveSpace_size)); + + cudnn_check(cudnnRNNBackwardWeights( + cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, x_desc_arr_ptr, + x.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), y_desc_arr_ptr, + y.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, w_desc.desc, + dw.raw_ptr(), reserve_space.raw_ptr(), desc_holder.reserveSpace_size)); +} + +size_t RNNBackwardImpl::get_workspace_in_bytes( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& flatten_weights, const TensorLayout& reserve_space, + const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw) { + auto desc_holder = rnn::get_RNNDescHolder_v6(this->handle(), param(), x); + return desc_holder.workspace_size; +} + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/rnn/opr_impl.h b/dnn/src/cuda/rnn/opr_impl.h new file mode 100644 index 00000000..701e7876 --- /dev/null +++ b/dnn/src/cuda/rnn/opr_impl.h @@ -0,0 +1,57 @@ +/** + * \file dnn/src/cuda/rnn/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/rnn/utils.h" + +namespace megdnn { +namespace cuda { + +class RNNImpl : public RNN { +public: + using RNN::RNN; + + void exec( + _megdnn_tensor_in input, _megdnn_tensor_in hx, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, + _megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, + _megdnn_workspace workspace); + + size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& hx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& reserve_space); + size_t get_reserve_size_in_bytes(const TensorLayout& input); + // private: + // rnn::RNNForwardDescHolder get_desc_holder(const TensorLayout& input); +}; + +class RNNBackwardImpl : public RNNBackward { +public: + using RNNBackward::RNNBackward; + + virtual void exec( + _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, + _megdnn_tensor_in dy, _megdnn_tensor_in dhy, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space, + _megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw, + _megdnn_workspace workspace); + + virtual size_t get_workspace_in_bytes( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& flatten_weights, const TensorLayout& reserve_space, + const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw); +}; + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/rnn/utils.cpp b/dnn/src/cuda/rnn/utils.cpp new file mode 100644 index 00000000..9e4ff825 --- /dev/null +++ b/dnn/src/cuda/rnn/utils.cpp @@ -0,0 +1,138 @@ +/** + * \file dnn/src/cuda/rnn/utils.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/cuda/rnn/utils.h" +#include "src/cuda/utils.h" + +#include + +namespace megdnn { +namespace cuda { +namespace rnn { +/*RNNForwardDescHolder::RNNForwardDescHolder(Handle* handle, size_t seq_len, size_t +batch_size, size_t hidden_size, size_t input_size, size_t proj_size, size_t num_layers, +bool bidirectional, bool bias, DType dtype, cudnnRNNMode_t _mode, cudnnForwardMode_t +_fwdMode) : mode(_mode), fwdMode(_fwdMode) +{ + size_t D = bidirectional ? 2 : 1; + + // TODO: set dropout to 0 in inference mode + dropout_desc.set_no_dropout(handle); + + // seq len is unified (not packed) + // cuda_check(cudaMalloc((void**)&devSeqLengths, sizeof(int32_t) * batch_size)); + devSeqLengths = (int32_t*)malloc(sizeof(int32_t) * batch_size); + for (size_t i = 0; i < batch_size; ++i) devSeqLengths[i] = seq_len; + + // proj size should be smaller than hidden size according to cudnn api + // otherwise it is disabled + proj_size = (proj_size > hidden_size || proj_size == 0) ? hidden_size : +proj_size; rnn_desc.set( input_size, hidden_size, proj_size, num_layers, bidirectional, +bias, dtype, mode, dropout_desc, handle + ); + + x_desc.set(batch_size, input_size, seq_len, devSeqLengths, dtype); + y_desc.set(batch_size, D * proj_size, seq_len, + devSeqLengths, dtype); + h_desc.set_nd(TensorLayout(TensorShape{D * num_layers, batch_size, proj_size}, +dtype)); + + cudnn_check(cudnnGetRNNWeightSpaceSize(cudnn_handle(handle), rnn_desc.desc, +&weight_size)); + + cudnn_check(cudnnGetRNNTempSpaceSizes( + cudnn_handle(handle), rnn_desc.desc, fwdMode, x_desc.desc, +&workspace_size, &reserveSpace_size + )); +} + +RNNForwardDescHolder::~RNNForwardDescHolder() { + // cuda_check(cudaFree(devSeqLengths)); + free(devSeqLengths); +}*/ + +RNNForwardDescHolder_v6::RNNForwardDescHolder_v6( + Handle* handle, size_t seq_len, size_t batch_size, size_t hidden_size, + size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional, + bool bias, DType dtype, cudnnRNNMode_t _mode) + : mode(_mode), seq_len(seq_len) { + size_t D = bidirectional ? 2 : 1; + + // TODO: set dropout to 0 in inference mode + dropout_desc.set_no_dropout(handle); + + proj_size = (proj_size > hidden_size || proj_size == 0) ? hidden_size : proj_size; + rnn_desc.set( + input_size, hidden_size, proj_size, num_layers, bidirectional, bias, dtype, + mode, dropout_desc, handle); + + x_descs.resize(seq_len); + y_descs.resize(seq_len); + for (size_t i = 0; i < seq_len; ++i) { + x_descs[i].set_nd(TensorLayout(TensorShape{batch_size, input_size}, dtype), 3); + y_descs[i].set_nd( + TensorLayout(TensorShape{batch_size, D * hidden_size}, dtype), 3); + } + +#define SET_H(_var) \ + _var.set_nd(TensorLayout( \ + TensorShape{D * num_layers, batch_size, hidden_size}, dtype)); + + SET_H(hx_desc) + SET_H(cx_desc) + SET_H(hy_desc) + SET_H(cy_desc) +#undef SET_H + + std::vector x_desc_arr = get_descs(x_descs); + cudnn_check(cudnnGetRNNWorkspaceSize( + cudnn_handle(handle), rnn_desc.desc, seq_len, x_desc_arr.data(), + &workspace_size)); + + cudnn_check(cudnnGetRNNTrainingReserveSize( + cudnn_handle(handle), rnn_desc.desc, seq_len, x_desc_arr.data(), + &reserveSpace_size)); +} + +RNNForwardDescHolder_v6 get_RNNDescHolder_v6( + Handle* handle, megdnn::RNNForward::Param& _param, const TensorLayout& input) { + size_t seq_len = input.shape[0]; + size_t batch_size = input.shape[1]; + size_t input_size = input.shape[2]; + + cudnnRNNMode_t mode; + using NonlineMode = param::RNN::NonlineMode; + switch (_param.nonlineMode) { + case NonlineMode::RELU: + mode = CUDNN_RNN_RELU; + break; + case NonlineMode::TANH: + mode = CUDNN_RNN_TANH; + break; + } + + RNNForwardDescHolder_v6 desc_holder( + handle, seq_len, batch_size, _param.hidden_size, input_size, + _param.proj_size, _param.num_layers, _param.bidirectional, _param.bias, + input.dtype, mode); + return desc_holder; +} + +std::vector get_descs(const std::vector& descs) { + std::vector r; + r.reserve(descs.size()); + for (auto& desc : descs) { + r.emplace_back(desc.desc); + } + return r; +} +} // namespace rnn +} // namespace cuda +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/rnn/utils.h b/dnn/src/cuda/rnn/utils.h new file mode 100644 index 00000000..d2a652dc --- /dev/null +++ b/dnn/src/cuda/rnn/utils.h @@ -0,0 +1,56 @@ +/** + * \file dnn/src/cuda/rnn/utils.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/cuda/cudnn_wrapper.h" + +namespace megdnn { +namespace cuda { +namespace rnn { +// v8, not for now +/*struct RNNForwardDescHolder { + + int32_t* devSeqLengths; + cudnnRNNMode_t mode; + cudnnForwardMode_t fwdMode; + RNNDesc rnn_desc; + DropoutDesc dropout_desc; + RNNDataDesc x_desc, y_desc; + TensorDesc h_desc; + size_t weight_size, workspace_size, reserveSpace_size; + + RNNForwardDescHolder(Handle* handle, size_t seq_len, size_t batch_size, size_t +hidden_size, size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional, + bool bias, DType dtype, +cudnnRNNMode_t _mode, cudnnForwardMode_t _fwdMode); ~RNNForwardDescHolder(); +};*/ + +struct RNNForwardDescHolder_v6 { + cudnnRNNMode_t mode; + RNNDesc rnn_desc; + int seq_len; + DropoutDesc dropout_desc; + std::vector x_descs, y_descs; + TensorDesc hx_desc, cx_desc, hy_desc, cy_desc; + + size_t workspace_size, reserveSpace_size; + + RNNForwardDescHolder_v6( + Handle* handle, size_t seq_len, size_t batch_size, size_t hidden_size, + size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional, + bool bias, DType dtype, cudnnRNNMode_t _mode); +}; + +RNNForwardDescHolder_v6 get_RNNDescHolder_v6( + Handle* handle, megdnn::RNNForward::Param& _param, const TensorLayout& input); +std::vector get_descs(const std::vector& descs); +} // namespace rnn +} // namespace cuda +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/rnn_cell/opr_impl.cpp b/dnn/src/cuda/rnn_cell/opr_impl.cpp new file mode 100644 index 00000000..4fb5f2ed --- /dev/null +++ b/dnn/src/cuda/rnn_cell/opr_impl.cpp @@ -0,0 +1,35 @@ +/** + * \file dnn/src/cuda/rnn_cell/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/cuda/rnn_cell/opr_impl.h" +#include "src/common/rnn_cell.h" + +namespace megdnn { +namespace cuda { +size_t RNNCellImpl::get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& dst) { + return megdnn::rnn_cell::get_workspace_in_bytes( + input, weight_ih, bias_hh, hx, weight_hh, bias_hh, dst, handle()); +} + +void RNNCellImpl::exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + megdnn::rnn_cell::exec( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, workspace, + param().nonlineMode, handle()); +} +} // namespace cuda + +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/rnn_cell/opr_impl.h b/dnn/src/cuda/rnn_cell/opr_impl.h new file mode 100644 index 00000000..916d8586 --- /dev/null +++ b/dnn/src/cuda/rnn_cell/opr_impl.h @@ -0,0 +1,40 @@ +/** + * \file dnn/src/cuda/rnn_cell/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace cuda { + +class RNNCellImpl : public RNNCell { +public: + using RNNCell::RNNCell; + void exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& dst) override; + /* + private: + void add_bias(_megdnn_tensor_in A, + _megdnn_tensor_in B, + _megdnn_tensor_in bias, + _megdnn_tensor_out C); + */ +}; + +} // namespace cuda +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index e38bfead..2997732b 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -52,6 +52,8 @@ #include "src/naive/local_share/opr_impl.h" #include "src/naive/lrn/opr_impl.h" #include "src/naive/lsq/opr_impl.h" +#include "src/naive/lstm/opr_impl.h" +#include "src/naive/lstm_cell/opr_impl.h" #include "src/naive/mask_conv/opr_impl.h" #include "src/naive/matrix_inverse/opr_impl.h" #include "src/naive/matrix_mul/opr_impl.h" @@ -68,6 +70,8 @@ #include "src/naive/repeat/opr_impl.h" #include "src/naive/resize/opr_impl.h" #include "src/naive/rng/opr_impl.h" +#include "src/naive/rnn/opr_impl.h" +#include "src/naive/rnn_cell/opr_impl.h" #include "src/naive/roi_align/opr_impl.h" #include "src/naive/roi_copy/opr_impl.h" #include "src/naive/roi_pooling/opr_impl.h" diff --git a/dnn/src/naive/lstm/opr_impl.cpp b/dnn/src/naive/lstm/opr_impl.cpp new file mode 100644 index 00000000..73f6ac1d --- /dev/null +++ b/dnn/src/naive/lstm/opr_impl.cpp @@ -0,0 +1,146 @@ +/** + * \file dnn/src/naive/lstm/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/naive/lstm/opr_impl.h" +#include "src/naive/rnn/funcs.h" +#include "src/naive/rnn/rnn.h" + +namespace megdnn { +namespace naive { +using rnn::LSTMCellWeightWrapper; + +void LSTMImpl::exec( + _megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, + _megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out reserve_space, + _megdnn_workspace workspace) { + auto _param = param(); + size_t D = _param.bidirectional ? 2 : 1; + size_t num_layers = _param.num_layers; + size_t input_size = input.layout.shape[2]; + std::vector cells; + size_t used_workspace_size = rnn::get_cells( + D, num_layers, input_size, _param.hidden_size, _param.bias, cells, + flatten_weights, workspace); + + Workspace new_workspace( + workspace.raw_ptr + used_workspace_size, + workspace.size - used_workspace_size); + TensorNDArray states = {hx, cx}, states_new = {hy, cy}; + rnn::exec_internal( + cells, input, states, states_new, output, reserve_space, num_layers, D, + this->handle(), new_workspace); +} + +size_t LSTMImpl::get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& cy, + const TensorLayout& reserve_space) { + size_t workspace_size = rnn::get_workspace_in_bytes( + input, flatten_weights, param().hidden_size, param().bidirectional ? 2 : 1, + this->handle()); + if (!param().bias) { // use fake bias (all 0) + TensorLayout bias_layout = {{param().hidden_size * 4}, flatten_weights.dtype}; + workspace_size += bias_layout.span().dist_byte(); + } + workspace_size += output.span().dist_byte(); + return workspace_size; +} + +size_t LSTMImpl::get_reserve_size_in_bytes(const TensorLayout& input) { + size_t num_layers = param().num_layers; + size_t D = param().bidirectional ? 2 : 1; + size_t seq_len = input.shape[0]; + size_t batch_size = input.shape[1]; + TensorLayout state_layout{{batch_size, param().hidden_size}, input.dtype}; + // 2 for hidden state and cell state + return 2 * num_layers * D * seq_len * state_layout.span().dist_byte(); +} + +void LSTMBackwardImpl::exec( + _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, + _megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, + _megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, + _megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx, + _megdnn_tensor_out dcx, _megdnn_tensor_out dw, _megdnn_workspace workspace) { + TensorNDArray layer_inputs; + TensorNDArray layer_outputs; + std::vector> cell_seq_states; + size_t num_layers = param().num_layers; + size_t D = param().bidirectional ? 2 : 1; + size_t input_size = x.layout.shape[2]; + size_t hidden_size = param().hidden_size; + size_t used_workspace_size = 0; + + // get cells + std::vector cells; + used_workspace_size += rnn::get_cells( + D, num_layers, input_size, hidden_size, param().bias, cells, + flatten_weights, workspace); + + // get formatted inputs + Workspace new_workspace = Workspace( + workspace.raw_ptr + used_workspace_size, + workspace.size - used_workspace_size); + used_workspace_size += rnn::get_inputs_for_exec( + x, y, reserve_space, num_layers, D, hidden_size, cells, layer_inputs, + layer_outputs, cell_seq_states, param::RNNCell::NonlineMode::IDENTITY, + new_workspace); + + // dhy arr, dhx arr + TensorNDArray dhy_arr = {dhy, dcy}, dhx_arr = {dhx, dcx}; + + // exec + new_workspace = Workspace( + workspace.raw_ptr + used_workspace_size, + workspace.size - used_workspace_size); + rnn::backward_exec_internal( + cells, D, num_layers, input_size, param().bias, + param::RNNCell::NonlineMode::IDENTITY, layer_inputs, layer_outputs, + cell_seq_states, dy, dhy_arr, dx, dhx_arr, dw, this->handle(), + new_workspace); +} + +size_t LSTMBackwardImpl::get_workspace_in_bytes( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& dcy, const TensorLayout& flatten_weights, + const TensorLayout& reserve_space, const TensorLayout& dx, + const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw) { + size_t D = param().bidirectional ? 2 : 1; + size_t num_layers = param().num_layers; + size_t hidden_size = param().hidden_size; + size_t gate_hidden_size = hidden_size * 4; + size_t max_input_size = std::max(x.shape[2], D * hidden_size); + + size_t workspace_size = LSTMCellWeightWrapper::backward_workspace_size_in_bytes( + this->handle(), x.shape[1], param().hidden_size, max_input_size, x.dtype); + if (!param().bias) { // use fake bias (all 0) + TensorLayout bias_layout = {{gate_hidden_size}, flatten_weights.dtype}; + workspace_size += bias_layout.span().dist_byte() * + 2; // times 2 because another bias is allocated in + // backward_exec_internal + } + workspace_size += num_layers * y.span().dist_byte(); + // add back exec workspace size + workspace_size += y.span().dist_byte() * 2; + workspace_size += x.span().dist_byte() * 2; + TensorLayout wih{{gate_hidden_size, max_input_size}, flatten_weights.dtype}; + TensorLayout whh{{gate_hidden_size, hidden_size}, flatten_weights.dtype}; + TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype}; + workspace_size += wih.span().dist_byte(); + workspace_size += whh.span().dist_byte(); + workspace_size += bias.span().dist_byte(); + return workspace_size; +} +} // namespace naive + +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/naive/lstm/opr_impl.h b/dnn/src/naive/lstm/opr_impl.h new file mode 100644 index 00000000..9ff5aa9d --- /dev/null +++ b/dnn/src/naive/lstm/opr_impl.h @@ -0,0 +1,56 @@ +/** + * \file dnn/src/naive/lstm/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class LSTMImpl : public LSTM { +public: + using LSTM::LSTM; + + void exec( + _megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, + _megdnn_tensor_out hy, _megdnn_tensor_out cy, + _megdnn_tensor_out reserve_space, _megdnn_workspace workspace); + + size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& cy, + const TensorLayout& reserve_space); + size_t get_reserve_size_in_bytes(const TensorLayout& input); +}; + +class LSTMBackwardImpl : public LSTMBackward { +public: + using LSTMBackward::LSTMBackward; + + virtual void exec( + _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, + _megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, + _megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, + _megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, + _megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw, + _megdnn_workspace workspace); + + virtual size_t get_workspace_in_bytes( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& dcy, const TensorLayout& flatten_weights, + const TensorLayout& reserve_space, const TensorLayout& dx, + const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw); +}; + +} // namespace naive +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/naive/lstm/template_impl.cpp b/dnn/src/naive/lstm/template_impl.cpp new file mode 100644 index 00000000..108ee238 --- /dev/null +++ b/dnn/src/naive/lstm/template_impl.cpp @@ -0,0 +1,55 @@ +/** + * \file dnn/src/naive/lstm/template_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/naive/rnn/funcs.h" + +namespace megdnn { +namespace naive { +namespace rnn { + +template <> +void cell_opr_exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in bias_hh, const TensorNDArray& states, + TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle) { + auto opr = handle->create_operator(); + TensorLayout gates, h_new, c_new; + opr->deduce_layout( + input.layout, weight_ih.layout, bias_ih.layout, states[0].layout, + weight_hh.layout, bias_hh.layout, states[1].layout, h_new, c_new, gates); + TensorND gates_tensor{workspace.raw_ptr, gates}; + _megdnn_workspace new_workspace = { + workspace.raw_ptr + gates.span().dist_byte(), + workspace.size - gates.span().dist_byte()}; + opr->exec( + input, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states[1], + states_new[0], states_new[1], gates_tensor, new_workspace); +} + +template <> +size_t cell_opr_get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& weight_hh, const TensorLayout& bias_ih, + const TensorLayout& bias_hh, const TensorLayout& hx, Handle* handle) { + TensorLayout cx = hx; + TensorLayout h_new, c_new, gates; + auto cell_opr = handle->create_operator(); + cell_opr->deduce_layout( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates); + return cell_opr->get_workspace_in_bytes( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, + gates) + + gates.span().dist_byte(); +} + +} // namespace rnn +} // namespace naive +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/naive/lstm_cell/opr_impl.cpp b/dnn/src/naive/lstm_cell/opr_impl.cpp new file mode 100644 index 00000000..10b0dd6b --- /dev/null +++ b/dnn/src/naive/lstm_cell/opr_impl.cpp @@ -0,0 +1,38 @@ +/** + * \file dnn/src/naive/lstm_cell/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/naive/lstm_cell/opr_impl.h" +#include "src/common/lstm_cell.h" + +namespace megdnn { +namespace naive { +size_t LSTMCellImpl::get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, + const TensorLayout& gates) { + return megdnn::lstm_cell::get_workspace_in_bytes( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates, + handle()); +} + +void LSTMCellImpl::exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, + _megdnn_tensor_out gates, _megdnn_workspace workspace) { + megdnn::lstm_cell::exec( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates, + workspace, handle()); +} +} // namespace naive + +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/naive/lstm_cell/opr_impl.h b/dnn/src/naive/lstm_cell/opr_impl.h new file mode 100644 index 00000000..4f56c8eb --- /dev/null +++ b/dnn/src/naive/lstm_cell/opr_impl.h @@ -0,0 +1,36 @@ +/** + * \file dnn/src/naive/lstm_cell/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "src/naive/rnn_cell/opr_impl.h" + +namespace megdnn { +namespace naive { + +class LSTMCellImpl : public LSTMCell { +public: + using LSTMCell::LSTMCell; + void exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, + _megdnn_tensor_out gates, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& cx, const TensorLayout& h_new, + const TensorLayout& c_new, const TensorLayout& gates) override; +}; + +} // namespace naive +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/naive/relayout/opr_impl.cpp b/dnn/src/naive/relayout/opr_impl.cpp index cda3cc21..fd62f82c 100644 --- a/dnn/src/naive/relayout/opr_impl.cpp +++ b/dnn/src/naive/relayout/opr_impl.cpp @@ -72,9 +72,7 @@ void RelayoutForwardImpl::do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) void RelayoutForwardImpl::check_cpu_handle(Handle* handle) { megdnn_assert( - !handle || - handle == this->handle() - || is_cpu_handle(handle), + !handle || handle == this->handle() || is_cpu_handle(handle), "relayout from non-CPU to CPU not supported"); } diff --git a/dnn/src/naive/rnn/funcs.h b/dnn/src/naive/rnn/funcs.h new file mode 100644 index 00000000..2ea83416 --- /dev/null +++ b/dnn/src/naive/rnn/funcs.h @@ -0,0 +1,75 @@ +/** + * \file dnn/src/naive/rnn/funcs.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#ifndef _RNN_H +#define _RNN_H +#include "megdnn/oprs.h" +namespace megdnn { +namespace naive { +namespace rnn { + +template +void cell_opr_exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in bias_hh, const TensorNDArray& states, + TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle); + +template +size_t cell_opr_get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& weight_hh, const TensorLayout& bias_ih, + const TensorLayout& bias_hh, const TensorLayout& hx, Handle* handle); + +template +size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& flatten_weights, + size_t hidden_size, + size_t D, // num_directions + Handle* handle); + +template +void exec_internal( + std::vector& cells, _megdnn_tensor_in input, const TensorNDArray& states, + TensorNDArray& states_new, _megdnn_tensor_out output, + _megdnn_tensor_out reserve_space, size_t num_layers, + size_t D, // D is num_directions + Handle* handle, _megdnn_workspace workspace); + +template +size_t get_cells( + size_t D, size_t num_layers, size_t input_size, size_t hidden_size, bool bias, + std::vector& cells, _megdnn_tensor_in flatten_weights, + _megdnn_workspace workspace); + +template +size_t get_inputs_for_exec( + _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in reserve_space, + size_t num_layers, size_t D, size_t hidden_size, const std::vector& cells, + TensorNDArray& layer_inputs, TensorNDArray& layer_outputs, + std::vector>& cell_seq_states, + param::RNNCell::NonlineMode nonlineMode, _megdnn_workspace workspace); + +template +void backward_exec_internal( + std::vector& cells, size_t D, size_t num_layers, size_t input_size, + bool bias, param::RNNCell::NonlineMode nonlineMode, + const TensorNDArray& layer_inputs, const TensorNDArray& layer_outputs, + const std::vector>& cell_seq_states, + _megdnn_tensor_in dy, const TensorNDArray& dhy, _megdnn_tensor_out dx, + TensorNDArray& dstates, _megdnn_tensor_out dw, Handle* handle, + _megdnn_workspace workspace); + +} // namespace rnn +} // namespace naive +} // namespace megdnn + +#include "funcs.tpp" +#endif diff --git a/dnn/src/naive/rnn/funcs.tpp b/dnn/src/naive/rnn/funcs.tpp new file mode 100644 index 00000000..60703f21 --- /dev/null +++ b/dnn/src/naive/rnn/funcs.tpp @@ -0,0 +1,449 @@ +/** + * \file dnn/src/naive/rnn/funcs.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "funcs.h" + +namespace megdnn { +namespace naive { +namespace rnn { + +template +size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& flatten_weights, + size_t hidden_size, + size_t D, // num_directions + Handle* handle) { + size_t seq_len = input.shape[0]; + size_t batch_size = input.shape[1]; + size_t input_size = input.shape[2]; + size_t gate_hidden_size = flatten_weights.shape[0]; + // concat workspace + TensorLayout direction_output_layout{ + TensorShape{seq_len, batch_size, hidden_size}, input.dtype}; + TensorLayout output_layout{{seq_len, batch_size, D * hidden_size}, input.dtype}; + TensorLayoutArray layer_layouts; + for (size_t i = 0; i < D; ++i) + layer_layouts.push_back(direction_output_layout); + auto concat_opr = handle->create_operator(); + concat_opr->param().axis = -1; + size_t concat_workspace = + concat_opr->get_workspace_in_bytes(layer_layouts, output_layout); + // cell workspace + TensorLayout weight_ih{{gate_hidden_size, input_size}, flatten_weights.dtype}; + TensorLayout weight_hh{{gate_hidden_size, hidden_size}, flatten_weights.dtype}; + TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype}; + TensorLayout hx{{batch_size, hidden_size}, input.dtype}; + size_t cell_workspace = cell_opr_get_workspace_in_bytes( + input, weight_ih, weight_hh, bias, bias, hx, handle); + + return std::max(cell_workspace, concat_workspace); +} + +template +void exec_internal( + std::vector& cells, _megdnn_tensor_in input, const TensorNDArray& states, + TensorNDArray& states_new, _megdnn_tensor_out output, + _megdnn_tensor_out reserve_space, size_t num_layers, + size_t D, // D is num_directions + Handle* handle, _megdnn_workspace workspace) { + size_t seq_len = input.layout.shape[0]; + size_t batch_size = input.layout.shape[1]; + size_t input_size = input.layout.shape[2]; + size_t hidden_size = cells[0].weight_hh.layout.shape[1]; + TensorLayout cell_output_layout{ + TensorShape{batch_size, hidden_size}, states[0].layout.dtype}; + TensorLayout cell_first_input_layout{ + TensorShape{batch_size, input_size}, input.layout.dtype}; + TensorLayout cell_input_layout{ + TensorShape{batch_size, D * hidden_size}, input.layout.dtype}; + TensorLayout direction_output_layout{ + TensorShape{seq_len, batch_size, hidden_size}, output.layout.dtype}; + TensorND tmp_output{workspace.raw_ptr, output.layout}; + _megdnn_workspace new_workspace{ + workspace.raw_ptr + tmp_output.layout.span().dist_byte(), + workspace.size - tmp_output.layout.span().dist_byte()}; + + auto cell_opr = handle->create_operator(); + auto copy_opr = handle->create_operator(); + + // copy states to states_new + for (size_t i = 0; i < states.size(); ++i) + copy_opr->exec(states[i], states_new[i]); + void* reserve_ptr = reserve_space.raw_ptr(); + + // layer 1 + // TensorNDArray layer_outputs; + for (size_t d = 0; d < D; ++d) { + size_t cell_idx = d; + auto& cell = cells[cell_idx]; + + TensorNDArray cur_states; + size_t states_offset = cell_idx * cell_output_layout.span().dist_byte(); + for (size_t i = 0; i < states.size(); ++i) { + cur_states.push_back(TensorND{ + static_cast(states_new[i].raw_ptr()) + states_offset, + cell_output_layout}); + } + // TensorND direction_output_tensor{output.raw_ptr + d * + // direction_output_layout.span().dist_byte(), + // direction_output_layout}; + for (size_t i = 0; i < seq_len; ++i) { + size_t step = d == 0 ? i : seq_len - 1 - i; + TensorND step_input{ + static_cast(input.raw_ptr()) + + step * cell_first_input_layout.span().dist_byte(), + cell_first_input_layout}; + TensorND step_output{ + static_cast(output.raw_ptr()) + + (step * D + d) * cell_output_layout.span().dist_byte(), + cell_output_layout}; + // temporary states of each step (use reserve space) + TensorNDArray tmp_states; + for (size_t s = 0; s < cur_states.size(); ++s) { + tmp_states.push_back(TensorND{reserve_ptr, cur_states[s].layout}); + size_t size_in_bytes = cur_states[s].layout.span().dist_byte(); + reserve_ptr = static_cast(reserve_ptr) + size_in_bytes; + } + cell_opr_exec( + step_input, cell.weight_ih, cell.weight_hh, cell.bias_ih, + cell.bias_hh, cur_states, tmp_states, new_workspace, handle); + // copy states to cur_states + for (size_t s = 0; s < tmp_states.size(); ++s) { + copy_opr->exec(tmp_states[s], cur_states[s]); + } + // copy h to step output + copy_opr->exec(cur_states[0], step_output); + } + } + + for (size_t layer = 1; layer < num_layers; ++layer) { + // TensorNDArray layer_outputs; + + for (size_t d = 0; d < D; ++d) { + size_t cell_idx = layer * D + d; + auto& cell = cells[cell_idx]; + + TensorNDArray cur_states; + size_t states_offset = cell_idx * cell_output_layout.span().dist_byte(); + for (size_t i = 0; i < states.size(); ++i) { + cur_states.push_back(TensorND{ + static_cast(states_new[i].raw_ptr()) + states_offset, + cell_output_layout}); + } + // TensorND direction_output_tensor{output.raw_ptr + d * + // direction_output_layout.span().dist_byte(), + // direction_output_layout}; + + for (size_t i = 0; i < seq_len; ++i) { + size_t step = d == 0 ? i : seq_len - 1 - i; + TensorND step_input{ + static_cast(output.raw_ptr()) + + step * cell_input_layout.span().dist_byte(), + cell_input_layout}; + TensorND step_output{ + static_cast(tmp_output.raw_ptr()) + + (step * D + d) * cell_output_layout.span().dist_byte(), + cell_output_layout}; + // temporary states of each step (use reserve space) + TensorNDArray tmp_states; + for (size_t s = 0; s < cur_states.size(); ++s) { + tmp_states.push_back(TensorND{reserve_ptr, cur_states[s].layout}); + size_t size_in_bytes = cur_states[s].layout.span().dist_byte(); + reserve_ptr = static_cast(reserve_ptr) + size_in_bytes; + } + cell_opr_exec( + step_input, cell.weight_ih, cell.weight_hh, cell.bias_ih, + cell.bias_hh, cur_states, cur_states, new_workspace, handle); + // copy states to cur_states + for (size_t s = 0; s < tmp_states.size(); ++s) { + copy_opr->exec(tmp_states[s], cur_states[s]); + } + // copy h to step_output + copy_opr->exec(cur_states[0], step_output); + } + } + // copy layer output to output + copy_opr->exec(tmp_output, output); + } + // output: [d0, d1, d0, d1 ...] +} + +template +size_t get_cells( + size_t D, size_t num_layers, size_t input_size, size_t hidden_size, bool bias, + std::vector& cells, _megdnn_tensor_in flatten_weights, + _megdnn_workspace workspace) { + cells.reserve(D * num_layers); + void* weight_ptr = flatten_weights.raw_ptr(); + for (size_t layer = 0; layer < num_layers; ++layer) { + for (size_t d = 0; d < D; ++d) { + size_t cell_input_size = D * hidden_size; + if (layer == 0) + cell_input_size = input_size; + Cell cell( + weight_ptr, hidden_size, cell_input_size, bias, + flatten_weights.layout.dtype, workspace); + weight_ptr = + static_cast(weight_ptr) + cell.weight_size_in_bytes(); + cells.push_back(cell); + } + } + // return used workspace + return cells[0].workspace_size_in_bytes(); +} + +template +size_t get_inputs_for_exec( + _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in reserve_space, + size_t num_layers, size_t D, size_t hidden_size, const std::vector& cells, + TensorNDArray& layer_inputs, TensorNDArray& layer_outputs, + std::vector>& cell_seq_states, + param::RNNCell::NonlineMode nonlineMode, _megdnn_workspace workspace) { + // return used workspace size + + layer_inputs.push_back(x); + size_t seq_len = x.layout.shape[0]; + size_t batch_size = x.layout.shape[1]; + size_t num_states = cells[0].num_states(); + TensorLayout cell_output_layout{{batch_size, hidden_size}, y.layout.dtype}; + TensorLayout direction_output_layout{ + {seq_len, batch_size, hidden_size}, y.layout.dtype}; + void* workspace_ptr = workspace.raw_ptr; + + // extract intermedia states from reserve space + for (int layer = 0; layer < num_layers; ++layer) { + TensorND layer_output{workspace_ptr, y.layout}; + workspace_ptr = static_cast(workspace_ptr) + + layer_output.layout.span().dist_byte(); + for (int d = 0; d < D; ++d) { + cell_seq_states.push_back(std::vector()); + // reverse direction is stored with reversed order of sequence order + for (int i = 0; i < seq_len; ++i) { + size_t step = i; + if (d == 1) + step = seq_len - i - 1; + size_t offset = ((layer * D + d) * seq_len + step) * + cell_output_layout.span().dist_byte() * num_states; + TensorNDArray cur_states; + for (int s = 0; s < num_states; ++s) { + TensorND h{ + static_cast(reserve_space.raw_ptr()) + offset + + s * cell_output_layout.span().dist_byte(), + cell_output_layout}; + cur_states.push_back(h); + } + TensorND hy{ + static_cast(reserve_space.raw_ptr()) + offset, + cell_output_layout}; // the first hidden state is the output + // states + cell_seq_states[cell_seq_states.size() - 1].push_back(cur_states); + // output + offset = i * D * cell_output_layout.span().dist_byte(); + memcpy(static_cast(layer_output.raw_ptr()) + offset, hy.raw_ptr(), + hy.layout.span().dist_byte()); + } + } + layer_outputs.push_back(layer_output); + if (layer != num_layers - 1) + layer_inputs.push_back(layer_output); + } + return static_cast(workspace_ptr) - + static_cast((void*)workspace.raw_ptr); +} + +template +// using Cell = RNNCellWeightWrapper; +void backward_exec_internal( + std::vector& cells, size_t D, size_t num_layers, size_t input_size, + bool bias, param::RNNCell::NonlineMode nonlineMode, + const TensorNDArray& layer_inputs, const TensorNDArray& layer_outputs, + const std::vector>& cell_seq_states, + _megdnn_tensor_in dy, const TensorNDArray& dhy, _megdnn_tensor_out dx, + TensorNDArray& dstates, _megdnn_tensor_out dw, Handle* handle, + _megdnn_workspace workspace) { + /* + layer_inputs: array of input of each layer, element 0: [seq_len, batch_size, + input_size], element others: [seq_len, batch_size, D * hidden_size] + layer_outputs: array of outputs of each rnn. To access outputs of the cell at + (layer, d), use layer_outputs[layer]. The shape is [seq_len, batch_size, + output_size(D*hidden_size)] (in sequence order) cell_seq_states: arrray of states + of each cell at each step. To access the states of the cell at (layer, d) at + sequence step (step), use cell_seq_states[layer*D + d][step] + */ + size_t seq_len = layer_inputs[0].layout.shape[0]; + size_t batch_size = layer_inputs[0].layout.shape[1]; + DType dtype = layer_inputs[0].layout.dtype; + size_t cell_y_size = + layer_outputs[0].layout.shape[2] / D; // should all be the same + size_t hidden_size = cell_y_size; + TensorLayout cell_y_layout = {{batch_size, cell_y_size}, dtype}; + void* workspace_ptr = workspace.raw_ptr; + + TensorND layer_output_grad{ + workspace_ptr, {{seq_len, batch_size, D * hidden_size}, dtype}}; + workspace_ptr = static_cast(workspace_ptr) + + layer_output_grad.layout.span().dist_byte(); + memcpy(layer_output_grad.raw_ptr(), dy.raw_ptr(), dy.layout.span().dist_byte()); + TensorNDArray direction_dx_arr; // for layer 1 to layer num_layers-1 + for (int i = 0; i < D; ++i) { + TensorLayout direction_dx_layout{{seq_len, batch_size, hidden_size}, dtype}; + direction_dx_arr.push_back(TensorND(workspace_ptr, direction_dx_layout)); + workspace_ptr = static_cast(workspace_ptr) + + direction_dx_layout.span().dist_byte(); + } + TensorNDArray L0_direction_dx_arr; + for (int i = 0; i < D; ++i) { + TensorLayout direction_dx_layout{{seq_len, batch_size, input_size}, dtype}; + L0_direction_dx_arr.push_back(TensorND(workspace_ptr, direction_dx_layout)); + workspace_ptr = static_cast(workspace_ptr) + + direction_dx_layout.span().dist_byte(); + } + // cell states for each layer and each direction + std::vector dstates_arr; + for (int layer = 0; layer < num_layers; ++layer) { + for (int d = 0; d < D; ++d) { + TensorNDArray cell_states; + cell_states.reserve(dstates.size()); + for (int i = 0; i < dstates.size(); ++i) { + size_t offset = (layer * D + d) * cell_y_layout.span().dist_byte(); + TensorND dhx_cell{ + static_cast(dstates[i].raw_ptr()) + offset, + cell_y_layout}; + memcpy(dhx_cell.raw_ptr(), static_cast(dhy[i].raw_ptr()) + offset, + cell_y_layout.span().dist_byte()); + cell_states.emplace_back(dhx_cell); + } + dstates_arr.push_back(cell_states); + } + } + + // init gradient on weight to zero + memset(dw.raw_ptr(), 0, dw.layout.span().dist_byte()); + // use cells to contain gradient + std::vector cell_grads; + size_t used_workspace_size = static_cast(workspace_ptr) - + static_cast((void*)(workspace.raw_ptr)); + workspace_ptr = + static_cast(workspace_ptr) + + get_cells( + D, num_layers, input_size, hidden_size, bias, cell_grads, dw, + Workspace( + workspace.raw_ptr + used_workspace_size, + workspace.size - used_workspace_size)); + + auto add_opr = handle->create_operator(); + add_opr->param().mode = Elemwise::Mode::ADD; + auto copy_opr = handle->create_operator(); + + // initialize dx to zero + memset(dx.raw_ptr(), 0, dx.layout.span().dist_byte()); + + // calculate grads + for (int layer = num_layers - 1; layer >= 0; --layer) { + for (int d = D - 1; d >= 0; --d) { + Cell& cell = cells[layer * D + d]; + Cell& cell_grad = cell_grads[layer * D + d]; + size_t input_size = layer_inputs[layer].layout.shape[2]; + const TensorND& x_arr = layer_inputs[layer]; + const TensorND& y_arr = layer_outputs[layer]; + TensorLayout x_layout = {{batch_size, input_size}, dtype}; + + // tmp tensors + void* tmp_workspace_ptr = workspace_ptr; + TensorND dwi_tmp{tmp_workspace_ptr, cell_grad.weight_ih.layout}; + tmp_workspace_ptr = static_cast(tmp_workspace_ptr) + + dwi_tmp.layout.span().dist_byte(); + TensorND dwh_tmp{tmp_workspace_ptr, cell_grad.weight_hh.layout}; + tmp_workspace_ptr = static_cast(tmp_workspace_ptr) + + dwh_tmp.layout.span().dist_byte(); + TensorND dbias_tmp{tmp_workspace_ptr, cell_grad.bias_ih.layout}; + tmp_workspace_ptr = static_cast(tmp_workspace_ptr) + + dbias_tmp.layout.span().dist_byte(); + size_t used_workspace_size = + static_cast(tmp_workspace_ptr) - + static_cast((void*)(workspace.raw_ptr)); + + for (int i = 0; i < seq_len; ++i) { + // reverse time step (not seq step). Here step means seq step + size_t step = i; + if (d == 0) + step = seq_len - i - 1; + TensorND x{ + static_cast(x_arr.raw_ptr()) + + step * x_layout.span().dist_byte(), + x_layout}, + y{static_cast(y_arr.raw_ptr()) + + (step * D + d) * cell_y_layout.span().dist_byte(), + cell_y_layout}; + const TensorNDArray& cell_states = cell_seq_states[layer * D + d][step]; + TensorNDArray& dstates_new = dstates_arr[layer * D + d]; + // dy should be d_output + d_hidden + TensorND dy_t{ + static_cast(layer_output_grad.raw_ptr()) + + (step * D + d) * cell_y_layout.span().dist_byte(), + cell_y_layout}; + add_opr->exec({dstates_new[0], dy_t}, dy_t); + // dx for layer 0 has a different size + TensorND dx_t; + if (layer == 0) + dx_t = {static_cast(L0_direction_dx_arr[d].raw_ptr()) + + step * x_layout.span().dist_byte(), + x_layout}; + else + dx_t = {static_cast(direction_dx_arr[d].raw_ptr()) + + step * x_layout.span().dist_byte(), + x_layout}; + TensorNDArray douts = {dy_t}; + for (int s = 1; s < dstates_new.size(); ++s) + douts.push_back(dstates_new[s]); + cell.backward( + handle, nonlineMode, x, cell_states, y, douts, dx_t, + dstates_new, dwi_tmp, dwh_tmp, dbias_tmp, + Workspace( + workspace.raw_ptr + used_workspace_size, + workspace.size - used_workspace_size)); + // add step gradient to overall gradient + add_opr->exec({dwi_tmp, cell_grad.weight_ih}, cell_grad.weight_ih); + add_opr->exec({dwh_tmp, cell_grad.weight_hh}, cell_grad.weight_hh); + add_opr->exec({dbias_tmp, cell_grad.bias_ih}, cell_grad.bias_ih); + add_opr->exec({dbias_tmp, cell_grad.bias_hh}, cell_grad.bias_hh); + } + } + // add gradient of different directions to layer_output_grad. + // Layer 0 to dx + if (layer == 0) { + for (int i = 0; i < D; ++i) + add_opr->exec({L0_direction_dx_arr[i], dx}, dx); + } else { + if (D == 1) + copy_opr->exec(direction_dx_arr[0], layer_output_grad); + else { // D == 2, arrange as [(d0, d1), (d0, d1), ...] + for (size_t t = 0; t < seq_len; ++t) { + size_t offset = t * D * cell_y_layout.span().dist_byte(); + for (size_t d = 0; d < D; ++d) { + TensorND src{ + static_cast(direction_dx_arr[d].raw_ptr()) + + offset, + cell_y_layout}; + TensorND dst{ + static_cast(layer_output_grad.raw_ptr()) + + offset + d * cell_y_layout.span().dist_byte(), + cell_y_layout}; + copy_opr->exec(src, dst); + } + } + } + } + } +} + +} // namespace rnn +} // namespace naive +} // namespace megdnn diff --git a/dnn/src/naive/rnn/opr_impl.cpp b/dnn/src/naive/rnn/opr_impl.cpp new file mode 100644 index 00000000..aaa8b4a1 --- /dev/null +++ b/dnn/src/naive/rnn/opr_impl.cpp @@ -0,0 +1,196 @@ +/** + * \file dnn/src/naive/rnn/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/naive/rnn/opr_impl.h" +#include "megdnn/dtype.h" +#include "megdnn/oprs/base.h" +#include "megdnn/oprs/general.h" +#include "src/common/opr_delegate.h" +#include "src/common/rnn.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" +#include "src/naive/matrix_mul/opr_impl.h" +#include "src/naive/rnn/funcs.h" +#include "src/naive/rnn/rnn.h" + +#include + +namespace megdnn { +namespace naive { + +using rnn::RNNCellWeightWrapper; + +void RNNImpl::exec( + _megdnn_tensor_in input, _megdnn_tensor_in hx, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, + _megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, + _megdnn_workspace workspace) { + auto _param = param(); + size_t D = _param.bidirectional ? 2 : 1; + size_t num_layers = _param.num_layers; + size_t input_size = input.layout.shape[2]; + std::vector cells; + size_t used_workspace_size = rnn::get_cells( + D, num_layers, input_size, _param.hidden_size, _param.bias, cells, + flatten_weights, workspace); + + Workspace new_workspace( + workspace.raw_ptr + used_workspace_size, + workspace.size - used_workspace_size); + TensorNDArray states, states_new; + states.push_back(hx); + states_new.push_back(hy); + rnn::exec_internal( + cells, input, states, states_new, output, reserve_space, num_layers, D, + this->handle(), new_workspace); +} + +size_t RNNImpl::get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& hx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& reserve_space) { + size_t workspace_size = rnn::get_workspace_in_bytes( + input, flatten_weights, param().hidden_size, param().bidirectional ? 2 : 1, + this->handle()); + if (!param().bias) { // use fake bias (all 0) + TensorLayout bias_layout = {{param().hidden_size}, flatten_weights.dtype}; + workspace_size += bias_layout.span().dist_byte(); + } + workspace_size += output.span().dist_byte(); + return workspace_size; +} + +size_t RNNImpl::get_reserve_size_in_bytes(const TensorLayout& input) { + size_t num_layers = param().num_layers; + size_t D = param().bidirectional ? 2 : 1; + size_t seq_len = input.shape[0]; + size_t batch_size = input.shape[1]; + TensorLayout state_layout{{batch_size, param().hidden_size}, input.dtype}; + return num_layers * D * seq_len * state_layout.span().dist_byte(); +} + +void RNNBackwardImpl::exec( + _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, + _megdnn_tensor_in dy, _megdnn_tensor_in dhy, _megdnn_tensor_in flatten_weights, + _megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx, + _megdnn_tensor_out dw, _megdnn_workspace workspace) { + TensorNDArray layer_inputs; + // layer_inputs.push_back(x); + TensorNDArray layer_outputs; + std::vector> cell_seq_states; + size_t num_layers = param().num_layers; + size_t D = param().bidirectional ? 2 : 1; + // size_t seq_len = x.layout.shape[0]; + // size_t batch_size = x.layout.shape[1]; + size_t input_size = x.layout.shape[2]; + size_t hidden_size = param().hidden_size; + size_t used_workspace_size = 0; + + // get cells + std::vector cells; + // workspace_ptr = static_cast(workspace_ptr) + + used_workspace_size += rnn::get_cells( + D, num_layers, input_size, hidden_size, param().bias, cells, + flatten_weights, workspace); + + // extract intermedia states from reserve space + /*for (int layer = 0; layer < num_layers; ++layer) { + TensorND layer_output{workspace_ptr, y.layout}; + workspace_ptr = static_cast(workspace_ptr) + + layer_output.layout.span().dist_byte(); for (int d = 0; d < D; ++d) { + cell_seq_states.push_back(std::vector()); + // reverse direction is stored with reversed order of sequence order + for (int i = 0; i < seq_len; ++i) { + size_t step = i; + if (d == 1) step = seq_len - i - 1; + size_t offset = ((layer * D + d) * seq_len + step) * + cell_output_layout.span().dist_byte(); TensorND + hy{static_cast(reserve_space.raw_ptr) + offset, cell_output_layout}; + // states + cell_seq_states[cell_seq_states.size() - 1].push_back({hy}); + // output + offset = i * D * cell_output_layout.span().dist_byte(); + memcpy(static_cast(layer_output.raw_ptr) + offset, + hy.raw_ptr, hy.layout.span().dist_byte()); + } + } + cell_seq_outputs.push_back(layer_output); + if (layer != num_layers - 1) layer_inputs.push_back(layer_output); + }*/ + // nonlinear mode + param::RNNCell::NonlineMode nonlineMode; + using ModeRNN = param::RNN::NonlineMode; + using ModeRNNCell = param::RNNCell::NonlineMode; + switch (param().nonlineMode) { + case ModeRNN::RELU: + nonlineMode = ModeRNNCell::RELU; + break; + case ModeRNN::TANH: + nonlineMode = ModeRNNCell::TANH; + break; + } + + // get formatted inputs + Workspace new_workspace = Workspace( + workspace.raw_ptr + used_workspace_size, + workspace.size - used_workspace_size); + used_workspace_size += rnn::get_inputs_for_exec( + x, y, reserve_space, num_layers, D, hidden_size, cells, layer_inputs, + layer_outputs, cell_seq_states, nonlineMode, new_workspace); + + // dhy arr, dhx arr + TensorNDArray dhy_arr = {dhy}, dhx_arr = {dhx}; + + // exec + /*size_t used_workspace_size = static_cast(workspace_ptr) - + static_cast((void*)workspace.raw_ptr);*/ + new_workspace = Workspace( + workspace.raw_ptr + used_workspace_size, + workspace.size - used_workspace_size); + rnn::backward_exec_internal( + cells, D, num_layers, input_size, param().bias, nonlineMode, layer_inputs, + layer_outputs, cell_seq_states, dy, dhy_arr, dx, dhx_arr, dw, + this->handle(), new_workspace); +} + +size_t RNNBackwardImpl::get_workspace_in_bytes( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& flatten_weights, const TensorLayout& reserve_space, + const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw) { + size_t D = param().bidirectional ? 2 : 1; + size_t num_layers = param().num_layers; + size_t hidden_size = param().hidden_size; + size_t gate_hidden_size = hidden_size; + size_t max_input_size = std::max(x.shape[2], D * hidden_size); + + size_t workspace_size = RNNCellWeightWrapper::backward_workspace_size_in_bytes( + this->handle(), x.shape[1], param().hidden_size, max_input_size, x.dtype); + if (!param().bias) { // use fake bias (all 0) + TensorLayout bias_layout = {{gate_hidden_size}, flatten_weights.dtype}; + workspace_size += bias_layout.span().dist_byte() * + 2; // times 2 because another bias is allocated in + // backward_exec_internal + } + workspace_size += num_layers * y.span().dist_byte(); + // add back exec workspace size + workspace_size += y.span().dist_byte() * 2; + workspace_size += x.span().dist_byte() * 2; + TensorLayout wih{{gate_hidden_size, max_input_size}, flatten_weights.dtype}; + TensorLayout whh{{gate_hidden_size, hidden_size}, flatten_weights.dtype}; + TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype}; + workspace_size += wih.span().dist_byte(); + workspace_size += whh.span().dist_byte(); + workspace_size += bias.span().dist_byte(); + return workspace_size; +} +} // namespace naive + +} // namespace megdnn diff --git a/dnn/src/naive/rnn/opr_impl.h b/dnn/src/naive/rnn/opr_impl.h new file mode 100644 index 00000000..8ae6809c --- /dev/null +++ b/dnn/src/naive/rnn/opr_impl.h @@ -0,0 +1,53 @@ +/** + * \file dnn/src/naive/rnn/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class RNNImpl : public RNN { +public: + using RNN::RNN; + + void exec( + _megdnn_tensor_in input, _megdnn_tensor_in hx, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, + _megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, + _megdnn_workspace workspace); + + size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& hx, + const TensorLayout& flatten_weights, const TensorLayout& output, + const TensorLayout& hy, const TensorLayout& reserve_space); + size_t get_reserve_size_in_bytes(const TensorLayout& input); +}; + +class RNNBackwardImpl : public RNNBackward { +public: + using RNNBackward::RNNBackward; + + virtual void exec( + _megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, + _megdnn_tensor_in dy, _megdnn_tensor_in dhy, + _megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space, + _megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw, + _megdnn_workspace workspace); + + virtual size_t get_workspace_in_bytes( + const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, + const TensorLayout& dy, const TensorLayout& dhy, + const TensorLayout& flatten_weights, const TensorLayout& reserve_space, + const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw); +}; + +} // namespace naive +} // namespace megdnn diff --git a/dnn/src/naive/rnn/rnn.cpp b/dnn/src/naive/rnn/rnn.cpp new file mode 100644 index 00000000..8ea43496 --- /dev/null +++ b/dnn/src/naive/rnn/rnn.cpp @@ -0,0 +1,285 @@ +/** + * \file dnn/src/naive/rnn/rnn.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/naive/rnn/rnn.h" + +#include +#include + +namespace megdnn { +namespace naive { +namespace rnn { + +CellWeightsWrapperBase::CellWeightsWrapperBase( + void* weight_ptr, size_t hidden_size, size_t input_size, size_t num_chunks, + bool has_bias, DType dtype, _megdnn_workspace workspace) { + // weight_ih: [gate_hidden_size, input_size] + // weight_hh: [gate_hidden_size, hidden_size] + // bias_ih: [gate_hidden_size] + // bias_hh: [gate_hidden_size] + size_t gate_hidden_size = num_chunks * hidden_size; + TensorLayout weight_ih_layout{{gate_hidden_size, input_size}, dtype}; + TensorLayout weight_hh_layout{{gate_hidden_size, hidden_size}, dtype}; + TensorLayout bias_layout{{gate_hidden_size}, dtype}; + this->_weight_size = 0; + this->weight_ih = TensorND(weight_ptr, weight_ih_layout); + this->_weight_size += weight_ih_layout.span().dist_byte(); + this->weight_hh = TensorND( + static_cast(weight_ptr) + this->_weight_size, weight_hh_layout); + this->_weight_size += weight_hh_layout.span().dist_byte(); + if (has_bias) { + this->bias_ih = TensorND( + static_cast(weight_ptr) + this->_weight_size, bias_layout); + this->_weight_size += bias_layout.span().dist_byte(); + this->bias_hh = TensorND( + static_cast(weight_ptr) + this->_weight_size, bias_layout); + this->_weight_size += bias_layout.span().dist_byte(); + this->_workspace_size = 0; + } else { + this->bias_ih = TensorND(workspace.raw_ptr, bias_layout); + this->bias_hh = TensorND(workspace.raw_ptr, bias_layout); + memset(workspace.raw_ptr, 0, bias_layout.span().dist_byte()); + this->_workspace_size = bias_layout.span().dist_byte(); + } +} + +size_t CellWeightsWrapperBase::weight_size_in_bytes() const { + return this->_weight_size; +} + +size_t CellWeightsWrapperBase::workspace_size_in_bytes() const { + return this->_workspace_size; +} + +size_t CellWeightsWrapperBase::num_states() const { + return 1; +} + +void CellWeightsWrapperBase::backward( + Handle* handle, param::RNNCell::NonlineMode nonlineMode, _megdnn_tensor_in x, + const TensorNDArray& states, _megdnn_tensor_in y, const TensorNDArray& douts, + _megdnn_tensor_out dx, TensorNDArray& dstates, _megdnn_tensor_out dwi, + _megdnn_tensor_out dwh, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) const { + auto dy = douts[0]; + using NonlineMode = param::RNNCell::NonlineMode; + using Mode = Elemwise::Mode; + auto elemwise_opr = handle->create_operator(); + TensorND tmp = {workspace.raw_ptr, dy.layout}; + auto new_workspace = Workspace( + workspace.raw_ptr + tmp.layout.span().dist_byte(), + workspace.size - tmp.layout.span().dist_byte()); + switch (nonlineMode) { + case (NonlineMode::IDENTITY): + memcpy(tmp.raw_ptr(), dy.raw_ptr(), dy.layout.span().dist_byte()); + break; + case (NonlineMode::TANH): + elemwise_opr->param().mode = Mode::TANH_GRAD; + elemwise_opr->exec({y, dy}, tmp); + break; + case (NonlineMode::RELU): + elemwise_opr->param().mode = Mode::SWITCH_GT0; + elemwise_opr->exec({y, dy}, tmp); + break; + } + auto matrixmul_opr = handle->create_operator(); + matrixmul_opr->param().transposeA = false; + matrixmul_opr->param().transposeB = false; + // dx + matrixmul_opr->exec(tmp, this->weight_ih, dx, new_workspace); + // dhx + matrixmul_opr->exec(tmp, this->weight_hh, dstates[0], new_workspace); + // dwi + matrixmul_opr->param().transposeA = true; + matrixmul_opr->exec(tmp, x, dwi, new_workspace); + // dwh + matrixmul_opr->exec(tmp, states[0], dwh, new_workspace); + // dbias + auto sum_opr = handle->create_operator(); + sum_opr->param().mode = ReduceForward::Mode::SUM; + sum_opr->param().axis = 0; + TensorND dbias_expanded = { + dbias.raw_ptr(), {{1, dbias.layout.shape[0]}, dbias.layout.dtype}}; + sum_opr->exec(tmp, dbias_expanded, new_workspace); +} + +size_t CellWeightsWrapperBase::backward_workspace_size_in_bytes( + Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, + size_t num_chunks, DType dtype) { + size_t gate_hidden_size = hidden_size * num_chunks; + TensorLayout tmp = {{batch_size, gate_hidden_size}, dtype}; + TensorLayout bias_expanded = {{1, gate_hidden_size}, dtype}; + TensorLayout wih = {{gate_hidden_size, input_size}, dtype}; + TensorLayout whh = {{gate_hidden_size, hidden_size}, dtype}; + TensorLayout x = {{batch_size, input_size}, dtype}; + TensorLayout hx = {{batch_size, hidden_size}, dtype}; + size_t workspace_size = 0; + auto matrixmul_opr = handle->create_operator(); + matrixmul_opr->param().transposeA = false; + matrixmul_opr->param().transposeB = false; + // dx + workspace_size = std::max( + workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, wih, x)); + // dhx + workspace_size = std::max( + workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, whh, hx)); + // dwi + matrixmul_opr->param().transposeA = true; + workspace_size = std::max( + workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, x, wih)); + // dwh + workspace_size = std::max( + workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, hx, whh)); + // dbias + auto sum_opr = handle->create_operator(); + sum_opr->param().mode = ReduceForward::Mode::SUM; + sum_opr->param().axis = 0; + workspace_size = std::max( + workspace_size, sum_opr->get_workspace_in_bytes(tmp, bias_expanded)); + workspace_size += tmp.span().dist_byte(); + return workspace_size; +} + +RNNCellWeightWrapper::RNNCellWeightWrapper( + void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, + DType dtype, _megdnn_workspace workspace) + : CellWeightsWrapperBase( + weight_ptr, hidden_size, input_size, 1, has_bias, dtype, workspace) {} + +size_t RNNCellWeightWrapper::backward_workspace_size_in_bytes( + Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, + DType dtype) { + return CellWeightsWrapperBase::backward_workspace_size_in_bytes( + handle, batch_size, hidden_size, input_size, 1, dtype); +} + +LSTMCellWeightWrapper::LSTMCellWeightWrapper( + void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, + DType dtype, _megdnn_workspace workspace) + : CellWeightsWrapperBase( + weight_ptr, hidden_size, input_size, 4, has_bias, dtype, workspace) {} + +size_t LSTMCellWeightWrapper::num_states() const { + return 2; +} + +size_t LSTMCellWeightWrapper::backward_workspace_size_in_bytes( + Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, + DType dtype) { + // get gates size + size_t gate_hidden_size = 4 * hidden_size; + auto lstm_opr = handle->create_operator(); + TensorLayout x = {{batch_size, input_size}, dtype}; + TensorLayout weight_ih = {{gate_hidden_size, input_size}, dtype}; + TensorLayout weight_hh = {{gate_hidden_size, hidden_size}, dtype}; + TensorLayout bias = {{gate_hidden_size}, dtype}; + TensorLayout h = {{batch_size, hidden_size}, dtype}; + TensorLayout gates, h_new, c_new; + lstm_opr->deduce_layout( + x, weight_ih, bias, h, weight_hh, bias, h, h_new, c_new, gates); + return CellWeightsWrapperBase::backward_workspace_size_in_bytes( + handle, batch_size, hidden_size, input_size, 4, dtype) + + gates.span().dist_byte() * 2 + c_new.span().dist_byte(); +} + +void LSTMCellWeightWrapper::backward( + Handle* handle, + param::RNNCell::NonlineMode nonlineMode, // nonlineMode must be identity + _megdnn_tensor_in x, const TensorNDArray& states, _megdnn_tensor_in y, + const TensorNDArray& douts, _megdnn_tensor_out dx, TensorNDArray& dstates, + _megdnn_tensor_out dwi, _megdnn_tensor_out dwh, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) const { + size_t used_workspace_size = 0; + // get gates + auto lstm_opr = handle->create_operator(); + TensorLayout gates, h_new, c_new; + lstm_opr->deduce_layout( + x.layout, weight_ih.layout, bias_ih.layout, states[0].layout, + weight_hh.layout, bias_hh.layout, states[1].layout, h_new, c_new, gates); + TensorND gates_tensor{workspace.raw_ptr, gates}; + used_workspace_size += gates.span().dist_byte(); + TensorND gates_grad{workspace.raw_ptr + used_workspace_size, gates}; + used_workspace_size += gates.span().dist_byte(); + TensorND tanh_cy{workspace.raw_ptr + used_workspace_size, y.layout}; + used_workspace_size += tanh_cy.layout.span().dist_byte(); + Workspace new_workspace = Workspace( + workspace.raw_ptr + used_workspace_size, + workspace.size - used_workspace_size); + // temporarily use dstates to store hy, cy + // only gates and cy needed, other output will be cleared afterwards + lstm_opr->exec( + x, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states[1], dstates[0], + dstates[1], gates_tensor, + new_workspace); // no information left in the workspace + // i, f, o, g + TensorLayout single_gate = {{gates.shape[0], gates.shape[1] / 4}, gates.dtype}; + TensorND i, f, o, g, i_grad, f_grad, o_grad, + g_grad; // grad refers to the grad of gates before activation + i = {gates_tensor.raw_ptr(), single_gate}; + f = {static_cast(gates_tensor.raw_ptr()) + single_gate.span().dist_byte(), + single_gate}; + o = {static_cast(f.raw_ptr()) + single_gate.span().dist_byte(), + single_gate}; + g = {static_cast(o.raw_ptr()) + single_gate.span().dist_byte(), + single_gate}; + i_grad = {gates_grad.raw_ptr(), single_gate}; + f_grad = { + static_cast(i_grad.raw_ptr()) + single_gate.span().dist_byte(), + single_gate}; + o_grad = { + static_cast(f_grad.raw_ptr()) + single_gate.span().dist_byte(), + single_gate}; + g_grad = { + static_cast(o_grad.raw_ptr()) + single_gate.span().dist_byte(), + single_gate}; + // activation + auto elem_opr = handle->create_operator(); + elem_opr->param().mode = Elemwise::Mode::SIGMOID; + elem_opr->exec({i}, i); + elem_opr->exec({f}, f); + elem_opr->exec({o}, o); + elem_opr->param().mode = Elemwise::Mode::TANH; + elem_opr->exec({g}, g); + elem_opr->exec({dstates[1]}, tanh_cy); + auto mul_opr = handle->create_operator(); + mul_opr->param().mode = Elemwise::Mode::MUL; + // use dstates[0] as tmp tensor to store dhy * tanh_cy + mul_opr->exec({douts[0], tanh_cy}, dstates[0]); + elem_opr->param().mode = Elemwise::Mode::SIGMOID_GRAD; + elem_opr->exec({o, dstates[0]}, o_grad); // grad of gate o + // use dstates[0] as tmp tensor to store dhy * o + mul_opr->exec({douts[0], o}, dstates[0]); + elem_opr->param().mode = Elemwise::Mode::TANH_GRAD; + elem_opr->exec({tanh_cy, dstates[0]}, dstates[1]); // grad of cy from hy + elem_opr->param().mode = Elemwise::Mode::ADD; + elem_opr->exec({douts[1], dstates[1]}, dstates[1]); // true grad of cy + // use dstates[0] as tmp tensor to store dcy * cx + mul_opr->exec({dstates[1], states[1]}, dstates[0]); + elem_opr->param().mode = Elemwise::Mode::SIGMOID_GRAD; + elem_opr->exec({f, dstates[0]}, f_grad); // grad of gate f + // use dstates[0] as tmp tensor to store dcy * g + mul_opr->exec({dstates[1], g}, dstates[0]); + elem_opr->exec({i, dstates[0]}, i_grad); // grad of gate i + // use dstates[0] as tmp tensor to store dcy * i + mul_opr->exec({dstates[1], i}, dstates[0]); + elem_opr->param().mode = Elemwise::Mode::TANH_GRAD; + elem_opr->exec({g, dstates[0]}, g_grad); // grad of gate g + + // grad if cx + mul_opr->exec({dstates[1], f}, dstates[1]); + TensorNDArray base_dstates = {dstates[0]}; + CellWeightsWrapperBase::backward( + handle, nonlineMode, x, {states[0]}, gates_tensor, {gates_grad}, dx, + base_dstates, dwi, dwh, dbias, new_workspace); +} + +} // namespace rnn +} // namespace naive +} // namespace megdnn diff --git a/dnn/src/naive/rnn/rnn.h b/dnn/src/naive/rnn/rnn.h new file mode 100644 index 00000000..bb5fdf08 --- /dev/null +++ b/dnn/src/naive/rnn/rnn.h @@ -0,0 +1,73 @@ +/** + * \file dnn/src/naive/rnn/rnn.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" +#include "megdnn/oprs/general.h" + +namespace megdnn { +namespace naive { +namespace rnn { + +class CellWeightsWrapperBase { +private: + size_t _weight_size, _workspace_size; + +public: + TensorND weight_ih, weight_hh, bias_ih, bias_hh; + // if no bias, will create dummy bias tensor from workspace + CellWeightsWrapperBase( + void* weight_ptr, size_t hidden_size, size_t input_size, size_t num_chunks, + bool has_bias, DType dtype, _megdnn_workspace workspace); + size_t weight_size_in_bytes() const; + size_t workspace_size_in_bytes() const; + static size_t backward_workspace_size_in_bytes( + Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, + size_t num_chunks, DType dtype); + virtual void backward( + Handle* handle, param::RNNCell::NonlineMode nonlineMode, + _megdnn_tensor_in x, const TensorNDArray& states, _megdnn_tensor_in y, + const TensorNDArray& douts, _megdnn_tensor_out dx, TensorNDArray& dstates, + _megdnn_tensor_out dwi, _megdnn_tensor_out dwh, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) const; + virtual size_t num_states() const; +}; + +class RNNCellWeightWrapper : public CellWeightsWrapperBase { +public: + RNNCellWeightWrapper( + void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, + DType dtype, _megdnn_workspace workspace); + + static size_t backward_workspace_size_in_bytes( + Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, + DType dtype); +}; + +class LSTMCellWeightWrapper : public CellWeightsWrapperBase { +public: + LSTMCellWeightWrapper( + void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, + DType dtype, _megdnn_workspace workspace); + static size_t backward_workspace_size_in_bytes( + Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, + DType dtype); + size_t num_states() const override; + void backward( + Handle* handle, param::RNNCell::NonlineMode nonlineMode, + _megdnn_tensor_in x, const TensorNDArray& states, _megdnn_tensor_in y, + const TensorNDArray& douts, _megdnn_tensor_out dx, TensorNDArray& dstates, + _megdnn_tensor_out dwi, _megdnn_tensor_out dwh, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) const override; +}; + +} // namespace rnn +} // namespace naive +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/naive/rnn/template_impl.cpp b/dnn/src/naive/rnn/template_impl.cpp new file mode 100644 index 00000000..d3f870ad --- /dev/null +++ b/dnn/src/naive/rnn/template_impl.cpp @@ -0,0 +1,41 @@ +/** + * \file dnn/src/naive/rnn/template_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/naive/rnn/funcs.h" + +namespace megdnn { +namespace naive { +namespace rnn { + +template <> +void cell_opr_exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in bias_hh, const TensorNDArray& states, + TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle) { + auto opr = handle->create_operator(); + opr->exec( + input, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states_new[0], + workspace); +} + +template <> +size_t cell_opr_get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& weight_hh, const TensorLayout& bias_ih, + const TensorLayout& bias_hh, const TensorLayout& hx, Handle* handle) { + auto cell_opr = handle->create_operator(); + return cell_opr->get_workspace_in_bytes( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, hx); +} + +} // namespace rnn +} // namespace naive +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/naive/rnn_cell/opr_impl.cpp b/dnn/src/naive/rnn_cell/opr_impl.cpp new file mode 100644 index 00000000..8a3d692c --- /dev/null +++ b/dnn/src/naive/rnn_cell/opr_impl.cpp @@ -0,0 +1,34 @@ +/** + * \file dnn/src/naive/rnn_cell/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "src/naive/rnn_cell/opr_impl.h" +#include "src/common/rnn_cell.h" + +namespace megdnn { +namespace naive { +size_t RNNCellImpl::get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& dst) { + return megdnn::rnn_cell::get_workspace_in_bytes( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, handle()); +} + +void RNNCellImpl::exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, + _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + megdnn::rnn_cell::exec( + input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, workspace, + param().nonlineMode, handle()); +} +} // namespace naive +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/naive/rnn_cell/opr_impl.h b/dnn/src/naive/rnn_cell/opr_impl.h new file mode 100644 index 00000000..fe73edb5 --- /dev/null +++ b/dnn/src/naive/rnn_cell/opr_impl.h @@ -0,0 +1,33 @@ +/** + * \file dnn/src/naive/rnn_cell/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class RNNCellImpl : public RNNCell { +public: + using RNNCell::RNNCell; + void exec( + _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, + _megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, + _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& weight_ih, + const TensorLayout& bias_ih, const TensorLayout& hx, + const TensorLayout& weight_hh, const TensorLayout& bias_hh, + const TensorLayout& dst) override; +}; + +} // namespace naive +} // namespace megdnn \ No newline at end of file diff --git a/dnn/test/common/deduce_layout_proxy.h b/dnn/test/common/deduce_layout_proxy.h index 17afc1dd..be71a2ec 100644 --- a/dnn/test/common/deduce_layout_proxy.h +++ b/dnn/test/common/deduce_layout_proxy.h @@ -68,6 +68,15 @@ struct DeduceLayoutProxy { }; template +struct DeduceLayoutProxy { + static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) { + megdnn_assert(layouts.size() == 6); + opr->deduce_layout( + layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5]); + } +}; + +template struct DeduceLayoutProxy { static void deduce_layout(Opr*, TensorLayoutArray&) {} }; diff --git a/dnn/test/common/elemwise.cpp b/dnn/test/common/elemwise.cpp index 90dfbf86..273b0b25 100644 --- a/dnn/test/common/elemwise.cpp +++ b/dnn/test/common/elemwise.cpp @@ -248,7 +248,6 @@ DEF_TEST(fuse_mul_add4) { } DEF_TEST(rmulh) { - using Mode = ElemwiseForward::Param::Mode; Checker checker(handle); @@ -315,7 +314,7 @@ DEF_TEST(unary1) { // case float UniformFloatRNG rng(1e-2, 6e1); checker.set_rng(0, &rng); - checker.set_epsilon(1e-5); + checker.set_epsilon(1e-5); checker.set_dtype(0, dtype::Float32()); BUILD_UNARY_TEST_CASE_FLOAT } @@ -900,7 +899,7 @@ DEF_TEST(unary_negative_stride) { UniformFloatRNG rng(1e-2, 6e1); checker.set_rng(0, &rng); - checker.set_epsilon(1e-5); + checker.set_epsilon(1e-5); BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT; } diff --git a/dnn/test/common/rnn.h b/dnn/test/common/rnn.h new file mode 100644 index 00000000..691b2cae --- /dev/null +++ b/dnn/test/common/rnn.h @@ -0,0 +1,51 @@ +/** + * \file dnn/test/common/rnn.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include + +#include "megdnn/basic_types.h" +#include "megdnn/opr_param_defs.h" + +namespace megdnn { +namespace test { +namespace rnn { +struct TestArg { + param::RNN param; + TensorShape input, hx, flatten_weights; + TestArg(param::RNN param, TensorShape input, TensorShape hx, + TensorShape flatten_weights) + : param(param), input(input), hx(hx), flatten_weights(flatten_weights) {} +}; + +inline std::vector get_args() { + std::vector args; + size_t batch_size = 2; + size_t input_size = 3; + size_t hidden_size = 2; + size_t seq_len = 2; + size_t gate_hidden_size = hidden_size; + param::RNN param; + param.num_layers = 1; + param.bidirectional = false; + param.bias = false; + param.hidden_size = hidden_size; + param.nonlineMode = param::RNN::NonlineMode::RELU; + + args.emplace_back( + param, TensorShape{seq_len, batch_size, input_size}, + TensorShape{batch_size, hidden_size}, + TensorShape{gate_hidden_size, input_size + hidden_size}); + return args; +} + +} // namespace rnn +} // namespace test +} // namespace megdnn \ No newline at end of file diff --git a/dnn/test/naive/rnn.cpp b/dnn/test/naive/rnn.cpp new file mode 100644 index 00000000..11e9162e --- /dev/null +++ b/dnn/test/naive/rnn.cpp @@ -0,0 +1,80 @@ +/** + * \file dnn/test/naive/rnn.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/common/rnn.h" +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/naive/fixture.h" + +namespace megdnn { +namespace test { + +/*TEST_F(NAIVE, RNN) { + std::vector args = rnn::get_args(); + Checker checker(handle()); + for (auto&& arg : args) { + checker.set_param(arg.param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_dtype(3, dtype::Float32()) + .set_dtype(4, dtype::Float32()) + .set_dtype(5, dtype::Float32()) + .execs({arg.input, arg.hx, arg.flatten_weights, {}, {}, {}}); + } +}*/ + +TEST_F(NAIVE, RNN_HAND_MADE) { + Checker checker(handle(), false); + size_t batch_size = 2; + size_t input_size = 3; + size_t hidden_size = 2; + size_t seq_len = 2; + size_t gate_hidden_size = hidden_size; + RNN::Param param; + param.num_layers = 1; + param.bidirectional = false; + param.bias = false; + param.hidden_size = hidden_size; + param.nonlineMode = param::RNN::NonlineMode::RELU; + checker.set_param(param).exect( + Testcase{ + TensorValue( + {seq_len, batch_size, input_size}, dtype::Float32(), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input + TensorValue( + {batch_size, hidden_size}, dtype::Float32(), + {2, 1, 3, 5}), // hx + TensorValue( + {gate_hidden_size, input_size + hidden_size}, + dtype::Float32(), + {3, 6, 1, 3, 2, 7, 9, 3, 5, 1}), // weights + {}, + {}, + {}}, + Testcase{ + {}, + {}, + {}, + TensorValue( + {seq_len, batch_size, hidden_size}, dtype::Float32(), + {39, 39, 90, 84, 300, 216, 546, 366}), // output + TensorValue( + {batch_size, hidden_size}, dtype::Float32(), + {21, 11, 42, 20}), // hy + TensorValue( + {1, 2, 2, 2}, dtype::Float32(), + {2, 1, 3, 5, 21, 11, 42, 20}) // reserve space + }); +} + +} // namespace test +} // namespace megdnn diff --git a/imperative/python/megengine/module/__init__.py b/imperative/python/megengine/module/__init__.py index 73b12474..c700c095 100644 --- a/imperative/python/megengine/module/__init__.py +++ b/imperative/python/megengine/module/__init__.py @@ -36,5 +36,6 @@ from .padding import Pad from .pixel_shuffle import PixelShuffle from .pooling import AvgPool2d, MaxPool2d from .quant_dequant import DequantStub, QuantStub +from .rnn import LSTM, RNN, LSTMCell, RNNCell from .sequential import Sequential from .sliding_window import SlidingWindow, SlidingWindowTranspose diff --git a/imperative/python/megengine/module/rnn.py b/imperative/python/megengine/module/rnn.py new file mode 100644 index 00000000..571e3a19 --- /dev/null +++ b/imperative/python/megengine/module/rnn.py @@ -0,0 +1,396 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import math +import numbers +from abc import abstractmethod +from typing import Optional, Tuple + +import numpy as np + +from ..core._imperative_rt.core2 import apply +from ..core.ops import builtin +from ..device import is_cuda_available +from ..functional import concat, expand_dims, repeat, stack, zeros +from ..functional.nn import concat +from ..tensor import Parameter, Tensor +from . import init +from .module import Module + + +class RNNCellBase(Module): + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool, + num_chunks: int, + device=None, + dtype=None, + ) -> None: + # num_chunks indicates the number of gates + super(RNNCellBase, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + + # initialize weights + common_kwargs = {"device": device, "dtype": dtype} + self.gate_hidden_size = num_chunks * hidden_size + self.weight_ih = Parameter( + np.random.uniform(size=(self.gate_hidden_size, input_size)).astype( + np.float32 + ), + **common_kwargs, + ) + self.weight_hh = Parameter( + np.random.uniform(size=(self.gate_hidden_size, hidden_size)).astype( + np.float32 + ), + **common_kwargs, + ) + if bias: + self.bias_ih = Parameter( + np.random.uniform(size=(self.gate_hidden_size)).astype(np.float32), + **common_kwargs, + ) + self.bias_hh = Parameter( + np.random.uniform(size=(self.gate_hidden_size)).astype(np.float32), + **common_kwargs, + ) + else: + self.bias_ih = zeros(shape=(self.gate_hidden_size), **common_kwargs) + self.bias_hh = zeros(shape=(self.gate_hidden_size), **common_kwargs) + self.reset_parameters() + # if bias is False self.bias will remain zero + + def get_op(self): + return builtin.RNNCell() + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) + for weight in self.parameters(): + init.uniform_(weight, -stdv, stdv) + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + if hx is None: + hx = zeros( + shape=(input.shape[0], self.gate_hidden_size), + dtype=input.dtype, + device=input.device, + ) + op = self.get_op() + return apply( + op, input, self.weight_ih, self.bias_ih, hx, self.weight_hh, self.bias_hh + )[0] + # return linear(input, self.weight_ih, self.bias_ih) + linear(hx, self.weight_hh, self.bias_hh) + + +class RNNCell(RNNCellBase): + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + nonlinearity: str = "tanh", + device=None, + dtype=None, + ) -> None: + self.nonlinearity = nonlinearity + super(RNNCell, self).__init__( + input_size, hidden_size, bias, num_chunks=1, device=device, dtype=dtype + ) + # self.activate = tanh if nonlinearity == "tanh" else relu + + def get_op(self): + return builtin.RNNCell(nonlineMode=self.nonlinearity) + + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + return super().forward(input, hx) + + +class LSTMCell(RNNCellBase): + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super(LSTMCell, self).__init__( + input_size, hidden_size, bias, num_chunks=4, device=device, dtype=dtype + ) + + def get_op(self): + return builtin.LSTMCell() + + def forward( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[Tensor, Tensor]: + # hx: (h, c) + if hx is None: + h = zeros( + shape=(input.shape[0], self.hidden_size), + dtype=input.dtype, + device=input.device, + ) + c = zeros( + shape=(input.shape[0], self.hidden_size), + dtype=input.dtype, + device=input.device, + ) + else: + h, c = hx + op = self.get_op() + return apply( + op, input, self.weight_ih, self.bias_ih, h, self.weight_hh, self.bias_hh, c + )[:2] + + +def is_gpu(device: str) -> bool: + if "xpux" in device and is_cuda_available(): + return True + if "gpu" in device: + return True + return False + + +class RNNBase(Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + bias: bool = True, + batch_first: bool = False, + dropout: float = 0, + bidirectional: bool = False, + proj_size: int = 0, + device=None, + dtype=None, + ) -> None: + super(RNNBase, self).__init__() + # self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = float(dropout) + self.bidirectional = bidirectional + self.num_directions = 2 if self.bidirectional else 1 + self.proj_size = proj_size + + # check validity of dropout + if ( + not isinstance(dropout, numbers.Number) + or not 0 <= dropout <= 1 + or isinstance(dropout, bool) + ): + raise ValueError( + "Dropout should be a float in [0, 1], which indicates the probability " + "of an element to be zero" + ) + + if proj_size < 0: + raise ValueError( + "proj_size should be a positive integer or zero to disable projections" + ) + elif proj_size >= hidden_size: + raise ValueError("proj_size has to be smaller than hidden_size") + + self.cells = [] + for layer in range(self.num_layers): + self.cells.append([]) + for _ in range(self.num_directions): + self.cells[layer].append(self.create_cell(layer, device, dtype)) + # parameters have been initialized during the creation of the cells + # if flatten, then delete cells + self._flatten_parameters(device, dtype, self.cells) + + def _flatten_parameters(self, device, dtype, cells): + gate_hidden_size = cells[0][0].gate_hidden_size + size_dim1 = 0 + for layer in range(self.num_layers): + for direction in range(self.num_directions): + size_dim1 += cells[layer][direction].weight_ih.shape[1] + size_dim1 += cells[layer][direction].weight_hh.shape[1] + # if self.bias: + # size_dim1 += 2 * self.num_directions * self.num_layers + size_dim1 += 2 * self.num_directions * self.num_layers + self._flatten_weights = Parameter( + np.zeros((gate_hidden_size, size_dim1), dtype=np.float32) + ) + self.reset_parameters() + # TODO: if no bias, set the bias to zero + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) + for weight in self.parameters(): + init.uniform_(weight, -stdv, stdv) + + @abstractmethod + def create_cell(self, layer, device, dtype): + raise NotImplementedError("Cell not implemented !") + + @abstractmethod + def init_hidden(self): + raise NotImplementedError("init_hidden not implemented !") + + @abstractmethod + def get_output_from_hidden(self, hx): + raise NotImplementedError("get_output_from_hidden not implemented !") + + @abstractmethod + def apply_op(self, input, hx): + raise NotImplementedError("apply_op not implemented !") + + def _apply_fn_to_hx(self, hx, fn): + return fn(hx) + + def _stack_h_n(self, h_n): + return stack(h_n, axis=0) + + def forward(self, input: Tensor, hx=None): + if self.batch_first: + batch_size = input.shape[0] + input = input.transpose((1, 0, 2)) # [seq_len, batch_size, dim] + else: + batch_size = input.shape[1] + if hx is None: + hx = self.init_hidden(batch_size, input.device, input.dtype) + + output, h = self.apply_op(input, hx) + if self.batch_first: + output = output.transpose((1, 0, 2)) + return output, h + + if is_gpu(str(input.device)) or True: + # return output, h_n + output, h = self.apply_op(input, hx) + if self.batch_first: + output = output.transpose((1, 0, 2)) + return output, h + + order_settings = [(0, input.shape[0]), (input.shape[0] - 1, -1, -1)] + h_n = [] + for layer in range(self.num_layers): + layer_outputs = [] + for direction in range(self.num_directions): + direction_outputs = [None for _ in range(input.shape[0])] + cell = self.cells[layer][direction] + hidden = self._apply_fn_to_hx( + hx, lambda x: x[layer * self.num_directions + direction] + ) + for step in range(*(order_settings[direction])): + hidden = cell(input[step], hidden) # [batch_size, hidden_size] + direction_outputs[step] = self.get_output_from_hidden(hidden) + direction_output = stack( + direction_outputs, axis=0 + ) # [seq_len, batch_size, hidden_size] + layer_outputs.append(direction_output) + h_n.append(hidden) + layer_output = concat( + layer_outputs, axis=-1 + ) # [seq_len, batch_size, D*hidden_size] + input = layer_output + if self.batch_first: + layer_output = layer_output.transpose((1, 0, 2)) + return layer_output, self._stack_h_n(h_n) + + +class RNN(RNNBase): + def __init__(self, *args, **kwargs) -> None: + self.nonlinearity = kwargs.pop("nonlinearity", "tanh") + super(RNN, self).__init__(*args, **kwargs) + + def create_cell(self, layer, device, dtype): + if layer == 0: + input_size = self.input_size + else: + input_size = self.num_directions * self.hidden_size + return RNNCell( + input_size, self.hidden_size, self.bias, self.nonlinearity, device, dtype + ) + + def init_hidden(self, batch_size, device, dtype): + hidden_shape = ( + self.num_directions * self.num_layers, + batch_size, + self.hidden_size, + ) + return zeros(shape=hidden_shape, dtype=dtype, device=device) + + def get_output_from_hidden(self, hx): + return hx + + def apply_op(self, input, hx): + op = builtin.RNN( + num_layers=self.num_layers, + bidirectional=self.bidirectional, + bias=self.bias, + hidden_size=self.hidden_size, + proj_size=self.proj_size, + dropout=self.dropout, + nonlineMode=self.nonlinearity, + ) + output, h = apply(op, input, hx, self._flatten_weights)[:2] + output = output + h.sum() * 0 + h = h + output.sum() * 0 + return output, h + + +class LSTM(RNNBase): + def __init__(self, *args, **kwargs) -> None: + super(LSTM, self).__init__(*args, **kwargs) + + def create_cell(self, layer, device, dtype): + if layer == 0: + input_size = self.input_size + else: + input_size = self.num_directions * self.hidden_size + return LSTMCell(input_size, self.hidden_size, self.bias, device, dtype) + + def init_hidden(self, batch_size, device, dtype): + hidden_shape = ( + self.num_directions * self.num_layers, + batch_size, + self.hidden_size, + ) + h = zeros(shape=hidden_shape, dtype=dtype, device=device) + c = zeros(shape=hidden_shape, dtype=dtype, device=device) + return (h, c) + + def get_output_from_hidden(self, hx): + return hx[0] + + def apply_op(self, input, hx): + op = builtin.LSTM( + num_layers=self.num_layers, + bidirectional=self.bidirectional, + bias=self.bias, + hidden_size=self.hidden_size, + proj_size=self.proj_size, + dropout=self.dropout, + ) + output, h, c = apply(op, input, hx[0], hx[1], self._flatten_weights)[:3] + placeholders = [output.sum() * 0, h.sum() * 0, c.sum() * 0] + output = output + placeholders[1] + placeholders[2] + h = h + placeholders[0] + placeholders[2] + c = c + placeholders[0] + placeholders[1] + return output, (h, c) + + def _apply_fn_to_hx(self, hx, fn): + return (fn(hx[0]), fn(hx[1])) + + def _stack_h_n(self, h_n): + h = [tup[0] for tup in h_n] + c = [tup[1] for tup in h_n] + return (stack(h, axis=0), stack(c, axis=0)) diff --git a/imperative/python/test/unit/module/test_rnn.py b/imperative/python/test/unit/module/test_rnn.py new file mode 100644 index 00000000..cf28f372 --- /dev/null +++ b/imperative/python/test/unit/module/test_rnn.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np +import pytest + +import megengine as mge +import megengine.functional as F +from megengine.module import LSTM, RNN, LSTMCell, RNNCell + + +def assert_tuple_equal(src, ref): + assert len(src) == len(ref) + for i, j in zip(src, ref): + assert i == j + + +@pytest.mark.parametrize( + "batch_size, input_size, hidden_size, init_hidden", + [(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False), (0, 10, 20, True)], +) +def test_rnn_cell(batch_size, input_size, hidden_size, init_hidden): + rnn_cell = RNNCell(input_size, hidden_size) + x = mge.random.normal(size=(batch_size, input_size)) + if init_hidden: + h = F.zeros(shape=(batch_size, hidden_size)) + else: + h = None + h_new = rnn_cell(x, h) + assert_tuple_equal(h_new.shape, (batch_size, hidden_size)) + + +# is batch_size == 0 tolerated ? it will cause error in slice operation xx[:, ...] +@pytest.mark.parametrize( + "batch_size, input_size, hidden_size, init_hidden", + [(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False)], +) +def test_lstm_cell(batch_size, input_size, hidden_size, init_hidden): + rnn_cell = LSTMCell(input_size, hidden_size) + x = mge.random.normal(size=(batch_size, input_size)) + if init_hidden: + h = F.zeros(shape=(batch_size, hidden_size)) + hx = (h, h) + else: + hx = None + h_new, c_new = rnn_cell(x, hx) + assert_tuple_equal(h_new.shape, (batch_size, hidden_size)) + assert_tuple_equal(c_new.shape, (batch_size, hidden_size)) + + +@pytest.mark.parametrize( + "batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first", + [ + (3, 6, 10, 20, 2, False, False, True), + pytest.param( + 3, + 3, + 10, + 10, + 1, + True, + True, + False, + marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"), + ), + ], +) +# (0, 1, 1, 1, 1, False, True, False)]) +def test_rnn( + batch_size, + seq_len, + input_size, + hidden_size, + num_layers, + bidirectional, + init_hidden, + batch_first, +): + rnn = RNN( + input_size, + hidden_size, + batch_first=batch_first, + num_layers=num_layers, + bidirectional=bidirectional, + ) + if batch_first: + x_shape = (batch_size, seq_len, input_size) + else: + x_shape = (seq_len, batch_size, input_size) + x = mge.random.normal(size=x_shape) + total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size + if init_hidden: + h = mge.random.normal(size=(batch_size, total_hidden_size)) + else: + h = None + output, h_n = rnn(x, h) + num_directions = 2 if bidirectional else 1 + if batch_first: + assert_tuple_equal( + output.shape, (batch_size, seq_len, num_directions * hidden_size) + ) + else: + assert_tuple_equal( + output.shape, (seq_len, batch_size, num_directions * hidden_size) + ) + assert_tuple_equal( + h_n.shape, (num_directions * num_layers, batch_size, hidden_size) + ) + + +@pytest.mark.parametrize( + "batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first", + [ + (3, 10, 20, 20, 1, False, False, True), + pytest.param( + 3, + 3, + 10, + 10, + 1, + True, + True, + False, + marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"), + ), + ], +) +# (0, 1, 1, 1, 1, False, True, False)]) +def test_lstm( + batch_size, + seq_len, + input_size, + hidden_size, + num_layers, + bidirectional, + init_hidden, + batch_first, +): + rnn = LSTM( + input_size, + hidden_size, + batch_first=batch_first, + num_layers=num_layers, + bidirectional=bidirectional, + ) + if batch_first: + x_shape = (batch_size, seq_len, input_size) + else: + x_shape = (seq_len, batch_size, input_size) + x = mge.random.normal(size=x_shape) + total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size + if init_hidden: + h = mge.random.normal(size=(batch_size, total_hidden_size)) + h = (h, h) + else: + h = None + output, h_n = rnn(x, h) + num_directions = 2 if bidirectional else 1 + if batch_first: + assert_tuple_equal( + output.shape, (batch_size, seq_len, num_directions * hidden_size) + ) + else: + assert_tuple_equal( + output.shape, (seq_len, batch_size, num_directions * hidden_size) + ) + assert_tuple_equal( + h_n[0].shape, (num_directions * num_layers, batch_size, hidden_size) + ) + assert_tuple_equal( + h_n[1].shape, (num_directions * num_layers, batch_size, hidden_size) + ) + + +if __name__ == "__main__": + test_lstm(5, 10, 10, 20, 1, False, False, True) diff --git a/imperative/src/impl/ops/rnn.cpp b/imperative/src/impl/ops/rnn.cpp new file mode 100644 index 00000000..b4e9ec80 --- /dev/null +++ b/imperative/src/impl/ops/rnn.cpp @@ -0,0 +1,68 @@ +/** + * \file imperative/src/impl/ops/rnn.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/opr/dnn/rnn.h" +#include "megbrain/imperative/ops/autogen.h" + +#include "../op_trait.h" + +namespace mgb::imperative { + +namespace rnn_cell { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 6); + return opr::RNNCell::make( + inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], + op.param()); +} +OP_TRAIT_REG(RNNCell, RNNCell).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace rnn_cell + +namespace lstm_cell { +VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 7); + auto* opr = opr::LSTMCell::make( + inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], + inputs[5], inputs[6], op.param()) + .node() + ->owner_opr(); + return {opr->output(0), opr->output(1), opr->output(2)}; +} +OP_TRAIT_REG(LSTMCell, LSTMCell).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace lstm_cell + +namespace rnn { +VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 3); + auto* opr = opr::RNN::make(inputs[0], inputs[1], inputs[2], op.param()) + .node() + ->owner_opr(); + return {opr->output(0), opr->output(1), opr->output(2)}; +} +OP_TRAIT_REG(RNN, RNN).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace rnn + +namespace lstm { +VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 4); + auto* opr = opr::LSTM::make(inputs[0], inputs[1], inputs[2], inputs[3], op.param()) + .node() + ->owner_opr(); + return {opr->output(0), opr->output(1), opr->output(2), opr->output(3)}; +} +OP_TRAIT_REG(LSTM, LSTM).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace lstm + +} // namespace mgb::imperative diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 233c99f3..bb6eb2fa 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -431,4 +431,12 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>; def LRN: MgbHashableOp<"LRN", [LRNParam]>; +def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; + +def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>; + +def RNN: MgbHashableOp<"RNN", [RNNParam]>; + +def LSTM: MgbHashableOp<"LSTM", [LSTMParam]>; + #endif // MGB_OPS diff --git a/src/opr/impl/dnn/dnn.sereg.h b/src/opr/impl/dnn/dnn.sereg.h index 4455bceb..b4729edb 100644 --- a/src/opr/impl/dnn/dnn.sereg.h +++ b/src/opr/impl/dnn/dnn.sereg.h @@ -20,6 +20,7 @@ #include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/lsq.h" #include "megbrain/opr/dnn/pooling.h" +#include "megbrain/opr/dnn/rnn.h" #include "megbrain/opr/dnn/roi_align.h" #include "megbrain/opr/dnn/roi_pooling.h" #include "megbrain/opr/dnn/sliding_window_transpose.h" @@ -292,6 +293,36 @@ struct OprMaker { ->owner_opr(); } }; + +template <> +struct OprMaker { + using Param = opr::RNNBackward::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + return opr::RNNBackward::make( + i[0], i[1], i[2], i[3], i[4], i[5], i[6], param, config)[0] + .node() + ->owner_opr(); + } +}; + +template <> +struct OprMaker { + using Param = opr::LSTMBackward::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + return opr::LSTMBackward::make( + i[0], i[1], i[2], i[3], i[4], i[5], i[6], i[7], i[8], param, + config)[0] + .node() + ->owner_opr(); + } +}; + template <> struct OprLoadDumpImpl : public GeneralOprLoadDumpImpl< @@ -641,6 +672,10 @@ MGB_SEREG_OPR(TQT, 2); MGB_SEREG_OPR(TQTBackward, 3); MGB_SEREG_OPR(LSQ, 4); MGB_SEREG_OPR(LSQBackward, 5); +MGB_SEREG_OPR(RNNForward, 3); +MGB_SEREG_OPR(RNNBackward, 7); +MGB_SEREG_OPR(LSTMForward, 4); +MGB_SEREG_OPR(LSTMBackward, 9); } // namespace opr } // namespace mgb diff --git a/src/opr/impl/dnn/rnn.cpp b/src/opr/impl/dnn/rnn.cpp new file mode 100644 index 00000000..0099cc3a --- /dev/null +++ b/src/opr/impl/dnn/rnn.cpp @@ -0,0 +1,323 @@ +/** + * \file src/opr/impl/dnn/rnn.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "megbrain/opr/dnn/rnn.h" +#include "../internal/megdnn_opr_wrapper.inl" +#include "megbrain/graph/grad_impl.h" +#include "megbrain/opr/basic_arith_wrapper.h" +#include "megbrain/opr/blas.h" +#include "megbrain/opr/internal/out_shape_by_sym_var.h" +#include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/utility.h" + +using namespace mgb; +using namespace opr; + +/* ================= RNNCell ================= */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNNCellForward); +RNNCellForward::RNNCellForward( + VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, + VarNode* weight_hh, VarNode* bias_hh, const Param& param, + const OperatorNodeConfig& config) + : Super{input->owner_graph(), + config, + "rnn_cell", + {input, weight_ih, bias_ih, hx, weight_hh, bias_hh}} { + init_megdnn_opr(*this, param); + add_input({input, weight_ih, bias_ih, hx, weight_hh, bias_hh}); +} + +SymbolVar RNNCellForward::make( + SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, + SymbolVar weight_hh, SymbolVar bias_hh, const Param& param, + const OperatorNodeConfig& config) { + return input.insert_single_output_opr( + input.node(), weight_ih.node(), bias_ih.node(), hx.node(), weight_hh.node(), + bias_hh.node(), param, config); +} + +#if MGB_ENABLE_GRAD + +VarNode* rnnCellBackward( + const SymbolVar& input, const SymbolVar& weight_ih, const SymbolVar& hx, + const SymbolVar& weight_hh, const SymbolVar& out, + RNNCell::NonlineMode nonlineMode, size_t wrt_idx, const SymbolVar& og) { + SymbolVar tmp; + // activation + using NonlineMode = RNNCell::NonlineMode; + using Mode = Elemwise::Mode; + switch (nonlineMode) { + case NonlineMode::IDENTITY: + tmp = og; + break; + case NonlineMode::TANH: + tmp = Elemwise::make({out, og}, Mode::TANH_GRAD); + break; + case NonlineMode::RELU: + tmp = Elemwise::make({out, og}, Mode::SWITCH_GT0); + break; + default: + mgb_throw(GraphError, "Activation method not supported"); + } + // now grad is in tmp + if (wrt_idx == 2 || wrt_idx == 5) + return tmp.node(); // bias + + SymbolVar result; + // A * Bt = C, A' = C' * B, B' = C't * A + if (wrt_idx == 0) { // input + result = MatrixMul::make( + tmp, weight_ih, + {false, false}); // transpose a false, transpose b false + } else if (wrt_idx == 1) { // weight_ih + result = MatrixMul::make(tmp, input, {true, false}); + } else if (wrt_idx == 3) { // hx + result = MatrixMul::make(tmp, weight_hh, {false, false}); + } else if (wrt_idx == 4) { // weight_hh + result = MatrixMul::make(tmp, hx, {true, false}); + } + return result.node(); +} + +MGB_IMPL_OPR_GRAD(RNNCell) { + SymbolVar input(opr.input(0)), weight_ih(opr.input(1)), hx(opr.input(3)), + weight_hh(opr.input(4)); + SymbolVar out(opr.output(0)), og{out_grad.at(0)}; + return rnnCellBackward( + input, weight_ih, hx, weight_hh, out, opr.param().nonlineMode, wrt_idx, og); +} +#endif + +/* ================= LSTMCell ================= */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTMCell); +LSTMCellForward::LSTMCellForward( + VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, + VarNode* weight_hh, VarNode* bias_hh, VarNode* cx, const Param& param, + const OperatorNodeConfig& config) + : Super{input->owner_graph(), + config, + "lstm_cell", + {input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx}} { + init_megdnn_opr(*this, param); + add_input({input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx}); +} + +SymbolVar LSTMCellForward::make( + SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, + SymbolVar weight_hh, SymbolVar bias_hh, SymbolVar cx, const Param& param, + const OperatorNodeConfig& config) { + return input.insert_single_output_opr( + input.node(), weight_ih.node(), bias_ih.node(), hx.node(), weight_hh.node(), + bias_hh.node(), cx.node(), param, config); +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(LSTMCell) { + SymbolVar input(opr.input(0)), weight_ih(opr.input(1)), hx(opr.input(3)), + weight_hh(opr.input(4)), cx(opr.input(6)); + SymbolVar h_out(opr.output(0)), c_out(opr.output(1)), gates(opr.output(2)), + h_og{out_grad.at(0)}, c_og{out_grad.at(1)}, tmp; + size_t ghs = gates.shape()[1] / 4; // gate_hidden_size + SymbolVarArray gates_array = Split::make( + gates, Split::Options::make_partition(gates, 1, {ghs, ghs, ghs, ghs})); + mgb_assert(gates_array.size() == 4); + using Mode = Elemwise::Mode; + const SymbolVar &i(Elemwise::make({gates_array.at(0)}, Mode::SIGMOID)), + f(Elemwise::make({gates_array.at(1)}, Mode::SIGMOID)), + o(Elemwise::make({gates_array.at(2)}, Mode::SIGMOID)), + g(Elemwise::make({gates_array.at(3)}, Mode::TANH)); + SymbolVar i_grad, f_grad, o_grad, g_grad; + + SymbolVar tanh_c_out = Elemwise::make({c_out}, Mode::TANH); + o_grad = Elemwise::make({o, h_og * tanh_c_out}, Mode::SIGMOID_GRAD); + c_og = c_og + Elemwise::make({tanh_c_out, h_og * o}, Mode::TANH_GRAD); + f_grad = Elemwise::make({f, c_og * cx}, Mode::SIGMOID_GRAD); + i_grad = Elemwise::make({i, c_og * g}, Mode::SIGMOID_GRAD); + g_grad = Elemwise::make({g, c_og * i}, Mode::TANH_GRAD); + SymbolVar rnn_cell_grad = Concat::make({i_grad, f_grad, o_grad, g_grad}, {-1}); + + SymbolVar result; + if (wrt_idx < 6) { + using NonlineMode = RNNCell::NonlineMode; + result = rnnCellBackward( + input, weight_ih, hx, weight_hh, gates, NonlineMode::IDENTITY, wrt_idx, + rnn_cell_grad); + } else { // cx + result = c_og * f; + } + return result.node(); +} +#endif + +/* ================= RNN ================= */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNN); +MEGDNN_OPR_INIT3(RNNForward, "rnn_fwd"); + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(RNN) { + mgb_assert( + opr.param().fwd_mode == RNN::Param::FwdMode::TRAINING, + "RNN could only take grad in training mode"); + SymbolVarArray grads = RNNBackward::make( + opr.input(0), opr.output(0), opr.input(1), out_grad.at(0), out_grad.at(1), + opr.input(2), opr.output(2), opr.param()); + // return grads.at(wrt_idx).node(); // input, hx, weights + VarNodeArray ret(opr.input().size(), nullptr); + for (size_t i = 0; i < ret.size(); ++i) { + ret[i] = grads[i].node(); + } + return ret; +} +#endif + +/* ================= RNNBackward ================= */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNNBackward); + +RNNBackward::RNNBackward( + VarNode* x, VarNode* y, VarNode* hx, VarNode* dy, VarNode* dhy, + VarNode* flatten_weights, VarNode* reserve_space, const Param& param, + const OperatorNodeConfig& config) + : Super({x->owner_graph(), + config, + "rnn_bwd", + {x, y, hx, dy, dhy, flatten_weights, reserve_space}}, + 0, true) { + init_megdnn_opr(*this, param); + add_input({x, y, hx, dy, dhy, flatten_weights, reserve_space}); +} + +SymbolVarArray RNNBackward::make( + SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar dy, SymbolVar dhy, + SymbolVar flatten_weights, SymbolVar reserve_space, const Param& param, + const OperatorNodeConfig& config) { + auto&& out = x.node()->owner_graph() + ->insert_opr(std::make_unique( + x.node(), y.node(), hx.node(), dy.node(), dhy.node(), + flatten_weights.node(), reserve_space.node(), param, + config)) + ->output(); + SymbolVarArray ret(out.size()); + for (size_t i = 0; i < ret.size(); ++i) { + ret[i] = out[i]; + } + return ret; +} + +RNNBackward::Super::NodeProp* RNNBackward::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(6), NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} + +void RNNBackward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + + mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); + mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); + mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(5))); + this->init_output_static_infer_desc_workspace( + intl::AutoAddWorkspaceNeedLimitGetter::val); +} + +void RNNBackward::init_output_dtype() { + output(0)->dtype(input(0)->dtype()); + output(1)->dtype(input(2)->dtype()); + output(2)->dtype(input(5)->dtype()); +} + +/* ================= LSTM ================= */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTM); +LSTMForward::LSTMForward( + VarNode* input, VarNode* hx, VarNode* cx, VarNode* flatten_weights, + const Param& param, const OperatorNodeConfig& config) + : Super{input->owner_graph(), + config, + "lstm", + {input, hx, cx, flatten_weights}} { + init_megdnn_opr(*this, param); + add_input({input, hx, cx, flatten_weights}); +} + +SymbolVar LSTMForward::make( + SymbolVar input, SymbolVar hx, SymbolVar cx, SymbolVar flatten_weights, + const Param& param, const OperatorNodeConfig& config) { + return input.insert_single_output_opr( + input.node(), hx.node(), cx.node(), flatten_weights.node(), param, config); +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(LSTM) { + SymbolVarArray grads = LSTMBackward::make( + opr.input(0), opr.output(0), opr.input(1), opr.input(2), out_grad.at(0), + out_grad.at(1), out_grad.at(2), opr.input(3), opr.output(3), opr.param()); + SymbolVar res; + return grads.at(wrt_idx).node(); // input, hx, cx, weights +} +#endif + +/* ================= LSTMBackward ================= */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTMBackward); +LSTMBackward::LSTMBackward( + VarNode* x, VarNode* y, VarNode* hx, VarNode* cx, VarNode* dy, VarNode* dhy, + VarNode* dcy, VarNode* flatten_weights, VarNode* reserve_space, + const Param& param, const OperatorNodeConfig& config) + : Super({x->owner_graph(), + config, + "lstm_bwd", + {x, y, hx, cx, dy, dhy, dcy, flatten_weights, reserve_space}}, + 1, true) { + init_megdnn_opr(*this, param); + add_input({x, y, hx, cx, dy, dhy, dcy, flatten_weights, reserve_space}); +} + +SymbolVarArray LSTMBackward::make( + SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar cx, SymbolVar dy, + SymbolVar dhy, SymbolVar dcy, SymbolVar flatten_weights, + SymbolVar reserve_space, const Param& param, const OperatorNodeConfig& config) { + auto&& out = x.node()->owner_graph() + ->insert_opr(std::make_unique( + x.node(), y.node(), hx.node(), cx.node(), dy.node(), + dhy.node(), dcy.node(), flatten_weights.node(), + reserve_space.node(), param, config)) + ->output(); + SymbolVarArray ret(out.size()); + for (size_t i = 0; i < ret.size(); ++i) { + ret[i] = out[i]; + } + return ret; +} + +LSTMBackward::Super::NodeProp* LSTMBackward::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var( + input(8), // reserve space + NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} + +void LSTMBackward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + + mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); + mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); + mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(3))); + mgr.register_shape_infer(output(3), ShapeInferDesc::make_identity(input(7))); + this->init_output_static_infer_desc_workspace( + intl::AutoAddWorkspaceNeedLimitGetter::val); +} + +void LSTMBackward::init_output_dtype() { + output(0)->dtype(input(0)->dtype()); + output(1)->dtype(input(2)->dtype()); + output(2)->dtype(input(3)->dtype()); + output(3)->dtype(input(7)->dtype()); +} \ No newline at end of file diff --git a/src/opr/impl/internal/megdnn_opr_wrapper.inl b/src/opr/impl/internal/megdnn_opr_wrapper.inl index 23b28f1f..724bf0f7 100644 --- a/src/opr/impl/internal/megdnn_opr_wrapper.inl +++ b/src/opr/impl/internal/megdnn_opr_wrapper.inl @@ -170,6 +170,16 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker) // { +public: + using NonlineMode = Param::NonlineMode; + + RNNCellForward( + VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, + VarNode* weight_hh, VarNode* bias_hh, const Param& param, + const OperatorNodeConfig& config); + static SymbolVar make( + SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, + SymbolVar weight_hh, SymbolVar bias_hh, const Param& param = {}, + const OperatorNodeConfig& config = {}); +}; +using RNNCell = RNNCellForward; + +MGB_DEFINE_OPR_CLASS( + LSTMCellForward, intl::MegDNNOprWrapperFwd) // { +public: + LSTMCellForward( + VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, + VarNode* weight_hh, VarNode* bias_hh, VarNode* cx, const Param& param, + const OperatorNodeConfig& config); + static SymbolVar make( + SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, + SymbolVar weight_hh, SymbolVar bias_hh, SymbolVar cx, + const Param& param = {}, const OperatorNodeConfig& config = {}); +}; +using LSTMCell = LSTMCellForward; + +MGB_DEFINE_OPR_CLASS(RNNForward, intl::MegDNNOprWrapperFwd) // { + /*private: + SymbolVarArray weight_ih_arr; // 1d, idx: direction * num_layers + layer + SymbolVarArray weight_hh_arr; + SymbolVarArray bias_arr; + */ + +public: + RNNForward( + VarNode* input, VarNode* hx, VarNode* flatten_weights, const Param& param, + const OperatorNodeConfig& config); + static SymbolVar make( + SymbolVar input, SymbolVar hx, SymbolVar flatten_weights, + const Param& param = {}, const OperatorNodeConfig& config = {}); +}; +using RNN = RNNForward; + +MGB_DEFINE_OPR_CLASS( + RNNBackward, intl::MegDNNOprWrapperBwd) // { +public: + RNNBackward( + VarNode* x, VarNode* y, VarNode* hx, VarNode* dy, VarNode* dhy, + VarNode* flatten_weights, VarNode* reserve_space, const Param& param, + const OperatorNodeConfig& config); + static SymbolVarArray make( + SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar dy, SymbolVar dhy, + SymbolVar flatten_weights, SymbolVar reserve_space, const Param& param = {}, + const OperatorNodeConfig& config = {}); + Super::NodeProp* do_make_node_prop() const override; + +private: + void init_output_static_infer_desc() override; + void init_output_dtype() override; +}; + +MGB_DEFINE_OPR_CLASS( + LSTMForward, intl::MegDNNOprWrapperFwd) // { +public: + LSTMForward( + VarNode* input, VarNode* hx, VarNode* cx, VarNode* flatten_weights, + const Param& param, const OperatorNodeConfig& config); + static SymbolVar make( + SymbolVar input, SymbolVar hx, SymbolVar cx, SymbolVar flatten_weights, + const Param& param = {}, const OperatorNodeConfig& config = {}); +}; +using LSTM = LSTMForward; + +MGB_DEFINE_OPR_CLASS( + LSTMBackward, intl::MegDNNOprWrapperBwd) // { +public: + LSTMBackward( + VarNode* x, VarNode* y, VarNode* hx, VarNode* cx, VarNode* dy, VarNode* dhy, + VarNode* dcy, VarNode* flatten_weights, VarNode* reserve_space, + const Param& param, const OperatorNodeConfig& config); + static SymbolVarArray make( + SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar cx, SymbolVar dy, + SymbolVar dhy, SymbolVar dcy, SymbolVar flatten_weights, + SymbolVar reserve_space, const Param& param = {}, + const OperatorNodeConfig& config = {}); + Super::NodeProp* do_make_node_prop() const override; + +private: + void init_output_static_infer_desc() override; + void init_output_dtype() override; +}; + +} // namespace opr +} // namespace mgb \ No newline at end of file diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index f91477e6..27e2a0b0 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -116,6 +116,9 @@ union OperatorParam { param.Padding = 82, param.ShuffleRNG = 83, param.CheckNonFinite = 84, + param.RNNCell = 85, + param.RNN = 86, + param.LSTM = 87, } table Operator {