@@ -1,7 +1,3 @@ | |||
# Build | |||
build/ | |||
output/ | |||
# Cache | |||
__pycache__/ | |||
.ccls-cache/ | |||
@@ -11,5 +7,8 @@ __pycache__/ | |||
.vs/ | |||
.idea/ | |||
# CMake | |||
compile_commands.json | |||
# Make and Build Settings | |||
build/ | |||
output/ | |||
compile_commands.json | |||
imperative/python/megengine/core/*.so |
@@ -1936,6 +1936,198 @@ protected: | |||
const TensorLayout& grad_s, size_t workspace_in_bytes); | |||
}; | |||
class RNNCellForward : public OperatorBase { | |||
DEF_OPR_PARAM(RNNCell); | |||
DEF_OPR_IMPL(RNNCellForward, OperatorBase, 6, 1); | |||
public: | |||
virtual void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||
static void deduce_layout( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& dst, size_t workspace_in_bytes); | |||
}; | |||
using RNNCell = RNNCellForward; | |||
class LSTMCellForward : public OperatorBase { | |||
// DEF_OPR_PARAM(LSTMCell); | |||
DEF_OPR_PARAM(Empty); | |||
DEF_OPR_IMPL(LSTMCellForward, OperatorBase, 7, 3); | |||
public: | |||
virtual void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||
_megdnn_tensor_out gates, _megdnn_workspace workspace) = 0; | |||
void deduce_layout( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new, | |||
TensorLayout& gates); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& cx, const TensorLayout& h_new, | |||
const TensorLayout& c_new, const TensorLayout& gates) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& cx, const TensorLayout& h_new, | |||
const TensorLayout& c_new, const TensorLayout& gates, | |||
size_t workspace_in_bytes); | |||
}; | |||
using LSTMCell = LSTMCellForward; | |||
class RNNForward : public OperatorBase { | |||
DEF_OPR_PARAM(RNN); | |||
DEF_OPR_IMPL(RNNForward, OperatorBase, 3, 3); | |||
public: | |||
virtual void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout( | |||
const TensorLayout& input, const TensorLayout& hx, | |||
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy, | |||
TensorLayout& reserve_space); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& hx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& reserve_space) = 0; | |||
virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& input, const TensorLayout& hx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& reserve_space, | |||
size_t workspace_in_bytes); | |||
}; | |||
using RNN = RNNForward; | |||
class RNNBackward : public OperatorBase { | |||
DEF_OPR_PARAM(RNN); | |||
DEF_OPR_IMPL(RNNBackward, OperatorBase, 7, 3); | |||
public: | |||
virtual void exec( | |||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space, | |||
_megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||
TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||
const TensorLayout& dx, const TensorLayout& dhx, | |||
const TensorLayout& dw) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw, | |||
size_t workspace_in_bytes); | |||
}; | |||
class LSTMForward : public OperatorBase { | |||
DEF_OPR_PARAM(LSTM); | |||
DEF_OPR_IMPL(LSTMForward, OperatorBase, 4, 4); | |||
public: | |||
virtual void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||
_megdnn_tensor_out hy, _megdnn_tensor_out cy, | |||
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace) = 0; | |||
void deduce_layout( | |||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy, | |||
TensorLayout& cy, TensorLayout& reserve_space); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& cy, | |||
const TensorLayout& reserve_space) = 0; | |||
virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& cy, | |||
const TensorLayout& reserve_space, size_t workspace_in_bytes); | |||
}; | |||
using LSTM = LSTMForward; | |||
class LSTMBackward : public OperatorBase { | |||
DEF_OPR_PARAM(LSTM); | |||
DEF_OPR_IMPL(LSTMBackward, OperatorBase, 9, 4); | |||
public: | |||
virtual void exec( | |||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, | |||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, | |||
_megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||
const TensorLayout& reserve_space, TensorLayout& dx, TensorLayout& dhx, | |||
TensorLayout& dcx, TensorLayout& dw); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||
const TensorLayout& dhx, const TensorLayout& dcx, | |||
const TensorLayout& dw) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw, | |||
size_t workspace_in_bytes); | |||
}; | |||
} // namespace megdnn | |||
#include "megdnn/internal/opr_header_epilogue.h" | |||
@@ -1203,3 +1203,28 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), | |||
member_alias=[(i, 'PADDING_{}'.format(i)) for i in PADDING_MODES] | |||
) | |||
) | |||
(pdef('RNNCell'). | |||
add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'TANH = 2') | |||
) | |||
(pdef('RNN'). | |||
add_fields('uint32', 'num_layers', '1'). | |||
add_fields('bool', 'bidirectional', 'false'). | |||
add_fields('bool', 'bias', 'true'). | |||
add_fields('uint32', 'hidden_size', '128'). | |||
add_fields('uint32', 'proj_size', '0'). | |||
add_fields('float32', 'dropout', '0.f'). | |||
add_enum_alias('NonlineMode', 'RNNCell'). | |||
add_enum_alias('FwdMode', 'BN', name_field='fwd_mode') | |||
) | |||
(pdef('LSTM'). | |||
add_fields('uint32', 'num_layers', '1'). | |||
add_fields('bool', 'bidirectional', 'false'). | |||
add_fields('bool', 'bias', 'true'). | |||
add_fields('uint32', 'hidden_size', '128'). | |||
add_fields('uint32', 'proj_size', '0'). | |||
add_fields('float32', 'dropout', '0.f'). | |||
add_enum_alias('FwdMode', 'BN', name_field='fwd_mode') | |||
) |
@@ -92,8 +92,7 @@ std::unique_ptr<Handle> Handle::make( | |||
} | |||
MIDOUT_END(); | |||
} | |||
else if (platform == megcorePlatformROCM) { | |||
} else if (platform == megcorePlatformROCM) { | |||
#if MEGDNN_WITH_ROCM | |||
return make_rocm_handle(computing_handle); | |||
#else | |||
@@ -111,8 +110,7 @@ std::unique_ptr<Handle> Handle::make( | |||
#else | |||
return nullptr; | |||
#endif | |||
} | |||
else { | |||
} else { | |||
// CUDA | |||
megdnn_throw_if( | |||
platform != megcorePlatformCUDA, megdnn_error, | |||
@@ -209,7 +209,13 @@ private: | |||
cb(LSQBackward) \ | |||
cb(Fill) \ | |||
cb(PaddingForward) \ | |||
cb(PaddingBackward) | |||
cb(PaddingBackward) \ | |||
cb(RNNCell) \ | |||
cb(LSTMCell) \ | |||
cb(RNN) \ | |||
cb(RNNBackward) \ | |||
cb(LSTM) \ | |||
cb(LSTMBackward) | |||
// clang-format on | |||
/*! | |||
@@ -0,0 +1,98 @@ | |||
/** | |||
* \file dnn/src/common/lstm.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
// #include "src/cuda/lstm/utils.h" | |||
namespace megdnn { | |||
/*size_t get_reserve_size(Handle* handle, megdnn::LSTMForward::Param& param, const | |||
TensorLayout& input) { #if CUDNN_MAJOR >= 6 auto holder = | |||
megdnn::cuda::lstm::get_RNNDescHolder_v6(handle, param, input); return | |||
holder.reserveSpace_size; # else return 0; #endif | |||
}*/ | |||
void LSTM::deduce_layout( | |||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy, | |||
TensorLayout& cy, TensorLayout& reserve_space) { | |||
// input: [seq_len, batch_size, input_size] | |||
// hx: [D * num_layers, batch_size, hidden_size] | |||
size_t seq_len = input.shape[0]; | |||
size_t batch_size = input.shape[1]; | |||
size_t D = param().bidirectional ? 2 : 1; | |||
size_t hidden_size = hx.shape[2]; | |||
output = TensorLayout( | |||
TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype); | |||
hy = TensorLayout(hx); | |||
cy = TensorLayout(cx); | |||
// reserve_space = {{get_reserve_size(this->handle(), param(), input)}, | |||
// dtype::Byte()}; | |||
reserve_space = {{get_reserve_size_in_bytes(input)}, dtype::Byte()}; | |||
} | |||
void LSTM::check_exec( | |||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& cy, | |||
const TensorLayout& reserve_space, size_t workspace_in_bytes) { | |||
auto errmsg = [&]() { | |||
std::string msg; | |||
msg.append("input="); | |||
msg.append(input.to_string()); | |||
msg.append(", hx="); | |||
msg.append(hx.to_string()); | |||
msg.append(", cx="); | |||
msg.append(cx.to_string()); | |||
msg.append(", hidden_size="); | |||
msg.append(std::to_string(param().hidden_size)); | |||
msg.append(", num_layers="); | |||
msg.append(std::to_string(param().num_layers)); | |||
msg.append(", bidirectional="); | |||
msg.append(std::to_string(param().bidirectional)); | |||
return msg; | |||
}; | |||
size_t D = param().bidirectional ? 2 : 1; | |||
size_t num_layers = param().num_layers; | |||
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str()); | |||
ASSERT_BRIEF(hx.ndim == 3) | |||
ASSERT_BRIEF(hx.shape[0] == D * num_layers) | |||
ASSERT_BRIEF(hx.shape[1] == input.shape[1]) // batch_size | |||
ASSERT_BRIEF(hx.shape[2] == param().hidden_size) | |||
ASSERT_BRIEF(cx.ndim == 3) | |||
ASSERT_BRIEF(cx.shape[0] == D * num_layers) | |||
ASSERT_BRIEF(cx.shape[1] == input.shape[1]) // batch_size | |||
ASSERT_BRIEF(cx.shape[2] == param().hidden_size) | |||
#undef ASSERT_BRIEF | |||
} | |||
void LSTMBackward::deduce_layout( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||
const TensorLayout& reserve_space, TensorLayout& dx, TensorLayout& dhx, | |||
TensorLayout& dcx, TensorLayout& dw) { | |||
dx = x; | |||
dhx = hx; | |||
dcx = cx; | |||
dw = flatten_weights; | |||
} | |||
void LSTMBackward::check_exec( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw, | |||
size_t workspace_in_bytes) {} | |||
} // namespace megdnn |
@@ -0,0 +1,136 @@ | |||
/** | |||
* \file dnn/src/common/lstm_cell.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/common/lstm_cell.h" | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void LSTMCell::deduce_layout( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new, | |||
TensorLayout& gates) { | |||
// size_t batch_size = hx.shape[0]; | |||
// size_t hidden_size = hx.shape[1]; | |||
h_new = TensorLayout(hx, hx.dtype); | |||
c_new = TensorLayout(cx, cx.dtype); | |||
auto opr = handle()->create_operator<RNNCellForward>(); | |||
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY; | |||
opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates); | |||
} | |||
void LSTMCell::check_exec( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, | |||
const TensorLayout& gates, size_t workspace_in_bytes) { | |||
TensorLayout h_new_expected, c_new_expected, gates_expected; | |||
deduce_layout( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new_expected, | |||
c_new_expected, gates_expected); | |||
megdnn_assert_eq_layout(h_new_expected, h_new); | |||
megdnn_assert_eq_layout(c_new_expected, c_new); | |||
megdnn_assert_eq_layout(gates_expected, gates); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
} // namespace megdnn | |||
namespace megdnn { | |||
namespace lstm_cell { | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, | |||
const TensorLayout& gates, Handle* handle) { | |||
TensorLayout tmp_layout; | |||
auto opr = handle->create_operator<RNNCellForward>(); | |||
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY; | |||
opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, tmp_layout); | |||
size_t rnn_cell_need = opr->get_workspace_in_bytes( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates); | |||
size_t lstm_cell_need = tmp_layout.span().dist_byte(); | |||
return rnn_cell_need > lstm_cell_need ? rnn_cell_need : lstm_cell_need; | |||
} | |||
void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||
_megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle) { | |||
auto opr = handle->create_operator<RNNCellForward>(); | |||
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY; | |||
/*TensorLayout tmp_layout; | |||
opr->deduce_layout(input.layout, weight_ih.layout, | |||
hx.layout, weight_hh.layout, | |||
bias.layout, tmp_layout); | |||
auto workspace_ptr = workspace.raw_ptr; | |||
// TensorND tmp{static_cast<void*>(workspace.raw_ptr), tmp_layout}; | |||
TensorND tmp{workspace_ptr, tmp_layout}; | |||
auto new_workspace = Workspace{workspace_ptr + tmp.layout.span().dist_byte(), | |||
workspace.size - | |||
tmp.layout.span().dist_byte()};*/ | |||
// opr->exec(input, weight_ih, hx, weight_hh, bias, tmp, new_workspace); | |||
opr->exec(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates, workspace); | |||
// activation | |||
// size_t batch_size = tmp.layout.shape[0]; | |||
size_t batch_size = hx.layout.shape[0]; | |||
size_t hidden_size = hx.layout.shape[1]; | |||
// sigmoid: i f o | |||
// TensorLayout gates_ifo_layout{TensorShape({batch_size, hidden_size * 3}), | |||
// tmp.layout.dtype}; | |||
TensorND tmp{static_cast<void*>(workspace.raw_ptr), gates.layout}; | |||
TensorLayout gates_ifo_layout{ | |||
TensorShape({batch_size, hidden_size * 3}), gates.layout.dtype}; | |||
TensorND gates_ifo_origin{gates.raw_ptr(), gates_ifo_layout}; | |||
TensorND gates_ifo{tmp.raw_ptr(), gates_ifo_layout}; | |||
auto sigmoid = handle->create_operator<ElemwiseForward>(); | |||
sigmoid->param().mode = Elemwise::Param::Mode::SIGMOID; | |||
sigmoid->exec({gates_ifo_origin}, gates_ifo); | |||
// tanh: g | |||
TensorLayout g_layout{TensorShape({batch_size, hidden_size}), gates.layout.dtype}; | |||
TensorND g_origin{ | |||
static_cast<char*>(gates.raw_ptr()) + gates_ifo_layout.span().dist_byte(), | |||
g_layout}; | |||
TensorND g{ | |||
static_cast<char*>(tmp.raw_ptr()) + gates_ifo_layout.span().dist_byte(), | |||
g_layout}; | |||
auto tanh = handle->create_operator<ElemwiseForward>(); | |||
tanh->param().mode = Elemwise::Param::Mode::TANH; | |||
tanh->exec({g_origin}, g); | |||
// extract i f o | |||
TensorND i{static_cast<char*>(tmp.raw_ptr()), g_layout}; | |||
TensorND f{ | |||
static_cast<char*>(tmp.raw_ptr()) + g_layout.span().dist_byte(), g_layout}; | |||
TensorND o{ | |||
static_cast<char*>(tmp.raw_ptr()) + g_layout.span().dist_byte() * 2, | |||
g_layout}; | |||
// calculate new cell state | |||
auto elewise_mul_add = handle->create_operator<ElemwiseForward>(); | |||
elewise_mul_add->param().mode = Elemwise::Param::Mode::FUSE_MUL_ADD4; | |||
elewise_mul_add->exec({f, cx, i, g}, c_new); | |||
// calculate new hidden state | |||
tanh->exec({c_new}, h_new); | |||
auto elewise_mul = handle->create_operator<ElemwiseForward>(); | |||
elewise_mul->param().mode = Elemwise::Param::Mode::MUL; | |||
elewise_mul->exec({o, h_new}, h_new); | |||
} | |||
} // namespace lstm_cell | |||
} // namespace megdnn |
@@ -0,0 +1,32 @@ | |||
/** | |||
* \file dnn/src/common/lstm_cell.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "megdnn/oprs/base.h" | |||
namespace megdnn { | |||
namespace lstm_cell { | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, | |||
const TensorLayout& gates, Handle* handle); | |||
void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||
_megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle); | |||
} // namespace lstm_cell | |||
} // namespace megdnn |
@@ -135,6 +135,8 @@ DEF(CheckNonFinite, 2, true, true); | |||
DEF(LSQForward, 5, true, true); | |||
DEF(LSQBackward, 7, true, false); | |||
DEF(Fill, 1, true, false); | |||
DEF(RNNCellForward, 6, true, true); | |||
DEF(RNNForward, 6, true, true); | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -402,9 +402,8 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { | |||
} | |||
if (dst.type() == TensorFormat::Type::IMAGE2D_PACK4 && | |||
( | |||
handle()->type() != Handle::HandleType::NAIVE && | |||
handle()->type() != Handle::HandleType::X86)) { | |||
(handle()->type() != Handle::HandleType::NAIVE && | |||
handle()->type() != Handle::HandleType::X86)) { | |||
megdnn_throw( | |||
"Dump with Image2DPack4TensorFormat is not available on CUDA compnode, " | |||
"try export CUDA_VISIBLE_DEVICES=\'\'"); | |||
@@ -0,0 +1,82 @@ | |||
/** | |||
* \file dnn/src/common/rnn.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/common/rnn.h" | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void RNN::deduce_layout( | |||
const TensorLayout& input, const TensorLayout& hx, | |||
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy, | |||
TensorLayout& reserve_space) { | |||
// input: [seq_len, batch_size, input_size] | |||
// hx: [D * num_layers, batch_size, hidden_size] | |||
size_t seq_len = input.shape[0]; | |||
size_t batch_size = input.shape[1]; | |||
size_t D = param().bidirectional ? 2 : 1; | |||
size_t hidden_size = hx.shape[2]; | |||
output = TensorLayout( | |||
TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype); | |||
hy = TensorLayout(hx); | |||
// reserve_space = {{get_reserve_size(this->handle(), param(), input)}, | |||
// dtype::Byte()}; | |||
reserve_space = {{get_reserve_size_in_bytes(input)}, dtype::Byte()}; | |||
} | |||
void RNN::check_exec( | |||
const TensorLayout& input, const TensorLayout& hx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& reserve_space, | |||
size_t workspace_in_bytes) { | |||
auto errmsg = [&]() { | |||
std::string msg; | |||
msg.append("input="); | |||
msg.append(input.to_string()); | |||
msg.append(", hx="); | |||
msg.append(hx.to_string()); | |||
msg.append(", hidden_size="); | |||
msg.append(std::to_string(param().hidden_size)); | |||
msg.append(", num_layers="); | |||
msg.append(std::to_string(param().num_layers)); | |||
msg.append(", bidirectional="); | |||
msg.append(std::to_string(param().bidirectional)); | |||
return msg; | |||
}; | |||
size_t D = param().bidirectional ? 2 : 1; | |||
size_t num_layers = param().num_layers; | |||
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str()); | |||
ASSERT_BRIEF(hx.ndim == 3) | |||
ASSERT_BRIEF(hx.shape[0] == D * num_layers) | |||
ASSERT_BRIEF(hx.shape[1] == input.shape[1]) // batch_size | |||
ASSERT_BRIEF(hx.shape[2] == param().hidden_size) | |||
#undef ASSERT_BRIEF | |||
} | |||
void RNNBackward::deduce_layout( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||
TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw) { | |||
dx = x; | |||
dhx = hx; | |||
dw = flatten_weights; | |||
} | |||
void RNNBackward::check_exec( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw, | |||
size_t workspace_in_bytes) {} | |||
} // namespace megdnn |
@@ -0,0 +1,33 @@ | |||
/** | |||
* \file dnn/src/common/rnn.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megdnn/oprs/base.h" | |||
#include "megdnn/oprs/general.h" | |||
namespace megdnn { | |||
namespace rnn { | |||
using Param = param::RNN; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayoutArray& weight_ih, | |||
const TensorLayoutArray& states, const TensorLayoutArray& weight_hh, | |||
const TensorLayoutArray& bias, const TensorLayout& output, | |||
const TensorLayoutArray& states_new, const Param& param, Handle* handle); | |||
void exec( | |||
_megdnn_tensor_in input, _megdnn_in const TensorNDArray& weight_ih, | |||
_megdnn_in const TensorNDArray& states, | |||
_megdnn_in const TensorNDArray& weight_hh, _megdnn_in const TensorNDArray& bias, | |||
_megdnn_tensor_out output, _megdnn_out const TensorNDArray& states_new, | |||
_megdnn_workspace workspace, const Param& param, Handle* handle); | |||
} // namespace rnn | |||
} // namespace megdnn |
@@ -0,0 +1,108 @@ | |||
/** | |||
* \file dnn/src/common/rnn_cell.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/common/rnn_cell.h" | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void RNNCell::deduce_layout( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, TensorLayout& dst) { | |||
// megdnn_assert(hx.ndim == 2); | |||
size_t batch_size = hx.shape[0]; | |||
// size_t hidden_size = weight_hh.shape[1]; | |||
size_t gate_hidden_size = weight_ih.shape[0]; | |||
// size_t input_size = weight_ih.shape[1]; | |||
// megdnn_assert(input.shape[1] == input_size); | |||
// megdnn_assert(hx.shape[1] == hidden_size); | |||
// megdnn_assert_eq_dtype(input, hx); | |||
dst = TensorLayout(TensorShape({batch_size, gate_hidden_size}), input.dtype); | |||
} | |||
void RNNCell::check_exec( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& dst, size_t workspace_in_bytes) { | |||
TensorLayout dst_expected; | |||
megdnn_assert_eq_dtype(input, dst); | |||
megdnn_assert_eq_dtype(hx, dst); | |||
deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst_expected); | |||
megdnn_assert_eq_layout(dst_expected, dst); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
} // namespace megdnn | |||
namespace megdnn { | |||
namespace rnn_cell { | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& dst, Handle* handle) { | |||
auto opr = handle->create_operator<MatrixMulForward>(); | |||
opr->param().transposeB = true; | |||
return dst.span().dist_byte() + opr->get_workspace_in_bytes(hx, weight_hh, dst); | |||
} | |||
void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace, | |||
param::RNNCell::NonlineMode nonline_mode, Handle* handle) { | |||
TensorND tmp{static_cast<void*>(workspace.raw_ptr), dst.layout}; | |||
_megdnn_workspace new_workspace = { | |||
workspace.raw_ptr + dst.layout.span().dist_byte(), | |||
workspace.size - dst.layout.span().dist_byte()}; | |||
auto opr = handle->create_operator<MatrixMulForward>(); | |||
opr->param().transposeB = true; | |||
opr->exec(input, weight_ih, tmp, new_workspace); | |||
opr->exec(hx, weight_hh, dst, new_workspace); | |||
// if (this->param().bias) add_bias(dst, tmp, bias, dst); | |||
// if (this->param().bias) { | |||
auto add_opr = handle->create_operator<ElemwiseForward>(); | |||
add_opr->param().mode = Elemwise::Param::Mode::ADD; | |||
add_opr->exec({dst, tmp}, dst); | |||
add_opr->exec({dst, bias_ih}, dst); | |||
add_opr->exec({dst, bias_hh}, dst); | |||
// } | |||
// activation | |||
using NonlineMode = param::RNNCell::NonlineMode; | |||
switch (nonline_mode) { | |||
#define cb(_mode) \ | |||
case NonlineMode::_mode: { \ | |||
auto nonlinear = handle->create_operator<ElemwiseForward>(); \ | |||
nonlinear->param().mode = Elemwise::Param::Mode::_mode; \ | |||
nonlinear->exec({dst}, dst); \ | |||
break; \ | |||
} | |||
cb(RELU); | |||
cb(TANH); | |||
#undef cb | |||
case NonlineMode::IDENTITY: | |||
break; | |||
default: | |||
megdnn_assert(false); | |||
} | |||
} | |||
} // namespace rnn_cell | |||
} // namespace megdnn |
@@ -0,0 +1,31 @@ | |||
/** | |||
* \file dnn/src/common/rnn_cell.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs/base.h" | |||
#include "megdnn/oprs/general.h" | |||
namespace megdnn { | |||
namespace rnn_cell { | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& dst, Handle* handle); | |||
void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace, | |||
param::RNNCell::NonlineMode nonline_mode, Handle* handle); | |||
} // namespace rnn_cell | |||
} // namespace megdnn |
@@ -160,6 +160,29 @@ void TensorDesc::set( | |||
} | |||
} | |||
void TensorDesc::set_nd(const TensorLayout& layout, int pad) { | |||
int nbDims = layout.ndim < pad ? pad : layout.ndim; | |||
int dimA[nbDims], strideA[nbDims]; | |||
for (size_t i = 0; i < layout.ndim; ++i) { | |||
dimA[i] = layout.shape[i]; | |||
// strideA[i] = layout.stride[i]; | |||
} | |||
for (size_t i = layout.ndim; i < nbDims; ++i) { | |||
dimA[i] = 1; // unused | |||
// strideA[i] = 1; | |||
} | |||
// stride | |||
for (size_t i = 0; i < nbDims; ++i) { | |||
strideA[i] = 1; | |||
for (size_t j = i + 1; j < nbDims; ++j) { | |||
strideA[i] *= dimA[j]; | |||
} | |||
} | |||
cudnn_check(cudnnSetTensorNdDescriptor( | |||
desc, to_cudnn_dtype(layout.dtype), nbDims, dimA, strideA)); | |||
} | |||
std::string TensorDesc::to_string() { | |||
cudnnDataType_t data_type; | |||
int n; | |||
@@ -433,6 +456,97 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) { | |||
desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT)); | |||
} | |||
DropoutDesc::DropoutDesc() { | |||
cudnn_check(cudnnCreateDropoutDescriptor(&desc)); | |||
} | |||
DropoutDesc::~DropoutDesc() { | |||
cudnn_check(cudnnDestroyDropoutDescriptor(desc)); | |||
} | |||
void DropoutDesc::set(float dropout, Handle* handle, TensorND& state) { | |||
cudnn_check(cudnnSetDropoutDescriptor( | |||
desc, cudnn_handle(handle), dropout, state.raw_ptr(), | |||
state.layout.span().dist_byte(), 0 // seed | |||
)); | |||
} | |||
void DropoutDesc::set_no_dropout(Handle* handle) { | |||
cudnn_check( | |||
cudnnSetDropoutDescriptor(desc, cudnn_handle(handle), 0, nullptr, 0, 0)); | |||
} | |||
RNNDesc::RNNDesc() { | |||
cudnn_check(cudnnCreateRNNDescriptor(&desc)); | |||
} | |||
RNNDesc::~RNNDesc() { | |||
cudnn_check(cudnnDestroyRNNDescriptor(desc)); | |||
} | |||
void RNNDesc::set( | |||
size_t input_size, size_t hidden_size, size_t proj_size, size_t num_layers, | |||
bool bidirectional, bool bias, const megdnn::DType dtype, cudnnRNNMode_t mode, | |||
DropoutDesc& dropout_desc, Handle* handle) { | |||
cudnnRNNMode_t rnn_mode = mode; | |||
cudnnRNNBiasMode_t bias_mode = bias ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; | |||
cudnnDirectionMode_t dir_mode = | |||
bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; | |||
cudnnDataType_t math_prec; | |||
// math precision | |||
if (dtype.enumv() == DTypeEnum::Float16) | |||
math_prec = CUDNN_DATA_HALF; | |||
else | |||
math_prec = CUDNN_DATA_FLOAT; | |||
#if false // CUDNN_MAJOR >= 8 | |||
cudnn_check(cudnnSetRNNDescriptor_v8( | |||
desc, CUDNN_RNN_ALGO_STANDARD, mode, bias_mode, dir_mode, | |||
CUDNN_LINEAR_INPUT, to_cudnn_dtype(dtype), math_prec, CUDNN_DEFAULT_MATH, | |||
input_size, hidden_size, proj_size, num_layers, dropout_desc.desc, | |||
CUDNN_RNN_PADDED_IO_DISABLED)); | |||
#else | |||
cudnn_check(cudnnSetRNNDescriptor_v6( | |||
cudnn_handle(handle), desc, hidden_size, num_layers, dropout_desc.desc, | |||
CUDNN_LINEAR_INPUT, dir_mode, mode, CUDNN_RNN_ALGO_STANDARD, math_prec)); | |||
#endif | |||
} | |||
RNNDataDesc::RNNDataDesc() { | |||
cudnn_check(cudnnCreateRNNDataDescriptor(&desc)); | |||
} | |||
RNNDataDesc::~RNNDataDesc() { | |||
cudnn_check(cudnnDestroyRNNDataDescriptor(desc)); | |||
} | |||
void RNNDataDesc::set( | |||
int batchSize, int vectorSize, int maxSeqLength, const int* devSeqLengths, | |||
DType dtype) { | |||
// for now, all tensor are padded in python | |||
// int seqLengthArray[batchSize]; | |||
// for (int i = 0; i < batchSize; ++i) seqLengthArray[i] = maxSeqLength; | |||
cudnn_check(cudnnSetRNNDataDescriptor( | |||
desc, to_cudnn_dtype(dtype), CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED, | |||
maxSeqLength, batchSize, vectorSize, devSeqLengths, nullptr)); | |||
} | |||
RNNWeightFilterDesc::RNNWeightFilterDesc() { | |||
cudnn_check(cudnnCreateFilterDescriptor(&desc)); | |||
} | |||
RNNWeightFilterDesc::~RNNWeightFilterDesc() { | |||
cudnn_check(cudnnDestroyFilterDescriptor(desc)); | |||
} | |||
void RNNWeightFilterDesc::set(const TensorLayout& flatten_weights) { | |||
int weight_elem_num = flatten_weights.total_nr_elems(); | |||
int dimW[] = {weight_elem_num, 1, 1}; | |||
cudnn_check(cudnnSetFilterNdDescriptor( | |||
desc, to_cudnn_dtype(flatten_weights.dtype), CUDNN_TENSOR_NCHW, 3, dimW)); | |||
} | |||
////////////////////////// CudnnAlgoPack ////////////////////////// | |||
#define V1(v) #v | |||
@@ -30,6 +30,7 @@ public: | |||
void set( | |||
const TensorLayout& layout, | |||
const param::Convolution::Format = param::Convolution::Format::NCHW); | |||
void set_nd(const TensorLayout& layout, int pad = 3); // at least 3 dimensions | |||
std::string to_string(); | |||
~TensorDesc(); | |||
cudnnTensorDescriptor_t desc; | |||
@@ -121,6 +122,44 @@ public: | |||
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr> conv3d_fwd_algos(); | |||
}; | |||
class DropoutDesc { | |||
public: | |||
DropoutDesc(); | |||
void set(float dropout, Handle* handle, TensorND& state); | |||
void set_no_dropout(Handle* handle); | |||
~DropoutDesc(); | |||
cudnnDropoutDescriptor_t desc; | |||
}; | |||
class RNNDesc { | |||
public: | |||
RNNDesc(); | |||
void set( | |||
size_t input_size, size_t hidden_size, size_t proj_size, size_t num_layers, | |||
bool bidirectional, bool bias, const megdnn::DType dtype, | |||
cudnnRNNMode_t mode, DropoutDesc& dropout_desc, Handle* handle); | |||
~RNNDesc(); | |||
cudnnRNNDescriptor_t desc; | |||
}; | |||
class RNNDataDesc { | |||
public: | |||
RNNDataDesc(); | |||
void set( | |||
int batchSize, int vectorSize, int maxSeqLength, const int* devSeqLengths, | |||
DType dtype); | |||
~RNNDataDesc(); | |||
cudnnRNNDataDescriptor_t desc; | |||
}; | |||
class RNNWeightFilterDesc { | |||
public: | |||
RNNWeightFilterDesc(); | |||
void set(const TensorLayout& flatten_weights); | |||
~RNNWeightFilterDesc(); | |||
cudnnFilterDescriptor_t desc; | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
@@ -50,6 +50,8 @@ | |||
#include "src/cuda/local_share/opr_impl.h" | |||
#include "src/cuda/lrn/opr_impl.h" | |||
#include "src/cuda/lsq/opr_impl.h" | |||
#include "src/cuda/lstm/opr_impl.h" | |||
#include "src/cuda/lstm_cell/opr_impl.h" | |||
#include "src/cuda/mask_conv/opr_impl.h" | |||
#include "src/cuda/matrix_inverse/opr_impl.h" | |||
#include "src/cuda/matrix_mul/opr_impl.h" | |||
@@ -66,6 +68,8 @@ | |||
#include "src/cuda/repeat/opr_impl.h" | |||
#include "src/cuda/resize/opr_impl.h" | |||
#include "src/cuda/rng/opr_impl.h" | |||
#include "src/cuda/rnn/opr_impl.h" | |||
#include "src/cuda/rnn_cell/opr_impl.h" | |||
#include "src/cuda/roi_align/opr_impl.h" | |||
#include "src/cuda/roi_copy/opr_impl.h" | |||
#include "src/cuda/roi_pooling/opr_impl.h" | |||
@@ -0,0 +1,112 @@ | |||
/** | |||
* \file dnn/src/cuda/lstm/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/lstm/opr_impl.h" | |||
#include "src/cuda/lstm/utils.h" | |||
#include "src/cuda/utils.h" | |||
#include <cudnn.h> | |||
namespace megdnn { | |||
namespace cuda { | |||
void LSTMImpl::exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||
_megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out reserve_space, | |||
_megdnn_workspace workspace) { | |||
Handle* handle = this->handle(); | |||
rnn::RNNForwardDescHolder_v6 desc_holder = | |||
lstm::get_RNNDescHolder_v6(this->handle(), param(), input.layout); | |||
auto x_desc_arr = rnn::get_descs(desc_holder.x_descs); | |||
auto y_desc_arr = rnn::get_descs(desc_holder.y_descs); | |||
RNNWeightFilterDesc w_desc; | |||
w_desc.set(flatten_weights.layout); | |||
if (param().fwd_mode == param::LSTM::FwdMode::TRAINING) { | |||
cudnn_check(cudnnRNNForwardTraining( | |||
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len, | |||
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc, | |||
hx.raw_ptr(), desc_holder.cx_desc.desc, cx.raw_ptr(), w_desc.desc, | |||
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(), | |||
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, | |||
cy.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, | |||
reserve_space.raw_ptr(), desc_holder.reserveSpace_size)); | |||
} else { | |||
cudnn_check(cudnnRNNForwardInference( | |||
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len, | |||
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc, | |||
hx.raw_ptr(), desc_holder.cx_desc.desc, nullptr, w_desc.desc, | |||
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(), | |||
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, | |||
cy.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size)); | |||
} | |||
} | |||
size_t LSTMImpl::get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& cy, | |||
const TensorLayout& reserve_space) { | |||
rnn::RNNForwardDescHolder_v6 desc_holder = | |||
lstm::get_RNNDescHolder_v6(this->handle(), param(), input); | |||
return desc_holder.workspace_size; | |||
} | |||
size_t LSTMImpl::get_reserve_size_in_bytes(const TensorLayout& input) { | |||
rnn::RNNForwardDescHolder_v6 desc_holder = | |||
lstm::get_RNNDescHolder_v6(this->handle(), param(), input); | |||
return desc_holder.reserveSpace_size; | |||
} | |||
void LSTMBackwardImpl::exec( | |||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, | |||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx, | |||
_megdnn_tensor_out dcx, _megdnn_tensor_out dw, _megdnn_workspace workspace) { | |||
Handle* handle = this->handle(); | |||
size_t seq_len = x.layout.shape[0]; | |||
auto desc_holder = lstm::get_RNNDescHolder_v6(handle, param(), x.layout); | |||
auto x_desc_arr_ptr = rnn::get_descs(desc_holder.x_descs).data(); | |||
auto y_desc_arr_ptr = rnn::get_descs(desc_holder.y_descs).data(); | |||
RNNWeightFilterDesc w_desc; | |||
w_desc.set(flatten_weights.layout); | |||
cudnn_check(cudnnRNNBackwardData( | |||
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, y_desc_arr_ptr, | |||
y.raw_ptr(), y_desc_arr_ptr, dy.raw_ptr(), desc_holder.hy_desc.desc, | |||
dhy.raw_ptr(), desc_holder.cy_desc.desc, dcy.raw_ptr(), w_desc.desc, | |||
flatten_weights.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), | |||
desc_holder.cx_desc.desc, cx.raw_ptr(), x_desc_arr_ptr, dx.raw_ptr(), | |||
desc_holder.hx_desc.desc, dhx.raw_ptr(), desc_holder.cx_desc.desc, | |||
dcx.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, | |||
reserve_space.raw_ptr(), desc_holder.reserveSpace_size)); | |||
cudnn_check(cudnnRNNBackwardWeights( | |||
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, x_desc_arr_ptr, | |||
x.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), y_desc_arr_ptr, | |||
y.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, w_desc.desc, | |||
dw.raw_ptr(), reserve_space.raw_ptr(), desc_holder.reserveSpace_size)); | |||
} | |||
size_t LSTMBackwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw) { | |||
auto desc_holder = lstm::get_RNNDescHolder_v6(this->handle(), param(), x); | |||
return desc_holder.workspace_size; | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,56 @@ | |||
/** | |||
* \file dnn/src/cuda/lstm/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class LSTMImpl : public LSTM { | |||
public: | |||
using LSTM::LSTM; | |||
void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||
_megdnn_tensor_out hy, _megdnn_tensor_out cy, | |||
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace); | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& cy, | |||
const TensorLayout& reserve_space); | |||
size_t get_reserve_size_in_bytes(const TensorLayout& input); | |||
}; | |||
class LSTMBackwardImpl : public LSTMBackward { | |||
public: | |||
using LSTMBackward::LSTMBackward; | |||
virtual void exec( | |||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, | |||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, | |||
_megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw, | |||
_megdnn_workspace workspace); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw); | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,39 @@ | |||
/** | |||
* \file dnn/src/cuda/lstm/utils.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/lstm/utils.h" | |||
#include "src/cuda/utils.h" | |||
#include <cudnn.h> | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace lstm { | |||
RNNForwardDescHolder_v6 get_RNNDescHolder_v6( | |||
Handle* handle, megdnn::LSTMForward::Param& _param, const TensorLayout& input) { | |||
size_t seq_len = input.shape[0]; | |||
size_t batch_size = input.shape[1]; | |||
size_t input_size = input.shape[2]; | |||
cudnnRNNMode_t mode = CUDNN_LSTM; | |||
using FwdMode = param::LSTM::FwdMode; | |||
RNNForwardDescHolder_v6 desc_holder( | |||
handle, seq_len, batch_size, _param.hidden_size, input_size, | |||
_param.proj_size, _param.num_layers, _param.bidirectional, _param.bias, | |||
input.dtype, mode); | |||
return desc_holder; | |||
} | |||
} // namespace lstm | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,23 @@ | |||
/** | |||
* \file dnn/src/cuda/lstm/utils.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/cuda/cudnn_wrapper.h" | |||
#include "src/cuda/rnn/utils.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace lstm { | |||
using megdnn::cuda::rnn::RNNForwardDescHolder_v6; | |||
RNNForwardDescHolder_v6 get_RNNDescHolder_v6( | |||
Handle* handle, megdnn::LSTMForward::Param& _param, const TensorLayout& input); | |||
} // namespace lstm | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,42 @@ | |||
/** | |||
* \file dnn/src/cuda/lstm_cell/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/lstm_cell/opr_impl.h" | |||
#include "megdnn/dtype.h" | |||
#include "megdnn/oprs/base.h" | |||
#include "src/common/lstm_cell.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
size_t LSTMCellImpl::get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, | |||
const TensorLayout& gates) { | |||
return megdnn::lstm_cell::get_workspace_in_bytes( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates, | |||
handle()); | |||
} | |||
void LSTMCellImpl::exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||
_megdnn_tensor_out gates, _megdnn_workspace workspace) { | |||
megdnn::lstm_cell::exec( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates, | |||
workspace, handle()); | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,36 @@ | |||
/** | |||
* \file dnn/src/cuda/lstm_cell/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/cuda/rnn_cell/opr_impl.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class LSTMCellImpl : public LSTMCell { | |||
public: | |||
using LSTMCell::LSTMCell; | |||
void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||
_megdnn_tensor_out gates, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& cx, const TensorLayout& h_new, | |||
const TensorLayout& c_new, const TensorLayout& gates) override; | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,170 @@ | |||
/** | |||
* \file dnn/src/cuda/rnn/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/rnn/opr_impl.h" | |||
#include "src/common/rnn.h" | |||
#include "src/cuda/utils.h" | |||
//#include <cstring> | |||
#include <cudnn.h> | |||
#include <cstdlib> | |||
#include <iostream> | |||
namespace megdnn { | |||
namespace cuda { | |||
using namespace std; | |||
void RNNImpl::exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, | |||
_megdnn_workspace workspace) { | |||
Handle* handle = this->handle(); | |||
#if false // CUDNN_MAJOR >= 8 | |||
rnn::RNNForwardDescHolder desc_holder = this->get_desc_holder(input.layout); | |||
void* workspace_ptr = workspace.raw_ptr; | |||
void* reserveSpace_ptr = static_cast<uint8_t*>(workspace_ptr) + desc_holder.workspace_size; | |||
cudnn_check(cudnnRNNForward( | |||
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.fwdMode, desc_holder.devSeqLengths, | |||
desc_holder.x_desc.desc, input.raw_ptr(), desc_holder.y_desc.desc, output.raw_ptr(), | |||
desc_holder.h_desc.desc, hx.raw_ptr(), hy.raw_ptr(), | |||
desc_holder.h_desc.desc, nullptr, nullptr, | |||
desc_holder.weight_size, flatten_weights.raw_ptr(), desc_holder.workspace_size, workspace_ptr, | |||
desc_holder.reserveSpace_size, reserveSpace_ptr | |||
)); | |||
#else | |||
rnn::RNNForwardDescHolder_v6 desc_holder = | |||
rnn::get_RNNDescHolder_v6(this->handle(), param(), input.layout); | |||
auto x_desc_arr = rnn::get_descs(desc_holder.x_descs); | |||
auto y_desc_arr = rnn::get_descs(desc_holder.y_descs); | |||
RNNWeightFilterDesc w_desc; | |||
w_desc.set(flatten_weights.layout); | |||
if (param().fwd_mode == param::RNN::FwdMode::TRAINING) { | |||
cudnn_check(cudnnRNNForwardTraining( | |||
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len, | |||
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc, | |||
hx.raw_ptr(), desc_holder.cx_desc.desc, NULL, w_desc.desc, | |||
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(), | |||
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, NULL, | |||
workspace.raw_ptr, desc_holder.workspace_size, reserve_space.raw_ptr(), | |||
desc_holder.reserveSpace_size)); | |||
} else { | |||
cudnn_check(cudnnRNNForwardInference( | |||
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len, | |||
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc, | |||
hx.raw_ptr(), desc_holder.cx_desc.desc, nullptr, w_desc.desc, | |||
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(), | |||
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, | |||
nullptr, workspace.raw_ptr, desc_holder.workspace_size)); | |||
} | |||
#endif | |||
} | |||
size_t RNNImpl::get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& hx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& reserve_space) { | |||
#if false // CUDNN_MAJOR >= 8 | |||
rnn::RNNForwardDescHolder desc_holder = this->get_desc_holder(input); | |||
#else | |||
rnn::RNNForwardDescHolder_v6 desc_holder = | |||
rnn::get_RNNDescHolder_v6(this->handle(), param(), input); | |||
#endif | |||
return desc_holder.workspace_size; | |||
} | |||
size_t RNNImpl::get_reserve_size_in_bytes(const TensorLayout& input) { | |||
rnn::RNNForwardDescHolder_v6 desc_holder = | |||
rnn::get_RNNDescHolder_v6(this->handle(), param(), input); | |||
return desc_holder.reserveSpace_size; | |||
} | |||
/*rnn::RNNForwardDescHolder RNNImpl::get_desc_holder(const TensorLayout& input) { | |||
Handle* handle = this->handle(); | |||
size_t seq_len = input.shape[0]; | |||
size_t batch_size = input.shape[1]; | |||
size_t input_size = input.shape[2]; | |||
auto _param = param(); | |||
cudnnRNNMode_t mode; | |||
using NonlineMode = param::RNN::NonlineMode; | |||
switch (_param.nonlineMode) { | |||
case NonlineMode::RELU: | |||
mode = CUDNN_RNN_RELU; | |||
break; | |||
case NonlineMode::TANH: | |||
mode = CUDNN_RNN_TANH; | |||
break; | |||
} | |||
cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_TRAINING; | |||
using FwdMode = param::RNN::FwdMode; | |||
switch (_param.fwd_mode) { | |||
case FwdMode::TRAINING: | |||
fwdMode = CUDNN_FWD_MODE_TRAINING; | |||
break; | |||
case FwdMode::INFERENCE: | |||
fwdMode = CUDNN_FWD_MODE_INFERENCE; | |||
break; | |||
} | |||
rnn::RNNForwardDescHolder desc_holder( | |||
handle, seq_len, batch_size, _param.hidden_size, input_size, | |||
_param.proj_size, _param.num_layers, _param.bidirectional, _param.bias, | |||
input.dtype, mode, fwdMode); | |||
return desc_holder; | |||
}*/ | |||
void RNNBackwardImpl::exec( | |||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in dy, _megdnn_tensor_in dhy, _megdnn_tensor_in flatten_weights, | |||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx, | |||
_megdnn_tensor_out dw, _megdnn_workspace workspace) { | |||
Handle* handle = this->handle(); | |||
size_t seq_len = x.layout.shape[0]; | |||
auto desc_holder = rnn::get_RNNDescHolder_v6(handle, param(), x.layout); | |||
auto x_desc_arr_ptr = rnn::get_descs(desc_holder.x_descs).data(); | |||
auto y_desc_arr_ptr = rnn::get_descs(desc_holder.y_descs).data(); | |||
RNNWeightFilterDesc w_desc; | |||
w_desc.set(flatten_weights.layout); | |||
cudnn_check(cudnnRNNBackwardData( | |||
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, y_desc_arr_ptr, | |||
y.raw_ptr(), y_desc_arr_ptr, dy.raw_ptr(), desc_holder.hy_desc.desc, | |||
dhy.raw_ptr(), desc_holder.cy_desc.desc, NULL, w_desc.desc, | |||
flatten_weights.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), | |||
desc_holder.cx_desc.desc, NULL, x_desc_arr_ptr, dx.raw_ptr(), | |||
desc_holder.hx_desc.desc, dhx.raw_ptr(), desc_holder.cx_desc.desc, NULL, | |||
workspace.raw_ptr, desc_holder.workspace_size, reserve_space.raw_ptr(), | |||
desc_holder.reserveSpace_size)); | |||
cudnn_check(cudnnRNNBackwardWeights( | |||
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, x_desc_arr_ptr, | |||
x.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), y_desc_arr_ptr, | |||
y.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, w_desc.desc, | |||
dw.raw_ptr(), reserve_space.raw_ptr(), desc_holder.reserveSpace_size)); | |||
} | |||
size_t RNNBackwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw) { | |||
auto desc_holder = rnn::get_RNNDescHolder_v6(this->handle(), param(), x); | |||
return desc_holder.workspace_size; | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,57 @@ | |||
/** | |||
* \file dnn/src/cuda/rnn/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/cuda/cudnn_wrapper.h" | |||
#include "src/cuda/rnn/utils.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class RNNImpl : public RNN { | |||
public: | |||
using RNN::RNN; | |||
void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, | |||
_megdnn_workspace workspace); | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& hx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& reserve_space); | |||
size_t get_reserve_size_in_bytes(const TensorLayout& input); | |||
// private: | |||
// rnn::RNNForwardDescHolder get_desc_holder(const TensorLayout& input); | |||
}; | |||
class RNNBackwardImpl : public RNNBackward { | |||
public: | |||
using RNNBackward::RNNBackward; | |||
virtual void exec( | |||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space, | |||
_megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw, | |||
_megdnn_workspace workspace); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw); | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,138 @@ | |||
/** | |||
* \file dnn/src/cuda/rnn/utils.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/rnn/utils.h" | |||
#include "src/cuda/utils.h" | |||
#include <cudnn.h> | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace rnn { | |||
/*RNNForwardDescHolder::RNNForwardDescHolder(Handle* handle, size_t seq_len, size_t | |||
batch_size, size_t hidden_size, size_t input_size, size_t proj_size, size_t num_layers, | |||
bool bidirectional, bool bias, DType dtype, cudnnRNNMode_t _mode, cudnnForwardMode_t | |||
_fwdMode) : mode(_mode), fwdMode(_fwdMode) | |||
{ | |||
size_t D = bidirectional ? 2 : 1; | |||
// TODO: set dropout to 0 in inference mode | |||
dropout_desc.set_no_dropout(handle); | |||
// seq len is unified (not packed) | |||
// cuda_check(cudaMalloc((void**)&devSeqLengths, sizeof(int32_t) * batch_size)); | |||
devSeqLengths = (int32_t*)malloc(sizeof(int32_t) * batch_size); | |||
for (size_t i = 0; i < batch_size; ++i) devSeqLengths[i] = seq_len; | |||
// proj size should be smaller than hidden size according to cudnn api | |||
// otherwise it is disabled | |||
proj_size = (proj_size > hidden_size || proj_size == 0) ? hidden_size : | |||
proj_size; rnn_desc.set( input_size, hidden_size, proj_size, num_layers, bidirectional, | |||
bias, dtype, mode, dropout_desc, handle | |||
); | |||
x_desc.set(batch_size, input_size, seq_len, devSeqLengths, dtype); | |||
y_desc.set(batch_size, D * proj_size, seq_len, | |||
devSeqLengths, dtype); | |||
h_desc.set_nd(TensorLayout(TensorShape{D * num_layers, batch_size, proj_size}, | |||
dtype)); | |||
cudnn_check(cudnnGetRNNWeightSpaceSize(cudnn_handle(handle), rnn_desc.desc, | |||
&weight_size)); | |||
cudnn_check(cudnnGetRNNTempSpaceSizes( | |||
cudnn_handle(handle), rnn_desc.desc, fwdMode, x_desc.desc, | |||
&workspace_size, &reserveSpace_size | |||
)); | |||
} | |||
RNNForwardDescHolder::~RNNForwardDescHolder() { | |||
// cuda_check(cudaFree(devSeqLengths)); | |||
free(devSeqLengths); | |||
}*/ | |||
RNNForwardDescHolder_v6::RNNForwardDescHolder_v6( | |||
Handle* handle, size_t seq_len, size_t batch_size, size_t hidden_size, | |||
size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional, | |||
bool bias, DType dtype, cudnnRNNMode_t _mode) | |||
: mode(_mode), seq_len(seq_len) { | |||
size_t D = bidirectional ? 2 : 1; | |||
// TODO: set dropout to 0 in inference mode | |||
dropout_desc.set_no_dropout(handle); | |||
proj_size = (proj_size > hidden_size || proj_size == 0) ? hidden_size : proj_size; | |||
rnn_desc.set( | |||
input_size, hidden_size, proj_size, num_layers, bidirectional, bias, dtype, | |||
mode, dropout_desc, handle); | |||
x_descs.resize(seq_len); | |||
y_descs.resize(seq_len); | |||
for (size_t i = 0; i < seq_len; ++i) { | |||
x_descs[i].set_nd(TensorLayout(TensorShape{batch_size, input_size}, dtype), 3); | |||
y_descs[i].set_nd( | |||
TensorLayout(TensorShape{batch_size, D * hidden_size}, dtype), 3); | |||
} | |||
#define SET_H(_var) \ | |||
_var.set_nd(TensorLayout( \ | |||
TensorShape{D * num_layers, batch_size, hidden_size}, dtype)); | |||
SET_H(hx_desc) | |||
SET_H(cx_desc) | |||
SET_H(hy_desc) | |||
SET_H(cy_desc) | |||
#undef SET_H | |||
std::vector<cudnnTensorDescriptor_t> x_desc_arr = get_descs(x_descs); | |||
cudnn_check(cudnnGetRNNWorkspaceSize( | |||
cudnn_handle(handle), rnn_desc.desc, seq_len, x_desc_arr.data(), | |||
&workspace_size)); | |||
cudnn_check(cudnnGetRNNTrainingReserveSize( | |||
cudnn_handle(handle), rnn_desc.desc, seq_len, x_desc_arr.data(), | |||
&reserveSpace_size)); | |||
} | |||
RNNForwardDescHolder_v6 get_RNNDescHolder_v6( | |||
Handle* handle, megdnn::RNNForward::Param& _param, const TensorLayout& input) { | |||
size_t seq_len = input.shape[0]; | |||
size_t batch_size = input.shape[1]; | |||
size_t input_size = input.shape[2]; | |||
cudnnRNNMode_t mode; | |||
using NonlineMode = param::RNN::NonlineMode; | |||
switch (_param.nonlineMode) { | |||
case NonlineMode::RELU: | |||
mode = CUDNN_RNN_RELU; | |||
break; | |||
case NonlineMode::TANH: | |||
mode = CUDNN_RNN_TANH; | |||
break; | |||
} | |||
RNNForwardDescHolder_v6 desc_holder( | |||
handle, seq_len, batch_size, _param.hidden_size, input_size, | |||
_param.proj_size, _param.num_layers, _param.bidirectional, _param.bias, | |||
input.dtype, mode); | |||
return desc_holder; | |||
} | |||
std::vector<cudnnTensorDescriptor_t> get_descs(const std::vector<TensorDesc>& descs) { | |||
std::vector<cudnnTensorDescriptor_t> r; | |||
r.reserve(descs.size()); | |||
for (auto& desc : descs) { | |||
r.emplace_back(desc.desc); | |||
} | |||
return r; | |||
} | |||
} // namespace rnn | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,56 @@ | |||
/** | |||
* \file dnn/src/cuda/rnn/utils.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/cuda/cudnn_wrapper.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace rnn { | |||
// v8, not for now | |||
/*struct RNNForwardDescHolder { | |||
int32_t* devSeqLengths; | |||
cudnnRNNMode_t mode; | |||
cudnnForwardMode_t fwdMode; | |||
RNNDesc rnn_desc; | |||
DropoutDesc dropout_desc; | |||
RNNDataDesc x_desc, y_desc; | |||
TensorDesc h_desc; | |||
size_t weight_size, workspace_size, reserveSpace_size; | |||
RNNForwardDescHolder(Handle* handle, size_t seq_len, size_t batch_size, size_t | |||
hidden_size, size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional, | |||
bool bias, DType dtype, | |||
cudnnRNNMode_t _mode, cudnnForwardMode_t _fwdMode); ~RNNForwardDescHolder(); | |||
};*/ | |||
struct RNNForwardDescHolder_v6 { | |||
cudnnRNNMode_t mode; | |||
RNNDesc rnn_desc; | |||
int seq_len; | |||
DropoutDesc dropout_desc; | |||
std::vector<TensorDesc> x_descs, y_descs; | |||
TensorDesc hx_desc, cx_desc, hy_desc, cy_desc; | |||
size_t workspace_size, reserveSpace_size; | |||
RNNForwardDescHolder_v6( | |||
Handle* handle, size_t seq_len, size_t batch_size, size_t hidden_size, | |||
size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional, | |||
bool bias, DType dtype, cudnnRNNMode_t _mode); | |||
}; | |||
RNNForwardDescHolder_v6 get_RNNDescHolder_v6( | |||
Handle* handle, megdnn::RNNForward::Param& _param, const TensorLayout& input); | |||
std::vector<cudnnTensorDescriptor_t> get_descs(const std::vector<TensorDesc>& descs); | |||
} // namespace rnn | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,35 @@ | |||
/** | |||
* \file dnn/src/cuda/rnn_cell/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/cuda/rnn_cell/opr_impl.h" | |||
#include "src/common/rnn_cell.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
size_t RNNCellImpl::get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& dst) { | |||
return megdnn::rnn_cell::get_workspace_in_bytes( | |||
input, weight_ih, bias_hh, hx, weight_hh, bias_hh, dst, handle()); | |||
} | |||
void RNNCellImpl::exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
megdnn::rnn_cell::exec( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, workspace, | |||
param().nonlineMode, handle()); | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -0,0 +1,40 @@ | |||
/** | |||
* \file dnn/src/cuda/rnn_cell/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class RNNCellImpl : public RNNCell { | |||
public: | |||
using RNNCell::RNNCell; | |||
void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& dst) override; | |||
/* | |||
private: | |||
void add_bias(_megdnn_tensor_in A, | |||
_megdnn_tensor_in B, | |||
_megdnn_tensor_in bias, | |||
_megdnn_tensor_out C); | |||
*/ | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn |
@@ -52,6 +52,8 @@ | |||
#include "src/naive/local_share/opr_impl.h" | |||
#include "src/naive/lrn/opr_impl.h" | |||
#include "src/naive/lsq/opr_impl.h" | |||
#include "src/naive/lstm/opr_impl.h" | |||
#include "src/naive/lstm_cell/opr_impl.h" | |||
#include "src/naive/mask_conv/opr_impl.h" | |||
#include "src/naive/matrix_inverse/opr_impl.h" | |||
#include "src/naive/matrix_mul/opr_impl.h" | |||
@@ -68,6 +70,8 @@ | |||
#include "src/naive/repeat/opr_impl.h" | |||
#include "src/naive/resize/opr_impl.h" | |||
#include "src/naive/rng/opr_impl.h" | |||
#include "src/naive/rnn/opr_impl.h" | |||
#include "src/naive/rnn_cell/opr_impl.h" | |||
#include "src/naive/roi_align/opr_impl.h" | |||
#include "src/naive/roi_copy/opr_impl.h" | |||
#include "src/naive/roi_pooling/opr_impl.h" | |||
@@ -0,0 +1,146 @@ | |||
/** | |||
* \file dnn/src/naive/lstm/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/naive/lstm/opr_impl.h" | |||
#include "src/naive/rnn/funcs.h" | |||
#include "src/naive/rnn/rnn.h" | |||
namespace megdnn { | |||
namespace naive { | |||
using rnn::LSTMCellWeightWrapper; | |||
void LSTMImpl::exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||
_megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out reserve_space, | |||
_megdnn_workspace workspace) { | |||
auto _param = param(); | |||
size_t D = _param.bidirectional ? 2 : 1; | |||
size_t num_layers = _param.num_layers; | |||
size_t input_size = input.layout.shape[2]; | |||
std::vector<LSTMCellWeightWrapper> cells; | |||
size_t used_workspace_size = rnn::get_cells<LSTMCellWeightWrapper>( | |||
D, num_layers, input_size, _param.hidden_size, _param.bias, cells, | |||
flatten_weights, workspace); | |||
Workspace new_workspace( | |||
workspace.raw_ptr + used_workspace_size, | |||
workspace.size - used_workspace_size); | |||
TensorNDArray states = {hx, cx}, states_new = {hy, cy}; | |||
rnn::exec_internal<LSTMCellWeightWrapper, LSTMCellForward>( | |||
cells, input, states, states_new, output, reserve_space, num_layers, D, | |||
this->handle(), new_workspace); | |||
} | |||
size_t LSTMImpl::get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& cy, | |||
const TensorLayout& reserve_space) { | |||
size_t workspace_size = rnn::get_workspace_in_bytes<LSTMCellForward>( | |||
input, flatten_weights, param().hidden_size, param().bidirectional ? 2 : 1, | |||
this->handle()); | |||
if (!param().bias) { // use fake bias (all 0) | |||
TensorLayout bias_layout = {{param().hidden_size * 4}, flatten_weights.dtype}; | |||
workspace_size += bias_layout.span().dist_byte(); | |||
} | |||
workspace_size += output.span().dist_byte(); | |||
return workspace_size; | |||
} | |||
size_t LSTMImpl::get_reserve_size_in_bytes(const TensorLayout& input) { | |||
size_t num_layers = param().num_layers; | |||
size_t D = param().bidirectional ? 2 : 1; | |||
size_t seq_len = input.shape[0]; | |||
size_t batch_size = input.shape[1]; | |||
TensorLayout state_layout{{batch_size, param().hidden_size}, input.dtype}; | |||
// 2 for hidden state and cell state | |||
return 2 * num_layers * D * seq_len * state_layout.span().dist_byte(); | |||
} | |||
void LSTMBackwardImpl::exec( | |||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, | |||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx, | |||
_megdnn_tensor_out dcx, _megdnn_tensor_out dw, _megdnn_workspace workspace) { | |||
TensorNDArray layer_inputs; | |||
TensorNDArray layer_outputs; | |||
std::vector<std::vector<TensorNDArray>> cell_seq_states; | |||
size_t num_layers = param().num_layers; | |||
size_t D = param().bidirectional ? 2 : 1; | |||
size_t input_size = x.layout.shape[2]; | |||
size_t hidden_size = param().hidden_size; | |||
size_t used_workspace_size = 0; | |||
// get cells | |||
std::vector<LSTMCellWeightWrapper> cells; | |||
used_workspace_size += rnn::get_cells<LSTMCellWeightWrapper>( | |||
D, num_layers, input_size, hidden_size, param().bias, cells, | |||
flatten_weights, workspace); | |||
// get formatted inputs | |||
Workspace new_workspace = Workspace( | |||
workspace.raw_ptr + used_workspace_size, | |||
workspace.size - used_workspace_size); | |||
used_workspace_size += rnn::get_inputs_for_exec<LSTMCellWeightWrapper>( | |||
x, y, reserve_space, num_layers, D, hidden_size, cells, layer_inputs, | |||
layer_outputs, cell_seq_states, param::RNNCell::NonlineMode::IDENTITY, | |||
new_workspace); | |||
// dhy arr, dhx arr | |||
TensorNDArray dhy_arr = {dhy, dcy}, dhx_arr = {dhx, dcx}; | |||
// exec | |||
new_workspace = Workspace( | |||
workspace.raw_ptr + used_workspace_size, | |||
workspace.size - used_workspace_size); | |||
rnn::backward_exec_internal<LSTMCellWeightWrapper>( | |||
cells, D, num_layers, input_size, param().bias, | |||
param::RNNCell::NonlineMode::IDENTITY, layer_inputs, layer_outputs, | |||
cell_seq_states, dy, dhy_arr, dx, dhx_arr, dw, this->handle(), | |||
new_workspace); | |||
} | |||
size_t LSTMBackwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw) { | |||
size_t D = param().bidirectional ? 2 : 1; | |||
size_t num_layers = param().num_layers; | |||
size_t hidden_size = param().hidden_size; | |||
size_t gate_hidden_size = hidden_size * 4; | |||
size_t max_input_size = std::max(x.shape[2], D * hidden_size); | |||
size_t workspace_size = LSTMCellWeightWrapper::backward_workspace_size_in_bytes( | |||
this->handle(), x.shape[1], param().hidden_size, max_input_size, x.dtype); | |||
if (!param().bias) { // use fake bias (all 0) | |||
TensorLayout bias_layout = {{gate_hidden_size}, flatten_weights.dtype}; | |||
workspace_size += bias_layout.span().dist_byte() * | |||
2; // times 2 because another bias is allocated in | |||
// backward_exec_internal | |||
} | |||
workspace_size += num_layers * y.span().dist_byte(); | |||
// add back exec workspace size | |||
workspace_size += y.span().dist_byte() * 2; | |||
workspace_size += x.span().dist_byte() * 2; | |||
TensorLayout wih{{gate_hidden_size, max_input_size}, flatten_weights.dtype}; | |||
TensorLayout whh{{gate_hidden_size, hidden_size}, flatten_weights.dtype}; | |||
TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype}; | |||
workspace_size += wih.span().dist_byte(); | |||
workspace_size += whh.span().dist_byte(); | |||
workspace_size += bias.span().dist_byte(); | |||
return workspace_size; | |||
} | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,56 @@ | |||
/** | |||
* \file dnn/src/naive/lstm/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class LSTMImpl : public LSTM { | |||
public: | |||
using LSTM::LSTM; | |||
void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||
_megdnn_tensor_out hy, _megdnn_tensor_out cy, | |||
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace); | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& cy, | |||
const TensorLayout& reserve_space); | |||
size_t get_reserve_size_in_bytes(const TensorLayout& input); | |||
}; | |||
class LSTMBackwardImpl : public LSTMBackward { | |||
public: | |||
using LSTMBackward::LSTMBackward; | |||
virtual void exec( | |||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights, | |||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, | |||
_megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw, | |||
_megdnn_workspace workspace); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& dcy, const TensorLayout& flatten_weights, | |||
const TensorLayout& reserve_space, const TensorLayout& dx, | |||
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw); | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,55 @@ | |||
/** | |||
* \file dnn/src/naive/lstm/template_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/naive/rnn/funcs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
namespace rnn { | |||
template <> | |||
void cell_opr_exec<LSTMCellForward>( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, | |||
_megdnn_tensor_in bias_hh, const TensorNDArray& states, | |||
TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle) { | |||
auto opr = handle->create_operator<LSTMCellForward>(); | |||
TensorLayout gates, h_new, c_new; | |||
opr->deduce_layout( | |||
input.layout, weight_ih.layout, bias_ih.layout, states[0].layout, | |||
weight_hh.layout, bias_hh.layout, states[1].layout, h_new, c_new, gates); | |||
TensorND gates_tensor{workspace.raw_ptr, gates}; | |||
_megdnn_workspace new_workspace = { | |||
workspace.raw_ptr + gates.span().dist_byte(), | |||
workspace.size - gates.span().dist_byte()}; | |||
opr->exec( | |||
input, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states[1], | |||
states_new[0], states_new[1], gates_tensor, new_workspace); | |||
} | |||
template <> | |||
size_t cell_opr_get_workspace_in_bytes<LSTMCellForward>( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_ih, | |||
const TensorLayout& bias_hh, const TensorLayout& hx, Handle* handle) { | |||
TensorLayout cx = hx; | |||
TensorLayout h_new, c_new, gates; | |||
auto cell_opr = handle->create_operator<LSTMCellForward>(); | |||
cell_opr->deduce_layout( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates); | |||
return cell_opr->get_workspace_in_bytes( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, | |||
gates) + | |||
gates.span().dist_byte(); | |||
} | |||
} // namespace rnn | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,38 @@ | |||
/** | |||
* \file dnn/src/naive/lstm_cell/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/naive/lstm_cell/opr_impl.h" | |||
#include "src/common/lstm_cell.h" | |||
namespace megdnn { | |||
namespace naive { | |||
size_t LSTMCellImpl::get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, | |||
const TensorLayout& gates) { | |||
return megdnn::lstm_cell::get_workspace_in_bytes( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates, | |||
handle()); | |||
} | |||
void LSTMCellImpl::exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||
_megdnn_tensor_out gates, _megdnn_workspace workspace) { | |||
megdnn::lstm_cell::exec( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates, | |||
workspace, handle()); | |||
} | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,36 @@ | |||
/** | |||
* \file dnn/src/naive/lstm_cell/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/naive/rnn_cell/opr_impl.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class LSTMCellImpl : public LSTMCell { | |||
public: | |||
using LSTMCell::LSTMCell; | |||
void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||
_megdnn_tensor_out gates, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& cx, const TensorLayout& h_new, | |||
const TensorLayout& c_new, const TensorLayout& gates) override; | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -72,9 +72,7 @@ void RelayoutForwardImpl::do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) | |||
void RelayoutForwardImpl::check_cpu_handle(Handle* handle) { | |||
megdnn_assert( | |||
!handle || | |||
handle == this->handle() | |||
|| is_cpu_handle(handle), | |||
!handle || handle == this->handle() || is_cpu_handle(handle), | |||
"relayout from non-CPU to CPU not supported"); | |||
} | |||
@@ -0,0 +1,75 @@ | |||
/** | |||
* \file dnn/src/naive/rnn/funcs.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#ifndef _RNN_H | |||
#define _RNN_H | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
namespace rnn { | |||
template <typename CellOpr> | |||
void cell_opr_exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, | |||
_megdnn_tensor_in bias_hh, const TensorNDArray& states, | |||
TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle); | |||
template <typename CellOpr> | |||
size_t cell_opr_get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_ih, | |||
const TensorLayout& bias_hh, const TensorLayout& hx, Handle* handle); | |||
template <typename CellOpr> | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& flatten_weights, | |||
size_t hidden_size, | |||
size_t D, // num_directions | |||
Handle* handle); | |||
template <class Cell, typename CellOpr> | |||
void exec_internal( | |||
std::vector<Cell>& cells, _megdnn_tensor_in input, const TensorNDArray& states, | |||
TensorNDArray& states_new, _megdnn_tensor_out output, | |||
_megdnn_tensor_out reserve_space, size_t num_layers, | |||
size_t D, // D is num_directions | |||
Handle* handle, _megdnn_workspace workspace); | |||
template <class Cell> | |||
size_t get_cells( | |||
size_t D, size_t num_layers, size_t input_size, size_t hidden_size, bool bias, | |||
std::vector<Cell>& cells, _megdnn_tensor_in flatten_weights, | |||
_megdnn_workspace workspace); | |||
template <class Cell> | |||
size_t get_inputs_for_exec( | |||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in reserve_space, | |||
size_t num_layers, size_t D, size_t hidden_size, const std::vector<Cell>& cells, | |||
TensorNDArray& layer_inputs, TensorNDArray& layer_outputs, | |||
std::vector<std::vector<TensorNDArray>>& cell_seq_states, | |||
param::RNNCell::NonlineMode nonlineMode, _megdnn_workspace workspace); | |||
template <class Cell> | |||
void backward_exec_internal( | |||
std::vector<Cell>& cells, size_t D, size_t num_layers, size_t input_size, | |||
bool bias, param::RNNCell::NonlineMode nonlineMode, | |||
const TensorNDArray& layer_inputs, const TensorNDArray& layer_outputs, | |||
const std::vector<std::vector<TensorNDArray>>& cell_seq_states, | |||
_megdnn_tensor_in dy, const TensorNDArray& dhy, _megdnn_tensor_out dx, | |||
TensorNDArray& dstates, _megdnn_tensor_out dw, Handle* handle, | |||
_megdnn_workspace workspace); | |||
} // namespace rnn | |||
} // namespace naive | |||
} // namespace megdnn | |||
#include "funcs.tpp" | |||
#endif |
@@ -0,0 +1,449 @@ | |||
/** | |||
* \file dnn/src/naive/rnn/funcs.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "funcs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
namespace rnn { | |||
template <typename CellOpr> | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& flatten_weights, | |||
size_t hidden_size, | |||
size_t D, // num_directions | |||
Handle* handle) { | |||
size_t seq_len = input.shape[0]; | |||
size_t batch_size = input.shape[1]; | |||
size_t input_size = input.shape[2]; | |||
size_t gate_hidden_size = flatten_weights.shape[0]; | |||
// concat workspace | |||
TensorLayout direction_output_layout{ | |||
TensorShape{seq_len, batch_size, hidden_size}, input.dtype}; | |||
TensorLayout output_layout{{seq_len, batch_size, D * hidden_size}, input.dtype}; | |||
TensorLayoutArray layer_layouts; | |||
for (size_t i = 0; i < D; ++i) | |||
layer_layouts.push_back(direction_output_layout); | |||
auto concat_opr = handle->create_operator<ConcatForward>(); | |||
concat_opr->param().axis = -1; | |||
size_t concat_workspace = | |||
concat_opr->get_workspace_in_bytes(layer_layouts, output_layout); | |||
// cell workspace | |||
TensorLayout weight_ih{{gate_hidden_size, input_size}, flatten_weights.dtype}; | |||
TensorLayout weight_hh{{gate_hidden_size, hidden_size}, flatten_weights.dtype}; | |||
TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype}; | |||
TensorLayout hx{{batch_size, hidden_size}, input.dtype}; | |||
size_t cell_workspace = cell_opr_get_workspace_in_bytes<CellOpr>( | |||
input, weight_ih, weight_hh, bias, bias, hx, handle); | |||
return std::max(cell_workspace, concat_workspace); | |||
} | |||
template <class Cell, typename CellOpr> | |||
void exec_internal( | |||
std::vector<Cell>& cells, _megdnn_tensor_in input, const TensorNDArray& states, | |||
TensorNDArray& states_new, _megdnn_tensor_out output, | |||
_megdnn_tensor_out reserve_space, size_t num_layers, | |||
size_t D, // D is num_directions | |||
Handle* handle, _megdnn_workspace workspace) { | |||
size_t seq_len = input.layout.shape[0]; | |||
size_t batch_size = input.layout.shape[1]; | |||
size_t input_size = input.layout.shape[2]; | |||
size_t hidden_size = cells[0].weight_hh.layout.shape[1]; | |||
TensorLayout cell_output_layout{ | |||
TensorShape{batch_size, hidden_size}, states[0].layout.dtype}; | |||
TensorLayout cell_first_input_layout{ | |||
TensorShape{batch_size, input_size}, input.layout.dtype}; | |||
TensorLayout cell_input_layout{ | |||
TensorShape{batch_size, D * hidden_size}, input.layout.dtype}; | |||
TensorLayout direction_output_layout{ | |||
TensorShape{seq_len, batch_size, hidden_size}, output.layout.dtype}; | |||
TensorND tmp_output{workspace.raw_ptr, output.layout}; | |||
_megdnn_workspace new_workspace{ | |||
workspace.raw_ptr + tmp_output.layout.span().dist_byte(), | |||
workspace.size - tmp_output.layout.span().dist_byte()}; | |||
auto cell_opr = handle->create_operator<CellOpr>(); | |||
auto copy_opr = handle->create_operator<TypeCvtForward>(); | |||
// copy states to states_new | |||
for (size_t i = 0; i < states.size(); ++i) | |||
copy_opr->exec(states[i], states_new[i]); | |||
void* reserve_ptr = reserve_space.raw_ptr(); | |||
// layer 1 | |||
// TensorNDArray layer_outputs; | |||
for (size_t d = 0; d < D; ++d) { | |||
size_t cell_idx = d; | |||
auto& cell = cells[cell_idx]; | |||
TensorNDArray cur_states; | |||
size_t states_offset = cell_idx * cell_output_layout.span().dist_byte(); | |||
for (size_t i = 0; i < states.size(); ++i) { | |||
cur_states.push_back(TensorND{ | |||
static_cast<uint8_t*>(states_new[i].raw_ptr()) + states_offset, | |||
cell_output_layout}); | |||
} | |||
// TensorND direction_output_tensor{output.raw_ptr + d * | |||
// direction_output_layout.span().dist_byte(), | |||
// direction_output_layout}; | |||
for (size_t i = 0; i < seq_len; ++i) { | |||
size_t step = d == 0 ? i : seq_len - 1 - i; | |||
TensorND step_input{ | |||
static_cast<uint8_t*>(input.raw_ptr()) + | |||
step * cell_first_input_layout.span().dist_byte(), | |||
cell_first_input_layout}; | |||
TensorND step_output{ | |||
static_cast<uint8_t*>(output.raw_ptr()) + | |||
(step * D + d) * cell_output_layout.span().dist_byte(), | |||
cell_output_layout}; | |||
// temporary states of each step (use reserve space) | |||
TensorNDArray tmp_states; | |||
for (size_t s = 0; s < cur_states.size(); ++s) { | |||
tmp_states.push_back(TensorND{reserve_ptr, cur_states[s].layout}); | |||
size_t size_in_bytes = cur_states[s].layout.span().dist_byte(); | |||
reserve_ptr = static_cast<uint8_t*>(reserve_ptr) + size_in_bytes; | |||
} | |||
cell_opr_exec<CellOpr>( | |||
step_input, cell.weight_ih, cell.weight_hh, cell.bias_ih, | |||
cell.bias_hh, cur_states, tmp_states, new_workspace, handle); | |||
// copy states to cur_states | |||
for (size_t s = 0; s < tmp_states.size(); ++s) { | |||
copy_opr->exec(tmp_states[s], cur_states[s]); | |||
} | |||
// copy h to step output | |||
copy_opr->exec(cur_states[0], step_output); | |||
} | |||
} | |||
for (size_t layer = 1; layer < num_layers; ++layer) { | |||
// TensorNDArray layer_outputs; | |||
for (size_t d = 0; d < D; ++d) { | |||
size_t cell_idx = layer * D + d; | |||
auto& cell = cells[cell_idx]; | |||
TensorNDArray cur_states; | |||
size_t states_offset = cell_idx * cell_output_layout.span().dist_byte(); | |||
for (size_t i = 0; i < states.size(); ++i) { | |||
cur_states.push_back(TensorND{ | |||
static_cast<uint8_t*>(states_new[i].raw_ptr()) + states_offset, | |||
cell_output_layout}); | |||
} | |||
// TensorND direction_output_tensor{output.raw_ptr + d * | |||
// direction_output_layout.span().dist_byte(), | |||
// direction_output_layout}; | |||
for (size_t i = 0; i < seq_len; ++i) { | |||
size_t step = d == 0 ? i : seq_len - 1 - i; | |||
TensorND step_input{ | |||
static_cast<uint8_t*>(output.raw_ptr()) + | |||
step * cell_input_layout.span().dist_byte(), | |||
cell_input_layout}; | |||
TensorND step_output{ | |||
static_cast<uint8_t*>(tmp_output.raw_ptr()) + | |||
(step * D + d) * cell_output_layout.span().dist_byte(), | |||
cell_output_layout}; | |||
// temporary states of each step (use reserve space) | |||
TensorNDArray tmp_states; | |||
for (size_t s = 0; s < cur_states.size(); ++s) { | |||
tmp_states.push_back(TensorND{reserve_ptr, cur_states[s].layout}); | |||
size_t size_in_bytes = cur_states[s].layout.span().dist_byte(); | |||
reserve_ptr = static_cast<uint8_t*>(reserve_ptr) + size_in_bytes; | |||
} | |||
cell_opr_exec<CellOpr>( | |||
step_input, cell.weight_ih, cell.weight_hh, cell.bias_ih, | |||
cell.bias_hh, cur_states, cur_states, new_workspace, handle); | |||
// copy states to cur_states | |||
for (size_t s = 0; s < tmp_states.size(); ++s) { | |||
copy_opr->exec(tmp_states[s], cur_states[s]); | |||
} | |||
// copy h to step_output | |||
copy_opr->exec(cur_states[0], step_output); | |||
} | |||
} | |||
// copy layer output to output | |||
copy_opr->exec(tmp_output, output); | |||
} | |||
// output: [d0, d1, d0, d1 ...] | |||
} | |||
template <class Cell> | |||
size_t get_cells( | |||
size_t D, size_t num_layers, size_t input_size, size_t hidden_size, bool bias, | |||
std::vector<Cell>& cells, _megdnn_tensor_in flatten_weights, | |||
_megdnn_workspace workspace) { | |||
cells.reserve(D * num_layers); | |||
void* weight_ptr = flatten_weights.raw_ptr(); | |||
for (size_t layer = 0; layer < num_layers; ++layer) { | |||
for (size_t d = 0; d < D; ++d) { | |||
size_t cell_input_size = D * hidden_size; | |||
if (layer == 0) | |||
cell_input_size = input_size; | |||
Cell cell( | |||
weight_ptr, hidden_size, cell_input_size, bias, | |||
flatten_weights.layout.dtype, workspace); | |||
weight_ptr = | |||
static_cast<uint8_t*>(weight_ptr) + cell.weight_size_in_bytes(); | |||
cells.push_back(cell); | |||
} | |||
} | |||
// return used workspace | |||
return cells[0].workspace_size_in_bytes(); | |||
} | |||
template <class Cell> | |||
size_t get_inputs_for_exec( | |||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in reserve_space, | |||
size_t num_layers, size_t D, size_t hidden_size, const std::vector<Cell>& cells, | |||
TensorNDArray& layer_inputs, TensorNDArray& layer_outputs, | |||
std::vector<std::vector<TensorNDArray>>& cell_seq_states, | |||
param::RNNCell::NonlineMode nonlineMode, _megdnn_workspace workspace) { | |||
// return used workspace size | |||
layer_inputs.push_back(x); | |||
size_t seq_len = x.layout.shape[0]; | |||
size_t batch_size = x.layout.shape[1]; | |||
size_t num_states = cells[0].num_states(); | |||
TensorLayout cell_output_layout{{batch_size, hidden_size}, y.layout.dtype}; | |||
TensorLayout direction_output_layout{ | |||
{seq_len, batch_size, hidden_size}, y.layout.dtype}; | |||
void* workspace_ptr = workspace.raw_ptr; | |||
// extract intermedia states from reserve space | |||
for (int layer = 0; layer < num_layers; ++layer) { | |||
TensorND layer_output{workspace_ptr, y.layout}; | |||
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) + | |||
layer_output.layout.span().dist_byte(); | |||
for (int d = 0; d < D; ++d) { | |||
cell_seq_states.push_back(std::vector<TensorNDArray>()); | |||
// reverse direction is stored with reversed order of sequence order | |||
for (int i = 0; i < seq_len; ++i) { | |||
size_t step = i; | |||
if (d == 1) | |||
step = seq_len - i - 1; | |||
size_t offset = ((layer * D + d) * seq_len + step) * | |||
cell_output_layout.span().dist_byte() * num_states; | |||
TensorNDArray cur_states; | |||
for (int s = 0; s < num_states; ++s) { | |||
TensorND h{ | |||
static_cast<uint8_t*>(reserve_space.raw_ptr()) + offset + | |||
s * cell_output_layout.span().dist_byte(), | |||
cell_output_layout}; | |||
cur_states.push_back(h); | |||
} | |||
TensorND hy{ | |||
static_cast<uint8_t*>(reserve_space.raw_ptr()) + offset, | |||
cell_output_layout}; // the first hidden state is the output | |||
// states | |||
cell_seq_states[cell_seq_states.size() - 1].push_back(cur_states); | |||
// output | |||
offset = i * D * cell_output_layout.span().dist_byte(); | |||
memcpy(static_cast<uint8_t*>(layer_output.raw_ptr()) + offset, hy.raw_ptr(), | |||
hy.layout.span().dist_byte()); | |||
} | |||
} | |||
layer_outputs.push_back(layer_output); | |||
if (layer != num_layers - 1) | |||
layer_inputs.push_back(layer_output); | |||
} | |||
return static_cast<uint8_t*>(workspace_ptr) - | |||
static_cast<uint8_t*>((void*)workspace.raw_ptr); | |||
} | |||
template <class Cell> | |||
// using Cell = RNNCellWeightWrapper; | |||
void backward_exec_internal( | |||
std::vector<Cell>& cells, size_t D, size_t num_layers, size_t input_size, | |||
bool bias, param::RNNCell::NonlineMode nonlineMode, | |||
const TensorNDArray& layer_inputs, const TensorNDArray& layer_outputs, | |||
const std::vector<std::vector<TensorNDArray>>& cell_seq_states, | |||
_megdnn_tensor_in dy, const TensorNDArray& dhy, _megdnn_tensor_out dx, | |||
TensorNDArray& dstates, _megdnn_tensor_out dw, Handle* handle, | |||
_megdnn_workspace workspace) { | |||
/* | |||
layer_inputs: array of input of each layer, element 0: [seq_len, batch_size, | |||
input_size], element others: [seq_len, batch_size, D * hidden_size] | |||
layer_outputs: array of outputs of each rnn. To access outputs of the cell at | |||
(layer, d), use layer_outputs[layer]. The shape is [seq_len, batch_size, | |||
output_size(D*hidden_size)] (in sequence order) cell_seq_states: arrray of states | |||
of each cell at each step. To access the states of the cell at (layer, d) at | |||
sequence step (step), use cell_seq_states[layer*D + d][step] | |||
*/ | |||
size_t seq_len = layer_inputs[0].layout.shape[0]; | |||
size_t batch_size = layer_inputs[0].layout.shape[1]; | |||
DType dtype = layer_inputs[0].layout.dtype; | |||
size_t cell_y_size = | |||
layer_outputs[0].layout.shape[2] / D; // should all be the same | |||
size_t hidden_size = cell_y_size; | |||
TensorLayout cell_y_layout = {{batch_size, cell_y_size}, dtype}; | |||
void* workspace_ptr = workspace.raw_ptr; | |||
TensorND layer_output_grad{ | |||
workspace_ptr, {{seq_len, batch_size, D * hidden_size}, dtype}}; | |||
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) + | |||
layer_output_grad.layout.span().dist_byte(); | |||
memcpy(layer_output_grad.raw_ptr(), dy.raw_ptr(), dy.layout.span().dist_byte()); | |||
TensorNDArray direction_dx_arr; // for layer 1 to layer num_layers-1 | |||
for (int i = 0; i < D; ++i) { | |||
TensorLayout direction_dx_layout{{seq_len, batch_size, hidden_size}, dtype}; | |||
direction_dx_arr.push_back(TensorND(workspace_ptr, direction_dx_layout)); | |||
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) + | |||
direction_dx_layout.span().dist_byte(); | |||
} | |||
TensorNDArray L0_direction_dx_arr; | |||
for (int i = 0; i < D; ++i) { | |||
TensorLayout direction_dx_layout{{seq_len, batch_size, input_size}, dtype}; | |||
L0_direction_dx_arr.push_back(TensorND(workspace_ptr, direction_dx_layout)); | |||
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) + | |||
direction_dx_layout.span().dist_byte(); | |||
} | |||
// cell states for each layer and each direction | |||
std::vector<TensorNDArray> dstates_arr; | |||
for (int layer = 0; layer < num_layers; ++layer) { | |||
for (int d = 0; d < D; ++d) { | |||
TensorNDArray cell_states; | |||
cell_states.reserve(dstates.size()); | |||
for (int i = 0; i < dstates.size(); ++i) { | |||
size_t offset = (layer * D + d) * cell_y_layout.span().dist_byte(); | |||
TensorND dhx_cell{ | |||
static_cast<uint8_t*>(dstates[i].raw_ptr()) + offset, | |||
cell_y_layout}; | |||
memcpy(dhx_cell.raw_ptr(), static_cast<uint8_t*>(dhy[i].raw_ptr()) + offset, | |||
cell_y_layout.span().dist_byte()); | |||
cell_states.emplace_back(dhx_cell); | |||
} | |||
dstates_arr.push_back(cell_states); | |||
} | |||
} | |||
// init gradient on weight to zero | |||
memset(dw.raw_ptr(), 0, dw.layout.span().dist_byte()); | |||
// use cells to contain gradient | |||
std::vector<Cell> cell_grads; | |||
size_t used_workspace_size = static_cast<uint8_t*>(workspace_ptr) - | |||
static_cast<uint8_t*>((void*)(workspace.raw_ptr)); | |||
workspace_ptr = | |||
static_cast<uint8_t*>(workspace_ptr) + | |||
get_cells( | |||
D, num_layers, input_size, hidden_size, bias, cell_grads, dw, | |||
Workspace( | |||
workspace.raw_ptr + used_workspace_size, | |||
workspace.size - used_workspace_size)); | |||
auto add_opr = handle->create_operator<ElemwiseForward>(); | |||
add_opr->param().mode = Elemwise::Mode::ADD; | |||
auto copy_opr = handle->create_operator<TypeCvtForward>(); | |||
// initialize dx to zero | |||
memset(dx.raw_ptr(), 0, dx.layout.span().dist_byte()); | |||
// calculate grads | |||
for (int layer = num_layers - 1; layer >= 0; --layer) { | |||
for (int d = D - 1; d >= 0; --d) { | |||
Cell& cell = cells[layer * D + d]; | |||
Cell& cell_grad = cell_grads[layer * D + d]; | |||
size_t input_size = layer_inputs[layer].layout.shape[2]; | |||
const TensorND& x_arr = layer_inputs[layer]; | |||
const TensorND& y_arr = layer_outputs[layer]; | |||
TensorLayout x_layout = {{batch_size, input_size}, dtype}; | |||
// tmp tensors | |||
void* tmp_workspace_ptr = workspace_ptr; | |||
TensorND dwi_tmp{tmp_workspace_ptr, cell_grad.weight_ih.layout}; | |||
tmp_workspace_ptr = static_cast<uint8_t*>(tmp_workspace_ptr) + | |||
dwi_tmp.layout.span().dist_byte(); | |||
TensorND dwh_tmp{tmp_workspace_ptr, cell_grad.weight_hh.layout}; | |||
tmp_workspace_ptr = static_cast<uint8_t*>(tmp_workspace_ptr) + | |||
dwh_tmp.layout.span().dist_byte(); | |||
TensorND dbias_tmp{tmp_workspace_ptr, cell_grad.bias_ih.layout}; | |||
tmp_workspace_ptr = static_cast<uint8_t*>(tmp_workspace_ptr) + | |||
dbias_tmp.layout.span().dist_byte(); | |||
size_t used_workspace_size = | |||
static_cast<uint8_t*>(tmp_workspace_ptr) - | |||
static_cast<uint8_t*>((void*)(workspace.raw_ptr)); | |||
for (int i = 0; i < seq_len; ++i) { | |||
// reverse time step (not seq step). Here step means seq step | |||
size_t step = i; | |||
if (d == 0) | |||
step = seq_len - i - 1; | |||
TensorND x{ | |||
static_cast<uint8_t*>(x_arr.raw_ptr()) + | |||
step * x_layout.span().dist_byte(), | |||
x_layout}, | |||
y{static_cast<uint8_t*>(y_arr.raw_ptr()) + | |||
(step * D + d) * cell_y_layout.span().dist_byte(), | |||
cell_y_layout}; | |||
const TensorNDArray& cell_states = cell_seq_states[layer * D + d][step]; | |||
TensorNDArray& dstates_new = dstates_arr[layer * D + d]; | |||
// dy should be d_output + d_hidden | |||
TensorND dy_t{ | |||
static_cast<uint8_t*>(layer_output_grad.raw_ptr()) + | |||
(step * D + d) * cell_y_layout.span().dist_byte(), | |||
cell_y_layout}; | |||
add_opr->exec({dstates_new[0], dy_t}, dy_t); | |||
// dx for layer 0 has a different size | |||
TensorND dx_t; | |||
if (layer == 0) | |||
dx_t = {static_cast<uint8_t*>(L0_direction_dx_arr[d].raw_ptr()) + | |||
step * x_layout.span().dist_byte(), | |||
x_layout}; | |||
else | |||
dx_t = {static_cast<uint8_t*>(direction_dx_arr[d].raw_ptr()) + | |||
step * x_layout.span().dist_byte(), | |||
x_layout}; | |||
TensorNDArray douts = {dy_t}; | |||
for (int s = 1; s < dstates_new.size(); ++s) | |||
douts.push_back(dstates_new[s]); | |||
cell.backward( | |||
handle, nonlineMode, x, cell_states, y, douts, dx_t, | |||
dstates_new, dwi_tmp, dwh_tmp, dbias_tmp, | |||
Workspace( | |||
workspace.raw_ptr + used_workspace_size, | |||
workspace.size - used_workspace_size)); | |||
// add step gradient to overall gradient | |||
add_opr->exec({dwi_tmp, cell_grad.weight_ih}, cell_grad.weight_ih); | |||
add_opr->exec({dwh_tmp, cell_grad.weight_hh}, cell_grad.weight_hh); | |||
add_opr->exec({dbias_tmp, cell_grad.bias_ih}, cell_grad.bias_ih); | |||
add_opr->exec({dbias_tmp, cell_grad.bias_hh}, cell_grad.bias_hh); | |||
} | |||
} | |||
// add gradient of different directions to layer_output_grad. | |||
// Layer 0 to dx | |||
if (layer == 0) { | |||
for (int i = 0; i < D; ++i) | |||
add_opr->exec({L0_direction_dx_arr[i], dx}, dx); | |||
} else { | |||
if (D == 1) | |||
copy_opr->exec(direction_dx_arr[0], layer_output_grad); | |||
else { // D == 2, arrange as [(d0, d1), (d0, d1), ...] | |||
for (size_t t = 0; t < seq_len; ++t) { | |||
size_t offset = t * D * cell_y_layout.span().dist_byte(); | |||
for (size_t d = 0; d < D; ++d) { | |||
TensorND src{ | |||
static_cast<uint8_t*>(direction_dx_arr[d].raw_ptr()) + | |||
offset, | |||
cell_y_layout}; | |||
TensorND dst{ | |||
static_cast<uint8_t*>(layer_output_grad.raw_ptr()) + | |||
offset + d * cell_y_layout.span().dist_byte(), | |||
cell_y_layout}; | |||
copy_opr->exec(src, dst); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} // namespace rnn | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,196 @@ | |||
/** | |||
* \file dnn/src/naive/rnn/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/naive/rnn/opr_impl.h" | |||
#include "megdnn/dtype.h" | |||
#include "megdnn/oprs/base.h" | |||
#include "megdnn/oprs/general.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/common/rnn.h" | |||
#include "src/common/utils.h" | |||
#include "src/naive/handle.h" | |||
#include "src/naive/matrix_mul/opr_impl.h" | |||
#include "src/naive/rnn/funcs.h" | |||
#include "src/naive/rnn/rnn.h" | |||
#include <cstring> | |||
namespace megdnn { | |||
namespace naive { | |||
using rnn::RNNCellWeightWrapper; | |||
void RNNImpl::exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, | |||
_megdnn_workspace workspace) { | |||
auto _param = param(); | |||
size_t D = _param.bidirectional ? 2 : 1; | |||
size_t num_layers = _param.num_layers; | |||
size_t input_size = input.layout.shape[2]; | |||
std::vector<RNNCellWeightWrapper> cells; | |||
size_t used_workspace_size = rnn::get_cells<RNNCellWeightWrapper>( | |||
D, num_layers, input_size, _param.hidden_size, _param.bias, cells, | |||
flatten_weights, workspace); | |||
Workspace new_workspace( | |||
workspace.raw_ptr + used_workspace_size, | |||
workspace.size - used_workspace_size); | |||
TensorNDArray states, states_new; | |||
states.push_back(hx); | |||
states_new.push_back(hy); | |||
rnn::exec_internal<RNNCellWeightWrapper, RNNCellForward>( | |||
cells, input, states, states_new, output, reserve_space, num_layers, D, | |||
this->handle(), new_workspace); | |||
} | |||
size_t RNNImpl::get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& hx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& reserve_space) { | |||
size_t workspace_size = rnn::get_workspace_in_bytes<RNNCellForward>( | |||
input, flatten_weights, param().hidden_size, param().bidirectional ? 2 : 1, | |||
this->handle()); | |||
if (!param().bias) { // use fake bias (all 0) | |||
TensorLayout bias_layout = {{param().hidden_size}, flatten_weights.dtype}; | |||
workspace_size += bias_layout.span().dist_byte(); | |||
} | |||
workspace_size += output.span().dist_byte(); | |||
return workspace_size; | |||
} | |||
size_t RNNImpl::get_reserve_size_in_bytes(const TensorLayout& input) { | |||
size_t num_layers = param().num_layers; | |||
size_t D = param().bidirectional ? 2 : 1; | |||
size_t seq_len = input.shape[0]; | |||
size_t batch_size = input.shape[1]; | |||
TensorLayout state_layout{{batch_size, param().hidden_size}, input.dtype}; | |||
return num_layers * D * seq_len * state_layout.span().dist_byte(); | |||
} | |||
void RNNBackwardImpl::exec( | |||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in dy, _megdnn_tensor_in dhy, _megdnn_tensor_in flatten_weights, | |||
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx, | |||
_megdnn_tensor_out dw, _megdnn_workspace workspace) { | |||
TensorNDArray layer_inputs; | |||
// layer_inputs.push_back(x); | |||
TensorNDArray layer_outputs; | |||
std::vector<std::vector<TensorNDArray>> cell_seq_states; | |||
size_t num_layers = param().num_layers; | |||
size_t D = param().bidirectional ? 2 : 1; | |||
// size_t seq_len = x.layout.shape[0]; | |||
// size_t batch_size = x.layout.shape[1]; | |||
size_t input_size = x.layout.shape[2]; | |||
size_t hidden_size = param().hidden_size; | |||
size_t used_workspace_size = 0; | |||
// get cells | |||
std::vector<RNNCellWeightWrapper> cells; | |||
// workspace_ptr = static_cast<uint8_t*>(workspace_ptr) + | |||
used_workspace_size += rnn::get_cells( | |||
D, num_layers, input_size, hidden_size, param().bias, cells, | |||
flatten_weights, workspace); | |||
// extract intermedia states from reserve space | |||
/*for (int layer = 0; layer < num_layers; ++layer) { | |||
TensorND layer_output{workspace_ptr, y.layout}; | |||
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) + | |||
layer_output.layout.span().dist_byte(); for (int d = 0; d < D; ++d) { | |||
cell_seq_states.push_back(std::vector<TensorNDArray>()); | |||
// reverse direction is stored with reversed order of sequence order | |||
for (int i = 0; i < seq_len; ++i) { | |||
size_t step = i; | |||
if (d == 1) step = seq_len - i - 1; | |||
size_t offset = ((layer * D + d) * seq_len + step) * | |||
cell_output_layout.span().dist_byte(); TensorND | |||
hy{static_cast<uint8_t*>(reserve_space.raw_ptr) + offset, cell_output_layout}; | |||
// states | |||
cell_seq_states[cell_seq_states.size() - 1].push_back({hy}); | |||
// output | |||
offset = i * D * cell_output_layout.span().dist_byte(); | |||
memcpy(static_cast<uint8_t*>(layer_output.raw_ptr) + offset, | |||
hy.raw_ptr, hy.layout.span().dist_byte()); | |||
} | |||
} | |||
cell_seq_outputs.push_back(layer_output); | |||
if (layer != num_layers - 1) layer_inputs.push_back(layer_output); | |||
}*/ | |||
// nonlinear mode | |||
param::RNNCell::NonlineMode nonlineMode; | |||
using ModeRNN = param::RNN::NonlineMode; | |||
using ModeRNNCell = param::RNNCell::NonlineMode; | |||
switch (param().nonlineMode) { | |||
case ModeRNN::RELU: | |||
nonlineMode = ModeRNNCell::RELU; | |||
break; | |||
case ModeRNN::TANH: | |||
nonlineMode = ModeRNNCell::TANH; | |||
break; | |||
} | |||
// get formatted inputs | |||
Workspace new_workspace = Workspace( | |||
workspace.raw_ptr + used_workspace_size, | |||
workspace.size - used_workspace_size); | |||
used_workspace_size += rnn::get_inputs_for_exec<RNNCellWeightWrapper>( | |||
x, y, reserve_space, num_layers, D, hidden_size, cells, layer_inputs, | |||
layer_outputs, cell_seq_states, nonlineMode, new_workspace); | |||
// dhy arr, dhx arr | |||
TensorNDArray dhy_arr = {dhy}, dhx_arr = {dhx}; | |||
// exec | |||
/*size_t used_workspace_size = static_cast<uint8_t*>(workspace_ptr) - | |||
static_cast<uint8_t*>((void*)workspace.raw_ptr);*/ | |||
new_workspace = Workspace( | |||
workspace.raw_ptr + used_workspace_size, | |||
workspace.size - used_workspace_size); | |||
rnn::backward_exec_internal<RNNCellWeightWrapper>( | |||
cells, D, num_layers, input_size, param().bias, nonlineMode, layer_inputs, | |||
layer_outputs, cell_seq_states, dy, dhy_arr, dx, dhx_arr, dw, | |||
this->handle(), new_workspace); | |||
} | |||
size_t RNNBackwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw) { | |||
size_t D = param().bidirectional ? 2 : 1; | |||
size_t num_layers = param().num_layers; | |||
size_t hidden_size = param().hidden_size; | |||
size_t gate_hidden_size = hidden_size; | |||
size_t max_input_size = std::max(x.shape[2], D * hidden_size); | |||
size_t workspace_size = RNNCellWeightWrapper::backward_workspace_size_in_bytes( | |||
this->handle(), x.shape[1], param().hidden_size, max_input_size, x.dtype); | |||
if (!param().bias) { // use fake bias (all 0) | |||
TensorLayout bias_layout = {{gate_hidden_size}, flatten_weights.dtype}; | |||
workspace_size += bias_layout.span().dist_byte() * | |||
2; // times 2 because another bias is allocated in | |||
// backward_exec_internal | |||
} | |||
workspace_size += num_layers * y.span().dist_byte(); | |||
// add back exec workspace size | |||
workspace_size += y.span().dist_byte() * 2; | |||
workspace_size += x.span().dist_byte() * 2; | |||
TensorLayout wih{{gate_hidden_size, max_input_size}, flatten_weights.dtype}; | |||
TensorLayout whh{{gate_hidden_size, hidden_size}, flatten_weights.dtype}; | |||
TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype}; | |||
workspace_size += wih.span().dist_byte(); | |||
workspace_size += whh.span().dist_byte(); | |||
workspace_size += bias.span().dist_byte(); | |||
return workspace_size; | |||
} | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,53 @@ | |||
/** | |||
* \file dnn/src/naive/rnn/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class RNNImpl : public RNN { | |||
public: | |||
using RNN::RNN; | |||
void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space, | |||
_megdnn_workspace workspace); | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& hx, | |||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||
const TensorLayout& hy, const TensorLayout& reserve_space); | |||
size_t get_reserve_size_in_bytes(const TensorLayout& input); | |||
}; | |||
class RNNBackwardImpl : public RNNBackward { | |||
public: | |||
using RNNBackward::RNNBackward; | |||
virtual void exec( | |||
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in dy, _megdnn_tensor_in dhy, | |||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space, | |||
_megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw, | |||
_megdnn_workspace workspace); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx, | |||
const TensorLayout& dy, const TensorLayout& dhy, | |||
const TensorLayout& flatten_weights, const TensorLayout& reserve_space, | |||
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw); | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,285 @@ | |||
/** | |||
* \file dnn/src/naive/rnn/rnn.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/naive/rnn/rnn.h" | |||
#include <cmath> | |||
#include <cstring> | |||
namespace megdnn { | |||
namespace naive { | |||
namespace rnn { | |||
CellWeightsWrapperBase::CellWeightsWrapperBase( | |||
void* weight_ptr, size_t hidden_size, size_t input_size, size_t num_chunks, | |||
bool has_bias, DType dtype, _megdnn_workspace workspace) { | |||
// weight_ih: [gate_hidden_size, input_size] | |||
// weight_hh: [gate_hidden_size, hidden_size] | |||
// bias_ih: [gate_hidden_size] | |||
// bias_hh: [gate_hidden_size] | |||
size_t gate_hidden_size = num_chunks * hidden_size; | |||
TensorLayout weight_ih_layout{{gate_hidden_size, input_size}, dtype}; | |||
TensorLayout weight_hh_layout{{gate_hidden_size, hidden_size}, dtype}; | |||
TensorLayout bias_layout{{gate_hidden_size}, dtype}; | |||
this->_weight_size = 0; | |||
this->weight_ih = TensorND(weight_ptr, weight_ih_layout); | |||
this->_weight_size += weight_ih_layout.span().dist_byte(); | |||
this->weight_hh = TensorND( | |||
static_cast<uint8_t*>(weight_ptr) + this->_weight_size, weight_hh_layout); | |||
this->_weight_size += weight_hh_layout.span().dist_byte(); | |||
if (has_bias) { | |||
this->bias_ih = TensorND( | |||
static_cast<uint8_t*>(weight_ptr) + this->_weight_size, bias_layout); | |||
this->_weight_size += bias_layout.span().dist_byte(); | |||
this->bias_hh = TensorND( | |||
static_cast<uint8_t*>(weight_ptr) + this->_weight_size, bias_layout); | |||
this->_weight_size += bias_layout.span().dist_byte(); | |||
this->_workspace_size = 0; | |||
} else { | |||
this->bias_ih = TensorND(workspace.raw_ptr, bias_layout); | |||
this->bias_hh = TensorND(workspace.raw_ptr, bias_layout); | |||
memset(workspace.raw_ptr, 0, bias_layout.span().dist_byte()); | |||
this->_workspace_size = bias_layout.span().dist_byte(); | |||
} | |||
} | |||
size_t CellWeightsWrapperBase::weight_size_in_bytes() const { | |||
return this->_weight_size; | |||
} | |||
size_t CellWeightsWrapperBase::workspace_size_in_bytes() const { | |||
return this->_workspace_size; | |||
} | |||
size_t CellWeightsWrapperBase::num_states() const { | |||
return 1; | |||
} | |||
void CellWeightsWrapperBase::backward( | |||
Handle* handle, param::RNNCell::NonlineMode nonlineMode, _megdnn_tensor_in x, | |||
const TensorNDArray& states, _megdnn_tensor_in y, const TensorNDArray& douts, | |||
_megdnn_tensor_out dx, TensorNDArray& dstates, _megdnn_tensor_out dwi, | |||
_megdnn_tensor_out dwh, _megdnn_tensor_out dbias, | |||
_megdnn_workspace workspace) const { | |||
auto dy = douts[0]; | |||
using NonlineMode = param::RNNCell::NonlineMode; | |||
using Mode = Elemwise::Mode; | |||
auto elemwise_opr = handle->create_operator<ElemwiseForward>(); | |||
TensorND tmp = {workspace.raw_ptr, dy.layout}; | |||
auto new_workspace = Workspace( | |||
workspace.raw_ptr + tmp.layout.span().dist_byte(), | |||
workspace.size - tmp.layout.span().dist_byte()); | |||
switch (nonlineMode) { | |||
case (NonlineMode::IDENTITY): | |||
memcpy(tmp.raw_ptr(), dy.raw_ptr(), dy.layout.span().dist_byte()); | |||
break; | |||
case (NonlineMode::TANH): | |||
elemwise_opr->param().mode = Mode::TANH_GRAD; | |||
elemwise_opr->exec({y, dy}, tmp); | |||
break; | |||
case (NonlineMode::RELU): | |||
elemwise_opr->param().mode = Mode::SWITCH_GT0; | |||
elemwise_opr->exec({y, dy}, tmp); | |||
break; | |||
} | |||
auto matrixmul_opr = handle->create_operator<MatrixMulForward>(); | |||
matrixmul_opr->param().transposeA = false; | |||
matrixmul_opr->param().transposeB = false; | |||
// dx | |||
matrixmul_opr->exec(tmp, this->weight_ih, dx, new_workspace); | |||
// dhx | |||
matrixmul_opr->exec(tmp, this->weight_hh, dstates[0], new_workspace); | |||
// dwi | |||
matrixmul_opr->param().transposeA = true; | |||
matrixmul_opr->exec(tmp, x, dwi, new_workspace); | |||
// dwh | |||
matrixmul_opr->exec(tmp, states[0], dwh, new_workspace); | |||
// dbias | |||
auto sum_opr = handle->create_operator<ReduceForward>(); | |||
sum_opr->param().mode = ReduceForward::Mode::SUM; | |||
sum_opr->param().axis = 0; | |||
TensorND dbias_expanded = { | |||
dbias.raw_ptr(), {{1, dbias.layout.shape[0]}, dbias.layout.dtype}}; | |||
sum_opr->exec(tmp, dbias_expanded, new_workspace); | |||
} | |||
size_t CellWeightsWrapperBase::backward_workspace_size_in_bytes( | |||
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, | |||
size_t num_chunks, DType dtype) { | |||
size_t gate_hidden_size = hidden_size * num_chunks; | |||
TensorLayout tmp = {{batch_size, gate_hidden_size}, dtype}; | |||
TensorLayout bias_expanded = {{1, gate_hidden_size}, dtype}; | |||
TensorLayout wih = {{gate_hidden_size, input_size}, dtype}; | |||
TensorLayout whh = {{gate_hidden_size, hidden_size}, dtype}; | |||
TensorLayout x = {{batch_size, input_size}, dtype}; | |||
TensorLayout hx = {{batch_size, hidden_size}, dtype}; | |||
size_t workspace_size = 0; | |||
auto matrixmul_opr = handle->create_operator<MatrixMulForward>(); | |||
matrixmul_opr->param().transposeA = false; | |||
matrixmul_opr->param().transposeB = false; | |||
// dx | |||
workspace_size = std::max( | |||
workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, wih, x)); | |||
// dhx | |||
workspace_size = std::max( | |||
workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, whh, hx)); | |||
// dwi | |||
matrixmul_opr->param().transposeA = true; | |||
workspace_size = std::max( | |||
workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, x, wih)); | |||
// dwh | |||
workspace_size = std::max( | |||
workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, hx, whh)); | |||
// dbias | |||
auto sum_opr = handle->create_operator<ReduceForward>(); | |||
sum_opr->param().mode = ReduceForward::Mode::SUM; | |||
sum_opr->param().axis = 0; | |||
workspace_size = std::max( | |||
workspace_size, sum_opr->get_workspace_in_bytes(tmp, bias_expanded)); | |||
workspace_size += tmp.span().dist_byte(); | |||
return workspace_size; | |||
} | |||
RNNCellWeightWrapper::RNNCellWeightWrapper( | |||
void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, | |||
DType dtype, _megdnn_workspace workspace) | |||
: CellWeightsWrapperBase( | |||
weight_ptr, hidden_size, input_size, 1, has_bias, dtype, workspace) {} | |||
size_t RNNCellWeightWrapper::backward_workspace_size_in_bytes( | |||
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, | |||
DType dtype) { | |||
return CellWeightsWrapperBase::backward_workspace_size_in_bytes( | |||
handle, batch_size, hidden_size, input_size, 1, dtype); | |||
} | |||
LSTMCellWeightWrapper::LSTMCellWeightWrapper( | |||
void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, | |||
DType dtype, _megdnn_workspace workspace) | |||
: CellWeightsWrapperBase( | |||
weight_ptr, hidden_size, input_size, 4, has_bias, dtype, workspace) {} | |||
size_t LSTMCellWeightWrapper::num_states() const { | |||
return 2; | |||
} | |||
size_t LSTMCellWeightWrapper::backward_workspace_size_in_bytes( | |||
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, | |||
DType dtype) { | |||
// get gates size | |||
size_t gate_hidden_size = 4 * hidden_size; | |||
auto lstm_opr = handle->create_operator<LSTMCellForward>(); | |||
TensorLayout x = {{batch_size, input_size}, dtype}; | |||
TensorLayout weight_ih = {{gate_hidden_size, input_size}, dtype}; | |||
TensorLayout weight_hh = {{gate_hidden_size, hidden_size}, dtype}; | |||
TensorLayout bias = {{gate_hidden_size}, dtype}; | |||
TensorLayout h = {{batch_size, hidden_size}, dtype}; | |||
TensorLayout gates, h_new, c_new; | |||
lstm_opr->deduce_layout( | |||
x, weight_ih, bias, h, weight_hh, bias, h, h_new, c_new, gates); | |||
return CellWeightsWrapperBase::backward_workspace_size_in_bytes( | |||
handle, batch_size, hidden_size, input_size, 4, dtype) + | |||
gates.span().dist_byte() * 2 + c_new.span().dist_byte(); | |||
} | |||
void LSTMCellWeightWrapper::backward( | |||
Handle* handle, | |||
param::RNNCell::NonlineMode nonlineMode, // nonlineMode must be identity | |||
_megdnn_tensor_in x, const TensorNDArray& states, _megdnn_tensor_in y, | |||
const TensorNDArray& douts, _megdnn_tensor_out dx, TensorNDArray& dstates, | |||
_megdnn_tensor_out dwi, _megdnn_tensor_out dwh, _megdnn_tensor_out dbias, | |||
_megdnn_workspace workspace) const { | |||
size_t used_workspace_size = 0; | |||
// get gates | |||
auto lstm_opr = handle->create_operator<LSTMCellForward>(); | |||
TensorLayout gates, h_new, c_new; | |||
lstm_opr->deduce_layout( | |||
x.layout, weight_ih.layout, bias_ih.layout, states[0].layout, | |||
weight_hh.layout, bias_hh.layout, states[1].layout, h_new, c_new, gates); | |||
TensorND gates_tensor{workspace.raw_ptr, gates}; | |||
used_workspace_size += gates.span().dist_byte(); | |||
TensorND gates_grad{workspace.raw_ptr + used_workspace_size, gates}; | |||
used_workspace_size += gates.span().dist_byte(); | |||
TensorND tanh_cy{workspace.raw_ptr + used_workspace_size, y.layout}; | |||
used_workspace_size += tanh_cy.layout.span().dist_byte(); | |||
Workspace new_workspace = Workspace( | |||
workspace.raw_ptr + used_workspace_size, | |||
workspace.size - used_workspace_size); | |||
// temporarily use dstates to store hy, cy | |||
// only gates and cy needed, other output will be cleared afterwards | |||
lstm_opr->exec( | |||
x, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states[1], dstates[0], | |||
dstates[1], gates_tensor, | |||
new_workspace); // no information left in the workspace | |||
// i, f, o, g | |||
TensorLayout single_gate = {{gates.shape[0], gates.shape[1] / 4}, gates.dtype}; | |||
TensorND i, f, o, g, i_grad, f_grad, o_grad, | |||
g_grad; // grad refers to the grad of gates before activation | |||
i = {gates_tensor.raw_ptr(), single_gate}; | |||
f = {static_cast<uint8_t*>(gates_tensor.raw_ptr()) + single_gate.span().dist_byte(), | |||
single_gate}; | |||
o = {static_cast<uint8_t*>(f.raw_ptr()) + single_gate.span().dist_byte(), | |||
single_gate}; | |||
g = {static_cast<uint8_t*>(o.raw_ptr()) + single_gate.span().dist_byte(), | |||
single_gate}; | |||
i_grad = {gates_grad.raw_ptr(), single_gate}; | |||
f_grad = { | |||
static_cast<uint8_t*>(i_grad.raw_ptr()) + single_gate.span().dist_byte(), | |||
single_gate}; | |||
o_grad = { | |||
static_cast<uint8_t*>(f_grad.raw_ptr()) + single_gate.span().dist_byte(), | |||
single_gate}; | |||
g_grad = { | |||
static_cast<uint8_t*>(o_grad.raw_ptr()) + single_gate.span().dist_byte(), | |||
single_gate}; | |||
// activation | |||
auto elem_opr = handle->create_operator<ElemwiseForward>(); | |||
elem_opr->param().mode = Elemwise::Mode::SIGMOID; | |||
elem_opr->exec({i}, i); | |||
elem_opr->exec({f}, f); | |||
elem_opr->exec({o}, o); | |||
elem_opr->param().mode = Elemwise::Mode::TANH; | |||
elem_opr->exec({g}, g); | |||
elem_opr->exec({dstates[1]}, tanh_cy); | |||
auto mul_opr = handle->create_operator<ElemwiseForward>(); | |||
mul_opr->param().mode = Elemwise::Mode::MUL; | |||
// use dstates[0] as tmp tensor to store dhy * tanh_cy | |||
mul_opr->exec({douts[0], tanh_cy}, dstates[0]); | |||
elem_opr->param().mode = Elemwise::Mode::SIGMOID_GRAD; | |||
elem_opr->exec({o, dstates[0]}, o_grad); // grad of gate o | |||
// use dstates[0] as tmp tensor to store dhy * o | |||
mul_opr->exec({douts[0], o}, dstates[0]); | |||
elem_opr->param().mode = Elemwise::Mode::TANH_GRAD; | |||
elem_opr->exec({tanh_cy, dstates[0]}, dstates[1]); // grad of cy from hy | |||
elem_opr->param().mode = Elemwise::Mode::ADD; | |||
elem_opr->exec({douts[1], dstates[1]}, dstates[1]); // true grad of cy | |||
// use dstates[0] as tmp tensor to store dcy * cx | |||
mul_opr->exec({dstates[1], states[1]}, dstates[0]); | |||
elem_opr->param().mode = Elemwise::Mode::SIGMOID_GRAD; | |||
elem_opr->exec({f, dstates[0]}, f_grad); // grad of gate f | |||
// use dstates[0] as tmp tensor to store dcy * g | |||
mul_opr->exec({dstates[1], g}, dstates[0]); | |||
elem_opr->exec({i, dstates[0]}, i_grad); // grad of gate i | |||
// use dstates[0] as tmp tensor to store dcy * i | |||
mul_opr->exec({dstates[1], i}, dstates[0]); | |||
elem_opr->param().mode = Elemwise::Mode::TANH_GRAD; | |||
elem_opr->exec({g, dstates[0]}, g_grad); // grad of gate g | |||
// grad if cx | |||
mul_opr->exec({dstates[1], f}, dstates[1]); | |||
TensorNDArray base_dstates = {dstates[0]}; | |||
CellWeightsWrapperBase::backward( | |||
handle, nonlineMode, x, {states[0]}, gates_tensor, {gates_grad}, dx, | |||
base_dstates, dwi, dwh, dbias, new_workspace); | |||
} | |||
} // namespace rnn | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,73 @@ | |||
/** | |||
* \file dnn/src/naive/rnn/rnn.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "megdnn/oprs/general.h" | |||
namespace megdnn { | |||
namespace naive { | |||
namespace rnn { | |||
class CellWeightsWrapperBase { | |||
private: | |||
size_t _weight_size, _workspace_size; | |||
public: | |||
TensorND weight_ih, weight_hh, bias_ih, bias_hh; | |||
// if no bias, will create dummy bias tensor from workspace | |||
CellWeightsWrapperBase( | |||
void* weight_ptr, size_t hidden_size, size_t input_size, size_t num_chunks, | |||
bool has_bias, DType dtype, _megdnn_workspace workspace); | |||
size_t weight_size_in_bytes() const; | |||
size_t workspace_size_in_bytes() const; | |||
static size_t backward_workspace_size_in_bytes( | |||
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, | |||
size_t num_chunks, DType dtype); | |||
virtual void backward( | |||
Handle* handle, param::RNNCell::NonlineMode nonlineMode, | |||
_megdnn_tensor_in x, const TensorNDArray& states, _megdnn_tensor_in y, | |||
const TensorNDArray& douts, _megdnn_tensor_out dx, TensorNDArray& dstates, | |||
_megdnn_tensor_out dwi, _megdnn_tensor_out dwh, _megdnn_tensor_out dbias, | |||
_megdnn_workspace workspace) const; | |||
virtual size_t num_states() const; | |||
}; | |||
class RNNCellWeightWrapper : public CellWeightsWrapperBase { | |||
public: | |||
RNNCellWeightWrapper( | |||
void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, | |||
DType dtype, _megdnn_workspace workspace); | |||
static size_t backward_workspace_size_in_bytes( | |||
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, | |||
DType dtype); | |||
}; | |||
class LSTMCellWeightWrapper : public CellWeightsWrapperBase { | |||
public: | |||
LSTMCellWeightWrapper( | |||
void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, | |||
DType dtype, _megdnn_workspace workspace); | |||
static size_t backward_workspace_size_in_bytes( | |||
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size, | |||
DType dtype); | |||
size_t num_states() const override; | |||
void backward( | |||
Handle* handle, param::RNNCell::NonlineMode nonlineMode, | |||
_megdnn_tensor_in x, const TensorNDArray& states, _megdnn_tensor_in y, | |||
const TensorNDArray& douts, _megdnn_tensor_out dx, TensorNDArray& dstates, | |||
_megdnn_tensor_out dwi, _megdnn_tensor_out dwh, _megdnn_tensor_out dbias, | |||
_megdnn_workspace workspace) const override; | |||
}; | |||
} // namespace rnn | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,41 @@ | |||
/** | |||
* \file dnn/src/naive/rnn/template_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/naive/rnn/funcs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
namespace rnn { | |||
template <> | |||
void cell_opr_exec<RNNCellForward>( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, | |||
_megdnn_tensor_in bias_hh, const TensorNDArray& states, | |||
TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle) { | |||
auto opr = handle->create_operator<RNNCellForward>(); | |||
opr->exec( | |||
input, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states_new[0], | |||
workspace); | |||
} | |||
template <> | |||
size_t cell_opr_get_workspace_in_bytes<RNNCellForward>( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_ih, | |||
const TensorLayout& bias_hh, const TensorLayout& hx, Handle* handle) { | |||
auto cell_opr = handle->create_operator<RNNCellForward>(); | |||
return cell_opr->get_workspace_in_bytes( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, hx); | |||
} | |||
} // namespace rnn | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,34 @@ | |||
/** | |||
* \file dnn/src/naive/rnn_cell/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/naive/rnn_cell/opr_impl.h" | |||
#include "src/common/rnn_cell.h" | |||
namespace megdnn { | |||
namespace naive { | |||
size_t RNNCellImpl::get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& dst) { | |||
return megdnn::rnn_cell::get_workspace_in_bytes( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, handle()); | |||
} | |||
void RNNCellImpl::exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
megdnn::rnn_cell::exec( | |||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, workspace, | |||
param().nonlineMode, handle()); | |||
} | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -0,0 +1,33 @@ | |||
/** | |||
* \file dnn/src/naive/rnn_cell/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class RNNCellImpl : public RNNCell { | |||
public: | |||
using RNNCell::RNNCell; | |||
void exec( | |||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& input, const TensorLayout& weight_ih, | |||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||
const TensorLayout& dst) override; | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn |
@@ -68,6 +68,15 @@ struct DeduceLayoutProxy<Opr, 6, false> { | |||
}; | |||
template <typename Opr> | |||
struct DeduceLayoutProxy<Opr, 6, true> { | |||
static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) { | |||
megdnn_assert(layouts.size() == 6); | |||
opr->deduce_layout( | |||
layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5]); | |||
} | |||
}; | |||
template <typename Opr> | |||
struct DeduceLayoutProxy<Opr, 7, false> { | |||
static void deduce_layout(Opr*, TensorLayoutArray&) {} | |||
}; | |||
@@ -248,7 +248,6 @@ DEF_TEST(fuse_mul_add4) { | |||
} | |||
DEF_TEST(rmulh) { | |||
using Mode = ElemwiseForward::Param::Mode; | |||
Checker<ElemwiseForward> checker(handle); | |||
@@ -315,7 +314,7 @@ DEF_TEST(unary1) { | |||
// case float | |||
UniformFloatRNG rng(1e-2, 6e1); | |||
checker.set_rng(0, &rng); | |||
checker.set_epsilon(1e-5); | |||
checker.set_epsilon(1e-5); | |||
checker.set_dtype(0, dtype::Float32()); | |||
BUILD_UNARY_TEST_CASE_FLOAT | |||
} | |||
@@ -900,7 +899,7 @@ DEF_TEST(unary_negative_stride) { | |||
UniformFloatRNG rng(1e-2, 6e1); | |||
checker.set_rng(0, &rng); | |||
checker.set_epsilon(1e-5); | |||
checker.set_epsilon(1e-5); | |||
BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT; | |||
} | |||
@@ -0,0 +1,51 @@ | |||
/** | |||
* \file dnn/test/common/rnn.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <vector> | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/opr_param_defs.h" | |||
namespace megdnn { | |||
namespace test { | |||
namespace rnn { | |||
struct TestArg { | |||
param::RNN param; | |||
TensorShape input, hx, flatten_weights; | |||
TestArg(param::RNN param, TensorShape input, TensorShape hx, | |||
TensorShape flatten_weights) | |||
: param(param), input(input), hx(hx), flatten_weights(flatten_weights) {} | |||
}; | |||
inline std::vector<TestArg> get_args() { | |||
std::vector<TestArg> args; | |||
size_t batch_size = 2; | |||
size_t input_size = 3; | |||
size_t hidden_size = 2; | |||
size_t seq_len = 2; | |||
size_t gate_hidden_size = hidden_size; | |||
param::RNN param; | |||
param.num_layers = 1; | |||
param.bidirectional = false; | |||
param.bias = false; | |||
param.hidden_size = hidden_size; | |||
param.nonlineMode = param::RNN::NonlineMode::RELU; | |||
args.emplace_back( | |||
param, TensorShape{seq_len, batch_size, input_size}, | |||
TensorShape{batch_size, hidden_size}, | |||
TensorShape{gate_hidden_size, input_size + hidden_size}); | |||
return args; | |||
} | |||
} // namespace rnn | |||
} // namespace test | |||
} // namespace megdnn |
@@ -0,0 +1,80 @@ | |||
/** | |||
* \file dnn/test/naive/rnn.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "test/common/rnn.h" | |||
#include "megdnn/dtype.h" | |||
#include "megdnn/oprs.h" | |||
#include "test/common/checker.h" | |||
#include "test/naive/fixture.h" | |||
namespace megdnn { | |||
namespace test { | |||
/*TEST_F(NAIVE, RNN) { | |||
std::vector<rnn::TestArg> args = rnn::get_args(); | |||
Checker<RNN> checker(handle()); | |||
for (auto&& arg : args) { | |||
checker.set_param(arg.param) | |||
.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
.set_dtype(2, dtype::Float32()) | |||
.set_dtype(3, dtype::Float32()) | |||
.set_dtype(4, dtype::Float32()) | |||
.set_dtype(5, dtype::Float32()) | |||
.execs({arg.input, arg.hx, arg.flatten_weights, {}, {}, {}}); | |||
} | |||
}*/ | |||
TEST_F(NAIVE, RNN_HAND_MADE) { | |||
Checker<RNN> checker(handle(), false); | |||
size_t batch_size = 2; | |||
size_t input_size = 3; | |||
size_t hidden_size = 2; | |||
size_t seq_len = 2; | |||
size_t gate_hidden_size = hidden_size; | |||
RNN::Param param; | |||
param.num_layers = 1; | |||
param.bidirectional = false; | |||
param.bias = false; | |||
param.hidden_size = hidden_size; | |||
param.nonlineMode = param::RNN::NonlineMode::RELU; | |||
checker.set_param(param).exect( | |||
Testcase{ | |||
TensorValue( | |||
{seq_len, batch_size, input_size}, dtype::Float32(), | |||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input | |||
TensorValue( | |||
{batch_size, hidden_size}, dtype::Float32(), | |||
{2, 1, 3, 5}), // hx | |||
TensorValue( | |||
{gate_hidden_size, input_size + hidden_size}, | |||
dtype::Float32(), | |||
{3, 6, 1, 3, 2, 7, 9, 3, 5, 1}), // weights | |||
{}, | |||
{}, | |||
{}}, | |||
Testcase{ | |||
{}, | |||
{}, | |||
{}, | |||
TensorValue( | |||
{seq_len, batch_size, hidden_size}, dtype::Float32(), | |||
{39, 39, 90, 84, 300, 216, 546, 366}), // output | |||
TensorValue( | |||
{batch_size, hidden_size}, dtype::Float32(), | |||
{21, 11, 42, 20}), // hy | |||
TensorValue( | |||
{1, 2, 2, 2}, dtype::Float32(), | |||
{2, 1, 3, 5, 21, 11, 42, 20}) // reserve space | |||
}); | |||
} | |||
} // namespace test | |||
} // namespace megdnn |
@@ -36,5 +36,6 @@ from .padding import Pad | |||
from .pixel_shuffle import PixelShuffle | |||
from .pooling import AvgPool2d, MaxPool2d | |||
from .quant_dequant import DequantStub, QuantStub | |||
from .rnn import LSTM, RNN, LSTMCell, RNNCell | |||
from .sequential import Sequential | |||
from .sliding_window import SlidingWindow, SlidingWindowTranspose |
@@ -0,0 +1,396 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import math | |||
import numbers | |||
from abc import abstractmethod | |||
from typing import Optional, Tuple | |||
import numpy as np | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core.ops import builtin | |||
from ..device import is_cuda_available | |||
from ..functional import concat, expand_dims, repeat, stack, zeros | |||
from ..functional.nn import concat | |||
from ..tensor import Parameter, Tensor | |||
from . import init | |||
from .module import Module | |||
class RNNCellBase(Module): | |||
def __init__( | |||
self, | |||
input_size: int, | |||
hidden_size: int, | |||
bias: bool, | |||
num_chunks: int, | |||
device=None, | |||
dtype=None, | |||
) -> None: | |||
# num_chunks indicates the number of gates | |||
super(RNNCellBase, self).__init__() | |||
self.input_size = input_size | |||
self.hidden_size = hidden_size | |||
self.bias = bias | |||
# initialize weights | |||
common_kwargs = {"device": device, "dtype": dtype} | |||
self.gate_hidden_size = num_chunks * hidden_size | |||
self.weight_ih = Parameter( | |||
np.random.uniform(size=(self.gate_hidden_size, input_size)).astype( | |||
np.float32 | |||
), | |||
**common_kwargs, | |||
) | |||
self.weight_hh = Parameter( | |||
np.random.uniform(size=(self.gate_hidden_size, hidden_size)).astype( | |||
np.float32 | |||
), | |||
**common_kwargs, | |||
) | |||
if bias: | |||
self.bias_ih = Parameter( | |||
np.random.uniform(size=(self.gate_hidden_size)).astype(np.float32), | |||
**common_kwargs, | |||
) | |||
self.bias_hh = Parameter( | |||
np.random.uniform(size=(self.gate_hidden_size)).astype(np.float32), | |||
**common_kwargs, | |||
) | |||
else: | |||
self.bias_ih = zeros(shape=(self.gate_hidden_size), **common_kwargs) | |||
self.bias_hh = zeros(shape=(self.gate_hidden_size), **common_kwargs) | |||
self.reset_parameters() | |||
# if bias is False self.bias will remain zero | |||
def get_op(self): | |||
return builtin.RNNCell() | |||
def reset_parameters(self) -> None: | |||
stdv = 1.0 / math.sqrt(self.hidden_size) | |||
for weight in self.parameters(): | |||
init.uniform_(weight, -stdv, stdv) | |||
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: | |||
if hx is None: | |||
hx = zeros( | |||
shape=(input.shape[0], self.gate_hidden_size), | |||
dtype=input.dtype, | |||
device=input.device, | |||
) | |||
op = self.get_op() | |||
return apply( | |||
op, input, self.weight_ih, self.bias_ih, hx, self.weight_hh, self.bias_hh | |||
)[0] | |||
# return linear(input, self.weight_ih, self.bias_ih) + linear(hx, self.weight_hh, self.bias_hh) | |||
class RNNCell(RNNCellBase): | |||
def __init__( | |||
self, | |||
input_size: int, | |||
hidden_size: int, | |||
bias: bool = True, | |||
nonlinearity: str = "tanh", | |||
device=None, | |||
dtype=None, | |||
) -> None: | |||
self.nonlinearity = nonlinearity | |||
super(RNNCell, self).__init__( | |||
input_size, hidden_size, bias, num_chunks=1, device=device, dtype=dtype | |||
) | |||
# self.activate = tanh if nonlinearity == "tanh" else relu | |||
def get_op(self): | |||
return builtin.RNNCell(nonlineMode=self.nonlinearity) | |||
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: | |||
return super().forward(input, hx) | |||
class LSTMCell(RNNCellBase): | |||
def __init__( | |||
self, | |||
input_size: int, | |||
hidden_size: int, | |||
bias: bool = True, | |||
device=None, | |||
dtype=None, | |||
) -> None: | |||
super(LSTMCell, self).__init__( | |||
input_size, hidden_size, bias, num_chunks=4, device=device, dtype=dtype | |||
) | |||
def get_op(self): | |||
return builtin.LSTMCell() | |||
def forward( | |||
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None | |||
) -> Tuple[Tensor, Tensor]: | |||
# hx: (h, c) | |||
if hx is None: | |||
h = zeros( | |||
shape=(input.shape[0], self.hidden_size), | |||
dtype=input.dtype, | |||
device=input.device, | |||
) | |||
c = zeros( | |||
shape=(input.shape[0], self.hidden_size), | |||
dtype=input.dtype, | |||
device=input.device, | |||
) | |||
else: | |||
h, c = hx | |||
op = self.get_op() | |||
return apply( | |||
op, input, self.weight_ih, self.bias_ih, h, self.weight_hh, self.bias_hh, c | |||
)[:2] | |||
def is_gpu(device: str) -> bool: | |||
if "xpux" in device and is_cuda_available(): | |||
return True | |||
if "gpu" in device: | |||
return True | |||
return False | |||
class RNNBase(Module): | |||
def __init__( | |||
self, | |||
input_size: int, | |||
hidden_size: int, | |||
num_layers: int = 1, | |||
bias: bool = True, | |||
batch_first: bool = False, | |||
dropout: float = 0, | |||
bidirectional: bool = False, | |||
proj_size: int = 0, | |||
device=None, | |||
dtype=None, | |||
) -> None: | |||
super(RNNBase, self).__init__() | |||
# self.mode = mode | |||
self.input_size = input_size | |||
self.hidden_size = hidden_size | |||
self.num_layers = num_layers | |||
self.bias = bias | |||
self.batch_first = batch_first | |||
self.dropout = float(dropout) | |||
self.bidirectional = bidirectional | |||
self.num_directions = 2 if self.bidirectional else 1 | |||
self.proj_size = proj_size | |||
# check validity of dropout | |||
if ( | |||
not isinstance(dropout, numbers.Number) | |||
or not 0 <= dropout <= 1 | |||
or isinstance(dropout, bool) | |||
): | |||
raise ValueError( | |||
"Dropout should be a float in [0, 1], which indicates the probability " | |||
"of an element to be zero" | |||
) | |||
if proj_size < 0: | |||
raise ValueError( | |||
"proj_size should be a positive integer or zero to disable projections" | |||
) | |||
elif proj_size >= hidden_size: | |||
raise ValueError("proj_size has to be smaller than hidden_size") | |||
self.cells = [] | |||
for layer in range(self.num_layers): | |||
self.cells.append([]) | |||
for _ in range(self.num_directions): | |||
self.cells[layer].append(self.create_cell(layer, device, dtype)) | |||
# parameters have been initialized during the creation of the cells | |||
# if flatten, then delete cells | |||
self._flatten_parameters(device, dtype, self.cells) | |||
def _flatten_parameters(self, device, dtype, cells): | |||
gate_hidden_size = cells[0][0].gate_hidden_size | |||
size_dim1 = 0 | |||
for layer in range(self.num_layers): | |||
for direction in range(self.num_directions): | |||
size_dim1 += cells[layer][direction].weight_ih.shape[1] | |||
size_dim1 += cells[layer][direction].weight_hh.shape[1] | |||
# if self.bias: | |||
# size_dim1 += 2 * self.num_directions * self.num_layers | |||
size_dim1 += 2 * self.num_directions * self.num_layers | |||
self._flatten_weights = Parameter( | |||
np.zeros((gate_hidden_size, size_dim1), dtype=np.float32) | |||
) | |||
self.reset_parameters() | |||
# TODO: if no bias, set the bias to zero | |||
def reset_parameters(self) -> None: | |||
stdv = 1.0 / math.sqrt(self.hidden_size) | |||
for weight in self.parameters(): | |||
init.uniform_(weight, -stdv, stdv) | |||
@abstractmethod | |||
def create_cell(self, layer, device, dtype): | |||
raise NotImplementedError("Cell not implemented !") | |||
@abstractmethod | |||
def init_hidden(self): | |||
raise NotImplementedError("init_hidden not implemented !") | |||
@abstractmethod | |||
def get_output_from_hidden(self, hx): | |||
raise NotImplementedError("get_output_from_hidden not implemented !") | |||
@abstractmethod | |||
def apply_op(self, input, hx): | |||
raise NotImplementedError("apply_op not implemented !") | |||
def _apply_fn_to_hx(self, hx, fn): | |||
return fn(hx) | |||
def _stack_h_n(self, h_n): | |||
return stack(h_n, axis=0) | |||
def forward(self, input: Tensor, hx=None): | |||
if self.batch_first: | |||
batch_size = input.shape[0] | |||
input = input.transpose((1, 0, 2)) # [seq_len, batch_size, dim] | |||
else: | |||
batch_size = input.shape[1] | |||
if hx is None: | |||
hx = self.init_hidden(batch_size, input.device, input.dtype) | |||
output, h = self.apply_op(input, hx) | |||
if self.batch_first: | |||
output = output.transpose((1, 0, 2)) | |||
return output, h | |||
if is_gpu(str(input.device)) or True: | |||
# return output, h_n | |||
output, h = self.apply_op(input, hx) | |||
if self.batch_first: | |||
output = output.transpose((1, 0, 2)) | |||
return output, h | |||
order_settings = [(0, input.shape[0]), (input.shape[0] - 1, -1, -1)] | |||
h_n = [] | |||
for layer in range(self.num_layers): | |||
layer_outputs = [] | |||
for direction in range(self.num_directions): | |||
direction_outputs = [None for _ in range(input.shape[0])] | |||
cell = self.cells[layer][direction] | |||
hidden = self._apply_fn_to_hx( | |||
hx, lambda x: x[layer * self.num_directions + direction] | |||
) | |||
for step in range(*(order_settings[direction])): | |||
hidden = cell(input[step], hidden) # [batch_size, hidden_size] | |||
direction_outputs[step] = self.get_output_from_hidden(hidden) | |||
direction_output = stack( | |||
direction_outputs, axis=0 | |||
) # [seq_len, batch_size, hidden_size] | |||
layer_outputs.append(direction_output) | |||
h_n.append(hidden) | |||
layer_output = concat( | |||
layer_outputs, axis=-1 | |||
) # [seq_len, batch_size, D*hidden_size] | |||
input = layer_output | |||
if self.batch_first: | |||
layer_output = layer_output.transpose((1, 0, 2)) | |||
return layer_output, self._stack_h_n(h_n) | |||
class RNN(RNNBase): | |||
def __init__(self, *args, **kwargs) -> None: | |||
self.nonlinearity = kwargs.pop("nonlinearity", "tanh") | |||
super(RNN, self).__init__(*args, **kwargs) | |||
def create_cell(self, layer, device, dtype): | |||
if layer == 0: | |||
input_size = self.input_size | |||
else: | |||
input_size = self.num_directions * self.hidden_size | |||
return RNNCell( | |||
input_size, self.hidden_size, self.bias, self.nonlinearity, device, dtype | |||
) | |||
def init_hidden(self, batch_size, device, dtype): | |||
hidden_shape = ( | |||
self.num_directions * self.num_layers, | |||
batch_size, | |||
self.hidden_size, | |||
) | |||
return zeros(shape=hidden_shape, dtype=dtype, device=device) | |||
def get_output_from_hidden(self, hx): | |||
return hx | |||
def apply_op(self, input, hx): | |||
op = builtin.RNN( | |||
num_layers=self.num_layers, | |||
bidirectional=self.bidirectional, | |||
bias=self.bias, | |||
hidden_size=self.hidden_size, | |||
proj_size=self.proj_size, | |||
dropout=self.dropout, | |||
nonlineMode=self.nonlinearity, | |||
) | |||
output, h = apply(op, input, hx, self._flatten_weights)[:2] | |||
output = output + h.sum() * 0 | |||
h = h + output.sum() * 0 | |||
return output, h | |||
class LSTM(RNNBase): | |||
def __init__(self, *args, **kwargs) -> None: | |||
super(LSTM, self).__init__(*args, **kwargs) | |||
def create_cell(self, layer, device, dtype): | |||
if layer == 0: | |||
input_size = self.input_size | |||
else: | |||
input_size = self.num_directions * self.hidden_size | |||
return LSTMCell(input_size, self.hidden_size, self.bias, device, dtype) | |||
def init_hidden(self, batch_size, device, dtype): | |||
hidden_shape = ( | |||
self.num_directions * self.num_layers, | |||
batch_size, | |||
self.hidden_size, | |||
) | |||
h = zeros(shape=hidden_shape, dtype=dtype, device=device) | |||
c = zeros(shape=hidden_shape, dtype=dtype, device=device) | |||
return (h, c) | |||
def get_output_from_hidden(self, hx): | |||
return hx[0] | |||
def apply_op(self, input, hx): | |||
op = builtin.LSTM( | |||
num_layers=self.num_layers, | |||
bidirectional=self.bidirectional, | |||
bias=self.bias, | |||
hidden_size=self.hidden_size, | |||
proj_size=self.proj_size, | |||
dropout=self.dropout, | |||
) | |||
output, h, c = apply(op, input, hx[0], hx[1], self._flatten_weights)[:3] | |||
placeholders = [output.sum() * 0, h.sum() * 0, c.sum() * 0] | |||
output = output + placeholders[1] + placeholders[2] | |||
h = h + placeholders[0] + placeholders[2] | |||
c = c + placeholders[0] + placeholders[1] | |||
return output, (h, c) | |||
def _apply_fn_to_hx(self, hx, fn): | |||
return (fn(hx[0]), fn(hx[1])) | |||
def _stack_h_n(self, h_n): | |||
h = [tup[0] for tup in h_n] | |||
c = [tup[1] for tup in h_n] | |||
return (stack(h, axis=0), stack(c, axis=0)) |
@@ -0,0 +1,181 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.functional as F | |||
from megengine.module import LSTM, RNN, LSTMCell, RNNCell | |||
def assert_tuple_equal(src, ref): | |||
assert len(src) == len(ref) | |||
for i, j in zip(src, ref): | |||
assert i == j | |||
@pytest.mark.parametrize( | |||
"batch_size, input_size, hidden_size, init_hidden", | |||
[(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False), (0, 10, 20, True)], | |||
) | |||
def test_rnn_cell(batch_size, input_size, hidden_size, init_hidden): | |||
rnn_cell = RNNCell(input_size, hidden_size) | |||
x = mge.random.normal(size=(batch_size, input_size)) | |||
if init_hidden: | |||
h = F.zeros(shape=(batch_size, hidden_size)) | |||
else: | |||
h = None | |||
h_new = rnn_cell(x, h) | |||
assert_tuple_equal(h_new.shape, (batch_size, hidden_size)) | |||
# is batch_size == 0 tolerated ? it will cause error in slice operation xx[:, ...] | |||
@pytest.mark.parametrize( | |||
"batch_size, input_size, hidden_size, init_hidden", | |||
[(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False)], | |||
) | |||
def test_lstm_cell(batch_size, input_size, hidden_size, init_hidden): | |||
rnn_cell = LSTMCell(input_size, hidden_size) | |||
x = mge.random.normal(size=(batch_size, input_size)) | |||
if init_hidden: | |||
h = F.zeros(shape=(batch_size, hidden_size)) | |||
hx = (h, h) | |||
else: | |||
hx = None | |||
h_new, c_new = rnn_cell(x, hx) | |||
assert_tuple_equal(h_new.shape, (batch_size, hidden_size)) | |||
assert_tuple_equal(c_new.shape, (batch_size, hidden_size)) | |||
@pytest.mark.parametrize( | |||
"batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first", | |||
[ | |||
(3, 6, 10, 20, 2, False, False, True), | |||
pytest.param( | |||
3, | |||
3, | |||
10, | |||
10, | |||
1, | |||
True, | |||
True, | |||
False, | |||
marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"), | |||
), | |||
], | |||
) | |||
# (0, 1, 1, 1, 1, False, True, False)]) | |||
def test_rnn( | |||
batch_size, | |||
seq_len, | |||
input_size, | |||
hidden_size, | |||
num_layers, | |||
bidirectional, | |||
init_hidden, | |||
batch_first, | |||
): | |||
rnn = RNN( | |||
input_size, | |||
hidden_size, | |||
batch_first=batch_first, | |||
num_layers=num_layers, | |||
bidirectional=bidirectional, | |||
) | |||
if batch_first: | |||
x_shape = (batch_size, seq_len, input_size) | |||
else: | |||
x_shape = (seq_len, batch_size, input_size) | |||
x = mge.random.normal(size=x_shape) | |||
total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size | |||
if init_hidden: | |||
h = mge.random.normal(size=(batch_size, total_hidden_size)) | |||
else: | |||
h = None | |||
output, h_n = rnn(x, h) | |||
num_directions = 2 if bidirectional else 1 | |||
if batch_first: | |||
assert_tuple_equal( | |||
output.shape, (batch_size, seq_len, num_directions * hidden_size) | |||
) | |||
else: | |||
assert_tuple_equal( | |||
output.shape, (seq_len, batch_size, num_directions * hidden_size) | |||
) | |||
assert_tuple_equal( | |||
h_n.shape, (num_directions * num_layers, batch_size, hidden_size) | |||
) | |||
@pytest.mark.parametrize( | |||
"batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first", | |||
[ | |||
(3, 10, 20, 20, 1, False, False, True), | |||
pytest.param( | |||
3, | |||
3, | |||
10, | |||
10, | |||
1, | |||
True, | |||
True, | |||
False, | |||
marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"), | |||
), | |||
], | |||
) | |||
# (0, 1, 1, 1, 1, False, True, False)]) | |||
def test_lstm( | |||
batch_size, | |||
seq_len, | |||
input_size, | |||
hidden_size, | |||
num_layers, | |||
bidirectional, | |||
init_hidden, | |||
batch_first, | |||
): | |||
rnn = LSTM( | |||
input_size, | |||
hidden_size, | |||
batch_first=batch_first, | |||
num_layers=num_layers, | |||
bidirectional=bidirectional, | |||
) | |||
if batch_first: | |||
x_shape = (batch_size, seq_len, input_size) | |||
else: | |||
x_shape = (seq_len, batch_size, input_size) | |||
x = mge.random.normal(size=x_shape) | |||
total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size | |||
if init_hidden: | |||
h = mge.random.normal(size=(batch_size, total_hidden_size)) | |||
h = (h, h) | |||
else: | |||
h = None | |||
output, h_n = rnn(x, h) | |||
num_directions = 2 if bidirectional else 1 | |||
if batch_first: | |||
assert_tuple_equal( | |||
output.shape, (batch_size, seq_len, num_directions * hidden_size) | |||
) | |||
else: | |||
assert_tuple_equal( | |||
output.shape, (seq_len, batch_size, num_directions * hidden_size) | |||
) | |||
assert_tuple_equal( | |||
h_n[0].shape, (num_directions * num_layers, batch_size, hidden_size) | |||
) | |||
assert_tuple_equal( | |||
h_n[1].shape, (num_directions * num_layers, batch_size, hidden_size) | |||
) | |||
if __name__ == "__main__": | |||
test_lstm(5, 10, 10, 20, 1, False, False, True) |
@@ -0,0 +1,68 @@ | |||
/** | |||
* \file imperative/src/impl/ops/rnn.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/opr/dnn/rnn.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "../op_trait.h" | |||
namespace mgb::imperative { | |||
namespace rnn_cell { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const RNNCell&>(def); | |||
mgb_assert(inputs.size() == 6); | |||
return opr::RNNCell::make( | |||
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5], | |||
op.param()); | |||
} | |||
OP_TRAIT_REG(RNNCell, RNNCell).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace rnn_cell | |||
namespace lstm_cell { | |||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const LSTMCell&>(def); | |||
mgb_assert(inputs.size() == 7); | |||
auto* opr = opr::LSTMCell::make( | |||
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], | |||
inputs[5], inputs[6], op.param()) | |||
.node() | |||
->owner_opr(); | |||
return {opr->output(0), opr->output(1), opr->output(2)}; | |||
} | |||
OP_TRAIT_REG(LSTMCell, LSTMCell).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace lstm_cell | |||
namespace rnn { | |||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const RNN&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
auto* opr = opr::RNN::make(inputs[0], inputs[1], inputs[2], op.param()) | |||
.node() | |||
->owner_opr(); | |||
return {opr->output(0), opr->output(1), opr->output(2)}; | |||
} | |||
OP_TRAIT_REG(RNN, RNN).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace rnn | |||
namespace lstm { | |||
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const LSTM&>(def); | |||
mgb_assert(inputs.size() == 4); | |||
auto* opr = opr::LSTM::make(inputs[0], inputs[1], inputs[2], inputs[3], op.param()) | |||
.node() | |||
->owner_opr(); | |||
return {opr->output(0), opr->output(1), opr->output(2), opr->output(3)}; | |||
} | |||
OP_TRAIT_REG(LSTM, LSTM).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace lstm | |||
} // namespace mgb::imperative |
@@ -431,4 +431,12 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>; | |||
def LRN: MgbHashableOp<"LRN", [LRNParam]>; | |||
def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; | |||
def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>; | |||
def RNN: MgbHashableOp<"RNN", [RNNParam]>; | |||
def LSTM: MgbHashableOp<"LSTM", [LSTMParam]>; | |||
#endif // MGB_OPS |
@@ -20,6 +20,7 @@ | |||
#include "megbrain/opr/dnn/lrn.h" | |||
#include "megbrain/opr/dnn/lsq.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/opr/dnn/rnn.h" | |||
#include "megbrain/opr/dnn/roi_align.h" | |||
#include "megbrain/opr/dnn/roi_pooling.h" | |||
#include "megbrain/opr/dnn/sliding_window_transpose.h" | |||
@@ -292,6 +293,36 @@ struct OprMaker<opr::LSQBackward, 5> { | |||
->owner_opr(); | |||
} | |||
}; | |||
template <> | |||
struct OprMaker<opr::RNNBackward, 7> { | |||
using Param = opr::RNNBackward::Param; | |||
static cg::OperatorNodeBase* make( | |||
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||
const OperatorNodeConfig& config) { | |||
MGB_MARK_USED_VAR(graph); | |||
return opr::RNNBackward::make( | |||
i[0], i[1], i[2], i[3], i[4], i[5], i[6], param, config)[0] | |||
.node() | |||
->owner_opr(); | |||
} | |||
}; | |||
template <> | |||
struct OprMaker<opr::LSTMBackward, 9> { | |||
using Param = opr::LSTMBackward::Param; | |||
static cg::OperatorNodeBase* make( | |||
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||
const OperatorNodeConfig& config) { | |||
MGB_MARK_USED_VAR(graph); | |||
return opr::LSTMBackward::make( | |||
i[0], i[1], i[2], i[3], i[4], i[5], i[6], i[7], i[8], param, | |||
config)[0] | |||
.node() | |||
->owner_opr(); | |||
} | |||
}; | |||
template <> | |||
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | |||
: public GeneralOprLoadDumpImpl< | |||
@@ -641,6 +672,10 @@ MGB_SEREG_OPR(TQT, 2); | |||
MGB_SEREG_OPR(TQTBackward, 3); | |||
MGB_SEREG_OPR(LSQ, 4); | |||
MGB_SEREG_OPR(LSQBackward, 5); | |||
MGB_SEREG_OPR(RNNForward, 3); | |||
MGB_SEREG_OPR(RNNBackward, 7); | |||
MGB_SEREG_OPR(LSTMForward, 4); | |||
MGB_SEREG_OPR(LSTMBackward, 9); | |||
} // namespace opr | |||
} // namespace mgb | |||
@@ -0,0 +1,323 @@ | |||
/** | |||
* \file src/opr/impl/dnn/rnn.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/opr/dnn/rnn.h" | |||
#include "../internal/megdnn_opr_wrapper.inl" | |||
#include "megbrain/graph/grad_impl.h" | |||
#include "megbrain/opr/basic_arith_wrapper.h" | |||
#include "megbrain/opr/blas.h" | |||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/utility.h" | |||
using namespace mgb; | |||
using namespace opr; | |||
/* ================= RNNCell ================= */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNNCellForward); | |||
RNNCellForward::RNNCellForward( | |||
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, | |||
VarNode* weight_hh, VarNode* bias_hh, const Param& param, | |||
const OperatorNodeConfig& config) | |||
: Super{input->owner_graph(), | |||
config, | |||
"rnn_cell", | |||
{input, weight_ih, bias_ih, hx, weight_hh, bias_hh}} { | |||
init_megdnn_opr(*this, param); | |||
add_input({input, weight_ih, bias_ih, hx, weight_hh, bias_hh}); | |||
} | |||
SymbolVar RNNCellForward::make( | |||
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, | |||
SymbolVar weight_hh, SymbolVar bias_hh, const Param& param, | |||
const OperatorNodeConfig& config) { | |||
return input.insert_single_output_opr<RNNCellForward>( | |||
input.node(), weight_ih.node(), bias_ih.node(), hx.node(), weight_hh.node(), | |||
bias_hh.node(), param, config); | |||
} | |||
#if MGB_ENABLE_GRAD | |||
VarNode* rnnCellBackward( | |||
const SymbolVar& input, const SymbolVar& weight_ih, const SymbolVar& hx, | |||
const SymbolVar& weight_hh, const SymbolVar& out, | |||
RNNCell::NonlineMode nonlineMode, size_t wrt_idx, const SymbolVar& og) { | |||
SymbolVar tmp; | |||
// activation | |||
using NonlineMode = RNNCell::NonlineMode; | |||
using Mode = Elemwise::Mode; | |||
switch (nonlineMode) { | |||
case NonlineMode::IDENTITY: | |||
tmp = og; | |||
break; | |||
case NonlineMode::TANH: | |||
tmp = Elemwise::make({out, og}, Mode::TANH_GRAD); | |||
break; | |||
case NonlineMode::RELU: | |||
tmp = Elemwise::make({out, og}, Mode::SWITCH_GT0); | |||
break; | |||
default: | |||
mgb_throw(GraphError, "Activation method not supported"); | |||
} | |||
// now grad is in tmp | |||
if (wrt_idx == 2 || wrt_idx == 5) | |||
return tmp.node(); // bias | |||
SymbolVar result; | |||
// A * Bt = C, A' = C' * B, B' = C't * A | |||
if (wrt_idx == 0) { // input | |||
result = MatrixMul::make( | |||
tmp, weight_ih, | |||
{false, false}); // transpose a false, transpose b false | |||
} else if (wrt_idx == 1) { // weight_ih | |||
result = MatrixMul::make(tmp, input, {true, false}); | |||
} else if (wrt_idx == 3) { // hx | |||
result = MatrixMul::make(tmp, weight_hh, {false, false}); | |||
} else if (wrt_idx == 4) { // weight_hh | |||
result = MatrixMul::make(tmp, hx, {true, false}); | |||
} | |||
return result.node(); | |||
} | |||
MGB_IMPL_OPR_GRAD(RNNCell) { | |||
SymbolVar input(opr.input(0)), weight_ih(opr.input(1)), hx(opr.input(3)), | |||
weight_hh(opr.input(4)); | |||
SymbolVar out(opr.output(0)), og{out_grad.at(0)}; | |||
return rnnCellBackward( | |||
input, weight_ih, hx, weight_hh, out, opr.param().nonlineMode, wrt_idx, og); | |||
} | |||
#endif | |||
/* ================= LSTMCell ================= */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTMCell); | |||
LSTMCellForward::LSTMCellForward( | |||
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, | |||
VarNode* weight_hh, VarNode* bias_hh, VarNode* cx, const Param& param, | |||
const OperatorNodeConfig& config) | |||
: Super{input->owner_graph(), | |||
config, | |||
"lstm_cell", | |||
{input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx}} { | |||
init_megdnn_opr(*this, param); | |||
add_input({input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx}); | |||
} | |||
SymbolVar LSTMCellForward::make( | |||
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, | |||
SymbolVar weight_hh, SymbolVar bias_hh, SymbolVar cx, const Param& param, | |||
const OperatorNodeConfig& config) { | |||
return input.insert_single_output_opr<LSTMCellForward>( | |||
input.node(), weight_ih.node(), bias_ih.node(), hx.node(), weight_hh.node(), | |||
bias_hh.node(), cx.node(), param, config); | |||
} | |||
#if MGB_ENABLE_GRAD | |||
MGB_IMPL_OPR_GRAD(LSTMCell) { | |||
SymbolVar input(opr.input(0)), weight_ih(opr.input(1)), hx(opr.input(3)), | |||
weight_hh(opr.input(4)), cx(opr.input(6)); | |||
SymbolVar h_out(opr.output(0)), c_out(opr.output(1)), gates(opr.output(2)), | |||
h_og{out_grad.at(0)}, c_og{out_grad.at(1)}, tmp; | |||
size_t ghs = gates.shape()[1] / 4; // gate_hidden_size | |||
SymbolVarArray gates_array = Split::make( | |||
gates, Split::Options::make_partition(gates, 1, {ghs, ghs, ghs, ghs})); | |||
mgb_assert(gates_array.size() == 4); | |||
using Mode = Elemwise::Mode; | |||
const SymbolVar &i(Elemwise::make({gates_array.at(0)}, Mode::SIGMOID)), | |||
f(Elemwise::make({gates_array.at(1)}, Mode::SIGMOID)), | |||
o(Elemwise::make({gates_array.at(2)}, Mode::SIGMOID)), | |||
g(Elemwise::make({gates_array.at(3)}, Mode::TANH)); | |||
SymbolVar i_grad, f_grad, o_grad, g_grad; | |||
SymbolVar tanh_c_out = Elemwise::make({c_out}, Mode::TANH); | |||
o_grad = Elemwise::make({o, h_og * tanh_c_out}, Mode::SIGMOID_GRAD); | |||
c_og = c_og + Elemwise::make({tanh_c_out, h_og * o}, Mode::TANH_GRAD); | |||
f_grad = Elemwise::make({f, c_og * cx}, Mode::SIGMOID_GRAD); | |||
i_grad = Elemwise::make({i, c_og * g}, Mode::SIGMOID_GRAD); | |||
g_grad = Elemwise::make({g, c_og * i}, Mode::TANH_GRAD); | |||
SymbolVar rnn_cell_grad = Concat::make({i_grad, f_grad, o_grad, g_grad}, {-1}); | |||
SymbolVar result; | |||
if (wrt_idx < 6) { | |||
using NonlineMode = RNNCell::NonlineMode; | |||
result = rnnCellBackward( | |||
input, weight_ih, hx, weight_hh, gates, NonlineMode::IDENTITY, wrt_idx, | |||
rnn_cell_grad); | |||
} else { // cx | |||
result = c_og * f; | |||
} | |||
return result.node(); | |||
} | |||
#endif | |||
/* ================= RNN ================= */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNN); | |||
MEGDNN_OPR_INIT3(RNNForward, "rnn_fwd"); | |||
#if MGB_ENABLE_GRAD | |||
MGB_IMPL_OPR_GRAD(RNN) { | |||
mgb_assert( | |||
opr.param().fwd_mode == RNN::Param::FwdMode::TRAINING, | |||
"RNN could only take grad in training mode"); | |||
SymbolVarArray grads = RNNBackward::make( | |||
opr.input(0), opr.output(0), opr.input(1), out_grad.at(0), out_grad.at(1), | |||
opr.input(2), opr.output(2), opr.param()); | |||
// return grads.at(wrt_idx).node(); // input, hx, weights | |||
VarNodeArray ret(opr.input().size(), nullptr); | |||
for (size_t i = 0; i < ret.size(); ++i) { | |||
ret[i] = grads[i].node(); | |||
} | |||
return ret; | |||
} | |||
#endif | |||
/* ================= RNNBackward ================= */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNNBackward); | |||
RNNBackward::RNNBackward( | |||
VarNode* x, VarNode* y, VarNode* hx, VarNode* dy, VarNode* dhy, | |||
VarNode* flatten_weights, VarNode* reserve_space, const Param& param, | |||
const OperatorNodeConfig& config) | |||
: Super({x->owner_graph(), | |||
config, | |||
"rnn_bwd", | |||
{x, y, hx, dy, dhy, flatten_weights, reserve_space}}, | |||
0, true) { | |||
init_megdnn_opr(*this, param); | |||
add_input({x, y, hx, dy, dhy, flatten_weights, reserve_space}); | |||
} | |||
SymbolVarArray RNNBackward::make( | |||
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar dy, SymbolVar dhy, | |||
SymbolVar flatten_weights, SymbolVar reserve_space, const Param& param, | |||
const OperatorNodeConfig& config) { | |||
auto&& out = x.node()->owner_graph() | |||
->insert_opr(std::make_unique<RNNBackward>( | |||
x.node(), y.node(), hx.node(), dy.node(), dhy.node(), | |||
flatten_weights.node(), reserve_space.node(), param, | |||
config)) | |||
->output(); | |||
SymbolVarArray ret(out.size()); | |||
for (size_t i = 0; i < ret.size(); ++i) { | |||
ret[i] = out[i]; | |||
} | |||
return ret; | |||
} | |||
RNNBackward::Super::NodeProp* RNNBackward::do_make_node_prop() const { | |||
auto ret = Super::do_make_node_prop(); | |||
ret->add_dep_type_existing_var(input(6), NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
return ret; | |||
} | |||
void RNNBackward::init_output_static_infer_desc() { | |||
using namespace cg::static_infer; | |||
auto&& mgr = owner_graph()->static_infer_manager(); | |||
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); | |||
mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); | |||
mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(5))); | |||
this->init_output_static_infer_desc_workspace( | |||
intl::AutoAddWorkspaceNeedLimitGetter<megdnn::RNNBackward>::val); | |||
} | |||
void RNNBackward::init_output_dtype() { | |||
output(0)->dtype(input(0)->dtype()); | |||
output(1)->dtype(input(2)->dtype()); | |||
output(2)->dtype(input(5)->dtype()); | |||
} | |||
/* ================= LSTM ================= */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTM); | |||
LSTMForward::LSTMForward( | |||
VarNode* input, VarNode* hx, VarNode* cx, VarNode* flatten_weights, | |||
const Param& param, const OperatorNodeConfig& config) | |||
: Super{input->owner_graph(), | |||
config, | |||
"lstm", | |||
{input, hx, cx, flatten_weights}} { | |||
init_megdnn_opr(*this, param); | |||
add_input({input, hx, cx, flatten_weights}); | |||
} | |||
SymbolVar LSTMForward::make( | |||
SymbolVar input, SymbolVar hx, SymbolVar cx, SymbolVar flatten_weights, | |||
const Param& param, const OperatorNodeConfig& config) { | |||
return input.insert_single_output_opr<LSTMForward>( | |||
input.node(), hx.node(), cx.node(), flatten_weights.node(), param, config); | |||
} | |||
#if MGB_ENABLE_GRAD | |||
MGB_IMPL_OPR_GRAD(LSTM) { | |||
SymbolVarArray grads = LSTMBackward::make( | |||
opr.input(0), opr.output(0), opr.input(1), opr.input(2), out_grad.at(0), | |||
out_grad.at(1), out_grad.at(2), opr.input(3), opr.output(3), opr.param()); | |||
SymbolVar res; | |||
return grads.at(wrt_idx).node(); // input, hx, cx, weights | |||
} | |||
#endif | |||
/* ================= LSTMBackward ================= */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTMBackward); | |||
LSTMBackward::LSTMBackward( | |||
VarNode* x, VarNode* y, VarNode* hx, VarNode* cx, VarNode* dy, VarNode* dhy, | |||
VarNode* dcy, VarNode* flatten_weights, VarNode* reserve_space, | |||
const Param& param, const OperatorNodeConfig& config) | |||
: Super({x->owner_graph(), | |||
config, | |||
"lstm_bwd", | |||
{x, y, hx, cx, dy, dhy, dcy, flatten_weights, reserve_space}}, | |||
1, true) { | |||
init_megdnn_opr(*this, param); | |||
add_input({x, y, hx, cx, dy, dhy, dcy, flatten_weights, reserve_space}); | |||
} | |||
SymbolVarArray LSTMBackward::make( | |||
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar cx, SymbolVar dy, | |||
SymbolVar dhy, SymbolVar dcy, SymbolVar flatten_weights, | |||
SymbolVar reserve_space, const Param& param, const OperatorNodeConfig& config) { | |||
auto&& out = x.node()->owner_graph() | |||
->insert_opr(std::make_unique<LSTMBackward>( | |||
x.node(), y.node(), hx.node(), cx.node(), dy.node(), | |||
dhy.node(), dcy.node(), flatten_weights.node(), | |||
reserve_space.node(), param, config)) | |||
->output(); | |||
SymbolVarArray ret(out.size()); | |||
for (size_t i = 0; i < ret.size(); ++i) { | |||
ret[i] = out[i]; | |||
} | |||
return ret; | |||
} | |||
LSTMBackward::Super::NodeProp* LSTMBackward::do_make_node_prop() const { | |||
auto ret = Super::do_make_node_prop(); | |||
ret->add_dep_type_existing_var( | |||
input(8), // reserve space | |||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
return ret; | |||
} | |||
void LSTMBackward::init_output_static_infer_desc() { | |||
using namespace cg::static_infer; | |||
auto&& mgr = owner_graph()->static_infer_manager(); | |||
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); | |||
mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); | |||
mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(3))); | |||
mgr.register_shape_infer(output(3), ShapeInferDesc::make_identity(input(7))); | |||
this->init_output_static_infer_desc_workspace( | |||
intl::AutoAddWorkspaceNeedLimitGetter<megdnn::LSTMBackward>::val); | |||
} | |||
void LSTMBackward::init_output_dtype() { | |||
output(0)->dtype(input(0)->dtype()); | |||
output(1)->dtype(input(2)->dtype()); | |||
output(2)->dtype(input(3)->dtype()); | |||
output(3)->dtype(input(7)->dtype()); | |||
} |
@@ -170,6 +170,16 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU | |||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0) | |||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||
#define _NR_INPUTS 4 | |||
#define _NR_OUTPUTS 4 | |||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0), _o(1), _o(2), _o(3) | |||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||
#define _NR_INPUTS 5 | |||
#define _NR_OUTPUTS 1 | |||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0) | |||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||
#define _NR_INPUTS 5 | |||
#define _NR_OUTPUTS 2 | |||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1) | |||
@@ -180,11 +190,34 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU | |||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2) | |||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||
#define _NR_INPUTS 6 | |||
#define _NR_OUTPUTS 1 | |||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0) | |||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||
#define _NR_INPUTS 6 | |||
#define _NR_OUTPUTS 2 | |||
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1) | |||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||
#define _NR_INPUTS 6 | |||
#define _NR_OUTPUTS 3 | |||
#define _FOREACH_IO(_i, _o) \ | |||
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2) | |||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||
#define _NR_INPUTS 7 | |||
#define _NR_OUTPUTS 3 | |||
#define _FOREACH_IO(_i, _o) \ | |||
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _o(0), _o(1), _o(2) | |||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||
#define _NR_INPUTS 9 | |||
#define _NR_OUTPUTS 4 | |||
#define _FOREACH_IO(_i, _o) \ | |||
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _i(7), _i(8), _o(0), _o(1), \ | |||
_o(2), _o(3) | |||
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" | |||
} // anonymous namespace | |||
/* ======================= MegDNNOprWrapperFwd ======================= */ | |||
@@ -0,0 +1,120 @@ | |||
/** | |||
* \file src/opr/include/megbrain/opr/dnn/rnn.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||
#if MGB_CUDA | |||
#include "../../../../impl/nvof/denseflownvidia.h" | |||
#include "megbrain/opr/param_defs.h" | |||
#endif | |||
#include "megdnn/oprs.h" | |||
namespace mgb { | |||
namespace opr { | |||
MGB_DEFINE_OPR_CLASS( | |||
RNNCellForward, intl::MegDNNOprWrapperFwd<megdnn::RNNCellForward>) // { | |||
public: | |||
using NonlineMode = Param::NonlineMode; | |||
RNNCellForward( | |||
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, | |||
VarNode* weight_hh, VarNode* bias_hh, const Param& param, | |||
const OperatorNodeConfig& config); | |||
static SymbolVar make( | |||
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, | |||
SymbolVar weight_hh, SymbolVar bias_hh, const Param& param = {}, | |||
const OperatorNodeConfig& config = {}); | |||
}; | |||
using RNNCell = RNNCellForward; | |||
MGB_DEFINE_OPR_CLASS( | |||
LSTMCellForward, intl::MegDNNOprWrapperFwd<megdnn::LSTMCellForward>) // { | |||
public: | |||
LSTMCellForward( | |||
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx, | |||
VarNode* weight_hh, VarNode* bias_hh, VarNode* cx, const Param& param, | |||
const OperatorNodeConfig& config); | |||
static SymbolVar make( | |||
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx, | |||
SymbolVar weight_hh, SymbolVar bias_hh, SymbolVar cx, | |||
const Param& param = {}, const OperatorNodeConfig& config = {}); | |||
}; | |||
using LSTMCell = LSTMCellForward; | |||
MGB_DEFINE_OPR_CLASS(RNNForward, intl::MegDNNOprWrapperFwd<megdnn::RNNForward>) // { | |||
/*private: | |||
SymbolVarArray weight_ih_arr; // 1d, idx: direction * num_layers + layer | |||
SymbolVarArray weight_hh_arr; | |||
SymbolVarArray bias_arr; | |||
*/ | |||
public: | |||
RNNForward( | |||
VarNode* input, VarNode* hx, VarNode* flatten_weights, const Param& param, | |||
const OperatorNodeConfig& config); | |||
static SymbolVar make( | |||
SymbolVar input, SymbolVar hx, SymbolVar flatten_weights, | |||
const Param& param = {}, const OperatorNodeConfig& config = {}); | |||
}; | |||
using RNN = RNNForward; | |||
MGB_DEFINE_OPR_CLASS( | |||
RNNBackward, intl::MegDNNOprWrapperBwd<megdnn::RNNBackward>) // { | |||
public: | |||
RNNBackward( | |||
VarNode* x, VarNode* y, VarNode* hx, VarNode* dy, VarNode* dhy, | |||
VarNode* flatten_weights, VarNode* reserve_space, const Param& param, | |||
const OperatorNodeConfig& config); | |||
static SymbolVarArray make( | |||
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar dy, SymbolVar dhy, | |||
SymbolVar flatten_weights, SymbolVar reserve_space, const Param& param = {}, | |||
const OperatorNodeConfig& config = {}); | |||
Super::NodeProp* do_make_node_prop() const override; | |||
private: | |||
void init_output_static_infer_desc() override; | |||
void init_output_dtype() override; | |||
}; | |||
MGB_DEFINE_OPR_CLASS( | |||
LSTMForward, intl::MegDNNOprWrapperFwd<megdnn::LSTMForward>) // { | |||
public: | |||
LSTMForward( | |||
VarNode* input, VarNode* hx, VarNode* cx, VarNode* flatten_weights, | |||
const Param& param, const OperatorNodeConfig& config); | |||
static SymbolVar make( | |||
SymbolVar input, SymbolVar hx, SymbolVar cx, SymbolVar flatten_weights, | |||
const Param& param = {}, const OperatorNodeConfig& config = {}); | |||
}; | |||
using LSTM = LSTMForward; | |||
MGB_DEFINE_OPR_CLASS( | |||
LSTMBackward, intl::MegDNNOprWrapperBwd<megdnn::LSTMBackward>) // { | |||
public: | |||
LSTMBackward( | |||
VarNode* x, VarNode* y, VarNode* hx, VarNode* cx, VarNode* dy, VarNode* dhy, | |||
VarNode* dcy, VarNode* flatten_weights, VarNode* reserve_space, | |||
const Param& param, const OperatorNodeConfig& config); | |||
static SymbolVarArray make( | |||
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar cx, SymbolVar dy, | |||
SymbolVar dhy, SymbolVar dcy, SymbolVar flatten_weights, | |||
SymbolVar reserve_space, const Param& param = {}, | |||
const OperatorNodeConfig& config = {}); | |||
Super::NodeProp* do_make_node_prop() const override; | |||
private: | |||
void init_output_static_infer_desc() override; | |||
void init_output_dtype() override; | |||
}; | |||
} // namespace opr | |||
} // namespace mgb |
@@ -116,6 +116,9 @@ union OperatorParam { | |||
param.Padding = 82, | |||
param.ShuffleRNG = 83, | |||
param.CheckNonFinite = 84, | |||
param.RNNCell = 85, | |||
param.RNN = 86, | |||
param.LSTM = 87, | |||
} | |||
table Operator { | |||