|
|
@@ -362,96 +362,111 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, |
|
|
|
InnerBlockSize get_inner_block_size() const override; \ |
|
|
|
size_t get_packA_type_size() const override; |
|
|
|
|
|
|
|
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_PACKA( \ |
|
|
|
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ |
|
|
|
_packa_type) \ |
|
|
|
\ |
|
|
|
MatrixMulImpl::kern_naked_t MatrixMulImpl::_algo_name::get_kern_naked( \ |
|
|
|
const KernSizeParam&) const { \ |
|
|
|
auto kern = [](const MatrixMulImpl::KernParam& kern_param, \ |
|
|
|
const void* packed_a, const void* packed_b) { \ |
|
|
|
MIDOUT_BEGIN(_midout_name, midout_iv(_mid_index)) { \ |
|
|
|
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; \ |
|
|
|
auto trA = kern_param.trA, trB = kern_param.trB; \ |
|
|
|
auto LDC = kern_param.LDC; \ |
|
|
|
auto A_type = kern_param.A_type, B_type = kern_param.B_type, \ |
|
|
|
C_type = kern_param.C_type; \ |
|
|
|
auto Cptr = kern_param.C<_c_type>(); \ |
|
|
|
\ |
|
|
|
_strategy strategy(M, N, K, A_type, B_type, C_type); \ |
|
|
|
megdnn::matmul::GemmInterleaved<_strategy>(M, N, K, trA, trB, \ |
|
|
|
strategy) \ |
|
|
|
.execute_naked(Cptr, LDC, packed_a, packed_b); \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
|
}; \ |
|
|
|
return kern; \ |
|
|
|
} \ |
|
|
|
\ |
|
|
|
void MatrixMulImpl::_algo_name::pack_A(const KernParam& kern_param, \ |
|
|
|
void* out, size_t index, \ |
|
|
|
size_t stride) const { \ |
|
|
|
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; \ |
|
|
|
auto A_type = kern_param.A_type, B_type = kern_param.B_type, \ |
|
|
|
C_type = kern_param.C_type; \ |
|
|
|
\ |
|
|
|
auto trA = kern_param.trA, trB = kern_param.trB; \ |
|
|
|
auto LDA = kern_param.LDA; \ |
|
|
|
const auto Aptr = kern_param.A<_i_type>(); \ |
|
|
|
_strategy strategy(M, N, K, A_type, B_type, C_type); \ |
|
|
|
size_t start_index = index * stride; \ |
|
|
|
size_t end_index = start_index + stride; \ |
|
|
|
end_index = std::min(end_index, M); \ |
|
|
|
megdnn::matmul::GemmInterleaved<_strategy>(M, N, K, trA, trB, \ |
|
|
|
strategy) \ |
|
|
|
.pack_A(reinterpret_cast<_packa_type*>(out), Aptr, LDA, \ |
|
|
|
start_index, end_index); \ |
|
|
|
} \ |
|
|
|
\ |
|
|
|
void MatrixMulImpl::_algo_name::pack_B(const KernParam& kern_param, \ |
|
|
|
void* out, const size_t x0, \ |
|
|
|
size_t xmax) const { \ |
|
|
|
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; \ |
|
|
|
auto A_type = kern_param.A_type, B_type = kern_param.B_type, \ |
|
|
|
C_type = kern_param.C_type; \ |
|
|
|
\ |
|
|
|
auto trA = kern_param.trA, trB = kern_param.trB; \ |
|
|
|
auto LDB = kern_param.LDB; \ |
|
|
|
const auto Bptr = kern_param.B<_i_type>(); \ |
|
|
|
_strategy strategy(M, N, K, A_type, B_type, C_type); \ |
|
|
|
megdnn::matmul::GemmInterleaved<_strategy>(M, N, K, trA, trB, \ |
|
|
|
strategy) \ |
|
|
|
.pack_B(reinterpret_cast<_i_type*>(out), Bptr, LDB, x0, xmax); \ |
|
|
|
} \ |
|
|
|
\ |
|
|
|
WorkspaceBundle MatrixMulImpl::_algo_name::get_bundle( \ |
|
|
|
const KernSizeParam& kern_size_param) const { \ |
|
|
|
auto M = kern_size_param.M, N = kern_size_param.N, \ |
|
|
|
K = kern_size_param.K; \ |
|
|
|
auto trA = kern_size_param.trA, trB = kern_size_param.trB; \ |
|
|
|
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, \ |
|
|
|
C_type = kern_size_param.C_type; \ |
|
|
|
_strategy strategy(M, N, K, A_type, B_type, C_type); \ |
|
|
|
return megdnn::matmul::GemmInterleaved<_strategy>(M, N, K, trA, trB, \ |
|
|
|
strategy) \ |
|
|
|
.get_bundle(); \ |
|
|
|
} \ |
|
|
|
\ |
|
|
|
MatrixMulImpl::_algo_name::InnerBlockSize \ |
|
|
|
MatrixMulImpl::_algo_name::get_inner_block_size() const { \ |
|
|
|
return {_strategy::KERNEL_H, _strategy::KERNEL_W, \ |
|
|
|
_strategy::UNROLL_K}; \ |
|
|
|
} \ |
|
|
|
\ |
|
|
|
size_t MatrixMulImpl::_algo_name::get_packA_type_size() const { \ |
|
|
|
return sizeof(_packa_type); \ |
|
|
|
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \ |
|
|
|
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ |
|
|
|
_packa_type) \ |
|
|
|
\ |
|
|
|
MatrixMulImpl::kern_naked_t MatrixMulImpl::_algo_name::get_kern_naked( \ |
|
|
|
const KernSizeParam&) const { \ |
|
|
|
auto kern = [](const MatrixMulImpl::KernParam& kern_param, \ |
|
|
|
const void* packed_a, const void* packed_b) { \ |
|
|
|
MIDOUT_BEGIN(_midout_name, midout_iv(_mid_index), \ |
|
|
|
midout_iv("get_kern_naked"_hash)) { \ |
|
|
|
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; \ |
|
|
|
auto trA = kern_param.trA, trB = kern_param.trB; \ |
|
|
|
auto LDC = kern_param.LDC; \ |
|
|
|
auto A_type = kern_param.A_type, B_type = kern_param.B_type, \ |
|
|
|
C_type = kern_param.C_type; \ |
|
|
|
auto Cptr = kern_param.C<_c_type>(); \ |
|
|
|
\ |
|
|
|
_strategy strategy(M, N, K, A_type, B_type, C_type); \ |
|
|
|
megdnn::matmul::GemmInterleaved<_strategy>(M, N, K, trA, trB, \ |
|
|
|
strategy) \ |
|
|
|
.execute_naked(Cptr, LDC, packed_a, packed_b); \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
|
}; \ |
|
|
|
return kern; \ |
|
|
|
} \ |
|
|
|
\ |
|
|
|
void MatrixMulImpl::_algo_name::pack_A(const KernParam& kern_param, \ |
|
|
|
void* out, size_t index, \ |
|
|
|
size_t stride) const { \ |
|
|
|
MIDOUT_BEGIN(_midout_name, midout_iv(_mid_index), \ |
|
|
|
midout_iv("pack_A"_hash)) { \ |
|
|
|
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; \ |
|
|
|
auto A_type = kern_param.A_type, B_type = kern_param.B_type, \ |
|
|
|
C_type = kern_param.C_type; \ |
|
|
|
\ |
|
|
|
auto trA = kern_param.trA, trB = kern_param.trB; \ |
|
|
|
auto LDA = kern_param.LDA; \ |
|
|
|
const auto Aptr = kern_param.A<_i_type>(); \ |
|
|
|
_strategy strategy(M, N, K, A_type, B_type, C_type); \ |
|
|
|
size_t start_index = index * stride; \ |
|
|
|
size_t end_index = start_index + stride; \ |
|
|
|
end_index = std::min(end_index, M); \ |
|
|
|
megdnn::matmul::GemmInterleaved<_strategy>(M, N, K, trA, trB, \ |
|
|
|
strategy) \ |
|
|
|
.pack_A(reinterpret_cast<_packa_type*>(out), Aptr, LDA, \ |
|
|
|
start_index, end_index); \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
|
} \ |
|
|
|
\ |
|
|
|
void MatrixMulImpl::_algo_name::pack_B(const KernParam& kern_param, \ |
|
|
|
void* out, const size_t x0, \ |
|
|
|
size_t xmax) const { \ |
|
|
|
MIDOUT_BEGIN(_midout_name, midout_iv(_mid_index), \ |
|
|
|
midout_iv("pack_B"_hash)) { \ |
|
|
|
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; \ |
|
|
|
auto A_type = kern_param.A_type, B_type = kern_param.B_type, \ |
|
|
|
C_type = kern_param.C_type; \ |
|
|
|
\ |
|
|
|
auto trA = kern_param.trA, trB = kern_param.trB; \ |
|
|
|
auto LDB = kern_param.LDB; \ |
|
|
|
const auto Bptr = kern_param.B<_i_type>(); \ |
|
|
|
_strategy strategy(M, N, K, A_type, B_type, C_type); \ |
|
|
|
megdnn::matmul::GemmInterleaved<_strategy>(M, N, K, trA, trB, \ |
|
|
|
strategy) \ |
|
|
|
.pack_B(reinterpret_cast<_i_type*>(out), Bptr, LDB, x0, \ |
|
|
|
xmax); \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
|
} \ |
|
|
|
\ |
|
|
|
WorkspaceBundle MatrixMulImpl::_algo_name::get_bundle( \ |
|
|
|
const KernSizeParam& kern_size_param) const { \ |
|
|
|
MIDOUT_BEGIN(_midout_name, midout_iv(_mid_index), \ |
|
|
|
midout_iv("get_bundle"_hash)) { \ |
|
|
|
auto M = kern_size_param.M, N = kern_size_param.N, \ |
|
|
|
K = kern_size_param.K; \ |
|
|
|
auto trA = kern_size_param.trA, trB = kern_size_param.trB; \ |
|
|
|
auto A_type = kern_size_param.A_type, \ |
|
|
|
B_type = kern_size_param.B_type, \ |
|
|
|
C_type = kern_size_param.C_type; \ |
|
|
|
_strategy strategy(M, N, K, A_type, B_type, C_type); \ |
|
|
|
return megdnn::matmul::GemmInterleaved<_strategy>(M, N, K, trA, \ |
|
|
|
trB, strategy) \ |
|
|
|
.get_bundle(); \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
|
} \ |
|
|
|
\ |
|
|
|
MatrixMulImpl::_algo_name::InnerBlockSize \ |
|
|
|
MatrixMulImpl::_algo_name::get_inner_block_size() const { \ |
|
|
|
return {_strategy::KERNEL_H, _strategy::KERNEL_W, \ |
|
|
|
_strategy::UNROLL_K}; \ |
|
|
|
} \ |
|
|
|
\ |
|
|
|
size_t MatrixMulImpl::_algo_name::get_packA_type_size() const { \ |
|
|
|
return sizeof(_packa_type); \ |
|
|
|
} |
|
|
|
|
|
|
|
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \ |
|
|
|
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type) \ |
|
|
|
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_PACKA(_algo_name, _midout_name, \ |
|
|
|
_mid_index, _strategy, _i_type, \ |
|
|
|
_c_type, _i_type) |
|
|
|
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \ |
|
|
|
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type) \ |
|
|
|
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(_algo_name, _midout_name, \ |
|
|
|
_mid_index, _strategy, \ |
|
|
|
_i_type, _c_type, _i_type) |
|
|
|
} // namespace matmul |
|
|
|
} // namespace megdnn |
|
|
|
|
|
|
|