GitOrigin-RevId: b9bb7352bc
tags/v1.8.0
@@ -30,18 +30,18 @@ struct TanhOpBase : UnaryOpBase<src_ctype, dst_ctype> { | |||||
template <typename src_ctype, typename dst_type = src_ctype> | template <typename src_ctype, typename dst_type = src_ctype> | ||||
struct TanhOp; | struct TanhOp; | ||||
#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ | |||||
#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ | |||||
template <> \ | template <> \ | ||||
struct TanhOp<_ctype> : TanhOpBase<_ctype> { \ | struct TanhOp<_ctype> : TanhOpBase<_ctype> { \ | ||||
using TanhOpBase::TanhOpBase; \ | using TanhOpBase::TanhOpBase; \ | ||||
using TanhOpBase::operator(); \ | using TanhOpBase::operator(); \ | ||||
constexpr static size_t SIMD_WIDTH = _simd_width; \ | constexpr static size_t SIMD_WIDTH = _simd_width; \ | ||||
void operator()(const _neon_type& src, _ctype* dst) const { \ | |||||
void operator()(const _neon_type2& src, _ctype* dst) const { \ | |||||
auto vitem = operator()(src); \ | auto vitem = operator()(src); \ | ||||
vst1q_##_func_suffix(dst, vitem.val[0]); \ | vst1q_##_func_suffix(dst, vitem.val[0]); \ | ||||
vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ | ||||
} \ | } \ | ||||
_neon_type operator()(const _neon_type& src) const { \ | |||||
_neon_type2 operator()(const _neon_type2& src) const { \ | |||||
auto one_val = vdupq_n_##_func_suffix(1.f); \ | auto one_val = vdupq_n_##_func_suffix(1.f); \ | ||||
auto two_val = vdupq_n_##_func_suffix(2.f); \ | auto two_val = vdupq_n_##_func_suffix(2.f); \ | ||||
auto val1 = src.val[0]; \ | auto val1 = src.val[0]; \ | ||||
@@ -62,10 +62,23 @@ struct TanhOp; | |||||
val2 = vsubq_##_func_suffix(one_val, val2); \ | val2 = vsubq_##_func_suffix(one_val, val2); \ | ||||
return {{val1, val2}}; \ | return {{val1, val2}}; \ | ||||
} \ | } \ | ||||
_neon_type operator()(const _neon_type& src) const { \ | |||||
auto one_val = vdupq_n_##_func_suffix(1.f); \ | |||||
auto two_val = vdupq_n_##_func_suffix(2.f); \ | |||||
auto val1 = src; \ | |||||
val1 = vmulq_##_func_suffix(two_val, val1); \ | |||||
val1 = exp_ps_##_func_suffix(val1); \ | |||||
val1 = vaddq_##_func_suffix(one_val, val1); \ | |||||
auto rval1 = vrecpeq_##_func_suffix(val1); \ | |||||
rval1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(val1, rval1), rval1); \ | |||||
val1 = vmulq_##_func_suffix(two_val, rval1); \ | |||||
val1 = vsubq_##_func_suffix(one_val, val1); \ | |||||
return val1; \ | |||||
} \ | |||||
}; | }; | ||||
OP(dt_float32, float32x4x2_t, f32, 4) | |||||
OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
OP(__fp16, float16x8x2_t, f16, 8) | |||||
OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) | |||||
#endif | #endif | ||||
#undef OP | #undef OP | ||||
@@ -19,9 +19,12 @@ | |||||
#include "src/arm_common/elemwise/opr_impl.h" | #include "src/arm_common/elemwise/opr_impl.h" | ||||
#include "src/arm_common/elemwise_multi_type/opr_impl.h" | #include "src/arm_common/elemwise_multi_type/opr_impl.h" | ||||
#include "src/arm_common/local/opr_impl.h" | #include "src/arm_common/local/opr_impl.h" | ||||
#include "src/arm_common/lstm/opr_impl.h" | |||||
#include "src/arm_common/lstm_cell/opr_impl.h" | |||||
#include "src/arm_common/pooling/opr_impl.h" | #include "src/arm_common/pooling/opr_impl.h" | ||||
#include "src/arm_common/reduce/opr_impl.h" | #include "src/arm_common/reduce/opr_impl.h" | ||||
#include "src/arm_common/resize/opr_impl.h" | #include "src/arm_common/resize/opr_impl.h" | ||||
#include "src/arm_common/rnn_cell/opr_impl.h" | |||||
#include "src/arm_common/separable_conv/opr_impl.h" | #include "src/arm_common/separable_conv/opr_impl.h" | ||||
#include "src/arm_common/separable_filter/opr_impl.h" | #include "src/arm_common/separable_filter/opr_impl.h" | ||||
#include "src/arm_common/type_cvt/opr_impl.h" | #include "src/arm_common/type_cvt/opr_impl.h" | ||||
@@ -50,6 +53,9 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt) | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Reduce) | MEGDNN_SPECIALIZE_CREATE_OPERATOR(Reduce) | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) | MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData) | MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData) | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RNNCell) | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTMCell) | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTM) | |||||
#pragma GCC diagnostic push | #pragma GCC diagnostic push | ||||
#pragma GCC diagnostic ignored "-Wpragmas" | #pragma GCC diagnostic ignored "-Wpragmas" | ||||
@@ -0,0 +1,107 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/lstm/lstm_utils.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "./lstm_utils.h" | |||||
#include "src/arm_common/lstm/opr_impl.h" | |||||
#include "src/arm_common/lstm_cell/cell_kernel.h" | |||||
#include "src/arm_common/lstm_cell/opr_impl.h" | |||||
#include "src/naive/handle.h" | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
LstmCellWeight::LstmCellWeight( | |||||
RefPtr weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, | |||||
DType dtype) { | |||||
// weight_ih: [gate_hidden_size, input_size] | |||||
// weight_hh: [gate_hidden_size, hidden_size] | |||||
// bias_ih: [gate_hidden_size] | |||||
// bias_hh: [gate_hidden_size] | |||||
size_t gate_hidden_size = 4 * hidden_size; | |||||
TensorLayout weight_ih_layout{{gate_hidden_size, input_size}, dtype}; | |||||
TensorLayout weight_hh_layout{{gate_hidden_size, hidden_size}, dtype}; | |||||
TensorLayout bias_layout{{gate_hidden_size}, dtype}; | |||||
m_weight_size = 0; | |||||
m_weight_ih = TensorND(weight_ih_layout, weight_ptr); | |||||
m_weight_size += weight_ih_layout.span().dist_byte(); | |||||
weight_ptr += weight_ih_layout.span().dist_byte(); | |||||
m_weight_hh = TensorND(weight_hh_layout, weight_ptr); | |||||
m_weight_size += weight_hh_layout.span().dist_byte(); | |||||
weight_ptr += weight_hh_layout.span().dist_byte(); | |||||
if (has_bias) { | |||||
m_bias_ih = TensorND(bias_layout, weight_ptr); | |||||
m_weight_size += bias_layout.span().dist_byte(); | |||||
weight_ptr += bias_layout.span().dist_byte(); | |||||
m_bias_hh = TensorND(bias_layout, weight_ptr); | |||||
m_weight_size += bias_layout.span().dist_byte(); | |||||
} | |||||
} | |||||
LstmStates::LstmStates( | |||||
const SmallVector<RefPtr> ptr, size_t hidden_size, size_t batch_size, | |||||
DType dtype) { | |||||
auto& h_ptr = ptr[0]; | |||||
auto& c_ptr = ptr[1]; | |||||
TensorLayout layout{{batch_size, hidden_size}, dtype}; | |||||
m_h = TensorND(layout, h_ptr); | |||||
m_c = TensorND(layout, c_ptr); | |||||
m_memory_size = layout.span().dist_byte(); | |||||
} | |||||
TensorNDArray megdnn::arm_common::split_tensor( | |||||
_megdnn_tensor_in tensor, size_t nr_tensor, const TensorLayout& layout) { | |||||
megdnn_assert( | |||||
tensor.layout.span().dist_byte() == nr_tensor * layout.span().dist_byte()); | |||||
TensorNDArray tensors; | |||||
auto ptr = tensor.get_ref_ptr(); | |||||
for (size_t i = 0; i < nr_tensor; i++) { | |||||
tensors.push_back(TensorND(layout, ptr)); | |||||
ptr += layout.span().dist_byte(); | |||||
} | |||||
return tensors; | |||||
} | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
template <> | |||||
void cell_opr_compute<LSTMCell, LstmStates>( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in bias_hh, const LstmStates& state_in, LstmStates& state_out, | |||||
Workspace cell_workspace, Handle* handle) { | |||||
auto opr = handle->create_operator<LSTMCellForward>(); | |||||
TensorLayout gates, h_new, c_new; | |||||
opr->deduce_layout( | |||||
input.layout, weight_ih.layout, bias_ih.layout, state_in.m_h.layout, | |||||
weight_hh.layout, bias_hh.layout, state_in.m_c.layout, h_new, c_new, gates); | |||||
auto workspace_bundle = LstmCellCompute::get_workspace_bundle( | |||||
input.layout, weight_ih.layout, bias_ih.layout, state_in.m_h.layout, | |||||
weight_hh.layout, bias_hh.layout, state_in.m_c.layout, h_new, c_new, gates); | |||||
workspace_bundle.set(cell_workspace.raw_ptr); | |||||
TensorND gates_tensor{workspace_bundle.get(0), gates}; | |||||
_megdnn_workspace new_workspace = { | |||||
static_cast<dt_byte*>(workspace_bundle.get(1)), | |||||
workspace_bundle.get_size(1)}; | |||||
LstmCellCompute::run( | |||||
input, weight_ih, bias_ih, state_in.m_h, weight_hh, bias_hh, state_in.m_c, | |||||
state_out.m_h, state_out.m_c, gates_tensor, new_workspace, handle); | |||||
} | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,259 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/lstm/lstm_utils.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/arm_common/lstm_cell/cell_kernel.h" | |||||
#include "src/common/opr_delegate.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/naive/handle.h" | |||||
#include "src/naive/lstm/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
template <class CellOp, class States> | |||||
void cell_opr_compute( | |||||
_megdnn_tensor_in step_input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in bias_hh, const States& state_in, States& state_out, | |||||
Workspace cell_workspace, Handle* handle); | |||||
struct LstmCellWeight { | |||||
size_t m_weight_size = 0; | |||||
TensorND m_weight_ih, m_weight_hh, m_bias_ih, m_bias_hh; | |||||
// if no bias, will create dummy bias tensor from workspace | |||||
LstmCellWeight( | |||||
RefPtr weight_ptr, size_t hidden_size, size_t input_size, bool has_bias, | |||||
DType dtype); | |||||
}; | |||||
struct LstmStates { | |||||
static size_t nr_states() { return 2; } | |||||
size_t m_memory_size; | |||||
TensorND m_h, m_c; | |||||
LstmStates( | |||||
const SmallVector<RefPtr> ptr, size_t hidden_size, size_t batch_size, | |||||
DType dtype); | |||||
}; | |||||
TensorNDArray split_tensor( | |||||
_megdnn_tensor_in tensor, size_t nr_tensor, const TensorLayout& layout); | |||||
template <class CellWeight> | |||||
SmallVector<CellWeight> get_all_cells( | |||||
size_t dir_size, size_t num_layers, size_t input_size, size_t hidden_size, | |||||
bool bias, _megdnn_tensor_in flatten_weights) { | |||||
SmallVector<CellWeight> cell_weights; | |||||
cell_weights.reserve(dir_size * num_layers); | |||||
auto weight_ptr = flatten_weights.get_ref_ptr(); | |||||
for (size_t layer = 0; layer < num_layers; ++layer) { | |||||
for (size_t d = 0; d < dir_size; ++d) { | |||||
size_t cell_input_size = layer == 0 ? input_size : dir_size * hidden_size; | |||||
CellWeight cell_weight( | |||||
weight_ptr, hidden_size, cell_input_size, bias, | |||||
flatten_weights.layout.dtype); | |||||
weight_ptr += cell_weight.m_weight_size; | |||||
cell_weights.push_back(cell_weight); | |||||
} | |||||
} | |||||
return cell_weights; | |||||
} | |||||
template <class States> | |||||
SmallVector<States> get_all_status( | |||||
_megdnn_tensor_in hx, _megdnn_tensor_in cx, size_t hidden_size, | |||||
size_t batch_size, size_t num_layers, size_t dir_size, DType dtype) { | |||||
SmallVector<States> states; | |||||
auto hx_ptr = hx.get_ref_ptr(); | |||||
auto cx_ptr = cx.get_ref_ptr(); | |||||
for (size_t layer = 0; layer < num_layers * dir_size; ++layer) { | |||||
States state({hx_ptr, cx_ptr}, hidden_size, batch_size, dtype); | |||||
hx_ptr += state.m_memory_size; | |||||
cx_ptr += state.m_memory_size; | |||||
states.push_back(state); | |||||
} | |||||
return states; | |||||
} | |||||
template <class Cell, typename CellOpr, class States> | |||||
void exec_kernel( | |||||
SmallVector<Cell>& cells, const TensorNDArray& inputs, | |||||
const SmallVector<States>& states_in, SmallVector<States>& states_out, | |||||
TensorNDArray& outputs, size_t num_layers, size_t dir_size, Handle* handle, | |||||
WorkspaceBundle workspace_bundle) { | |||||
megdnn_assert(cells.size() == num_layers * dir_size); | |||||
megdnn_assert( | |||||
states_in.size() == states_out.size() && | |||||
states_in.size() == num_layers * dir_size); | |||||
megdnn_assert(outputs.size() == inputs.size()); | |||||
//! two tmp state workspace | |||||
megdnn_assert(workspace_bundle.nr_workspace() == 4 + States::nr_states()); | |||||
size_t seq_len = inputs.size(); | |||||
size_t batch_size = inputs[0].layout.shape[0]; | |||||
size_t input_size = inputs[0].layout.shape[1]; | |||||
size_t hidden_size = cells[0].m_weight_hh.layout.shape[1]; | |||||
TensorLayout batch_output_layout{ | |||||
{hidden_size}, outputs[0].layout.dtype}; // output hy | |||||
TensorLayout cell_output_layout{ | |||||
{batch_size, hidden_size}, outputs[0].layout.dtype}; // output hy | |||||
TensorLayout seq_output_layout{ | |||||
{batch_size, dir_size * hidden_size}, outputs[0].layout.dtype}; | |||||
TensorLayout cell_first_input_layout{ | |||||
{batch_size, input_size}, inputs[0].layout.dtype}; // input | |||||
TensorLayout cell_input_layout{ | |||||
{batch_size, dir_size * hidden_size}, inputs[0].layout.dtype}; | |||||
TensorLayout tmp_output_layout{ | |||||
{seq_len, batch_size, dir_size * hidden_size}, outputs[0].layout.dtype}; | |||||
//! workspace get | |||||
Workspace cell_workspace( | |||||
static_cast<dt_byte*>(workspace_bundle.get(0)), | |||||
workspace_bundle.get_size(0) + workspace_bundle.get_size(1)); | |||||
auto&& tmp_inputs_1 = split_tensor( | |||||
TensorND{workspace_bundle.get(2), tmp_output_layout}, seq_len, | |||||
cell_input_layout); | |||||
auto&& tmp_outputs_1 = split_tensor( | |||||
TensorND{workspace_bundle.get(2), tmp_output_layout}, seq_len, | |||||
seq_output_layout); | |||||
auto&& tmp_inputs_2 = split_tensor( | |||||
TensorND{workspace_bundle.get(3), tmp_output_layout}, seq_len, | |||||
cell_input_layout); | |||||
auto&& tmp_outputs_2 = split_tensor( | |||||
TensorND{workspace_bundle.get(3), tmp_output_layout}, seq_len, | |||||
seq_output_layout); | |||||
using IoPair = std::pair<TensorNDArray, TensorNDArray>; | |||||
IoPair io_pair1 = {tmp_inputs_1, tmp_outputs_2}; | |||||
IoPair io_pair2 = {tmp_inputs_2, tmp_outputs_1}; | |||||
SmallVector<IoPair> io_pairs = {io_pair1, io_pair2}; | |||||
SmallVector<RefPtr> ptr; | |||||
for (size_t index = 0; index < States::nr_states(); index++) { | |||||
ptr.push_back(workspace_bundle.get(4 + index)); | |||||
} | |||||
auto&& tmp_state = States(ptr, hidden_size, batch_size, outputs[0].layout.dtype); | |||||
for (size_t layer = 0; layer < num_layers; layer++) { | |||||
auto layer_inputs = io_pairs[layer % 2].first; | |||||
auto layer_outputs = io_pairs[layer % 2].second; | |||||
//! if last layer, direct write to output tensors | |||||
if (num_layers - 1 == layer) { | |||||
layer_outputs = outputs; | |||||
} | |||||
if (0 == layer) { | |||||
layer_inputs = inputs; | |||||
} | |||||
for (size_t d = 0; d < dir_size; ++d) { | |||||
size_t cell_idx = layer * dir_size + d; | |||||
auto& cell = cells[cell_idx]; | |||||
auto& state_in_origin = states_in[cell_idx]; | |||||
auto& state_out_origin = states_out[cell_idx]; | |||||
auto state_in = state_in_origin; | |||||
auto state_out = tmp_state; | |||||
for (size_t i = 0; i < seq_len; ++i) { | |||||
size_t step = d == 0 ? i : seq_len - 1 - i; | |||||
auto& step_input = layer_inputs[step]; | |||||
auto& step_output = layer_outputs[step]; | |||||
if (i == seq_len - 1) { | |||||
state_out = state_out_origin; | |||||
} | |||||
//! task 1 | |||||
//! this CellOp will dispatch task inner, so here not dispatch task | |||||
cell_opr_compute<CellOpr, LstmStates>( | |||||
step_input, cell.m_weight_ih, cell.m_weight_hh, cell.m_bias_ih, | |||||
cell.m_bias_hh, state_in, state_out, cell_workspace, handle); | |||||
//! task 2 | |||||
//! copy output to continue space | |||||
auto copy_to_output = [=]() { | |||||
//! if dir_size >1 and batch_size > 1, recorder to output | |||||
size_t stride = batch_output_layout.span().dist_byte(); | |||||
if (dir_size > 1 && batch_size > 1) { | |||||
int8_t* source = static_cast<int8_t*>(state_out.m_h.raw_ptr()); | |||||
int8_t* dst = static_cast<int8_t*>(step_output.raw_ptr()) + | |||||
d * stride; | |||||
for (size_t b = 0; b < batch_size; b++) { | |||||
memcpy(dst, source, stride); | |||||
source += stride; | |||||
dst += dir_size * stride; | |||||
} | |||||
} else { | |||||
void* source = state_out.m_h.raw_ptr(); | |||||
int8_t* dst = static_cast<int8_t*>(step_output.raw_ptr()) + | |||||
d * stride; | |||||
memcpy(dst, source, state_out.m_h.layout.span().dist_byte()); | |||||
} | |||||
}; | |||||
MEGDNN_DISPATCH_CPU_KERN( | |||||
static_cast<naive::HandleImpl*>(handle), copy_to_output()); | |||||
//! state_in and state_out are read and write inplace | |||||
if (0 == i) { | |||||
state_in = tmp_state; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <typename CellOpr> | |||||
WorkspaceBundle get_workspace_bundle( | |||||
const TensorLayout& input, const TensorLayout& output, | |||||
const TensorLayout& flatten_weights, size_t hidden_size, size_t dir_size, | |||||
size_t states_size) { | |||||
size_t batch_size = input.shape[1]; | |||||
size_t input_size = input.shape[2]; | |||||
size_t gate_hidden_size = flatten_weights.shape[0]; | |||||
// cell workspace | |||||
TensorLayout weight_ih{{gate_hidden_size, input_size}, flatten_weights.dtype}; | |||||
TensorLayout weight_hh{ | |||||
{gate_hidden_size, dir_size * hidden_size}, flatten_weights.dtype}; | |||||
TensorLayout bias{{1, gate_hidden_size}, flatten_weights.dtype}; | |||||
TensorLayout hx{{batch_size, dir_size * hidden_size}, input.dtype}; | |||||
auto cell_opr = inplace_cpu_handle()->create_operator<CellOpr>(); | |||||
TensorLayout h_new, c_new, gates; | |||||
cell_opr->deduce_layout( | |||||
input, weight_ih, bias, hx, weight_hh, bias, hx, h_new, c_new, gates); | |||||
SmallVector<size_t> workspaces; | |||||
//! the cell opr compute workspace | |||||
size_t cell_opr_workspace = cell_opr->get_workspace_in_bytes( | |||||
input, weight_ih, bias, hx, weight_hh, bias, hx, h_new, c_new, gates); | |||||
workspaces.push_back(gates.span().dist_byte()); | |||||
workspaces.push_back(cell_opr_workspace); | |||||
//! double tmp output memory | |||||
size_t tmp_output_workspace = output.span().dist_byte(); | |||||
workspaces.push_back(tmp_output_workspace); | |||||
workspaces.push_back(tmp_output_workspace); | |||||
//! tmp states memory | |||||
size_t tmp_state_workspace = hx.span().dist_byte(); | |||||
for (size_t i = 0; i < states_size; i++) { | |||||
workspaces.push_back(tmp_state_workspace); | |||||
} | |||||
return {nullptr, workspaces}; | |||||
} | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,83 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/lstm/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/lstm/opr_impl.h" | |||||
#include "./lstm_utils.h" | |||||
#include "src/arm_common/lstm_cell/opr_impl.h" | |||||
#include "src/naive/handle.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_lstm) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
void LSTMImpl::exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out hy, _megdnn_tensor_out cy, _megdnn_tensor_out, | |||||
_megdnn_workspace workspace) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_lstm, midout_iv(0)) { | |||||
size_t dir_size = param().bidirectional ? 2 : 1; | |||||
size_t num_layers = param().num_layers; | |||||
size_t hidden_size = param().hidden_size; | |||||
size_t seq_len = input.layout.shape[0]; | |||||
size_t batch_size = input.layout.shape[1]; | |||||
size_t input_size = input.layout.shape[2]; | |||||
//! in order to support input ptr change in record, so this task should be | |||||
//! dispatch to device | |||||
auto&& cell_weights = get_all_cells<LstmCellWeight>( | |||||
dir_size, num_layers, input_size, hidden_size, param().bias, | |||||
flatten_weights); | |||||
auto&& cell_states_in = get_all_status<LstmStates>( | |||||
hx, cx, hidden_size, batch_size, num_layers, dir_size, hx.layout.dtype); | |||||
auto&& cell_states_out = get_all_status<LstmStates>( | |||||
hy, cy, hidden_size, batch_size, num_layers, dir_size, hy.layout.dtype); | |||||
auto&& inputs = split_tensor( | |||||
input, seq_len, | |||||
TensorLayout{{batch_size, input_size}, input.layout.dtype}); | |||||
auto&& outputs = split_tensor( | |||||
output, seq_len, | |||||
TensorLayout{ | |||||
{batch_size, dir_size * hidden_size}, output.layout.dtype}); | |||||
auto workspace_bundle = get_workspace_bundle<LSTMCell>( | |||||
input.layout, output.layout, flatten_weights.layout, hidden_size, | |||||
dir_size, LstmStates::nr_states()); | |||||
workspace_bundle.set(workspace.raw_ptr); | |||||
exec_kernel<LstmCellWeight, LSTMCell, LstmStates>( | |||||
cell_weights, inputs, cell_states_in, cell_states_out, outputs, | |||||
num_layers, dir_size, handle(), workspace_bundle); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
size_t LSTMImpl::get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout&) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_lstm, midout_iv(1)) { | |||||
size_t dir_size = param().bidirectional ? 2 : 1; | |||||
size_t hidden_size = param().hidden_size; | |||||
auto bundle = get_workspace_bundle<LSTMCell>( | |||||
input, output, flatten_weights, hidden_size, dir_size, | |||||
LstmStates::nr_states()); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,43 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/lstm/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/utils.h" | |||||
#include "src/naive/lstm/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
class LSTMImpl : public naive::LSTMImpl { | |||||
public: | |||||
using naive::LSTMImpl::LSTMImpl; | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in hx, _megdnn_tensor_in cx, | |||||
_megdnn_tensor_in flatten_weights, _megdnn_tensor_out output, | |||||
_megdnn_tensor_out hy, _megdnn_tensor_out cy, | |||||
_megdnn_tensor_out reserve_space, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx, | |||||
const TensorLayout& flatten_weights, const TensorLayout& output, | |||||
const TensorLayout& hy, const TensorLayout& cy, | |||||
const TensorLayout& reserve_space) override; | |||||
//! in arm_common only store the output tensor, other tensor is only | |||||
//! used in computing grad, so arm ignore them | |||||
size_t get_reserve_size_in_bytes(const TensorLayout&) override { return 1; } | |||||
}; | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,273 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/lstm_cell/cell_kernel.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "./cell_kernel.h" | |||||
#include "src/arm_common/lstm_cell/opr_impl.h" | |||||
#include "src/common/lstm_cell.h" | |||||
#include "src/common/opr_delegate.h" | |||||
#include "src/naive/handle.h" | |||||
#include "src/arm_common/elemwise_helper/kimpl/sigmoid.h" | |||||
#include "src/arm_common/elemwise_helper/kimpl/tanh.h" | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
template <class Op, bool bias> | |||||
struct ElemwiseCompute { | |||||
static Op op; | |||||
static inline float32x4x2_t compute_8( | |||||
float* dst, float* tmp, float* ih, float* hh) { | |||||
float32x4_t dst0 = vld1q_f32(dst); | |||||
float32x4_t dst1 = vld1q_f32(dst + 4); | |||||
float32x4_t tmp0 = vld1q_f32(tmp); | |||||
float32x4_t tmp1 = vld1q_f32(tmp + 4); | |||||
auto mid0 = vaddq_f32(dst0, tmp0); | |||||
auto mid1 = vaddq_f32(dst1, tmp1); | |||||
float32x4_t out0, out1; | |||||
if (bias) { | |||||
float32x4_t ih0 = vld1q_f32(ih); | |||||
float32x4_t ih1 = vld1q_f32(ih + 4); | |||||
float32x4_t hh0 = vld1q_f32(hh); | |||||
float32x4_t hh1 = vld1q_f32(hh + 4); | |||||
auto midd0 = vaddq_f32(ih0, hh0); | |||||
auto midd1 = vaddq_f32(ih1, hh1); | |||||
out0 = vaddq_f32(mid0, midd0); | |||||
out1 = vaddq_f32(mid1, midd1); | |||||
} else { | |||||
out0 = mid0; | |||||
out1 = mid1; | |||||
} | |||||
return {{op(out0), op(out1)}}; | |||||
} | |||||
static inline float32x4_t compute_4(float* dst, float* tmp, float* ih, float* hh) { | |||||
float32x4_t dst0 = vld1q_f32(dst); | |||||
float32x4_t tmp0 = vld1q_f32(tmp); | |||||
auto mid0 = vaddq_f32(dst0, tmp0); | |||||
float32x4_t out0; | |||||
if (bias) { | |||||
float32x4_t ih0 = vld1q_f32(ih); | |||||
float32x4_t hh0 = vld1q_f32(hh); | |||||
auto midd0 = vaddq_f32(ih0, hh0); | |||||
out0 = vaddq_f32(mid0, midd0); | |||||
} else { | |||||
out0 = mid0; | |||||
} | |||||
return op(out0); | |||||
} | |||||
static inline float compute_1(float* dst, float* tmp, float* ih, float* hh) { | |||||
float out; | |||||
if (bias) { | |||||
out = dst[0] + tmp[0] + ih[0] + hh[0]; | |||||
} else { | |||||
out = dst[0] + tmp[0]; | |||||
} | |||||
return op(out); | |||||
} | |||||
}; | |||||
template <class Op, bool bias> | |||||
Op ElemwiseCompute<Op, bias>::op = Op(); | |||||
template <bool bias> | |||||
void rnn_cell_elemwise_compute( | |||||
_megdnn_tensor_out dst, _megdnn_tensor_in tmp, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in bias_hh, _megdnn_tensor_in cx, _megdnn_tensor_out h_new, | |||||
_megdnn_tensor_out c_new) { | |||||
size_t batch = dst.layout[0]; | |||||
size_t batch_length = dst.layout.total_nr_elems() / batch; | |||||
size_t base_length = batch_length / 4; | |||||
float *ih_ptr_ = nullptr, *hh_ptr_ = nullptr; | |||||
float* dst_ptr_ = dst.ptr<float>(); | |||||
float* tmp_ptr_ = tmp.ptr<float>(); | |||||
if (bias) { | |||||
ih_ptr_ = bias_ih.ptr<float>(); | |||||
hh_ptr_ = bias_hh.ptr<float>(); | |||||
} | |||||
float* cx_ptr_ = cx.ptr<float>(); | |||||
float* h_new_ptr_ = h_new.ptr<float>(); | |||||
float* c_new_ptr_ = c_new.ptr<float>(); | |||||
ElemwiseCompute<SigmoidOp<dt_float32>, bias> sigmoid_compute; | |||||
ElemwiseCompute<TanhOp<dt_float32>, bias> tanh_compute; | |||||
TanhOp<dt_float32> tanh_op; | |||||
for (size_t b = 0; b < batch; b++) { | |||||
float* dst_ptr = dst_ptr_ + b * batch_length; | |||||
float* tmp_ptr = tmp_ptr_ + b * batch_length; | |||||
float* ih_ptr = ih_ptr_; | |||||
float* hh_ptr = hh_ptr_; | |||||
float* cx_ptr = cx_ptr_ + b * base_length; | |||||
float* h_new_ptr = h_new_ptr_ + b * base_length; | |||||
float* c_new_ptr = c_new_ptr_ + b * base_length; | |||||
size_t index = 0; | |||||
for (; index + 7 < base_length; index += 8) { | |||||
auto out_i = sigmoid_compute.compute_8(dst_ptr, tmp_ptr, ih_ptr, hh_ptr); | |||||
auto out_f = sigmoid_compute.compute_8( | |||||
dst_ptr + base_length, tmp_ptr + base_length, ih_ptr + base_length, | |||||
hh_ptr + base_length); | |||||
auto out_g = tanh_compute.compute_8( | |||||
dst_ptr + 2 * base_length, tmp_ptr + 2 * base_length, | |||||
ih_ptr + 2 * base_length, hh_ptr + 2 * base_length); | |||||
auto out_o = sigmoid_compute.compute_8( | |||||
dst_ptr + 3 * base_length, tmp_ptr + 3 * base_length, | |||||
ih_ptr + 3 * base_length, hh_ptr + 3 * base_length); | |||||
float32x4_t cx_0 = vld1q_f32(cx_ptr); | |||||
float32x4_t cx_1 = vld1q_f32(cx_ptr + 4); | |||||
//! f * cx + i * g | |||||
auto c_new_0 = vaddq_f32( | |||||
vmulq_f32(out_f.val[0], cx_0), | |||||
vmulq_f32(out_i.val[0], out_g.val[0])); | |||||
auto c_new_1 = vaddq_f32( | |||||
vmulq_f32(out_f.val[1], cx_1), | |||||
vmulq_f32(out_i.val[1], out_g.val[1])); | |||||
vst1q_f32(c_new_ptr, c_new_0); | |||||
vst1q_f32(c_new_ptr + 4, c_new_1); | |||||
auto h_new_0 = vmulq_f32(tanh_op(c_new_0), out_o.val[0]); | |||||
auto h_new_1 = vmulq_f32(tanh_op(c_new_1), out_o.val[1]); | |||||
vst1q_f32(h_new_ptr, h_new_0); | |||||
vst1q_f32(h_new_ptr + 4, h_new_1); | |||||
dst_ptr += 8; | |||||
tmp_ptr += 8; | |||||
ih_ptr += 8; | |||||
hh_ptr += 8; | |||||
cx_ptr += 8; | |||||
c_new_ptr += 8; | |||||
h_new_ptr += 8; | |||||
} | |||||
for (; index + 3 < base_length; index += 4) { | |||||
auto out_i = sigmoid_compute.compute_4(dst_ptr, tmp_ptr, ih_ptr, hh_ptr); | |||||
auto out_f = sigmoid_compute.compute_4( | |||||
dst_ptr + base_length, tmp_ptr + base_length, ih_ptr + base_length, | |||||
hh_ptr + base_length); | |||||
auto out_g = tanh_compute.compute_4( | |||||
dst_ptr + 2 * base_length, tmp_ptr + 2 * base_length, | |||||
ih_ptr + 2 * base_length, hh_ptr + 2 * base_length); | |||||
auto out_o = sigmoid_compute.compute_4( | |||||
dst_ptr + 3 * base_length, tmp_ptr + 3 * base_length, | |||||
ih_ptr + 3 * base_length, hh_ptr + 3 * base_length); | |||||
float32x4_t cx_v = vld1q_f32(cx_ptr); | |||||
//! f * cx + i * g | |||||
auto c_new = vaddq_f32(vmulq_f32(out_f, cx_v), vmulq_f32(out_i, out_g)); | |||||
vst1q_f32(c_new_ptr, c_new); | |||||
auto h_new = vmulq_f32(tanh_op(c_new), out_o); | |||||
vst1q_f32(h_new_ptr, h_new); | |||||
dst_ptr += 4; | |||||
tmp_ptr += 4; | |||||
ih_ptr += 4; | |||||
hh_ptr += 4; | |||||
cx_ptr += 4; | |||||
c_new_ptr += 4; | |||||
h_new_ptr += 4; | |||||
} | |||||
for (; index < base_length; index++) { | |||||
auto out_i = sigmoid_compute.compute_1(dst_ptr, tmp_ptr, ih_ptr, hh_ptr); | |||||
auto out_f = sigmoid_compute.compute_1( | |||||
dst_ptr + base_length, tmp_ptr + base_length, ih_ptr + base_length, | |||||
hh_ptr + base_length); | |||||
auto out_g = tanh_compute.compute_1( | |||||
dst_ptr + 2 * base_length, tmp_ptr + 2 * base_length, | |||||
ih_ptr + 2 * base_length, hh_ptr + 2 * base_length); | |||||
auto out_o = sigmoid_compute.compute_1( | |||||
dst_ptr + 3 * base_length, tmp_ptr + 3 * base_length, | |||||
ih_ptr + 3 * base_length, hh_ptr + 3 * base_length); | |||||
c_new_ptr[0] = out_f * cx_ptr[0] + out_i * out_g; | |||||
h_new_ptr[0] = tanh_op(c_new_ptr[0]) * out_o; | |||||
dst_ptr += 1; | |||||
tmp_ptr += 1; | |||||
ih_ptr += 1; | |||||
hh_ptr += 1; | |||||
cx_ptr += 1; | |||||
c_new_ptr += 1; | |||||
h_new_ptr += 1; | |||||
} | |||||
} | |||||
} | |||||
} // namespace | |||||
void LstmCellCompute::run( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||||
_megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle) { | |||||
auto bundle = get_workspace_bundle( | |||||
input.layout, weight_ih.layout, bias_ih.layout, hx.layout, weight_hh.layout, | |||||
bias_hh.layout, cx.layout, h_new.layout, c_new.layout, gates.layout); | |||||
bundle.set(workspace.raw_ptr); | |||||
TensorND tmp{static_cast<void*>(bundle.get(0)), gates.layout}; | |||||
auto matmul_workspace = | |||||
megdnn::Workspace{static_cast<dt_byte*>(bundle.get(1)), bundle.get_size(1)}; | |||||
auto opr = handle->create_operator<MatrixMul>(); | |||||
opr->param().transposeB = true; | |||||
//! the opr will dispatch compute task to device, so record mode | |||||
//! performance will not be effect | |||||
opr->exec(input, weight_ih, tmp, matmul_workspace); | |||||
opr->exec(hx, weight_hh, gates, matmul_workspace); | |||||
//! the optimized post compute, nonlinear(tmp + dst + bias_hx + bias_cx) | |||||
if (bias_ih.layout.ndim != 0 && bias_ih.layout.ndim != 0) { | |||||
MEGDNN_DISPATCH_CPU_KERN( | |||||
static_cast<naive::HandleImpl*>(handle), | |||||
rnn_cell_elemwise_compute<true>( | |||||
gates, tmp, bias_ih, bias_hh, cx, h_new, c_new)); | |||||
} else { | |||||
megdnn_assert(bias_ih.layout.ndim == 0 && bias_ih.layout.ndim == 0); | |||||
MEGDNN_DISPATCH_CPU_KERN( | |||||
static_cast<naive::HandleImpl*>(handle), | |||||
rnn_cell_elemwise_compute<false>( | |||||
gates, tmp, bias_ih, bias_hh, cx, h_new, c_new)); | |||||
} | |||||
} | |||||
WorkspaceBundle LstmCellCompute::get_workspace_bundle( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout&, | |||||
const TensorLayout& hx, const TensorLayout& weight_hh, const TensorLayout&, | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout& gates) { | |||||
auto opr = inplace_cpu_handle()->create_operator<MatrixMul>(); | |||||
opr->param().transposeB = true; | |||||
size_t matmul_workspace = std::max( | |||||
opr->get_workspace_in_bytes(input, weight_ih, gates), | |||||
opr->get_workspace_in_bytes(hx, weight_hh, gates)); | |||||
return WorkspaceBundle{nullptr, {gates.span().dist_byte(), matmul_workspace}}; | |||||
} | |||||
bool LstmCellCompute::is_optimized( | |||||
const TensorLayout& input, const TensorLayout&, const TensorLayout& bias_ih, | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout& bias_hh, | |||||
const TensorLayout&, const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout& gates) { | |||||
if (input.dtype.enumv() == DTypeEnum::Float32 && gates[1] == bias_ih[1] && | |||||
bias_ih[0] == 1 && bias_ih.eq_layout(bias_hh)) { | |||||
return true; | |||||
} else { | |||||
return false; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,47 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/lstm_cell/cell_kernel.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/utils.h" | |||||
#include "src/naive/handle.h" | |||||
#include "src/naive/lstm_cell/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
struct LstmCellCompute { | |||||
static WorkspaceBundle get_workspace_bundle( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, const TensorLayout& h_new, | |||||
const TensorLayout& c_new, const TensorLayout& gates); | |||||
static bool is_optimized( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, const TensorLayout& h_new, | |||||
const TensorLayout& c_new, const TensorLayout& gates); | |||||
static void run( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||||
_megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle); | |||||
}; | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,71 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/lstm_cell/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/lstm_cell/opr_impl.h" | |||||
#include "src/common/lstm_cell.h" | |||||
#include "src/naive/handle.h" | |||||
#include "./cell_kernel.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_lstm_cell) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
void LSTMCellImpl::exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||||
_megdnn_tensor_out gates, _megdnn_workspace workspace) { | |||||
//! only float32 and {1, xx} shape bias will be optimized | |||||
MIDOUT_BEGIN(megdnn_arm_common_lstm_cell, midout_iv(0)) { | |||||
if (!LstmCellCompute::is_optimized( | |||||
input.layout, weight_ih.layout, bias_ih.layout, hx.layout, | |||||
weight_hh.layout, bias_hh.layout, cx.layout, h_new.layout, | |||||
c_new.layout, gates.layout)) { | |||||
naive::LSTMCellImpl::exec( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, | |||||
gates, workspace); | |||||
} else { | |||||
LstmCellCompute::run( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, | |||||
gates, workspace, handle()); | |||||
} | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
size_t LSTMCellImpl::get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new, | |||||
const TensorLayout& gates) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_lstm_cell, midout_iv(1)) { | |||||
if (!LstmCellCompute::is_optimized( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, | |||||
gates)) { | |||||
return naive::LSTMCellImpl::get_workspace_in_bytes( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, | |||||
gates); | |||||
} else { | |||||
return LstmCellCompute::get_workspace_bundle( | |||||
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, | |||||
c_new, gates) | |||||
.total_size_in_bytes(); | |||||
} | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/lstm_cell/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/utils.h" | |||||
#include "src/naive/lstm_cell/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
class LSTMCellImpl : public naive::LSTMCellImpl { | |||||
public: | |||||
using naive::LSTMCellImpl::LSTMCellImpl; | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new, | |||||
_megdnn_tensor_out gates, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& cx, const TensorLayout& h_new, | |||||
const TensorLayout& c_new, const TensorLayout& gates) override; | |||||
}; | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,218 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/rnn_cell/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/arm_common/rnn_cell/opr_impl.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/naive/handle.h" | |||||
#include "src/arm_common/elemwise_helper/kimpl/none.h" | |||||
#include "src/arm_common/elemwise_helper/kimpl/relu.h" | |||||
#include "src/arm_common/elemwise_helper/kimpl/tanh.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_rnn_cell) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
namespace { | |||||
ElemwiseForward* get_elemwise_opr() { | |||||
static CpuOprDelegationStorage<1> storage; | |||||
return storage.get<ElemwiseForward>(); | |||||
} | |||||
template <typename Op> | |||||
void elemwise_compute( | |||||
float* dst_ptr, float* tmp_ptr, float* ih_ptr, float* hh_ptr, size_t batch, | |||||
size_t length) { | |||||
const constexpr size_t SIMD_8 = 8; | |||||
const constexpr size_t SIMD_4 = 4; | |||||
Op op; | |||||
for (size_t b = 0; b < batch; b++) { | |||||
float* dst = dst_ptr + b * length; | |||||
float* tmp = tmp_ptr + b * length; | |||||
float* ih = ih_ptr; | |||||
float* hh = hh_ptr; | |||||
size_t index = 0; | |||||
for (; index + SIMD_8 - 1 < length; index += SIMD_8) { | |||||
float32x4_t dst0 = vld1q_f32(dst); | |||||
float32x4_t dst1 = vld1q_f32(dst + 4); | |||||
float32x4_t tmp0 = vld1q_f32(tmp); | |||||
float32x4_t tmp1 = vld1q_f32(tmp + 4); | |||||
float32x4_t ih0 = vld1q_f32(ih); | |||||
float32x4_t ih1 = vld1q_f32(ih + 4); | |||||
float32x4_t hh0 = vld1q_f32(hh); | |||||
float32x4_t hh1 = vld1q_f32(hh + 4); | |||||
auto mid0 = vaddq_f32(dst0, tmp0); | |||||
auto mid1 = vaddq_f32(dst1, tmp1); | |||||
auto midd0 = vaddq_f32(ih0, hh0); | |||||
auto midd1 = vaddq_f32(ih1, hh1); | |||||
auto out0 = vaddq_f32(mid0, midd0); | |||||
auto out1 = vaddq_f32(mid1, midd1); | |||||
vst1q_f32(dst, op(out0)); | |||||
vst1q_f32(dst + 4, op(out1)); | |||||
dst += SIMD_8; | |||||
tmp += SIMD_8; | |||||
ih += SIMD_8; | |||||
hh += SIMD_8; | |||||
} | |||||
for (; index + SIMD_4 - 1 < length; index += SIMD_4) { | |||||
float32x4_t dst0 = vld1q_f32(dst); | |||||
float32x4_t tmp0 = vld1q_f32(tmp); | |||||
float32x4_t ih0 = vld1q_f32(ih); | |||||
float32x4_t hh0 = vld1q_f32(hh); | |||||
auto mid0 = vaddq_f32(dst0, tmp0); | |||||
auto midd0 = vaddq_f32(ih0, hh0); | |||||
auto out0 = vaddq_f32(mid0, midd0); | |||||
vst1q_f32(dst, op(out0)); | |||||
dst += SIMD_4; | |||||
tmp += SIMD_4; | |||||
ih += SIMD_4; | |||||
hh += SIMD_4; | |||||
} | |||||
for (; index < length; index++) { | |||||
auto out = dst[0] + tmp[0] + ih[0] + hh[0]; | |||||
dst[0] = op(out); | |||||
dst++; | |||||
tmp++; | |||||
ih++; | |||||
hh++; | |||||
} | |||||
} | |||||
} | |||||
void rnn_cell_post_compute( | |||||
_megdnn_tensor_out dst, _megdnn_tensor_in tmp, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in bias_hh, param::RNNCell::NonlineMode nonline_mode, | |||||
Handle* handle) { | |||||
using NonlineMode = param::RNNCell::NonlineMode; | |||||
megdnn_assert( | |||||
nonline_mode == NonlineMode::RELU || nonline_mode == NonlineMode::TANH || | |||||
nonline_mode == NonlineMode::IDENTITY, | |||||
"Now arm only support nonlinear mode Relu, TANH, IDENTITY."); | |||||
if (dst.layout.dtype.enumv() == DTypeEnum::Float32 && | |||||
dst.layout[1] == bias_ih.layout[1] && bias_ih.layout[0] == 1 && | |||||
bias_ih.layout.eq_layout(bias_hh.layout)) { | |||||
auto run = [=]() { | |||||
size_t batch = dst.layout[0]; | |||||
size_t length = bias_ih.layout.total_nr_elems(); | |||||
float* dst_ptr = dst.ptr<float>(); | |||||
float* tmp_ptr = tmp.ptr<float>(); | |||||
float* ih_ptr = bias_ih.ptr<float>(); | |||||
float* hh_ptr = bias_hh.ptr<float>(); | |||||
if (nonline_mode == NonlineMode::RELU) { | |||||
elemwise_compute<ReluOp<dt_float32>>( | |||||
dst_ptr, tmp_ptr, ih_ptr, hh_ptr, batch, length); | |||||
} else if (nonline_mode == NonlineMode::TANH) { | |||||
elemwise_compute<TanhOp<dt_float32>>( | |||||
dst_ptr, tmp_ptr, ih_ptr, hh_ptr, batch, length); | |||||
} else { | |||||
elemwise_compute<NoneOp<dt_float32>>( | |||||
dst_ptr, tmp_ptr, ih_ptr, hh_ptr, batch, length); | |||||
} | |||||
}; | |||||
MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle), run()); | |||||
} else { | |||||
//! this opr must be created by inplace handle | |||||
auto elem_opr = get_elemwise_opr(); | |||||
auto run = [=]() { | |||||
elem_opr->param().mode = Elemwise::Param::Mode::ADD; | |||||
elem_opr->exec({dst, tmp}, dst); | |||||
elem_opr->exec({dst, bias_ih}, dst); | |||||
elem_opr->exec({dst, bias_hh}, dst); | |||||
// activation | |||||
switch (nonline_mode) { | |||||
#define cb(_mode) \ | |||||
case NonlineMode::_mode: { \ | |||||
elem_opr->param().mode = Elemwise::Param::Mode::_mode; \ | |||||
elem_opr->exec({dst}, dst); \ | |||||
break; \ | |||||
} | |||||
cb(RELU); | |||||
cb(TANH); | |||||
#undef cb | |||||
case NonlineMode::IDENTITY: | |||||
break; | |||||
default: | |||||
megdnn_throw("unsupport nonlinear mode."); | |||||
} | |||||
}; | |||||
MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle), run()); | |||||
} | |||||
} | |||||
} // namespace | |||||
WorkspaceBundle RNNCellImpl::get_workspace_bundle( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, const TensorLayout&, | |||||
const TensorLayout& hx, const TensorLayout& weight_hh, const TensorLayout&, | |||||
const TensorLayout& dst) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_rnn_cell, midout_iv(0)) { | |||||
auto opr = handle()->create_operator<MatrixMulForward>(); | |||||
opr->param().transposeB = true; | |||||
auto matmul_workspace = std::max( | |||||
opr->get_workspace_in_bytes(input, weight_ih, dst), | |||||
opr->get_workspace_in_bytes(hx, weight_hh, dst)); | |||||
auto tmp_workspace = dst.span().dist_byte(); | |||||
return WorkspaceBundle{nullptr, {tmp_workspace, matmul_workspace}}; | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
size_t RNNCellImpl::get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& dst) { | |||||
return get_workspace_bundle(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst) | |||||
.total_size_in_bytes(); | |||||
} | |||||
void RNNCellImpl::exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih, | |||||
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_rnn_cell, midout_iv(1)) { | |||||
auto bundle = get_workspace_bundle( | |||||
input.layout, weight_ih.layout, bias_ih.layout, hx.layout, | |||||
weight_hh.layout, bias_hh.layout, dst.layout); | |||||
bundle.set(workspace.raw_ptr); | |||||
auto nonline_mode = param().nonlineMode; | |||||
TensorND tmp{static_cast<void*>(bundle.get(0)), dst.layout}; | |||||
auto new_workspace = | |||||
Workspace{static_cast<dt_byte*>(bundle.get(1)), bundle.get_size(1)}; | |||||
//! this opr can't be created by inplace handle | |||||
auto opr = handle()->create_operator<MatrixMulForward>(); | |||||
opr->param().transposeB = true; | |||||
//! the opr will dispatch compute task to device, so record mode | |||||
//! performance will not be effect | |||||
opr->exec(input, weight_ih, tmp, new_workspace); | |||||
opr->exec(hx, weight_hh, dst, new_workspace); | |||||
//! the optimized post compute, nonlinear(tmp + dst + bias_hx + bias_cx) | |||||
rnn_cell_post_compute(dst, tmp, bias_ih, bias_hh, nonline_mode, handle()); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,43 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/rnn_cell/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/opr_delegate.h" | |||||
#include "src/naive/rnn_cell/opr_impl.h" | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
class RNNCellImpl : public naive::RNNCellImpl { | |||||
public: | |||||
using naive::RNNCellImpl::RNNCellImpl; | |||||
void exec( | |||||
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, | |||||
_megdnn_tensor_in bias_ih, _megdnn_tensor_in hx, | |||||
_megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& dst) override; | |||||
private: | |||||
WorkspaceBundle get_workspace_bundle( | |||||
const TensorLayout& input, const TensorLayout& weight_ih, | |||||
const TensorLayout& bias_ih, const TensorLayout& hx, | |||||
const TensorLayout& weight_hh, const TensorLayout& bias_hh, | |||||
const TensorLayout& dst); | |||||
}; | |||||
} // namespace arm_common | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -31,4 +31,6 @@ public: | |||||
}; | }; | ||||
} // namespace naive | } // namespace naive | ||||
} // namespace megdnn | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,225 @@ | |||||
/** | |||||
* \file dnn/test/arm_common/lstm.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "test/arm_common/fixture.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "megdnn/oprs/general.h" | |||||
#include "test/common/benchmarker.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/common/task_record_check.h" | |||||
using namespace megdnn; | |||||
using namespace test; | |||||
namespace { | |||||
//! in arm_common the reserve tensor is not used | |||||
void output_canonizer(const CheckerHelper::TensorValueArray& arr) { | |||||
const TensorND& reserve = arr.back(); | |||||
TensorND& modif_reserve = const_cast<TensorND&>(reserve); | |||||
modif_reserve.layout = TensorLayout(); | |||||
} | |||||
} // namespace | |||||
TEST_F(ARM_COMMON, LSTMCell) { | |||||
Checker<LSTMCell> checker(handle()); | |||||
checker.set_output_canonizer(output_canonizer); | |||||
checker.exec( | |||||
{{1, 10}, | |||||
{40, 10}, | |||||
{1, 40}, | |||||
{1, 10}, | |||||
{40, 10}, | |||||
{1, 40}, | |||||
{1, 10}, | |||||
{}, | |||||
{}, | |||||
{}}); | |||||
for (size_t batch : {2}) | |||||
for (size_t n : {3, 4, 5, 23, 100}) | |||||
for (size_t out : {3, 6, 25, 100}) { | |||||
checker.exec( | |||||
{{batch, n}, | |||||
{out * 4, n}, | |||||
{1, out * 4}, | |||||
{batch, out}, | |||||
{out * 4, out}, | |||||
{1, out * 4}, | |||||
{batch, out}, | |||||
{}, | |||||
{}, | |||||
{}}); | |||||
checker.exec( | |||||
{{batch, n}, | |||||
{out * 4, n}, | |||||
{batch, out * 4}, | |||||
{batch, out}, | |||||
{out * 4, out}, | |||||
{batch, out * 4}, | |||||
{batch, out}, | |||||
{}, | |||||
{}, | |||||
{}}); | |||||
} | |||||
} | |||||
TEST_F(ARM_COMMON, LSTMCellRecord) { | |||||
TaskRecordChecker<LSTMCell> checker(0); | |||||
checker.exec( | |||||
{{1, 10}, | |||||
{40, 10}, | |||||
{1, 40}, | |||||
{1, 10}, | |||||
{40, 10}, | |||||
{1, 40}, | |||||
{1, 10}, | |||||
{}, | |||||
{}, | |||||
{}}); | |||||
} | |||||
namespace { | |||||
void test_lstm(bool bias, bool direction, Handle* handle) { | |||||
Checker<LSTM> checker(handle, true); | |||||
//! because lstm has tanh, exp mathematical compute, after more iteration, | |||||
//! the error will more than 1e-3 | |||||
checker.set_epsilon(1e-2); | |||||
checker.set_output_canonizer(output_canonizer); | |||||
for (size_t input_size : {2, 8, 13}) | |||||
for (size_t hidden_size : {1, 4, 17}) { | |||||
size_t dir_size = direction == false ? 1 : 2; | |||||
LSTM::Param param; | |||||
param.bidirectional = direction; | |||||
size_t gate_hidden_size = 4 * hidden_size; | |||||
param.bias = bias; | |||||
param.hidden_size = hidden_size; | |||||
for (size_t seq_len : {1, 3, 5}) | |||||
for (size_t batch_size : {1, 2, 4}) | |||||
for (size_t number_layer : {1, 2, 4, 5, 8}) { | |||||
size_t flatten_size = 0; | |||||
for (size_t layer = 0; layer < number_layer; layer++) { | |||||
for (size_t dir = 0; dir < dir_size; dir++) { | |||||
flatten_size += layer == 0 | |||||
? input_size | |||||
: dir_size * hidden_size; // ih | |||||
flatten_size += hidden_size; // hh | |||||
} | |||||
} | |||||
if (bias) { | |||||
flatten_size += 2 * dir_size * number_layer; | |||||
} | |||||
param.num_layers = number_layer; | |||||
checker.set_param(param).exec( | |||||
{{seq_len, batch_size, input_size}, // input | |||||
{number_layer * dir_size, batch_size, | |||||
hidden_size}, // hx | |||||
{number_layer * dir_size, batch_size, | |||||
hidden_size}, // hy | |||||
{gate_hidden_size, flatten_size}, // flat weight | |||||
{}, | |||||
{}, | |||||
{}, | |||||
{}}); | |||||
} | |||||
} | |||||
} | |||||
} // namespace | |||||
TEST_F(ARM_COMMON, LSTM_FORWARD_NO_BIAS_NO_DIRCTION) { | |||||
test_lstm(false, false, handle()); | |||||
} | |||||
TEST_F(ARM_COMMON, LSTM_FORWARD_BIAS_NO_DIRCTION) { | |||||
test_lstm(true, false, handle()); | |||||
} | |||||
TEST_F(ARM_COMMON, LSTM_FORWARD_DIRECTION_NO_BIAS) { | |||||
test_lstm(false, true, handle()); | |||||
} | |||||
TEST_F(ARM_COMMON, LSTM_FORWARD_DIRECTION_BIAS) { | |||||
test_lstm(true, true, handle()); | |||||
} | |||||
TEST_F(ARM_COMMON, LSTM_FORWARD_RECORD) { | |||||
TaskRecordChecker<LSTM> checker(0); | |||||
size_t input_size = 2; | |||||
size_t hidden_size = 2; | |||||
size_t gate_hidden_size = 4 * hidden_size; | |||||
LSTM::Param param; | |||||
param.bidirectional = false; | |||||
param.bias = false; | |||||
param.hidden_size = hidden_size; | |||||
// checker.set_output_canonizer(output_canonizer); | |||||
for (size_t seq_len : {1, 3, 5}) | |||||
for (size_t batch_size : {1, 2, 4}) | |||||
for (size_t number_layer : {1, 2, 4, 5, 8}) { | |||||
param.num_layers = number_layer; | |||||
checker.set_param(param).exec( | |||||
{{seq_len, batch_size, input_size}, // input | |||||
{number_layer, batch_size, hidden_size}, // hx | |||||
{number_layer, batch_size, hidden_size}, // hy | |||||
{number_layer, gate_hidden_size, | |||||
input_size + hidden_size}, // flat weight | |||||
{}, | |||||
{}, | |||||
{}, | |||||
{}}); | |||||
} | |||||
} | |||||
#if MEGDNN_WITH_BENCHMARK | |||||
TEST_F(ARM_COMMON, BENCHMARK_LSTM_FORWARD) { | |||||
Benchmarker<LSTM> optimized_bench(handle()); | |||||
constexpr size_t RUNS = 20; | |||||
auto run = [&](size_t hidden_size, size_t input_size) { | |||||
optimized_bench.set_times(20).set_display(true); | |||||
size_t gate_hidden_size = 4 * hidden_size; | |||||
for (bool direction : {false, true}) { | |||||
LSTM::Param param; | |||||
param.hidden_size = hidden_size; | |||||
param.bidirectional = direction; | |||||
param.bias = false; | |||||
size_t dir_size = direction == false ? 1 : 2; | |||||
for (size_t seq_len : {1, 5, 8}) | |||||
for (size_t batch_size : {1, 8, 16}) | |||||
for (size_t number_layer : {1}) { | |||||
param.num_layers = number_layer; | |||||
size_t flatten_size = 0; | |||||
for (size_t layer = 0; layer < number_layer; layer++) { | |||||
for (size_t dir = 0; dir < dir_size; dir++) { | |||||
flatten_size += layer == 0 | |||||
? input_size | |||||
: dir_size * hidden_size; // ih | |||||
flatten_size += hidden_size; // hh | |||||
} | |||||
} | |||||
optimized_bench.set_param(param).exec( | |||||
{{seq_len, batch_size, input_size}, // input | |||||
{number_layer * dir_size, batch_size, | |||||
hidden_size}, // hx | |||||
{number_layer * dir_size, batch_size, | |||||
hidden_size}, // hy | |||||
{gate_hidden_size, flatten_size}, // flat weight | |||||
{}, | |||||
{}, | |||||
{}, | |||||
{}}); | |||||
} | |||||
} | |||||
}; | |||||
run(512, 256); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,67 @@ | |||||
/** | |||||
* \file dnn/test/arm_common/rnn.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "test/arm_common/fixture.h" | |||||
#include "megdnn/oprs.h" | |||||
#include "test/common/benchmarker.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/common/task_record_check.h" | |||||
using namespace megdnn; | |||||
using namespace test; | |||||
TEST_F(ARM_COMMON, RNNCell) { | |||||
Checker<RNNCell> checker(handle()); | |||||
using NonlineMode = param::RNNCell::NonlineMode; | |||||
param::RNNCell param; | |||||
for (auto mode : {NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::TANH}) | |||||
for (size_t batch : {1, 4}) | |||||
for (size_t n : {3, 4, 5, 23, 100}) | |||||
for (size_t h : {5, 23, 100}) | |||||
for (size_t out : {3, 6, 25, 100}) { | |||||
param.nonlineMode = mode; | |||||
checker.set_param(param); | |||||
checker.exec( | |||||
{{batch, n}, | |||||
{out, n}, | |||||
{1, out}, | |||||
{batch, h}, | |||||
{out, h}, | |||||
{1, out}, | |||||
{}}); | |||||
checker.exec( | |||||
{{batch, n}, | |||||
{out, n}, | |||||
{batch, out}, | |||||
{batch, h}, | |||||
{out, h}, | |||||
{batch, out}, | |||||
{}}); | |||||
} | |||||
} | |||||
TEST_F(ARM_COMMON, RNNCellRecord) { | |||||
TaskRecordChecker<RNNCell> checker(0); | |||||
using NonlineMode = param::RNNCell::NonlineMode; | |||||
param::RNNCell param; | |||||
for (auto mode : {NonlineMode::IDENTITY, NonlineMode::RELU, NonlineMode::TANH}) { | |||||
param.nonlineMode = mode; | |||||
checker.set_param(param); | |||||
checker.exec({{1, 100}, {10, 100}, {1, 10}, {1, 100}, {10, 100}, {1, 10}, {}}); | |||||
checker.exec({{1, 34}, {15, 34}, {1, 15}, {1, 34}, {15, 34}, {1, 15}, {}}); | |||||
checker.exec({{1, 73}, {25, 73}, {1, 25}, {1, 73}, {25, 73}, {1, 25}, {}}); | |||||
} | |||||
} | |||||
#if MEGDNN_WITH_BENCHMARK | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |