@@ -92,6 +92,12 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||||
fill_dwconv_algos(); | fill_dwconv_algos(); | ||||
all_algos.push_back(&int8_chwn4_dotprod); | all_algos.push_back(&int8_chwn4_dotprod); | ||||
all_algos.push_back(&fallback_nchw_qs8); | all_algos.push_back(&fallback_nchw_qs8); | ||||
fill_ptx_algos(); | |||||
for (auto&& algo : algo_ptx_conv2d_u4_s4) { | |||||
all_algos.push_back(&algo); | |||||
} | |||||
for (size_t i = all_algo_size; i < all_algos.size(); ++i) { | for (size_t i = all_algo_size; i < all_algos.size(); ++i) { | ||||
non_cudnn_algos.push_back(all_algos[i]); | non_cudnn_algos.push_back(all_algos[i]); | ||||
} | } | ||||
@@ -364,6 +370,15 @@ void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { | |||||
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2}); | int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2}); | ||||
} | } | ||||
void ConvBiasForwardImpl::AlgoPack::fill_ptx_algos() { | |||||
algo_ptx_conv2d_u4_s4.emplace_back( | |||||
AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm{128, 256, 256}); | |||||
algo_ptx_conv2d_u4_s4.emplace_back( | |||||
AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm{128, 128, 128}); | |||||
algo_ptx_conv2d_u4_s4.emplace_back( | |||||
AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm{256, 64, 128}); | |||||
} | |||||
ConvBiasForwardImpl::AlgoBase* ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum( | ConvBiasForwardImpl::AlgoBase* ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum( | ||||
cudnnConvolutionFwdAlgo_t algo) { | cudnnConvolutionFwdAlgo_t algo) { | ||||
for (auto&& i : cudnn_convs) { | for (auto&& i : cudnn_convs) { | ||||
@@ -78,6 +78,7 @@ public: | |||||
CUDA_SIMPLE_INT1, | CUDA_SIMPLE_INT1, | ||||
CUDA_CUDNN_CONV_V8, | CUDA_CUDNN_CONV_V8, | ||||
CUDA_CUDNN_CONVBIAS_V8, | CUDA_CUDNN_CONVBIAS_V8, | ||||
CUDA_IMPLICIT_GEMM_PTX_NCHW64_IMMA_UINT4_INT4, | |||||
}; | }; | ||||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | ||||
@@ -1203,6 +1204,45 @@ private: | |||||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | ||||
}; | }; | ||||
class ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm final | |||||
: public AlgoBase { | |||||
public: | |||||
AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm( | |||||
unsigned int tile_nhw, unsigned int tile_oc, unsigned int threads) | |||||
: m_tile_nhw{tile_nhw}, m_tile_oc{tile_oc}, m_threads{threads} { | |||||
m_name = ConvBias::algo_name<ConvBias::DirectParam>( | |||||
ssprintf( | |||||
"PTX_UINT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%uX%u_%u", m_tile_nhw, | |||||
m_tile_oc, m_threads), | |||||
ConvBias::DirectParam{}); | |||||
} | |||||
bool is_available(const SizeArgs& args) const override; | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
void exec(const ExecArgs& args) const override; | |||||
const char* name() const override { return m_name.c_str(); } | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override; | |||||
SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | |||||
const SizeArgs& args) const override; | |||||
void exec_preprocess(const ExecArgs& args) const override; | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_PTX_NCHW64_IMMA_UINT4_INT4) | |||||
std::string param() const override { | |||||
std::string ret; | |||||
serialize_write_pod(m_tile_nhw, ret); | |||||
serialize_write_pod(m_tile_oc, ret); | |||||
serialize_write_pod(m_threads, ret); | |||||
return ret; | |||||
} | |||||
private: | |||||
std::string kernel_key(const SizeArgs& args) const; | |||||
unsigned int m_tile_nhw, m_tile_oc, m_threads; | |||||
std::string m_name; | |||||
void reorder_filter_bias( | |||||
const ExecArgs& args, void* reduce_filter, void* reordered_filter, | |||||
void* reordered_bias) const; | |||||
}; | |||||
class ConvBiasForwardImpl::AlgoPack : NonCopyableObj { | class ConvBiasForwardImpl::AlgoPack : NonCopyableObj { | ||||
private: | private: | ||||
AlgoBase::Mapper m_all_algos_map; | AlgoBase::Mapper m_all_algos_map; | ||||
@@ -1251,6 +1291,7 @@ public: | |||||
AlgoCUDNNConvV8 cudnn_conv_v8; | AlgoCUDNNConvV8 cudnn_conv_v8; | ||||
AlgoCUDNNConvBiasActivationV8 cudnn_conv_bias_activation_v8; | AlgoCUDNNConvBiasActivationV8 cudnn_conv_bias_activation_v8; | ||||
#endif | #endif | ||||
std::vector<AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm> algo_ptx_conv2d_u4_s4; | |||||
AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); | AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); | ||||
@@ -1265,6 +1306,7 @@ private: | |||||
void fill_cudnn_algos(); | void fill_cudnn_algos(); | ||||
void fill_dp4a_algos(); | void fill_dp4a_algos(); | ||||
void fill_dwconv_algos(); | void fill_dwconv_algos(); | ||||
void fill_ptx_algos(); | |||||
}; | }; | ||||
} // namespace cuda | } // namespace cuda | ||||
@@ -72,6 +72,7 @@ public: | |||||
class AlgoCUDNNConvV8; | class AlgoCUDNNConvV8; | ||||
class AlgoCUDNNConvBiasActivationV8; | class AlgoCUDNNConvBiasActivationV8; | ||||
#endif | #endif | ||||
class AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm; | |||||
class AlgoPack; | class AlgoPack; | ||||
@@ -0,0 +1,156 @@ | |||||
#include "src/cuda/conv_bias/ptx_helper.cuh" | |||||
#include "src/cuda/integer_subbyte_utils.cuh" | |||||
#include "src/cuda/query_blocksize.cuh" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace ptx; | |||||
namespace { | |||||
template <uint32_t size_bits, uint32_t interleaved> | |||||
__device__ __forceinline__ void reorder_imma_filter_func( | |||||
int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, | |||||
uint32_t FW, uint32_t lane) { | |||||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||||
uint32_t elements = lane * elements_per_lane; | |||||
uint32_t row = elements / (IC * FH * FW); | |||||
uint32_t col = elements - row * IC * FH * FW; | |||||
uint32_t sec = row / 4; | |||||
uint32_t res = col & (interleaved - 1); | |||||
uint32_t sec_sec = row & 3; | |||||
uint32_t sec_res = (row & 15) / 4; | |||||
uint32_t crosswise_offset = ((sec_sec >> 1) * 2 * interleaved) + | |||||
(((sec_sec & 1) ^ (sec_res >> 1)) * interleaved); | |||||
uint32_t residue_offset = | |||||
((res / elements_per_lane) ^ (sec_res & 1)) * elements_per_lane; | |||||
uint32_t dst_offset = | |||||
(sec / 2) * 8 * FH * FW * IC + (col / interleaved) * (8 * interleaved) + | |||||
(sec & 1) * (4 * interleaved) + crosswise_offset + residue_offset; | |||||
static constexpr uint32_t instruction_shape_col = 8; | |||||
// 4 threads per Quad | |||||
static constexpr uint32_t elements_per_thread = instruction_shape_col / 4; | |||||
// 4 threads per Quad | |||||
static constexpr uint32_t reordered_elements_per_thread = interleaved / 4; | |||||
uint32_t elem_in_interleaved = row % interleaved; | |||||
uint32_t elem_in_interleaved_pack = elem_in_interleaved / elements_per_thread; | |||||
int elem_new = (row / interleaved * interleaved + | |||||
elem_in_interleaved_pack % 4 * reordered_elements_per_thread + | |||||
elem_in_interleaved_pack / 4 * elements_per_thread + | |||||
elem_in_interleaved % elements_per_thread) * | |||||
(IC * FH * FW) + | |||||
col; | |||||
*(reinterpret_cast<int4*>(dst + (dst_offset * size_bits / 8))) = | |||||
*(reinterpret_cast<const int4*>(src + (elem_new * size_bits / 8))); | |||||
} | |||||
template <uint32_t interleaved> | |||||
__device__ __forceinline__ void reorder_imma_bias_func( | |||||
float* __restrict__ dst, float src_value, uint32_t OC, uint32_t lane) { | |||||
dst[lane] = src_value; | |||||
} | |||||
template <uint32_t size_bits, uint32_t interleaved> | |||||
__global__ void reorder_imma_filter_bias_kernel( | |||||
int8_t* __restrict__ dst_filter, float* __restrict__ dst_bias, | |||||
const int8_t* __restrict__ src_filter, const int32_t* __restrict__ src_bias, | |||||
float bias_scale, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW) { | |||||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||||
const uint32_t size1 = OC * IC * FH * FW / elements_per_lane; | |||||
const uint32_t size2 = OC; | |||||
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||||
if (lane < size1) { | |||||
reorder_imma_filter_func<size_bits, interleaved>( | |||||
dst_filter, src_filter, OC, IC, FH, FW, lane); | |||||
} else if (lane < size1 + size2) { | |||||
lane = lane - size1; | |||||
float src_bias_value = src_bias[lane] * bias_scale; | |||||
reorder_imma_bias_func<interleaved>(dst_bias, src_bias_value, OC, lane); | |||||
} | |||||
} | |||||
template <uint32_t size_bits, uint32_t interleaved> | |||||
__global__ void reorder_imma_filter_bias_fusion_zero_point_kernel( | |||||
int8_t* __restrict__ dst_filter, float* __restrict__ dst_bias, | |||||
const int8_t* __restrict__ src_filter, const int32_t* __restrict__ src_bias, | |||||
float bias_scale, const int32_t* reduce_filter, float zero_point, uint32_t OC, | |||||
uint32_t IC, uint32_t FH, uint32_t FW) { | |||||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||||
const uint32_t size1 = OC * IC * FH * FW / elements_per_lane; | |||||
const uint32_t size2 = OC; | |||||
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||||
if (lane < size1) { | |||||
reorder_imma_filter_func<size_bits, interleaved>( | |||||
dst_filter, src_filter, OC, IC, FH, FW, lane); | |||||
} else if (lane < size1 + size2) { | |||||
lane = lane - size1; | |||||
// fusion bias and zero_point | |||||
// zero_point = zero_point * src_scale * filter_scale | |||||
float src_bias_value = | |||||
src_bias[lane] * bias_scale - reduce_filter[lane] * zero_point; | |||||
reorder_imma_bias_func<interleaved>(dst_bias, src_bias_value, OC, lane); | |||||
} | |||||
} | |||||
} // namespace | |||||
template <uint32_t size_bits, uint32_t interleaved> | |||||
void megdnn::cuda::ptx::reorder_imma_filter_bias( | |||||
int8_t* dst_filter, float* dst_bias, const int8_t* src_filter, | |||||
const int32_t* src_bias, float bias_scale, uint32_t OC, uint32_t IC, | |||||
uint32_t FH, uint32_t FW, cudaStream_t stream) { | |||||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||||
uint32_t nr_threads = query_blocksize_for_kernel(reinterpret_cast<const void*>( | |||||
reorder_imma_filter_bias_kernel<size_bits, interleaved>)); | |||||
uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_lane) + OC; | |||||
nr_threads = std::min(nr_threads, vthreads); | |||||
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||||
reorder_imma_filter_bias_kernel<size_bits, interleaved> | |||||
<<<nr_blocks, nr_threads, 0, stream>>>( | |||||
dst_filter, dst_bias, src_filter, src_bias, bias_scale, OC, IC, FH, | |||||
FW); | |||||
after_kernel_launch(); | |||||
} | |||||
template <uint32_t size_bits, uint32_t interleaved> | |||||
void megdnn::cuda::ptx::reorder_imma_filter_bias_fusion_zero_point( | |||||
int8_t* dst_filter, float* dst_bias, const int8_t* src_filter, | |||||
const int32_t* src_bias, float bias_scale, const int32_t* reduce_filter, | |||||
float zero_point, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, | |||||
cudaStream_t stream) { | |||||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||||
uint32_t nr_threads = query_blocksize_for_kernel(reinterpret_cast<const void*>( | |||||
reorder_imma_filter_bias_fusion_zero_point_kernel<size_bits, interleaved>)); | |||||
uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_lane) + OC; | |||||
nr_threads = std::min(nr_threads, vthreads); | |||||
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||||
reorder_imma_filter_bias_fusion_zero_point_kernel<size_bits, interleaved> | |||||
<<<nr_blocks, nr_threads, 0, stream>>>( | |||||
dst_filter, dst_bias, src_filter, src_bias, bias_scale, | |||||
reduce_filter, zero_point, OC, IC, FH, FW); | |||||
after_kernel_launch(); | |||||
} | |||||
#define INST(_size_bits, _interleaved) \ | |||||
template void \ | |||||
megdnn::cuda::ptx::reorder_imma_filter_bias<_size_bits, _interleaved>( \ | |||||
int8_t * dst_filter, float* dst_bias, const int8_t* src_filter, \ | |||||
const int32_t* src_bias, float bias_scale, uint32_t OC, uint32_t IC, \ | |||||
uint32_t FH, uint32_t FW, cudaStream_t stream); | |||||
INST(8, 32) | |||||
INST(4, 64) | |||||
#undef INST | |||||
#define INST(_size_bits, _interleaved) \ | |||||
template void megdnn::cuda::ptx::reorder_imma_filter_bias_fusion_zero_point< \ | |||||
_size_bits, _interleaved>( \ | |||||
int8_t * dst_filter, float* dst_bias, const int8_t* src_filter, \ | |||||
const int32_t* src_bias, float bias_scale, const int32_t* reduce_filter, \ | |||||
float zero_point, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, \ | |||||
cudaStream_t stream); | |||||
INST(4, 64) | |||||
#undef INST | |||||
// vim: syntax=cuda.doxygen |
@@ -0,0 +1,120 @@ | |||||
#pragma once | |||||
#include "src/cuda/int_fastdiv.cuh" | |||||
#include "src/cuda/utils.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace ptx { | |||||
struct Conv2dInt8Param { | |||||
uint32_t n, ic, ih, iw, fh, fw, sh, sw, ph, pw, oc, oh, ow; | |||||
uint32_t ibs, ics, ihs; | |||||
uint32_t obs, ocs, ohs; | |||||
uint32_t icfhfw; | |||||
uint32_t nhw; | |||||
uint32_t oc32; | |||||
Uint32Fastdiv div_ohow; | |||||
Uint32Fastdiv div_ow; | |||||
Conv2dInt8Param( | |||||
uint32_t n, uint32_t ic, uint32_t ih, uint32_t iw, uint32_t fh, uint32_t fw, | |||||
uint32_t sh, uint32_t sw, uint32_t ph, uint32_t pw, uint32_t oc, | |||||
uint32_t oh, uint32_t ow, uint32_t interleaved) | |||||
: n(n), | |||||
ic(ic), | |||||
ih(ih), | |||||
iw(iw), | |||||
fh(fh), | |||||
fw(fw), | |||||
sh(sh), | |||||
sw(sw), | |||||
ph(ph), | |||||
pw(pw), | |||||
oc(oc), | |||||
oh(oh), | |||||
ow(ow) { | |||||
ibs = ic * ih * iw; | |||||
ics = ih * iw * interleaved; | |||||
ihs = iw * interleaved; | |||||
obs = oc * oh * ow; | |||||
ocs = oh * ow * interleaved; | |||||
ohs = ow * interleaved; | |||||
icfhfw = ic * fh * fw; | |||||
div_ohow = oh * ow; | |||||
div_ow = ow; | |||||
nhw = n * oh * ow; | |||||
// used for dp4a kernel, reduce usage of register file | |||||
oc32 = oc * 32; | |||||
} | |||||
}; | |||||
struct Conv2dInt4Param { | |||||
uint32_t n, ic, ih, iw, fh, fw, sh, sw, ph, pw, oc, oh, ow; | |||||
uint32_t ibs, ics, ihs; | |||||
uint32_t obs, ocs, ohs; | |||||
uint32_t icfhfw; | |||||
uint32_t nhw; | |||||
Uint32Fastdiv div_ohow; | |||||
Uint32Fastdiv div_ow; | |||||
Conv2dInt4Param( | |||||
uint32_t n, uint32_t ic, uint32_t ih, uint32_t iw, uint32_t fh, uint32_t fw, | |||||
uint32_t sh, uint32_t sw, uint32_t ph, uint32_t pw, uint32_t oc, | |||||
uint32_t oh, uint32_t ow, uint32_t interleaved = 64) | |||||
: n(n), | |||||
ic(ic), | |||||
ih(ih), | |||||
iw(iw), | |||||
fh(fh), | |||||
fw(fw), | |||||
sh(sh), | |||||
sw(sw), | |||||
ph(ph), | |||||
pw(pw), | |||||
oc(oc), | |||||
oh(oh), | |||||
ow(ow) { | |||||
constexpr uint32_t size_bits = 4; | |||||
// all stride size in bytes | |||||
ibs = ic * ih * iw * size_bits / 8; | |||||
ics = ih * iw * interleaved * size_bits / 8; | |||||
ihs = iw * interleaved * size_bits / 8; | |||||
obs = oc * oh * ow * size_bits / 8; | |||||
ocs = oh * ow * interleaved * size_bits / 8; | |||||
ohs = ow * interleaved * size_bits / 8; | |||||
icfhfw = ic * fh * fw; | |||||
nhw = n * oh * ow; | |||||
div_ohow = oh * ow; | |||||
div_ow = ow; | |||||
} | |||||
}; | |||||
struct Conv2dConstantOffsetParam { | |||||
int32_t begin; | |||||
int32_t size; | |||||
int32_t max; | |||||
int32_t rewind; | |||||
}; | |||||
#define CONSTANT_BUFFER_SIZE 848 | |||||
struct Conv2dConstantOffset { | |||||
Conv2dConstantOffsetParam c_offset_param; | |||||
int c_offset[CONSTANT_BUFFER_SIZE]; | |||||
}; | |||||
template <uint32_t size_bits, uint32_t interleaved> | |||||
void reorder_imma_filter_bias( | |||||
int8_t* dst_filter, float* dst_bias, const int8_t* src_filter, | |||||
const int32_t* src_bias, float bias_scale, uint32_t OC, uint32_t IC, | |||||
uint32_t FH, uint32_t FW, cudaStream_t stream); | |||||
template <uint32_t size_bits, uint32_t interleaved> | |||||
void reorder_imma_filter_bias_fusion_zero_point( | |||||
int8_t* dst_filter, float* dst_bias, const int8_t* src_filter, | |||||
const int32_t* src_bias, float bias_scale, const int32_t* reduce_filter, | |||||
float zero_point, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, | |||||
cudaStream_t stream); | |||||
} // namespace ptx | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cuda.doxygen |
@@ -0,0 +1,341 @@ | |||||
/** | |||||
* \file dnn/src/cuda/conv_bias/ptx_implicit_gemm_uint4_int4_nchw64_imma.cpp | |||||
*/ | |||||
#include "./algo.h" | |||||
#include "src/common/conv_bias.h" | |||||
#include "src/cuda/conv_bias/ptx_helper.cuh" | |||||
#include "src/cuda/conv_bias/reduce_filter.cuh" | |||||
#include "src/cuda/ptx/uint4_int4/kern.cuh" | |||||
#include "src/cuda/ptx_loader.h" | |||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace ptx; | |||||
namespace { | |||||
// all stride are in bytes | |||||
void compute_conv2d_offset( | |||||
size_t fh, size_t fw, size_t ics, size_t ihs, | |||||
Conv2dConstantOffset& constant_offset) { | |||||
constexpr int interleaved = 64; | |||||
constexpr int size_bits = 4; | |||||
constexpr int threablock_k = 128; | |||||
constexpr int inc_step = threablock_k / interleaved; | |||||
size_t i = 0; | |||||
int* s32 = &(constant_offset.c_offset[0]); | |||||
for (; i < inc_step; i++) { | |||||
int c = i / (fh * fw); | |||||
int khkw = i % (fh * fw); | |||||
int kh = khkw / fw; | |||||
int kw = khkw % fw; | |||||
s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8; | |||||
int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1])); | |||||
s8[0] = kh; | |||||
s8[1] = kw; | |||||
s8[2] = -kh; | |||||
s8[3] = -kw; | |||||
} | |||||
for (; i < (inc_step + fh * fw * inc_step); i++) { | |||||
int c = i / (fh * fw); | |||||
int khkw = i % (fh * fw); | |||||
int kh = khkw / fw; | |||||
int kw = khkw % fw; | |||||
s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8; | |||||
int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1])); | |||||
s8[0] = kh; | |||||
s8[1] = kw; | |||||
s8[2] = -kh; | |||||
s8[3] = -kw; | |||||
int i_ = i - inc_step; | |||||
c = i_ / (fh * fw); | |||||
khkw = i_ % (fh * fw); | |||||
kh = khkw / fw; | |||||
kw = khkw % fw; | |||||
s32[2 * i] -= c * ics + kh * ihs + kw * interleaved * size_bits / 8; | |||||
} | |||||
} | |||||
}; // namespace | |||||
std::string ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::kernel_key( | |||||
const SizeArgs& args) const { | |||||
std::string kernel_key; | |||||
using NonlineMode = Param::NonlineMode; | |||||
auto&& param = args.opr->param(); | |||||
if (args.z_layout->ndim > 0) { | |||||
kernel_key = ssprintf( | |||||
"%s_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_%ux%u", | |||||
current_device_arch_name(), m_tile_nhw, m_tile_oc); | |||||
} else { | |||||
kernel_key = ssprintf( | |||||
"%s_conv_bias_uint4_int4_imma8832_ldg16_%ux%u", | |||||
current_device_arch_name(), m_tile_nhw, m_tile_oc); | |||||
} | |||||
megdnn_assert( | |||||
param.nonlineMode == NonlineMode::RELU || | |||||
param.nonlineMode == NonlineMode::IDENTITY); | |||||
kernel_key += "_relu"; | |||||
return kernel_key; | |||||
} | |||||
bool ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::is_available( | |||||
const SizeArgs& args) const { | |||||
if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { | |||||
return false; | |||||
} | |||||
if (args.bias_layout->ndim <= 0) | |||||
return false; | |||||
using Param = param::ConvBias; | |||||
using Format = Param::Format; | |||||
using Sparse = Param::Sparse; | |||||
using Mode = Param::Mode; | |||||
using NonlineMode = Param::NonlineMode; | |||||
bool available = true; | |||||
auto&& param = args.opr->param(); | |||||
auto&& fm = args.filter_meta; | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) | |||||
return false; | |||||
if (param.format != Format::NCHW64) | |||||
return false; | |||||
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | |||||
// TODO support group conv | |||||
available &= param.sparse == Sparse::DENSE; | |||||
// mode must be cross correlation | |||||
available &= param.mode == Mode::CROSS_CORRELATION; | |||||
// nonlineMode must be RELU or IDENTITY | |||||
available &= | |||||
(param.nonlineMode == NonlineMode::RELU || | |||||
param.nonlineMode == NonlineMode::IDENTITY); | |||||
// check data type | |||||
auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype, | |||||
bias_dtype = args.bias_layout->dtype, dst_dtype = args.dst_layout->dtype; | |||||
available &= | |||||
(src_dtype.enumv() == DTypeEnum::Quantized4Asymm && | |||||
filter_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
bias_dtype.enumv() == DTypeEnum::QuantizedS32 && | |||||
dst_dtype.enumv() == DTypeEnum::Quantized4Asymm); | |||||
// TODO: support dialtion | |||||
available &= dh == 1 && dw == 1; | |||||
// ensure precomputed offsets are positive integers | |||||
available &= hi >= fh && wi >= fw; | |||||
// only support sm_86 or later, platform should have tensorcore int4 | |||||
// support | |||||
available &= | |||||
(is_compute_capability_equalto(8, 0) || | |||||
is_compute_capability_equalto(8, 6)); | |||||
// param buffer size is 4K, use 3K to store precomputed offset | |||||
size_t kMaxFilterPixels = CONSTANT_BUFFER_SIZE / (2 * 128 / 64) - 1; | |||||
available &= fh * fw <= kMaxFilterPixels; | |||||
return available; | |||||
} | |||||
size_t ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm:: | |||||
get_workspace_in_bytes(const SizeArgs& args) const { | |||||
if (args.preprocessed_filter == nullptr) { | |||||
size_t OC = args.filter_layout->operator[](0), | |||||
IC = args.filter_layout->operator[](1) * 64, | |||||
FH = args.filter_layout->operator[](2), | |||||
FW = args.filter_layout->operator[](3); | |||||
size_t ws_size_reduce_filter = OC * sizeof(int32_t); | |||||
// for reduce filter | |||||
{ | |||||
size_t A = OC, B = IC * FH * FW / 8, C = 1; | |||||
ws_size_reduce_filter += do_dispatch_reduce_workspace_in_bytes(A, B, C); | |||||
} | |||||
return args.filter_layout->span().dist_byte() + | |||||
args.bias_layout->span().dist_byte() + ws_size_reduce_filter; | |||||
} | |||||
return 0_z; | |||||
} | |||||
void ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::exec( | |||||
const ExecArgs& args) const { | |||||
using Format = Param::Format; | |||||
auto&& param = args.opr->param(); | |||||
auto&& fm = args.filter_meta; | |||||
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | |||||
auto&& stream = cuda_stream(args.opr->handle()); | |||||
constexpr int interleaved = 64; | |||||
void* bias_ptr = nullptr; | |||||
void* filter_ptr = nullptr; | |||||
if (args.preprocessed_filter) { | |||||
megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | |||||
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr(); | |||||
} else { | |||||
// reorder filter and bias | |||||
filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | |||||
bias_ptr = reinterpret_cast<void*>( | |||||
args.workspace.raw_ptr + args.filter_layout->span().dist_byte()); | |||||
void* reduce_filter_ptr = reinterpret_cast<void*>( | |||||
args.workspace.raw_ptr + args.filter_layout->span().dist_byte() + | |||||
args.bias_layout->span().dist_byte()); | |||||
reorder_filter_bias(args, reduce_filter_ptr, filter_ptr, bias_ptr); | |||||
} | |||||
uint32_t u32_n = n, u32_ci = ci, u32_hi = hi, u32_wi = wi, u32_fh = fh, u32_fw = fw, | |||||
u32_sh = sh, u32_sw = sw, u32_ph = ph, u32_pw = pw, u32_co = co, | |||||
u32_ho = ho, u32_wo = wo; | |||||
Conv2dInt4Param kern_param( | |||||
u32_n, u32_ci, u32_hi, u32_wi, u32_fh, u32_fw, u32_sh, u32_sw, u32_ph, | |||||
u32_pw, u32_co, u32_ho, u32_wo, interleaved); | |||||
Conv2dConstantOffset kern_coffset; | |||||
compute_conv2d_offset(fh, fw, kern_param.ics, kern_param.ihs, kern_coffset); | |||||
// begin is not need | |||||
kern_coffset.c_offset_param.begin = param_buffer_start_address(); | |||||
kern_coffset.c_offset_param.size = 4 * (1 + fh * fw); | |||||
kern_coffset.c_offset_param.max = 4 * fh * fw; | |||||
kern_coffset.c_offset_param.rewind = 4 * (1 - fh * fw); | |||||
float src_scale = args.src_layout->dtype.param<dtype::Quantized4Asymm>().scale, | |||||
dst_scale = args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale, | |||||
filter_scale = args.filter_layout->dtype.param<dtype::QuantizedS4>().scale; | |||||
uint32_t src_zero_point = | |||||
(uint32_t)(args.src_layout->dtype.param<dtype::Quantized4Asymm>() | |||||
.zero_point); | |||||
uint32_t pk_src_zero_point = 0; | |||||
for (int i = 0; i < 8; i++) { | |||||
pk_src_zero_point <<= 4; | |||||
pk_src_zero_point |= (src_zero_point & 0xF); | |||||
} | |||||
float dst_zero_point = | |||||
(float)(args.dst_layout->dtype.param<dtype::Quantized4Asymm>().zero_point); | |||||
float alpha = src_scale * filter_scale / dst_scale, beta = 1.f; | |||||
unsigned int tx = m_threads, ty = 1; | |||||
unsigned int gridx = | |||||
div_ceil<unsigned int>(static_cast<unsigned int>(n * ho * wo), m_tile_nhw); | |||||
unsigned int gridy = | |||||
div_ceil<unsigned int>(static_cast<unsigned int>(co), m_tile_oc); | |||||
void* src_ptr = const_cast<void*>(args.src_tensor->raw_ptr()); | |||||
void* dst_ptr = const_cast<void*>(args.dst_tensor->raw_ptr()); | |||||
using NonlineMode = Param::NonlineMode; | |||||
auto kern_key = kernel_key(args); | |||||
auto&& kernel = PTXKernelLoader::instance().get_kernel(kern_key); | |||||
if (args.z_layout->ndim > 0) { | |||||
void* z_ptr = const_cast<void*>(args.z_tensor->raw_ptr()); | |||||
auto z_param = args.z_layout->dtype.param<dtype::Quantized4Asymm>(); | |||||
int32_t z_zero_point = (int32_t)z_param.zero_point; | |||||
float z_scale = z_param.scale; | |||||
float gamma = z_scale / dst_scale; | |||||
std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr, &z_ptr, | |||||
&dst_ptr, &alpha, &beta, &gamma}; | |||||
kern_coffset.c_offset_param.begin += sizeof(src_ptr) + sizeof(filter_ptr) + | |||||
sizeof(bias_ptr) + sizeof(z_ptr) + | |||||
sizeof(dst_ptr) + sizeof(alpha) + | |||||
sizeof(beta) + sizeof(gamma); | |||||
kern_coffset.c_offset_param.begin += sizeof(pk_src_zero_point); | |||||
params.push_back(&pk_src_zero_point); | |||||
kern_coffset.c_offset_param.begin += sizeof(z_zero_point); | |||||
params.push_back(&z_zero_point); | |||||
kern_coffset.c_offset_param.begin += sizeof(dst_zero_point); | |||||
params.push_back(&dst_zero_point); | |||||
uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0; | |||||
params.push_back(&relu); | |||||
kern_coffset.c_offset_param.begin += sizeof(relu); | |||||
params.push_back(&kern_param); | |||||
kern_coffset.c_offset_param.begin += sizeof(kern_param); | |||||
kern_coffset.c_offset_param.begin += sizeof(kern_coffset.c_offset_param); | |||||
params.push_back(&kern_coffset); | |||||
dim3 grid(gridx, gridy, 1); | |||||
dim3 block(tx, ty, 1); | |||||
kernel(grid, block, stream, params.data()); | |||||
} else { | |||||
std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr, | |||||
&dst_ptr, &alpha, &beta}; | |||||
kern_coffset.c_offset_param.begin += sizeof(src_ptr) + sizeof(filter_ptr) + | |||||
sizeof(bias_ptr) + sizeof(dst_ptr) + | |||||
sizeof(alpha) + sizeof(beta); | |||||
kern_coffset.c_offset_param.begin += sizeof(pk_src_zero_point); | |||||
params.push_back(&pk_src_zero_point); | |||||
kern_coffset.c_offset_param.begin += sizeof(dst_zero_point); | |||||
params.push_back(&dst_zero_point); | |||||
uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0; | |||||
params.push_back(&relu); | |||||
kern_coffset.c_offset_param.begin += sizeof(relu); | |||||
params.push_back(&kern_param); | |||||
kern_coffset.c_offset_param.begin += sizeof(kern_param); | |||||
kern_coffset.c_offset_param.begin += sizeof(kern_coffset.c_offset_param); | |||||
params.push_back(&kern_coffset); | |||||
dim3 grid(gridx, gridy, 1); | |||||
dim3 block(tx, ty, 1); | |||||
kernel(grid, block, stream, params.data()); | |||||
} | |||||
after_kernel_launch(); | |||||
} | |||||
size_t ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm:: | |||||
get_preprocess_workspace_in_bytes(const SizeArgs& args) const { | |||||
size_t OC = args.filter_layout->operator[](0), | |||||
IC = args.filter_layout->operator[](1) * 64, | |||||
FH = args.filter_layout->operator[](2), | |||||
FW = args.filter_layout->operator[](3); | |||||
size_t ws_size_reduce_filter = OC * sizeof(int32_t); | |||||
// for reduce filter | |||||
{ | |||||
size_t A = OC, B = IC * FH * FW / 8, C = 1; | |||||
ws_size_reduce_filter += do_dispatch_reduce_workspace_in_bytes(A, B, C); | |||||
} | |||||
return ws_size_reduce_filter; | |||||
} | |||||
SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm:: | |||||
deduce_preprocessed_filter_layout(const SizeArgs& args) const { | |||||
return {args.filter_layout->collapse_contiguous(), | |||||
args.bias_layout->collapse_contiguous()}; | |||||
} | |||||
void ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::reorder_filter_bias( | |||||
const ExecArgs& args, void* reduce_filter, void* reordered_filter, | |||||
void* reordered_bias) const { | |||||
using Format = Param::Format; | |||||
auto&& param = args.opr->param(); | |||||
auto&& fm = args.filter_meta; | |||||
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | |||||
auto&& stream = cuda_stream(args.opr->handle()); | |||||
float src_scale = args.src_layout->dtype.param<dtype::Quantized4Asymm>().scale, | |||||
filter_scale = args.filter_layout->dtype.param<dtype::QuantizedS4>().scale, | |||||
bias_scale = args.bias_layout->dtype.param<dtype::QuantizedS32>().scale, | |||||
dst_scale = args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||||
float scaled_src_zero_point = | |||||
args.src_layout->dtype.param<dtype::Quantized4Asymm>().zero_point * | |||||
src_scale * filter_scale / dst_scale; | |||||
// NCHW64 reduce CHW64 | |||||
do_dispatch_reduce_with_scale_filter_4bit<true>( | |||||
reinterpret_cast<uint8_t*>(args.filter_tensor->raw_ptr()), 1, co, | |||||
ci * fh * fw / 8, static_cast<int32_t*>(reduce_filter), stream); | |||||
reorder_imma_filter_bias_fusion_zero_point<4, 64>( | |||||
reinterpret_cast<int8_t*>(reordered_filter), | |||||
reinterpret_cast<float*>(reordered_bias), | |||||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr()), | |||||
args.bias_tensor->compatible_ptr<int32_t>(), bias_scale / dst_scale, | |||||
static_cast<int32_t*>(reduce_filter), scaled_src_zero_point, co, ci, fh, fw, | |||||
stream); | |||||
} | |||||
void ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( | |||||
const ExecArgs& args) const { | |||||
reorder_filter_bias( | |||||
args, args.workspace.raw_ptr, | |||||
args.preprocessed_filter->tensors[0].raw_ptr(), | |||||
args.preprocessed_filter->tensors[1].raw_ptr()); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,49 @@ | |||||
/** | |||||
* \file dnn/src/cuda/ptx_loader.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/cuda/ptx_loader.h" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
// ******************* PTXKernelLoader ********************* | |||||
const std::unordered_map<std::string, PTXKernelLoader::kernel> PTXKernelLoader::KERNEL_MAP = | |||||
{{"ampere_conv_bias_uint4_int4_imma8832_ldg16_256x64_relu", | |||||
ptx::run_ampere_conv_bias_uint4_int4_imma8832_ldg16_256x64_relu}, | |||||
{"ampere_conv_bias_uint4_int4_imma8832_ldg16_128x128_relu", | |||||
ptx::run_ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu}, | |||||
{"ampere_conv_bias_uint4_int4_imma8832_ldg16_128x256_relu", | |||||
ptx::run_ampere_conv_bias_uint4_int4_imma8832_ldg16_128x256_relu}, | |||||
{"ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_256x64_relu", | |||||
ptx::run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_256x64_relu}, | |||||
{"ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_128x128_relu", | |||||
ptx::run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu}, | |||||
{"ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_128x256_relu", | |||||
ptx::run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_128x256_relu}}; | |||||
PTXKernelLoader& PTXKernelLoader::instance() { | |||||
static PTXKernelLoader ins; | |||||
return ins; | |||||
} | |||||
const PTXKernelLoader::kernel PTXKernelLoader::get_kernel( | |||||
const std::string& kernel_name) { | |||||
decltype(KERNEL_MAP.begin()) kernel_iter; | |||||
kernel_iter = KERNEL_MAP.find(kernel_name); | |||||
megdnn_throw_if( | |||||
kernel_iter == KERNEL_MAP.end(), megdnn_error, | |||||
ssprintf("kernel name %s not found in KERNEL_MAP", kernel_name.c_str()) | |||||
.c_str()); | |||||
return kernel_iter->second; | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* \file dnn/src/cuda/ptx_loader.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <mutex> | |||||
#include <unordered_map> | |||||
#include "src/cuda/ptx/uint4_int4/kern.cuh" | |||||
#include "src/cuda/utils.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
class PTXKernelLoader { | |||||
private: | |||||
PTXKernelLoader() = default; | |||||
using kernel = std::function<void(const dim3, const dim3, cudaStream_t, void**)>; | |||||
public: | |||||
PTXKernelLoader(const PTXKernelLoader&) = delete; | |||||
const PTXKernelLoader& operator=(const PTXKernelLoader&) = delete; | |||||
static PTXKernelLoader& instance(); | |||||
const kernel get_kernel(const std::string& kernel_name); | |||||
static const std::unordered_map<std::string, kernel> KERNEL_MAP; | |||||
}; | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |