@@ -6,6 +6,7 @@ dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary | |||||
dnn/src/cuda/matrix_mul/fp32_simt/kimpl/* binary | dnn/src/cuda/matrix_mul/fp32_simt/kimpl/* binary | ||||
dnn/src/cuda/sass/prebuilt/map_defs.cpp binary | dnn/src/cuda/sass/prebuilt/map_defs.cpp binary | ||||
dnn/src/cuda/convolution/backward_data/int8/kimpl/* binary | dnn/src/cuda/convolution/backward_data/int8/kimpl/* binary | ||||
dnn/src/cuda/elemwise_multi_type/kimpl/* binary | |||||
tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text | tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text | ||||
imperative/python/test/integration/data/*.mge filter=lfs diff=lfs merge=lfs -text | imperative/python/test/integration/data/*.mge filter=lfs diff=lfs merge=lfs -text | ||||
ci/resource/models/float/mobilenet_v2.pkl filter=lfs diff=lfs merge=lfs -text | ci/resource/models/float/mobilenet_v2.pkl filter=lfs diff=lfs merge=lfs -text | ||||
@@ -382,6 +382,9 @@ struct TensorLayout : public TensorShape { | |||||
//! get lowest and highest offset reachable from this layout | //! get lowest and highest offset reachable from this layout | ||||
Span span() const; | Span span() const; | ||||
//! total number of access bytes | |||||
size_t access_bytes() const; | |||||
}; | }; | ||||
/** | /** | ||||
@@ -308,6 +308,8 @@ class dt_qulowbit { | |||||
return _; | return _; | ||||
} | } | ||||
MEGDNN_DEVICE uint8_t as_storage() const { return _; } | |||||
MEGDNN_HOST MEGDNN_DEVICE explicit dt_qulowbit(uint8_t val):_(val) {} | MEGDNN_HOST MEGDNN_DEVICE explicit dt_qulowbit(uint8_t val):_(val) {} | ||||
#ifdef MEGDNN_CC_HOST | #ifdef MEGDNN_CC_HOST | ||||
explicit operator uint8_t() { return _; } | explicit operator uint8_t() { return _; } | ||||
@@ -332,6 +334,8 @@ class dt_qlowbit { | |||||
return _; | return _; | ||||
} | } | ||||
MEGDNN_DEVICE int8_t as_storage() const { return _; } | |||||
MEGDNN_HOST MEGDNN_DEVICE explicit dt_qlowbit(int8_t val):_(val) {} | MEGDNN_HOST MEGDNN_DEVICE explicit dt_qlowbit(int8_t val):_(val) {} | ||||
#ifdef MEGDNN_CC_HOST | #ifdef MEGDNN_CC_HOST | ||||
explicit operator int8_t() { return _; } | explicit operator int8_t() { return _; } | ||||
@@ -1,6 +1,10 @@ | |||||
# As cuda currently do not support quint8, so we just ignore it. | # As cuda currently do not support quint8, so we just ignore it. | ||||
SUPPORT_DTYPES = [('dt_qint8', 'dt_qint8')] | SUPPORT_DTYPES = [('dt_qint8', 'dt_qint8')] | ||||
SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32')] | |||||
SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32'), | |||||
('dt_qint4', 'dt_qint32'), ('dt_quint4', 'dt_qint32')] | |||||
SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')] | |||||
SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')] | |||||
MODES = { | MODES = { | ||||
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | 1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | ||||
@@ -16,6 +20,15 @@ MODES = { | |||||
3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | 3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | ||||
} | } | ||||
QINT4_MODES = { | |||||
1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID', | |||||
'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'], | |||||
2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0', | |||||
'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH', | |||||
'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'], | |||||
3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | |||||
} | |||||
QINT32_MODES = { | QINT32_MODES = { | ||||
1: ['RELU', 'SIGMOID', 'TANH', 'FAST_TANH', 'H_SWISH'], | 1: ['RELU', 'SIGMOID', 'TANH', 'FAST_TANH', 'H_SWISH'], | ||||
2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID', | 2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID', | ||||
@@ -212,7 +212,7 @@ TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, | |||||
TensorLayout::TensorLayout(const TensorShape& shape, | TensorLayout::TensorLayout(const TensorShape& shape, | ||||
const std::vector<ptrdiff_t>& stride, DType dtype) | const std::vector<ptrdiff_t>& stride, DType dtype) | ||||
: TensorLayout(shape, stride, dtype, DefaultTensorFormat::make()) {} | |||||
: TensorLayout(shape, stride, dtype, Format(dtype)) {} | |||||
TensorLayout::TensorLayout(const TensorShape& shape, | TensorLayout::TensorLayout(const TensorShape& shape, | ||||
const std::vector<ptrdiff_t>& stride, DType dtype, | const std::vector<ptrdiff_t>& stride, DType dtype, | ||||
@@ -412,6 +412,27 @@ TensorLayout::Span TensorLayout::span() const { | |||||
return format.impl()->span_spec(*this); | return format.impl()->span_spec(*this); | ||||
} | } | ||||
size_t TensorLayout::access_bytes() const { | |||||
megdnn_assert(dtype.valid()); | |||||
auto contig = collapse_contiguous(); | |||||
size_t ret = 0; | |||||
if (dtype.is_low_bit()) { | |||||
ret = 1; | |||||
int align_size_in_elements = 8 / dtype.low_bit(); | |||||
for (size_t i = 0; i < contig.ndim; ++i) { | |||||
if (contig.stride[i] == 1) { | |||||
ret *= round_up((int)contig.shape[i], align_size_in_elements); | |||||
} else { | |||||
ret *= contig.shape[i]; | |||||
} | |||||
} | |||||
ret /= align_size_in_elements; | |||||
} else { | |||||
ret = dtype.size(total_nr_elems()); | |||||
} | |||||
return ret; | |||||
} | |||||
TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { | TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { | ||||
megdnn_throw_if(!ndim || !tshape.ndim, tensor_reshape_error, | megdnn_throw_if(!ndim || !tshape.ndim, tensor_reshape_error, | ||||
"broadcast involves empty tensor"); | "broadcast involves empty tensor"); | ||||
@@ -236,33 +236,66 @@ INST(dt_qint8); | |||||
INST(dt_quint8); | INST(dt_quint8); | ||||
#undef dt_ibyte | #undef dt_ibyte | ||||
template <int ndim> | |||||
void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init( | |||||
const TensorND& rv, int /*grid_size*/, int /*block_size*/) { | |||||
m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr); | |||||
for (size_t i = 0; i < rv.layout.ndim; ++i) { | |||||
m_stride[i] = rv.layout.stride[i]; | |||||
m_shape[i] = rv.layout.shape[i]; | |||||
if (i + 1 < rv.layout.ndim) { | |||||
m_shape_highdim[i] = rv.layout.shape[i + 1]; | |||||
if (rv.layout.stride[i + 1] == 1) | |||||
m_align_shape_highdim[i] = | |||||
(uint32_t)round_up((int)rv.layout.shape[i + 1], 2); | |||||
else | |||||
m_align_shape_highdim[i] = rv.layout.shape[i + 1]; | |||||
} | |||||
} | |||||
for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) { | |||||
m_shape_highdim[i] = 1; | |||||
m_align_shape_highdim[i] = 1; | |||||
} | |||||
for (size_t i = rv.layout.ndim; i < ndim; ++i) { | |||||
m_stride[i] = 0; | |||||
m_shape[i] = 1; | |||||
} | |||||
m_is_physical_contiguous = rv.layout.is_physical_contiguous(); | |||||
} | |||||
#define ndim_cb(_ndim) \ | |||||
template class ParamElemVisitor4bitBase<_ndim, BCAST_OTHER>; | |||||
MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) | |||||
#undef ndim_cb | |||||
} // namespace elemwise_intl | } // namespace elemwise_intl | ||||
void elemwise_intl::get_launch_spec(const void* kern, size_t size, | void elemwise_intl::get_launch_spec(const void* kern, size_t size, | ||||
int* grid_size, int* block_size) { | int* grid_size, int* block_size) { | ||||
safe_size_in_kern(size); | |||||
auto config = query_launch_config_for_kernel(kern); | |||||
*block_size = config.block_size; | |||||
int a = size / (config.block_size * 2), | |||||
b = (size - 1) / (config.block_size * 3) + 1; | |||||
if (current_device_prop().major <= 3) { | |||||
// for Kepler, less blocks (more work per thread) is faster | |||||
*grid_size = b; | |||||
} else { | |||||
*grid_size = std::max(a, b); | |||||
safe_size_in_kern(size); | |||||
auto config = query_launch_config_for_kernel(kern); | |||||
*block_size = config.block_size; | |||||
int a = size / (config.block_size * 2), | |||||
b = (size - 1) / (config.block_size * 3) + 1; | |||||
if (current_device_prop().major <= 3) { | |||||
// for Kepler, less blocks (more work per thread) is faster | |||||
*grid_size = b; | |||||
} else { | |||||
*grid_size = std::max(a, b); | |||||
} | |||||
if (!*grid_size) { | |||||
*block_size = std::min<int>(std::max<int>(size / 64, 1) * 32, 1024); | |||||
*grid_size = std::max<int>(size / *block_size, 1); | |||||
} | |||||
// because we unroll 3 times in the kernel | |||||
megdnn_assert(static_cast<size_t>(*block_size) * *grid_size * 3 >= | |||||
size); | |||||
} | } | ||||
if (!*grid_size) { | |||||
*block_size = std::min<int>(std::max<int>(size / 64, 1) * 32, 1024); | |||||
*grid_size = std::max<int>(size / *block_size, 1); | |||||
} | |||||
// because we unroll 3 times in the kernel | |||||
megdnn_assert(static_cast<size_t>(*block_size) * *grid_size * 3 >= size); | |||||
} | |||||
void elemwise_intl::on_bad_ndim(int ndim) { | |||||
megdnn_throw(ssprintf("invalid ndim: %d", ndim)); | |||||
MEGDNN_MARK_USED_VAR(ndim); | |||||
} | |||||
void elemwise_intl::on_bad_ndim(int ndim) { | |||||
megdnn_throw(ssprintf("invalid ndim: %d", ndim)); | |||||
MEGDNN_MARK_USED_VAR(ndim); | |||||
} | |||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -115,6 +115,34 @@ INST(dt_qint32, int4); | |||||
#undef as_raw | #undef as_raw | ||||
#undef INST | #undef INST | ||||
struct int4bx2 { | |||||
int8_t x; | |||||
}; | |||||
struct uint4bx2 { | |||||
uint8_t x; | |||||
}; | |||||
#define INST(_ctype, _Storage, _vect_type) \ | |||||
template <> \ | |||||
class VectTypeTrait<_ctype> { \ | |||||
public: \ | |||||
using Storage = _Storage; \ | |||||
static const Storage kMask = 0xf; \ | |||||
static const Storage kBits = 4; \ | |||||
using vect_type = _vect_type; \ | |||||
static const size_t packed_size = 2; \ | |||||
static __device__ __forceinline__ vect_type make_vector(Storage x, \ | |||||
Storage y) { \ | |||||
vect_type t; \ | |||||
t.x = (x & kMask) | (y << kBits); \ | |||||
return t; \ | |||||
} \ | |||||
} | |||||
INST(dt_qint4, int8_t, int4bx2); | |||||
INST(dt_quint4, uint8_t, uint4bx2); | |||||
#undef INST | |||||
/*! | /*! | ||||
* \brief visitor to access an elemeent in a tensor at given logic index | * \brief visitor to access an elemeent in a tensor at given logic index | ||||
* \tparam ctype plain element ctype (i.e. ctype in DTypeTrait) | * \tparam ctype plain element ctype (i.e. ctype in DTypeTrait) | ||||
@@ -217,6 +245,7 @@ template <int ndim, typename ctype> | |||||
class ParamElemVisitor<ndim, ctype, BCAST_OTHER> | class ParamElemVisitor<ndim, ctype, BCAST_OTHER> | ||||
: public ParamVisitorBase<ndim, ctype, BCAST_OTHER> { | : public ParamVisitorBase<ndim, ctype, BCAST_OTHER> { | ||||
public: | public: | ||||
using CType = ctype; | |||||
PARAM_ELEM_VISITOR_COMMON_HOST | PARAM_ELEM_VISITOR_COMMON_HOST | ||||
void host_init(const TensorND& rv, int grid_size, int block_size) { | void host_init(const TensorND& rv, int grid_size, int block_size) { | ||||
@@ -500,6 +529,177 @@ public: | |||||
#endif | #endif | ||||
}; | }; | ||||
template <int ndim, BcastType brd_type> | |||||
class ParamElemVisitor4bitBase; | |||||
template <int ndim> | |||||
class ParamElemVisitor4bitBase<ndim, BCAST_OTHER> { | |||||
using Storage = int8_t; | |||||
protected: | |||||
Storage* __restrict m_ptr; | |||||
int m_stride[ndim]; | |||||
int m_shape[ndim]; | |||||
bool m_is_physical_contiguous; | |||||
//! m_shape_highdim[i] = original_shape[i + 1] | |||||
#ifdef _MSC_VER | |||||
Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1]; | |||||
Uint32Fastdiv m_align_shape_highdim[ndim > 1 ? ndim - 1 : 1]; | |||||
#else | |||||
Uint32Fastdiv m_shape_highdim[ndim]; | |||||
Uint32Fastdiv m_align_shape_highdim[ndim]; | |||||
#endif | |||||
public: | |||||
static const Storage kMask = 0xf; | |||||
static const Storage kBits = 4; | |||||
static const int NDIM = ndim; | |||||
void host_init(const TensorND& rv, int grid_size, int block_size); | |||||
#if MEGDNN_CC_CUDA | |||||
devfunc void thread_init(uint32_t) {} | |||||
devfunc void next() {} | |||||
devfunc void get_shape_from_access(uint32_t access_idx, | |||||
int (&shape_idx)[ndim]) { | |||||
#pragma unroll | |||||
for (int i = ndim - 1; i >= 1; --i) { | |||||
Uint32Fastdiv& align_shp = m_align_shape_highdim[i - 1]; | |||||
uint32_t access_idx_div = access_idx / align_shp; | |||||
shape_idx[i] = access_idx - access_idx_div * align_shp.divisor(); | |||||
access_idx = access_idx_div; | |||||
} | |||||
shape_idx[0] = access_idx; | |||||
} | |||||
devfunc int offset(uint32_t idx) { | |||||
int offset = 0; | |||||
#pragma unroll | |||||
for (int i = ndim - 1; i >= 1; --i) { | |||||
Uint32Fastdiv& shp = m_shape_highdim[i - 1]; | |||||
uint32_t idx_div = idx / shp; | |||||
offset += (idx - idx_div * shp.divisor()) * m_stride[i]; | |||||
idx = idx_div; | |||||
} | |||||
offset += idx * m_stride[0]; | |||||
return offset; | |||||
} | |||||
devfunc int idx(uint32_t access_idx) { | |||||
int idx = 0; | |||||
if (m_is_physical_contiguous) { | |||||
idx = access_idx; | |||||
} else { | |||||
int shape_idx[ndim]; | |||||
bool valid = true; | |||||
get_shape_from_access(access_idx, shape_idx); | |||||
#pragma unroll | |||||
for (int i = 0; i < ndim; ++i) { | |||||
valid &= (shape_idx[i] < m_shape[i]); | |||||
} | |||||
#pragma unroll | |||||
for (int i = 0; i < ndim - 1; ++i) { | |||||
idx = (idx + shape_idx[i]) * m_shape[i + 1]; | |||||
} | |||||
idx = valid ? idx + shape_idx[ndim - 1] : -1; | |||||
} | |||||
return idx; | |||||
} | |||||
devfunc Storage* ptr() { return m_ptr; } | |||||
#endif | |||||
}; | |||||
template <int ndim> | |||||
class ParamElemVisitor<ndim, dt_qint4, BCAST_OTHER> | |||||
: public ParamElemVisitor4bitBase<ndim, BCAST_OTHER> { | |||||
using CType = dt_qint4; | |||||
using Storage = int8_t; | |||||
public: | |||||
static const int packed_size = 1; | |||||
using Super = ParamElemVisitor4bitBase<ndim, BCAST_OTHER>; | |||||
void host_init(const TensorND& rv, int grid_size, int block_size) { | |||||
Super::host_init(rv, grid_size, block_size); | |||||
} | |||||
#if MEGDNN_CC_CUDA | |||||
// cannot be l-value, only support read | |||||
devfunc dt_qint4 at(uint32_t idx) { | |||||
int offset_ = Super::offset(idx); | |||||
int vec_idx = offset_ >> 1; | |||||
int lane_idx = offset_ & 0x1; | |||||
Storage item = Storage(unpack_integer_4bits<true>( | |||||
*(Storage*)&Super::m_ptr[vec_idx], lane_idx * 4)); | |||||
dt_qint4 result(item); | |||||
return result; | |||||
} | |||||
#endif | |||||
}; | |||||
template <int ndim> | |||||
class ParamElemVisitor<ndim, dt_quint4, BCAST_OTHER> | |||||
: public ParamElemVisitor4bitBase<ndim, BCAST_OTHER> { | |||||
using CType = dt_quint4; | |||||
using Storage = uint8_t; | |||||
using Super = ParamElemVisitor4bitBase<ndim, BCAST_OTHER>; | |||||
public: | |||||
static const int packed_size = 1; | |||||
void host_init(const TensorND& rv, int grid_size, int block_size) { | |||||
Super::host_init(rv, grid_size, block_size); | |||||
} | |||||
#if MEGDNN_CC_CUDA | |||||
// cannot be l-value, only support read | |||||
devfunc dt_quint4 at(uint32_t idx) { | |||||
int offset_ = Super::offset(idx); | |||||
int vec_idx = offset_ >> 1; | |||||
int lane_idx = offset_ & 0x1; | |||||
Storage item = Storage(unpack_integer_4bits<false>( | |||||
*(Storage*)&Super::m_ptr[vec_idx], lane_idx * 4)); | |||||
dt_quint4 result(item); | |||||
return result; | |||||
} | |||||
#endif | |||||
}; | |||||
#if MEGDNN_CC_CUDA | |||||
#define DEVICE_WRAPPER(x) x | |||||
#else | |||||
#define DEVICE_WRAPPER(x) | |||||
#endif | |||||
#define INST_DT_IBYTE(ctype) \ | |||||
template <int ndim> \ | |||||
class ParamVectVisitor<ndim, ctype, BCAST_OTHER> \ | |||||
: public ParamElemVisitor4bitBase<ndim, BCAST_OTHER> { \ | |||||
public: \ | |||||
using Super = ParamElemVisitor4bitBase<ndim, BCAST_OTHER>; \ | |||||
void host_init(const TensorND& rv, int grid_size, int block_size) { \ | |||||
Super::host_init(rv, grid_size, block_size); \ | |||||
} \ | |||||
using rwtype = typename VectTypeTrait<ctype>::vect_type; \ | |||||
static const int packed_size = VectTypeTrait<ctype>::packed_size; \ | |||||
DEVICE_WRAPPER(devfunc rwtype& at(uint32_t access_idx) { \ | |||||
return *(rwtype*)(&Super::m_ptr[access_idx]); \ | |||||
}) \ | |||||
}; | |||||
INST_DT_IBYTE(dt_qint4); | |||||
INST_DT_IBYTE(dt_quint4); | |||||
#undef DEVICE_WRAPPER | |||||
#undef INST_DT_IBYTE | |||||
/* f}}} */ | /* f}}} */ | ||||
#if MEGDNN_CC_CUDA | #if MEGDNN_CC_CUDA | ||||
@@ -507,7 +707,8 @@ public: | |||||
/* f{{{ user operator callers */ | /* f{{{ user operator callers */ | ||||
/* | /* | ||||
* OpCaller is used to invoke user operator with loaded element arguments. | |||||
* OpCaller is used to invoke user operator with loaded element | |||||
* arguments. | |||||
* | * | ||||
* device interface: | * device interface: | ||||
* void thread_init(uint32_t idx); | * void thread_init(uint32_t idx); | ||||
@@ -518,8 +719,8 @@ public: | |||||
*/ | */ | ||||
/*! | /*! | ||||
* \brief call user op directly without visiting any params (i.e. arity == | |||||
* 0) | |||||
* \brief call user op directly without visiting any params (i.e. arity | |||||
* == 0) | |||||
*/ | */ | ||||
template <class Op> | template <class Op> | ||||
struct OpCallerNull { | struct OpCallerNull { | ||||
@@ -1151,6 +1352,20 @@ public: | |||||
} | } | ||||
}; | }; | ||||
#define INST_DT_TYPE(ctype) \ | |||||
template <class Op> \ | |||||
class UserOpInvoker<Op, ctype, 2> \ | |||||
: public UserOpInvokerToSameNdim<Op, ctype, 2> { \ | |||||
public: \ | |||||
UserOpInvoker(const ElemwiseOpParamN<2>& param, cudaStream_t stream, \ | |||||
const Op& op) \ | |||||
: UserOpInvokerToSameNdim<Op, ctype, 2>(param, stream, op) {} \ | |||||
} | |||||
INST_DT_TYPE(dt_qint4); | |||||
INST_DT_TYPE(dt_quint4); | |||||
#undef INST_DT_TYPE | |||||
#define DEFINE_VECT_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, \ | #define DEFINE_VECT_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, \ | ||||
_stride) \ | _stride) \ | ||||
DEFINE_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, _stride) \ | DEFINE_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, _stride) \ | ||||
@@ -1404,7 +1619,6 @@ void run_elemwise(const ElemwiseOpParamN<arity>& param, cudaStream_t stream, | |||||
#define INST_RUN_ELEMWISE(Op, ctype, arity) \ | #define INST_RUN_ELEMWISE(Op, ctype, arity) \ | ||||
template void run_elemwise<Op, ctype, arity>( \ | template void run_elemwise<Op, ctype, arity>( \ | ||||
const ElemwiseOpParamN<arity>&, cudaStream_t, const Op&) | const ElemwiseOpParamN<arity>&, cudaStream_t, const Op&) | ||||
#endif | #endif | ||||
} // namespace cuda | } // namespace cuda | ||||
@@ -0,0 +1,256 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise_helper_q4.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/cuda/elemwise_helper.cuh" | |||||
/* | |||||
* please note that all arithmetics on GPU are 32-bit for best performance; this | |||||
* limits max possible size | |||||
*/ | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
template <typename ctype> | |||||
struct IsNotTypeQ4 { | |||||
static constexpr bool value = !(std::is_same<ctype, dt_qint4>::value || | |||||
std::is_same<ctype, dt_quint4>::value); | |||||
}; | |||||
template <typename ctype> | |||||
struct IsTypeQ4 { | |||||
static constexpr bool value = (std::is_same<ctype, dt_qint4>::value || | |||||
std::is_same<ctype, dt_quint4>::value); | |||||
}; | |||||
//! internals for element-wise | |||||
namespace elemwise_intl { | |||||
#define devfunc __device__ __forceinline__ | |||||
#if MEGDNN_CC_CUDA | |||||
/*! | |||||
* \brief call an operator whose each param are promted to the same ndim and | |||||
* brdcast_mask | |||||
* \tparam PVis ParamElemVisitor class | |||||
*/ | |||||
template <class Op, int arity, class PVisSrc, class PVisDst, bool BetweenQ4> | |||||
struct OpCallerToQ4; | |||||
//! specialization for arity == 1 | |||||
template <class Op, class PVisSrc, class PVisDst> | |||||
struct OpCallerToQ4<Op, 1, PVisSrc, PVisDst, false> { | |||||
Op op; | |||||
PVisSrc par_src[1]; | |||||
PVisDst par_dst[1]; | |||||
using src_ctype = typename PVisSrc::CType; | |||||
devfunc void on(uint32_t access_idx) { | |||||
int32_t idx0 = par_dst[0].idx(access_idx * 2); | |||||
int32_t idx1 = par_dst[0].idx(access_idx * 2 + 1); | |||||
src_ctype src0 = (idx0 >= 0) ? par_src[0].at(idx0) : (src_ctype)0; | |||||
src_ctype src1 = (idx1 >= 0) ? par_src[0].at(idx1) : (src_ctype)0; | |||||
op(access_idx, src0, src1); | |||||
} | |||||
}; | |||||
//! specialization for arity == 2 | |||||
template <class Op, class PVisSrc, class PVisDst> | |||||
struct OpCallerToQ4<Op, 2, PVisSrc, PVisDst, false> { | |||||
Op op; | |||||
PVisSrc par_src[2]; | |||||
PVisDst par_dst[1]; | |||||
using src_ctype = typename PVisSrc::CType; | |||||
devfunc void on(uint32_t access_idx) { | |||||
int32_t idx0 = par_dst[0].idx(access_idx * 2); | |||||
int32_t idx1 = par_dst[0].idx(access_idx * 2 + 1); | |||||
src_ctype src00 = (idx0 >= 0) ? par_src[0].at(idx0) : (src_ctype)0; | |||||
src_ctype src10 = (idx0 >= 0) ? par_src[1].at(idx0) : (src_ctype)0; | |||||
src_ctype src01 = (idx0 >= 0) ? par_src[0].at(idx1) : (src_ctype)0; | |||||
src_ctype src11 = (idx0 >= 0) ? par_src[1].at(idx1) : (src_ctype)0; | |||||
op(access_idx, src00, src10, src01, src11); | |||||
} | |||||
}; | |||||
template <class Op, class PVisSrc, class PVisDst> | |||||
struct OpCallerToQ4<Op, 3, PVisSrc, PVisDst, false> { | |||||
Op op; | |||||
PVisSrc par_src[3]; | |||||
PVisDst par_dst[1]; | |||||
using src_ctype = typename PVisSrc::CType; | |||||
devfunc void on(uint32_t access_idx) { | |||||
int32_t idx0 = par_dst[0].idx(access_idx * 2); | |||||
int32_t idx1 = par_dst[0].idx(access_idx * 2 + 1); | |||||
src_ctype src00 = (idx0 >= 0) ? par_src[0].at(idx0) : (src_ctype)0; | |||||
src_ctype src10 = (idx0 >= 0) ? par_src[1].at(idx0) : (src_ctype)0; | |||||
src_ctype src20 = (idx0 >= 0) ? par_src[2].at(idx0) : (src_ctype)0; | |||||
src_ctype src01 = (idx0 >= 0) ? par_src[0].at(idx1) : (src_ctype)0; | |||||
src_ctype src11 = (idx0 >= 0) ? par_src[1].at(idx1) : (src_ctype)0; | |||||
src_ctype src21 = (idx0 >= 0) ? par_src[2].at(idx1) : (src_ctype)0; | |||||
op(access_idx, src00, src10, src20, src01, src11, src21); | |||||
} | |||||
}; | |||||
//! specialization for arity == 1 | |||||
template <class Op, class PVisSrc, class PVisDst> | |||||
struct OpCallerToQ4<Op, 1, PVisSrc, PVisDst, true> { | |||||
Op op; | |||||
PVisSrc par_src[1]; | |||||
PVisDst par_dst[1]; | |||||
devfunc void on(uint32_t access_idx) { | |||||
op(access_idx, par_src[0].at(access_idx)); | |||||
} | |||||
}; | |||||
//! specialization for arity == 2 | |||||
template <class Op, class PVisSrc, class PVisDst> | |||||
struct OpCallerToQ4<Op, 2, PVisSrc, PVisDst, true> { | |||||
Op op; | |||||
PVisSrc par_src[2]; | |||||
PVisDst par_dst[1]; | |||||
devfunc void on(uint32_t access_idx) { | |||||
op(access_idx, par_src[0].at(access_idx), par_src[1].at(access_idx)); | |||||
} | |||||
}; | |||||
template <class Op, class PVisSrc, class PVisDst> | |||||
struct OpCallerToQ4<Op, 3, PVisSrc, PVisDst, true> { | |||||
Op op; | |||||
PVisSrc par_src[3]; | |||||
PVisDst par_dst[1]; | |||||
devfunc void on(uint32_t access_idx) { | |||||
op(access_idx, par_src[0].at(access_idx), par_src[1].at(access_idx), | |||||
par_src[2].at(access_idx)); | |||||
} | |||||
}; | |||||
/* f}}} */ | |||||
template <class OpCaller> | |||||
__global__ void cuda_kern_q4(OpCaller op_caller, uint32_t size) { | |||||
uint32_t access_idx = blockIdx.x * blockDim.x + threadIdx.x, | |||||
delta = blockDim.x * gridDim.x; | |||||
if (access_idx < size) { | |||||
op_caller.on(access_idx); | |||||
access_idx += delta; | |||||
if (access_idx < size) { | |||||
op_caller.on(access_idx); | |||||
access_idx += delta; | |||||
if (access_idx < size) { | |||||
op_caller.on(access_idx); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
/* f{{{ UserOpInvoker specializations */ | |||||
//! run op by promoting all params to same ndim | |||||
template <class Op, typename src_ctype, typename dst_ctype, int arity, | |||||
bool BetweenQ4> | |||||
class UserOpInvokerQ4 { | |||||
const ElemwiseOpParamN<arity>& m_src_param; | |||||
const ElemwiseOpParamN<1>& m_dst_param; | |||||
cudaStream_t m_stream; | |||||
const Op& m_op; | |||||
void dispatch0() { | |||||
switch (m_dst_param.max_ndim) { | |||||
#define cb(ndim) \ | |||||
case ndim: \ | |||||
return dispatch1<ndim>(); | |||||
MEGDNN_FOREACH_TENSOR_NDIM(cb) | |||||
#undef cb | |||||
} | |||||
on_bad_ndim(m_dst_param.max_ndim); | |||||
} | |||||
template <int ndim> | |||||
void dispatch1() { | |||||
using PVisSrc = typename std::conditional< | |||||
BetweenQ4, ParamVectVisitor<ndim, src_ctype, BCAST_OTHER>, | |||||
ParamElemVisitor<ndim, src_ctype, BCAST_OTHER>>::type; | |||||
typedef OpCallerToQ4<Op, arity, PVisSrc, | |||||
ParamVectVisitor<ndim, dst_ctype, BCAST_OTHER>, | |||||
BetweenQ4> | |||||
Caller; | |||||
size_t size = m_dst_param[0].layout.access_bytes(); | |||||
int grid_size, block_size; | |||||
void (*fptr)(Caller, uint32_t) = cuda_kern_q4<Caller>; | |||||
get_launch_spec(reinterpret_cast<const void*>(fptr), size, &grid_size, | |||||
&block_size); | |||||
Caller caller; | |||||
caller.op = m_op; | |||||
for (int i = 0; i < arity; ++i) | |||||
caller.par_src[i].host_init(m_src_param[i], grid_size, block_size); | |||||
caller.par_dst[0].host_init(m_dst_param[0], grid_size, block_size); | |||||
(*fptr)<<<grid_size, block_size, 0, m_stream>>>(caller, size); | |||||
after_kernel_launch(); | |||||
} | |||||
public: | |||||
UserOpInvokerQ4(const ElemwiseOpParamN<arity>& src_param, | |||||
const ElemwiseOpParamN<1>& dst_param, cudaStream_t stream, | |||||
const Op& op) | |||||
: m_src_param(src_param), | |||||
m_dst_param(dst_param), | |||||
m_stream(stream), | |||||
m_op(op) { | |||||
dispatch0(); | |||||
} | |||||
}; | |||||
#endif | |||||
/* f}}} */ | |||||
#undef devfunc | |||||
} // namespace elemwise_intl | |||||
template <class Op, typename src_ctype, typename dst_ctype, int arity> | |||||
void run_elemwise(const ElemwiseOpParamN<arity>& src_param, | |||||
const ElemwiseOpParamN<1>& dst_param, cudaStream_t stream, | |||||
const Op& op = Op()); | |||||
#if MEGDNN_CC_CUDA | |||||
template <class Op, typename src_ctype, typename dst_ctype, int arity> | |||||
void run_elemwise(const ElemwiseOpParamN<arity>& src_param, | |||||
const ElemwiseOpParamN<1>& dst_param, cudaStream_t stream, | |||||
const Op& op) { | |||||
src_param.assert_initialized(); | |||||
dst_param.assert_initialized(); | |||||
// TODO: Maybe 2bit? | |||||
megdnn_assert(dst_param[0].layout.dtype.is_low_bit()); | |||||
megdnn_assert(dst_param[0].layout.is_contiguous()); | |||||
elemwise_intl::UserOpInvokerQ4<Op, src_ctype, dst_ctype, arity, | |||||
IsTypeQ4<src_ctype>::value>( | |||||
src_param, dst_param, stream, op); | |||||
} | |||||
#define INST_RUN_ELEMWISE_LOWBIT(Op, src_ctype, dst_ctype, arity) \ | |||||
template void run_elemwise<Op, src_ctype, dst_ctype, arity>( \ | |||||
const ElemwiseOpParamN<arity>&, const ElemwiseOpParamN<1>&, \ | |||||
cudaStream_t, const Op&) | |||||
#endif | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,39 @@ | |||||
/** | |||||
* \file dnn/src/cuda/elemwise_multi_type/kern_impl_q4.inl | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#ifndef KERN_IMPL_MODE | |||||
#error "KERN_IMPL_MODE, KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE must be defined" | |||||
#endif | |||||
#include "src/cuda/elemwise_multi_type/kern_ops.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
#define cb(_m) \ | |||||
typedef ElemwiseKern<megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, \ | |||||
float> \ | |||||
KernImpl; \ | |||||
typedef kern_ops_quantized::QuantizedMultiTypeOp< \ | |||||
KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, KernImpl> \ | |||||
Op; \ | |||||
INST_RUN_ELEMWISE_LOWBIT(Op, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, \ | |||||
KERN_IMPL_ARITY); | |||||
KERN_IMPL_MODE(cb) | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -6,11 +6,13 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/cuda/elemwise_helper.cuh" | #include "src/cuda/elemwise_helper.cuh" | ||||
#include "src/cuda/elemwise_helper_q4.cuh" | |||||
#include "src/cuda/elemwise_multi_type/kern.cuh" | #include "src/cuda/elemwise_multi_type/kern.cuh" | ||||
#include "src/cuda/utils.cuh" | #include "src/cuda/utils.cuh" | ||||
@@ -127,10 +129,10 @@ struct QuantizedMultiTypeOp; | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | template <typename ctype_src, typename ctype_dst, typename KernImpl> | ||||
struct QuantizedMultiTypeOp< | struct QuantizedMultiTypeOp< | ||||
1, ctype_src, ctype_dst, KernImpl, | 1, ctype_src, ctype_dst, KernImpl, | ||||
typename std::enable_if< | |||||
std::is_same<ctype_src, dt_qint8>::value || | |||||
std::is_same<ctype_src, dt_qint32>::value || | |||||
std::is_same<ctype_src, dt_quint8>::value>::type> { | |||||
typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value || | |||||
std::is_same<ctype_src, dt_qint32>::value || | |||||
std::is_same<ctype_src, dt_quint8>::value) && | |||||
IsNotTypeQ4<ctype_dst>::value>::type> { | |||||
ctype_dst* dst; | ctype_dst* dst; | ||||
CudaDTypeParam<ctype_dst> dst_param; | CudaDTypeParam<ctype_dst> dst_param; | ||||
CudaDTypeParam<ctype_src> param_a; | CudaDTypeParam<ctype_src> param_a; | ||||
@@ -173,10 +175,10 @@ struct QuantizedMultiTypeOp< | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | template <typename ctype_src, typename ctype_dst, typename KernImpl> | ||||
struct QuantizedMultiTypeOp< | struct QuantizedMultiTypeOp< | ||||
2, ctype_src, ctype_dst, KernImpl, | 2, ctype_src, ctype_dst, KernImpl, | ||||
typename std::enable_if< | |||||
std::is_same<ctype_src, dt_qint8>::value || | |||||
std::is_same<ctype_src, dt_qint32>::value || | |||||
std::is_same<ctype_src, dt_quint8>::value>::type> { | |||||
typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value || | |||||
std::is_same<ctype_src, dt_qint32>::value || | |||||
std::is_same<ctype_src, dt_quint8>::value) && | |||||
IsNotTypeQ4<ctype_dst>::value>::type> { | |||||
ctype_dst* dst; | ctype_dst* dst; | ||||
CudaDTypeParam<ctype_dst> dst_param; | CudaDTypeParam<ctype_dst> dst_param; | ||||
CudaDTypeParam<ctype_src> param_a, param_b; | CudaDTypeParam<ctype_src> param_a, param_b; | ||||
@@ -224,10 +226,10 @@ struct QuantizedMultiTypeOp< | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | template <typename ctype_src, typename ctype_dst, typename KernImpl> | ||||
struct QuantizedMultiTypeOp< | struct QuantizedMultiTypeOp< | ||||
3, ctype_src, ctype_dst, KernImpl, | 3, ctype_src, ctype_dst, KernImpl, | ||||
typename std::enable_if< | |||||
std::is_same<ctype_src, dt_qint8>::value || | |||||
std::is_same<ctype_src, dt_qint32>::value || | |||||
std::is_same<ctype_src, dt_quint8>::value>::type> { | |||||
typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value || | |||||
std::is_same<ctype_src, dt_qint32>::value || | |||||
std::is_same<ctype_src, dt_quint8>::value) && | |||||
IsNotTypeQ4<ctype_dst>::value>::type> { | |||||
ctype_dst* dst; | ctype_dst* dst; | ||||
CudaDTypeParam<ctype_dst> dst_param; | CudaDTypeParam<ctype_dst> dst_param; | ||||
CudaDTypeParam<ctype_src> param_a, param_b, param_c; | CudaDTypeParam<ctype_src> param_a, param_b, param_c; | ||||
@@ -277,6 +279,367 @@ struct QuantizedMultiTypeOp< | |||||
#endif | #endif | ||||
}; | }; | ||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | |||||
1, ctype_src, ctype_dst, KernImpl, | |||||
typename std::enable_if<IsTypeQ4<ctype_src>::value && | |||||
IsNotTypeQ4<ctype_dst>::value>::type> { | |||||
ctype_dst* dst; | |||||
CudaDTypeParam<ctype_dst> dst_param; | |||||
CudaDTypeParam<ctype_src> param_a; | |||||
#if !MEGDNN_CC_CUDA | |||||
QuantizedMultiTypeOp( | |||||
const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
ctype_dst* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
: dst{dst}, dst_param{dst_param} { | |||||
param_a = src_params[0]; | |||||
} | |||||
#endif | |||||
#if MEGDNN_CC_CUDA | |||||
__device__ __forceinline__ ctype_dst apply(ctype_src v1) { | |||||
float fv1 = param_a.dequantize(v1); | |||||
float rv = KernImpl::apply(fv1); | |||||
return dst_param.quantize(rv); | |||||
} | |||||
__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a) { | |||||
dst[idx] = dst_param.quantize(KernImpl::apply(param_a.dequantize(a))); | |||||
} | |||||
#endif | |||||
}; | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | |||||
2, ctype_src, ctype_dst, KernImpl, | |||||
typename std::enable_if<IsTypeQ4<ctype_src>::value && | |||||
IsNotTypeQ4<ctype_dst>::value>::type> { | |||||
ctype_dst* dst; | |||||
CudaDTypeParam<ctype_dst> dst_param; | |||||
CudaDTypeParam<ctype_src> param_a, param_b; | |||||
#if !MEGDNN_CC_CUDA | |||||
QuantizedMultiTypeOp( | |||||
const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
ctype_dst* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
: dst{dst}, dst_param{dst_param} { | |||||
param_a = src_params[0]; | |||||
param_b = src_params[1]; | |||||
} | |||||
#endif | |||||
#if MEGDNN_CC_CUDA | |||||
__device__ __forceinline__ ctype_dst apply(ctype_src v1, ctype_src v2) { | |||||
float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2); | |||||
float rv = KernImpl::apply(fv1, fv2); | |||||
return dst_param.quantize(rv); | |||||
} | |||||
__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a, | |||||
ctype_src b) { | |||||
dst[idx] = dst_param.quantize( | |||||
KernImpl::apply(param_a.dequantize(a), param_b.dequantize(b))); | |||||
} | |||||
#endif | |||||
}; | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | |||||
1, ctype_src, ctype_dst, KernImpl, | |||||
typename std::enable_if<IsTypeQ4<ctype_src>::value && | |||||
IsTypeQ4<ctype_dst>::value>::type> { | |||||
using src_storage = | |||||
typename elemwise_intl::VectTypeTrait<ctype_src>::Storage; | |||||
using dst_storage = | |||||
typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | |||||
dst_storage* dst; | |||||
CudaDTypeParam<ctype_dst> dst_param; | |||||
CudaDTypeParam<ctype_src> param_a; | |||||
static constexpr bool src_signedness = | |||||
std::is_same<ctype_src, dt_qint4>::value; | |||||
typedef typename elemwise_intl::VectTypeTrait<ctype_src>::vect_type | |||||
src_vect_type; | |||||
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type | |||||
dst_vect_type; | |||||
#if !MEGDNN_CC_CUDA | |||||
QuantizedMultiTypeOp( | |||||
const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
: dst{dst}, dst_param{dst_param} { | |||||
param_a = src_params[0]; | |||||
} | |||||
#endif | |||||
#if MEGDNN_CC_CUDA | |||||
__device__ __forceinline__ dst_storage apply(src_storage v1) { | |||||
float fv1 = param_a.dequantize(v1); | |||||
float rv = KernImpl::apply(fv1); | |||||
return dst_param.quantize(rv).as_storage(); | |||||
} | |||||
__device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a) { | |||||
dst_storage x = apply( | |||||
src_storage(unpack_integer_4bits<src_signedness>(a.x, 0))); | |||||
dst_storage y = apply( | |||||
src_storage(unpack_integer_4bits<src_signedness>(a.x, 4))); | |||||
*(dst_vect_type*)(&dst[idx]) = | |||||
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y); | |||||
} | |||||
#endif | |||||
}; | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | |||||
1, ctype_src, ctype_dst, KernImpl, | |||||
typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value || | |||||
std::is_same<ctype_src, dt_qint32>::value || | |||||
std::is_same<ctype_src, dt_quint8>::value) && | |||||
IsTypeQ4<ctype_dst>::value>::type> { | |||||
using dst_storage = | |||||
typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | |||||
dst_storage* dst; | |||||
CudaDTypeParam<ctype_dst> dst_param; | |||||
CudaDTypeParam<ctype_src> param_a; | |||||
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type | |||||
dst_vect_type; | |||||
#if !MEGDNN_CC_CUDA | |||||
QuantizedMultiTypeOp( | |||||
const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
: dst{dst}, dst_param{dst_param} { | |||||
param_a = src_params[0]; | |||||
} | |||||
#endif | |||||
#if MEGDNN_CC_CUDA | |||||
__device__ __forceinline__ dst_storage apply(ctype_src v1) { | |||||
float fv1 = param_a.dequantize(v1); | |||||
float rv = KernImpl::apply(fv1); | |||||
return dst_param.quantize(rv).as_storage(); | |||||
} | |||||
__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a_x, | |||||
ctype_src a_y) { | |||||
dst_storage x = apply(a_x), y = apply(a_y); | |||||
*(dst_vect_type*)(&dst[idx]) = | |||||
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y); | |||||
} | |||||
#endif | |||||
}; | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | |||||
2, ctype_src, ctype_dst, KernImpl, | |||||
typename std::enable_if<IsTypeQ4<ctype_src>::value && | |||||
IsTypeQ4<ctype_dst>::value>::type> { | |||||
using src_storage = | |||||
typename elemwise_intl::VectTypeTrait<ctype_src>::Storage; | |||||
using dst_storage = | |||||
typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | |||||
dst_storage* dst; | |||||
CudaDTypeParam<ctype_dst> dst_param; | |||||
CudaDTypeParam<ctype_src> param_a, param_b; | |||||
static constexpr bool src_signedness = | |||||
std::is_same<ctype_src, dt_qint4>::value; | |||||
typedef typename elemwise_intl::VectTypeTrait<ctype_src>::vect_type | |||||
src_vect_type; | |||||
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type | |||||
dst_vect_type; | |||||
#if !MEGDNN_CC_CUDA | |||||
QuantizedMultiTypeOp( | |||||
const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
: dst{dst}, dst_param{dst_param} { | |||||
param_a = src_params[0]; | |||||
param_b = src_params[1]; | |||||
} | |||||
#endif | |||||
#if MEGDNN_CC_CUDA | |||||
__device__ __forceinline__ dst_storage apply(src_storage v1, | |||||
src_storage v2) { | |||||
float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2); | |||||
float rv = KernImpl::apply(fv1, fv2); | |||||
return dst_param.quantize(rv).as_storage(); | |||||
} | |||||
__device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a, | |||||
src_vect_type b) { | |||||
src_storage a_x = | |||||
src_storage(unpack_integer_4bits<src_signedness>(a.x, 0)); | |||||
src_storage a_y = | |||||
src_storage(unpack_integer_4bits<src_signedness>(a.x, 4)); | |||||
src_storage b_x = | |||||
src_storage(unpack_integer_4bits<src_signedness>(b.x, 0)); | |||||
src_storage b_y = | |||||
src_storage(unpack_integer_4bits<src_signedness>(b.x, 4)); | |||||
dst_storage x = apply(a_x, b_x), y = apply(a_y, b_y); | |||||
*(dst_vect_type*)(&dst[idx]) = | |||||
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y); | |||||
} | |||||
#endif | |||||
}; | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | |||||
2, ctype_src, ctype_dst, KernImpl, | |||||
typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value || | |||||
std::is_same<ctype_src, dt_qint32>::value || | |||||
std::is_same<ctype_src, dt_quint8>::value) && | |||||
IsTypeQ4<ctype_dst>::value>::type> { | |||||
using dst_storage = | |||||
typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | |||||
dst_storage* dst; | |||||
CudaDTypeParam<ctype_dst> dst_param; | |||||
CudaDTypeParam<ctype_src> param_a, param_b; | |||||
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type | |||||
dst_vect_type; | |||||
#if !MEGDNN_CC_CUDA | |||||
QuantizedMultiTypeOp( | |||||
const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
: dst{dst}, dst_param{dst_param} { | |||||
param_a = src_params[0]; | |||||
param_b = src_params[1]; | |||||
} | |||||
#endif | |||||
#if MEGDNN_CC_CUDA | |||||
__device__ __forceinline__ dst_storage apply(ctype_src v1, ctype_src v2) { | |||||
float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2); | |||||
float rv = KernImpl::apply(fv1, fv2); | |||||
return dst_param.quantize(rv).as_storage(); | |||||
} | |||||
__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a_x, | |||||
ctype_src b_x, ctype_src a_y, | |||||
ctype_src b_y) { | |||||
dst_storage x = apply(a_x, b_x), y = apply(a_y, b_y); | |||||
*(dst_vect_type*)(&dst[idx]) = | |||||
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y); | |||||
} | |||||
#endif | |||||
}; | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | |||||
3, ctype_src, ctype_dst, KernImpl, | |||||
typename std::enable_if<IsTypeQ4<ctype_src>::value && | |||||
IsTypeQ4<ctype_dst>::value>::type> { | |||||
using src_storage = | |||||
typename elemwise_intl::VectTypeTrait<ctype_src>::Storage; | |||||
using dst_storage = | |||||
typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | |||||
dst_storage* dst; | |||||
CudaDTypeParam<ctype_dst> dst_param; | |||||
CudaDTypeParam<ctype_src> param_a, param_b, param_c; | |||||
static constexpr bool src_signedness = | |||||
std::is_same<ctype_src, dt_qint4>::value; | |||||
typedef typename elemwise_intl::VectTypeTrait<ctype_src>::vect_type | |||||
src_vect_type; | |||||
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type | |||||
dst_vect_type; | |||||
#if !MEGDNN_CC_CUDA | |||||
QuantizedMultiTypeOp( | |||||
const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
: dst{dst}, dst_param{dst_param} { | |||||
param_a = src_params[0]; | |||||
param_b = src_params[1]; | |||||
param_c = src_params[2]; | |||||
} | |||||
#endif | |||||
#if MEGDNN_CC_CUDA | |||||
__device__ __forceinline__ dst_storage apply(src_storage v1, src_storage v2, | |||||
src_storage v3) { | |||||
float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2), | |||||
fv3 = param_c.dequantize(v3); | |||||
float rv = KernImpl::apply(fv1, fv2, fv3); | |||||
return dst_param.quantize(rv).as_storage(); | |||||
} | |||||
__device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a, | |||||
src_vect_type b, | |||||
src_vect_type c) { | |||||
src_storage a_x = | |||||
src_storage(unpack_integer_4bits<src_signedness>(a.x, 0)); | |||||
src_storage a_y = | |||||
src_storage(unpack_integer_4bits<src_signedness>(a.x, 4)); | |||||
src_storage b_x = | |||||
src_storage(unpack_integer_4bits<src_signedness>(b.x, 0)); | |||||
src_storage b_y = | |||||
src_storage(unpack_integer_4bits<src_signedness>(b.x, 4)); | |||||
src_storage c_x = | |||||
src_storage(unpack_integer_4bits<src_signedness>(c.x, 0)); | |||||
src_storage c_y = | |||||
src_storage(unpack_integer_4bits<src_signedness>(c.x, 4)); | |||||
dst_storage x = apply(a_x, b_x, c_x), y = apply(a_y, b_y, c_y); | |||||
*(dst_vect_type*)(&dst[idx]) = | |||||
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y); | |||||
} | |||||
#endif | |||||
}; | |||||
template <typename ctype_src, typename ctype_dst, typename KernImpl> | |||||
struct QuantizedMultiTypeOp< | |||||
3, ctype_src, ctype_dst, KernImpl, | |||||
typename std::enable_if<(std::is_same<ctype_src, dt_qint8>::value || | |||||
std::is_same<ctype_src, dt_qint32>::value || | |||||
std::is_same<ctype_src, dt_quint8>::value) && | |||||
IsTypeQ4<ctype_dst>::value>::type> { | |||||
using dst_storage = | |||||
typename elemwise_intl::VectTypeTrait<ctype_dst>::Storage; | |||||
dst_storage* dst; | |||||
CudaDTypeParam<ctype_dst> dst_param; | |||||
CudaDTypeParam<ctype_src> param_a, param_b, param_c; | |||||
typedef typename elemwise_intl::VectTypeTrait<ctype_dst>::vect_type | |||||
dst_vect_type; | |||||
#if !MEGDNN_CC_CUDA | |||||
QuantizedMultiTypeOp( | |||||
const SmallVector<CudaDTypeParam<ctype_src>>& src_params, | |||||
dst_storage* dst, const CudaDTypeParam<ctype_dst>& dst_param) | |||||
: dst{dst}, dst_param{dst_param} { | |||||
param_a = src_params[0]; | |||||
param_b = src_params[1]; | |||||
param_c = src_params[2]; | |||||
} | |||||
#endif | |||||
#if MEGDNN_CC_CUDA | |||||
__device__ __forceinline__ dst_storage apply(ctype_src v1, ctype_src v2, | |||||
ctype_src v3) { | |||||
float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2), | |||||
fv3 = param_c.dequantize(v3); | |||||
float rv = KernImpl::apply(fv1, fv2, fv3); | |||||
return dst_param.quantize(rv).as_storage(); | |||||
} | |||||
__device__ __forceinline__ void operator()(uint32_t idx, ctype_src a_x, | |||||
ctype_src b_x, ctype_src c_x, | |||||
ctype_src a_y, ctype_src b_y, | |||||
ctype_src c_y) { | |||||
dst_storage x = apply(a_x, b_x, c_x), y = apply(a_y, b_y, c_y); | |||||
*(dst_vect_type*)(&dst[idx]) = | |||||
elemwise_intl::VectTypeTrait<ctype_dst>::make_vector(x, y); | |||||
} | |||||
#endif | |||||
}; | |||||
} // namespace kern_ops_quantized | } // namespace kern_ops_quantized | ||||
} // namespace cuda | } // namespace cuda | ||||
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) | |||||
#define KERN_IMPL_ARITY 3 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) | |||||
#define KERN_IMPL_ARITY 2 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint32 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_qint4 | |||||
#define KERN_IMPL_DTYPE dt_qint4 | |||||
#include "../kern_impl_q4.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_qint32 | |||||
#include "../kern_impl.inl" |
@@ -0,0 +1,6 @@ | |||||
// generated by gen_elemwise_multi_type_kern_impls.py | |||||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||||
#define KERN_IMPL_ARITY 1 | |||||
#define KERN_IMPL_STYPE dt_quint4 | |||||
#define KERN_IMPL_DTYPE dt_quint4 | |||||
#include "../kern_impl_q4.inl" |