|
@@ -143,6 +143,9 @@ INST(dt_qint32, int4); |
|
|
* |
|
|
* |
|
|
*/ |
|
|
*/ |
|
|
template <int ndim, typename ctype, BcastType brd_type> |
|
|
template <int ndim, typename ctype, BcastType brd_type> |
|
|
|
|
|
class ParamVisitorBase; |
|
|
|
|
|
|
|
|
|
|
|
template <int ndim, typename ctype, BcastType brd_type> |
|
|
class ParamElemVisitor; |
|
|
class ParamElemVisitor; |
|
|
|
|
|
|
|
|
/*! |
|
|
/*! |
|
@@ -157,6 +160,7 @@ class ParamElemVisitor; |
|
|
* ptr()[offset(idx)] |
|
|
* ptr()[offset(idx)] |
|
|
* |
|
|
* |
|
|
*/ |
|
|
*/ |
|
|
|
|
|
|
|
|
template <int ndim, typename ctype, BcastType brd_type> |
|
|
template <int ndim, typename ctype, BcastType brd_type> |
|
|
class ParamVectVisitor; |
|
|
class ParamVectVisitor; |
|
|
|
|
|
|
|
@@ -169,11 +173,9 @@ class ParamVectVisitor; |
|
|
|
|
|
|
|
|
//! specialization for BCAST_OTHER |
|
|
//! specialization for BCAST_OTHER |
|
|
template <int ndim, typename ctype> |
|
|
template <int ndim, typename ctype> |
|
|
class ParamElemVisitor<ndim, ctype, BCAST_OTHER> { |
|
|
|
|
|
|
|
|
class ParamVisitorBase<ndim, ctype, BCAST_OTHER> { |
|
|
protected: |
|
|
protected: |
|
|
ctype* __restrict m_ptr; |
|
|
ctype* __restrict m_ptr; |
|
|
|
|
|
|
|
|
private: |
|
|
|
|
|
int m_stride[ndim]; |
|
|
int m_stride[ndim]; |
|
|
|
|
|
|
|
|
//! m_shape_highdim[i] = original_shape[i + 1] |
|
|
//! m_shape_highdim[i] = original_shape[i + 1] |
|
@@ -185,10 +187,9 @@ private: |
|
|
|
|
|
|
|
|
public: |
|
|
public: |
|
|
static const int NDIM = ndim; |
|
|
static const int NDIM = ndim; |
|
|
PARAM_ELEM_VISITOR_COMMON_HOST |
|
|
|
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size, |
|
|
|
|
|
int packed_size); |
|
|
#if MEGDNN_CC_CUDA |
|
|
#if MEGDNN_CC_CUDA |
|
|
devfunc void thread_init(uint32_t) {} |
|
|
devfunc void thread_init(uint32_t) {} |
|
|
|
|
|
|
|
@@ -211,6 +212,18 @@ public: |
|
|
#endif |
|
|
#endif |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template <int ndim, typename ctype> |
|
|
|
|
|
class ParamElemVisitor<ndim, ctype, BCAST_OTHER> |
|
|
|
|
|
: public ParamVisitorBase<ndim, ctype, BCAST_OTHER> { |
|
|
|
|
|
public: |
|
|
|
|
|
PARAM_ELEM_VISITOR_COMMON_HOST |
|
|
|
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size) { |
|
|
|
|
|
ParamVisitorBase<ndim, ctype, BCAST_OTHER>::host_init( |
|
|
|
|
|
rv, grid_size, block_size, packed_size); |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
/*! |
|
|
/*! |
|
|
* \brief specialization for ndim == 3 and BCAST_101 |
|
|
* \brief specialization for ndim == 3 and BCAST_101 |
|
|
* (for dimshuffle 'x', 0, 'x') |
|
|
* (for dimshuffle 'x', 0, 'x') |
|
@@ -218,7 +231,7 @@ public: |
|
|
* visit: idx / m_shape2 % m_shape1 |
|
|
* visit: idx / m_shape2 % m_shape1 |
|
|
*/ |
|
|
*/ |
|
|
template <typename ctype> |
|
|
template <typename ctype> |
|
|
class ParamElemVisitor<3, ctype, BCAST_101> { |
|
|
|
|
|
|
|
|
class ParamVisitorBase<3, ctype, BCAST_101> { |
|
|
StridedDivSeq2 m_shape12; |
|
|
StridedDivSeq2 m_shape12; |
|
|
int m_stride1; |
|
|
int m_stride1; |
|
|
|
|
|
|
|
@@ -227,9 +240,9 @@ protected: |
|
|
|
|
|
|
|
|
public: |
|
|
public: |
|
|
static const int NDIM = 3; |
|
|
static const int NDIM = 3; |
|
|
PARAM_ELEM_VISITOR_COMMON_HOST |
|
|
|
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size); |
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size, |
|
|
|
|
|
int packed_size); |
|
|
|
|
|
|
|
|
#if MEGDNN_CC_CUDA |
|
|
#if MEGDNN_CC_CUDA |
|
|
devfunc void thread_init(uint32_t idx) { m_shape12.device_init(idx); } |
|
|
devfunc void thread_init(uint32_t idx) { m_shape12.device_init(idx); } |
|
@@ -242,13 +255,25 @@ public: |
|
|
#endif |
|
|
#endif |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template <typename ctype> |
|
|
|
|
|
class ParamElemVisitor<3, ctype, BCAST_101> |
|
|
|
|
|
: public ParamVisitorBase<3, ctype, BCAST_101> { |
|
|
|
|
|
public: |
|
|
|
|
|
PARAM_ELEM_VISITOR_COMMON_HOST |
|
|
|
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size) { |
|
|
|
|
|
ParamVisitorBase<3, ctype, BCAST_101>::host_init( |
|
|
|
|
|
rv, grid_size, block_size, packed_size); |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
/*! |
|
|
/*! |
|
|
* \brief specialization for ndim == 2 and BCAST_10 |
|
|
* \brief specialization for ndim == 2 and BCAST_10 |
|
|
* |
|
|
* |
|
|
* visit: idx % m_shape1 |
|
|
* visit: idx % m_shape1 |
|
|
*/ |
|
|
*/ |
|
|
template <typename ctype> |
|
|
template <typename ctype> |
|
|
class ParamElemVisitor<2, ctype, BCAST_10> { |
|
|
|
|
|
|
|
|
class ParamVisitorBase<2, ctype, BCAST_10> { |
|
|
StridedDivSeq<false> m_shape1; |
|
|
StridedDivSeq<false> m_shape1; |
|
|
int m_stride1; |
|
|
int m_stride1; |
|
|
|
|
|
|
|
@@ -257,9 +282,9 @@ protected: |
|
|
|
|
|
|
|
|
public: |
|
|
public: |
|
|
static const int NDIM = 2; |
|
|
static const int NDIM = 2; |
|
|
PARAM_ELEM_VISITOR_COMMON_HOST |
|
|
|
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size); |
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size, |
|
|
|
|
|
int packed_size); |
|
|
|
|
|
|
|
|
#if MEGDNN_CC_CUDA |
|
|
#if MEGDNN_CC_CUDA |
|
|
devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); } |
|
|
devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); } |
|
@@ -272,13 +297,25 @@ public: |
|
|
#endif |
|
|
#endif |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template <typename ctype> |
|
|
|
|
|
class ParamElemVisitor<2, ctype, BCAST_10> |
|
|
|
|
|
: public ParamVisitorBase<2, ctype, BCAST_10> { |
|
|
|
|
|
public: |
|
|
|
|
|
PARAM_ELEM_VISITOR_COMMON_HOST |
|
|
|
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size) { |
|
|
|
|
|
ParamVisitorBase<2, ctype, BCAST_10>::host_init( |
|
|
|
|
|
rv, grid_size, block_size, packed_size); |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
/*! |
|
|
/*! |
|
|
* \brief specialization for ndim == 2 and BCAST_01 |
|
|
* \brief specialization for ndim == 2 and BCAST_01 |
|
|
* |
|
|
* |
|
|
* visit: idx / shape1 |
|
|
* visit: idx / shape1 |
|
|
*/ |
|
|
*/ |
|
|
template <typename ctype> |
|
|
template <typename ctype> |
|
|
class ParamElemVisitor<2, ctype, BCAST_01> { |
|
|
|
|
|
|
|
|
class ParamVisitorBase<2, ctype, BCAST_01> { |
|
|
StridedDivSeq<true> m_shape1; |
|
|
StridedDivSeq<true> m_shape1; |
|
|
int m_stride0; |
|
|
int m_stride0; |
|
|
|
|
|
|
|
@@ -287,9 +324,9 @@ protected: |
|
|
|
|
|
|
|
|
public: |
|
|
public: |
|
|
static const int NDIM = 2; |
|
|
static const int NDIM = 2; |
|
|
PARAM_ELEM_VISITOR_COMMON_HOST |
|
|
|
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size); |
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size, |
|
|
|
|
|
int packed_size); |
|
|
|
|
|
|
|
|
#if MEGDNN_CC_CUDA |
|
|
#if MEGDNN_CC_CUDA |
|
|
devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); } |
|
|
devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); } |
|
@@ -302,9 +339,21 @@ public: |
|
|
#endif |
|
|
#endif |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template <typename ctype> |
|
|
|
|
|
class ParamElemVisitor<2, ctype, BCAST_01> |
|
|
|
|
|
: public ParamVisitorBase<2, ctype, BCAST_01> { |
|
|
|
|
|
public: |
|
|
|
|
|
PARAM_ELEM_VISITOR_COMMON_HOST |
|
|
|
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size) { |
|
|
|
|
|
ParamVisitorBase<2, ctype, BCAST_01>::host_init( |
|
|
|
|
|
rv, grid_size, block_size, packed_size); |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
//! specialization for ndim == 1 and BCAST_FULL |
|
|
//! specialization for ndim == 1 and BCAST_FULL |
|
|
template <typename ctype> |
|
|
template <typename ctype> |
|
|
class ParamElemVisitor<1, ctype, BCAST_FULL> { |
|
|
|
|
|
|
|
|
class ParamVisitorBase<1, ctype, BCAST_FULL> { |
|
|
protected: |
|
|
protected: |
|
|
ctype* __restrict m_ptr; |
|
|
ctype* __restrict m_ptr; |
|
|
|
|
|
|
|
@@ -312,7 +361,8 @@ public: |
|
|
static const int NDIM = 1; |
|
|
static const int NDIM = 1; |
|
|
PARAM_ELEM_VISITOR_COMMON_HOST |
|
|
PARAM_ELEM_VISITOR_COMMON_HOST |
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size); |
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size, |
|
|
|
|
|
int packed_size); |
|
|
|
|
|
|
|
|
#if MEGDNN_CC_CUDA |
|
|
#if MEGDNN_CC_CUDA |
|
|
devfunc void thread_init(uint32_t) {} |
|
|
devfunc void thread_init(uint32_t) {} |
|
@@ -328,6 +378,18 @@ public: |
|
|
#endif |
|
|
#endif |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
template <typename ctype> |
|
|
|
|
|
class ParamElemVisitor<1, ctype, BCAST_FULL> |
|
|
|
|
|
: public ParamVisitorBase<1, ctype, BCAST_FULL> { |
|
|
|
|
|
public: |
|
|
|
|
|
PARAM_ELEM_VISITOR_COMMON_HOST |
|
|
|
|
|
|
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size) { |
|
|
|
|
|
ParamVisitorBase<1, ctype, BCAST_FULL>::host_init( |
|
|
|
|
|
rv, grid_size, block_size, packed_size); |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
#undef PARAM_ELEM_VISITOR_COMMON_DEV |
|
|
#undef PARAM_ELEM_VISITOR_COMMON_DEV |
|
|
#undef PARAM_ELEM_VISITOR_COMMON_HOST |
|
|
#undef PARAM_ELEM_VISITOR_COMMON_HOST |
|
|
|
|
|
|
|
@@ -340,17 +402,21 @@ public: |
|
|
#else |
|
|
#else |
|
|
#define DEVICE_WRAPPER(x) |
|
|
#define DEVICE_WRAPPER(x) |
|
|
#endif |
|
|
#endif |
|
|
#define INST_PARAM_VECT_VISITOR \ |
|
|
|
|
|
template <int ndim, typename ctype> \ |
|
|
|
|
|
class ParamVectVisitor<ndim, ctype, _brdcast_mask> \ |
|
|
|
|
|
: public ParamElemVisitor<ndim, ctype, _brdcast_mask> { \ |
|
|
|
|
|
public: \ |
|
|
|
|
|
using Super = ParamElemVisitor<ndim, ctype, _brdcast_mask>; \ |
|
|
|
|
|
using rwtype = typename VectTypeTrait<ctype>::vect_type; \ |
|
|
|
|
|
static const int packed_size = sizeof(rwtype) / sizeof(ctype); \ |
|
|
|
|
|
DEVICE_WRAPPER(devfunc rwtype& at(uint32_t idx) { \ |
|
|
|
|
|
return *(rwtype*)(&Super::m_ptr[Super::offset(idx)]); \ |
|
|
|
|
|
}) \ |
|
|
|
|
|
|
|
|
#define INST_PARAM_VECT_VISITOR \ |
|
|
|
|
|
template <int ndim, typename ctype> \ |
|
|
|
|
|
class ParamVectVisitor<ndim, ctype, _brdcast_mask> \ |
|
|
|
|
|
: public ParamVisitorBase<ndim, ctype, _brdcast_mask> { \ |
|
|
|
|
|
public: \ |
|
|
|
|
|
using Super = ParamVisitorBase<ndim, ctype, _brdcast_mask>; \ |
|
|
|
|
|
using rwtype = typename VectTypeTrait<ctype>::vect_type; \ |
|
|
|
|
|
static const int packed_size = sizeof(rwtype) / sizeof(ctype); \ |
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size) { \ |
|
|
|
|
|
ParamVisitorBase<ndim, ctype, _brdcast_mask>::host_init( \ |
|
|
|
|
|
rv, grid_size, block_size, packed_size); \ |
|
|
|
|
|
} \ |
|
|
|
|
|
DEVICE_WRAPPER(devfunc rwtype& at(uint32_t idx) { \ |
|
|
|
|
|
return *(rwtype*)(&Super::m_ptr[Super::offset(idx)]); \ |
|
|
|
|
|
}) \ |
|
|
}; |
|
|
}; |
|
|
#define _brdcast_mask BCAST_OTHER |
|
|
#define _brdcast_mask BCAST_OTHER |
|
|
INST_PARAM_VECT_VISITOR; |
|
|
INST_PARAM_VECT_VISITOR; |
|
@@ -367,11 +433,15 @@ INST_PARAM_VECT_VISITOR; |
|
|
#define INST_DT_IBYTE(ctype) \ |
|
|
#define INST_DT_IBYTE(ctype) \ |
|
|
template <int ndim> \ |
|
|
template <int ndim> \ |
|
|
class ParamVectVisitor<ndim, ctype, BCAST_FULL> \ |
|
|
class ParamVectVisitor<ndim, ctype, BCAST_FULL> \ |
|
|
: public ParamElemVisitor<ndim, ctype, BCAST_FULL> { \ |
|
|
|
|
|
|
|
|
: public ParamVisitorBase<ndim, ctype, BCAST_FULL> { \ |
|
|
public: \ |
|
|
public: \ |
|
|
using Super = ParamElemVisitor<ndim, ctype, BCAST_FULL>; \ |
|
|
|
|
|
|
|
|
using Super = ParamVisitorBase<ndim, ctype, BCAST_FULL>; \ |
|
|
using rwtype = typename VectTypeTrait<ctype>::vect_type; \ |
|
|
using rwtype = typename VectTypeTrait<ctype>::vect_type; \ |
|
|
static const int packed_size = sizeof(rwtype) / sizeof(ctype); \ |
|
|
static const int packed_size = sizeof(rwtype) / sizeof(ctype); \ |
|
|
|
|
|
void host_init(const TensorND& rv, int grid_size, int block_size) { \ |
|
|
|
|
|
ParamVisitorBase<ndim, ctype, BCAST_FULL>::host_init( \ |
|
|
|
|
|
rv, grid_size, block_size, packed_size); \ |
|
|
|
|
|
} \ |
|
|
DEVICE_WRAPPER(rwtype vect_scalar; \ |
|
|
DEVICE_WRAPPER(rwtype vect_scalar; \ |
|
|
devfunc rwtype & at(uint32_t /* idx */) { \ |
|
|
devfunc rwtype & at(uint32_t /* idx */) { \ |
|
|
ctype v = Super::m_ptr[0]; \ |
|
|
ctype v = Super::m_ptr[0]; \ |
|
|