Browse Source

feat(mgb/opr): add cell level rnn/lstm and sequence level rnn/lstm

tags/v1.8.0
kxz@thumt102-1 Zhenhuan Chen 3 years ago
parent
commit
8f48da7ffe
59 changed files with 4601 additions and 20 deletions
  1. +5
    -6
      .gitignore
  2. +192
    -0
      dnn/include/megdnn/oprs/nn.h
  3. +25
    -0
      dnn/scripts/opr_param_defs.py
  4. +2
    -4
      dnn/src/common/handle.cpp
  5. +7
    -1
      dnn/src/common/handle_impl.h
  6. +98
    -0
      dnn/src/common/lstm.cpp
  7. +136
    -0
      dnn/src/common/lstm_cell.cpp
  8. +32
    -0
      dnn/src/common/lstm_cell.h
  9. +2
    -0
      dnn/src/common/opr_trait.h
  10. +2
    -3
      dnn/src/common/relayout_format.cpp
  11. +82
    -0
      dnn/src/common/rnn.cpp
  12. +33
    -0
      dnn/src/common/rnn.h
  13. +108
    -0
      dnn/src/common/rnn_cell.cpp
  14. +31
    -0
      dnn/src/common/rnn_cell.h
  15. +114
    -0
      dnn/src/cuda/cudnn_wrapper.cpp
  16. +39
    -0
      dnn/src/cuda/cudnn_wrapper.h
  17. +4
    -0
      dnn/src/cuda/handle_create.cpp
  18. +112
    -0
      dnn/src/cuda/lstm/opr_impl.cpp
  19. +56
    -0
      dnn/src/cuda/lstm/opr_impl.h
  20. +39
    -0
      dnn/src/cuda/lstm/utils.cpp
  21. +23
    -0
      dnn/src/cuda/lstm/utils.h
  22. +42
    -0
      dnn/src/cuda/lstm_cell/opr_impl.cpp
  23. +36
    -0
      dnn/src/cuda/lstm_cell/opr_impl.h
  24. +170
    -0
      dnn/src/cuda/rnn/opr_impl.cpp
  25. +57
    -0
      dnn/src/cuda/rnn/opr_impl.h
  26. +138
    -0
      dnn/src/cuda/rnn/utils.cpp
  27. +56
    -0
      dnn/src/cuda/rnn/utils.h
  28. +35
    -0
      dnn/src/cuda/rnn_cell/opr_impl.cpp
  29. +40
    -0
      dnn/src/cuda/rnn_cell/opr_impl.h
  30. +4
    -0
      dnn/src/naive/handle.cpp
  31. +146
    -0
      dnn/src/naive/lstm/opr_impl.cpp
  32. +56
    -0
      dnn/src/naive/lstm/opr_impl.h
  33. +55
    -0
      dnn/src/naive/lstm/template_impl.cpp
  34. +38
    -0
      dnn/src/naive/lstm_cell/opr_impl.cpp
  35. +36
    -0
      dnn/src/naive/lstm_cell/opr_impl.h
  36. +1
    -3
      dnn/src/naive/relayout/opr_impl.cpp
  37. +75
    -0
      dnn/src/naive/rnn/funcs.h
  38. +449
    -0
      dnn/src/naive/rnn/funcs.tpp
  39. +196
    -0
      dnn/src/naive/rnn/opr_impl.cpp
  40. +53
    -0
      dnn/src/naive/rnn/opr_impl.h
  41. +285
    -0
      dnn/src/naive/rnn/rnn.cpp
  42. +73
    -0
      dnn/src/naive/rnn/rnn.h
  43. +41
    -0
      dnn/src/naive/rnn/template_impl.cpp
  44. +34
    -0
      dnn/src/naive/rnn_cell/opr_impl.cpp
  45. +33
    -0
      dnn/src/naive/rnn_cell/opr_impl.h
  46. +9
    -0
      dnn/test/common/deduce_layout_proxy.h
  47. +2
    -3
      dnn/test/common/elemwise.cpp
  48. +51
    -0
      dnn/test/common/rnn.h
  49. +80
    -0
      dnn/test/naive/rnn.cpp
  50. +1
    -0
      imperative/python/megengine/module/__init__.py
  51. +396
    -0
      imperative/python/megengine/module/rnn.py
  52. +181
    -0
      imperative/python/test/unit/module/test_rnn.py
  53. +68
    -0
      imperative/src/impl/ops/rnn.cpp
  54. +8
    -0
      src/core/include/megbrain/ir/ops.td
  55. +35
    -0
      src/opr/impl/dnn/dnn.sereg.h
  56. +323
    -0
      src/opr/impl/dnn/rnn.cpp
  57. +33
    -0
      src/opr/impl/internal/megdnn_opr_wrapper.inl
  58. +120
    -0
      src/opr/include/megbrain/opr/dnn/rnn.h
  59. +3
    -0
      src/serialization/impl/schema.fbs

+ 5
- 6
.gitignore View File

@@ -1,7 +1,3 @@
# Build
build/
output/

# Cache
__pycache__/
.ccls-cache/
@@ -11,5 +7,8 @@ __pycache__/
.vs/
.idea/

# CMake
compile_commands.json
# Make and Build Settings
build/
output/
compile_commands.json
imperative/python/megengine/core/*.so

+ 192
- 0
dnn/include/megdnn/oprs/nn.h View File

@@ -1936,6 +1936,198 @@ protected:
const TensorLayout& grad_s, size_t workspace_in_bytes);
};

class RNNCellForward : public OperatorBase {
DEF_OPR_PARAM(RNNCell);
DEF_OPR_IMPL(RNNCellForward, OperatorBase, 6, 1);

public:
virtual void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0;
static void deduce_layout(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
TensorLayout& dst);
virtual size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst) = 0;

protected:
void check_exec(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst, size_t workspace_in_bytes);
};
using RNNCell = RNNCellForward;

class LSTMCellForward : public OperatorBase {
// DEF_OPR_PARAM(LSTMCell);
DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(LSTMCellForward, OperatorBase, 7, 3);

public:
virtual void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new,
TensorLayout& gates);
virtual size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new,
const TensorLayout& c_new, const TensorLayout& gates) = 0;

protected:
void check_exec(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new,
const TensorLayout& c_new, const TensorLayout& gates,
size_t workspace_in_bytes);
};
using LSTMCell = LSTMCellForward;

class RNNForward : public OperatorBase {
DEF_OPR_PARAM(RNN);
DEF_OPR_IMPL(RNNForward, OperatorBase, 3, 3);

public:
virtual void exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
_megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
TensorLayout& reserve_space);
virtual size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& reserve_space) = 0;
virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0;

protected:
void check_exec(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& reserve_space,
size_t workspace_in_bytes);
};
using RNN = RNNForward;

class RNNBackward : public OperatorBase {
DEF_OPR_PARAM(RNN);
DEF_OPR_IMPL(RNNBackward, OperatorBase, 7, 3);

public:
virtual void exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in dy, _megdnn_tensor_in dhy,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space,
_megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw,
_megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw);
virtual size_t get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
const TensorLayout& dx, const TensorLayout& dhx,
const TensorLayout& dw) = 0;

protected:
void check_exec(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw,
size_t workspace_in_bytes);
};

class LSTMForward : public OperatorBase {
DEF_OPR_PARAM(LSTM);
DEF_OPR_IMPL(LSTMForward, OperatorBase, 4, 4);

public:
virtual void exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out cy,
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
TensorLayout& cy, TensorLayout& reserve_space);
virtual size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& reserve_space) = 0;
virtual size_t get_reserve_size_in_bytes(const TensorLayout& input) = 0;

protected:
void check_exec(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& reserve_space, size_t workspace_in_bytes);
};
using LSTM = LSTMForward;

class LSTMBackward : public OperatorBase {
DEF_OPR_PARAM(LSTM);
DEF_OPR_IMPL(LSTMBackward, OperatorBase, 9, 4);

public:
virtual void exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights,
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx,
_megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw,
_megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, TensorLayout& dx, TensorLayout& dhx,
TensorLayout& dcx, TensorLayout& dw);
virtual size_t get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, const TensorLayout& dx,
const TensorLayout& dhx, const TensorLayout& dcx,
const TensorLayout& dw) = 0;

protected:
void check_exec(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, const TensorLayout& dx,
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw,
size_t workspace_in_bytes);
};
} // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h"



+ 25
- 0
dnn/scripts/opr_param_defs.py View File

@@ -1203,3 +1203,28 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
member_alias=[(i, 'PADDING_{}'.format(i)) for i in PADDING_MODES]
)
)

(pdef('RNNCell').
add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'TANH = 2')
)

(pdef('RNN').
add_fields('uint32', 'num_layers', '1').
add_fields('bool', 'bidirectional', 'false').
add_fields('bool', 'bias', 'true').
add_fields('uint32', 'hidden_size', '128').
add_fields('uint32', 'proj_size', '0').
add_fields('float32', 'dropout', '0.f').
add_enum_alias('NonlineMode', 'RNNCell').
add_enum_alias('FwdMode', 'BN', name_field='fwd_mode')
)

(pdef('LSTM').
add_fields('uint32', 'num_layers', '1').
add_fields('bool', 'bidirectional', 'false').
add_fields('bool', 'bias', 'true').
add_fields('uint32', 'hidden_size', '128').
add_fields('uint32', 'proj_size', '0').
add_fields('float32', 'dropout', '0.f').
add_enum_alias('FwdMode', 'BN', name_field='fwd_mode')
)

+ 2
- 4
dnn/src/common/handle.cpp View File

@@ -92,8 +92,7 @@ std::unique_ptr<Handle> Handle::make(
}
MIDOUT_END();

}
else if (platform == megcorePlatformROCM) {
} else if (platform == megcorePlatformROCM) {
#if MEGDNN_WITH_ROCM
return make_rocm_handle(computing_handle);
#else
@@ -111,8 +110,7 @@ std::unique_ptr<Handle> Handle::make(
#else
return nullptr;
#endif
}
else {
} else {
// CUDA
megdnn_throw_if(
platform != megcorePlatformCUDA, megdnn_error,


+ 7
- 1
dnn/src/common/handle_impl.h View File

@@ -209,7 +209,13 @@ private:
cb(LSQBackward) \
cb(Fill) \
cb(PaddingForward) \
cb(PaddingBackward)
cb(PaddingBackward) \
cb(RNNCell) \
cb(LSTMCell) \
cb(RNN) \
cb(RNNBackward) \
cb(LSTM) \
cb(LSTMBackward)
// clang-format on

/*!


+ 98
- 0
dnn/src/common/lstm.cpp View File

@@ -0,0 +1,98 @@
/**
* \file dnn/src/common/lstm.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/oprs.h"
#include "src/common/utils.h"
// #include "src/cuda/lstm/utils.h"

namespace megdnn {

/*size_t get_reserve_size(Handle* handle, megdnn::LSTMForward::Param& param, const
TensorLayout& input) { #if CUDNN_MAJOR >= 6 auto holder =
megdnn::cuda::lstm::get_RNNDescHolder_v6(handle, param, input); return
holder.reserveSpace_size; # else return 0; #endif
}*/

void LSTM::deduce_layout(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
TensorLayout& cy, TensorLayout& reserve_space) {
// input: [seq_len, batch_size, input_size]
// hx: [D * num_layers, batch_size, hidden_size]
size_t seq_len = input.shape[0];
size_t batch_size = input.shape[1];
size_t D = param().bidirectional ? 2 : 1;
size_t hidden_size = hx.shape[2];
output = TensorLayout(
TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype);
hy = TensorLayout(hx);
cy = TensorLayout(cx);
// reserve_space = {{get_reserve_size(this->handle(), param(), input)},
// dtype::Byte()};
reserve_space = {{get_reserve_size_in_bytes(input)}, dtype::Byte()};
}

void LSTM::check_exec(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& reserve_space, size_t workspace_in_bytes) {
auto errmsg = [&]() {
std::string msg;
msg.append("input=");
msg.append(input.to_string());
msg.append(", hx=");
msg.append(hx.to_string());
msg.append(", cx=");
msg.append(cx.to_string());
msg.append(", hidden_size=");
msg.append(std::to_string(param().hidden_size));
msg.append(", num_layers=");
msg.append(std::to_string(param().num_layers));
msg.append(", bidirectional=");
msg.append(std::to_string(param().bidirectional));
return msg;
};
size_t D = param().bidirectional ? 2 : 1;
size_t num_layers = param().num_layers;
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());

ASSERT_BRIEF(hx.ndim == 3)
ASSERT_BRIEF(hx.shape[0] == D * num_layers)
ASSERT_BRIEF(hx.shape[1] == input.shape[1]) // batch_size
ASSERT_BRIEF(hx.shape[2] == param().hidden_size)
ASSERT_BRIEF(cx.ndim == 3)
ASSERT_BRIEF(cx.shape[0] == D * num_layers)
ASSERT_BRIEF(cx.shape[1] == input.shape[1]) // batch_size
ASSERT_BRIEF(cx.shape[2] == param().hidden_size)
#undef ASSERT_BRIEF
}

void LSTMBackward::deduce_layout(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, TensorLayout& dx, TensorLayout& dhx,
TensorLayout& dcx, TensorLayout& dw) {
dx = x;
dhx = hx;
dcx = cx;
dw = flatten_weights;
}

void LSTMBackward::check_exec(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, const TensorLayout& dx,
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw,
size_t workspace_in_bytes) {}

} // namespace megdnn

+ 136
- 0
dnn/src/common/lstm_cell.cpp View File

@@ -0,0 +1,136 @@
/**
* \file dnn/src/common/lstm_cell.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/lstm_cell.h"
#include "megdnn/oprs.h"
#include "src/common/utils.h"

namespace megdnn {

void LSTMCell::deduce_layout(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new,
TensorLayout& gates) {
// size_t batch_size = hx.shape[0];
// size_t hidden_size = hx.shape[1];
h_new = TensorLayout(hx, hx.dtype);
c_new = TensorLayout(cx, cx.dtype);
auto opr = handle()->create_operator<RNNCellForward>();
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates);
}

void LSTMCell::check_exec(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new,
const TensorLayout& gates, size_t workspace_in_bytes) {
TensorLayout h_new_expected, c_new_expected, gates_expected;
deduce_layout(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new_expected,
c_new_expected, gates_expected);
megdnn_assert_eq_layout(h_new_expected, h_new);
megdnn_assert_eq_layout(c_new_expected, c_new);
megdnn_assert_eq_layout(gates_expected, gates);

auto required_workspace_in_bytes = get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

} // namespace megdnn

namespace megdnn {
namespace lstm_cell {

size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new,
const TensorLayout& gates, Handle* handle) {
TensorLayout tmp_layout;
auto opr = handle->create_operator<RNNCellForward>();
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, tmp_layout);
size_t rnn_cell_need = opr->get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates);
size_t lstm_cell_need = tmp_layout.span().dist_byte();
return rnn_cell_need > lstm_cell_need ? rnn_cell_need : lstm_cell_need;
}

void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle) {
auto opr = handle->create_operator<RNNCellForward>();
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
/*TensorLayout tmp_layout;
opr->deduce_layout(input.layout, weight_ih.layout,
hx.layout, weight_hh.layout,
bias.layout, tmp_layout);
auto workspace_ptr = workspace.raw_ptr;
// TensorND tmp{static_cast<void*>(workspace.raw_ptr), tmp_layout};
TensorND tmp{workspace_ptr, tmp_layout};
auto new_workspace = Workspace{workspace_ptr + tmp.layout.span().dist_byte(),
workspace.size -
tmp.layout.span().dist_byte()};*/
// opr->exec(input, weight_ih, hx, weight_hh, bias, tmp, new_workspace);
opr->exec(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates, workspace);
// activation
// size_t batch_size = tmp.layout.shape[0];
size_t batch_size = hx.layout.shape[0];
size_t hidden_size = hx.layout.shape[1];
// sigmoid: i f o
// TensorLayout gates_ifo_layout{TensorShape({batch_size, hidden_size * 3}),
// tmp.layout.dtype};
TensorND tmp{static_cast<void*>(workspace.raw_ptr), gates.layout};
TensorLayout gates_ifo_layout{
TensorShape({batch_size, hidden_size * 3}), gates.layout.dtype};
TensorND gates_ifo_origin{gates.raw_ptr(), gates_ifo_layout};
TensorND gates_ifo{tmp.raw_ptr(), gates_ifo_layout};
auto sigmoid = handle->create_operator<ElemwiseForward>();
sigmoid->param().mode = Elemwise::Param::Mode::SIGMOID;
sigmoid->exec({gates_ifo_origin}, gates_ifo);
// tanh: g
TensorLayout g_layout{TensorShape({batch_size, hidden_size}), gates.layout.dtype};
TensorND g_origin{
static_cast<char*>(gates.raw_ptr()) + gates_ifo_layout.span().dist_byte(),
g_layout};
TensorND g{
static_cast<char*>(tmp.raw_ptr()) + gates_ifo_layout.span().dist_byte(),
g_layout};
auto tanh = handle->create_operator<ElemwiseForward>();
tanh->param().mode = Elemwise::Param::Mode::TANH;
tanh->exec({g_origin}, g);
// extract i f o
TensorND i{static_cast<char*>(tmp.raw_ptr()), g_layout};
TensorND f{
static_cast<char*>(tmp.raw_ptr()) + g_layout.span().dist_byte(), g_layout};
TensorND o{
static_cast<char*>(tmp.raw_ptr()) + g_layout.span().dist_byte() * 2,
g_layout};
// calculate new cell state
auto elewise_mul_add = handle->create_operator<ElemwiseForward>();
elewise_mul_add->param().mode = Elemwise::Param::Mode::FUSE_MUL_ADD4;
elewise_mul_add->exec({f, cx, i, g}, c_new);
// calculate new hidden state
tanh->exec({c_new}, h_new);
auto elewise_mul = handle->create_operator<ElemwiseForward>();
elewise_mul->param().mode = Elemwise::Param::Mode::MUL;
elewise_mul->exec({o, h_new}, h_new);
}

} // namespace lstm_cell
} // namespace megdnn

+ 32
- 0
dnn/src/common/lstm_cell.h View File

@@ -0,0 +1,32 @@
/**
* \file dnn/src/common/lstm_cell.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "megdnn/oprs/base.h"

namespace megdnn {
namespace lstm_cell {

size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new,
const TensorLayout& gates, Handle* handle);

void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle);

} // namespace lstm_cell
} // namespace megdnn

+ 2
- 0
dnn/src/common/opr_trait.h View File

@@ -135,6 +135,8 @@ DEF(CheckNonFinite, 2, true, true);
DEF(LSQForward, 5, true, true);
DEF(LSQBackward, 7, true, false);
DEF(Fill, 1, true, false);
DEF(RNNCellForward, 6, true, true);
DEF(RNNForward, 6, true, true);
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 2
- 3
dnn/src/common/relayout_format.cpp View File

@@ -402,9 +402,8 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
}

if (dst.type() == TensorFormat::Type::IMAGE2D_PACK4 &&
(
handle()->type() != Handle::HandleType::NAIVE &&
handle()->type() != Handle::HandleType::X86)) {
(handle()->type() != Handle::HandleType::NAIVE &&
handle()->type() != Handle::HandleType::X86)) {
megdnn_throw(
"Dump with Image2DPack4TensorFormat is not available on CUDA compnode, "
"try export CUDA_VISIBLE_DEVICES=\'\'");


+ 82
- 0
dnn/src/common/rnn.cpp View File

@@ -0,0 +1,82 @@
/**
* \file dnn/src/common/rnn.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/rnn.h"
#include "megdnn/oprs.h"
#include "src/common/utils.h"

namespace megdnn {

void RNN::deduce_layout(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, TensorLayout& output, TensorLayout& hy,
TensorLayout& reserve_space) {
// input: [seq_len, batch_size, input_size]
// hx: [D * num_layers, batch_size, hidden_size]
size_t seq_len = input.shape[0];
size_t batch_size = input.shape[1];
size_t D = param().bidirectional ? 2 : 1;
size_t hidden_size = hx.shape[2];
output = TensorLayout(
TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype);
hy = TensorLayout(hx);
// reserve_space = {{get_reserve_size(this->handle(), param(), input)},
// dtype::Byte()};
reserve_space = {{get_reserve_size_in_bytes(input)}, dtype::Byte()};
}

void RNN::check_exec(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& reserve_space,
size_t workspace_in_bytes) {
auto errmsg = [&]() {
std::string msg;
msg.append("input=");
msg.append(input.to_string());
msg.append(", hx=");
msg.append(hx.to_string());
msg.append(", hidden_size=");
msg.append(std::to_string(param().hidden_size));
msg.append(", num_layers=");
msg.append(std::to_string(param().num_layers));
msg.append(", bidirectional=");
msg.append(std::to_string(param().bidirectional));
return msg;
};
size_t D = param().bidirectional ? 2 : 1;
size_t num_layers = param().num_layers;
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());

ASSERT_BRIEF(hx.ndim == 3)
ASSERT_BRIEF(hx.shape[0] == D * num_layers)
ASSERT_BRIEF(hx.shape[1] == input.shape[1]) // batch_size
ASSERT_BRIEF(hx.shape[2] == param().hidden_size)
#undef ASSERT_BRIEF
}

void RNNBackward::deduce_layout(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
TensorLayout& dx, TensorLayout& dhx, TensorLayout& dw) {
dx = x;
dhx = hx;
dw = flatten_weights;
}

void RNNBackward::check_exec(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw,
size_t workspace_in_bytes) {}

} // namespace megdnn

+ 33
- 0
dnn/src/common/rnn.h View File

@@ -0,0 +1,33 @@
/**
* \file dnn/src/common/rnn.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/base.h"
#include "megdnn/oprs/general.h"

namespace megdnn {
namespace rnn {
using Param = param::RNN;

size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayoutArray& weight_ih,
const TensorLayoutArray& states, const TensorLayoutArray& weight_hh,
const TensorLayoutArray& bias, const TensorLayout& output,
const TensorLayoutArray& states_new, const Param& param, Handle* handle);

void exec(
_megdnn_tensor_in input, _megdnn_in const TensorNDArray& weight_ih,
_megdnn_in const TensorNDArray& states,
_megdnn_in const TensorNDArray& weight_hh, _megdnn_in const TensorNDArray& bias,
_megdnn_tensor_out output, _megdnn_out const TensorNDArray& states_new,
_megdnn_workspace workspace, const Param& param, Handle* handle);
} // namespace rnn
} // namespace megdnn

+ 108
- 0
dnn/src/common/rnn_cell.cpp View File

@@ -0,0 +1,108 @@
/**
* \file dnn/src/common/rnn_cell.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/rnn_cell.h"
#include "megdnn/oprs.h"
#include "src/common/utils.h"

namespace megdnn {

void RNNCell::deduce_layout(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh, TensorLayout& dst) {
// megdnn_assert(hx.ndim == 2);
size_t batch_size = hx.shape[0];
// size_t hidden_size = weight_hh.shape[1];
size_t gate_hidden_size = weight_ih.shape[0];
// size_t input_size = weight_ih.shape[1];
// megdnn_assert(input.shape[1] == input_size);
// megdnn_assert(hx.shape[1] == hidden_size);
// megdnn_assert_eq_dtype(input, hx);

dst = TensorLayout(TensorShape({batch_size, gate_hidden_size}), input.dtype);
}

void RNNCell::check_exec(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst, size_t workspace_in_bytes) {
TensorLayout dst_expected;
megdnn_assert_eq_dtype(input, dst);
megdnn_assert_eq_dtype(hx, dst);
deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst_expected);
megdnn_assert_eq_layout(dst_expected, dst);

auto required_workspace_in_bytes = get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

} // namespace megdnn

namespace megdnn {
namespace rnn_cell {

size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst, Handle* handle) {
auto opr = handle->create_operator<MatrixMulForward>();
opr->param().transposeB = true;
return dst.span().dist_byte() + opr->get_workspace_in_bytes(hx, weight_hh, dst);
}

void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace,
param::RNNCell::NonlineMode nonline_mode, Handle* handle) {
TensorND tmp{static_cast<void*>(workspace.raw_ptr), dst.layout};
_megdnn_workspace new_workspace = {
workspace.raw_ptr + dst.layout.span().dist_byte(),
workspace.size - dst.layout.span().dist_byte()};
auto opr = handle->create_operator<MatrixMulForward>();
opr->param().transposeB = true;
opr->exec(input, weight_ih, tmp, new_workspace);
opr->exec(hx, weight_hh, dst, new_workspace);
// if (this->param().bias) add_bias(dst, tmp, bias, dst);
// if (this->param().bias) {
auto add_opr = handle->create_operator<ElemwiseForward>();
add_opr->param().mode = Elemwise::Param::Mode::ADD;
add_opr->exec({dst, tmp}, dst);
add_opr->exec({dst, bias_ih}, dst);
add_opr->exec({dst, bias_hh}, dst);
// }

// activation
using NonlineMode = param::RNNCell::NonlineMode;

switch (nonline_mode) {
#define cb(_mode) \
case NonlineMode::_mode: { \
auto nonlinear = handle->create_operator<ElemwiseForward>(); \
nonlinear->param().mode = Elemwise::Param::Mode::_mode; \
nonlinear->exec({dst}, dst); \
break; \
}
cb(RELU);
cb(TANH);
#undef cb
case NonlineMode::IDENTITY:
break;
default:
megdnn_assert(false);
}
}

} // namespace rnn_cell
} // namespace megdnn

+ 31
- 0
dnn/src/common/rnn_cell.h View File

@@ -0,0 +1,31 @@
/**
* \file dnn/src/common/rnn_cell.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs/base.h"
#include "megdnn/oprs/general.h"

namespace megdnn {
namespace rnn_cell {

size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst, Handle* handle);

void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace,
param::RNNCell::NonlineMode nonline_mode, Handle* handle);

} // namespace rnn_cell
} // namespace megdnn

+ 114
- 0
dnn/src/cuda/cudnn_wrapper.cpp View File

@@ -160,6 +160,29 @@ void TensorDesc::set(
}
}

void TensorDesc::set_nd(const TensorLayout& layout, int pad) {
int nbDims = layout.ndim < pad ? pad : layout.ndim;
int dimA[nbDims], strideA[nbDims];

for (size_t i = 0; i < layout.ndim; ++i) {
dimA[i] = layout.shape[i];
// strideA[i] = layout.stride[i];
}
for (size_t i = layout.ndim; i < nbDims; ++i) {
dimA[i] = 1; // unused
// strideA[i] = 1;
}
// stride
for (size_t i = 0; i < nbDims; ++i) {
strideA[i] = 1;
for (size_t j = i + 1; j < nbDims; ++j) {
strideA[i] *= dimA[j];
}
}
cudnn_check(cudnnSetTensorNdDescriptor(
desc, to_cudnn_dtype(layout.dtype), nbDims, dimA, strideA));
}

std::string TensorDesc::to_string() {
cudnnDataType_t data_type;
int n;
@@ -433,6 +456,97 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) {
desc, 3, padA, filterStrideA, dilationA, mode, CUDNN_DATA_FLOAT));
}

DropoutDesc::DropoutDesc() {
cudnn_check(cudnnCreateDropoutDescriptor(&desc));
}

DropoutDesc::~DropoutDesc() {
cudnn_check(cudnnDestroyDropoutDescriptor(desc));
}

void DropoutDesc::set(float dropout, Handle* handle, TensorND& state) {
cudnn_check(cudnnSetDropoutDescriptor(
desc, cudnn_handle(handle), dropout, state.raw_ptr(),
state.layout.span().dist_byte(), 0 // seed
));
}

void DropoutDesc::set_no_dropout(Handle* handle) {
cudnn_check(
cudnnSetDropoutDescriptor(desc, cudnn_handle(handle), 0, nullptr, 0, 0));
}

RNNDesc::RNNDesc() {
cudnn_check(cudnnCreateRNNDescriptor(&desc));
}

RNNDesc::~RNNDesc() {
cudnn_check(cudnnDestroyRNNDescriptor(desc));
}

void RNNDesc::set(
size_t input_size, size_t hidden_size, size_t proj_size, size_t num_layers,
bool bidirectional, bool bias, const megdnn::DType dtype, cudnnRNNMode_t mode,
DropoutDesc& dropout_desc, Handle* handle) {
cudnnRNNMode_t rnn_mode = mode;
cudnnRNNBiasMode_t bias_mode = bias ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS;
cudnnDirectionMode_t dir_mode =
bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
cudnnDataType_t math_prec;

// math precision
if (dtype.enumv() == DTypeEnum::Float16)
math_prec = CUDNN_DATA_HALF;
else
math_prec = CUDNN_DATA_FLOAT;

#if false // CUDNN_MAJOR >= 8
cudnn_check(cudnnSetRNNDescriptor_v8(
desc, CUDNN_RNN_ALGO_STANDARD, mode, bias_mode, dir_mode,
CUDNN_LINEAR_INPUT, to_cudnn_dtype(dtype), math_prec, CUDNN_DEFAULT_MATH,
input_size, hidden_size, proj_size, num_layers, dropout_desc.desc,
CUDNN_RNN_PADDED_IO_DISABLED));
#else
cudnn_check(cudnnSetRNNDescriptor_v6(
cudnn_handle(handle), desc, hidden_size, num_layers, dropout_desc.desc,
CUDNN_LINEAR_INPUT, dir_mode, mode, CUDNN_RNN_ALGO_STANDARD, math_prec));
#endif
}

RNNDataDesc::RNNDataDesc() {
cudnn_check(cudnnCreateRNNDataDescriptor(&desc));
}

RNNDataDesc::~RNNDataDesc() {
cudnn_check(cudnnDestroyRNNDataDescriptor(desc));
}

void RNNDataDesc::set(
int batchSize, int vectorSize, int maxSeqLength, const int* devSeqLengths,
DType dtype) {
// for now, all tensor are padded in python
// int seqLengthArray[batchSize];
// for (int i = 0; i < batchSize; ++i) seqLengthArray[i] = maxSeqLength;
cudnn_check(cudnnSetRNNDataDescriptor(
desc, to_cudnn_dtype(dtype), CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED,
maxSeqLength, batchSize, vectorSize, devSeqLengths, nullptr));
}

RNNWeightFilterDesc::RNNWeightFilterDesc() {
cudnn_check(cudnnCreateFilterDescriptor(&desc));
}

RNNWeightFilterDesc::~RNNWeightFilterDesc() {
cudnn_check(cudnnDestroyFilterDescriptor(desc));
}

void RNNWeightFilterDesc::set(const TensorLayout& flatten_weights) {
int weight_elem_num = flatten_weights.total_nr_elems();
int dimW[] = {weight_elem_num, 1, 1};
cudnn_check(cudnnSetFilterNdDescriptor(
desc, to_cudnn_dtype(flatten_weights.dtype), CUDNN_TENSOR_NCHW, 3, dimW));
}

////////////////////////// CudnnAlgoPack //////////////////////////

#define V1(v) #v


+ 39
- 0
dnn/src/cuda/cudnn_wrapper.h View File

@@ -30,6 +30,7 @@ public:
void set(
const TensorLayout& layout,
const param::Convolution::Format = param::Convolution::Format::NCHW);
void set_nd(const TensorLayout& layout, int pad = 3); // at least 3 dimensions
std::string to_string();
~TensorDesc();
cudnnTensorDescriptor_t desc;
@@ -121,6 +122,44 @@ public:
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, Attr> conv3d_fwd_algos();
};

class DropoutDesc {
public:
DropoutDesc();
void set(float dropout, Handle* handle, TensorND& state);
void set_no_dropout(Handle* handle);
~DropoutDesc();
cudnnDropoutDescriptor_t desc;
};

class RNNDesc {
public:
RNNDesc();
void set(
size_t input_size, size_t hidden_size, size_t proj_size, size_t num_layers,
bool bidirectional, bool bias, const megdnn::DType dtype,
cudnnRNNMode_t mode, DropoutDesc& dropout_desc, Handle* handle);
~RNNDesc();
cudnnRNNDescriptor_t desc;
};

class RNNDataDesc {
public:
RNNDataDesc();
void set(
int batchSize, int vectorSize, int maxSeqLength, const int* devSeqLengths,
DType dtype);
~RNNDataDesc();
cudnnRNNDataDescriptor_t desc;
};

class RNNWeightFilterDesc {
public:
RNNWeightFilterDesc();
void set(const TensorLayout& flatten_weights);
~RNNWeightFilterDesc();
cudnnFilterDescriptor_t desc;
};

} // namespace cuda
} // namespace megdnn



+ 4
- 0
dnn/src/cuda/handle_create.cpp View File

@@ -50,6 +50,8 @@
#include "src/cuda/local_share/opr_impl.h"
#include "src/cuda/lrn/opr_impl.h"
#include "src/cuda/lsq/opr_impl.h"
#include "src/cuda/lstm/opr_impl.h"
#include "src/cuda/lstm_cell/opr_impl.h"
#include "src/cuda/mask_conv/opr_impl.h"
#include "src/cuda/matrix_inverse/opr_impl.h"
#include "src/cuda/matrix_mul/opr_impl.h"
@@ -66,6 +68,8 @@
#include "src/cuda/repeat/opr_impl.h"
#include "src/cuda/resize/opr_impl.h"
#include "src/cuda/rng/opr_impl.h"
#include "src/cuda/rnn/opr_impl.h"
#include "src/cuda/rnn_cell/opr_impl.h"
#include "src/cuda/roi_align/opr_impl.h"
#include "src/cuda/roi_copy/opr_impl.h"
#include "src/cuda/roi_pooling/opr_impl.h"


+ 112
- 0
dnn/src/cuda/lstm/opr_impl.cpp View File

@@ -0,0 +1,112 @@
/**
* \file dnn/src/cuda/lstm/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/lstm/opr_impl.h"
#include "src/cuda/lstm/utils.h"
#include "src/cuda/utils.h"

#include <cudnn.h>

namespace megdnn {
namespace cuda {

void LSTMImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out reserve_space,
_megdnn_workspace workspace) {
Handle* handle = this->handle();

rnn::RNNForwardDescHolder_v6 desc_holder =
lstm::get_RNNDescHolder_v6(this->handle(), param(), input.layout);
auto x_desc_arr = rnn::get_descs(desc_holder.x_descs);
auto y_desc_arr = rnn::get_descs(desc_holder.y_descs);
RNNWeightFilterDesc w_desc;
w_desc.set(flatten_weights.layout);

if (param().fwd_mode == param::LSTM::FwdMode::TRAINING) {
cudnn_check(cudnnRNNForwardTraining(
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len,
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc,
hx.raw_ptr(), desc_holder.cx_desc.desc, cx.raw_ptr(), w_desc.desc,
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(),
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc,
cy.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size,
reserve_space.raw_ptr(), desc_holder.reserveSpace_size));
} else {
cudnn_check(cudnnRNNForwardInference(
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len,
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc,
hx.raw_ptr(), desc_holder.cx_desc.desc, nullptr, w_desc.desc,
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(),
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc,
cy.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size));
}
}

size_t LSTMImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& reserve_space) {
rnn::RNNForwardDescHolder_v6 desc_holder =
lstm::get_RNNDescHolder_v6(this->handle(), param(), input);
return desc_holder.workspace_size;
}

size_t LSTMImpl::get_reserve_size_in_bytes(const TensorLayout& input) {
rnn::RNNForwardDescHolder_v6 desc_holder =
lstm::get_RNNDescHolder_v6(this->handle(), param(), input);
return desc_holder.reserveSpace_size;
}

void LSTMBackwardImpl::exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights,
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx,
_megdnn_tensor_out dcx, _megdnn_tensor_out dw, _megdnn_workspace workspace) {
Handle* handle = this->handle();
size_t seq_len = x.layout.shape[0];
auto desc_holder = lstm::get_RNNDescHolder_v6(handle, param(), x.layout);
auto x_desc_arr_ptr = rnn::get_descs(desc_holder.x_descs).data();
auto y_desc_arr_ptr = rnn::get_descs(desc_holder.y_descs).data();
RNNWeightFilterDesc w_desc;
w_desc.set(flatten_weights.layout);

cudnn_check(cudnnRNNBackwardData(
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, y_desc_arr_ptr,
y.raw_ptr(), y_desc_arr_ptr, dy.raw_ptr(), desc_holder.hy_desc.desc,
dhy.raw_ptr(), desc_holder.cy_desc.desc, dcy.raw_ptr(), w_desc.desc,
flatten_weights.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(),
desc_holder.cx_desc.desc, cx.raw_ptr(), x_desc_arr_ptr, dx.raw_ptr(),
desc_holder.hx_desc.desc, dhx.raw_ptr(), desc_holder.cx_desc.desc,
dcx.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size,
reserve_space.raw_ptr(), desc_holder.reserveSpace_size));

cudnn_check(cudnnRNNBackwardWeights(
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, x_desc_arr_ptr,
x.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), y_desc_arr_ptr,
y.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, w_desc.desc,
dw.raw_ptr(), reserve_space.raw_ptr(), desc_holder.reserveSpace_size));
}

size_t LSTMBackwardImpl::get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, const TensorLayout& dx,
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw) {
auto desc_holder = lstm::get_RNNDescHolder_v6(this->handle(), param(), x);
return desc_holder.workspace_size;
}

} // namespace cuda
} // namespace megdnn

+ 56
- 0
dnn/src/cuda/lstm/opr_impl.h View File

@@ -0,0 +1,56 @@
/**
* \file dnn/src/cuda/lstm/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"

namespace megdnn {
namespace cuda {

class LSTMImpl : public LSTM {
public:
using LSTM::LSTM;

void exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out cy,
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace);

size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& reserve_space);
size_t get_reserve_size_in_bytes(const TensorLayout& input);
};

class LSTMBackwardImpl : public LSTMBackward {
public:
using LSTMBackward::LSTMBackward;

virtual void exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights,
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx,
_megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw,
_megdnn_workspace workspace);

virtual size_t get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, const TensorLayout& dx,
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw);
};

} // namespace cuda
} // namespace megdnn

+ 39
- 0
dnn/src/cuda/lstm/utils.cpp View File

@@ -0,0 +1,39 @@
/**
* \file dnn/src/cuda/lstm/utils.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/lstm/utils.h"
#include "src/cuda/utils.h"

#include <cudnn.h>

namespace megdnn {
namespace cuda {
namespace lstm {

RNNForwardDescHolder_v6 get_RNNDescHolder_v6(
Handle* handle, megdnn::LSTMForward::Param& _param, const TensorLayout& input) {
size_t seq_len = input.shape[0];
size_t batch_size = input.shape[1];
size_t input_size = input.shape[2];

cudnnRNNMode_t mode = CUDNN_LSTM;

using FwdMode = param::LSTM::FwdMode;

RNNForwardDescHolder_v6 desc_holder(
handle, seq_len, batch_size, _param.hidden_size, input_size,
_param.proj_size, _param.num_layers, _param.bidirectional, _param.bias,
input.dtype, mode);
return desc_holder;
}

} // namespace lstm
} // namespace cuda
} // namespace megdnn

+ 23
- 0
dnn/src/cuda/lstm/utils.h View File

@@ -0,0 +1,23 @@
/**
* \file dnn/src/cuda/lstm/utils.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/rnn/utils.h"

namespace megdnn {
namespace cuda {
namespace lstm {
using megdnn::cuda::rnn::RNNForwardDescHolder_v6;
RNNForwardDescHolder_v6 get_RNNDescHolder_v6(
Handle* handle, megdnn::LSTMForward::Param& _param, const TensorLayout& input);
} // namespace lstm
} // namespace cuda
} // namespace megdnn

+ 42
- 0
dnn/src/cuda/lstm_cell/opr_impl.cpp View File

@@ -0,0 +1,42 @@
/**
* \file dnn/src/cuda/lstm_cell/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/lstm_cell/opr_impl.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs/base.h"
#include "src/common/lstm_cell.h"
#include "src/common/opr_delegate.h"
#include "src/common/utils.h"

namespace megdnn {
namespace cuda {
size_t LSTMCellImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new,
const TensorLayout& gates) {
return megdnn::lstm_cell::get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates,
handle());
}

void LSTMCellImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace) {
megdnn::lstm_cell::exec(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates,
workspace, handle());
}
} // namespace cuda

} // namespace megdnn

+ 36
- 0
dnn/src/cuda/lstm_cell/opr_impl.h View File

@@ -0,0 +1,36 @@
/**
* \file dnn/src/cuda/lstm_cell/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/cuda/rnn_cell/opr_impl.h"

namespace megdnn {
namespace cuda {

class LSTMCellImpl : public LSTMCell {
public:
using LSTMCell::LSTMCell;
void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new,
const TensorLayout& c_new, const TensorLayout& gates) override;
};

} // namespace cuda
} // namespace megdnn

+ 170
- 0
dnn/src/cuda/rnn/opr_impl.cpp View File

@@ -0,0 +1,170 @@
/**
* \file dnn/src/cuda/rnn/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/rnn/opr_impl.h"
#include "src/common/rnn.h"
#include "src/cuda/utils.h"

//#include <cstring>
#include <cudnn.h>
#include <cstdlib>
#include <iostream>

namespace megdnn {
namespace cuda {

using namespace std;

void RNNImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
_megdnn_workspace workspace) {
Handle* handle = this->handle();

#if false // CUDNN_MAJOR >= 8
rnn::RNNForwardDescHolder desc_holder = this->get_desc_holder(input.layout);

void* workspace_ptr = workspace.raw_ptr;
void* reserveSpace_ptr = static_cast<uint8_t*>(workspace_ptr) + desc_holder.workspace_size;

cudnn_check(cudnnRNNForward(
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.fwdMode, desc_holder.devSeqLengths,
desc_holder.x_desc.desc, input.raw_ptr(), desc_holder.y_desc.desc, output.raw_ptr(),
desc_holder.h_desc.desc, hx.raw_ptr(), hy.raw_ptr(),
desc_holder.h_desc.desc, nullptr, nullptr,
desc_holder.weight_size, flatten_weights.raw_ptr(), desc_holder.workspace_size, workspace_ptr,
desc_holder.reserveSpace_size, reserveSpace_ptr
));
#else
rnn::RNNForwardDescHolder_v6 desc_holder =
rnn::get_RNNDescHolder_v6(this->handle(), param(), input.layout);
auto x_desc_arr = rnn::get_descs(desc_holder.x_descs);
auto y_desc_arr = rnn::get_descs(desc_holder.y_descs);
RNNWeightFilterDesc w_desc;
w_desc.set(flatten_weights.layout);

if (param().fwd_mode == param::RNN::FwdMode::TRAINING) {
cudnn_check(cudnnRNNForwardTraining(
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len,
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc,
hx.raw_ptr(), desc_holder.cx_desc.desc, NULL, w_desc.desc,
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(),
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc, NULL,
workspace.raw_ptr, desc_holder.workspace_size, reserve_space.raw_ptr(),
desc_holder.reserveSpace_size));
} else {
cudnn_check(cudnnRNNForwardInference(
cudnn_handle(handle), desc_holder.rnn_desc.desc, desc_holder.seq_len,
x_desc_arr.data(), input.raw_ptr(), desc_holder.hx_desc.desc,
hx.raw_ptr(), desc_holder.cx_desc.desc, nullptr, w_desc.desc,
flatten_weights.raw_ptr(), y_desc_arr.data(), output.raw_ptr(),
desc_holder.hy_desc.desc, hy.raw_ptr(), desc_holder.cy_desc.desc,
nullptr, workspace.raw_ptr, desc_holder.workspace_size));
}
#endif
}

size_t RNNImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& reserve_space) {
#if false // CUDNN_MAJOR >= 8
rnn::RNNForwardDescHolder desc_holder = this->get_desc_holder(input);
#else
rnn::RNNForwardDescHolder_v6 desc_holder =
rnn::get_RNNDescHolder_v6(this->handle(), param(), input);
#endif
return desc_holder.workspace_size;
}

size_t RNNImpl::get_reserve_size_in_bytes(const TensorLayout& input) {
rnn::RNNForwardDescHolder_v6 desc_holder =
rnn::get_RNNDescHolder_v6(this->handle(), param(), input);
return desc_holder.reserveSpace_size;
}

/*rnn::RNNForwardDescHolder RNNImpl::get_desc_holder(const TensorLayout& input) {
Handle* handle = this->handle();
size_t seq_len = input.shape[0];
size_t batch_size = input.shape[1];
size_t input_size = input.shape[2];
auto _param = param();

cudnnRNNMode_t mode;
using NonlineMode = param::RNN::NonlineMode;
switch (_param.nonlineMode) {
case NonlineMode::RELU:
mode = CUDNN_RNN_RELU;
break;
case NonlineMode::TANH:
mode = CUDNN_RNN_TANH;
break;
}

cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_TRAINING;
using FwdMode = param::RNN::FwdMode;
switch (_param.fwd_mode) {
case FwdMode::TRAINING:
fwdMode = CUDNN_FWD_MODE_TRAINING;
break;
case FwdMode::INFERENCE:
fwdMode = CUDNN_FWD_MODE_INFERENCE;
break;
}

rnn::RNNForwardDescHolder desc_holder(
handle, seq_len, batch_size, _param.hidden_size, input_size,
_param.proj_size, _param.num_layers, _param.bidirectional, _param.bias,
input.dtype, mode, fwdMode);
return desc_holder;
}*/

void RNNBackwardImpl::exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in dy, _megdnn_tensor_in dhy, _megdnn_tensor_in flatten_weights,
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx,
_megdnn_tensor_out dw, _megdnn_workspace workspace) {
Handle* handle = this->handle();
size_t seq_len = x.layout.shape[0];
auto desc_holder = rnn::get_RNNDescHolder_v6(handle, param(), x.layout);
auto x_desc_arr_ptr = rnn::get_descs(desc_holder.x_descs).data();
auto y_desc_arr_ptr = rnn::get_descs(desc_holder.y_descs).data();
RNNWeightFilterDesc w_desc;
w_desc.set(flatten_weights.layout);

cudnn_check(cudnnRNNBackwardData(
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, y_desc_arr_ptr,
y.raw_ptr(), y_desc_arr_ptr, dy.raw_ptr(), desc_holder.hy_desc.desc,
dhy.raw_ptr(), desc_holder.cy_desc.desc, NULL, w_desc.desc,
flatten_weights.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(),
desc_holder.cx_desc.desc, NULL, x_desc_arr_ptr, dx.raw_ptr(),
desc_holder.hx_desc.desc, dhx.raw_ptr(), desc_holder.cx_desc.desc, NULL,
workspace.raw_ptr, desc_holder.workspace_size, reserve_space.raw_ptr(),
desc_holder.reserveSpace_size));

cudnn_check(cudnnRNNBackwardWeights(
cudnn_handle(handle), desc_holder.rnn_desc.desc, seq_len, x_desc_arr_ptr,
x.raw_ptr(), desc_holder.hx_desc.desc, hx.raw_ptr(), y_desc_arr_ptr,
y.raw_ptr(), workspace.raw_ptr, desc_holder.workspace_size, w_desc.desc,
dw.raw_ptr(), reserve_space.raw_ptr(), desc_holder.reserveSpace_size));
}

size_t RNNBackwardImpl::get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw) {
auto desc_holder = rnn::get_RNNDescHolder_v6(this->handle(), param(), x);
return desc_holder.workspace_size;
}

} // namespace cuda
} // namespace megdnn

+ 57
- 0
dnn/src/cuda/rnn/opr_impl.h View File

@@ -0,0 +1,57 @@
/**
* \file dnn/src/cuda/rnn/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/rnn/utils.h"

namespace megdnn {
namespace cuda {

class RNNImpl : public RNN {
public:
using RNN::RNN;

void exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
_megdnn_workspace workspace);

size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& reserve_space);
size_t get_reserve_size_in_bytes(const TensorLayout& input);
// private:
// rnn::RNNForwardDescHolder get_desc_holder(const TensorLayout& input);
};

class RNNBackwardImpl : public RNNBackward {
public:
using RNNBackward::RNNBackward;

virtual void exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in dy, _megdnn_tensor_in dhy,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space,
_megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw,
_megdnn_workspace workspace);

virtual size_t get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw);
};

} // namespace cuda
} // namespace megdnn

+ 138
- 0
dnn/src/cuda/rnn/utils.cpp View File

@@ -0,0 +1,138 @@
/**
* \file dnn/src/cuda/rnn/utils.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/rnn/utils.h"
#include "src/cuda/utils.h"

#include <cudnn.h>

namespace megdnn {
namespace cuda {
namespace rnn {
/*RNNForwardDescHolder::RNNForwardDescHolder(Handle* handle, size_t seq_len, size_t
batch_size, size_t hidden_size, size_t input_size, size_t proj_size, size_t num_layers,
bool bidirectional, bool bias, DType dtype, cudnnRNNMode_t _mode, cudnnForwardMode_t
_fwdMode) : mode(_mode), fwdMode(_fwdMode)
{
size_t D = bidirectional ? 2 : 1;

// TODO: set dropout to 0 in inference mode
dropout_desc.set_no_dropout(handle);

// seq len is unified (not packed)
// cuda_check(cudaMalloc((void**)&devSeqLengths, sizeof(int32_t) * batch_size));
devSeqLengths = (int32_t*)malloc(sizeof(int32_t) * batch_size);
for (size_t i = 0; i < batch_size; ++i) devSeqLengths[i] = seq_len;

// proj size should be smaller than hidden size according to cudnn api
// otherwise it is disabled
proj_size = (proj_size > hidden_size || proj_size == 0) ? hidden_size :
proj_size; rnn_desc.set( input_size, hidden_size, proj_size, num_layers, bidirectional,
bias, dtype, mode, dropout_desc, handle
);

x_desc.set(batch_size, input_size, seq_len, devSeqLengths, dtype);
y_desc.set(batch_size, D * proj_size, seq_len,
devSeqLengths, dtype);
h_desc.set_nd(TensorLayout(TensorShape{D * num_layers, batch_size, proj_size},
dtype));

cudnn_check(cudnnGetRNNWeightSpaceSize(cudnn_handle(handle), rnn_desc.desc,
&weight_size));

cudnn_check(cudnnGetRNNTempSpaceSizes(
cudnn_handle(handle), rnn_desc.desc, fwdMode, x_desc.desc,
&workspace_size, &reserveSpace_size
));
}

RNNForwardDescHolder::~RNNForwardDescHolder() {
// cuda_check(cudaFree(devSeqLengths));
free(devSeqLengths);
}*/

RNNForwardDescHolder_v6::RNNForwardDescHolder_v6(
Handle* handle, size_t seq_len, size_t batch_size, size_t hidden_size,
size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional,
bool bias, DType dtype, cudnnRNNMode_t _mode)
: mode(_mode), seq_len(seq_len) {
size_t D = bidirectional ? 2 : 1;

// TODO: set dropout to 0 in inference mode
dropout_desc.set_no_dropout(handle);

proj_size = (proj_size > hidden_size || proj_size == 0) ? hidden_size : proj_size;
rnn_desc.set(
input_size, hidden_size, proj_size, num_layers, bidirectional, bias, dtype,
mode, dropout_desc, handle);

x_descs.resize(seq_len);
y_descs.resize(seq_len);
for (size_t i = 0; i < seq_len; ++i) {
x_descs[i].set_nd(TensorLayout(TensorShape{batch_size, input_size}, dtype), 3);
y_descs[i].set_nd(
TensorLayout(TensorShape{batch_size, D * hidden_size}, dtype), 3);
}

#define SET_H(_var) \
_var.set_nd(TensorLayout( \
TensorShape{D * num_layers, batch_size, hidden_size}, dtype));

SET_H(hx_desc)
SET_H(cx_desc)
SET_H(hy_desc)
SET_H(cy_desc)
#undef SET_H

std::vector<cudnnTensorDescriptor_t> x_desc_arr = get_descs(x_descs);
cudnn_check(cudnnGetRNNWorkspaceSize(
cudnn_handle(handle), rnn_desc.desc, seq_len, x_desc_arr.data(),
&workspace_size));

cudnn_check(cudnnGetRNNTrainingReserveSize(
cudnn_handle(handle), rnn_desc.desc, seq_len, x_desc_arr.data(),
&reserveSpace_size));
}

RNNForwardDescHolder_v6 get_RNNDescHolder_v6(
Handle* handle, megdnn::RNNForward::Param& _param, const TensorLayout& input) {
size_t seq_len = input.shape[0];
size_t batch_size = input.shape[1];
size_t input_size = input.shape[2];

cudnnRNNMode_t mode;
using NonlineMode = param::RNN::NonlineMode;
switch (_param.nonlineMode) {
case NonlineMode::RELU:
mode = CUDNN_RNN_RELU;
break;
case NonlineMode::TANH:
mode = CUDNN_RNN_TANH;
break;
}

RNNForwardDescHolder_v6 desc_holder(
handle, seq_len, batch_size, _param.hidden_size, input_size,
_param.proj_size, _param.num_layers, _param.bidirectional, _param.bias,
input.dtype, mode);
return desc_holder;
}

std::vector<cudnnTensorDescriptor_t> get_descs(const std::vector<TensorDesc>& descs) {
std::vector<cudnnTensorDescriptor_t> r;
r.reserve(descs.size());
for (auto& desc : descs) {
r.emplace_back(desc.desc);
}
return r;
}
} // namespace rnn
} // namespace cuda
} // namespace megdnn

+ 56
- 0
dnn/src/cuda/rnn/utils.h View File

@@ -0,0 +1,56 @@
/**
* \file dnn/src/cuda/rnn/utils.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/cuda/cudnn_wrapper.h"

namespace megdnn {
namespace cuda {
namespace rnn {
// v8, not for now
/*struct RNNForwardDescHolder {

int32_t* devSeqLengths;
cudnnRNNMode_t mode;
cudnnForwardMode_t fwdMode;
RNNDesc rnn_desc;
DropoutDesc dropout_desc;
RNNDataDesc x_desc, y_desc;
TensorDesc h_desc;
size_t weight_size, workspace_size, reserveSpace_size;

RNNForwardDescHolder(Handle* handle, size_t seq_len, size_t batch_size, size_t
hidden_size, size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional,
bool bias, DType dtype,
cudnnRNNMode_t _mode, cudnnForwardMode_t _fwdMode); ~RNNForwardDescHolder();
};*/

struct RNNForwardDescHolder_v6 {
cudnnRNNMode_t mode;
RNNDesc rnn_desc;
int seq_len;
DropoutDesc dropout_desc;
std::vector<TensorDesc> x_descs, y_descs;
TensorDesc hx_desc, cx_desc, hy_desc, cy_desc;

size_t workspace_size, reserveSpace_size;

RNNForwardDescHolder_v6(
Handle* handle, size_t seq_len, size_t batch_size, size_t hidden_size,
size_t input_size, size_t proj_size, size_t num_layers, bool bidirectional,
bool bias, DType dtype, cudnnRNNMode_t _mode);
};

RNNForwardDescHolder_v6 get_RNNDescHolder_v6(
Handle* handle, megdnn::RNNForward::Param& _param, const TensorLayout& input);
std::vector<cudnnTensorDescriptor_t> get_descs(const std::vector<TensorDesc>& descs);
} // namespace rnn
} // namespace cuda
} // namespace megdnn

+ 35
- 0
dnn/src/cuda/rnn_cell/opr_impl.cpp View File

@@ -0,0 +1,35 @@
/**
* \file dnn/src/cuda/rnn_cell/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/rnn_cell/opr_impl.h"
#include "src/common/rnn_cell.h"

namespace megdnn {
namespace cuda {
size_t RNNCellImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst) {
return megdnn::rnn_cell::get_workspace_in_bytes(
input, weight_ih, bias_hh, hx, weight_hh, bias_hh, dst, handle());
}

void RNNCellImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace) {
megdnn::rnn_cell::exec(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, workspace,
param().nonlineMode, handle());
}
} // namespace cuda

} // namespace megdnn

+ 40
- 0
dnn/src/cuda/rnn_cell/opr_impl.h View File

@@ -0,0 +1,40 @@
/**
* \file dnn/src/cuda/rnn_cell/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"

namespace megdnn {
namespace cuda {

class RNNCellImpl : public RNNCell {
public:
using RNNCell::RNNCell;
void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst) override;
/*
private:
void add_bias(_megdnn_tensor_in A,
_megdnn_tensor_in B,
_megdnn_tensor_in bias,
_megdnn_tensor_out C);
*/
};

} // namespace cuda
} // namespace megdnn

+ 4
- 0
dnn/src/naive/handle.cpp View File

@@ -52,6 +52,8 @@
#include "src/naive/local_share/opr_impl.h"
#include "src/naive/lrn/opr_impl.h"
#include "src/naive/lsq/opr_impl.h"
#include "src/naive/lstm/opr_impl.h"
#include "src/naive/lstm_cell/opr_impl.h"
#include "src/naive/mask_conv/opr_impl.h"
#include "src/naive/matrix_inverse/opr_impl.h"
#include "src/naive/matrix_mul/opr_impl.h"
@@ -68,6 +70,8 @@
#include "src/naive/repeat/opr_impl.h"
#include "src/naive/resize/opr_impl.h"
#include "src/naive/rng/opr_impl.h"
#include "src/naive/rnn/opr_impl.h"
#include "src/naive/rnn_cell/opr_impl.h"
#include "src/naive/roi_align/opr_impl.h"
#include "src/naive/roi_copy/opr_impl.h"
#include "src/naive/roi_pooling/opr_impl.h"


+ 146
- 0
dnn/src/naive/lstm/opr_impl.cpp View File

@@ -0,0 +1,146 @@
/**
* \file dnn/src/naive/lstm/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/naive/lstm/opr_impl.h"
#include "src/naive/rnn/funcs.h"
#include "src/naive/rnn/rnn.h"

namespace megdnn {
namespace naive {
using rnn::LSTMCellWeightWrapper;

void LSTMImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out reserve_space,
_megdnn_workspace workspace) {
auto _param = param();
size_t D = _param.bidirectional ? 2 : 1;
size_t num_layers = _param.num_layers;
size_t input_size = input.layout.shape[2];
std::vector<LSTMCellWeightWrapper> cells;
size_t used_workspace_size = rnn::get_cells<LSTMCellWeightWrapper>(
D, num_layers, input_size, _param.hidden_size, _param.bias, cells,
flatten_weights, workspace);

Workspace new_workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
TensorNDArray states = {hx, cx}, states_new = {hy, cy};
rnn::exec_internal<LSTMCellWeightWrapper, LSTMCellForward>(
cells, input, states, states_new, output, reserve_space, num_layers, D,
this->handle(), new_workspace);
}

size_t LSTMImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& reserve_space) {
size_t workspace_size = rnn::get_workspace_in_bytes<LSTMCellForward>(
input, flatten_weights, param().hidden_size, param().bidirectional ? 2 : 1,
this->handle());
if (!param().bias) { // use fake bias (all 0)
TensorLayout bias_layout = {{param().hidden_size * 4}, flatten_weights.dtype};
workspace_size += bias_layout.span().dist_byte();
}
workspace_size += output.span().dist_byte();
return workspace_size;
}

size_t LSTMImpl::get_reserve_size_in_bytes(const TensorLayout& input) {
size_t num_layers = param().num_layers;
size_t D = param().bidirectional ? 2 : 1;
size_t seq_len = input.shape[0];
size_t batch_size = input.shape[1];
TensorLayout state_layout{{batch_size, param().hidden_size}, input.dtype};
// 2 for hidden state and cell state
return 2 * num_layers * D * seq_len * state_layout.span().dist_byte();
}

void LSTMBackwardImpl::exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights,
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx,
_megdnn_tensor_out dcx, _megdnn_tensor_out dw, _megdnn_workspace workspace) {
TensorNDArray layer_inputs;
TensorNDArray layer_outputs;
std::vector<std::vector<TensorNDArray>> cell_seq_states;
size_t num_layers = param().num_layers;
size_t D = param().bidirectional ? 2 : 1;
size_t input_size = x.layout.shape[2];
size_t hidden_size = param().hidden_size;
size_t used_workspace_size = 0;

// get cells
std::vector<LSTMCellWeightWrapper> cells;
used_workspace_size += rnn::get_cells<LSTMCellWeightWrapper>(
D, num_layers, input_size, hidden_size, param().bias, cells,
flatten_weights, workspace);

// get formatted inputs
Workspace new_workspace = Workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
used_workspace_size += rnn::get_inputs_for_exec<LSTMCellWeightWrapper>(
x, y, reserve_space, num_layers, D, hidden_size, cells, layer_inputs,
layer_outputs, cell_seq_states, param::RNNCell::NonlineMode::IDENTITY,
new_workspace);

// dhy arr, dhx arr
TensorNDArray dhy_arr = {dhy, dcy}, dhx_arr = {dhx, dcx};

// exec
new_workspace = Workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
rnn::backward_exec_internal<LSTMCellWeightWrapper>(
cells, D, num_layers, input_size, param().bias,
param::RNNCell::NonlineMode::IDENTITY, layer_inputs, layer_outputs,
cell_seq_states, dy, dhy_arr, dx, dhx_arr, dw, this->handle(),
new_workspace);
}

size_t LSTMBackwardImpl::get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, const TensorLayout& dx,
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw) {
size_t D = param().bidirectional ? 2 : 1;
size_t num_layers = param().num_layers;
size_t hidden_size = param().hidden_size;
size_t gate_hidden_size = hidden_size * 4;
size_t max_input_size = std::max(x.shape[2], D * hidden_size);

size_t workspace_size = LSTMCellWeightWrapper::backward_workspace_size_in_bytes(
this->handle(), x.shape[1], param().hidden_size, max_input_size, x.dtype);
if (!param().bias) { // use fake bias (all 0)
TensorLayout bias_layout = {{gate_hidden_size}, flatten_weights.dtype};
workspace_size += bias_layout.span().dist_byte() *
2; // times 2 because another bias is allocated in
// backward_exec_internal
}
workspace_size += num_layers * y.span().dist_byte();
// add back exec workspace size
workspace_size += y.span().dist_byte() * 2;
workspace_size += x.span().dist_byte() * 2;
TensorLayout wih{{gate_hidden_size, max_input_size}, flatten_weights.dtype};
TensorLayout whh{{gate_hidden_size, hidden_size}, flatten_weights.dtype};
TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype};
workspace_size += wih.span().dist_byte();
workspace_size += whh.span().dist_byte();
workspace_size += bias.span().dist_byte();
return workspace_size;
}
} // namespace naive

} // namespace megdnn

+ 56
- 0
dnn/src/naive/lstm/opr_impl.h View File

@@ -0,0 +1,56 @@
/**
* \file dnn/src/naive/lstm/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"

namespace megdnn {
namespace naive {

class LSTMImpl : public LSTM {
public:
using LSTM::LSTM;

void exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out cy,
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace);

size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& reserve_space);
size_t get_reserve_size_in_bytes(const TensorLayout& input);
};

class LSTMBackwardImpl : public LSTMBackward {
public:
using LSTMBackward::LSTMBackward;

virtual void exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in cx, _megdnn_tensor_in dy, _megdnn_tensor_in dhy,
_megdnn_tensor_in dcy, _megdnn_tensor_in flatten_weights,
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx,
_megdnn_tensor_out dhx, _megdnn_tensor_out dcx, _megdnn_tensor_out dw,
_megdnn_workspace workspace);

virtual size_t get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& dcy, const TensorLayout& flatten_weights,
const TensorLayout& reserve_space, const TensorLayout& dx,
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw);
};

} // namespace naive
} // namespace megdnn

+ 55
- 0
dnn/src/naive/lstm/template_impl.cpp View File

@@ -0,0 +1,55 @@
/**
* \file dnn/src/naive/lstm/template_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/naive/rnn/funcs.h"

namespace megdnn {
namespace naive {
namespace rnn {

template <>
void cell_opr_exec<LSTMCellForward>(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in bias_hh, const TensorNDArray& states,
TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle) {
auto opr = handle->create_operator<LSTMCellForward>();
TensorLayout gates, h_new, c_new;
opr->deduce_layout(
input.layout, weight_ih.layout, bias_ih.layout, states[0].layout,
weight_hh.layout, bias_hh.layout, states[1].layout, h_new, c_new, gates);
TensorND gates_tensor{workspace.raw_ptr, gates};
_megdnn_workspace new_workspace = {
workspace.raw_ptr + gates.span().dist_byte(),
workspace.size - gates.span().dist_byte()};
opr->exec(
input, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states[1],
states_new[0], states_new[1], gates_tensor, new_workspace);
}

template <>
size_t cell_opr_get_workspace_in_bytes<LSTMCellForward>(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& weight_hh, const TensorLayout& bias_ih,
const TensorLayout& bias_hh, const TensorLayout& hx, Handle* handle) {
TensorLayout cx = hx;
TensorLayout h_new, c_new, gates;
auto cell_opr = handle->create_operator<LSTMCellForward>();
cell_opr->deduce_layout(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates);
return cell_opr->get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new,
gates) +
gates.span().dist_byte();
}

} // namespace rnn
} // namespace naive
} // namespace megdnn

+ 38
- 0
dnn/src/naive/lstm_cell/opr_impl.cpp View File

@@ -0,0 +1,38 @@
/**
* \file dnn/src/naive/lstm_cell/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/naive/lstm_cell/opr_impl.h"
#include "src/common/lstm_cell.h"

namespace megdnn {
namespace naive {
size_t LSTMCellImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new,
const TensorLayout& gates) {
return megdnn::lstm_cell::get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates,
handle());
}

void LSTMCellImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace) {
megdnn::lstm_cell::exec(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates,
workspace, handle());
}
} // namespace naive

} // namespace megdnn

+ 36
- 0
dnn/src/naive/lstm_cell/opr_impl.h View File

@@ -0,0 +1,36 @@
/**
* \file dnn/src/naive/lstm_cell/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/naive/rnn_cell/opr_impl.h"

namespace megdnn {
namespace naive {

class LSTMCellImpl : public LSTMCell {
public:
using LSTMCell::LSTMCell;
void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new,
const TensorLayout& c_new, const TensorLayout& gates) override;
};

} // namespace naive
} // namespace megdnn

+ 1
- 3
dnn/src/naive/relayout/opr_impl.cpp View File

@@ -72,9 +72,7 @@ void RelayoutForwardImpl::do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst)

void RelayoutForwardImpl::check_cpu_handle(Handle* handle) {
megdnn_assert(
!handle ||
handle == this->handle()
|| is_cpu_handle(handle),
!handle || handle == this->handle() || is_cpu_handle(handle),
"relayout from non-CPU to CPU not supported");
}



+ 75
- 0
dnn/src/naive/rnn/funcs.h View File

@@ -0,0 +1,75 @@
/**
* \file dnn/src/naive/rnn/funcs.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#ifndef _RNN_H
#define _RNN_H
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
namespace rnn {

template <typename CellOpr>
void cell_opr_exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in bias_hh, const TensorNDArray& states,
TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle);

template <typename CellOpr>
size_t cell_opr_get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& weight_hh, const TensorLayout& bias_ih,
const TensorLayout& bias_hh, const TensorLayout& hx, Handle* handle);

template <typename CellOpr>
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& flatten_weights,
size_t hidden_size,
size_t D, // num_directions
Handle* handle);

template <class Cell, typename CellOpr>
void exec_internal(
std::vector<Cell>& cells, _megdnn_tensor_in input, const TensorNDArray& states,
TensorNDArray& states_new, _megdnn_tensor_out output,
_megdnn_tensor_out reserve_space, size_t num_layers,
size_t D, // D is num_directions
Handle* handle, _megdnn_workspace workspace);

template <class Cell>
size_t get_cells(
size_t D, size_t num_layers, size_t input_size, size_t hidden_size, bool bias,
std::vector<Cell>& cells, _megdnn_tensor_in flatten_weights,
_megdnn_workspace workspace);

template <class Cell>
size_t get_inputs_for_exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in reserve_space,
size_t num_layers, size_t D, size_t hidden_size, const std::vector<Cell>& cells,
TensorNDArray& layer_inputs, TensorNDArray& layer_outputs,
std::vector<std::vector<TensorNDArray>>& cell_seq_states,
param::RNNCell::NonlineMode nonlineMode, _megdnn_workspace workspace);

template <class Cell>
void backward_exec_internal(
std::vector<Cell>& cells, size_t D, size_t num_layers, size_t input_size,
bool bias, param::RNNCell::NonlineMode nonlineMode,
const TensorNDArray& layer_inputs, const TensorNDArray& layer_outputs,
const std::vector<std::vector<TensorNDArray>>& cell_seq_states,
_megdnn_tensor_in dy, const TensorNDArray& dhy, _megdnn_tensor_out dx,
TensorNDArray& dstates, _megdnn_tensor_out dw, Handle* handle,
_megdnn_workspace workspace);

} // namespace rnn
} // namespace naive
} // namespace megdnn

#include "funcs.tpp"
#endif

+ 449
- 0
dnn/src/naive/rnn/funcs.tpp View File

@@ -0,0 +1,449 @@
/**
* \file dnn/src/naive/rnn/funcs.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "funcs.h"

namespace megdnn {
namespace naive {
namespace rnn {

template <typename CellOpr>
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& flatten_weights,
size_t hidden_size,
size_t D, // num_directions
Handle* handle) {
size_t seq_len = input.shape[0];
size_t batch_size = input.shape[1];
size_t input_size = input.shape[2];
size_t gate_hidden_size = flatten_weights.shape[0];
// concat workspace
TensorLayout direction_output_layout{
TensorShape{seq_len, batch_size, hidden_size}, input.dtype};
TensorLayout output_layout{{seq_len, batch_size, D * hidden_size}, input.dtype};
TensorLayoutArray layer_layouts;
for (size_t i = 0; i < D; ++i)
layer_layouts.push_back(direction_output_layout);
auto concat_opr = handle->create_operator<ConcatForward>();
concat_opr->param().axis = -1;
size_t concat_workspace =
concat_opr->get_workspace_in_bytes(layer_layouts, output_layout);
// cell workspace
TensorLayout weight_ih{{gate_hidden_size, input_size}, flatten_weights.dtype};
TensorLayout weight_hh{{gate_hidden_size, hidden_size}, flatten_weights.dtype};
TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype};
TensorLayout hx{{batch_size, hidden_size}, input.dtype};
size_t cell_workspace = cell_opr_get_workspace_in_bytes<CellOpr>(
input, weight_ih, weight_hh, bias, bias, hx, handle);

return std::max(cell_workspace, concat_workspace);
}

template <class Cell, typename CellOpr>
void exec_internal(
std::vector<Cell>& cells, _megdnn_tensor_in input, const TensorNDArray& states,
TensorNDArray& states_new, _megdnn_tensor_out output,
_megdnn_tensor_out reserve_space, size_t num_layers,
size_t D, // D is num_directions
Handle* handle, _megdnn_workspace workspace) {
size_t seq_len = input.layout.shape[0];
size_t batch_size = input.layout.shape[1];
size_t input_size = input.layout.shape[2];
size_t hidden_size = cells[0].weight_hh.layout.shape[1];
TensorLayout cell_output_layout{
TensorShape{batch_size, hidden_size}, states[0].layout.dtype};
TensorLayout cell_first_input_layout{
TensorShape{batch_size, input_size}, input.layout.dtype};
TensorLayout cell_input_layout{
TensorShape{batch_size, D * hidden_size}, input.layout.dtype};
TensorLayout direction_output_layout{
TensorShape{seq_len, batch_size, hidden_size}, output.layout.dtype};
TensorND tmp_output{workspace.raw_ptr, output.layout};
_megdnn_workspace new_workspace{
workspace.raw_ptr + tmp_output.layout.span().dist_byte(),
workspace.size - tmp_output.layout.span().dist_byte()};

auto cell_opr = handle->create_operator<CellOpr>();
auto copy_opr = handle->create_operator<TypeCvtForward>();

// copy states to states_new
for (size_t i = 0; i < states.size(); ++i)
copy_opr->exec(states[i], states_new[i]);
void* reserve_ptr = reserve_space.raw_ptr();

// layer 1
// TensorNDArray layer_outputs;
for (size_t d = 0; d < D; ++d) {
size_t cell_idx = d;
auto& cell = cells[cell_idx];

TensorNDArray cur_states;
size_t states_offset = cell_idx * cell_output_layout.span().dist_byte();
for (size_t i = 0; i < states.size(); ++i) {
cur_states.push_back(TensorND{
static_cast<uint8_t*>(states_new[i].raw_ptr()) + states_offset,
cell_output_layout});
}
// TensorND direction_output_tensor{output.raw_ptr + d *
// direction_output_layout.span().dist_byte(),
// direction_output_layout};
for (size_t i = 0; i < seq_len; ++i) {
size_t step = d == 0 ? i : seq_len - 1 - i;
TensorND step_input{
static_cast<uint8_t*>(input.raw_ptr()) +
step * cell_first_input_layout.span().dist_byte(),
cell_first_input_layout};
TensorND step_output{
static_cast<uint8_t*>(output.raw_ptr()) +
(step * D + d) * cell_output_layout.span().dist_byte(),
cell_output_layout};
// temporary states of each step (use reserve space)
TensorNDArray tmp_states;
for (size_t s = 0; s < cur_states.size(); ++s) {
tmp_states.push_back(TensorND{reserve_ptr, cur_states[s].layout});
size_t size_in_bytes = cur_states[s].layout.span().dist_byte();
reserve_ptr = static_cast<uint8_t*>(reserve_ptr) + size_in_bytes;
}
cell_opr_exec<CellOpr>(
step_input, cell.weight_ih, cell.weight_hh, cell.bias_ih,
cell.bias_hh, cur_states, tmp_states, new_workspace, handle);
// copy states to cur_states
for (size_t s = 0; s < tmp_states.size(); ++s) {
copy_opr->exec(tmp_states[s], cur_states[s]);
}
// copy h to step output
copy_opr->exec(cur_states[0], step_output);
}
}

for (size_t layer = 1; layer < num_layers; ++layer) {
// TensorNDArray layer_outputs;

for (size_t d = 0; d < D; ++d) {
size_t cell_idx = layer * D + d;
auto& cell = cells[cell_idx];

TensorNDArray cur_states;
size_t states_offset = cell_idx * cell_output_layout.span().dist_byte();
for (size_t i = 0; i < states.size(); ++i) {
cur_states.push_back(TensorND{
static_cast<uint8_t*>(states_new[i].raw_ptr()) + states_offset,
cell_output_layout});
}
// TensorND direction_output_tensor{output.raw_ptr + d *
// direction_output_layout.span().dist_byte(),
// direction_output_layout};

for (size_t i = 0; i < seq_len; ++i) {
size_t step = d == 0 ? i : seq_len - 1 - i;
TensorND step_input{
static_cast<uint8_t*>(output.raw_ptr()) +
step * cell_input_layout.span().dist_byte(),
cell_input_layout};
TensorND step_output{
static_cast<uint8_t*>(tmp_output.raw_ptr()) +
(step * D + d) * cell_output_layout.span().dist_byte(),
cell_output_layout};
// temporary states of each step (use reserve space)
TensorNDArray tmp_states;
for (size_t s = 0; s < cur_states.size(); ++s) {
tmp_states.push_back(TensorND{reserve_ptr, cur_states[s].layout});
size_t size_in_bytes = cur_states[s].layout.span().dist_byte();
reserve_ptr = static_cast<uint8_t*>(reserve_ptr) + size_in_bytes;
}
cell_opr_exec<CellOpr>(
step_input, cell.weight_ih, cell.weight_hh, cell.bias_ih,
cell.bias_hh, cur_states, cur_states, new_workspace, handle);
// copy states to cur_states
for (size_t s = 0; s < tmp_states.size(); ++s) {
copy_opr->exec(tmp_states[s], cur_states[s]);
}
// copy h to step_output
copy_opr->exec(cur_states[0], step_output);
}
}
// copy layer output to output
copy_opr->exec(tmp_output, output);
}
// output: [d0, d1, d0, d1 ...]
}

template <class Cell>
size_t get_cells(
size_t D, size_t num_layers, size_t input_size, size_t hidden_size, bool bias,
std::vector<Cell>& cells, _megdnn_tensor_in flatten_weights,
_megdnn_workspace workspace) {
cells.reserve(D * num_layers);
void* weight_ptr = flatten_weights.raw_ptr();
for (size_t layer = 0; layer < num_layers; ++layer) {
for (size_t d = 0; d < D; ++d) {
size_t cell_input_size = D * hidden_size;
if (layer == 0)
cell_input_size = input_size;
Cell cell(
weight_ptr, hidden_size, cell_input_size, bias,
flatten_weights.layout.dtype, workspace);
weight_ptr =
static_cast<uint8_t*>(weight_ptr) + cell.weight_size_in_bytes();
cells.push_back(cell);
}
}
// return used workspace
return cells[0].workspace_size_in_bytes();
}

template <class Cell>
size_t get_inputs_for_exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in reserve_space,
size_t num_layers, size_t D, size_t hidden_size, const std::vector<Cell>& cells,
TensorNDArray& layer_inputs, TensorNDArray& layer_outputs,
std::vector<std::vector<TensorNDArray>>& cell_seq_states,
param::RNNCell::NonlineMode nonlineMode, _megdnn_workspace workspace) {
// return used workspace size

layer_inputs.push_back(x);
size_t seq_len = x.layout.shape[0];
size_t batch_size = x.layout.shape[1];
size_t num_states = cells[0].num_states();
TensorLayout cell_output_layout{{batch_size, hidden_size}, y.layout.dtype};
TensorLayout direction_output_layout{
{seq_len, batch_size, hidden_size}, y.layout.dtype};
void* workspace_ptr = workspace.raw_ptr;

// extract intermedia states from reserve space
for (int layer = 0; layer < num_layers; ++layer) {
TensorND layer_output{workspace_ptr, y.layout};
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
layer_output.layout.span().dist_byte();
for (int d = 0; d < D; ++d) {
cell_seq_states.push_back(std::vector<TensorNDArray>());
// reverse direction is stored with reversed order of sequence order
for (int i = 0; i < seq_len; ++i) {
size_t step = i;
if (d == 1)
step = seq_len - i - 1;
size_t offset = ((layer * D + d) * seq_len + step) *
cell_output_layout.span().dist_byte() * num_states;
TensorNDArray cur_states;
for (int s = 0; s < num_states; ++s) {
TensorND h{
static_cast<uint8_t*>(reserve_space.raw_ptr()) + offset +
s * cell_output_layout.span().dist_byte(),
cell_output_layout};
cur_states.push_back(h);
}
TensorND hy{
static_cast<uint8_t*>(reserve_space.raw_ptr()) + offset,
cell_output_layout}; // the first hidden state is the output
// states
cell_seq_states[cell_seq_states.size() - 1].push_back(cur_states);
// output
offset = i * D * cell_output_layout.span().dist_byte();
memcpy(static_cast<uint8_t*>(layer_output.raw_ptr()) + offset, hy.raw_ptr(),
hy.layout.span().dist_byte());
}
}
layer_outputs.push_back(layer_output);
if (layer != num_layers - 1)
layer_inputs.push_back(layer_output);
}
return static_cast<uint8_t*>(workspace_ptr) -
static_cast<uint8_t*>((void*)workspace.raw_ptr);
}

template <class Cell>
// using Cell = RNNCellWeightWrapper;
void backward_exec_internal(
std::vector<Cell>& cells, size_t D, size_t num_layers, size_t input_size,
bool bias, param::RNNCell::NonlineMode nonlineMode,
const TensorNDArray& layer_inputs, const TensorNDArray& layer_outputs,
const std::vector<std::vector<TensorNDArray>>& cell_seq_states,
_megdnn_tensor_in dy, const TensorNDArray& dhy, _megdnn_tensor_out dx,
TensorNDArray& dstates, _megdnn_tensor_out dw, Handle* handle,
_megdnn_workspace workspace) {
/*
layer_inputs: array of input of each layer, element 0: [seq_len, batch_size,
input_size], element others: [seq_len, batch_size, D * hidden_size]
layer_outputs: array of outputs of each rnn. To access outputs of the cell at
(layer, d), use layer_outputs[layer]. The shape is [seq_len, batch_size,
output_size(D*hidden_size)] (in sequence order) cell_seq_states: arrray of states
of each cell at each step. To access the states of the cell at (layer, d) at
sequence step (step), use cell_seq_states[layer*D + d][step]
*/
size_t seq_len = layer_inputs[0].layout.shape[0];
size_t batch_size = layer_inputs[0].layout.shape[1];
DType dtype = layer_inputs[0].layout.dtype;
size_t cell_y_size =
layer_outputs[0].layout.shape[2] / D; // should all be the same
size_t hidden_size = cell_y_size;
TensorLayout cell_y_layout = {{batch_size, cell_y_size}, dtype};
void* workspace_ptr = workspace.raw_ptr;

TensorND layer_output_grad{
workspace_ptr, {{seq_len, batch_size, D * hidden_size}, dtype}};
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
layer_output_grad.layout.span().dist_byte();
memcpy(layer_output_grad.raw_ptr(), dy.raw_ptr(), dy.layout.span().dist_byte());
TensorNDArray direction_dx_arr; // for layer 1 to layer num_layers-1
for (int i = 0; i < D; ++i) {
TensorLayout direction_dx_layout{{seq_len, batch_size, hidden_size}, dtype};
direction_dx_arr.push_back(TensorND(workspace_ptr, direction_dx_layout));
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
direction_dx_layout.span().dist_byte();
}
TensorNDArray L0_direction_dx_arr;
for (int i = 0; i < D; ++i) {
TensorLayout direction_dx_layout{{seq_len, batch_size, input_size}, dtype};
L0_direction_dx_arr.push_back(TensorND(workspace_ptr, direction_dx_layout));
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
direction_dx_layout.span().dist_byte();
}
// cell states for each layer and each direction
std::vector<TensorNDArray> dstates_arr;
for (int layer = 0; layer < num_layers; ++layer) {
for (int d = 0; d < D; ++d) {
TensorNDArray cell_states;
cell_states.reserve(dstates.size());
for (int i = 0; i < dstates.size(); ++i) {
size_t offset = (layer * D + d) * cell_y_layout.span().dist_byte();
TensorND dhx_cell{
static_cast<uint8_t*>(dstates[i].raw_ptr()) + offset,
cell_y_layout};
memcpy(dhx_cell.raw_ptr(), static_cast<uint8_t*>(dhy[i].raw_ptr()) + offset,
cell_y_layout.span().dist_byte());
cell_states.emplace_back(dhx_cell);
}
dstates_arr.push_back(cell_states);
}
}

// init gradient on weight to zero
memset(dw.raw_ptr(), 0, dw.layout.span().dist_byte());
// use cells to contain gradient
std::vector<Cell> cell_grads;
size_t used_workspace_size = static_cast<uint8_t*>(workspace_ptr) -
static_cast<uint8_t*>((void*)(workspace.raw_ptr));
workspace_ptr =
static_cast<uint8_t*>(workspace_ptr) +
get_cells(
D, num_layers, input_size, hidden_size, bias, cell_grads, dw,
Workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size));

auto add_opr = handle->create_operator<ElemwiseForward>();
add_opr->param().mode = Elemwise::Mode::ADD;
auto copy_opr = handle->create_operator<TypeCvtForward>();

// initialize dx to zero
memset(dx.raw_ptr(), 0, dx.layout.span().dist_byte());

// calculate grads
for (int layer = num_layers - 1; layer >= 0; --layer) {
for (int d = D - 1; d >= 0; --d) {
Cell& cell = cells[layer * D + d];
Cell& cell_grad = cell_grads[layer * D + d];
size_t input_size = layer_inputs[layer].layout.shape[2];
const TensorND& x_arr = layer_inputs[layer];
const TensorND& y_arr = layer_outputs[layer];
TensorLayout x_layout = {{batch_size, input_size}, dtype};

// tmp tensors
void* tmp_workspace_ptr = workspace_ptr;
TensorND dwi_tmp{tmp_workspace_ptr, cell_grad.weight_ih.layout};
tmp_workspace_ptr = static_cast<uint8_t*>(tmp_workspace_ptr) +
dwi_tmp.layout.span().dist_byte();
TensorND dwh_tmp{tmp_workspace_ptr, cell_grad.weight_hh.layout};
tmp_workspace_ptr = static_cast<uint8_t*>(tmp_workspace_ptr) +
dwh_tmp.layout.span().dist_byte();
TensorND dbias_tmp{tmp_workspace_ptr, cell_grad.bias_ih.layout};
tmp_workspace_ptr = static_cast<uint8_t*>(tmp_workspace_ptr) +
dbias_tmp.layout.span().dist_byte();
size_t used_workspace_size =
static_cast<uint8_t*>(tmp_workspace_ptr) -
static_cast<uint8_t*>((void*)(workspace.raw_ptr));

for (int i = 0; i < seq_len; ++i) {
// reverse time step (not seq step). Here step means seq step
size_t step = i;
if (d == 0)
step = seq_len - i - 1;
TensorND x{
static_cast<uint8_t*>(x_arr.raw_ptr()) +
step * x_layout.span().dist_byte(),
x_layout},
y{static_cast<uint8_t*>(y_arr.raw_ptr()) +
(step * D + d) * cell_y_layout.span().dist_byte(),
cell_y_layout};
const TensorNDArray& cell_states = cell_seq_states[layer * D + d][step];
TensorNDArray& dstates_new = dstates_arr[layer * D + d];
// dy should be d_output + d_hidden
TensorND dy_t{
static_cast<uint8_t*>(layer_output_grad.raw_ptr()) +
(step * D + d) * cell_y_layout.span().dist_byte(),
cell_y_layout};
add_opr->exec({dstates_new[0], dy_t}, dy_t);
// dx for layer 0 has a different size
TensorND dx_t;
if (layer == 0)
dx_t = {static_cast<uint8_t*>(L0_direction_dx_arr[d].raw_ptr()) +
step * x_layout.span().dist_byte(),
x_layout};
else
dx_t = {static_cast<uint8_t*>(direction_dx_arr[d].raw_ptr()) +
step * x_layout.span().dist_byte(),
x_layout};
TensorNDArray douts = {dy_t};
for (int s = 1; s < dstates_new.size(); ++s)
douts.push_back(dstates_new[s]);
cell.backward(
handle, nonlineMode, x, cell_states, y, douts, dx_t,
dstates_new, dwi_tmp, dwh_tmp, dbias_tmp,
Workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size));
// add step gradient to overall gradient
add_opr->exec({dwi_tmp, cell_grad.weight_ih}, cell_grad.weight_ih);
add_opr->exec({dwh_tmp, cell_grad.weight_hh}, cell_grad.weight_hh);
add_opr->exec({dbias_tmp, cell_grad.bias_ih}, cell_grad.bias_ih);
add_opr->exec({dbias_tmp, cell_grad.bias_hh}, cell_grad.bias_hh);
}
}
// add gradient of different directions to layer_output_grad.
// Layer 0 to dx
if (layer == 0) {
for (int i = 0; i < D; ++i)
add_opr->exec({L0_direction_dx_arr[i], dx}, dx);
} else {
if (D == 1)
copy_opr->exec(direction_dx_arr[0], layer_output_grad);
else { // D == 2, arrange as [(d0, d1), (d0, d1), ...]
for (size_t t = 0; t < seq_len; ++t) {
size_t offset = t * D * cell_y_layout.span().dist_byte();
for (size_t d = 0; d < D; ++d) {
TensorND src{
static_cast<uint8_t*>(direction_dx_arr[d].raw_ptr()) +
offset,
cell_y_layout};
TensorND dst{
static_cast<uint8_t*>(layer_output_grad.raw_ptr()) +
offset + d * cell_y_layout.span().dist_byte(),
cell_y_layout};
copy_opr->exec(src, dst);
}
}
}
}
}
}

} // namespace rnn
} // namespace naive
} // namespace megdnn

+ 196
- 0
dnn/src/naive/rnn/opr_impl.cpp View File

@@ -0,0 +1,196 @@
/**
* \file dnn/src/naive/rnn/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/naive/rnn/opr_impl.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs/base.h"
#include "megdnn/oprs/general.h"
#include "src/common/opr_delegate.h"
#include "src/common/rnn.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include "src/naive/matrix_mul/opr_impl.h"
#include "src/naive/rnn/funcs.h"
#include "src/naive/rnn/rnn.h"

#include <cstring>

namespace megdnn {
namespace naive {

using rnn::RNNCellWeightWrapper;

void RNNImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
_megdnn_workspace workspace) {
auto _param = param();
size_t D = _param.bidirectional ? 2 : 1;
size_t num_layers = _param.num_layers;
size_t input_size = input.layout.shape[2];
std::vector<RNNCellWeightWrapper> cells;
size_t used_workspace_size = rnn::get_cells<RNNCellWeightWrapper>(
D, num_layers, input_size, _param.hidden_size, _param.bias, cells,
flatten_weights, workspace);

Workspace new_workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
TensorNDArray states, states_new;
states.push_back(hx);
states_new.push_back(hy);
rnn::exec_internal<RNNCellWeightWrapper, RNNCellForward>(
cells, input, states, states_new, output, reserve_space, num_layers, D,
this->handle(), new_workspace);
}

size_t RNNImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& reserve_space) {
size_t workspace_size = rnn::get_workspace_in_bytes<RNNCellForward>(
input, flatten_weights, param().hidden_size, param().bidirectional ? 2 : 1,
this->handle());
if (!param().bias) { // use fake bias (all 0)
TensorLayout bias_layout = {{param().hidden_size}, flatten_weights.dtype};
workspace_size += bias_layout.span().dist_byte();
}
workspace_size += output.span().dist_byte();
return workspace_size;
}

size_t RNNImpl::get_reserve_size_in_bytes(const TensorLayout& input) {
size_t num_layers = param().num_layers;
size_t D = param().bidirectional ? 2 : 1;
size_t seq_len = input.shape[0];
size_t batch_size = input.shape[1];
TensorLayout state_layout{{batch_size, param().hidden_size}, input.dtype};
return num_layers * D * seq_len * state_layout.span().dist_byte();
}

void RNNBackwardImpl::exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in dy, _megdnn_tensor_in dhy, _megdnn_tensor_in flatten_weights,
_megdnn_tensor_in reserve_space, _megdnn_tensor_out dx, _megdnn_tensor_out dhx,
_megdnn_tensor_out dw, _megdnn_workspace workspace) {
TensorNDArray layer_inputs;
// layer_inputs.push_back(x);
TensorNDArray layer_outputs;
std::vector<std::vector<TensorNDArray>> cell_seq_states;
size_t num_layers = param().num_layers;
size_t D = param().bidirectional ? 2 : 1;
// size_t seq_len = x.layout.shape[0];
// size_t batch_size = x.layout.shape[1];
size_t input_size = x.layout.shape[2];
size_t hidden_size = param().hidden_size;
size_t used_workspace_size = 0;

// get cells
std::vector<RNNCellWeightWrapper> cells;
// workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
used_workspace_size += rnn::get_cells(
D, num_layers, input_size, hidden_size, param().bias, cells,
flatten_weights, workspace);

// extract intermedia states from reserve space
/*for (int layer = 0; layer < num_layers; ++layer) {
TensorND layer_output{workspace_ptr, y.layout};
workspace_ptr = static_cast<uint8_t*>(workspace_ptr) +
layer_output.layout.span().dist_byte(); for (int d = 0; d < D; ++d) {
cell_seq_states.push_back(std::vector<TensorNDArray>());
// reverse direction is stored with reversed order of sequence order
for (int i = 0; i < seq_len; ++i) {
size_t step = i;
if (d == 1) step = seq_len - i - 1;
size_t offset = ((layer * D + d) * seq_len + step) *
cell_output_layout.span().dist_byte(); TensorND
hy{static_cast<uint8_t*>(reserve_space.raw_ptr) + offset, cell_output_layout};
// states
cell_seq_states[cell_seq_states.size() - 1].push_back({hy});
// output
offset = i * D * cell_output_layout.span().dist_byte();
memcpy(static_cast<uint8_t*>(layer_output.raw_ptr) + offset,
hy.raw_ptr, hy.layout.span().dist_byte());
}
}
cell_seq_outputs.push_back(layer_output);
if (layer != num_layers - 1) layer_inputs.push_back(layer_output);
}*/
// nonlinear mode
param::RNNCell::NonlineMode nonlineMode;
using ModeRNN = param::RNN::NonlineMode;
using ModeRNNCell = param::RNNCell::NonlineMode;
switch (param().nonlineMode) {
case ModeRNN::RELU:
nonlineMode = ModeRNNCell::RELU;
break;
case ModeRNN::TANH:
nonlineMode = ModeRNNCell::TANH;
break;
}

// get formatted inputs
Workspace new_workspace = Workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
used_workspace_size += rnn::get_inputs_for_exec<RNNCellWeightWrapper>(
x, y, reserve_space, num_layers, D, hidden_size, cells, layer_inputs,
layer_outputs, cell_seq_states, nonlineMode, new_workspace);

// dhy arr, dhx arr
TensorNDArray dhy_arr = {dhy}, dhx_arr = {dhx};

// exec
/*size_t used_workspace_size = static_cast<uint8_t*>(workspace_ptr) -
static_cast<uint8_t*>((void*)workspace.raw_ptr);*/
new_workspace = Workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
rnn::backward_exec_internal<RNNCellWeightWrapper>(
cells, D, num_layers, input_size, param().bias, nonlineMode, layer_inputs,
layer_outputs, cell_seq_states, dy, dhy_arr, dx, dhx_arr, dw,
this->handle(), new_workspace);
}

size_t RNNBackwardImpl::get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw) {
size_t D = param().bidirectional ? 2 : 1;
size_t num_layers = param().num_layers;
size_t hidden_size = param().hidden_size;
size_t gate_hidden_size = hidden_size;
size_t max_input_size = std::max(x.shape[2], D * hidden_size);

size_t workspace_size = RNNCellWeightWrapper::backward_workspace_size_in_bytes(
this->handle(), x.shape[1], param().hidden_size, max_input_size, x.dtype);
if (!param().bias) { // use fake bias (all 0)
TensorLayout bias_layout = {{gate_hidden_size}, flatten_weights.dtype};
workspace_size += bias_layout.span().dist_byte() *
2; // times 2 because another bias is allocated in
// backward_exec_internal
}
workspace_size += num_layers * y.span().dist_byte();
// add back exec workspace size
workspace_size += y.span().dist_byte() * 2;
workspace_size += x.span().dist_byte() * 2;
TensorLayout wih{{gate_hidden_size, max_input_size}, flatten_weights.dtype};
TensorLayout whh{{gate_hidden_size, hidden_size}, flatten_weights.dtype};
TensorLayout bias{{gate_hidden_size}, flatten_weights.dtype};
workspace_size += wih.span().dist_byte();
workspace_size += whh.span().dist_byte();
workspace_size += bias.span().dist_byte();
return workspace_size;
}
} // namespace naive

} // namespace megdnn

+ 53
- 0
dnn/src/naive/rnn/opr_impl.h View File

@@ -0,0 +1,53 @@
/**
* \file dnn/src/naive/rnn/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"

namespace megdnn {
namespace naive {

class RNNImpl : public RNN {
public:
using RNN::RNN;

void exec(
_megdnn_tensor_in input, _megdnn_tensor_in hx,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output,
_megdnn_tensor_out hy, _megdnn_tensor_out reserve_space,
_megdnn_workspace workspace);

size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& hx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& reserve_space);
size_t get_reserve_size_in_bytes(const TensorLayout& input);
};

class RNNBackwardImpl : public RNNBackward {
public:
using RNNBackward::RNNBackward;

virtual void exec(
_megdnn_tensor_in x, _megdnn_tensor_in y, _megdnn_tensor_in hx,
_megdnn_tensor_in dy, _megdnn_tensor_in dhy,
_megdnn_tensor_in flatten_weights, _megdnn_tensor_in reserve_space,
_megdnn_tensor_out dx, _megdnn_tensor_out dhx, _megdnn_tensor_out dw,
_megdnn_workspace workspace);

virtual size_t get_workspace_in_bytes(
const TensorLayout& x, const TensorLayout& y, const TensorLayout& hx,
const TensorLayout& dy, const TensorLayout& dhy,
const TensorLayout& flatten_weights, const TensorLayout& reserve_space,
const TensorLayout& dx, const TensorLayout& dhx, const TensorLayout& dw);
};

} // namespace naive
} // namespace megdnn

+ 285
- 0
dnn/src/naive/rnn/rnn.cpp View File

@@ -0,0 +1,285 @@
/**
* \file dnn/src/naive/rnn/rnn.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/naive/rnn/rnn.h"

#include <cmath>
#include <cstring>

namespace megdnn {
namespace naive {
namespace rnn {

CellWeightsWrapperBase::CellWeightsWrapperBase(
void* weight_ptr, size_t hidden_size, size_t input_size, size_t num_chunks,
bool has_bias, DType dtype, _megdnn_workspace workspace) {
// weight_ih: [gate_hidden_size, input_size]
// weight_hh: [gate_hidden_size, hidden_size]
// bias_ih: [gate_hidden_size]
// bias_hh: [gate_hidden_size]
size_t gate_hidden_size = num_chunks * hidden_size;
TensorLayout weight_ih_layout{{gate_hidden_size, input_size}, dtype};
TensorLayout weight_hh_layout{{gate_hidden_size, hidden_size}, dtype};
TensorLayout bias_layout{{gate_hidden_size}, dtype};
this->_weight_size = 0;
this->weight_ih = TensorND(weight_ptr, weight_ih_layout);
this->_weight_size += weight_ih_layout.span().dist_byte();
this->weight_hh = TensorND(
static_cast<uint8_t*>(weight_ptr) + this->_weight_size, weight_hh_layout);
this->_weight_size += weight_hh_layout.span().dist_byte();
if (has_bias) {
this->bias_ih = TensorND(
static_cast<uint8_t*>(weight_ptr) + this->_weight_size, bias_layout);
this->_weight_size += bias_layout.span().dist_byte();
this->bias_hh = TensorND(
static_cast<uint8_t*>(weight_ptr) + this->_weight_size, bias_layout);
this->_weight_size += bias_layout.span().dist_byte();
this->_workspace_size = 0;
} else {
this->bias_ih = TensorND(workspace.raw_ptr, bias_layout);
this->bias_hh = TensorND(workspace.raw_ptr, bias_layout);
memset(workspace.raw_ptr, 0, bias_layout.span().dist_byte());
this->_workspace_size = bias_layout.span().dist_byte();
}
}

size_t CellWeightsWrapperBase::weight_size_in_bytes() const {
return this->_weight_size;
}

size_t CellWeightsWrapperBase::workspace_size_in_bytes() const {
return this->_workspace_size;
}

size_t CellWeightsWrapperBase::num_states() const {
return 1;
}

void CellWeightsWrapperBase::backward(
Handle* handle, param::RNNCell::NonlineMode nonlineMode, _megdnn_tensor_in x,
const TensorNDArray& states, _megdnn_tensor_in y, const TensorNDArray& douts,
_megdnn_tensor_out dx, TensorNDArray& dstates, _megdnn_tensor_out dwi,
_megdnn_tensor_out dwh, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) const {
auto dy = douts[0];
using NonlineMode = param::RNNCell::NonlineMode;
using Mode = Elemwise::Mode;
auto elemwise_opr = handle->create_operator<ElemwiseForward>();
TensorND tmp = {workspace.raw_ptr, dy.layout};
auto new_workspace = Workspace(
workspace.raw_ptr + tmp.layout.span().dist_byte(),
workspace.size - tmp.layout.span().dist_byte());
switch (nonlineMode) {
case (NonlineMode::IDENTITY):
memcpy(tmp.raw_ptr(), dy.raw_ptr(), dy.layout.span().dist_byte());
break;
case (NonlineMode::TANH):
elemwise_opr->param().mode = Mode::TANH_GRAD;
elemwise_opr->exec({y, dy}, tmp);
break;
case (NonlineMode::RELU):
elemwise_opr->param().mode = Mode::SWITCH_GT0;
elemwise_opr->exec({y, dy}, tmp);
break;
}
auto matrixmul_opr = handle->create_operator<MatrixMulForward>();
matrixmul_opr->param().transposeA = false;
matrixmul_opr->param().transposeB = false;
// dx
matrixmul_opr->exec(tmp, this->weight_ih, dx, new_workspace);
// dhx
matrixmul_opr->exec(tmp, this->weight_hh, dstates[0], new_workspace);
// dwi
matrixmul_opr->param().transposeA = true;
matrixmul_opr->exec(tmp, x, dwi, new_workspace);
// dwh
matrixmul_opr->exec(tmp, states[0], dwh, new_workspace);
// dbias
auto sum_opr = handle->create_operator<ReduceForward>();
sum_opr->param().mode = ReduceForward::Mode::SUM;
sum_opr->param().axis = 0;
TensorND dbias_expanded = {
dbias.raw_ptr(), {{1, dbias.layout.shape[0]}, dbias.layout.dtype}};
sum_opr->exec(tmp, dbias_expanded, new_workspace);
}

size_t CellWeightsWrapperBase::backward_workspace_size_in_bytes(
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size,
size_t num_chunks, DType dtype) {
size_t gate_hidden_size = hidden_size * num_chunks;
TensorLayout tmp = {{batch_size, gate_hidden_size}, dtype};
TensorLayout bias_expanded = {{1, gate_hidden_size}, dtype};
TensorLayout wih = {{gate_hidden_size, input_size}, dtype};
TensorLayout whh = {{gate_hidden_size, hidden_size}, dtype};
TensorLayout x = {{batch_size, input_size}, dtype};
TensorLayout hx = {{batch_size, hidden_size}, dtype};
size_t workspace_size = 0;
auto matrixmul_opr = handle->create_operator<MatrixMulForward>();
matrixmul_opr->param().transposeA = false;
matrixmul_opr->param().transposeB = false;
// dx
workspace_size = std::max(
workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, wih, x));
// dhx
workspace_size = std::max(
workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, whh, hx));
// dwi
matrixmul_opr->param().transposeA = true;
workspace_size = std::max(
workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, x, wih));
// dwh
workspace_size = std::max(
workspace_size, matrixmul_opr->get_workspace_in_bytes(tmp, hx, whh));
// dbias
auto sum_opr = handle->create_operator<ReduceForward>();
sum_opr->param().mode = ReduceForward::Mode::SUM;
sum_opr->param().axis = 0;
workspace_size = std::max(
workspace_size, sum_opr->get_workspace_in_bytes(tmp, bias_expanded));
workspace_size += tmp.span().dist_byte();
return workspace_size;
}

RNNCellWeightWrapper::RNNCellWeightWrapper(
void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias,
DType dtype, _megdnn_workspace workspace)
: CellWeightsWrapperBase(
weight_ptr, hidden_size, input_size, 1, has_bias, dtype, workspace) {}

size_t RNNCellWeightWrapper::backward_workspace_size_in_bytes(
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size,
DType dtype) {
return CellWeightsWrapperBase::backward_workspace_size_in_bytes(
handle, batch_size, hidden_size, input_size, 1, dtype);
}

LSTMCellWeightWrapper::LSTMCellWeightWrapper(
void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias,
DType dtype, _megdnn_workspace workspace)
: CellWeightsWrapperBase(
weight_ptr, hidden_size, input_size, 4, has_bias, dtype, workspace) {}

size_t LSTMCellWeightWrapper::num_states() const {
return 2;
}

size_t LSTMCellWeightWrapper::backward_workspace_size_in_bytes(
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size,
DType dtype) {
// get gates size
size_t gate_hidden_size = 4 * hidden_size;
auto lstm_opr = handle->create_operator<LSTMCellForward>();
TensorLayout x = {{batch_size, input_size}, dtype};
TensorLayout weight_ih = {{gate_hidden_size, input_size}, dtype};
TensorLayout weight_hh = {{gate_hidden_size, hidden_size}, dtype};
TensorLayout bias = {{gate_hidden_size}, dtype};
TensorLayout h = {{batch_size, hidden_size}, dtype};
TensorLayout gates, h_new, c_new;
lstm_opr->deduce_layout(
x, weight_ih, bias, h, weight_hh, bias, h, h_new, c_new, gates);
return CellWeightsWrapperBase::backward_workspace_size_in_bytes(
handle, batch_size, hidden_size, input_size, 4, dtype) +
gates.span().dist_byte() * 2 + c_new.span().dist_byte();
}

void LSTMCellWeightWrapper::backward(
Handle* handle,
param::RNNCell::NonlineMode nonlineMode, // nonlineMode must be identity
_megdnn_tensor_in x, const TensorNDArray& states, _megdnn_tensor_in y,
const TensorNDArray& douts, _megdnn_tensor_out dx, TensorNDArray& dstates,
_megdnn_tensor_out dwi, _megdnn_tensor_out dwh, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) const {
size_t used_workspace_size = 0;
// get gates
auto lstm_opr = handle->create_operator<LSTMCellForward>();
TensorLayout gates, h_new, c_new;
lstm_opr->deduce_layout(
x.layout, weight_ih.layout, bias_ih.layout, states[0].layout,
weight_hh.layout, bias_hh.layout, states[1].layout, h_new, c_new, gates);
TensorND gates_tensor{workspace.raw_ptr, gates};
used_workspace_size += gates.span().dist_byte();
TensorND gates_grad{workspace.raw_ptr + used_workspace_size, gates};
used_workspace_size += gates.span().dist_byte();
TensorND tanh_cy{workspace.raw_ptr + used_workspace_size, y.layout};
used_workspace_size += tanh_cy.layout.span().dist_byte();
Workspace new_workspace = Workspace(
workspace.raw_ptr + used_workspace_size,
workspace.size - used_workspace_size);
// temporarily use dstates to store hy, cy
// only gates and cy needed, other output will be cleared afterwards
lstm_opr->exec(
x, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states[1], dstates[0],
dstates[1], gates_tensor,
new_workspace); // no information left in the workspace
// i, f, o, g
TensorLayout single_gate = {{gates.shape[0], gates.shape[1] / 4}, gates.dtype};
TensorND i, f, o, g, i_grad, f_grad, o_grad,
g_grad; // grad refers to the grad of gates before activation
i = {gates_tensor.raw_ptr(), single_gate};
f = {static_cast<uint8_t*>(gates_tensor.raw_ptr()) + single_gate.span().dist_byte(),
single_gate};
o = {static_cast<uint8_t*>(f.raw_ptr()) + single_gate.span().dist_byte(),
single_gate};
g = {static_cast<uint8_t*>(o.raw_ptr()) + single_gate.span().dist_byte(),
single_gate};
i_grad = {gates_grad.raw_ptr(), single_gate};
f_grad = {
static_cast<uint8_t*>(i_grad.raw_ptr()) + single_gate.span().dist_byte(),
single_gate};
o_grad = {
static_cast<uint8_t*>(f_grad.raw_ptr()) + single_gate.span().dist_byte(),
single_gate};
g_grad = {
static_cast<uint8_t*>(o_grad.raw_ptr()) + single_gate.span().dist_byte(),
single_gate};
// activation
auto elem_opr = handle->create_operator<ElemwiseForward>();
elem_opr->param().mode = Elemwise::Mode::SIGMOID;
elem_opr->exec({i}, i);
elem_opr->exec({f}, f);
elem_opr->exec({o}, o);
elem_opr->param().mode = Elemwise::Mode::TANH;
elem_opr->exec({g}, g);
elem_opr->exec({dstates[1]}, tanh_cy);
auto mul_opr = handle->create_operator<ElemwiseForward>();
mul_opr->param().mode = Elemwise::Mode::MUL;
// use dstates[0] as tmp tensor to store dhy * tanh_cy
mul_opr->exec({douts[0], tanh_cy}, dstates[0]);
elem_opr->param().mode = Elemwise::Mode::SIGMOID_GRAD;
elem_opr->exec({o, dstates[0]}, o_grad); // grad of gate o
// use dstates[0] as tmp tensor to store dhy * o
mul_opr->exec({douts[0], o}, dstates[0]);
elem_opr->param().mode = Elemwise::Mode::TANH_GRAD;
elem_opr->exec({tanh_cy, dstates[0]}, dstates[1]); // grad of cy from hy
elem_opr->param().mode = Elemwise::Mode::ADD;
elem_opr->exec({douts[1], dstates[1]}, dstates[1]); // true grad of cy
// use dstates[0] as tmp tensor to store dcy * cx
mul_opr->exec({dstates[1], states[1]}, dstates[0]);
elem_opr->param().mode = Elemwise::Mode::SIGMOID_GRAD;
elem_opr->exec({f, dstates[0]}, f_grad); // grad of gate f
// use dstates[0] as tmp tensor to store dcy * g
mul_opr->exec({dstates[1], g}, dstates[0]);
elem_opr->exec({i, dstates[0]}, i_grad); // grad of gate i
// use dstates[0] as tmp tensor to store dcy * i
mul_opr->exec({dstates[1], i}, dstates[0]);
elem_opr->param().mode = Elemwise::Mode::TANH_GRAD;
elem_opr->exec({g, dstates[0]}, g_grad); // grad of gate g

// grad if cx
mul_opr->exec({dstates[1], f}, dstates[1]);
TensorNDArray base_dstates = {dstates[0]};
CellWeightsWrapperBase::backward(
handle, nonlineMode, x, {states[0]}, gates_tensor, {gates_grad}, dx,
base_dstates, dwi, dwh, dbias, new_workspace);
}

} // namespace rnn
} // namespace naive
} // namespace megdnn

+ 73
- 0
dnn/src/naive/rnn/rnn.h View File

@@ -0,0 +1,73 @@
/**
* \file dnn/src/naive/rnn/rnn.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "megdnn/oprs/general.h"

namespace megdnn {
namespace naive {
namespace rnn {

class CellWeightsWrapperBase {
private:
size_t _weight_size, _workspace_size;

public:
TensorND weight_ih, weight_hh, bias_ih, bias_hh;
// if no bias, will create dummy bias tensor from workspace
CellWeightsWrapperBase(
void* weight_ptr, size_t hidden_size, size_t input_size, size_t num_chunks,
bool has_bias, DType dtype, _megdnn_workspace workspace);
size_t weight_size_in_bytes() const;
size_t workspace_size_in_bytes() const;
static size_t backward_workspace_size_in_bytes(
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size,
size_t num_chunks, DType dtype);
virtual void backward(
Handle* handle, param::RNNCell::NonlineMode nonlineMode,
_megdnn_tensor_in x, const TensorNDArray& states, _megdnn_tensor_in y,
const TensorNDArray& douts, _megdnn_tensor_out dx, TensorNDArray& dstates,
_megdnn_tensor_out dwi, _megdnn_tensor_out dwh, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) const;
virtual size_t num_states() const;
};

class RNNCellWeightWrapper : public CellWeightsWrapperBase {
public:
RNNCellWeightWrapper(
void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias,
DType dtype, _megdnn_workspace workspace);

static size_t backward_workspace_size_in_bytes(
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size,
DType dtype);
};

class LSTMCellWeightWrapper : public CellWeightsWrapperBase {
public:
LSTMCellWeightWrapper(
void* weight_ptr, size_t hidden_size, size_t input_size, bool has_bias,
DType dtype, _megdnn_workspace workspace);
static size_t backward_workspace_size_in_bytes(
Handle* handle, size_t batch_size, size_t hidden_size, size_t input_size,
DType dtype);
size_t num_states() const override;
void backward(
Handle* handle, param::RNNCell::NonlineMode nonlineMode,
_megdnn_tensor_in x, const TensorNDArray& states, _megdnn_tensor_in y,
const TensorNDArray& douts, _megdnn_tensor_out dx, TensorNDArray& dstates,
_megdnn_tensor_out dwi, _megdnn_tensor_out dwh, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) const override;
};

} // namespace rnn
} // namespace naive
} // namespace megdnn

+ 41
- 0
dnn/src/naive/rnn/template_impl.cpp View File

@@ -0,0 +1,41 @@
/**
* \file dnn/src/naive/rnn/template_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/naive/rnn/funcs.h"

namespace megdnn {
namespace naive {
namespace rnn {

template <>
void cell_opr_exec<RNNCellForward>(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in bias_hh, const TensorNDArray& states,
TensorNDArray& states_new, _megdnn_workspace workspace, Handle* handle) {
auto opr = handle->create_operator<RNNCellForward>();
opr->exec(
input, weight_ih, bias_ih, states[0], weight_hh, bias_hh, states_new[0],
workspace);
}

template <>
size_t cell_opr_get_workspace_in_bytes<RNNCellForward>(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& weight_hh, const TensorLayout& bias_ih,
const TensorLayout& bias_hh, const TensorLayout& hx, Handle* handle) {
auto cell_opr = handle->create_operator<RNNCellForward>();
return cell_opr->get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, hx);
}

} // namespace rnn
} // namespace naive
} // namespace megdnn

+ 34
- 0
dnn/src/naive/rnn_cell/opr_impl.cpp View File

@@ -0,0 +1,34 @@
/**
* \file dnn/src/naive/rnn_cell/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/naive/rnn_cell/opr_impl.h"
#include "src/common/rnn_cell.h"

namespace megdnn {
namespace naive {
size_t RNNCellImpl::get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst) {
return megdnn::rnn_cell::get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, handle());
}

void RNNCellImpl::exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace) {
megdnn::rnn_cell::exec(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst, workspace,
param().nonlineMode, handle());
}
} // namespace naive
} // namespace megdnn

+ 33
- 0
dnn/src/naive/rnn_cell/opr_impl.h View File

@@ -0,0 +1,33 @@
/**
* \file dnn/src/naive/rnn_cell/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"

namespace megdnn {
namespace naive {

class RNNCellImpl : public RNNCell {
public:
using RNNCell::RNNCell;
void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih,
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx,
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst) override;
};

} // namespace naive
} // namespace megdnn

+ 9
- 0
dnn/test/common/deduce_layout_proxy.h View File

@@ -68,6 +68,15 @@ struct DeduceLayoutProxy<Opr, 6, false> {
};

template <typename Opr>
struct DeduceLayoutProxy<Opr, 6, true> {
static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() == 6);
opr->deduce_layout(
layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5]);
}
};

template <typename Opr>
struct DeduceLayoutProxy<Opr, 7, false> {
static void deduce_layout(Opr*, TensorLayoutArray&) {}
};


+ 2
- 3
dnn/test/common/elemwise.cpp View File

@@ -248,7 +248,6 @@ DEF_TEST(fuse_mul_add4) {
}

DEF_TEST(rmulh) {

using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle);

@@ -315,7 +314,7 @@ DEF_TEST(unary1) {
// case float
UniformFloatRNG rng(1e-2, 6e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
BUILD_UNARY_TEST_CASE_FLOAT
}
@@ -900,7 +899,7 @@ DEF_TEST(unary_negative_stride) {

UniformFloatRNG rng(1e-2, 6e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_epsilon(1e-5);
BUILD_UNARY_NEGATIVE_STRIDE_TEST_CASE_FLOAT;
}



+ 51
- 0
dnn/test/common/rnn.h View File

@@ -0,0 +1,51 @@
/**
* \file dnn/test/common/rnn.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <vector>

#include "megdnn/basic_types.h"
#include "megdnn/opr_param_defs.h"

namespace megdnn {
namespace test {
namespace rnn {
struct TestArg {
param::RNN param;
TensorShape input, hx, flatten_weights;
TestArg(param::RNN param, TensorShape input, TensorShape hx,
TensorShape flatten_weights)
: param(param), input(input), hx(hx), flatten_weights(flatten_weights) {}
};

inline std::vector<TestArg> get_args() {
std::vector<TestArg> args;
size_t batch_size = 2;
size_t input_size = 3;
size_t hidden_size = 2;
size_t seq_len = 2;
size_t gate_hidden_size = hidden_size;
param::RNN param;
param.num_layers = 1;
param.bidirectional = false;
param.bias = false;
param.hidden_size = hidden_size;
param.nonlineMode = param::RNN::NonlineMode::RELU;

args.emplace_back(
param, TensorShape{seq_len, batch_size, input_size},
TensorShape{batch_size, hidden_size},
TensorShape{gate_hidden_size, input_size + hidden_size});
return args;
}

} // namespace rnn
} // namespace test
} // namespace megdnn

+ 80
- 0
dnn/test/naive/rnn.cpp View File

@@ -0,0 +1,80 @@
/**
* \file dnn/test/naive/rnn.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/common/rnn.h"
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/naive/fixture.h"

namespace megdnn {
namespace test {

/*TEST_F(NAIVE, RNN) {
std::vector<rnn::TestArg> args = rnn::get_args();
Checker<RNN> checker(handle());
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(3, dtype::Float32())
.set_dtype(4, dtype::Float32())
.set_dtype(5, dtype::Float32())
.execs({arg.input, arg.hx, arg.flatten_weights, {}, {}, {}});
}
}*/

TEST_F(NAIVE, RNN_HAND_MADE) {
Checker<RNN> checker(handle(), false);
size_t batch_size = 2;
size_t input_size = 3;
size_t hidden_size = 2;
size_t seq_len = 2;
size_t gate_hidden_size = hidden_size;
RNN::Param param;
param.num_layers = 1;
param.bidirectional = false;
param.bias = false;
param.hidden_size = hidden_size;
param.nonlineMode = param::RNN::NonlineMode::RELU;
checker.set_param(param).exect(
Testcase{
TensorValue(
{seq_len, batch_size, input_size}, dtype::Float32(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), // input
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{2, 1, 3, 5}), // hx
TensorValue(
{gate_hidden_size, input_size + hidden_size},
dtype::Float32(),
{3, 6, 1, 3, 2, 7, 9, 3, 5, 1}), // weights
{},
{},
{}},
Testcase{
{},
{},
{},
TensorValue(
{seq_len, batch_size, hidden_size}, dtype::Float32(),
{39, 39, 90, 84, 300, 216, 546, 366}), // output
TensorValue(
{batch_size, hidden_size}, dtype::Float32(),
{21, 11, 42, 20}), // hy
TensorValue(
{1, 2, 2, 2}, dtype::Float32(),
{2, 1, 3, 5, 21, 11, 42, 20}) // reserve space
});
}

} // namespace test
} // namespace megdnn

+ 1
- 0
imperative/python/megengine/module/__init__.py View File

@@ -36,5 +36,6 @@ from .padding import Pad
from .pixel_shuffle import PixelShuffle
from .pooling import AvgPool2d, MaxPool2d
from .quant_dequant import DequantStub, QuantStub
from .rnn import LSTM, RNN, LSTMCell, RNNCell
from .sequential import Sequential
from .sliding_window import SlidingWindow, SlidingWindowTranspose

+ 396
- 0
imperative/python/megengine/module/rnn.py View File

@@ -0,0 +1,396 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
import numbers
from abc import abstractmethod
from typing import Optional, Tuple

import numpy as np

from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..device import is_cuda_available
from ..functional import concat, expand_dims, repeat, stack, zeros
from ..functional.nn import concat
from ..tensor import Parameter, Tensor
from . import init
from .module import Module


class RNNCellBase(Module):
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool,
num_chunks: int,
device=None,
dtype=None,
) -> None:
# num_chunks indicates the number of gates
super(RNNCellBase, self).__init__()

self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias

# initialize weights
common_kwargs = {"device": device, "dtype": dtype}
self.gate_hidden_size = num_chunks * hidden_size
self.weight_ih = Parameter(
np.random.uniform(size=(self.gate_hidden_size, input_size)).astype(
np.float32
),
**common_kwargs,
)
self.weight_hh = Parameter(
np.random.uniform(size=(self.gate_hidden_size, hidden_size)).astype(
np.float32
),
**common_kwargs,
)
if bias:
self.bias_ih = Parameter(
np.random.uniform(size=(self.gate_hidden_size)).astype(np.float32),
**common_kwargs,
)
self.bias_hh = Parameter(
np.random.uniform(size=(self.gate_hidden_size)).astype(np.float32),
**common_kwargs,
)
else:
self.bias_ih = zeros(shape=(self.gate_hidden_size), **common_kwargs)
self.bias_hh = zeros(shape=(self.gate_hidden_size), **common_kwargs)
self.reset_parameters()
# if bias is False self.bias will remain zero

def get_op(self):
return builtin.RNNCell()

def reset_parameters(self) -> None:
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
init.uniform_(weight, -stdv, stdv)

def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
if hx is None:
hx = zeros(
shape=(input.shape[0], self.gate_hidden_size),
dtype=input.dtype,
device=input.device,
)
op = self.get_op()
return apply(
op, input, self.weight_ih, self.bias_ih, hx, self.weight_hh, self.bias_hh
)[0]
# return linear(input, self.weight_ih, self.bias_ih) + linear(hx, self.weight_hh, self.bias_hh)


class RNNCell(RNNCellBase):
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
nonlinearity: str = "tanh",
device=None,
dtype=None,
) -> None:
self.nonlinearity = nonlinearity
super(RNNCell, self).__init__(
input_size, hidden_size, bias, num_chunks=1, device=device, dtype=dtype
)
# self.activate = tanh if nonlinearity == "tanh" else relu

def get_op(self):
return builtin.RNNCell(nonlineMode=self.nonlinearity)

def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
return super().forward(input, hx)


class LSTMCell(RNNCellBase):
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super(LSTMCell, self).__init__(
input_size, hidden_size, bias, num_chunks=4, device=device, dtype=dtype
)

def get_op(self):
return builtin.LSTMCell()

def forward(
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Tensor]:
# hx: (h, c)
if hx is None:
h = zeros(
shape=(input.shape[0], self.hidden_size),
dtype=input.dtype,
device=input.device,
)
c = zeros(
shape=(input.shape[0], self.hidden_size),
dtype=input.dtype,
device=input.device,
)
else:
h, c = hx
op = self.get_op()
return apply(
op, input, self.weight_ih, self.bias_ih, h, self.weight_hh, self.bias_hh, c
)[:2]


def is_gpu(device: str) -> bool:
if "xpux" in device and is_cuda_available():
return True
if "gpu" in device:
return True
return False


class RNNBase(Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
bias: bool = True,
batch_first: bool = False,
dropout: float = 0,
bidirectional: bool = False,
proj_size: int = 0,
device=None,
dtype=None,
) -> None:
super(RNNBase, self).__init__()
# self.mode = mode
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = float(dropout)
self.bidirectional = bidirectional
self.num_directions = 2 if self.bidirectional else 1
self.proj_size = proj_size

# check validity of dropout
if (
not isinstance(dropout, numbers.Number)
or not 0 <= dropout <= 1
or isinstance(dropout, bool)
):
raise ValueError(
"Dropout should be a float in [0, 1], which indicates the probability "
"of an element to be zero"
)

if proj_size < 0:
raise ValueError(
"proj_size should be a positive integer or zero to disable projections"
)
elif proj_size >= hidden_size:
raise ValueError("proj_size has to be smaller than hidden_size")

self.cells = []
for layer in range(self.num_layers):
self.cells.append([])
for _ in range(self.num_directions):
self.cells[layer].append(self.create_cell(layer, device, dtype))
# parameters have been initialized during the creation of the cells
# if flatten, then delete cells
self._flatten_parameters(device, dtype, self.cells)

def _flatten_parameters(self, device, dtype, cells):
gate_hidden_size = cells[0][0].gate_hidden_size
size_dim1 = 0
for layer in range(self.num_layers):
for direction in range(self.num_directions):
size_dim1 += cells[layer][direction].weight_ih.shape[1]
size_dim1 += cells[layer][direction].weight_hh.shape[1]
# if self.bias:
# size_dim1 += 2 * self.num_directions * self.num_layers
size_dim1 += 2 * self.num_directions * self.num_layers
self._flatten_weights = Parameter(
np.zeros((gate_hidden_size, size_dim1), dtype=np.float32)
)
self.reset_parameters()
# TODO: if no bias, set the bias to zero

def reset_parameters(self) -> None:
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
init.uniform_(weight, -stdv, stdv)

@abstractmethod
def create_cell(self, layer, device, dtype):
raise NotImplementedError("Cell not implemented !")

@abstractmethod
def init_hidden(self):
raise NotImplementedError("init_hidden not implemented !")

@abstractmethod
def get_output_from_hidden(self, hx):
raise NotImplementedError("get_output_from_hidden not implemented !")

@abstractmethod
def apply_op(self, input, hx):
raise NotImplementedError("apply_op not implemented !")

def _apply_fn_to_hx(self, hx, fn):
return fn(hx)

def _stack_h_n(self, h_n):
return stack(h_n, axis=0)

def forward(self, input: Tensor, hx=None):
if self.batch_first:
batch_size = input.shape[0]
input = input.transpose((1, 0, 2)) # [seq_len, batch_size, dim]
else:
batch_size = input.shape[1]
if hx is None:
hx = self.init_hidden(batch_size, input.device, input.dtype)

output, h = self.apply_op(input, hx)
if self.batch_first:
output = output.transpose((1, 0, 2))
return output, h

if is_gpu(str(input.device)) or True:
# return output, h_n
output, h = self.apply_op(input, hx)
if self.batch_first:
output = output.transpose((1, 0, 2))
return output, h

order_settings = [(0, input.shape[0]), (input.shape[0] - 1, -1, -1)]
h_n = []
for layer in range(self.num_layers):
layer_outputs = []
for direction in range(self.num_directions):
direction_outputs = [None for _ in range(input.shape[0])]
cell = self.cells[layer][direction]
hidden = self._apply_fn_to_hx(
hx, lambda x: x[layer * self.num_directions + direction]
)
for step in range(*(order_settings[direction])):
hidden = cell(input[step], hidden) # [batch_size, hidden_size]
direction_outputs[step] = self.get_output_from_hidden(hidden)
direction_output = stack(
direction_outputs, axis=0
) # [seq_len, batch_size, hidden_size]
layer_outputs.append(direction_output)
h_n.append(hidden)
layer_output = concat(
layer_outputs, axis=-1
) # [seq_len, batch_size, D*hidden_size]
input = layer_output
if self.batch_first:
layer_output = layer_output.transpose((1, 0, 2))
return layer_output, self._stack_h_n(h_n)


class RNN(RNNBase):
def __init__(self, *args, **kwargs) -> None:
self.nonlinearity = kwargs.pop("nonlinearity", "tanh")
super(RNN, self).__init__(*args, **kwargs)

def create_cell(self, layer, device, dtype):
if layer == 0:
input_size = self.input_size
else:
input_size = self.num_directions * self.hidden_size
return RNNCell(
input_size, self.hidden_size, self.bias, self.nonlinearity, device, dtype
)

def init_hidden(self, batch_size, device, dtype):
hidden_shape = (
self.num_directions * self.num_layers,
batch_size,
self.hidden_size,
)
return zeros(shape=hidden_shape, dtype=dtype, device=device)

def get_output_from_hidden(self, hx):
return hx

def apply_op(self, input, hx):
op = builtin.RNN(
num_layers=self.num_layers,
bidirectional=self.bidirectional,
bias=self.bias,
hidden_size=self.hidden_size,
proj_size=self.proj_size,
dropout=self.dropout,
nonlineMode=self.nonlinearity,
)
output, h = apply(op, input, hx, self._flatten_weights)[:2]
output = output + h.sum() * 0
h = h + output.sum() * 0
return output, h


class LSTM(RNNBase):
def __init__(self, *args, **kwargs) -> None:
super(LSTM, self).__init__(*args, **kwargs)

def create_cell(self, layer, device, dtype):
if layer == 0:
input_size = self.input_size
else:
input_size = self.num_directions * self.hidden_size
return LSTMCell(input_size, self.hidden_size, self.bias, device, dtype)

def init_hidden(self, batch_size, device, dtype):
hidden_shape = (
self.num_directions * self.num_layers,
batch_size,
self.hidden_size,
)
h = zeros(shape=hidden_shape, dtype=dtype, device=device)
c = zeros(shape=hidden_shape, dtype=dtype, device=device)
return (h, c)

def get_output_from_hidden(self, hx):
return hx[0]

def apply_op(self, input, hx):
op = builtin.LSTM(
num_layers=self.num_layers,
bidirectional=self.bidirectional,
bias=self.bias,
hidden_size=self.hidden_size,
proj_size=self.proj_size,
dropout=self.dropout,
)
output, h, c = apply(op, input, hx[0], hx[1], self._flatten_weights)[:3]
placeholders = [output.sum() * 0, h.sum() * 0, c.sum() * 0]
output = output + placeholders[1] + placeholders[2]
h = h + placeholders[0] + placeholders[2]
c = c + placeholders[0] + placeholders[1]
return output, (h, c)

def _apply_fn_to_hx(self, hx, fn):
return (fn(hx[0]), fn(hx[1]))

def _stack_h_n(self, h_n):
h = [tup[0] for tup in h_n]
c = [tup[1] for tup in h_n]
return (stack(h, axis=0), stack(c, axis=0))

+ 181
- 0
imperative/python/test/unit/module/test_rnn.py View File

@@ -0,0 +1,181 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest

import megengine as mge
import megengine.functional as F
from megengine.module import LSTM, RNN, LSTMCell, RNNCell


def assert_tuple_equal(src, ref):
assert len(src) == len(ref)
for i, j in zip(src, ref):
assert i == j


@pytest.mark.parametrize(
"batch_size, input_size, hidden_size, init_hidden",
[(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False), (0, 10, 20, True)],
)
def test_rnn_cell(batch_size, input_size, hidden_size, init_hidden):
rnn_cell = RNNCell(input_size, hidden_size)
x = mge.random.normal(size=(batch_size, input_size))
if init_hidden:
h = F.zeros(shape=(batch_size, hidden_size))
else:
h = None
h_new = rnn_cell(x, h)
assert_tuple_equal(h_new.shape, (batch_size, hidden_size))


# is batch_size == 0 tolerated ? it will cause error in slice operation xx[:, ...]
@pytest.mark.parametrize(
"batch_size, input_size, hidden_size, init_hidden",
[(3, 10, 20, True), (3, 10, 20, False), (1, 10, 20, False)],
)
def test_lstm_cell(batch_size, input_size, hidden_size, init_hidden):
rnn_cell = LSTMCell(input_size, hidden_size)
x = mge.random.normal(size=(batch_size, input_size))
if init_hidden:
h = F.zeros(shape=(batch_size, hidden_size))
hx = (h, h)
else:
hx = None
h_new, c_new = rnn_cell(x, hx)
assert_tuple_equal(h_new.shape, (batch_size, hidden_size))
assert_tuple_equal(c_new.shape, (batch_size, hidden_size))


@pytest.mark.parametrize(
"batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first",
[
(3, 6, 10, 20, 2, False, False, True),
pytest.param(
3,
3,
10,
10,
1,
True,
True,
False,
marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"),
),
],
)
# (0, 1, 1, 1, 1, False, True, False)])
def test_rnn(
batch_size,
seq_len,
input_size,
hidden_size,
num_layers,
bidirectional,
init_hidden,
batch_first,
):
rnn = RNN(
input_size,
hidden_size,
batch_first=batch_first,
num_layers=num_layers,
bidirectional=bidirectional,
)
if batch_first:
x_shape = (batch_size, seq_len, input_size)
else:
x_shape = (seq_len, batch_size, input_size)
x = mge.random.normal(size=x_shape)
total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size
if init_hidden:
h = mge.random.normal(size=(batch_size, total_hidden_size))
else:
h = None
output, h_n = rnn(x, h)
num_directions = 2 if bidirectional else 1
if batch_first:
assert_tuple_equal(
output.shape, (batch_size, seq_len, num_directions * hidden_size)
)
else:
assert_tuple_equal(
output.shape, (seq_len, batch_size, num_directions * hidden_size)
)
assert_tuple_equal(
h_n.shape, (num_directions * num_layers, batch_size, hidden_size)
)


@pytest.mark.parametrize(
"batch_size, seq_len, input_size, hidden_size, num_layers, bidirectional, init_hidden, batch_first",
[
(3, 10, 20, 20, 1, False, False, True),
pytest.param(
3,
3,
10,
10,
1,
True,
True,
False,
marks=pytest.mark.skip(reason="bidirectional will cause cuda oom"),
),
],
)
# (0, 1, 1, 1, 1, False, True, False)])
def test_lstm(
batch_size,
seq_len,
input_size,
hidden_size,
num_layers,
bidirectional,
init_hidden,
batch_first,
):
rnn = LSTM(
input_size,
hidden_size,
batch_first=batch_first,
num_layers=num_layers,
bidirectional=bidirectional,
)
if batch_first:
x_shape = (batch_size, seq_len, input_size)
else:
x_shape = (seq_len, batch_size, input_size)
x = mge.random.normal(size=x_shape)
total_hidden_size = num_layers * (2 if bidirectional else 1) * hidden_size
if init_hidden:
h = mge.random.normal(size=(batch_size, total_hidden_size))
h = (h, h)
else:
h = None
output, h_n = rnn(x, h)
num_directions = 2 if bidirectional else 1
if batch_first:
assert_tuple_equal(
output.shape, (batch_size, seq_len, num_directions * hidden_size)
)
else:
assert_tuple_equal(
output.shape, (seq_len, batch_size, num_directions * hidden_size)
)
assert_tuple_equal(
h_n[0].shape, (num_directions * num_layers, batch_size, hidden_size)
)
assert_tuple_equal(
h_n[1].shape, (num_directions * num_layers, batch_size, hidden_size)
)


if __name__ == "__main__":
test_lstm(5, 10, 10, 20, 1, False, False, True)

+ 68
- 0
imperative/src/impl/ops/rnn.cpp View File

@@ -0,0 +1,68 @@
/**
* \file imperative/src/impl/ops/rnn.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "megbrain/opr/dnn/rnn.h"
#include "megbrain/imperative/ops/autogen.h"

#include "../op_trait.h"

namespace mgb::imperative {

namespace rnn_cell {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const RNNCell&>(def);
mgb_assert(inputs.size() == 6);
return opr::RNNCell::make(
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5],
op.param());
}
OP_TRAIT_REG(RNNCell, RNNCell).apply_on_var_node(apply_on_var_node).fallback();
} // namespace rnn_cell

namespace lstm_cell {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const LSTMCell&>(def);
mgb_assert(inputs.size() == 7);
auto* opr = opr::LSTMCell::make(
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4],
inputs[5], inputs[6], op.param())
.node()
->owner_opr();
return {opr->output(0), opr->output(1), opr->output(2)};
}
OP_TRAIT_REG(LSTMCell, LSTMCell).apply_on_var_node(apply_on_var_node).fallback();
} // namespace lstm_cell

namespace rnn {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const RNN&>(def);
mgb_assert(inputs.size() == 3);
auto* opr = opr::RNN::make(inputs[0], inputs[1], inputs[2], op.param())
.node()
->owner_opr();
return {opr->output(0), opr->output(1), opr->output(2)};
}
OP_TRAIT_REG(RNN, RNN).apply_on_var_node(apply_on_var_node).fallback();
} // namespace rnn

namespace lstm {
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const LSTM&>(def);
mgb_assert(inputs.size() == 4);
auto* opr = opr::LSTM::make(inputs[0], inputs[1], inputs[2], inputs[3], op.param())
.node()
->owner_opr();
return {opr->output(0), opr->output(1), opr->output(2), opr->output(3)};
}
OP_TRAIT_REG(LSTM, LSTM).apply_on_var_node(apply_on_var_node).fallback();
} // namespace lstm

} // namespace mgb::imperative

+ 8
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -431,4 +431,12 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>;

def LRN: MgbHashableOp<"LRN", [LRNParam]>;

def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>;

def LSTMCell: MgbHashableOp<"LSTMCell", [EmptyParam]>;

def RNN: MgbHashableOp<"RNN", [RNNParam]>;

def LSTM: MgbHashableOp<"LSTM", [LSTMParam]>;

#endif // MGB_OPS

+ 35
- 0
src/opr/impl/dnn/dnn.sereg.h View File

@@ -20,6 +20,7 @@
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/lsq.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/rnn.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/dnn/sliding_window_transpose.h"
@@ -292,6 +293,36 @@ struct OprMaker<opr::LSQBackward, 5> {
->owner_opr();
}
};

template <>
struct OprMaker<opr::RNNBackward, 7> {
using Param = opr::RNNBackward::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
return opr::RNNBackward::make(
i[0], i[1], i[2], i[3], i[4], i[5], i[6], param, config)[0]
.node()
->owner_opr();
}
};

template <>
struct OprMaker<opr::LSTMBackward, 9> {
using Param = opr::LSTMBackward::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
return opr::LSTMBackward::make(
i[0], i[1], i[2], i[3], i[4], i[5], i[6], i[7], i[8], param,
config)[0]
.node()
->owner_opr();
}
};

template <>
struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0>
: public GeneralOprLoadDumpImpl<
@@ -641,6 +672,10 @@ MGB_SEREG_OPR(TQT, 2);
MGB_SEREG_OPR(TQTBackward, 3);
MGB_SEREG_OPR(LSQ, 4);
MGB_SEREG_OPR(LSQBackward, 5);
MGB_SEREG_OPR(RNNForward, 3);
MGB_SEREG_OPR(RNNBackward, 7);
MGB_SEREG_OPR(LSTMForward, 4);
MGB_SEREG_OPR(LSTMBackward, 9);
} // namespace opr

} // namespace mgb


+ 323
- 0
src/opr/impl/dnn/rnn.cpp View File

@@ -0,0 +1,323 @@
/**
* \file src/opr/impl/dnn/rnn.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/opr/dnn/rnn.h"
#include "../internal/megdnn_opr_wrapper.inl"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"

using namespace mgb;
using namespace opr;

/* ================= RNNCell ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNNCellForward);
RNNCellForward::RNNCellForward(
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx,
VarNode* weight_hh, VarNode* bias_hh, const Param& param,
const OperatorNodeConfig& config)
: Super{input->owner_graph(),
config,
"rnn_cell",
{input, weight_ih, bias_ih, hx, weight_hh, bias_hh}} {
init_megdnn_opr(*this, param);
add_input({input, weight_ih, bias_ih, hx, weight_hh, bias_hh});
}

SymbolVar RNNCellForward::make(
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx,
SymbolVar weight_hh, SymbolVar bias_hh, const Param& param,
const OperatorNodeConfig& config) {
return input.insert_single_output_opr<RNNCellForward>(
input.node(), weight_ih.node(), bias_ih.node(), hx.node(), weight_hh.node(),
bias_hh.node(), param, config);
}

#if MGB_ENABLE_GRAD

VarNode* rnnCellBackward(
const SymbolVar& input, const SymbolVar& weight_ih, const SymbolVar& hx,
const SymbolVar& weight_hh, const SymbolVar& out,
RNNCell::NonlineMode nonlineMode, size_t wrt_idx, const SymbolVar& og) {
SymbolVar tmp;
// activation
using NonlineMode = RNNCell::NonlineMode;
using Mode = Elemwise::Mode;
switch (nonlineMode) {
case NonlineMode::IDENTITY:
tmp = og;
break;
case NonlineMode::TANH:
tmp = Elemwise::make({out, og}, Mode::TANH_GRAD);
break;
case NonlineMode::RELU:
tmp = Elemwise::make({out, og}, Mode::SWITCH_GT0);
break;
default:
mgb_throw(GraphError, "Activation method not supported");
}
// now grad is in tmp
if (wrt_idx == 2 || wrt_idx == 5)
return tmp.node(); // bias

SymbolVar result;
// A * Bt = C, A' = C' * B, B' = C't * A
if (wrt_idx == 0) { // input
result = MatrixMul::make(
tmp, weight_ih,
{false, false}); // transpose a false, transpose b false
} else if (wrt_idx == 1) { // weight_ih
result = MatrixMul::make(tmp, input, {true, false});
} else if (wrt_idx == 3) { // hx
result = MatrixMul::make(tmp, weight_hh, {false, false});
} else if (wrt_idx == 4) { // weight_hh
result = MatrixMul::make(tmp, hx, {true, false});
}
return result.node();
}

MGB_IMPL_OPR_GRAD(RNNCell) {
SymbolVar input(opr.input(0)), weight_ih(opr.input(1)), hx(opr.input(3)),
weight_hh(opr.input(4));
SymbolVar out(opr.output(0)), og{out_grad.at(0)};
return rnnCellBackward(
input, weight_ih, hx, weight_hh, out, opr.param().nonlineMode, wrt_idx, og);
}
#endif

/* ================= LSTMCell ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTMCell);
LSTMCellForward::LSTMCellForward(
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx,
VarNode* weight_hh, VarNode* bias_hh, VarNode* cx, const Param& param,
const OperatorNodeConfig& config)
: Super{input->owner_graph(),
config,
"lstm_cell",
{input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx}} {
init_megdnn_opr(*this, param);
add_input({input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx});
}

SymbolVar LSTMCellForward::make(
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx,
SymbolVar weight_hh, SymbolVar bias_hh, SymbolVar cx, const Param& param,
const OperatorNodeConfig& config) {
return input.insert_single_output_opr<LSTMCellForward>(
input.node(), weight_ih.node(), bias_ih.node(), hx.node(), weight_hh.node(),
bias_hh.node(), cx.node(), param, config);
}

#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LSTMCell) {
SymbolVar input(opr.input(0)), weight_ih(opr.input(1)), hx(opr.input(3)),
weight_hh(opr.input(4)), cx(opr.input(6));
SymbolVar h_out(opr.output(0)), c_out(opr.output(1)), gates(opr.output(2)),
h_og{out_grad.at(0)}, c_og{out_grad.at(1)}, tmp;
size_t ghs = gates.shape()[1] / 4; // gate_hidden_size
SymbolVarArray gates_array = Split::make(
gates, Split::Options::make_partition(gates, 1, {ghs, ghs, ghs, ghs}));
mgb_assert(gates_array.size() == 4);
using Mode = Elemwise::Mode;
const SymbolVar &i(Elemwise::make({gates_array.at(0)}, Mode::SIGMOID)),
f(Elemwise::make({gates_array.at(1)}, Mode::SIGMOID)),
o(Elemwise::make({gates_array.at(2)}, Mode::SIGMOID)),
g(Elemwise::make({gates_array.at(3)}, Mode::TANH));
SymbolVar i_grad, f_grad, o_grad, g_grad;

SymbolVar tanh_c_out = Elemwise::make({c_out}, Mode::TANH);
o_grad = Elemwise::make({o, h_og * tanh_c_out}, Mode::SIGMOID_GRAD);
c_og = c_og + Elemwise::make({tanh_c_out, h_og * o}, Mode::TANH_GRAD);
f_grad = Elemwise::make({f, c_og * cx}, Mode::SIGMOID_GRAD);
i_grad = Elemwise::make({i, c_og * g}, Mode::SIGMOID_GRAD);
g_grad = Elemwise::make({g, c_og * i}, Mode::TANH_GRAD);
SymbolVar rnn_cell_grad = Concat::make({i_grad, f_grad, o_grad, g_grad}, {-1});

SymbolVar result;
if (wrt_idx < 6) {
using NonlineMode = RNNCell::NonlineMode;
result = rnnCellBackward(
input, weight_ih, hx, weight_hh, gates, NonlineMode::IDENTITY, wrt_idx,
rnn_cell_grad);
} else { // cx
result = c_og * f;
}
return result.node();
}
#endif

/* ================= RNN ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNN);
MEGDNN_OPR_INIT3(RNNForward, "rnn_fwd");

#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(RNN) {
mgb_assert(
opr.param().fwd_mode == RNN::Param::FwdMode::TRAINING,
"RNN could only take grad in training mode");
SymbolVarArray grads = RNNBackward::make(
opr.input(0), opr.output(0), opr.input(1), out_grad.at(0), out_grad.at(1),
opr.input(2), opr.output(2), opr.param());
// return grads.at(wrt_idx).node(); // input, hx, weights
VarNodeArray ret(opr.input().size(), nullptr);
for (size_t i = 0; i < ret.size(); ++i) {
ret[i] = grads[i].node();
}
return ret;
}
#endif

/* ================= RNNBackward ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNNBackward);

RNNBackward::RNNBackward(
VarNode* x, VarNode* y, VarNode* hx, VarNode* dy, VarNode* dhy,
VarNode* flatten_weights, VarNode* reserve_space, const Param& param,
const OperatorNodeConfig& config)
: Super({x->owner_graph(),
config,
"rnn_bwd",
{x, y, hx, dy, dhy, flatten_weights, reserve_space}},
0, true) {
init_megdnn_opr(*this, param);
add_input({x, y, hx, dy, dhy, flatten_weights, reserve_space});
}

SymbolVarArray RNNBackward::make(
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar dy, SymbolVar dhy,
SymbolVar flatten_weights, SymbolVar reserve_space, const Param& param,
const OperatorNodeConfig& config) {
auto&& out = x.node()->owner_graph()
->insert_opr(std::make_unique<RNNBackward>(
x.node(), y.node(), hx.node(), dy.node(), dhy.node(),
flatten_weights.node(), reserve_space.node(), param,
config))
->output();
SymbolVarArray ret(out.size());
for (size_t i = 0; i < ret.size(); ++i) {
ret[i] = out[i];
}
return ret;
}

RNNBackward::Super::NodeProp* RNNBackward::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(6), NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}

void RNNBackward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();

mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0)));
mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2)));
mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(5)));
this->init_output_static_infer_desc_workspace(
intl::AutoAddWorkspaceNeedLimitGetter<megdnn::RNNBackward>::val);
}

void RNNBackward::init_output_dtype() {
output(0)->dtype(input(0)->dtype());
output(1)->dtype(input(2)->dtype());
output(2)->dtype(input(5)->dtype());
}

/* ================= LSTM ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTM);
LSTMForward::LSTMForward(
VarNode* input, VarNode* hx, VarNode* cx, VarNode* flatten_weights,
const Param& param, const OperatorNodeConfig& config)
: Super{input->owner_graph(),
config,
"lstm",
{input, hx, cx, flatten_weights}} {
init_megdnn_opr(*this, param);
add_input({input, hx, cx, flatten_weights});
}

SymbolVar LSTMForward::make(
SymbolVar input, SymbolVar hx, SymbolVar cx, SymbolVar flatten_weights,
const Param& param, const OperatorNodeConfig& config) {
return input.insert_single_output_opr<LSTMForward>(
input.node(), hx.node(), cx.node(), flatten_weights.node(), param, config);
}

#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(LSTM) {
SymbolVarArray grads = LSTMBackward::make(
opr.input(0), opr.output(0), opr.input(1), opr.input(2), out_grad.at(0),
out_grad.at(1), out_grad.at(2), opr.input(3), opr.output(3), opr.param());
SymbolVar res;
return grads.at(wrt_idx).node(); // input, hx, cx, weights
}
#endif

/* ================= LSTMBackward ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSTMBackward);
LSTMBackward::LSTMBackward(
VarNode* x, VarNode* y, VarNode* hx, VarNode* cx, VarNode* dy, VarNode* dhy,
VarNode* dcy, VarNode* flatten_weights, VarNode* reserve_space,
const Param& param, const OperatorNodeConfig& config)
: Super({x->owner_graph(),
config,
"lstm_bwd",
{x, y, hx, cx, dy, dhy, dcy, flatten_weights, reserve_space}},
1, true) {
init_megdnn_opr(*this, param);
add_input({x, y, hx, cx, dy, dhy, dcy, flatten_weights, reserve_space});
}

SymbolVarArray LSTMBackward::make(
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar cx, SymbolVar dy,
SymbolVar dhy, SymbolVar dcy, SymbolVar flatten_weights,
SymbolVar reserve_space, const Param& param, const OperatorNodeConfig& config) {
auto&& out = x.node()->owner_graph()
->insert_opr(std::make_unique<LSTMBackward>(
x.node(), y.node(), hx.node(), cx.node(), dy.node(),
dhy.node(), dcy.node(), flatten_weights.node(),
reserve_space.node(), param, config))
->output();
SymbolVarArray ret(out.size());
for (size_t i = 0; i < ret.size(); ++i) {
ret[i] = out[i];
}
return ret;
}

LSTMBackward::Super::NodeProp* LSTMBackward::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(
input(8), // reserve space
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}

void LSTMBackward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();

mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0)));
mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2)));
mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(3)));
mgr.register_shape_infer(output(3), ShapeInferDesc::make_identity(input(7)));
this->init_output_static_infer_desc_workspace(
intl::AutoAddWorkspaceNeedLimitGetter<megdnn::LSTMBackward>::val);
}

void LSTMBackward::init_output_dtype() {
output(0)->dtype(input(0)->dtype());
output(1)->dtype(input(2)->dtype());
output(2)->dtype(input(3)->dtype());
output(3)->dtype(input(7)->dtype());
}

+ 33
- 0
src/opr/impl/internal/megdnn_opr_wrapper.inl View File

@@ -170,6 +170,16 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS 4
#define _NR_OUTPUTS 4
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0), _o(1), _o(2), _o(3)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS 5
#define _NR_OUTPUTS 1
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS 5
#define _NR_OUTPUTS 2
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1)
@@ -180,11 +190,34 @@ using MegDNNOprMethInvoker = _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPU
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS 6
#define _NR_OUTPUTS 1
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS 6
#define _NR_OUTPUTS 2
#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS 6
#define _NR_OUTPUTS 3
#define _FOREACH_IO(_i, _o) \
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS 7
#define _NR_OUTPUTS 3
#define _FOREACH_IO(_i, _o) \
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _o(0), _o(1), _o(2)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"

#define _NR_INPUTS 9
#define _NR_OUTPUTS 4
#define _FOREACH_IO(_i, _o) \
_i(0), _i(1), _i(2), _i(3), _i(4), _i(5), _i(6), _i(7), _i(8), _o(0), _o(1), \
_o(2), _o(3)
#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
} // anonymous namespace

/* ======================= MegDNNOprWrapperFwd ======================= */


+ 120
- 0
src/opr/include/megbrain/opr/dnn/rnn.h View File

@@ -0,0 +1,120 @@
/**
* \file src/opr/include/megbrain/opr/dnn/rnn.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#if MGB_CUDA
#include "../../../../impl/nvof/denseflownvidia.h"
#include "megbrain/opr/param_defs.h"
#endif
#include "megdnn/oprs.h"

namespace mgb {
namespace opr {
MGB_DEFINE_OPR_CLASS(
RNNCellForward, intl::MegDNNOprWrapperFwd<megdnn::RNNCellForward>) // {
public:
using NonlineMode = Param::NonlineMode;

RNNCellForward(
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx,
VarNode* weight_hh, VarNode* bias_hh, const Param& param,
const OperatorNodeConfig& config);
static SymbolVar make(
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx,
SymbolVar weight_hh, SymbolVar bias_hh, const Param& param = {},
const OperatorNodeConfig& config = {});
};
using RNNCell = RNNCellForward;

MGB_DEFINE_OPR_CLASS(
LSTMCellForward, intl::MegDNNOprWrapperFwd<megdnn::LSTMCellForward>) // {
public:
LSTMCellForward(
VarNode* input, VarNode* weight_ih, VarNode* bias_ih, VarNode* hx,
VarNode* weight_hh, VarNode* bias_hh, VarNode* cx, const Param& param,
const OperatorNodeConfig& config);
static SymbolVar make(
SymbolVar input, SymbolVar weight_ih, SymbolVar bias_ih, SymbolVar hx,
SymbolVar weight_hh, SymbolVar bias_hh, SymbolVar cx,
const Param& param = {}, const OperatorNodeConfig& config = {});
};
using LSTMCell = LSTMCellForward;

MGB_DEFINE_OPR_CLASS(RNNForward, intl::MegDNNOprWrapperFwd<megdnn::RNNForward>) // {
/*private:
SymbolVarArray weight_ih_arr; // 1d, idx: direction * num_layers + layer
SymbolVarArray weight_hh_arr;
SymbolVarArray bias_arr;
*/

public:
RNNForward(
VarNode* input, VarNode* hx, VarNode* flatten_weights, const Param& param,
const OperatorNodeConfig& config);
static SymbolVar make(
SymbolVar input, SymbolVar hx, SymbolVar flatten_weights,
const Param& param = {}, const OperatorNodeConfig& config = {});
};
using RNN = RNNForward;

MGB_DEFINE_OPR_CLASS(
RNNBackward, intl::MegDNNOprWrapperBwd<megdnn::RNNBackward>) // {
public:
RNNBackward(
VarNode* x, VarNode* y, VarNode* hx, VarNode* dy, VarNode* dhy,
VarNode* flatten_weights, VarNode* reserve_space, const Param& param,
const OperatorNodeConfig& config);
static SymbolVarArray make(
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar dy, SymbolVar dhy,
SymbolVar flatten_weights, SymbolVar reserve_space, const Param& param = {},
const OperatorNodeConfig& config = {});
Super::NodeProp* do_make_node_prop() const override;

private:
void init_output_static_infer_desc() override;
void init_output_dtype() override;
};

MGB_DEFINE_OPR_CLASS(
LSTMForward, intl::MegDNNOprWrapperFwd<megdnn::LSTMForward>) // {
public:
LSTMForward(
VarNode* input, VarNode* hx, VarNode* cx, VarNode* flatten_weights,
const Param& param, const OperatorNodeConfig& config);
static SymbolVar make(
SymbolVar input, SymbolVar hx, SymbolVar cx, SymbolVar flatten_weights,
const Param& param = {}, const OperatorNodeConfig& config = {});
};
using LSTM = LSTMForward;

MGB_DEFINE_OPR_CLASS(
LSTMBackward, intl::MegDNNOprWrapperBwd<megdnn::LSTMBackward>) // {
public:
LSTMBackward(
VarNode* x, VarNode* y, VarNode* hx, VarNode* cx, VarNode* dy, VarNode* dhy,
VarNode* dcy, VarNode* flatten_weights, VarNode* reserve_space,
const Param& param, const OperatorNodeConfig& config);
static SymbolVarArray make(
SymbolVar x, SymbolVar y, SymbolVar hx, SymbolVar cx, SymbolVar dy,
SymbolVar dhy, SymbolVar dcy, SymbolVar flatten_weights,
SymbolVar reserve_space, const Param& param = {},
const OperatorNodeConfig& config = {});
Super::NodeProp* do_make_node_prop() const override;

private:
void init_output_static_infer_desc() override;
void init_output_dtype() override;
};

} // namespace opr
} // namespace mgb

+ 3
- 0
src/serialization/impl/schema.fbs View File

@@ -116,6 +116,9 @@ union OperatorParam {
param.Padding = 82,
param.ShuffleRNG = 83,
param.CheckNonFinite = 84,
param.RNNCell = 85,
param.RNN = 86,
param.LSTM = 87,
}

table Operator {


Loading…
Cancel
Save