This reverts commitHuaHua404-patch-43436c3bdaa
. GitOrigin-RevId:95ab3d1aa7
@@ -314,48 +314,6 @@ protected: | |||
}; | |||
using Cumsum = CumsumForward; | |||
class CumprodForward : public OperatorBase { | |||
DEF_OPR_PARAM(Cumprod); | |||
DEF_OPR_IMPL(CumprodForward, OperatorBase, 1, 1); | |||
public: | |||
/** | |||
* \param[in] src input tensor | |||
* \param[out] dst output tensor | |||
* | |||
* src and dst should be contiguous. | |||
* src and dst should have the same shape. | |||
* | |||
* The exclusive flag specifies whether the current element it taken | |||
* into account when calculating results. | |||
* | |||
* The reverse flag specifies whether cumprod is forward ( | |||
* from 0 to n) or backward (from n downto 0). | |||
* | |||
* Example: | |||
* exclusive && reverse: | |||
* dst_i = src_{i+1} * src_{i+2} * ... * src_{n-1} | |||
* exclusive && !reverse | |||
* dst_i = src_0 * src_1 * ... * src_{i-1} | |||
* !exclusive && reverse: | |||
* dst_i = src_i * src_{i+1} * ... * src_{n-1} | |||
* !exclusive && !reverse: | |||
* dst_i = src_0 * src_1 * ... * src_i | |||
*/ | |||
virtual void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec( | |||
const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using Cumprod = CumprodForward; | |||
// mxx can be max or min | |||
class ArgmxxBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(ArgmxxBase, OperatorBase); | |||
@@ -24,7 +24,7 @@ MODES = { | |||
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | |||
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD', 'PRELU', | |||
'ASINH_GRAD', 'ACOSH_GRAD', 'ATANH_GRAD', 'SOFTPLUS_GRAD', | |||
'RELU6_GRAD', 'HSIGMOID_GRAD', 'SAFE_DIV'], | |||
'RELU6_GRAD', 'HSIGMOID_GRAD'], | |||
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3', 'CLIP', 'PRELU_GRAD'], | |||
} | |||
@@ -31,7 +31,7 @@ MODES = { | |||
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | |||
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD', 'PRELU', | |||
'ASINH_GRAD', 'ACOSH_GRAD', 'ATANH_GRAD', 'SOFTPLUS_GRAD', | |||
'RELU6_GRAD', 'HSIGMOID_GRAD', 'SAFE_DIV'], | |||
'RELU6_GRAD', 'HSIGMOID_GRAD'], | |||
(3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3', 'CLIP', 'PRELU_GRAD'], | |||
(1, 'BOOL'): ['NOT'], | |||
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], | |||
@@ -443,10 +443,9 @@ pdef('Elemwise').add_enum( | |||
Doc('SQRT = 80', 'unary: x^(1/2)'), | |||
Doc('SQUARE = 81', 'unary: x^2'), | |||
Doc('SIGN = 82', 'unary: sgn(x)'), | |||
Doc('SAFE_DIV = 83', 'safe div: x / y'), | |||
Doc('NEQ = 84', 'binary: x != y'), | |||
Doc('ISNAN = 85', 'unary: isnan(x)'), | |||
Doc('ISINF = 86', 'unary: isinf(x)'), | |||
Doc('NEQ = 83', 'binary: x != y'), | |||
Doc('ISNAN = 84', 'unary: isnan(x)'), | |||
Doc('ISINF = 85', 'unary: isinf(x)'), | |||
) | |||
pdef('ElemwiseMultiType').add_enum( | |||
@@ -740,20 +739,6 @@ Currently, ```DEFAULT``` mode means: | |||
'whether the cumsum is forward or backward'), | |||
'false')) | |||
(pdef('Cumprod', 'calculate accumulated product along given axis'). | |||
add_fields('int32', | |||
Doc('axis', | |||
'axis along which cumprod is performed, default with INT_MAX'), | |||
(1<<31)-1). | |||
add_fields('bool', | |||
Doc('exclusive', | |||
'whether the current element is taken into account'), | |||
'true'). | |||
add_fields('bool', | |||
Doc('reverse', | |||
'whether the cumprod is forward or backward'), | |||
'false')) | |||
(pdef('CondTake'). | |||
add_enum('Mode', | |||
Doc('EQ = 0', 'take if ``abs(data-val)<eps``'), | |||
@@ -1,25 +0,0 @@ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void CumprodForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { | |||
megdnn_assert_contiguous(src); | |||
dst = src; | |||
} | |||
void CumprodForward::check_exec( | |||
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { | |||
megdnn_assert_contiguous(src); | |||
megdnn_assert_eq_layout(src, dst); | |||
megdnn_assert(param().axis >= 0); | |||
megdnn_assert(static_cast<size_t>(param().axis) < src.ndim); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -89,8 +89,7 @@ | |||
MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) \ | |||
MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) \ | |||
MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) \ | |||
MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) \ | |||
MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) | |||
MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) | |||
#define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \ | |||
MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ | |||
@@ -247,8 +247,6 @@ DEF_KERN(dt_bool, EQ, x == y); | |||
DEF_KERN_INT(FLOOR_DIV, dispatch_floordiv_int(x, y)); | |||
DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); | |||
DEF_KERN_INT(SAFE_DIV, y != 0 ? x / y : 0); | |||
DEF_KERN_FLOAT(SAFE_DIV, y != 0.f ? x / y : 0.f); | |||
DEF_KERN_INT(MOD, x % y); | |||
DEF_KERN_FLOAT(MOD, fmodf(x, y)); | |||
@@ -242,7 +242,6 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { | |||
CB_MODE(Mode::SQRT); | |||
CB_MODE(Mode::SQUARE); | |||
CB_MODE(Mode::SIGN); | |||
CB_MODE(Mode::SAFE_DIV); | |||
default: | |||
megdnn_assert( | |||
0, | |||
@@ -106,7 +106,6 @@ private: | |||
cb(SVDForward) \ | |||
cb(ReduceForward) \ | |||
cb(CondTake) \ | |||
cb(CumprodForward) \ | |||
cb(CumsumForward) \ | |||
cb(ArgmaxForward) \ | |||
cb(ArgminForward) \ | |||
@@ -62,7 +62,6 @@ DEF(BatchedMatrixMulForward, 3, true, true); | |||
DEF(MatrixInverse, 2, true, true); | |||
DEF(SVDForward, 4, true, true); | |||
DEF(ReduceForward, 2, true, true); | |||
DEF(CumprodForward, 2, true, true); | |||
DEF(CumsumForward, 2, true, true); | |||
DEF(ArgmaxForward, 2, true, true); | |||
DEF(ArgminForward, 2, true, true); | |||
@@ -1,25 +0,0 @@ | |||
#include "./kern_impl.cuinl" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace cumprod { | |||
#define INST_(T, Op, exclusive, reverse) \ | |||
template void run_kern<T, Op, exclusive, reverse>( \ | |||
T*, void*, uint32_t, uint32_t, uint32_t, uint32_t, const Op&, \ | |||
cudaStream_t) | |||
#define INST(T) \ | |||
INST_(T, ProdOp<T>, true, true); \ | |||
INST_(T, ProdOp<T>, false, true); \ | |||
INST_(T, ProdOp<T>, true, false); \ | |||
INST_(T, ProdOp<T>, false, false); | |||
#define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
} // namespace cumprod | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: ft=cuda syntax=cuda.doxygen |
@@ -1,62 +0,0 @@ | |||
#pragma once | |||
#include "src/cuda/utils.cuh" | |||
#include <cuda_runtime_api.h> | |||
#include <stdint.h> | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace cumprod { | |||
//! compute conventional sum of elements | |||
template <typename T> | |||
struct ProdOp { | |||
const T* data; | |||
typedef ProdOp ContigOp; | |||
ProdOp(const T* d) : data(d) {} | |||
__host__ __device__ static T init() { return T(1); } | |||
__device__ static T apply(T lhs, T rhs) { return lhs * rhs; } | |||
__device__ T visit(uint32_t idx) const { return data[idx]; } | |||
static ProdOp make_contig(const T* data) { return ProdOp(data); } | |||
}; | |||
/*! | |||
* \brief cumprod kernel launcher; defined in kern_impl.cuinl | |||
* \tparam T output data type | |||
* \tparam Op reduction operator class, which must provide following interface: | |||
* typdef ContigOp | |||
* static T init(): the identity element | |||
* static T apply(T lhs, T rhs): the reduction operation | |||
* T visit(uint32_t idx) const: access input | |||
* static ContigOp make_contig(const T *data): make an Oo to continue | |||
* reduction on temp buffer | |||
* | |||
* Note that Op::init() must be accessible from both host and device. | |||
* | |||
* In exclusive mode, Op::init() would be filled to the boundary | |||
* | |||
* The buffer in *op* and *dst* should not have identical memory addresses. | |||
*/ | |||
template <typename T, typename Op, bool exclusive, bool reverse> | |||
void run_kern( | |||
T* dst, void* workspace, uint32_t workspace_size, uint32_t A, uint32_t B, | |||
uint32_t C, const Op& op, cudaStream_t stream); | |||
/*! | |||
* \brief get required workspace size for cumprod, in bytes | |||
* \param item_size size of item; i.e. sizeof(T) in run_kern | |||
* | |||
* Note: cuda device must be set to the computing device before calling this | |||
* function. | |||
*/ | |||
uint32_t get_workspace_in_bytes(uint32_t A, uint32_t B, uint32_t C, uint32_t item_size); | |||
} // namespace cumprod | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: ft=cpp syntax=cpp.doxygen |
@@ -1,18 +0,0 @@ | |||
#pragma once | |||
#include <cuda_runtime_api.h> | |||
#include <stdint.h> | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace cumprod { | |||
void get_BX_BY(uint32_t A, uint32_t B, uint32_t C, uint32_t& BX, uint32_t& BY); | |||
uint32_t get_workspace_bytes_for_cub_1d(uint32_t nr_item, uint32_t item_size); | |||
} // namespace cumprod | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: ft=cpp syntax=cpp.doxygen |
@@ -1,82 +0,0 @@ | |||
#include "./kern.cuh" | |||
#include "./kern_helper.cuh" | |||
#include "./kern_impl.cuinl" | |||
#include "src/cuda/kernel_common/diagnostic_prologue.cuh" | |||
using namespace megdnn::cuda; | |||
using namespace cumprod::detail::cubwrap; | |||
namespace { | |||
template <typename T> | |||
struct FakeOp { | |||
__device__ T visit(int) { return 0; } | |||
__device__ static T apply(T, T) { return 0; } | |||
}; | |||
template <bool reverse, typename T> | |||
uint32_t get_workspace_elems_for_cub_1d_with_dtype_reverse(uint32_t nr_item) { | |||
typedef FakeOp<T> Op; | |||
Op op; | |||
InputIterator<T, Op, reverse> inp_iter(op, nr_item); | |||
OutputIterator<T, reverse> out_iter(NULL, nr_item); | |||
ScanOp<T, Op> scan_op; | |||
size_t wk_size0 = 0, wk_size1 = 0; | |||
cuda_check(cub::DeviceScan::ExclusiveScan( | |||
NULL, wk_size0, inp_iter, out_iter, scan_op, 0, nr_item)); | |||
cuda_check(cub::DeviceScan::InclusiveScan( | |||
NULL, wk_size1, inp_iter, out_iter, scan_op, nr_item)); | |||
return std::max(wk_size0, wk_size1); | |||
} | |||
template <typename T> | |||
uint32_t get_workspace_elems_for_cub_1d_with_dtype(uint32_t nr_item) { | |||
return std::max( | |||
get_workspace_elems_for_cub_1d_with_dtype_reverse<false, T>(nr_item), | |||
get_workspace_elems_for_cub_1d_with_dtype_reverse<true, T>(nr_item)); | |||
} | |||
} // namespace | |||
uint32_t cumprod::get_workspace_bytes_for_cub_1d(uint32_t nr_item, uint32_t item_size) { | |||
switch (item_size) { | |||
#define CASE(size, type) \ | |||
case size: \ | |||
return get_workspace_elems_for_cub_1d_with_dtype<type>(nr_item) | |||
CASE(1, uint8_t); | |||
CASE(2, uint16_t); | |||
CASE(4, uint32_t); | |||
CASE(8, uint64_t); | |||
#undef CASE | |||
default: | |||
report_error("unsupported item size in cumprod"); | |||
} | |||
} | |||
uint32_t cumprod::get_workspace_in_bytes( | |||
uint32_t A, uint32_t B, uint32_t C, uint32_t item_size) { | |||
if (A == 1 && C == 1) { | |||
return get_workspace_bytes_for_cub_1d(B, item_size); | |||
} | |||
uint32_t BX, BY; | |||
get_BX_BY(A, B, C, BX, BY); | |||
uint32_t BY2 = BY * 2; | |||
uint32_t res = 0; | |||
while (B > BY2) { | |||
B = (B + BY2 - 1) / BY2; | |||
res += A * B * C; | |||
} | |||
return res * item_size; | |||
} | |||
void cumprod::get_BX_BY( | |||
uint32_t /* A */, uint32_t /* B */, uint32_t C, uint32_t& BX, uint32_t& BY) { | |||
BX = 1; | |||
while (BX < C && BX * 2 <= 32) | |||
BX *= 2; | |||
BY = 512 / BX; | |||
} | |||
#include "src/cuda/kernel_common/diagnostic_epilogue.cuh" | |||
// vim: syntax=cpp.doxygen |
@@ -1,326 +0,0 @@ | |||
#include "./kern.cuh" | |||
#include "./kern_helper.cuh" | |||
#include "megdnn/dtype.h" | |||
#include "src/cuda/cub/device/device_scan.cuh" | |||
#include "src/cuda/cub/util_ptx.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace cumprod { | |||
namespace detail { | |||
/** | |||
* src shape is (A, B, C), performing blockwise scan over B axis. | |||
* Each CUDA block calculates a blockwise scan result of size (BY2, BX). | |||
* The block area corresponds to a 2-D area on (B, C) dimension of src. | |||
* | |||
* Per-block prefix sum is stored in dst (dst has the same shape as src). | |||
* | |||
* The whole scan result of each block as a single value is stored in | |||
* block_sum (of shape (A, B/BY2, C)). | |||
* | |||
* block_sum can be NULL. | |||
* | |||
* src and dst can be inplace. | |||
* | |||
* We need to launch (C/BX)*ceil(B/BY2)*A blocks in total. | |||
* Because in CUDA the number of launched blocks over y and z axis are | |||
* limited (at most 65535), we launch all blocks over axis x. | |||
* | |||
* Param: exclusive | |||
* This flag specifies whether the scan is inclusive or exclusive, namely | |||
* whether src_i influences dst_i. | |||
* | |||
* Param: reverse: | |||
* This flag specifies whether the scan is forward or backward. | |||
* | |||
* Example: | |||
* !exclusive && !reverse: dst_i = op(src_0, src_1, ..., src_i) | |||
* !exclusive && reverse: dst_i = op(src_i, src_{i+1}, ..., src_{n-1}) | |||
* exclusive && !reverse: dst_i = op(src_0, src_1, ..., src{i-1}) | |||
* exclusive && reverse: dst_i = op(src_{i+1}, src{i+2}, ..., src{n-1}) | |||
* | |||
* Op should have the following methods: | |||
* static T init() | |||
* static T apply(T lhs, T rhs) | |||
*/ | |||
template <typename T, typename Op, bool exclusive, bool reverse, | |||
uint32_t BY, uint32_t BX> | |||
__global__ void scan_kernel(T *dst, T *block_sum, | |||
uint32_t A, uint32_t B, uint32_t C, const Op op) { | |||
constexpr size_t warp_size = 32; | |||
const uint32_t BY2 = BY*2; | |||
const uint32_t B_ = (B+BY2-1) / BY2; | |||
const uint32_t C_ = (C+BX-1) / BX; | |||
const uint32_t GX = C_; | |||
const uint32_t GY = B_; | |||
// src, dst: (A, B, C) | |||
// block_sum: (A, B_, C) | |||
// shared: (BY2+1, BX) | |||
const uint32_t bx = blockIdx.x % GX; | |||
const uint32_t by = blockIdx.x / GX % GY; | |||
const uint32_t bz = blockIdx.x / GX / GY; | |||
const uint32_t tx = threadIdx.x; | |||
const uint32_t ty = threadIdx.y; | |||
// TODO: shared memory bank conflict optimization | |||
#define shared_idx(x) ((x) + ((x) >> 5)) | |||
volatile __shared__ T cache[shared_idx((BY2+1)*BX)]; | |||
uint32_t base_offset = (bz)*B*C + (by*BY2)*C + (bx*BX); | |||
dst += base_offset; | |||
// load to cache | |||
if (reverse) { | |||
cache[shared_idx((BY2-ty)*BX+tx)] = ty+by*BY2 < B && tx+bx*BX < C ? | |||
op.visit(base_offset + ty*C + tx) : Op::init(); | |||
} else { | |||
cache[shared_idx((ty+1)*BX+tx)] = ty+by*BY2 < B && tx+bx*BX < C ? | |||
op.visit(base_offset + ty*C + tx) : Op::init(); | |||
} | |||
if (reverse) { | |||
cache[shared_idx((BY-ty)*BX+tx)] = | |||
(ty+BY) + by*BY2 < B && tx+bx*BX < C ? | |||
op.visit(base_offset + (ty+BY)*C + tx) : Op::init(); | |||
} else { | |||
cache[shared_idx((ty+BY+1)*BX+tx)] = | |||
(ty+BY) + by*BY2 < B && tx+bx*BX < C ? | |||
op.visit(base_offset + (ty+BY)*C + tx) : Op::init(); | |||
} | |||
if (ty == 0) { | |||
cache[shared_idx(tx)] = Op::init(); | |||
} | |||
__syncthreads(); | |||
uint32_t total, stride; | |||
// first pass | |||
#pragma unroll | |||
for (total = BY, stride = 1; | |||
total > 0; | |||
total >>= 1, stride <<= 1) | |||
{ | |||
if (ty < total) { | |||
uint32_t ai = shared_idx(stride * (2*ty+1) * BX + tx); | |||
uint32_t bi = shared_idx(stride * (2*ty+2) * BX + tx); | |||
cache[bi] = Op::apply(cache[bi], cache[ai]); | |||
} | |||
if (total > warp_size/BX) __syncthreads(); | |||
else cub::WARP_SYNC(0xffffffff); | |||
} | |||
// second pass | |||
#pragma unroll | |||
for (total = 1, stride = BY; | |||
stride > 0; | |||
total <<= 1, stride >>= 1) | |||
{ | |||
if (total > warp_size/BX) __syncthreads(); | |||
else cub::WARP_SYNC(0xffffffff); | |||
if (ty < total) { | |||
uint32_t ai = shared_idx(stride * (2*ty+0) * BX + tx); | |||
uint32_t bi = shared_idx(stride * (2*ty+1) * BX + tx); | |||
cache[bi] = Op::apply(cache[bi], cache[ai]); | |||
} | |||
} | |||
__syncthreads(); | |||
uint32_t ty_offset = (exclusive ? 0 : 1); | |||
if (ty+by*BY2 < B && tx+bx*BX < C) { | |||
if (reverse) { | |||
dst[ty*C + tx] = cache[shared_idx((BY2-1-ty+ty_offset)*BX + tx)]; | |||
} else { | |||
dst[ty*C + tx] = cache[shared_idx((ty+ty_offset)*BX + tx)]; | |||
} | |||
} | |||
if (ty+BY+by*BY2 < B && tx+bx*BX < C) { | |||
if (reverse) { | |||
dst[(ty+BY)*C + tx] = | |||
cache[shared_idx((BY2-1-(ty+BY)+ty_offset)*BX + tx)]; | |||
} else { | |||
dst[(ty+BY)*C + tx] = | |||
cache[shared_idx((ty+BY+ty_offset)*BX + tx)]; | |||
} | |||
} | |||
if (block_sum && ty == 0 && bx*BX+tx < C) { | |||
block_sum[(bz)*B_*C + (by)*C + (bx*BX) + tx] = | |||
cache[shared_idx(BY2*BX + tx)]; | |||
} | |||
} | |||
template <typename T, typename Op, uint32_t BY, uint32_t BX> | |||
__global__ void update_kernel(T *dst, const T *delta, | |||
uint32_t A, uint32_t B, uint32_t C) { | |||
const uint32_t BY2 = BY*2; | |||
const uint32_t B_ = (B+BY2-1) / BY2; | |||
const uint32_t C_ = (C+BX-1) / BX; | |||
const uint32_t GX = C_; | |||
const uint32_t GY = B_; | |||
// src: (A, B, C) | |||
// delta: (A, B_, C) | |||
const uint32_t bx = blockIdx.x % GX; | |||
const uint32_t by = blockIdx.x / GX % GY; | |||
const uint32_t bz = blockIdx.x / GX / GY; | |||
const uint32_t tx = threadIdx.x; | |||
const uint32_t ty = threadIdx.y; | |||
if (tx + bx*BX < C) { | |||
T delta_v = delta[(bz)*B_*C + (by)*C + (bx*BX) + tx]; | |||
if (ty+by*BY2 < B && tx+bx*BX < C) { | |||
T &res = dst[bz*B*C + (ty+by*BY2)*C + (tx+bx*BX)]; | |||
res = Op::apply(res, delta_v); | |||
} | |||
if (ty+BY+by*BY2 < B && tx+bx*BX < C) { | |||
T &res = dst[bz*B*C + (ty+BY+by*BY2)*C + (tx+bx*BX)]; | |||
res = Op::apply(res, delta_v); | |||
} | |||
} | |||
} | |||
template <typename T, typename Op, bool exclusive, bool reverse> | |||
void run_kern_multiAC(T* dst, T* workspace, uint32_t A, uint32_t B, | |||
uint32_t C, const Op& op, cudaStream_t stream); | |||
template <typename T, typename Op, bool exclusive, bool reverse, | |||
uint32_t BX, uint32_t BY> | |||
void do_run_kern(T *dst, T *workspace, | |||
uint32_t A, uint32_t B, uint32_t C, const Op &op, cudaStream_t stream) { | |||
const uint32_t BY2 = BY*2; | |||
const uint32_t B_ = (B+BY2-1)/BY2; | |||
const uint32_t C_ = (C+BX-1)/BX; | |||
dim3 blocks(C_*B_*A); | |||
dim3 threads(BX, BY); | |||
scan_kernel<T, Op, exclusive, reverse, BY, BX> | |||
<<<blocks, threads, 0, stream>>>( | |||
dst, B > BY2 ? workspace : NULL, A, B, C, op); | |||
if (B <= BY2) | |||
return; | |||
run_kern_multiAC<T, typename Op::ContigOp, true, reverse>( | |||
workspace, workspace + A*B_*C, A, B_, C, | |||
Op::make_contig(workspace), stream); | |||
update_kernel<T, Op, BY, BX><<<blocks, threads, 0, stream>>>( | |||
dst, workspace, A, B, C); | |||
} | |||
template <typename T, typename Op, bool exclusive, bool reverse> | |||
void run_kern_multiAC(T* dst, T* workspace, uint32_t A, uint32_t B, uint32_t C, | |||
const Op& op, cudaStream_t stream) { | |||
#define IF(BX, BY) \ | |||
do { \ | |||
if (vBX == BX && vBY == BY) { \ | |||
return do_run_kern<T, Op, exclusive, reverse, BX, BY>( \ | |||
dst, workspace, A, B, C, op, stream); \ | |||
} \ | |||
} while (0) | |||
uint32_t vBX, vBY; | |||
get_BX_BY(A, B, C, vBX, vBY); | |||
IF(1, 512); | |||
IF(2, 256); | |||
IF(4, 128); | |||
IF(8, 64); | |||
IF(16, 32); | |||
IF(32, 16); | |||
megdnn_trap(); | |||
#undef IF | |||
} | |||
//! wrap cub library for 1-dim scan | |||
namespace cubwrap { | |||
template <typename T, typename Op, bool reverse> | |||
class InputIterator : public std::iterator<std::random_access_iterator_tag, T> { | |||
int m_offset, m_len; | |||
Op m_op; | |||
public: | |||
InputIterator(Op op, int len) : m_offset(0), m_len(len), m_op(op) {} | |||
__device__ InputIterator(int offset, int len, Op op) | |||
: m_offset(offset), m_len(len), m_op(op) {} | |||
__device__ T operator[](int idx) { | |||
idx += m_offset; | |||
if (reverse) { | |||
idx = m_len - 1 - idx; | |||
} | |||
return m_op.visit(idx); | |||
} | |||
__device__ InputIterator operator+(int offset) { | |||
return InputIterator(m_offset + offset, m_len, m_op); | |||
} | |||
}; | |||
template <typename T, bool reverse> | |||
class OutputIterator | |||
: public std::iterator<std::random_access_iterator_tag, T> { | |||
int m_offset, m_len; | |||
T* m_dst; | |||
public: | |||
OutputIterator(T* dst, int len) : m_offset(0), m_len(len), m_dst(dst) {} | |||
__device__ OutputIterator(int offset, int len, T* dst) | |||
: m_offset(offset), m_len(len), m_dst(dst) {} | |||
__device__ T& operator[](int idx) { | |||
idx += m_offset; | |||
if (reverse) { | |||
idx = m_len - 1 - idx; | |||
} | |||
return m_dst[idx]; | |||
} | |||
__device__ OutputIterator operator+(int offset) { | |||
return OutputIterator(m_offset + offset, m_len, m_dst); | |||
} | |||
}; | |||
template <typename T, typename Op> | |||
struct ScanOp { | |||
__device__ __host__ T operator()(T a, T b) { | |||
// cub requires it to be a __device__ __host__ function but MegDNN has | |||
// no such contraint on Op::apply; so we just trap on host | |||
#ifdef __CUDA_ARCH__ | |||
return Op::apply(a, b); | |||
#else | |||
megdnn_trap(); | |||
#endif | |||
} | |||
}; | |||
template <typename T, typename Op, bool exclusive, bool reverse> | |||
void invoke(T* dst, void* workspace, size_t wk_size, const Op& op, uint32_t len, | |||
cudaStream_t stream) { | |||
InputIterator<T, Op, reverse> inp_iter(op, len); | |||
OutputIterator<T, reverse> out_iter(dst, len); | |||
ScanOp<T, Op> scan_op; | |||
if (exclusive) { | |||
cuda_check(cub::DeviceScan::ExclusiveScan(workspace, wk_size, inp_iter, | |||
out_iter, scan_op, Op::init(), | |||
len, stream)); | |||
} else { | |||
cuda_check(cub::DeviceScan::InclusiveScan( | |||
workspace, wk_size, inp_iter, out_iter, scan_op, len, stream)); | |||
} | |||
} | |||
} // namespace cubwrap | |||
} // namespace detail | |||
template <typename T, typename Op, bool exclusive, bool reverse> | |||
void run_kern(T* dst, void* workspace, uint32_t workspace_size, uint32_t A, | |||
uint32_t B, uint32_t C, const Op& op, cudaStream_t stream) { | |||
if (A == 1 && C == 1) { | |||
return detail::cubwrap::invoke<T, Op, exclusive, reverse>( | |||
dst, workspace, workspace_size, op, B, stream); | |||
} | |||
return detail::run_kern_multiAC<T, Op, exclusive, reverse>( | |||
dst, static_cast<T*>(workspace), A, B, C, op, stream); | |||
} | |||
} // namespace cumprod | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: ft=cuda syntax=cuda.doxygen |
@@ -1,63 +0,0 @@ | |||
#include "./opr_impl.h" | |||
#include "./kern.cuh" | |||
#include "src/common/reduce_helper_device.h" | |||
#include "src/cuda/utils.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
using namespace cumprod; | |||
namespace { | |||
/*! | |||
* \brief compute cumprod reduction on (A, B, C) tensor to (A, 1, C) | |||
*/ | |||
template <typename T, class Op> | |||
void dispatch( | |||
T* dst, T* workspace, size_t workspace_size, size_t A, size_t B, size_t C, | |||
bool exclusive, bool reverse, const Op& op, cudaStream_t stream) { | |||
#define IF(exclusive_v, reverse_v) \ | |||
if (exclusive == exclusive_v && reverse == reverse_v) { \ | |||
run_kern<T, Op, exclusive_v, reverse_v>( \ | |||
dst, workspace, workspace_size, A, B, C, op, stream); \ | |||
return; \ | |||
} | |||
IF(true, true) | |||
IF(true, false) | |||
IF(false, true) | |||
IF(false, false) | |||
megdnn_assert_internal(false); | |||
#undef IF | |||
} | |||
} // anonymous namespace | |||
void CumprodForwardImpl::exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
size_t A, B, C; | |||
reduce::get_ABC(src.layout, A, B, C, param().axis); | |||
auto stream = cuda_stream(handle()); | |||
#define cb(DType) \ | |||
if (src.layout.dtype == DType()) { \ | |||
using ctype = DTypeTrait<DType>::ctype; \ | |||
dispatch<ctype, ProdOp<ctype>>( \ | |||
dst.ptr<ctype>(), workspace.ptr<ctype>(), workspace.size, A, B, C, \ | |||
param().exclusive, param().reverse, src.ptr<ctype>(), stream); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
#undef cb | |||
megdnn_assert_internal(false); | |||
} | |||
size_t CumprodForwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout&) { | |||
size_t A, B, C; | |||
reduce::get_ABC(src, A, B, C, param().axis); | |||
cuda_check(cudaSetDevice(concrete_handle(handle())->device_id())); | |||
return cumprod::get_workspace_in_bytes(A, B, C, src.dtype.size()); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -1,19 +0,0 @@ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class CumprodForwardImpl : public CumprodForward { | |||
public: | |||
using CumprodForward::CumprodForward; | |||
void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) override; | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,7 +0,0 @@ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -1,7 +0,0 @@ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_float16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -1,5 +0,0 @@ | |||
// generated by gen_elemwise_kern_impls.py | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_float32 | |||
#include "../kern_impl.inl" |
@@ -1,6 +0,0 @@ | |||
// generated by gen_elemwise_multi_type_kern_impls.py | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_STYPE dt_qint8 | |||
#define KERN_IMPL_DTYPE dt_qint8 | |||
#include "../kern_impl.inl" |
@@ -268,7 +268,6 @@ IMPL_MODE_DISPATCHER(2, dt_quint4, dt_quint4); | |||
#undef FOREACH | |||
#define FOREACH(cb) \ | |||
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ | |||
MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ | |||
MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) \ | |||
MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) | |||
@@ -16,7 +16,6 @@ | |||
#include "src/cuda/convolution3d/opr_impl.h" | |||
#include "src/cuda/convpooling/opr_impl.h" | |||
#include "src/cuda/correlation/opr_impl.h" | |||
#include "src/cuda/cumprod/opr_impl.h" | |||
#include "src/cuda/cumsum/opr_impl.h" | |||
#include "src/cuda/cvt_color/opr_impl.h" | |||
#include "src/cuda/dct/opr_impl.h" | |||
@@ -117,7 +116,6 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SVDForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(CondTake); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(CumprodForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(CumsumForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward); | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward); | |||
@@ -266,7 +266,6 @@ INST(Mode::ATANH_GRAD); | |||
INST(Mode::SOFTPLUS_GRAD); | |||
INST(Mode::RELU6_GRAD); | |||
INST(Mode::HSIGMOID_GRAD); | |||
INST(Mode::SAFE_DIV); | |||
#undef INST | |||
} // namespace fallback | |||
} // namespace megdnn | |||
@@ -1,72 +0,0 @@ | |||
#include "src/naive/cumprod/opr_impl.h" | |||
#include "src/naive/handle.h" | |||
#include "src/common/reduce_helper.h" | |||
#include "src/common/utils.h" | |||
namespace { | |||
template <typename T> | |||
void exec_internal( | |||
const T* __restrict src, T* __restrict dst, size_t A, size_t B, size_t C, | |||
bool exclusive, bool reverse) { | |||
for (size_t a = 0; a < A; ++a) | |||
for (size_t c = 0; c < C; ++c) { | |||
if (exclusive && reverse) { | |||
T prod = T(1); | |||
for (size_t b = B; b > 0; --b) { | |||
dst[a * B * C + (b - 1) * C + c] = prod; | |||
prod *= src[a * B * C + (b - 1) * C + c]; | |||
} | |||
} else if (exclusive && !reverse) { | |||
T prod = T(1); | |||
for (size_t b = 0; b < B; ++b) { | |||
dst[a * B * C + b * C + c] = prod; | |||
prod *= src[a * B * C + b * C + c]; | |||
} | |||
} else if (!exclusive && reverse) { | |||
T prod = T(1); | |||
for (size_t b = B; b > 0; --b) { | |||
prod *= src[a * B * C + (b - 1) * C + c]; | |||
dst[a * B * C + (b - 1) * C + c] = prod; | |||
} | |||
} else if (!exclusive && !reverse) { | |||
T prod = T(1); | |||
for (size_t b = 0; b < B; ++b) { | |||
prod *= src[a * B * C + b * C + c]; | |||
dst[a * B * C + b * C + c] = prod; | |||
} | |||
} else { | |||
megdnn_assert_internal(false); | |||
} | |||
} | |||
} | |||
} // anonymous namespace | |||
namespace megdnn { | |||
namespace naive { | |||
void CumprodForwardImpl::exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
size_t A, B, C; | |||
reduce::get_ABC(src.layout, A, B, C, param().axis); | |||
#define cb(DType) \ | |||
if (src.layout.dtype == DType()) { \ | |||
using ctype = DTypeTrait<DType>::ctype; \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>( \ | |||
src.ptr<ctype>(), dst.ptr<ctype>(), A, B, C, param().exclusive, \ | |||
param().reverse)); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
megdnn_assert_internal(0); | |||
#undef cb | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,20 +0,0 @@ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class CumprodForwardImpl : public CumprodForward { | |||
public: | |||
using CumprodForward::CumprodForward; | |||
void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { | |||
return 0; | |||
} | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,7 +0,0 @@ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -1,7 +0,0 @@ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_float16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -1,5 +0,0 @@ | |||
// generated by gen_elemwise_kern_impls.py | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_float32 | |||
#include "../kern_impl.inl" |
@@ -5,8 +5,6 @@ | |||
#include "src/naive/elemwise/kern_caller.h" | |||
#include "src/naive/handle.h" | |||
#include <iostream> | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_naive_elemwise) | |||
@@ -18,7 +18,6 @@ | |||
#include "src/naive/convolution3d/opr_impl.h" | |||
#include "src/naive/convpooling/opr_impl.h" | |||
#include "src/naive/correlation/opr_impl.h" | |||
#include "src/naive/cumprod/opr_impl.h" | |||
#include "src/naive/cumsum/opr_impl.h" | |||
#include "src/naive/cvt_color/opr_impl.h" | |||
#include "src/naive/dct/opr_impl.h" | |||
@@ -1,7 +0,0 @@ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_bfloat16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -1,7 +0,0 @@ | |||
// generated by gen_elemwise_kern_impls.py | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_float16 | |||
#include "../kern_impl.inl" | |||
#endif |
@@ -1,5 +0,0 @@ | |||
// generated by gen_elemwise_kern_impls.py | |||
#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) | |||
#define KERN_IMPL_ARITY 2 | |||
#define KERN_IMPL_CTYPE dt_float32 | |||
#include "../kern_impl.inl" |
@@ -815,7 +815,7 @@ DEF_TEST(all_modes) { | |||
checker.set_rng(0, &abslt1_rng_f32); | |||
} else if ( | |||
mode == Mode::MOD || mode == Mode::TRUE_DIV || | |||
mode == Mode::FLOOR_DIV || mode == Mode::SAFE_DIV) { | |||
mode == Mode::FLOOR_DIV) { | |||
if (dtype.category() == DTypeCategory::INT) { | |||
checker.set_rng(0, &default_rng_i32); | |||
checker.set_rng(1, &nonzero_rng_i32); | |||
@@ -1,63 +0,0 @@ | |||
#include "test/cuda/fixture.h" | |||
#include "megdnn/oprs.h" | |||
#include "test/common/checker.h" | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(CUDA, CUMPROD) { | |||
Checker<Cumprod> checker(handle_cuda()); | |||
struct TestArg { | |||
param::Cumprod param; | |||
TensorShape shape; | |||
TestArg(param::Cumprod param, TensorShape shape) : param(param), shape(shape) {} | |||
}; | |||
std::vector<TestArg> args, args_int32; | |||
for (auto shape : | |||
TensorShapeArray{{10000}, {33000, 33}, {100, 100, 100}, {30, 30, 30, 30}}) { | |||
for (size_t axis = 0; axis < shape.ndim; ++axis) { | |||
args.emplace_back(param::Cumprod(axis, true, true), shape); | |||
args.emplace_back(param::Cumprod(axis, true, false), shape); | |||
args.emplace_back(param::Cumprod(axis, false, true), shape); | |||
args.emplace_back(param::Cumprod(axis, false, false), shape); | |||
} | |||
} | |||
for (auto shape : TensorShapeArray{{1}, {10}, {100}, {1000}, {10000}, {100000}}) { | |||
args.emplace_back(param::Cumprod(0, true, true), shape); | |||
args.emplace_back(param::Cumprod(0, true, false), shape); | |||
args.emplace_back(param::Cumprod(0, false, true), shape); | |||
args.emplace_back(param::Cumprod(0, false, false), shape); | |||
} | |||
for (auto shape : TensorShapeArray{ | |||
{1}, | |||
{10}, | |||
{100}, | |||
{1000}, | |||
{10000}, | |||
{100000}, | |||
{1000000}, | |||
{1050000}, | |||
{2100000}}) { | |||
args_int32.emplace_back(param::Cumprod(0, true, true), shape); | |||
args_int32.emplace_back(param::Cumprod(0, true, false), shape); | |||
args_int32.emplace_back(param::Cumprod(0, false, true), shape); | |||
args_int32.emplace_back(param::Cumprod(0, false, false), shape); | |||
} | |||
for (auto arg : args) { | |||
checker.set_param(arg.param); | |||
checker.set_epsilon(1e-2); | |||
checker.set_dtype(0, dtype::Float32()).execs({{arg.shape}, {}}); | |||
checker.set_dtype(0, dtype::Int16()).execs({{arg.shape}, {}}); | |||
checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}}); | |||
} | |||
for (auto arg : args_int32) { | |||
checker.set_param(arg.param); | |||
checker.set_epsilon(1e-2); | |||
checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}}); | |||
} | |||
} | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -494,49 +494,6 @@ TEST_F(NAIVE, CONV3D_RECORD) { | |||
} | |||
} | |||
//! cumprod | |||
TEST_F(NAIVE, CUMPROD_RECORD) { | |||
TaskRecordChecker<Cumprod> checker(2); | |||
struct TestArg { | |||
param::Cumprod param; | |||
TensorShape shape; | |||
TestArg(param::Cumprod param, TensorShape shape) : param(param), shape(shape) {} | |||
}; | |||
std::vector<TestArg> args, args_int32; | |||
for (auto shape : TensorShapeArray{{1000}, {330, 33}, {10, 10, 10}, {5, 5, 5, 5}}) { | |||
for (size_t axis = 0; axis < shape.ndim; ++axis) { | |||
args.emplace_back(param::Cumprod(axis, true, true), shape); | |||
args.emplace_back(param::Cumprod(axis, true, false), shape); | |||
args.emplace_back(param::Cumprod(axis, false, true), shape); | |||
args.emplace_back(param::Cumprod(axis, false, false), shape); | |||
} | |||
} | |||
for (auto shape : TensorShapeArray{{1}, {10}, {100}, {1000}, {10000}}) { | |||
args.emplace_back(param::Cumprod(0, true, true), shape); | |||
args.emplace_back(param::Cumprod(0, true, false), shape); | |||
args.emplace_back(param::Cumprod(0, false, true), shape); | |||
args.emplace_back(param::Cumprod(0, false, false), shape); | |||
} | |||
for (auto shape : TensorShapeArray{{1}, {10}, {100}, {1000}, {10000}}) { | |||
args_int32.emplace_back(param::Cumprod(0, true, true), shape); | |||
args_int32.emplace_back(param::Cumprod(0, true, false), shape); | |||
args_int32.emplace_back(param::Cumprod(0, false, true), shape); | |||
args_int32.emplace_back(param::Cumprod(0, false, false), shape); | |||
} | |||
for (auto arg : args) { | |||
checker.set_param(arg.param); | |||
checker.set_epsilon(1e-2); | |||
checker.set_dtype(0, dtype::Float32()).execs({{arg.shape}, {}}); | |||
checker.set_dtype(0, dtype::Int16()).execs({{arg.shape}, {}}); | |||
checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}}); | |||
} | |||
for (auto arg : args_int32) { | |||
checker.set_param(arg.param); | |||
checker.set_epsilon(1e-2); | |||
checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}}); | |||
} | |||
} | |||
//! cumsum | |||
TEST_F(NAIVE, CUMSUM_RECORD) { | |||
TaskRecordChecker<Cumsum> checker(2); | |||
@@ -27,7 +27,6 @@ __all__ = [ | |||
"broadcast_to", | |||
"concat", | |||
"cond_take", | |||
"cumprod", | |||
"cumsum", | |||
"diag", | |||
"expand_dims", | |||
@@ -1140,22 +1139,6 @@ def cumsum(inp: Tensor, axis: int): | |||
Tensor([[ 1 3 6] | |||
[ 4 9 15]], dtype=int32, device=xpux:0) | |||
""" | |||
assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor" | |||
op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False) | |||
return apply(op, inp)[0] | |||
def cumprod(inp: Tensor, axis: int): | |||
r"""Computes the cumulative product of elements along given axis. | |||
Args: | |||
inp: input tensor. | |||
axis: axis along which cumprod is performed. | |||
Examples: | |||
>>> x = Tensor([[1, 2, 3], [4, 5, 6]], "int32") | |||
>>> F.cumprod(x, 1) | |||
Tensor([[ 1 2 6] | |||
[ 4 20 120]], dtype=int32, device=xpux:0) | |||
""" | |||
op = builtin.Cumprod(axis=axis, exclusive=False, reverse=False) | |||
return apply(op, inp)[0] |
@@ -501,22 +501,6 @@ def test_dot(): | |||
np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy()) | |||
def test_cumprod(): | |||
x = mge.Parameter(F.full((2, 2), 2.0, dtype="float32")) | |||
with Grad() as grad: | |||
grad.wrt(x, callback=save_to(x)) | |||
def f(x): | |||
return F.cumprod(x, axis=0) | |||
y = f(x) | |||
grad(y, F.ones_like(y)) | |||
expected = np.array([[3.0, 3.0], [2.0, 2.0]], dtype=np.float32) | |||
np.testing.assert_almost_equal(x.grad.numpy(), expected) | |||
def test_pixel_shuffle(): | |||
x = np.random.rand(2, 3, 16, 3, 4).astype("float32") | |||
@@ -1,34 +0,0 @@ | |||
#include "../blob_manager_impl.h" | |||
#include "../dnn_op_helper.h" | |||
#include "../op_trait.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/opr/misc.h" | |||
namespace mgb::imperative { | |||
namespace { | |||
namespace cumsum { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Cumsum&>(def); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::Cumsum::make(inputs[0], op.param(), config); | |||
} | |||
OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace cumsum | |||
} // namespace | |||
namespace { | |||
namespace cumprod { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = def.cast_final_safe<Cumprod>(); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::Cumprod::make(inputs[0], op.param(), config); | |||
} | |||
OP_TRAIT_REG(Cumprod, Cumprod).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace cumprod | |||
} // namespace | |||
} // namespace mgb::imperative |
@@ -652,6 +652,18 @@ OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose) | |||
} // namespace sliding_window_transpose | |||
} // namespace | |||
namespace { | |||
namespace cumsum { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Cumsum&>(def); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::Cumsum::make(inputs[0], op.param(), config); | |||
} | |||
OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback(); | |||
} // namespace cumsum | |||
} // namespace | |||
namespace lrn { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const LRN&>(def); | |||
@@ -1,7 +1,7 @@ | |||
792a91e469d151906d5275ebd65cedf0 ../../dnn/scripts/opr_param_defs.py | |||
dcf9ae8b7881e9f93870aaed1b18f1dd ../../src/core/include/megbrain/ir/ops.td | |||
59bf6c34770e60b030d962484f2eac3b generated/opdef.h.inl | |||
d406f17445b9cec6c34b77377f15a61f generated/opdef.cpp.inl | |||
67fd2dee346d795220231fc266d260cc generated/opdef.py.inl | |||
65740fff1c79fdb2f955dd52df88fb6e generated/opdef.cpy.inl | |||
905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py | |||
759bfbf27fd3f0dd6b6edf06377e1d6b ../../src/core/include/megbrain/ir/ops.td | |||
2a5851d0e2470d4d045811e7a20b1a3f generated/opdef.h.inl | |||
55b862badeed19aed8e84c5d6f468ff2 generated/opdef.cpp.inl | |||
f3f4c7f0ee1b39392df8a679f6d22596 generated/opdef.py.inl | |||
6b11ca844a7855fdc5eebffaf563a89c generated/opdef.cpy.inl | |||
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h |
@@ -2303,49 +2303,6 @@ OP_TRAIT_REG(Correlation, Correlation) | |||
.props(Correlation_props_impl) | |||
.make_name(Correlation_make_name_impl); | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumprod); | |||
namespace { | |||
size_t Cumprod_hash_impl(const OpDef& def_) { | |||
auto&& op_ = def_.cast_final_safe<Cumprod>(); | |||
static_cast<void>(op_); | |||
size_t val = mgb::hash(op_.dyn_typeinfo()); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.axis)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.exclusive)); | |||
val = mgb::hash_pair_combine(val, mgb::hash(op_.reverse)); | |||
return val; | |||
} | |||
bool Cumprod_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { | |||
auto &&a_ = lhs_.cast_final_safe<Cumprod>(), | |||
&&b_ = rhs_.cast_final_safe<Cumprod>(); | |||
static_cast<void>(a_); | |||
static_cast<void>(b_); | |||
if (a_.axis != b_.axis) return false; | |||
if (a_.exclusive != b_.exclusive) return false; | |||
if (a_.reverse != b_.reverse) return false; | |||
return true; | |||
} | |||
std::vector<std::pair<const char*, std::string>> Cumprod_props_impl(const OpDef& def_) { | |||
auto&& op_ = def_.cast_final_safe<Cumprod>(); | |||
static_cast<void>(op_); | |||
std::vector<std::pair<const char*, std::string>> props_; | |||
props_.emplace_back("axis", std::to_string(op_.axis)); | |||
props_.emplace_back("exclusive", std::to_string(op_.exclusive)); | |||
props_.emplace_back("reverse", std::to_string(op_.reverse)); | |||
return props_; | |||
} | |||
std::string Cumprod_make_name_impl(const OpDef& def_) { | |||
auto&& op_ = def_.cast_final_safe<Cumprod>(); | |||
static_cast<void>(op_); | |||
return "Cumprod"; | |||
} | |||
} // anonymous namespace | |||
OP_TRAIT_REG(Cumprod, Cumprod) | |||
.hash(Cumprod_hash_impl) | |||
.is_same_st(Cumprod_is_same_st_impl) | |||
.props(Cumprod_props_impl) | |||
.make_name(Cumprod_make_name_impl); | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumsum); | |||
namespace { | |||
@@ -3110,75 +3067,6 @@ std::vector<std::pair<const char*, std::string>> Elemwise_props_impl(const OpDef | |||
case Elemwise::Mode::COND_LT_MOV: | |||
props_.emplace_back("mode", "COND_LT_MOV"); | |||
break; | |||
case Elemwise::Mode::SINH: | |||
props_.emplace_back("mode", "SINH"); | |||
break; | |||
case Elemwise::Mode::COSH: | |||
props_.emplace_back("mode", "COSH"); | |||
break; | |||
case Elemwise::Mode::ASINH: | |||
props_.emplace_back("mode", "ASINH"); | |||
break; | |||
case Elemwise::Mode::ACOSH: | |||
props_.emplace_back("mode", "ACOSH"); | |||
break; | |||
case Elemwise::Mode::ATANH: | |||
props_.emplace_back("mode", "ATANH"); | |||
break; | |||
case Elemwise::Mode::TAN: | |||
props_.emplace_back("mode", "TAN"); | |||
break; | |||
case Elemwise::Mode::ASINH_GRAD: | |||
props_.emplace_back("mode", "ASINH_GRAD"); | |||
break; | |||
case Elemwise::Mode::ACOSH_GRAD: | |||
props_.emplace_back("mode", "ACOSH_GRAD"); | |||
break; | |||
case Elemwise::Mode::ATANH_GRAD: | |||
props_.emplace_back("mode", "ATANH_GRAD"); | |||
break; | |||
case Elemwise::Mode::PRELU: | |||
props_.emplace_back("mode", "PRELU"); | |||
break; | |||
case Elemwise::Mode::CLIP: | |||
props_.emplace_back("mode", "CLIP"); | |||
break; | |||
case Elemwise::Mode::PRELU_GRAD: | |||
props_.emplace_back("mode", "PRELU_GRAD"); | |||
break; | |||
case Elemwise::Mode::SOFTPLUS: | |||
props_.emplace_back("mode", "SOFTPLUS"); | |||
break; | |||
case Elemwise::Mode::SOFTPLUS_GRAD: | |||
props_.emplace_back("mode", "SOFTPLUS_GRAD"); | |||
break; | |||
case Elemwise::Mode::RELU6: | |||
props_.emplace_back("mode", "RELU6"); | |||
break; | |||
case Elemwise::Mode::RELU6_GRAD: | |||
props_.emplace_back("mode", "RELU6_GRAD"); | |||
break; | |||
case Elemwise::Mode::HSIGMOID: | |||
props_.emplace_back("mode", "HSIGMOID"); | |||
break; | |||
case Elemwise::Mode::HSIGMOID_GRAD: | |||
props_.emplace_back("mode", "HSIGMOID_GRAD"); | |||
break; | |||
case Elemwise::Mode::LOGSIGMOID: | |||
props_.emplace_back("mode", "LOGSIGMOID"); | |||
break; | |||
case Elemwise::Mode::SQRT: | |||
props_.emplace_back("mode", "SQRT"); | |||
break; | |||
case Elemwise::Mode::SQUARE: | |||
props_.emplace_back("mode", "SQUARE"); | |||
break; | |||
case Elemwise::Mode::SIGN: | |||
props_.emplace_back("mode", "SIGN"); | |||
break; | |||
case Elemwise::Mode::SAFE_DIV: | |||
props_.emplace_back("mode", "SAFE_DIV"); | |||
break; | |||
case Elemwise::Mode::NEQ: | |||
props_.emplace_back("mode", "NEQ"); | |||
break; | |||
@@ -6674,131 +6674,6 @@ void _init_py_Correlation(py::module m) { | |||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Correlation::typeinfo(), &py_type).second); | |||
} | |||
PyOpDefBegin(Cumprod) // { | |||
static PyGetSetDef py_getsetters[]; | |||
static PyMethodDef tp_methods[]; | |||
static PyObject* getstate(PyObject* self, PyObject*) { | |||
auto& opdef = reinterpret_cast<PyOp(Cumprod)*>(self)->inst(); | |||
static_cast<void>(opdef); | |||
std::unordered_map<std::string, py::object> state { | |||
{"axis", serialization<decltype(opdef.axis)>::dump(opdef.axis)}, | |||
{"exclusive", serialization<decltype(opdef.exclusive)>::dump(opdef.exclusive)}, | |||
{"reverse", serialization<decltype(opdef.reverse)>::dump(opdef.reverse)} | |||
}; | |||
return py::cast(state).release().ptr(); | |||
} | |||
static PyObject* setstate(PyObject* self, PyObject* args) { | |||
PyObject* dict = PyTuple_GetItem(args, 0); | |||
if (!dict) return NULL; | |||
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict); | |||
auto& opdef = reinterpret_cast<PyOp(Cumprod)*>(self)->inst(); | |||
static_cast<void>(opdef); | |||
{ | |||
auto&& iter = state.find("axis"); | |||
if (iter != state.end()) { | |||
opdef.axis = serialization<decltype(opdef.axis)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("exclusive"); | |||
if (iter != state.end()) { | |||
opdef.exclusive = serialization<decltype(opdef.exclusive)>::load(iter->second); | |||
} | |||
} | |||
{ | |||
auto&& iter = state.find("reverse"); | |||
if (iter != state.end()) { | |||
opdef.reverse = serialization<decltype(opdef.reverse)>::load(iter->second); | |||
} | |||
} | |||
Py_RETURN_NONE; | |||
} | |||
static int py_init(PyObject *self, PyObject *args, PyObject *kwds); | |||
// }; | |||
PyOpDefEnd(Cumprod) | |||
int PyOp(Cumprod)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { | |||
static const char* kwlist[] = {"axis", "exclusive", "reverse", "scope", NULL}; | |||
PyObject *axis = NULL, *exclusive = NULL, *reverse = NULL, *scope = NULL; | |||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOO", const_cast<char**>(kwlist), &axis, &exclusive, &reverse, &scope)) | |||
return -1; | |||
if (axis) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(Cumprod)*>(self)->inst().axis = | |||
py::cast<decltype(Cumprod::axis)>(py::handle(axis)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (exclusive) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(Cumprod)*>(self)->inst().exclusive = | |||
py::cast<decltype(Cumprod::exclusive)>(py::handle(exclusive)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (reverse) { | |||
try { | |||
// TODO: remove this guard which is used for pybind11 implicit conversion | |||
py::detail::loader_life_support guard{}; | |||
reinterpret_cast<PyOp(Cumprod)*>(self)->inst().reverse = | |||
py::cast<decltype(Cumprod::reverse)>(py::handle(reverse)); | |||
} CATCH_ALL(-1) | |||
} | |||
if (scope) { | |||
try { | |||
reinterpret_cast<PyOp(OpDef)*>(self)->op | |||
->set_scope(py::cast<std::string>(py::handle(scope))); | |||
} CATCH_ALL(-1) | |||
} | |||
return 0; | |||
} | |||
PyGetSetDef PyOp(Cumprod)::py_getsetters[] = { | |||
{const_cast<char*>("axis"), py_get_generic(Cumprod, axis), py_set_generic(Cumprod, axis), const_cast<char*>("axis"), NULL}, | |||
{const_cast<char*>("exclusive"), py_get_generic(Cumprod, exclusive), py_set_generic(Cumprod, exclusive), const_cast<char*>("exclusive"), NULL}, | |||
{const_cast<char*>("reverse"), py_get_generic(Cumprod, reverse), py_set_generic(Cumprod, reverse), const_cast<char*>("reverse"), NULL}, | |||
{NULL} /* Sentinel */ | |||
}; | |||
PyMethodDef PyOp(Cumprod)::tp_methods[] = { | |||
{const_cast<char*>("__getstate__"), PyOp(Cumprod)::getstate, METH_NOARGS, "Cumprod getstate"}, | |||
{const_cast<char*>("__setstate__"), PyOp(Cumprod)::setstate, METH_VARARGS, "Cumprod setstate"}, | |||
{NULL} /* Sentinel */ | |||
}; | |||
void _init_py_Cumprod(py::module m) { | |||
using py_op = PyOp(Cumprod); | |||
auto& py_type = PyOpType(Cumprod); | |||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
py_type.tp_name = "megengine.core._imperative_rt.ops.Cumprod"; | |||
py_type.tp_basicsize = sizeof(PyOp(Cumprod)); | |||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
py_type.tp_doc = "Cumprod"; | |||
py_type.tp_base = &PyOpType(OpDef); | |||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
py_type.tp_new = py_new_generic<py_op>; | |||
py_type.tp_init = py_op::py_init; | |||
py_type.tp_methods = py_op::tp_methods; | |||
py_type.tp_getset = py_op::py_getsetters; | |||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||
PyType_Modified(&py_type); | |||
m.add_object("Cumprod", reinterpret_cast<PyObject*>(&py_type)); | |||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Cumprod::typeinfo(), &py_type).second); | |||
} | |||
PyOpDefBegin(Cumsum) // { | |||
static PyGetSetDef py_getsetters[]; | |||
static PyMethodDef tp_methods[]; | |||
@@ -8135,16 +8010,16 @@ void _init_py_Dropout(py::module m) { | |||
template<> struct EnumTrait<Elemwise::Mode> { | |||
static constexpr const char *name = "Elemwise.Mode"; | |||
static constexpr std::underlying_type_t<Elemwise::Mode> max = 87 - 1; | |||
static constexpr std::underlying_type_t<Elemwise::Mode> max = 64 - 1; | |||
}; | |||
template<> PyTypeObject* EnumWrapper<Elemwise::Mode>::type = nullptr; | |||
template<> const char* | |||
EnumWrapper<Elemwise::Mode>::members[] = {"RELU", "ABS", "ACOS", "ASIN", "CEIL", "COS", "EXP", "EXPM1", "FLOOR", "LOG", "LOG1P", "NEGATE", "SIGMOID", "SIN", "TANH", "ABS_GRAD", "ADD", "FLOOR_DIV", "MAX", "MIN", "MOD", "MUL", "POW", "SIGMOID_GRAD", "SUB", "SWITCH_GT0", "TANH_GRAD", "TRUE_DIV", "LOG_SUM_EXP", "LT", "LEQ", "EQ", "SHL", "SHR", "COND_LEQ_MOV", "FUSE_MUL_ADD3", "FUSE_MUL_ADD4", "FUSE_ADD_RELU", "FUSE_ADD_SIGMOID", "FUSE_ADD_TANH", "FAST_TANH", "FAST_TANH_GRAD", "ROUND", "RMULH", "ATAN2", "ERF", "ERFINV", "ERFC", "ERFCINV", "H_SWISH", "H_SWISH_GRAD", "FUSE_ADD_H_SWISH", "NOT", "AND", "OR", "XOR", "SILU", "SILU_GRAD", "GELU", "GELU_GRAD", "COND_LT_MOV", "SINH", "COSH", "ASINH", "ACOSH", "ATANH", "TAN", "ASINH_GRAD", "ACOSH_GRAD", "ATANH_GRAD", "PRELU", "CLIP", "PRELU_GRAD", "SOFTPLUS", "SOFTPLUS_GRAD", "RELU6", "RELU6_GRAD", "HSIGMOID", "HSIGMOID_GRAD", "LOGSIGMOID", "SQRT", "SQUARE", "SIGN", "SAFE_DIV", "NEQ", "ISNAN", "ISINF"}; | |||
EnumWrapper<Elemwise::Mode>::members[] = {"RELU", "ABS", "ACOS", "ASIN", "CEIL", "COS", "EXP", "EXPM1", "FLOOR", "LOG", "LOG1P", "NEGATE", "SIGMOID", "SIN", "TANH", "ABS_GRAD", "ADD", "FLOOR_DIV", "MAX", "MIN", "MOD", "MUL", "POW", "SIGMOID_GRAD", "SUB", "SWITCH_GT0", "TANH_GRAD", "TRUE_DIV", "LOG_SUM_EXP", "LT", "LEQ", "EQ", "SHL", "SHR", "COND_LEQ_MOV", "FUSE_MUL_ADD3", "FUSE_MUL_ADD4", "FUSE_ADD_RELU", "FUSE_ADD_SIGMOID", "FUSE_ADD_TANH", "FAST_TANH", "FAST_TANH_GRAD", "ROUND", "RMULH", "ATAN2", "ERF", "ERFINV", "ERFC", "ERFCINV", "H_SWISH", "H_SWISH_GRAD", "FUSE_ADD_H_SWISH", "NOT", "AND", "OR", "XOR", "SILU", "SILU_GRAD", "GELU", "GELU_GRAD", "COND_LT_MOV", "NEQ", "ISNAN", "ISINF"}; | |||
template<> std::unordered_map<std::string, Elemwise::Mode> | |||
EnumWrapper<Elemwise::Mode>::mem2value = {{normalize_enum("RELU"), Elemwise::Mode::RELU}, {normalize_enum("ABS"), Elemwise::Mode::ABS}, {normalize_enum("ACOS"), Elemwise::Mode::ACOS}, {normalize_enum("ASIN"), Elemwise::Mode::ASIN}, {normalize_enum("CEIL"), Elemwise::Mode::CEIL}, {normalize_enum("COS"), Elemwise::Mode::COS}, {normalize_enum("EXP"), Elemwise::Mode::EXP}, {normalize_enum("EXPM1"), Elemwise::Mode::EXPM1}, {normalize_enum("FLOOR"), Elemwise::Mode::FLOOR}, {normalize_enum("LOG"), Elemwise::Mode::LOG}, {normalize_enum("LOG1P"), Elemwise::Mode::LOG1P}, {normalize_enum("NEGATE"), Elemwise::Mode::NEGATE}, {normalize_enum("SIGMOID"), Elemwise::Mode::SIGMOID}, {normalize_enum("SIN"), Elemwise::Mode::SIN}, {normalize_enum("TANH"), Elemwise::Mode::TANH}, {normalize_enum("ABS_GRAD"), Elemwise::Mode::ABS_GRAD}, {normalize_enum("ADD"), Elemwise::Mode::ADD}, {normalize_enum("FLOOR_DIV"), Elemwise::Mode::FLOOR_DIV}, {normalize_enum("MAX"), Elemwise::Mode::MAX}, {normalize_enum("MIN"), Elemwise::Mode::MIN}, {normalize_enum("MOD"), Elemwise::Mode::MOD}, {normalize_enum("MUL"), Elemwise::Mode::MUL}, {normalize_enum("POW"), Elemwise::Mode::POW}, {normalize_enum("SIGMOID_GRAD"), Elemwise::Mode::SIGMOID_GRAD}, {normalize_enum("SUB"), Elemwise::Mode::SUB}, {normalize_enum("SWITCH_GT0"), Elemwise::Mode::SWITCH_GT0}, {normalize_enum("TANH_GRAD"), Elemwise::Mode::TANH_GRAD}, {normalize_enum("TRUE_DIV"), Elemwise::Mode::TRUE_DIV}, {normalize_enum("LOG_SUM_EXP"), Elemwise::Mode::LOG_SUM_EXP}, {normalize_enum("LT"), Elemwise::Mode::LT}, {normalize_enum("LEQ"), Elemwise::Mode::LEQ}, {normalize_enum("EQ"), Elemwise::Mode::EQ}, {normalize_enum("SHL"), Elemwise::Mode::SHL}, {normalize_enum("SHR"), Elemwise::Mode::SHR}, {normalize_enum("COND_LEQ_MOV"), Elemwise::Mode::COND_LEQ_MOV}, {normalize_enum("FUSE_MUL_ADD3"), Elemwise::Mode::FUSE_MUL_ADD3}, {normalize_enum("FUSE_MUL_ADD4"), Elemwise::Mode::FUSE_MUL_ADD4}, {normalize_enum("FUSE_ADD_RELU"), Elemwise::Mode::FUSE_ADD_RELU}, {normalize_enum("FUSE_ADD_SIGMOID"), Elemwise::Mode::FUSE_ADD_SIGMOID}, {normalize_enum("FUSE_ADD_TANH"), Elemwise::Mode::FUSE_ADD_TANH}, {normalize_enum("FAST_TANH"), Elemwise::Mode::FAST_TANH}, {normalize_enum("FAST_TANH_GRAD"), Elemwise::Mode::FAST_TANH_GRAD}, {normalize_enum("ROUND"), Elemwise::Mode::ROUND}, {normalize_enum("RMULH"), Elemwise::Mode::RMULH}, {normalize_enum("ATAN2"), Elemwise::Mode::ATAN2}, {normalize_enum("ERF"), Elemwise::Mode::ERF}, {normalize_enum("ERFINV"), Elemwise::Mode::ERFINV}, {normalize_enum("ERFC"), Elemwise::Mode::ERFC}, {normalize_enum("ERFCINV"), Elemwise::Mode::ERFCINV}, {normalize_enum("H_SWISH"), Elemwise::Mode::H_SWISH}, {normalize_enum("H_SWISH_GRAD"), Elemwise::Mode::H_SWISH_GRAD}, {normalize_enum("FUSE_ADD_H_SWISH"), Elemwise::Mode::FUSE_ADD_H_SWISH}, {normalize_enum("NOT"), Elemwise::Mode::NOT}, {normalize_enum("AND"), Elemwise::Mode::AND}, {normalize_enum("OR"), Elemwise::Mode::OR}, {normalize_enum("XOR"), Elemwise::Mode::XOR}, {normalize_enum("SILU"), Elemwise::Mode::SILU}, {normalize_enum("SILU_GRAD"), Elemwise::Mode::SILU_GRAD}, {normalize_enum("GELU"), Elemwise::Mode::GELU}, {normalize_enum("GELU_GRAD"), Elemwise::Mode::GELU_GRAD}, {normalize_enum("COND_LT_MOV"), Elemwise::Mode::COND_LT_MOV}, {normalize_enum("SINH"), Elemwise::Mode::SINH}, {normalize_enum("COSH"), Elemwise::Mode::COSH}, {normalize_enum("ASINH"), Elemwise::Mode::ASINH}, {normalize_enum("ACOSH"), Elemwise::Mode::ACOSH}, {normalize_enum("ATANH"), Elemwise::Mode::ATANH}, {normalize_enum("TAN"), Elemwise::Mode::TAN}, {normalize_enum("ASINH_GRAD"), Elemwise::Mode::ASINH_GRAD}, {normalize_enum("ACOSH_GRAD"), Elemwise::Mode::ACOSH_GRAD}, {normalize_enum("ATANH_GRAD"), Elemwise::Mode::ATANH_GRAD}, {normalize_enum("PRELU"), Elemwise::Mode::PRELU}, {normalize_enum("CLIP"), Elemwise::Mode::CLIP}, {normalize_enum("PRELU_GRAD"), Elemwise::Mode::PRELU_GRAD}, {normalize_enum("SOFTPLUS"), Elemwise::Mode::SOFTPLUS}, {normalize_enum("SOFTPLUS_GRAD"), Elemwise::Mode::SOFTPLUS_GRAD}, {normalize_enum("RELU6"), Elemwise::Mode::RELU6}, {normalize_enum("RELU6_GRAD"), Elemwise::Mode::RELU6_GRAD}, {normalize_enum("HSIGMOID"), Elemwise::Mode::HSIGMOID}, {normalize_enum("HSIGMOID_GRAD"), Elemwise::Mode::HSIGMOID_GRAD}, {normalize_enum("LOGSIGMOID"), Elemwise::Mode::LOGSIGMOID}, {normalize_enum("SQRT"), Elemwise::Mode::SQRT}, {normalize_enum("SQUARE"), Elemwise::Mode::SQUARE}, {normalize_enum("SIGN"), Elemwise::Mode::SIGN}, {normalize_enum("SAFE_DIV"), Elemwise::Mode::SAFE_DIV}, {normalize_enum("NEQ"), Elemwise::Mode::NEQ}, {normalize_enum("ISNAN"), Elemwise::Mode::ISNAN}, {normalize_enum("ISINF"), Elemwise::Mode::ISINF}}; | |||
template<> PyObject* EnumWrapper<Elemwise::Mode>::pyobj_insts[87] = {nullptr}; | |||
EnumWrapper<Elemwise::Mode>::mem2value = {{normalize_enum("RELU"), Elemwise::Mode::RELU}, {normalize_enum("ABS"), Elemwise::Mode::ABS}, {normalize_enum("ACOS"), Elemwise::Mode::ACOS}, {normalize_enum("ASIN"), Elemwise::Mode::ASIN}, {normalize_enum("CEIL"), Elemwise::Mode::CEIL}, {normalize_enum("COS"), Elemwise::Mode::COS}, {normalize_enum("EXP"), Elemwise::Mode::EXP}, {normalize_enum("EXPM1"), Elemwise::Mode::EXPM1}, {normalize_enum("FLOOR"), Elemwise::Mode::FLOOR}, {normalize_enum("LOG"), Elemwise::Mode::LOG}, {normalize_enum("LOG1P"), Elemwise::Mode::LOG1P}, {normalize_enum("NEGATE"), Elemwise::Mode::NEGATE}, {normalize_enum("SIGMOID"), Elemwise::Mode::SIGMOID}, {normalize_enum("SIN"), Elemwise::Mode::SIN}, {normalize_enum("TANH"), Elemwise::Mode::TANH}, {normalize_enum("ABS_GRAD"), Elemwise::Mode::ABS_GRAD}, {normalize_enum("ADD"), Elemwise::Mode::ADD}, {normalize_enum("FLOOR_DIV"), Elemwise::Mode::FLOOR_DIV}, {normalize_enum("MAX"), Elemwise::Mode::MAX}, {normalize_enum("MIN"), Elemwise::Mode::MIN}, {normalize_enum("MOD"), Elemwise::Mode::MOD}, {normalize_enum("MUL"), Elemwise::Mode::MUL}, {normalize_enum("POW"), Elemwise::Mode::POW}, {normalize_enum("SIGMOID_GRAD"), Elemwise::Mode::SIGMOID_GRAD}, {normalize_enum("SUB"), Elemwise::Mode::SUB}, {normalize_enum("SWITCH_GT0"), Elemwise::Mode::SWITCH_GT0}, {normalize_enum("TANH_GRAD"), Elemwise::Mode::TANH_GRAD}, {normalize_enum("TRUE_DIV"), Elemwise::Mode::TRUE_DIV}, {normalize_enum("LOG_SUM_EXP"), Elemwise::Mode::LOG_SUM_EXP}, {normalize_enum("LT"), Elemwise::Mode::LT}, {normalize_enum("LEQ"), Elemwise::Mode::LEQ}, {normalize_enum("EQ"), Elemwise::Mode::EQ}, {normalize_enum("SHL"), Elemwise::Mode::SHL}, {normalize_enum("SHR"), Elemwise::Mode::SHR}, {normalize_enum("COND_LEQ_MOV"), Elemwise::Mode::COND_LEQ_MOV}, {normalize_enum("FUSE_MUL_ADD3"), Elemwise::Mode::FUSE_MUL_ADD3}, {normalize_enum("FUSE_MUL_ADD4"), Elemwise::Mode::FUSE_MUL_ADD4}, {normalize_enum("FUSE_ADD_RELU"), Elemwise::Mode::FUSE_ADD_RELU}, {normalize_enum("FUSE_ADD_SIGMOID"), Elemwise::Mode::FUSE_ADD_SIGMOID}, {normalize_enum("FUSE_ADD_TANH"), Elemwise::Mode::FUSE_ADD_TANH}, {normalize_enum("FAST_TANH"), Elemwise::Mode::FAST_TANH}, {normalize_enum("FAST_TANH_GRAD"), Elemwise::Mode::FAST_TANH_GRAD}, {normalize_enum("ROUND"), Elemwise::Mode::ROUND}, {normalize_enum("RMULH"), Elemwise::Mode::RMULH}, {normalize_enum("ATAN2"), Elemwise::Mode::ATAN2}, {normalize_enum("ERF"), Elemwise::Mode::ERF}, {normalize_enum("ERFINV"), Elemwise::Mode::ERFINV}, {normalize_enum("ERFC"), Elemwise::Mode::ERFC}, {normalize_enum("ERFCINV"), Elemwise::Mode::ERFCINV}, {normalize_enum("H_SWISH"), Elemwise::Mode::H_SWISH}, {normalize_enum("H_SWISH_GRAD"), Elemwise::Mode::H_SWISH_GRAD}, {normalize_enum("FUSE_ADD_H_SWISH"), Elemwise::Mode::FUSE_ADD_H_SWISH}, {normalize_enum("NOT"), Elemwise::Mode::NOT}, {normalize_enum("AND"), Elemwise::Mode::AND}, {normalize_enum("OR"), Elemwise::Mode::OR}, {normalize_enum("XOR"), Elemwise::Mode::XOR}, {normalize_enum("SILU"), Elemwise::Mode::SILU}, {normalize_enum("SILU_GRAD"), Elemwise::Mode::SILU_GRAD}, {normalize_enum("GELU"), Elemwise::Mode::GELU}, {normalize_enum("GELU_GRAD"), Elemwise::Mode::GELU_GRAD}, {normalize_enum("COND_LT_MOV"), Elemwise::Mode::COND_LT_MOV}, {normalize_enum("NEQ"), Elemwise::Mode::NEQ}, {normalize_enum("ISNAN"), Elemwise::Mode::ISNAN}, {normalize_enum("ISINF"), Elemwise::Mode::ISINF}}; | |||
template<> PyObject* EnumWrapper<Elemwise::Mode>::pyobj_insts[64] = {nullptr}; | |||
void _init_py_Elemwise_Mode(PyTypeObject& py_type) { | |||
auto& e_type = EnumWrapper<Elemwise::Mode>::type; | |||
@@ -8499,134 +8374,19 @@ void _init_py_Elemwise_Mode(PyTypeObject& py_type) { | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[60] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::SINH; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SINH", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[61] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::COSH; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "COSH", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[62] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::ASINH; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ASINH", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[63] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::ACOSH; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ACOSH", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[64] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::ATANH; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ATANH", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[65] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::TAN; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "TAN", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[66] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::ASINH_GRAD; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ASINH_GRAD", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[67] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::ACOSH_GRAD; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ACOSH_GRAD", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[68] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::ATANH_GRAD; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ATANH_GRAD", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[69] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::PRELU; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "PRELU", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[70] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::CLIP; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "CLIP", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[71] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::PRELU_GRAD; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "PRELU_GRAD", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[72] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::SOFTPLUS; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SOFTPLUS", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[73] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::SOFTPLUS_GRAD; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SOFTPLUS_GRAD", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[74] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::RELU6; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "RELU6", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[75] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::RELU6_GRAD; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "RELU6_GRAD", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[76] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::HSIGMOID; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "HSIGMOID", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[77] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::HSIGMOID_GRAD; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "HSIGMOID_GRAD", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[78] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::LOGSIGMOID; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "LOGSIGMOID", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[79] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::SQRT; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SQRT", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[80] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::SQUARE; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SQUARE", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[81] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::SIGN; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SIGN", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[82] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::SAFE_DIV; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SAFE_DIV", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[83] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::NEQ; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "NEQ", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[84] = inst; | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[61] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::ISNAN; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ISNAN", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[85] = inst; | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[62] = inst; | |||
}{ | |||
PyObject* inst = e_type->tp_alloc(e_type, 0); | |||
reinterpret_cast<EnumWrapper<Elemwise::Mode>*>(inst)->value = Elemwise::Mode::ISINF; | |||
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ISINF", inst) >= 0); | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[86] = inst; | |||
EnumWrapper<Elemwise::Mode>::pyobj_insts[63] = inst; | |||
} | |||
Py_INCREF(e_type); | |||
mgb_assert(PyDict_SetItemString( | |||
@@ -18456,7 +18216,6 @@ void _init_py_WarpPerspective(py::module m) { | |||
_init_py_ConvolutionBackwardData(m); \ | |||
_init_py_Copy(m); \ | |||
_init_py_Correlation(m); \ | |||
_init_py_Cumprod(m); \ | |||
_init_py_Cumsum(m); \ | |||
_init_py_CvtColor(m); \ | |||
_init_py_DeformableConv(m); \ | |||
@@ -567,21 +567,6 @@ public: | |||
} | |||
}; | |||
class Cumprod : public OpDefImplBase<Cumprod> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
int32_t axis = 2147483647; | |||
bool exclusive = true; | |||
bool reverse = false; | |||
Cumprod() = default; | |||
Cumprod(int32_t axis_, bool exclusive_, bool reverse_, std::string scope_ = {}): axis(axis_), exclusive(exclusive_), reverse(reverse_) { set_scope(scope_); } | |||
Cumprod(::megdnn::param::Cumprod packed_param_0): axis(packed_param_0.axis), exclusive(packed_param_0.exclusive), reverse(packed_param_0.reverse) {} | |||
::megdnn::param::Cumprod param() const { | |||
return {axis, exclusive, reverse}; | |||
} | |||
}; | |||
class Cumsum : public OpDefImplBase<Cumsum> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
@@ -795,29 +780,6 @@ case Elemwise::Mode::SILU_GRAD: return "SILU_GRAD"; | |||
case Elemwise::Mode::GELU: return "GELU"; | |||
case Elemwise::Mode::GELU_GRAD: return "GELU_GRAD"; | |||
case Elemwise::Mode::COND_LT_MOV: return "COND_LT_MOV"; | |||
case Elemwise::Mode::SINH: return "SINH"; | |||
case Elemwise::Mode::COSH: return "COSH"; | |||
case Elemwise::Mode::ASINH: return "ASINH"; | |||
case Elemwise::Mode::ACOSH: return "ACOSH"; | |||
case Elemwise::Mode::ATANH: return "ATANH"; | |||
case Elemwise::Mode::TAN: return "TAN"; | |||
case Elemwise::Mode::ASINH_GRAD: return "ASINH_GRAD"; | |||
case Elemwise::Mode::ACOSH_GRAD: return "ACOSH_GRAD"; | |||
case Elemwise::Mode::ATANH_GRAD: return "ATANH_GRAD"; | |||
case Elemwise::Mode::PRELU: return "PRELU"; | |||
case Elemwise::Mode::CLIP: return "CLIP"; | |||
case Elemwise::Mode::PRELU_GRAD: return "PRELU_GRAD"; | |||
case Elemwise::Mode::SOFTPLUS: return "SOFTPLUS"; | |||
case Elemwise::Mode::SOFTPLUS_GRAD: return "SOFTPLUS_GRAD"; | |||
case Elemwise::Mode::RELU6: return "RELU6"; | |||
case Elemwise::Mode::RELU6_GRAD: return "RELU6_GRAD"; | |||
case Elemwise::Mode::HSIGMOID: return "HSIGMOID"; | |||
case Elemwise::Mode::HSIGMOID_GRAD: return "HSIGMOID_GRAD"; | |||
case Elemwise::Mode::LOGSIGMOID: return "LOGSIGMOID"; | |||
case Elemwise::Mode::SQRT: return "SQRT"; | |||
case Elemwise::Mode::SQUARE: return "SQUARE"; | |||
case Elemwise::Mode::SIGN: return "SIGN"; | |||
case Elemwise::Mode::SAFE_DIV: return "SAFE_DIV"; | |||
case Elemwise::Mode::NEQ: return "NEQ"; | |||
case Elemwise::Mode::ISNAN: return "ISNAN"; | |||
case Elemwise::Mode::ISINF: return "ISINF"; | |||
@@ -678,14 +678,6 @@ CorrelationInst | |||
.def_readwrite("pad_size", &Correlation::pad_size) | |||
.def_readwrite("is_multiply", &Correlation::is_multiply); | |||
py::class_<Cumprod, std::shared_ptr<Cumprod>, OpDef> CumprodInst(m, "Cumprod"); | |||
CumprodInst | |||
.def(py::init<int32_t, bool, bool, std::string>(), py::arg("axis") = 2147483647, py::arg("exclusive") = true, py::arg("reverse") = false, py::arg("scope") = {}) | |||
.def_readwrite("axis", &Cumprod::axis) | |||
.def_readwrite("exclusive", &Cumprod::exclusive) | |||
.def_readwrite("reverse", &Cumprod::reverse); | |||
py::class_<Cumsum, std::shared_ptr<Cumsum>, OpDef> CumsumInst(m, "Cumsum"); | |||
CumsumInst | |||
@@ -901,29 +893,6 @@ py::enum_<Elemwise::Mode>(ElemwiseInst, "Mode") | |||
.value("GELU", Elemwise::Mode::GELU) | |||
.value("GELU_GRAD", Elemwise::Mode::GELU_GRAD) | |||
.value("COND_LT_MOV", Elemwise::Mode::COND_LT_MOV) | |||
.value("SINH", Elemwise::Mode::SINH) | |||
.value("COSH", Elemwise::Mode::COSH) | |||
.value("ASINH", Elemwise::Mode::ASINH) | |||
.value("ACOSH", Elemwise::Mode::ACOSH) | |||
.value("ATANH", Elemwise::Mode::ATANH) | |||
.value("TAN", Elemwise::Mode::TAN) | |||
.value("ASINH_GRAD", Elemwise::Mode::ASINH_GRAD) | |||
.value("ACOSH_GRAD", Elemwise::Mode::ACOSH_GRAD) | |||
.value("ATANH_GRAD", Elemwise::Mode::ATANH_GRAD) | |||
.value("PRELU", Elemwise::Mode::PRELU) | |||
.value("CLIP", Elemwise::Mode::CLIP) | |||
.value("PRELU_GRAD", Elemwise::Mode::PRELU_GRAD) | |||
.value("SOFTPLUS", Elemwise::Mode::SOFTPLUS) | |||
.value("SOFTPLUS_GRAD", Elemwise::Mode::SOFTPLUS_GRAD) | |||
.value("RELU6", Elemwise::Mode::RELU6) | |||
.value("RELU6_GRAD", Elemwise::Mode::RELU6_GRAD) | |||
.value("HSIGMOID", Elemwise::Mode::HSIGMOID) | |||
.value("HSIGMOID_GRAD", Elemwise::Mode::HSIGMOID_GRAD) | |||
.value("LOGSIGMOID", Elemwise::Mode::LOGSIGMOID) | |||
.value("SQRT", Elemwise::Mode::SQRT) | |||
.value("SQUARE", Elemwise::Mode::SQUARE) | |||
.value("SIGN", Elemwise::Mode::SIGN) | |||
.value("SAFE_DIV", Elemwise::Mode::SAFE_DIV) | |||
.value("NEQ", Elemwise::Mode::NEQ) | |||
.value("ISNAN", Elemwise::Mode::ISNAN) | |||
.value("ISINF", Elemwise::Mode::ISINF) | |||
@@ -990,29 +959,6 @@ py::enum_<Elemwise::Mode>(ElemwiseInst, "Mode") | |||
if (str == "GELU") return Elemwise::Mode::GELU; | |||
if (str == "GELU_GRAD") return Elemwise::Mode::GELU_GRAD; | |||
if (str == "COND_LT_MOV") return Elemwise::Mode::COND_LT_MOV; | |||
if (str == "SINH") return Elemwise::Mode::SINH; | |||
if (str == "COSH") return Elemwise::Mode::COSH; | |||
if (str == "ASINH") return Elemwise::Mode::ASINH; | |||
if (str == "ACOSH") return Elemwise::Mode::ACOSH; | |||
if (str == "ATANH") return Elemwise::Mode::ATANH; | |||
if (str == "TAN") return Elemwise::Mode::TAN; | |||
if (str == "ASINH_GRAD") return Elemwise::Mode::ASINH_GRAD; | |||
if (str == "ACOSH_GRAD") return Elemwise::Mode::ACOSH_GRAD; | |||
if (str == "ATANH_GRAD") return Elemwise::Mode::ATANH_GRAD; | |||
if (str == "PRELU") return Elemwise::Mode::PRELU; | |||
if (str == "CLIP") return Elemwise::Mode::CLIP; | |||
if (str == "PRELU_GRAD") return Elemwise::Mode::PRELU_GRAD; | |||
if (str == "SOFTPLUS") return Elemwise::Mode::SOFTPLUS; | |||
if (str == "SOFTPLUS_GRAD") return Elemwise::Mode::SOFTPLUS_GRAD; | |||
if (str == "RELU6") return Elemwise::Mode::RELU6; | |||
if (str == "RELU6_GRAD") return Elemwise::Mode::RELU6_GRAD; | |||
if (str == "HSIGMOID") return Elemwise::Mode::HSIGMOID; | |||
if (str == "HSIGMOID_GRAD") return Elemwise::Mode::HSIGMOID_GRAD; | |||
if (str == "LOGSIGMOID") return Elemwise::Mode::LOGSIGMOID; | |||
if (str == "SQRT") return Elemwise::Mode::SQRT; | |||
if (str == "SQUARE") return Elemwise::Mode::SQUARE; | |||
if (str == "SIGN") return Elemwise::Mode::SIGN; | |||
if (str == "SAFE_DIV") return Elemwise::Mode::SAFE_DIV; | |||
if (str == "NEQ") return Elemwise::Mode::NEQ; | |||
if (str == "ISNAN") return Elemwise::Mode::ISNAN; | |||
if (str == "ISINF") return Elemwise::Mode::ISINF; | |||
@@ -472,7 +472,6 @@ def ExternOpr: MgbHashableOp<"ExternOpr"> { | |||
} | |||
def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>; | |||
def Cumprod: MgbHashableOp<"Cumprod", [CumprodParam]>; | |||
def Split: MgbHashableOp<"Split", [EmptyParam]> { | |||
let extraArguments = (ins | |||
@@ -107,12 +107,6 @@ TEST(TestGoptBasicArithInplace, Absorbing) { | |||
ASSERT_EQ(y.as_immutable_scalar()->get_cast<float>(), 0.f); | |||
} | |||
auto gen_postive = [](HostTensorND& dest) { | |||
HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> mask_generator{ | |||
2.f, 4.f}; | |||
dest = *mask_generator(dest.shape(), dest.comp_node()); | |||
}; | |||
TEST(TestGoptBasicArithInplace, LogExpExpand) { | |||
// test log(exp(a) * (exp(b) / (exp(c) * d**2))) -> a + b - c - log(d**2) | |||
@@ -150,13 +144,9 @@ TEST(TestGoptBasicArithInplace, LogExpExpand) { | |||
opt.numdiff_eps_single_inp[3] = 1e-3; | |||
opt.numdiff_max_err_single_inp[3] = 1e-2; | |||
Checker{make_graph, fwd} | |||
.set_input_generator(0, gen_postive) | |||
.set_input_generator(1, gen_postive) | |||
.set_input_generator(2, gen_postive) | |||
.set_input_generator(3, gen_postive) | |||
.run(ms({32, 1}, {32, 1}), opt) | |||
.run(ms({2, 32}, {2, 32}), opt) | |||
.run(ms({1, 32}, {1, 32}), opt); | |||
.run(ms({2, 3}, {2, 3}), opt) | |||
.run(ms({1, 3}, {2, 3}), opt) | |||
.run(ms({3, 2}, {1}), opt); | |||
} | |||
TEST(TestGoptBasicArithInplace, LogSumExp) { | |||
@@ -133,7 +133,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { | |||
0.f}) / | |||
6.f), | |||
}; | |||
mgb_assert(map.size() + 42 == opr::Elemwise::Param::MODE_NR_MEMBER); | |||
mgb_assert(map.size() + 41 == opr::Elemwise::Param::MODE_NR_MEMBER); | |||
// unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, | |||
// ERFINV, ERFCINV, NOT, AND, OR, XOR, NEQ, ISNAN, ISINF | |||
return map; | |||
@@ -580,8 +580,6 @@ MGB_IMPL_OPR_GRAD(Elemwise) { | |||
RET(EL2(ABS_GRAD, i0, og)); | |||
case Mode::ADD: | |||
RET(og); | |||
case Mode::SAFE_DIV: | |||
RET_INVALID(); | |||
case Mode::FLOOR_DIV: | |||
return nullptr; | |||
case Mode::MAX: | |||
@@ -119,75 +119,6 @@ MGB_IMPL_OPR_GRAD(ArgsortForward) { | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ArgsortBackward); | |||
MEGDNN_OPR_INIT3(ArgsortBackward, "argsort_bwd", 2, false) | |||
/* ================= Cumprod ================= */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumprod); | |||
Cumprod::Cumprod(VarNode* opr, const Param& param, const OperatorNodeConfig& config) | |||
: Super{opr->owner_graph(), config, "Cumprod", {opr}} { | |||
init_megdnn_opr(*this, param); | |||
add_input({opr}, AddInputSortType::CUR_ADDED); | |||
} | |||
#if MGB_ENABLE_GRAD | |||
MGB_IMPL_OPR_GRAD(Cumprod) { | |||
mgb_assert(out_grad[0] && !out_grad[1]); | |||
auto x = SymbolVar{opr.input(0)}, y = SymbolVar{opr.output(0)}, | |||
grad = SymbolVar{out_grad[0]}; | |||
auto prod_param = opr.param(); | |||
Cumsum::Param reversed_param; | |||
reversed_param.axis = prod_param.axis; | |||
reversed_param.exclusive = prod_param.exclusive; | |||
reversed_param.reverse = !prod_param.reverse; | |||
auto w = y * grad; | |||
return Elemwise::make( | |||
{Cumsum::make(w, reversed_param), x}, Elemwise::Mode::SAFE_DIV) | |||
.node(); | |||
} | |||
#endif | |||
SymbolVar Cumprod::make( | |||
SymbolVar opr, const Param& param, const OperatorNodeConfig& config) { | |||
return opr.insert_single_output_opr<Cumprod>(opr.node(), param, config); | |||
} | |||
void Cumprod::scn_do_execute() { | |||
megdnn_opr()->exec( | |||
input(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||
intl::get_megdnn_workspace_from_var(output().back())); | |||
} | |||
void Cumprod::add_input_layout_constraint() { | |||
input(0)->add_layout_constraint_contiguous(); | |||
} | |||
void Cumprod::init_output_static_infer_desc() { | |||
using namespace cg::static_infer; | |||
auto infer_shape = [](TensorShape& dest, const InpVal& iv) { | |||
auto ishp = iv.val.at(0).shape(); | |||
dest = ishp; | |||
return true; | |||
}; | |||
owner_graph()->static_infer_manager().register_shape_infer( | |||
output(0), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape}); | |||
auto infer_workspace = [this](TensorShape& dest, const InpVal& iv) { | |||
auto dtype = input(0)->dtype(); | |||
auto ishp = iv.val.at(0).shape(); | |||
TensorLayout ily(ishp, dtype); | |||
Param real_param = param(); | |||
if (real_param.axis < 0) | |||
real_param.axis += ishp.ndim; | |||
megdnn_opr()->param() = real_param; | |||
dest.ndim = 1; | |||
dest[0] = megdnn_opr()->get_workspace_in_bytes(ily, ily); | |||
return true; | |||
}; | |||
owner_graph()->static_infer_manager().register_shape_infer( | |||
output(1), | |||
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace}); | |||
} | |||
/* ================= Cumsum ================= */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumsum); | |||
@@ -42,16 +42,6 @@ decl_opr('Argsort', | |||
'performed. Two vars are returned: the sorted array, and the ' | |||
'indices. ') | |||
decl_opr('Cumprod', | |||
inputs=['src'], params='Cumprod', | |||
body=[ | |||
'if param.axis == (1<<31)-1:', | |||
' all_inputs[0] = all_inputs[0].flatten()', | |||
' param.axis = 0' | |||
], | |||
desc='Return the cumulative product of the elements along a given axis.' | |||
' If axis is INT_MAX, compute on flattened input.', version=1) | |||
decl_opr('Cumsum', | |||
inputs=['src'], params='Cumsum', | |||
body=[ | |||
@@ -70,7 +70,6 @@ MGB_SEREG_OPR(TopK, 2); | |||
//! current cumsum version | |||
using CumsumV1 = opr::Cumsum; | |||
MGB_SEREG_OPR(CumsumV1, 1); | |||
MGB_SEREG_OPR(Cumprod, 1); | |||
#if MGB_CUDA | |||
MGB_SEREG_OPR(NvOf, 1); | |||
@@ -76,25 +76,6 @@ public: | |||
} | |||
}; | |||
//! cumulative product along given axis | |||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
Cumprod, | |||
cg::SingleCNOperatorNodeBaseT<mixin::MegDNNOprHolderImpl<megdnn::Cumprod>>) // { | |||
void add_input_layout_constraint() override; | |||
public: | |||
MGE_WIN_DECLSPEC_FUC Cumprod( | |||
VarNode* src, const Param& param, const OperatorNodeConfig& config); | |||
// for serialization | |||
MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
SymbolVar opr, const Param& param, const OperatorNodeConfig& config = {}); | |||
protected: | |||
void scn_do_execute() override; | |||
void init_output_static_infer_desc() override; | |||
}; | |||
//! cumulative sum along given axis | |||
MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
Cumsum, | |||
@@ -350,9 +350,6 @@ template <> | |||
struct CheckerConfig<H_SWISH_GRAD> : public NoGradCheckerConfig {}; | |||
template <> | |||
struct CheckerConfig<SAFE_DIV> : public NoGradCheckerConfig {}; | |||
template <> | |||
struct CheckerConfig<TAN> : public NoGradCheckerConfig { | |||
template <typename ctype> | |||
static InputGenerator get_inp_gen(size_t) { | |||
@@ -696,13 +693,9 @@ struct CheckerConfig<HSIGMOID_GRAD> : public NoGradCheckerConfig { | |||
/* ======================= ternary config ======================= */ | |||
template <> | |||
struct CheckerConfig<COND_LEQ_MOV> : public BinaryInputMinGap<false> {}; | |||
template <> | |||
struct CheckerConfig<COND_LT_MOV> : public BinaryInputMinGap<false> {}; | |||
template <> | |||
struct CheckerConfig<PRELU_GRAD> : public NoGradCheckerConfig {}; | |||
template <> | |||
struct CheckerConfig<CLIP> : public CheckerConfig<void> { | |||
template <typename ctype, class Checker> | |||
@@ -893,10 +886,6 @@ void TestRunner<Trait, dtype, true>::run() { | |||
} | |||
TensorShape shapes[] = {{1}, {23, 3}, {666}}; | |||
if (Trait::ARITY == 4) { | |||
checker.disable_graph_opt(); | |||
shapes[0] = {32}; | |||
} | |||
typename Checker::RunOptions opt; | |||
Config::update_opt(opt); | |||
Config::update_checker(checker); | |||
@@ -1045,13 +1034,13 @@ TEST(TestOprBasicArithElemwise, FuseMulAdd4Shapes) { | |||
}; | |||
Checker checker{make_graph, fwd}; | |||
checker.run({TensorShape{1, 32}, {1, 32}, {1, 32}, {1, 32}}) | |||
.run({TensorShape{1, 1, 1, 1, 1, 32}, | |||
{1, 1, 1, 1, 1, 32}, | |||
{1, 1, 1, 1, 1, 32}, | |||
{1, 1, 1, 1, 1, 32}}); | |||
checker.run({TensorShape{1, 2}, {2, 1}, {1, 2}, {2, 1}}) | |||
.run({TensorShape{1, 2, 1, 2, 1, 2}, | |||
{2, 1, 2, 1, 2, 1}, | |||
{2, 1, 2, 1, 2, 1}, | |||
{1, 2, 1, 2, 1, 2}}); | |||
ASSERT_FALSE(opr->fuse_badlayout_warn_printed()); | |||
checker.run({TensorShape{1, 32}, {32, 1}, {32, 32}, {32, 32}}); | |||
checker.run({TensorShape{1, 2}, {2, 1}, {2, 2}, {2, 2}}); | |||
ASSERT_TRUE(opr->fuse_badlayout_warn_printed()); | |||
} | |||
@@ -47,7 +47,6 @@ DEF_TRAIT(PRELU, (x > 0) ? x : (x* y)) | |||
#define _ALLOW_INT false | |||
DEF_TRAIT(POW, std::pow(x, y)) | |||
DEF_TRAIT(TRUE_DIV, x / y) | |||
DEF_TRAIT(SAFE_DIV, y != 0 ? x / y : 0) | |||
DEF_TRAIT(LOG_SUM_EXP, do_log_sum_exp(x, y)) | |||
DEF_TRAIT(FUSE_ADD_SIGMOID, 1 / (1 + std::exp(-(x + y)))) | |||
DEF_TRAIT(FUSE_ADD_TANH, std::tanh(x + y)) | |||
@@ -145,60 +145,6 @@ TEST(TestOprMisc, Argsort) { | |||
run(Order::DESCENDING); | |||
} | |||
TEST(TestOprMisc, Cumprod) { | |||
using Param = opr::Cumprod::Param; | |||
auto run = [](const Param& param) { | |||
using Checker = AutoOprChecker<1, 1>; | |||
auto make_graph = | |||
[&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||
return {opr::Cumprod::make(inputs[0], param)}; | |||
}; | |||
auto fwd = [&](Checker::NumOutArray& out, Checker::NumInpArray inp) { | |||
out[0].resize(inp[0]->shape()); | |||
auto pin = inp[0]->ptr<float>(), pout = out[0].ptr<float>(); | |||
size_t A, B, C; | |||
int real_axis = param.axis; | |||
if (real_axis < 0) | |||
real_axis += 3; | |||
shape_abc(inp[0]->shape(), real_axis, A, B, C); | |||
ptrdiff_t stride = C; | |||
if (param.reverse) | |||
stride = -stride; | |||
for (size_t i = 0; i < A; ++i) { | |||
for (size_t k = 0; k < C; ++k) { | |||
auto pi = pin + i * B * C + k, po = pout + i * B * C + k; | |||
if (param.reverse) { | |||
pi += (B - 1) * C; | |||
po += (B - 1) * C; | |||
} | |||
if (param.exclusive) { | |||
*po = 1; | |||
po += stride; | |||
} | |||
float prod = 1; | |||
for (size_t j = 0; j < B - 1; ++j) { | |||
prod *= pi[j * stride]; | |||
po[j * stride] = prod; | |||
} | |||
if (!param.exclusive) { | |||
po[(B - 1) * stride] = prod * pi[(B - 1) * stride]; | |||
} | |||
} | |||
} | |||
}; | |||
Checker{make_graph, fwd} | |||
.run({TensorShape{2, 3, 4}}) | |||
.run({TensorShape{3, 1, 2}}) | |||
.run({TensorShape{4, 2, 3}}); | |||
}; | |||
// test negative axis | |||
for (int32_t axis = -3; axis < 3; ++axis) | |||
for (int mask = 0; mask < 4; ++mask) | |||
run({axis, bool(mask >> 1), bool(mask & 1)}); | |||
} | |||
TEST(TestOprMisc, Cumsum) { | |||
using Param = opr::Cumsum::Param; | |||
auto run = [](const Param& param) { | |||
@@ -123,7 +123,6 @@ union OperatorParam { | |||
param.LSTM = 89, | |||
param.Softmax = 90, | |||
param.Diag = 91, | |||
param.Cumprod = 92, | |||
} | |||
table Operator { | |||
@@ -224,12 +224,6 @@ DEF_IMPL(void)::do_run(const ShapeInpArray& shapes, const RunOptions& opt) { | |||
m_inputs_generator[i](*m_inputs[i]); | |||
mgb_assert(m_inputs[i]->shape().eq_shape(shapes[i])); | |||
} | |||
if (shapes.size() == 4u) { | |||
m_extra_err_msg = ssprintf("%d,", *((int*)(m_inputs[0]->raw_ptr()) + 11)); | |||
m_extra_err_msg += ssprintf("%d,", *((int*)(m_inputs[1]->raw_ptr()) + 11)); | |||
m_extra_err_msg += ssprintf("%d,", *((int*)(m_inputs[2]->raw_ptr()) + 11)); | |||
m_extra_err_msg += ssprintf("%d,", *((int*)(m_inputs[3]->raw_ptr()) + 11)); | |||
} | |||
if (MGB_GETENV("MGB_AUTOCHECK_DUMP_INPUT")) { | |||
static size_t run_id; | |||
auto fname = output_file(ssprintf("autocheck-inp-%zu.bin", run_id++)); | |||