You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lstm.cpp 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #include "megdnn/oprs.h"
  2. #include "src/common/utils.h"
  3. namespace megdnn {
  4. void LSTM::deduce_layout(
  5. const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
  6. const TensorLayout& /*flatten_weights*/, TensorLayout& output, TensorLayout& hy,
  7. TensorLayout& cy, TensorLayout& reserve_space) {
  8. // input: [seq_len, batch_size, input_size]
  9. // hx: [D * num_layers, batch_size, hidden_size]
  10. size_t seq_len = input.shape[0];
  11. size_t batch_size = input.shape[1];
  12. size_t D = param().bidirectional ? 2 : 1;
  13. size_t hidden_size = hx.shape[2];
  14. output = TensorLayout(
  15. TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype);
  16. hy = TensorLayout(hx);
  17. cy = TensorLayout(cx);
  18. reserve_space = {{get_reserve_size_in_bytes(input)}, input.dtype};
  19. }
  20. void LSTM::check_exec(
  21. const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
  22. const TensorLayout& flatten_weights, const TensorLayout& output,
  23. const TensorLayout& hy, const TensorLayout& cy,
  24. const TensorLayout& /*reserve_space*/, size_t /*workspace_in_bytes*/) {
  25. auto errmsg = [&]() {
  26. std::string msg;
  27. msg.append("input=");
  28. msg.append(input.to_string());
  29. msg.append(", output=");
  30. msg.append(output.to_string());
  31. msg.append(", hx=");
  32. msg.append(hx.to_string());
  33. msg.append(", cx=");
  34. msg.append(cx.to_string());
  35. msg.append(", hy=");
  36. msg.append(hy.to_string());
  37. msg.append(", cy=");
  38. msg.append(cy.to_string());
  39. msg.append(", flatten_weights=");
  40. msg.append(flatten_weights.to_string());
  41. msg.append(", hidden_size=");
  42. msg.append(std::to_string(param().hidden_size));
  43. msg.append(", num_layers=");
  44. msg.append(std::to_string(param().num_layers));
  45. msg.append(", bidirectional=");
  46. msg.append(std::to_string(param().bidirectional));
  47. return msg;
  48. };
  49. size_t D = param().bidirectional ? 2 : 1;
  50. size_t b = param().bias ? 1 : 0;
  51. size_t num_layers = param().num_layers;
  52. size_t input_size = input.shape[2];
  53. size_t gate_hidden_size = 4 * param().hidden_size;
  54. // first layer{ weight_ih_l[k][_reverse].shape = (4*hidden_size, input_size)
  55. // weight_hh_l[k][_reverse].shape = (4*hidden_size, hidden_size)}
  56. // other layers{ weight_ih_l[k][_reverse].shape = (4*hidden_size, num_directions *
  57. // hidden_size)
  58. // weight_hh_l[k][_reverse].shape = (4*hidden_size, hidden_size)}
  59. // bias: 2 * num_directions * num_layers
  60. // size_dim1 = D * first layer + (layer -1) * other layer + bias
  61. size_t size_dim1 = D * (input_size + param().hidden_size) +
  62. (num_layers - 1) * D * ((D + 1) * param().hidden_size) +
  63. b * 2 * D * num_layers;
  64. #define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());
  65. ASSERT_BRIEF(input.ndim == 3)
  66. ASSERT_BRIEF(output.ndim == 3)
  67. ASSERT_BRIEF(flatten_weights.shape[0] == gate_hidden_size)
  68. ASSERT_BRIEF(flatten_weights.shape[0] == size_dim1)
  69. ASSERT_BRIEF(output.shape[0] == input.shape[0])
  70. ASSERT_BRIEF(output.shape[1] == input.shape[1])
  71. ASSERT_BRIEF(output.shape[2] == D * param().hidden_size)
  72. ASSERT_BRIEF(hx.ndim == 3)
  73. ASSERT_BRIEF(hx.shape[0] == D * num_layers)
  74. ASSERT_BRIEF(hx.shape[1] == input.shape[1]) // batch_size
  75. ASSERT_BRIEF(hx.shape[2] == param().hidden_size)
  76. ASSERT_BRIEF(cx.ndim == 3)
  77. ASSERT_BRIEF(cx.shape[0] == D * num_layers)
  78. ASSERT_BRIEF(cx.shape[1] == input.shape[1]) // batch_size
  79. ASSERT_BRIEF(cx.shape[2] == param().hidden_size)
  80. ASSERT_BRIEF(hy.ndim == 3)
  81. ASSERT_BRIEF(hy.shape[0] == D * num_layers)
  82. ASSERT_BRIEF(hy.shape[1] == input.shape[1]) // batch_size
  83. ASSERT_BRIEF(hy.shape[2] == param().hidden_size)
  84. ASSERT_BRIEF(cy.ndim == 3)
  85. ASSERT_BRIEF(cy.shape[0] == D * num_layers)
  86. ASSERT_BRIEF(cy.shape[1] == input.shape[1]) // batch_size
  87. ASSERT_BRIEF(cy.shape[2] == param().hidden_size)
  88. #undef ASSERT_BRIEF
  89. }
  90. void LSTMBackward::deduce_layout(
  91. const TensorLayout& x, const TensorLayout& /*y*/, const TensorLayout& hx,
  92. const TensorLayout& cx, const TensorLayout& /*dy*/, const TensorLayout& /*dhy*/,
  93. const TensorLayout& /*dcy*/, const TensorLayout& flatten_weights,
  94. const TensorLayout& /*reserve_space*/, TensorLayout& dx, TensorLayout& dhx,
  95. TensorLayout& dcx, TensorLayout& dw) {
  96. dx = x;
  97. dhx = hx;
  98. dcx = cx;
  99. dw = flatten_weights;
  100. }
  101. // TODO: add shape check of BWD
  102. void LSTMBackward::check_exec(
  103. const TensorLayout& /*x*/, const TensorLayout& /*y*/,
  104. const TensorLayout& /*hx*/, const TensorLayout& /*cx*/,
  105. const TensorLayout& /*dy*/, const TensorLayout& /*dhy*/,
  106. const TensorLayout& /*dcy*/, const TensorLayout& /*flatten_weights*/,
  107. const TensorLayout& /*reserve_space*/, const TensorLayout& /*dx*/,
  108. const TensorLayout& /*dhx*/, const TensorLayout& /*dcx*/,
  109. const TensorLayout& /*dw*/, size_t /*workspace_in_bytes*/) {}
  110. } // namespace megdnn