|
- #include "megdnn/oprs.h"
- #include "src/common/utils.h"
-
- namespace megdnn {
-
- void LSTM::deduce_layout(
- const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
- const TensorLayout& /*flatten_weights*/, TensorLayout& output, TensorLayout& hy,
- TensorLayout& cy, TensorLayout& reserve_space) {
- // input: [seq_len, batch_size, input_size]
- // hx: [D * num_layers, batch_size, hidden_size]
- size_t seq_len = input.shape[0];
- size_t batch_size = input.shape[1];
- size_t D = param().bidirectional ? 2 : 1;
- size_t hidden_size = hx.shape[2];
- output = TensorLayout(
- TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype);
- hy = TensorLayout(hx);
- cy = TensorLayout(cx);
- reserve_space = {{get_reserve_size_in_bytes(input)}, input.dtype};
- }
-
- void LSTM::check_exec(
- const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
- const TensorLayout& flatten_weights, const TensorLayout& output,
- const TensorLayout& hy, const TensorLayout& cy,
- const TensorLayout& /*reserve_space*/, size_t /*workspace_in_bytes*/) {
- auto errmsg = [&]() {
- std::string msg;
- msg.append("input=");
- msg.append(input.to_string());
- msg.append(", output=");
- msg.append(output.to_string());
- msg.append(", hx=");
- msg.append(hx.to_string());
- msg.append(", cx=");
- msg.append(cx.to_string());
- msg.append(", hy=");
- msg.append(hy.to_string());
- msg.append(", cy=");
- msg.append(cy.to_string());
- msg.append(", flatten_weights=");
- msg.append(flatten_weights.to_string());
- msg.append(", hidden_size=");
- msg.append(std::to_string(param().hidden_size));
- msg.append(", num_layers=");
- msg.append(std::to_string(param().num_layers));
- msg.append(", bidirectional=");
- msg.append(std::to_string(param().bidirectional));
- return msg;
- };
- size_t D = param().bidirectional ? 2 : 1;
- size_t b = param().bias ? 1 : 0;
- size_t num_layers = param().num_layers;
- size_t input_size = input.shape[2];
- size_t gate_hidden_size = 4 * param().hidden_size;
- // first layer{ weight_ih_l[k][_reverse].shape = (4*hidden_size, input_size)
- // weight_hh_l[k][_reverse].shape = (4*hidden_size, hidden_size)}
- // other layers{ weight_ih_l[k][_reverse].shape = (4*hidden_size, num_directions *
- // hidden_size)
- // weight_hh_l[k][_reverse].shape = (4*hidden_size, hidden_size)}
- // bias: 2 * num_directions * num_layers
- // size_dim1 = D * first layer + (layer -1) * other layer + bias
- size_t size_dim1 = D * (input_size + param().hidden_size) +
- (num_layers - 1) * D * ((D + 1) * param().hidden_size) +
- b * 2 * D * num_layers;
- #define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());
-
- ASSERT_BRIEF(input.ndim == 3)
- ASSERT_BRIEF(output.ndim == 3)
- ASSERT_BRIEF(flatten_weights.shape[0] == gate_hidden_size)
- ASSERT_BRIEF(flatten_weights.shape[0] == size_dim1)
- ASSERT_BRIEF(output.shape[0] == input.shape[0])
- ASSERT_BRIEF(output.shape[1] == input.shape[1])
- ASSERT_BRIEF(output.shape[2] == D * param().hidden_size)
- ASSERT_BRIEF(hx.ndim == 3)
- ASSERT_BRIEF(hx.shape[0] == D * num_layers)
- ASSERT_BRIEF(hx.shape[1] == input.shape[1]) // batch_size
- ASSERT_BRIEF(hx.shape[2] == param().hidden_size)
- ASSERT_BRIEF(cx.ndim == 3)
- ASSERT_BRIEF(cx.shape[0] == D * num_layers)
- ASSERT_BRIEF(cx.shape[1] == input.shape[1]) // batch_size
- ASSERT_BRIEF(cx.shape[2] == param().hidden_size)
- ASSERT_BRIEF(hy.ndim == 3)
- ASSERT_BRIEF(hy.shape[0] == D * num_layers)
- ASSERT_BRIEF(hy.shape[1] == input.shape[1]) // batch_size
- ASSERT_BRIEF(hy.shape[2] == param().hidden_size)
- ASSERT_BRIEF(cy.ndim == 3)
- ASSERT_BRIEF(cy.shape[0] == D * num_layers)
- ASSERT_BRIEF(cy.shape[1] == input.shape[1]) // batch_size
- ASSERT_BRIEF(cy.shape[2] == param().hidden_size)
- #undef ASSERT_BRIEF
- }
-
- void LSTMBackward::deduce_layout(
- const TensorLayout& x, const TensorLayout& /*y*/, const TensorLayout& hx,
- const TensorLayout& cx, const TensorLayout& /*dy*/, const TensorLayout& /*dhy*/,
- const TensorLayout& /*dcy*/, const TensorLayout& flatten_weights,
- const TensorLayout& /*reserve_space*/, TensorLayout& dx, TensorLayout& dhx,
- TensorLayout& dcx, TensorLayout& dw) {
- dx = x;
- dhx = hx;
- dcx = cx;
- dw = flatten_weights;
- }
-
- // TODO: add shape check of BWD
- void LSTMBackward::check_exec(
- const TensorLayout& /*x*/, const TensorLayout& /*y*/,
- const TensorLayout& /*hx*/, const TensorLayout& /*cx*/,
- const TensorLayout& /*dy*/, const TensorLayout& /*dhy*/,
- const TensorLayout& /*dcy*/, const TensorLayout& /*flatten_weights*/,
- const TensorLayout& /*reserve_space*/, const TensorLayout& /*dx*/,
- const TensorLayout& /*dhx*/, const TensorLayout& /*dcx*/,
- const TensorLayout& /*dw*/, size_t /*workspace_in_bytes*/) {}
-
- } // namespace megdnn
|