@@ -92,6 +92,12 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||
fill_dwconv_algos(); | |||
all_algos.push_back(&int8_chwn4_dotprod); | |||
all_algos.push_back(&fallback_nchw_qs8); | |||
fill_ptx_algos(); | |||
for (auto&& algo : algo_ptx_conv2d_u4_s4) { | |||
all_algos.push_back(&algo); | |||
} | |||
for (size_t i = all_algo_size; i < all_algos.size(); ++i) { | |||
non_cudnn_algos.push_back(all_algos[i]); | |||
} | |||
@@ -364,6 +370,15 @@ void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { | |||
int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2}); | |||
} | |||
void ConvBiasForwardImpl::AlgoPack::fill_ptx_algos() { | |||
algo_ptx_conv2d_u4_s4.emplace_back( | |||
AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm{128, 256, 256}); | |||
algo_ptx_conv2d_u4_s4.emplace_back( | |||
AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm{128, 128, 128}); | |||
algo_ptx_conv2d_u4_s4.emplace_back( | |||
AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm{256, 64, 128}); | |||
} | |||
ConvBiasForwardImpl::AlgoBase* ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum( | |||
cudnnConvolutionFwdAlgo_t algo) { | |||
for (auto&& i : cudnn_convs) { | |||
@@ -78,6 +78,7 @@ public: | |||
CUDA_SIMPLE_INT1, | |||
CUDA_CUDNN_CONV_V8, | |||
CUDA_CUDNN_CONVBIAS_V8, | |||
CUDA_IMPLICIT_GEMM_PTX_NCHW64_IMMA_UINT4_INT4, | |||
}; | |||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | |||
@@ -1203,6 +1204,45 @@ private: | |||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
}; | |||
class ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm final | |||
: public AlgoBase { | |||
public: | |||
AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm( | |||
unsigned int tile_nhw, unsigned int tile_oc, unsigned int threads) | |||
: m_tile_nhw{tile_nhw}, m_tile_oc{tile_oc}, m_threads{threads} { | |||
m_name = ConvBias::algo_name<ConvBias::DirectParam>( | |||
ssprintf( | |||
"PTX_UINT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%uX%u_%u", m_tile_nhw, | |||
m_tile_oc, m_threads), | |||
ConvBias::DirectParam{}); | |||
} | |||
bool is_available(const SizeArgs& args) const override; | |||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
void exec(const ExecArgs& args) const override; | |||
const char* name() const override { return m_name.c_str(); } | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override; | |||
SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | |||
const SizeArgs& args) const override; | |||
void exec_preprocess(const ExecArgs& args) const override; | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_PTX_NCHW64_IMMA_UINT4_INT4) | |||
std::string param() const override { | |||
std::string ret; | |||
serialize_write_pod(m_tile_nhw, ret); | |||
serialize_write_pod(m_tile_oc, ret); | |||
serialize_write_pod(m_threads, ret); | |||
return ret; | |||
} | |||
private: | |||
std::string kernel_key(const SizeArgs& args) const; | |||
unsigned int m_tile_nhw, m_tile_oc, m_threads; | |||
std::string m_name; | |||
void reorder_filter_bias( | |||
const ExecArgs& args, void* reduce_filter, void* reordered_filter, | |||
void* reordered_bias) const; | |||
}; | |||
class ConvBiasForwardImpl::AlgoPack : NonCopyableObj { | |||
private: | |||
AlgoBase::Mapper m_all_algos_map; | |||
@@ -1251,6 +1291,7 @@ public: | |||
AlgoCUDNNConvV8 cudnn_conv_v8; | |||
AlgoCUDNNConvBiasActivationV8 cudnn_conv_bias_activation_v8; | |||
#endif | |||
std::vector<AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm> algo_ptx_conv2d_u4_s4; | |||
AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); | |||
@@ -1265,6 +1306,7 @@ private: | |||
void fill_cudnn_algos(); | |||
void fill_dp4a_algos(); | |||
void fill_dwconv_algos(); | |||
void fill_ptx_algos(); | |||
}; | |||
} // namespace cuda | |||
@@ -72,6 +72,7 @@ public: | |||
class AlgoCUDNNConvV8; | |||
class AlgoCUDNNConvBiasActivationV8; | |||
#endif | |||
class AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm; | |||
class AlgoPack; | |||
@@ -0,0 +1,156 @@ | |||
#include "src/cuda/conv_bias/ptx_helper.cuh" | |||
#include "src/cuda/integer_subbyte_utils.cuh" | |||
#include "src/cuda/query_blocksize.cuh" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
using namespace ptx; | |||
namespace { | |||
template <uint32_t size_bits, uint32_t interleaved> | |||
__device__ __forceinline__ void reorder_imma_filter_func( | |||
int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, | |||
uint32_t FW, uint32_t lane) { | |||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||
uint32_t elements = lane * elements_per_lane; | |||
uint32_t row = elements / (IC * FH * FW); | |||
uint32_t col = elements - row * IC * FH * FW; | |||
uint32_t sec = row / 4; | |||
uint32_t res = col & (interleaved - 1); | |||
uint32_t sec_sec = row & 3; | |||
uint32_t sec_res = (row & 15) / 4; | |||
uint32_t crosswise_offset = ((sec_sec >> 1) * 2 * interleaved) + | |||
(((sec_sec & 1) ^ (sec_res >> 1)) * interleaved); | |||
uint32_t residue_offset = | |||
((res / elements_per_lane) ^ (sec_res & 1)) * elements_per_lane; | |||
uint32_t dst_offset = | |||
(sec / 2) * 8 * FH * FW * IC + (col / interleaved) * (8 * interleaved) + | |||
(sec & 1) * (4 * interleaved) + crosswise_offset + residue_offset; | |||
static constexpr uint32_t instruction_shape_col = 8; | |||
// 4 threads per Quad | |||
static constexpr uint32_t elements_per_thread = instruction_shape_col / 4; | |||
// 4 threads per Quad | |||
static constexpr uint32_t reordered_elements_per_thread = interleaved / 4; | |||
uint32_t elem_in_interleaved = row % interleaved; | |||
uint32_t elem_in_interleaved_pack = elem_in_interleaved / elements_per_thread; | |||
int elem_new = (row / interleaved * interleaved + | |||
elem_in_interleaved_pack % 4 * reordered_elements_per_thread + | |||
elem_in_interleaved_pack / 4 * elements_per_thread + | |||
elem_in_interleaved % elements_per_thread) * | |||
(IC * FH * FW) + | |||
col; | |||
*(reinterpret_cast<int4*>(dst + (dst_offset * size_bits / 8))) = | |||
*(reinterpret_cast<const int4*>(src + (elem_new * size_bits / 8))); | |||
} | |||
template <uint32_t interleaved> | |||
__device__ __forceinline__ void reorder_imma_bias_func( | |||
float* __restrict__ dst, float src_value, uint32_t OC, uint32_t lane) { | |||
dst[lane] = src_value; | |||
} | |||
template <uint32_t size_bits, uint32_t interleaved> | |||
__global__ void reorder_imma_filter_bias_kernel( | |||
int8_t* __restrict__ dst_filter, float* __restrict__ dst_bias, | |||
const int8_t* __restrict__ src_filter, const int32_t* __restrict__ src_bias, | |||
float bias_scale, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW) { | |||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||
const uint32_t size1 = OC * IC * FH * FW / elements_per_lane; | |||
const uint32_t size2 = OC; | |||
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (lane < size1) { | |||
reorder_imma_filter_func<size_bits, interleaved>( | |||
dst_filter, src_filter, OC, IC, FH, FW, lane); | |||
} else if (lane < size1 + size2) { | |||
lane = lane - size1; | |||
float src_bias_value = src_bias[lane] * bias_scale; | |||
reorder_imma_bias_func<interleaved>(dst_bias, src_bias_value, OC, lane); | |||
} | |||
} | |||
template <uint32_t size_bits, uint32_t interleaved> | |||
__global__ void reorder_imma_filter_bias_fusion_zero_point_kernel( | |||
int8_t* __restrict__ dst_filter, float* __restrict__ dst_bias, | |||
const int8_t* __restrict__ src_filter, const int32_t* __restrict__ src_bias, | |||
float bias_scale, const int32_t* reduce_filter, float zero_point, uint32_t OC, | |||
uint32_t IC, uint32_t FH, uint32_t FW) { | |||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||
const uint32_t size1 = OC * IC * FH * FW / elements_per_lane; | |||
const uint32_t size2 = OC; | |||
uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; | |||
if (lane < size1) { | |||
reorder_imma_filter_func<size_bits, interleaved>( | |||
dst_filter, src_filter, OC, IC, FH, FW, lane); | |||
} else if (lane < size1 + size2) { | |||
lane = lane - size1; | |||
// fusion bias and zero_point | |||
// zero_point = zero_point * src_scale * filter_scale | |||
float src_bias_value = | |||
src_bias[lane] * bias_scale - reduce_filter[lane] * zero_point; | |||
reorder_imma_bias_func<interleaved>(dst_bias, src_bias_value, OC, lane); | |||
} | |||
} | |||
} // namespace | |||
template <uint32_t size_bits, uint32_t interleaved> | |||
void megdnn::cuda::ptx::reorder_imma_filter_bias( | |||
int8_t* dst_filter, float* dst_bias, const int8_t* src_filter, | |||
const int32_t* src_bias, float bias_scale, uint32_t OC, uint32_t IC, | |||
uint32_t FH, uint32_t FW, cudaStream_t stream) { | |||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||
uint32_t nr_threads = query_blocksize_for_kernel(reinterpret_cast<const void*>( | |||
reorder_imma_filter_bias_kernel<size_bits, interleaved>)); | |||
uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_lane) + OC; | |||
nr_threads = std::min(nr_threads, vthreads); | |||
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||
reorder_imma_filter_bias_kernel<size_bits, interleaved> | |||
<<<nr_blocks, nr_threads, 0, stream>>>( | |||
dst_filter, dst_bias, src_filter, src_bias, bias_scale, OC, IC, FH, | |||
FW); | |||
after_kernel_launch(); | |||
} | |||
template <uint32_t size_bits, uint32_t interleaved> | |||
void megdnn::cuda::ptx::reorder_imma_filter_bias_fusion_zero_point( | |||
int8_t* dst_filter, float* dst_bias, const int8_t* src_filter, | |||
const int32_t* src_bias, float bias_scale, const int32_t* reduce_filter, | |||
float zero_point, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, | |||
cudaStream_t stream) { | |||
static constexpr uint32_t elements_per_lane = 128 / size_bits; | |||
uint32_t nr_threads = query_blocksize_for_kernel(reinterpret_cast<const void*>( | |||
reorder_imma_filter_bias_fusion_zero_point_kernel<size_bits, interleaved>)); | |||
uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_lane) + OC; | |||
nr_threads = std::min(nr_threads, vthreads); | |||
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); | |||
reorder_imma_filter_bias_fusion_zero_point_kernel<size_bits, interleaved> | |||
<<<nr_blocks, nr_threads, 0, stream>>>( | |||
dst_filter, dst_bias, src_filter, src_bias, bias_scale, | |||
reduce_filter, zero_point, OC, IC, FH, FW); | |||
after_kernel_launch(); | |||
} | |||
#define INST(_size_bits, _interleaved) \ | |||
template void \ | |||
megdnn::cuda::ptx::reorder_imma_filter_bias<_size_bits, _interleaved>( \ | |||
int8_t * dst_filter, float* dst_bias, const int8_t* src_filter, \ | |||
const int32_t* src_bias, float bias_scale, uint32_t OC, uint32_t IC, \ | |||
uint32_t FH, uint32_t FW, cudaStream_t stream); | |||
INST(8, 32) | |||
INST(4, 64) | |||
#undef INST | |||
#define INST(_size_bits, _interleaved) \ | |||
template void megdnn::cuda::ptx::reorder_imma_filter_bias_fusion_zero_point< \ | |||
_size_bits, _interleaved>( \ | |||
int8_t * dst_filter, float* dst_bias, const int8_t* src_filter, \ | |||
const int32_t* src_bias, float bias_scale, const int32_t* reduce_filter, \ | |||
float zero_point, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, \ | |||
cudaStream_t stream); | |||
INST(4, 64) | |||
#undef INST | |||
// vim: syntax=cuda.doxygen |
@@ -0,0 +1,120 @@ | |||
#pragma once | |||
#include "src/cuda/int_fastdiv.cuh" | |||
#include "src/cuda/utils.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace ptx { | |||
struct Conv2dInt8Param { | |||
uint32_t n, ic, ih, iw, fh, fw, sh, sw, ph, pw, oc, oh, ow; | |||
uint32_t ibs, ics, ihs; | |||
uint32_t obs, ocs, ohs; | |||
uint32_t icfhfw; | |||
uint32_t nhw; | |||
uint32_t oc32; | |||
Uint32Fastdiv div_ohow; | |||
Uint32Fastdiv div_ow; | |||
Conv2dInt8Param( | |||
uint32_t n, uint32_t ic, uint32_t ih, uint32_t iw, uint32_t fh, uint32_t fw, | |||
uint32_t sh, uint32_t sw, uint32_t ph, uint32_t pw, uint32_t oc, | |||
uint32_t oh, uint32_t ow, uint32_t interleaved) | |||
: n(n), | |||
ic(ic), | |||
ih(ih), | |||
iw(iw), | |||
fh(fh), | |||
fw(fw), | |||
sh(sh), | |||
sw(sw), | |||
ph(ph), | |||
pw(pw), | |||
oc(oc), | |||
oh(oh), | |||
ow(ow) { | |||
ibs = ic * ih * iw; | |||
ics = ih * iw * interleaved; | |||
ihs = iw * interleaved; | |||
obs = oc * oh * ow; | |||
ocs = oh * ow * interleaved; | |||
ohs = ow * interleaved; | |||
icfhfw = ic * fh * fw; | |||
div_ohow = oh * ow; | |||
div_ow = ow; | |||
nhw = n * oh * ow; | |||
// used for dp4a kernel, reduce usage of register file | |||
oc32 = oc * 32; | |||
} | |||
}; | |||
struct Conv2dInt4Param { | |||
uint32_t n, ic, ih, iw, fh, fw, sh, sw, ph, pw, oc, oh, ow; | |||
uint32_t ibs, ics, ihs; | |||
uint32_t obs, ocs, ohs; | |||
uint32_t icfhfw; | |||
uint32_t nhw; | |||
Uint32Fastdiv div_ohow; | |||
Uint32Fastdiv div_ow; | |||
Conv2dInt4Param( | |||
uint32_t n, uint32_t ic, uint32_t ih, uint32_t iw, uint32_t fh, uint32_t fw, | |||
uint32_t sh, uint32_t sw, uint32_t ph, uint32_t pw, uint32_t oc, | |||
uint32_t oh, uint32_t ow, uint32_t interleaved = 64) | |||
: n(n), | |||
ic(ic), | |||
ih(ih), | |||
iw(iw), | |||
fh(fh), | |||
fw(fw), | |||
sh(sh), | |||
sw(sw), | |||
ph(ph), | |||
pw(pw), | |||
oc(oc), | |||
oh(oh), | |||
ow(ow) { | |||
constexpr uint32_t size_bits = 4; | |||
// all stride size in bytes | |||
ibs = ic * ih * iw * size_bits / 8; | |||
ics = ih * iw * interleaved * size_bits / 8; | |||
ihs = iw * interleaved * size_bits / 8; | |||
obs = oc * oh * ow * size_bits / 8; | |||
ocs = oh * ow * interleaved * size_bits / 8; | |||
ohs = ow * interleaved * size_bits / 8; | |||
icfhfw = ic * fh * fw; | |||
nhw = n * oh * ow; | |||
div_ohow = oh * ow; | |||
div_ow = ow; | |||
} | |||
}; | |||
struct Conv2dConstantOffsetParam { | |||
int32_t begin; | |||
int32_t size; | |||
int32_t max; | |||
int32_t rewind; | |||
}; | |||
#define CONSTANT_BUFFER_SIZE 848 | |||
struct Conv2dConstantOffset { | |||
Conv2dConstantOffsetParam c_offset_param; | |||
int c_offset[CONSTANT_BUFFER_SIZE]; | |||
}; | |||
template <uint32_t size_bits, uint32_t interleaved> | |||
void reorder_imma_filter_bias( | |||
int8_t* dst_filter, float* dst_bias, const int8_t* src_filter, | |||
const int32_t* src_bias, float bias_scale, uint32_t OC, uint32_t IC, | |||
uint32_t FH, uint32_t FW, cudaStream_t stream); | |||
template <uint32_t size_bits, uint32_t interleaved> | |||
void reorder_imma_filter_bias_fusion_zero_point( | |||
int8_t* dst_filter, float* dst_bias, const int8_t* src_filter, | |||
const int32_t* src_bias, float bias_scale, const int32_t* reduce_filter, | |||
float zero_point, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, | |||
cudaStream_t stream); | |||
} // namespace ptx | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cuda.doxygen |
@@ -0,0 +1,341 @@ | |||
/** | |||
* \file dnn/src/cuda/conv_bias/ptx_implicit_gemm_uint4_int4_nchw64_imma.cpp | |||
*/ | |||
#include "./algo.h" | |||
#include "src/common/conv_bias.h" | |||
#include "src/cuda/conv_bias/ptx_helper.cuh" | |||
#include "src/cuda/conv_bias/reduce_filter.cuh" | |||
#include "src/cuda/ptx/uint4_int4/kern.cuh" | |||
#include "src/cuda/ptx_loader.h" | |||
#include "src/cuda/utils.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
using namespace ptx; | |||
namespace { | |||
// all stride are in bytes | |||
void compute_conv2d_offset( | |||
size_t fh, size_t fw, size_t ics, size_t ihs, | |||
Conv2dConstantOffset& constant_offset) { | |||
constexpr int interleaved = 64; | |||
constexpr int size_bits = 4; | |||
constexpr int threablock_k = 128; | |||
constexpr int inc_step = threablock_k / interleaved; | |||
size_t i = 0; | |||
int* s32 = &(constant_offset.c_offset[0]); | |||
for (; i < inc_step; i++) { | |||
int c = i / (fh * fw); | |||
int khkw = i % (fh * fw); | |||
int kh = khkw / fw; | |||
int kw = khkw % fw; | |||
s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8; | |||
int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1])); | |||
s8[0] = kh; | |||
s8[1] = kw; | |||
s8[2] = -kh; | |||
s8[3] = -kw; | |||
} | |||
for (; i < (inc_step + fh * fw * inc_step); i++) { | |||
int c = i / (fh * fw); | |||
int khkw = i % (fh * fw); | |||
int kh = khkw / fw; | |||
int kw = khkw % fw; | |||
s32[2 * i] = c * ics + kh * ihs + kw * interleaved * size_bits / 8; | |||
int8_t* s8 = reinterpret_cast<int8_t*>(&(s32[2 * i + 1])); | |||
s8[0] = kh; | |||
s8[1] = kw; | |||
s8[2] = -kh; | |||
s8[3] = -kw; | |||
int i_ = i - inc_step; | |||
c = i_ / (fh * fw); | |||
khkw = i_ % (fh * fw); | |||
kh = khkw / fw; | |||
kw = khkw % fw; | |||
s32[2 * i] -= c * ics + kh * ihs + kw * interleaved * size_bits / 8; | |||
} | |||
} | |||
}; // namespace | |||
std::string ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::kernel_key( | |||
const SizeArgs& args) const { | |||
std::string kernel_key; | |||
using NonlineMode = Param::NonlineMode; | |||
auto&& param = args.opr->param(); | |||
if (args.z_layout->ndim > 0) { | |||
kernel_key = ssprintf( | |||
"%s_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_%ux%u", | |||
current_device_arch_name(), m_tile_nhw, m_tile_oc); | |||
} else { | |||
kernel_key = ssprintf( | |||
"%s_conv_bias_uint4_int4_imma8832_ldg16_%ux%u", | |||
current_device_arch_name(), m_tile_nhw, m_tile_oc); | |||
} | |||
megdnn_assert( | |||
param.nonlineMode == NonlineMode::RELU || | |||
param.nonlineMode == NonlineMode::IDENTITY); | |||
kernel_key += "_relu"; | |||
return kernel_key; | |||
} | |||
bool ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::is_available( | |||
const SizeArgs& args) const { | |||
if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { | |||
return false; | |||
} | |||
if (args.bias_layout->ndim <= 0) | |||
return false; | |||
using Param = param::ConvBias; | |||
using Format = Param::Format; | |||
using Sparse = Param::Sparse; | |||
using Mode = Param::Mode; | |||
using NonlineMode = Param::NonlineMode; | |||
bool available = true; | |||
auto&& param = args.opr->param(); | |||
auto&& fm = args.filter_meta; | |||
if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) | |||
return false; | |||
if (param.format != Format::NCHW64) | |||
return false; | |||
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | |||
// TODO support group conv | |||
available &= param.sparse == Sparse::DENSE; | |||
// mode must be cross correlation | |||
available &= param.mode == Mode::CROSS_CORRELATION; | |||
// nonlineMode must be RELU or IDENTITY | |||
available &= | |||
(param.nonlineMode == NonlineMode::RELU || | |||
param.nonlineMode == NonlineMode::IDENTITY); | |||
// check data type | |||
auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype, | |||
bias_dtype = args.bias_layout->dtype, dst_dtype = args.dst_layout->dtype; | |||
available &= | |||
(src_dtype.enumv() == DTypeEnum::Quantized4Asymm && | |||
filter_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||
bias_dtype.enumv() == DTypeEnum::QuantizedS32 && | |||
dst_dtype.enumv() == DTypeEnum::Quantized4Asymm); | |||
// TODO: support dialtion | |||
available &= dh == 1 && dw == 1; | |||
// ensure precomputed offsets are positive integers | |||
available &= hi >= fh && wi >= fw; | |||
// only support sm_86 or later, platform should have tensorcore int4 | |||
// support | |||
available &= | |||
(is_compute_capability_equalto(8, 0) || | |||
is_compute_capability_equalto(8, 6)); | |||
// param buffer size is 4K, use 3K to store precomputed offset | |||
size_t kMaxFilterPixels = CONSTANT_BUFFER_SIZE / (2 * 128 / 64) - 1; | |||
available &= fh * fw <= kMaxFilterPixels; | |||
return available; | |||
} | |||
size_t ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm:: | |||
get_workspace_in_bytes(const SizeArgs& args) const { | |||
if (args.preprocessed_filter == nullptr) { | |||
size_t OC = args.filter_layout->operator[](0), | |||
IC = args.filter_layout->operator[](1) * 64, | |||
FH = args.filter_layout->operator[](2), | |||
FW = args.filter_layout->operator[](3); | |||
size_t ws_size_reduce_filter = OC * sizeof(int32_t); | |||
// for reduce filter | |||
{ | |||
size_t A = OC, B = IC * FH * FW / 8, C = 1; | |||
ws_size_reduce_filter += do_dispatch_reduce_workspace_in_bytes(A, B, C); | |||
} | |||
return args.filter_layout->span().dist_byte() + | |||
args.bias_layout->span().dist_byte() + ws_size_reduce_filter; | |||
} | |||
return 0_z; | |||
} | |||
void ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::exec( | |||
const ExecArgs& args) const { | |||
using Format = Param::Format; | |||
auto&& param = args.opr->param(); | |||
auto&& fm = args.filter_meta; | |||
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | |||
auto&& stream = cuda_stream(args.opr->handle()); | |||
constexpr int interleaved = 64; | |||
void* bias_ptr = nullptr; | |||
void* filter_ptr = nullptr; | |||
if (args.preprocessed_filter) { | |||
megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | |||
filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||
bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr(); | |||
} else { | |||
// reorder filter and bias | |||
filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | |||
bias_ptr = reinterpret_cast<void*>( | |||
args.workspace.raw_ptr + args.filter_layout->span().dist_byte()); | |||
void* reduce_filter_ptr = reinterpret_cast<void*>( | |||
args.workspace.raw_ptr + args.filter_layout->span().dist_byte() + | |||
args.bias_layout->span().dist_byte()); | |||
reorder_filter_bias(args, reduce_filter_ptr, filter_ptr, bias_ptr); | |||
} | |||
uint32_t u32_n = n, u32_ci = ci, u32_hi = hi, u32_wi = wi, u32_fh = fh, u32_fw = fw, | |||
u32_sh = sh, u32_sw = sw, u32_ph = ph, u32_pw = pw, u32_co = co, | |||
u32_ho = ho, u32_wo = wo; | |||
Conv2dInt4Param kern_param( | |||
u32_n, u32_ci, u32_hi, u32_wi, u32_fh, u32_fw, u32_sh, u32_sw, u32_ph, | |||
u32_pw, u32_co, u32_ho, u32_wo, interleaved); | |||
Conv2dConstantOffset kern_coffset; | |||
compute_conv2d_offset(fh, fw, kern_param.ics, kern_param.ihs, kern_coffset); | |||
// begin is not need | |||
kern_coffset.c_offset_param.begin = param_buffer_start_address(); | |||
kern_coffset.c_offset_param.size = 4 * (1 + fh * fw); | |||
kern_coffset.c_offset_param.max = 4 * fh * fw; | |||
kern_coffset.c_offset_param.rewind = 4 * (1 - fh * fw); | |||
float src_scale = args.src_layout->dtype.param<dtype::Quantized4Asymm>().scale, | |||
dst_scale = args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale, | |||
filter_scale = args.filter_layout->dtype.param<dtype::QuantizedS4>().scale; | |||
uint32_t src_zero_point = | |||
(uint32_t)(args.src_layout->dtype.param<dtype::Quantized4Asymm>() | |||
.zero_point); | |||
uint32_t pk_src_zero_point = 0; | |||
for (int i = 0; i < 8; i++) { | |||
pk_src_zero_point <<= 4; | |||
pk_src_zero_point |= (src_zero_point & 0xF); | |||
} | |||
float dst_zero_point = | |||
(float)(args.dst_layout->dtype.param<dtype::Quantized4Asymm>().zero_point); | |||
float alpha = src_scale * filter_scale / dst_scale, beta = 1.f; | |||
unsigned int tx = m_threads, ty = 1; | |||
unsigned int gridx = | |||
div_ceil<unsigned int>(static_cast<unsigned int>(n * ho * wo), m_tile_nhw); | |||
unsigned int gridy = | |||
div_ceil<unsigned int>(static_cast<unsigned int>(co), m_tile_oc); | |||
void* src_ptr = const_cast<void*>(args.src_tensor->raw_ptr()); | |||
void* dst_ptr = const_cast<void*>(args.dst_tensor->raw_ptr()); | |||
using NonlineMode = Param::NonlineMode; | |||
auto kern_key = kernel_key(args); | |||
auto&& kernel = PTXKernelLoader::instance().get_kernel(kern_key); | |||
if (args.z_layout->ndim > 0) { | |||
void* z_ptr = const_cast<void*>(args.z_tensor->raw_ptr()); | |||
auto z_param = args.z_layout->dtype.param<dtype::Quantized4Asymm>(); | |||
int32_t z_zero_point = (int32_t)z_param.zero_point; | |||
float z_scale = z_param.scale; | |||
float gamma = z_scale / dst_scale; | |||
std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr, &z_ptr, | |||
&dst_ptr, &alpha, &beta, &gamma}; | |||
kern_coffset.c_offset_param.begin += sizeof(src_ptr) + sizeof(filter_ptr) + | |||
sizeof(bias_ptr) + sizeof(z_ptr) + | |||
sizeof(dst_ptr) + sizeof(alpha) + | |||
sizeof(beta) + sizeof(gamma); | |||
kern_coffset.c_offset_param.begin += sizeof(pk_src_zero_point); | |||
params.push_back(&pk_src_zero_point); | |||
kern_coffset.c_offset_param.begin += sizeof(z_zero_point); | |||
params.push_back(&z_zero_point); | |||
kern_coffset.c_offset_param.begin += sizeof(dst_zero_point); | |||
params.push_back(&dst_zero_point); | |||
uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0; | |||
params.push_back(&relu); | |||
kern_coffset.c_offset_param.begin += sizeof(relu); | |||
params.push_back(&kern_param); | |||
kern_coffset.c_offset_param.begin += sizeof(kern_param); | |||
kern_coffset.c_offset_param.begin += sizeof(kern_coffset.c_offset_param); | |||
params.push_back(&kern_coffset); | |||
dim3 grid(gridx, gridy, 1); | |||
dim3 block(tx, ty, 1); | |||
kernel(grid, block, stream, params.data()); | |||
} else { | |||
std::vector<void*> params = {&src_ptr, &filter_ptr, &bias_ptr, | |||
&dst_ptr, &alpha, &beta}; | |||
kern_coffset.c_offset_param.begin += sizeof(src_ptr) + sizeof(filter_ptr) + | |||
sizeof(bias_ptr) + sizeof(dst_ptr) + | |||
sizeof(alpha) + sizeof(beta); | |||
kern_coffset.c_offset_param.begin += sizeof(pk_src_zero_point); | |||
params.push_back(&pk_src_zero_point); | |||
kern_coffset.c_offset_param.begin += sizeof(dst_zero_point); | |||
params.push_back(&dst_zero_point); | |||
uint32_t relu = param.nonlineMode == NonlineMode::RELU ? 1 : 0; | |||
params.push_back(&relu); | |||
kern_coffset.c_offset_param.begin += sizeof(relu); | |||
params.push_back(&kern_param); | |||
kern_coffset.c_offset_param.begin += sizeof(kern_param); | |||
kern_coffset.c_offset_param.begin += sizeof(kern_coffset.c_offset_param); | |||
params.push_back(&kern_coffset); | |||
dim3 grid(gridx, gridy, 1); | |||
dim3 block(tx, ty, 1); | |||
kernel(grid, block, stream, params.data()); | |||
} | |||
after_kernel_launch(); | |||
} | |||
size_t ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm:: | |||
get_preprocess_workspace_in_bytes(const SizeArgs& args) const { | |||
size_t OC = args.filter_layout->operator[](0), | |||
IC = args.filter_layout->operator[](1) * 64, | |||
FH = args.filter_layout->operator[](2), | |||
FW = args.filter_layout->operator[](3); | |||
size_t ws_size_reduce_filter = OC * sizeof(int32_t); | |||
// for reduce filter | |||
{ | |||
size_t A = OC, B = IC * FH * FW / 8, C = 1; | |||
ws_size_reduce_filter += do_dispatch_reduce_workspace_in_bytes(A, B, C); | |||
} | |||
return ws_size_reduce_filter; | |||
} | |||
SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm:: | |||
deduce_preprocessed_filter_layout(const SizeArgs& args) const { | |||
return {args.filter_layout->collapse_contiguous(), | |||
args.bias_layout->collapse_contiguous()}; | |||
} | |||
void ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::reorder_filter_bias( | |||
const ExecArgs& args, void* reduce_filter, void* reordered_filter, | |||
void* reordered_bias) const { | |||
using Format = Param::Format; | |||
auto&& param = args.opr->param(); | |||
auto&& fm = args.filter_meta; | |||
UNPACK_CONV_BIAS_NCHW64_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | |||
auto&& stream = cuda_stream(args.opr->handle()); | |||
float src_scale = args.src_layout->dtype.param<dtype::Quantized4Asymm>().scale, | |||
filter_scale = args.filter_layout->dtype.param<dtype::QuantizedS4>().scale, | |||
bias_scale = args.bias_layout->dtype.param<dtype::QuantizedS32>().scale, | |||
dst_scale = args.dst_layout->dtype.param<dtype::Quantized4Asymm>().scale; | |||
float scaled_src_zero_point = | |||
args.src_layout->dtype.param<dtype::Quantized4Asymm>().zero_point * | |||
src_scale * filter_scale / dst_scale; | |||
// NCHW64 reduce CHW64 | |||
do_dispatch_reduce_with_scale_filter_4bit<true>( | |||
reinterpret_cast<uint8_t*>(args.filter_tensor->raw_ptr()), 1, co, | |||
ci * fh * fw / 8, static_cast<int32_t*>(reduce_filter), stream); | |||
reorder_imma_filter_bias_fusion_zero_point<4, 64>( | |||
reinterpret_cast<int8_t*>(reordered_filter), | |||
reinterpret_cast<float*>(reordered_bias), | |||
reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr()), | |||
args.bias_tensor->compatible_ptr<int32_t>(), bias_scale / dst_scale, | |||
static_cast<int32_t*>(reduce_filter), scaled_src_zero_point, co, ci, fh, fw, | |||
stream); | |||
} | |||
void ConvBiasForwardImpl::AlgoPTXUInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( | |||
const ExecArgs& args) const { | |||
reorder_filter_bias( | |||
args, args.workspace.raw_ptr, | |||
args.preprocessed_filter->tensors[0].raw_ptr(), | |||
args.preprocessed_filter->tensors[1].raw_ptr()); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,49 @@ | |||
/** | |||
* \file dnn/src/cuda/ptx_loader.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/cuda/ptx_loader.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
// ******************* PTXKernelLoader ********************* | |||
const std::unordered_map<std::string, PTXKernelLoader::kernel> PTXKernelLoader::KERNEL_MAP = | |||
{{"ampere_conv_bias_uint4_int4_imma8832_ldg16_256x64_relu", | |||
ptx::run_ampere_conv_bias_uint4_int4_imma8832_ldg16_256x64_relu}, | |||
{"ampere_conv_bias_uint4_int4_imma8832_ldg16_128x128_relu", | |||
ptx::run_ampere_conv_bias_uint4_int4_imma8832_ldgsts16_128x128_relu}, | |||
{"ampere_conv_bias_uint4_int4_imma8832_ldg16_128x256_relu", | |||
ptx::run_ampere_conv_bias_uint4_int4_imma8832_ldg16_128x256_relu}, | |||
{"ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_256x64_relu", | |||
ptx::run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_256x64_relu}, | |||
{"ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_128x128_relu", | |||
ptx::run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldgsts16_128x128_relu}, | |||
{"ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_128x256_relu", | |||
ptx::run_ampere_conv_bias_uint4_int4_fuse_z_imma8832_ldg16_128x256_relu}}; | |||
PTXKernelLoader& PTXKernelLoader::instance() { | |||
static PTXKernelLoader ins; | |||
return ins; | |||
} | |||
const PTXKernelLoader::kernel PTXKernelLoader::get_kernel( | |||
const std::string& kernel_name) { | |||
decltype(KERNEL_MAP.begin()) kernel_iter; | |||
kernel_iter = KERNEL_MAP.find(kernel_name); | |||
megdnn_throw_if( | |||
kernel_iter == KERNEL_MAP.end(), megdnn_error, | |||
ssprintf("kernel name %s not found in KERNEL_MAP", kernel_name.c_str()) | |||
.c_str()); | |||
return kernel_iter->second; | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,40 @@ | |||
/** | |||
* \file dnn/src/cuda/ptx_loader.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include <mutex> | |||
#include <unordered_map> | |||
#include "src/cuda/ptx/uint4_int4/kern.cuh" | |||
#include "src/cuda/utils.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class PTXKernelLoader { | |||
private: | |||
PTXKernelLoader() = default; | |||
using kernel = std::function<void(const dim3, const dim3, cudaStream_t, void**)>; | |||
public: | |||
PTXKernelLoader(const PTXKernelLoader&) = delete; | |||
const PTXKernelLoader& operator=(const PTXKernelLoader&) = delete; | |||
static PTXKernelLoader& instance(); | |||
const kernel get_kernel(const std::string& kernel_name); | |||
static const std::unordered_map<std::string, kernel> KERNEL_MAP; | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |