From 49e14f87b578b535a1002e4a55da4f9d3c812777 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 25 Apr 2022 18:46:17 +0800 Subject: [PATCH] feat(mgb): add cumprod opr GitOrigin-RevId: 3436c3bdaa4c11de35e3e37849a11202318cc99f --- dnn/include/megdnn/oprs/general.h | 42 +++ dnn/scripts/gen_elemwise_multi_type_utils.py | 2 +- dnn/scripts/gen_elemwise_utils.py | 2 +- dnn/scripts/opr_param_defs.py | 21 +- dnn/src/common/cumprod.cpp | 25 ++ dnn/src/common/elemwise/each_mode.inl | 3 +- dnn/src/common/elemwise/kern_defs.cuh | 2 + dnn/src/common/elemwise/opr_impl.cpp | 1 + dnn/src/common/handle_impl.h | 1 + dnn/src/common/opr_trait.h | 1 + dnn/src/cuda/cumprod/cumprod.cu | 25 ++ dnn/src/cuda/cumprod/kern.cuh | 62 ++++ dnn/src/cuda/cumprod/kern_helper.cuh | 18 ++ dnn/src/cuda/cumprod/kern_impl.cu | 82 ++++++ dnn/src/cuda/cumprod/kern_impl.cuinl | 326 +++++++++++++++++++++ dnn/src/cuda/cumprod/opr_impl.cpp | 63 ++++ dnn/src/cuda/cumprod/opr_impl.h | 19 ++ .../cuda/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cu | 7 + dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_float16.cu | 7 + dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_float32.cu | 5 + .../kimpl/SAFE_DIV_dt_qint8_dt_qint8.cu | 6 + dnn/src/cuda/elemwise_multi_type/opr_impl.cpp | 1 + dnn/src/cuda/handle_create.cpp | 2 + .../elemwise/fallback_impl/opr_binary_impl.cpp | 1 + dnn/src/naive/cumprod/opr_impl.cpp | 72 +++++ dnn/src/naive/cumprod/opr_impl.h | 20 ++ .../naive/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cpp | 7 + .../naive/elemwise/kimpl/SAFE_DIV_dt_float16.cpp | 7 + .../naive/elemwise/kimpl/SAFE_DIV_dt_float32.cpp | 5 + dnn/src/naive/elemwise/opr_impl.cpp | 2 + dnn/src/naive/handle.cpp | 1 + .../elemwise/kimpl/SAFE_DIV_dt_bfloat16.cpp.hip | 7 + .../elemwise/kimpl/SAFE_DIV_dt_float16.cpp.hip | 7 + .../elemwise/kimpl/SAFE_DIV_dt_float32.cpp.hip | 5 + dnn/test/common/elemwise.cpp | 2 +- dnn/test/cuda/cumprod.cpp | 63 ++++ dnn/test/naive/record1.cpp | 43 +++ imperative/python/megengine/functional/tensor.py | 19 +- imperative/python/test/unit/core/test_autodiff.py | 16 + imperative/src/impl/ops/cumxxx.cpp | 34 +++ imperative/src/impl/ops/specializations.cpp | 12 - imperative/tablegen/generated/hash.txt | 12 +- imperative/tablegen/generated/opdef.cpp.inl | 112 +++++++ imperative/tablegen/generated/opdef.cpy.inl | 255 +++++++++++++++- imperative/tablegen/generated/opdef.h.inl | 38 +++ imperative/tablegen/generated/opdef.py.inl | 54 ++++ src/core/include/megbrain/ir/ops.td | 1 + src/gopt/test/basic_arith.cpp | 16 +- src/jit/impl/ast_c.cpp | 2 +- src/opr/impl/basic_arith.cpp | 2 + src/opr/impl/misc.cpp | 69 +++++ src/opr/impl/misc.oprdecl | 10 + src/opr/impl/misc.sereg.h | 1 + src/opr/include/megbrain/opr/misc.h | 19 ++ src/opr/test/basic_arith/elemwise.cpp | 23 +- .../test/basic_arith/elemwise_binary_trait_def.inl | 1 + src/opr/test/misc.cpp | 54 ++++ src/serialization/impl/schema.fbs | 1 + test/src/autocheck.cpp | 6 + 59 files changed, 1679 insertions(+), 43 deletions(-) create mode 100644 dnn/src/common/cumprod.cpp create mode 100644 dnn/src/cuda/cumprod/cumprod.cu create mode 100644 dnn/src/cuda/cumprod/kern.cuh create mode 100644 dnn/src/cuda/cumprod/kern_helper.cuh create mode 100644 dnn/src/cuda/cumprod/kern_impl.cu create mode 100644 dnn/src/cuda/cumprod/kern_impl.cuinl create mode 100644 dnn/src/cuda/cumprod/opr_impl.cpp create mode 100644 dnn/src/cuda/cumprod/opr_impl.h create mode 100644 dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_float16.cu create mode 100644 dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_float32.cu create mode 100644 dnn/src/cuda/elemwise_multi_type/kimpl/SAFE_DIV_dt_qint8_dt_qint8.cu create mode 100644 dnn/src/naive/cumprod/opr_impl.cpp create mode 100644 dnn/src/naive/cumprod/opr_impl.h create mode 100644 dnn/src/naive/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cpp create mode 100644 dnn/src/naive/elemwise/kimpl/SAFE_DIV_dt_float16.cpp create mode 100644 dnn/src/naive/elemwise/kimpl/SAFE_DIV_dt_float32.cpp create mode 100644 dnn/src/rocm/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cpp.hip create mode 100644 dnn/src/rocm/elemwise/kimpl/SAFE_DIV_dt_float16.cpp.hip create mode 100644 dnn/src/rocm/elemwise/kimpl/SAFE_DIV_dt_float32.cpp.hip create mode 100644 dnn/test/cuda/cumprod.cpp create mode 100644 imperative/src/impl/ops/cumxxx.cpp diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index 9b556bc3..0800cab0 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -314,6 +314,48 @@ protected: }; using Cumsum = CumsumForward; +class CumprodForward : public OperatorBase { + DEF_OPR_PARAM(Cumprod); + DEF_OPR_IMPL(CumprodForward, OperatorBase, 1, 1); + +public: + /** + * \param[in] src input tensor + * \param[out] dst output tensor + * + * src and dst should be contiguous. + * src and dst should have the same shape. + * + * The exclusive flag specifies whether the current element it taken + * into account when calculating results. + * + * The reverse flag specifies whether cumprod is forward ( + * from 0 to n) or backward (from n downto 0). + * + * Example: + * exclusive && reverse: + * dst_i = src_{i+1} * src_{i+2} * ... * src_{n-1} + * exclusive && !reverse + * dst_i = src_0 * src_1 * ... * src_{i-1} + * !exclusive && reverse: + * dst_i = src_i * src_{i+1} * ... * src_{n-1} + * !exclusive && !reverse: + * dst_i = src_0 * src_1 * ... * src_i + */ + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); +}; +using Cumprod = CumprodForward; + // mxx can be max or min class ArgmxxBase : public OperatorBase { DEF_OPR_IMPL_CTOR(ArgmxxBase, OperatorBase); diff --git a/dnn/scripts/gen_elemwise_multi_type_utils.py b/dnn/scripts/gen_elemwise_multi_type_utils.py index 39aec818..71126e2d 100755 --- a/dnn/scripts/gen_elemwise_multi_type_utils.py +++ b/dnn/scripts/gen_elemwise_multi_type_utils.py @@ -24,7 +24,7 @@ MODES = { 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD', 'PRELU', 'ASINH_GRAD', 'ACOSH_GRAD', 'ATANH_GRAD', 'SOFTPLUS_GRAD', - 'RELU6_GRAD', 'HSIGMOID_GRAD'], + 'RELU6_GRAD', 'HSIGMOID_GRAD', 'SAFE_DIV'], 3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3', 'CLIP', 'PRELU_GRAD'], } diff --git a/dnn/scripts/gen_elemwise_utils.py b/dnn/scripts/gen_elemwise_utils.py index 52f28bb7..e7848d1e 100755 --- a/dnn/scripts/gen_elemwise_utils.py +++ b/dnn/scripts/gen_elemwise_utils.py @@ -31,7 +31,7 @@ MODES = { 'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', 'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD', 'PRELU', 'ASINH_GRAD', 'ACOSH_GRAD', 'ATANH_GRAD', 'SOFTPLUS_GRAD', - 'RELU6_GRAD', 'HSIGMOID_GRAD'], + 'RELU6_GRAD', 'HSIGMOID_GRAD', 'SAFE_DIV'], (3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3', 'CLIP', 'PRELU_GRAD'], (1, 'BOOL'): ['NOT'], (2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'], diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 0e901b57..c5dd9ec7 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -443,9 +443,10 @@ pdef('Elemwise').add_enum( Doc('SQRT = 80', 'unary: x^(1/2)'), Doc('SQUARE = 81', 'unary: x^2'), Doc('SIGN = 82', 'unary: sgn(x)'), - Doc('NEQ = 83', 'binary: x != y'), - Doc('ISNAN = 84', 'unary: isnan(x)'), - Doc('ISINF = 85', 'unary: isinf(x)'), + Doc('SAFE_DIV = 83', 'safe div: x / y'), + Doc('NEQ = 84', 'binary: x != y'), + Doc('ISNAN = 85', 'unary: isnan(x)'), + Doc('ISINF = 86', 'unary: isinf(x)'), ) pdef('ElemwiseMultiType').add_enum( @@ -739,6 +740,20 @@ Currently, ```DEFAULT``` mode means: 'whether the cumsum is forward or backward'), 'false')) +(pdef('Cumprod', 'calculate accumulated product along given axis'). + add_fields('int32', + Doc('axis', + 'axis along which cumprod is performed, default with INT_MAX'), + (1<<31)-1). + add_fields('bool', + Doc('exclusive', + 'whether the current element is taken into account'), + 'true'). + add_fields('bool', + Doc('reverse', + 'whether the cumprod is forward or backward'), + 'false')) + (pdef('CondTake'). add_enum('Mode', Doc('EQ = 0', 'take if ``abs(data-val)= 0); + megdnn_assert(static_cast(param().axis) < src.ndim); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/elemwise/each_mode.inl b/dnn/src/common/elemwise/each_mode.inl index 48ca51dc..8c05fcab 100644 --- a/dnn/src/common/elemwise/each_mode.inl +++ b/dnn/src/common/elemwise/each_mode.inl @@ -89,7 +89,8 @@ MEGDNN_ELEMWISE_MODE_ENABLE(ATANH_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(SOFTPLUS_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(RELU6_GRAD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) + MEGDNN_ELEMWISE_MODE_ENABLE(HSIGMOID_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index 95105470..b5dcca44 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -247,6 +247,8 @@ DEF_KERN(dt_bool, EQ, x == y); DEF_KERN_INT(FLOOR_DIV, dispatch_floordiv_int(x, y)); DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); +DEF_KERN_INT(SAFE_DIV, y != 0 ? x / y : 0); +DEF_KERN_FLOAT(SAFE_DIV, y != 0.f ? x / y : 0.f); DEF_KERN_INT(MOD, x % y); DEF_KERN_FLOAT(MOD, fmodf(x, y)); diff --git a/dnn/src/common/elemwise/opr_impl.cpp b/dnn/src/common/elemwise/opr_impl.cpp index 2f0a1d5a..3b38ddb2 100644 --- a/dnn/src/common/elemwise/opr_impl.cpp +++ b/dnn/src/common/elemwise/opr_impl.cpp @@ -242,6 +242,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { CB_MODE(Mode::SQRT); CB_MODE(Mode::SQUARE); CB_MODE(Mode::SIGN); + CB_MODE(Mode::SAFE_DIV); default: megdnn_assert( 0, diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index a63bd078..07a7535e 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -106,6 +106,7 @@ private: cb(SVDForward) \ cb(ReduceForward) \ cb(CondTake) \ + cb(CumprodForward) \ cb(CumsumForward) \ cb(ArgmaxForward) \ cb(ArgminForward) \ diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 89bfad73..8d13b960 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -62,6 +62,7 @@ DEF(BatchedMatrixMulForward, 3, true, true); DEF(MatrixInverse, 2, true, true); DEF(SVDForward, 4, true, true); DEF(ReduceForward, 2, true, true); +DEF(CumprodForward, 2, true, true); DEF(CumsumForward, 2, true, true); DEF(ArgmaxForward, 2, true, true); DEF(ArgminForward, 2, true, true); diff --git a/dnn/src/cuda/cumprod/cumprod.cu b/dnn/src/cuda/cumprod/cumprod.cu new file mode 100644 index 00000000..a1b21a6f --- /dev/null +++ b/dnn/src/cuda/cumprod/cumprod.cu @@ -0,0 +1,25 @@ +#include "./kern_impl.cuinl" + +namespace megdnn { +namespace cuda { +namespace cumprod { + +#define INST_(T, Op, exclusive, reverse) \ + template void run_kern( \ + T*, void*, uint32_t, uint32_t, uint32_t, uint32_t, const Op&, \ + cudaStream_t) +#define INST(T) \ + INST_(T, ProdOp, true, true); \ + INST_(T, ProdOp, false, true); \ + INST_(T, ProdOp, true, false); \ + INST_(T, ProdOp, false, false); + +#define cb(DType) INST(typename DTypeTrait::ctype) + +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + +} // namespace cumprod +} // namespace cuda +} // namespace megdnn + +// vim: ft=cuda syntax=cuda.doxygen diff --git a/dnn/src/cuda/cumprod/kern.cuh b/dnn/src/cuda/cumprod/kern.cuh new file mode 100644 index 00000000..d9e94a97 --- /dev/null +++ b/dnn/src/cuda/cumprod/kern.cuh @@ -0,0 +1,62 @@ +#pragma once + +#include "src/cuda/utils.cuh" + +#include +#include + +namespace megdnn { +namespace cuda { +namespace cumprod { + +//! compute conventional sum of elements +template +struct ProdOp { + const T* data; + typedef ProdOp ContigOp; + + ProdOp(const T* d) : data(d) {} + + __host__ __device__ static T init() { return T(1); } + __device__ static T apply(T lhs, T rhs) { return lhs * rhs; } + __device__ T visit(uint32_t idx) const { return data[idx]; } + + static ProdOp make_contig(const T* data) { return ProdOp(data); } +}; + +/*! + * \brief cumprod kernel launcher; defined in kern_impl.cuinl + * \tparam T output data type + * \tparam Op reduction operator class, which must provide following interface: + * typdef ContigOp + * static T init(): the identity element + * static T apply(T lhs, T rhs): the reduction operation + * T visit(uint32_t idx) const: access input + * static ContigOp make_contig(const T *data): make an Oo to continue + * reduction on temp buffer + * + * Note that Op::init() must be accessible from both host and device. + * + * In exclusive mode, Op::init() would be filled to the boundary + * + * The buffer in *op* and *dst* should not have identical memory addresses. + */ +template +void run_kern( + T* dst, void* workspace, uint32_t workspace_size, uint32_t A, uint32_t B, + uint32_t C, const Op& op, cudaStream_t stream); + +/*! + * \brief get required workspace size for cumprod, in bytes + * \param item_size size of item; i.e. sizeof(T) in run_kern + * + * Note: cuda device must be set to the computing device before calling this + * function. + */ +uint32_t get_workspace_in_bytes(uint32_t A, uint32_t B, uint32_t C, uint32_t item_size); + +} // namespace cumprod +} // namespace cuda +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/cuda/cumprod/kern_helper.cuh b/dnn/src/cuda/cumprod/kern_helper.cuh new file mode 100644 index 00000000..52a3bdc4 --- /dev/null +++ b/dnn/src/cuda/cumprod/kern_helper.cuh @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +namespace megdnn { +namespace cuda { +namespace cumprod { + +void get_BX_BY(uint32_t A, uint32_t B, uint32_t C, uint32_t& BX, uint32_t& BY); + +uint32_t get_workspace_bytes_for_cub_1d(uint32_t nr_item, uint32_t item_size); + +} // namespace cumprod +} // namespace cuda +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/cuda/cumprod/kern_impl.cu b/dnn/src/cuda/cumprod/kern_impl.cu new file mode 100644 index 00000000..c3e02a31 --- /dev/null +++ b/dnn/src/cuda/cumprod/kern_impl.cu @@ -0,0 +1,82 @@ +#include "./kern.cuh" +#include "./kern_helper.cuh" +#include "./kern_impl.cuinl" +#include "src/cuda/kernel_common/diagnostic_prologue.cuh" + +using namespace megdnn::cuda; +using namespace cumprod::detail::cubwrap; + +namespace { + +template +struct FakeOp { + __device__ T visit(int) { return 0; } + __device__ static T apply(T, T) { return 0; } +}; + +template +uint32_t get_workspace_elems_for_cub_1d_with_dtype_reverse(uint32_t nr_item) { + typedef FakeOp Op; + Op op; + InputIterator inp_iter(op, nr_item); + OutputIterator out_iter(NULL, nr_item); + ScanOp scan_op; + + size_t wk_size0 = 0, wk_size1 = 0; + cuda_check(cub::DeviceScan::ExclusiveScan( + NULL, wk_size0, inp_iter, out_iter, scan_op, 0, nr_item)); + cuda_check(cub::DeviceScan::InclusiveScan( + NULL, wk_size1, inp_iter, out_iter, scan_op, nr_item)); + return std::max(wk_size0, wk_size1); +} + +template +uint32_t get_workspace_elems_for_cub_1d_with_dtype(uint32_t nr_item) { + return std::max( + get_workspace_elems_for_cub_1d_with_dtype_reverse(nr_item), + get_workspace_elems_for_cub_1d_with_dtype_reverse(nr_item)); +} + +} // namespace + +uint32_t cumprod::get_workspace_bytes_for_cub_1d(uint32_t nr_item, uint32_t item_size) { + switch (item_size) { +#define CASE(size, type) \ + case size: \ + return get_workspace_elems_for_cub_1d_with_dtype(nr_item) + CASE(1, uint8_t); + CASE(2, uint16_t); + CASE(4, uint32_t); + CASE(8, uint64_t); +#undef CASE + default: + report_error("unsupported item size in cumprod"); + } +} + +uint32_t cumprod::get_workspace_in_bytes( + uint32_t A, uint32_t B, uint32_t C, uint32_t item_size) { + if (A == 1 && C == 1) { + return get_workspace_bytes_for_cub_1d(B, item_size); + } + uint32_t BX, BY; + get_BX_BY(A, B, C, BX, BY); + uint32_t BY2 = BY * 2; + uint32_t res = 0; + while (B > BY2) { + B = (B + BY2 - 1) / BY2; + res += A * B * C; + } + return res * item_size; +} + +void cumprod::get_BX_BY( + uint32_t /* A */, uint32_t /* B */, uint32_t C, uint32_t& BX, uint32_t& BY) { + BX = 1; + while (BX < C && BX * 2 <= 32) + BX *= 2; + BY = 512 / BX; +} + +#include "src/cuda/kernel_common/diagnostic_epilogue.cuh" +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/cumprod/kern_impl.cuinl b/dnn/src/cuda/cumprod/kern_impl.cuinl new file mode 100644 index 00000000..6f8e08a0 --- /dev/null +++ b/dnn/src/cuda/cumprod/kern_impl.cuinl @@ -0,0 +1,326 @@ +#include "./kern.cuh" +#include "./kern_helper.cuh" +#include "megdnn/dtype.h" +#include "src/cuda/cub/device/device_scan.cuh" +#include "src/cuda/cub/util_ptx.cuh" + +namespace megdnn { +namespace cuda { +namespace cumprod { +namespace detail { + +/** + * src shape is (A, B, C), performing blockwise scan over B axis. + * Each CUDA block calculates a blockwise scan result of size (BY2, BX). + * The block area corresponds to a 2-D area on (B, C) dimension of src. + * + * Per-block prefix sum is stored in dst (dst has the same shape as src). + * + * The whole scan result of each block as a single value is stored in + * block_sum (of shape (A, B/BY2, C)). + * + * block_sum can be NULL. + * + * src and dst can be inplace. + * + * We need to launch (C/BX)*ceil(B/BY2)*A blocks in total. + * Because in CUDA the number of launched blocks over y and z axis are + * limited (at most 65535), we launch all blocks over axis x. + * + * Param: exclusive + * This flag specifies whether the scan is inclusive or exclusive, namely + * whether src_i influences dst_i. + * + * Param: reverse: + * This flag specifies whether the scan is forward or backward. + * + * Example: + * !exclusive && !reverse: dst_i = op(src_0, src_1, ..., src_i) + * !exclusive && reverse: dst_i = op(src_i, src_{i+1}, ..., src_{n-1}) + * exclusive && !reverse: dst_i = op(src_0, src_1, ..., src{i-1}) + * exclusive && reverse: dst_i = op(src_{i+1}, src{i+2}, ..., src{n-1}) + * + * Op should have the following methods: + * static T init() + * static T apply(T lhs, T rhs) + */ +template +__global__ void scan_kernel(T *dst, T *block_sum, + uint32_t A, uint32_t B, uint32_t C, const Op op) { + constexpr size_t warp_size = 32; + const uint32_t BY2 = BY*2; + const uint32_t B_ = (B+BY2-1) / BY2; + const uint32_t C_ = (C+BX-1) / BX; + const uint32_t GX = C_; + const uint32_t GY = B_; + // src, dst: (A, B, C) + // block_sum: (A, B_, C) + // shared: (BY2+1, BX) + const uint32_t bx = blockIdx.x % GX; + const uint32_t by = blockIdx.x / GX % GY; + const uint32_t bz = blockIdx.x / GX / GY; + const uint32_t tx = threadIdx.x; + const uint32_t ty = threadIdx.y; + // TODO: shared memory bank conflict optimization +#define shared_idx(x) ((x) + ((x) >> 5)) + volatile __shared__ T cache[shared_idx((BY2+1)*BX)]; + uint32_t base_offset = (bz)*B*C + (by*BY2)*C + (bx*BX); + dst += base_offset; + // load to cache + if (reverse) { + cache[shared_idx((BY2-ty)*BX+tx)] = ty+by*BY2 < B && tx+bx*BX < C ? + op.visit(base_offset + ty*C + tx) : Op::init(); + } else { + cache[shared_idx((ty+1)*BX+tx)] = ty+by*BY2 < B && tx+bx*BX < C ? + op.visit(base_offset + ty*C + tx) : Op::init(); + } + if (reverse) { + cache[shared_idx((BY-ty)*BX+tx)] = + (ty+BY) + by*BY2 < B && tx+bx*BX < C ? + op.visit(base_offset + (ty+BY)*C + tx) : Op::init(); + } else { + cache[shared_idx((ty+BY+1)*BX+tx)] = + (ty+BY) + by*BY2 < B && tx+bx*BX < C ? + op.visit(base_offset + (ty+BY)*C + tx) : Op::init(); + } + if (ty == 0) { + cache[shared_idx(tx)] = Op::init(); + } + __syncthreads(); + uint32_t total, stride; + // first pass +#pragma unroll + for (total = BY, stride = 1; + total > 0; + total >>= 1, stride <<= 1) + { + if (ty < total) { + uint32_t ai = shared_idx(stride * (2*ty+1) * BX + tx); + uint32_t bi = shared_idx(stride * (2*ty+2) * BX + tx); + cache[bi] = Op::apply(cache[bi], cache[ai]); + } + if (total > warp_size/BX) __syncthreads(); + else cub::WARP_SYNC(0xffffffff); + } + // second pass +#pragma unroll + for (total = 1, stride = BY; + stride > 0; + total <<= 1, stride >>= 1) + { + if (total > warp_size/BX) __syncthreads(); + else cub::WARP_SYNC(0xffffffff); + if (ty < total) { + uint32_t ai = shared_idx(stride * (2*ty+0) * BX + tx); + uint32_t bi = shared_idx(stride * (2*ty+1) * BX + tx); + cache[bi] = Op::apply(cache[bi], cache[ai]); + } + } + __syncthreads(); + uint32_t ty_offset = (exclusive ? 0 : 1); + if (ty+by*BY2 < B && tx+bx*BX < C) { + if (reverse) { + dst[ty*C + tx] = cache[shared_idx((BY2-1-ty+ty_offset)*BX + tx)]; + } else { + dst[ty*C + tx] = cache[shared_idx((ty+ty_offset)*BX + tx)]; + } + } + if (ty+BY+by*BY2 < B && tx+bx*BX < C) { + if (reverse) { + dst[(ty+BY)*C + tx] = + cache[shared_idx((BY2-1-(ty+BY)+ty_offset)*BX + tx)]; + } else { + dst[(ty+BY)*C + tx] = + cache[shared_idx((ty+BY+ty_offset)*BX + tx)]; + } + } + if (block_sum && ty == 0 && bx*BX+tx < C) { + block_sum[(bz)*B_*C + (by)*C + (bx*BX) + tx] = + cache[shared_idx(BY2*BX + tx)]; + } +} + +template +__global__ void update_kernel(T *dst, const T *delta, + uint32_t A, uint32_t B, uint32_t C) { + const uint32_t BY2 = BY*2; + const uint32_t B_ = (B+BY2-1) / BY2; + const uint32_t C_ = (C+BX-1) / BX; + const uint32_t GX = C_; + const uint32_t GY = B_; + // src: (A, B, C) + // delta: (A, B_, C) + const uint32_t bx = blockIdx.x % GX; + const uint32_t by = blockIdx.x / GX % GY; + const uint32_t bz = blockIdx.x / GX / GY; + const uint32_t tx = threadIdx.x; + const uint32_t ty = threadIdx.y; + + if (tx + bx*BX < C) { + T delta_v = delta[(bz)*B_*C + (by)*C + (bx*BX) + tx]; + if (ty+by*BY2 < B && tx+bx*BX < C) { + T &res = dst[bz*B*C + (ty+by*BY2)*C + (tx+bx*BX)]; + res = Op::apply(res, delta_v); + } + if (ty+BY+by*BY2 < B && tx+bx*BX < C) { + T &res = dst[bz*B*C + (ty+BY+by*BY2)*C + (tx+bx*BX)]; + res = Op::apply(res, delta_v); + } + } +} + +template +void run_kern_multiAC(T* dst, T* workspace, uint32_t A, uint32_t B, + uint32_t C, const Op& op, cudaStream_t stream); + +template +void do_run_kern(T *dst, T *workspace, + uint32_t A, uint32_t B, uint32_t C, const Op &op, cudaStream_t stream) { + const uint32_t BY2 = BY*2; + const uint32_t B_ = (B+BY2-1)/BY2; + const uint32_t C_ = (C+BX-1)/BX; + + dim3 blocks(C_*B_*A); + dim3 threads(BX, BY); + + scan_kernel + <<>>( + dst, B > BY2 ? workspace : NULL, A, B, C, op); + if (B <= BY2) + return; + + run_kern_multiAC( + workspace, workspace + A*B_*C, A, B_, C, + Op::make_contig(workspace), stream); + update_kernel<<>>( + dst, workspace, A, B, C); +} + +template +void run_kern_multiAC(T* dst, T* workspace, uint32_t A, uint32_t B, uint32_t C, + const Op& op, cudaStream_t stream) { +#define IF(BX, BY) \ + do { \ + if (vBX == BX && vBY == BY) { \ + return do_run_kern( \ + dst, workspace, A, B, C, op, stream); \ + } \ + } while (0) + + uint32_t vBX, vBY; + get_BX_BY(A, B, C, vBX, vBY); + IF(1, 512); + IF(2, 256); + IF(4, 128); + IF(8, 64); + IF(16, 32); + IF(32, 16); + megdnn_trap(); +#undef IF +} + +//! wrap cub library for 1-dim scan +namespace cubwrap { + +template +class InputIterator : public std::iterator { + int m_offset, m_len; + Op m_op; + +public: + InputIterator(Op op, int len) : m_offset(0), m_len(len), m_op(op) {} + + __device__ InputIterator(int offset, int len, Op op) + : m_offset(offset), m_len(len), m_op(op) {} + + __device__ T operator[](int idx) { + idx += m_offset; + if (reverse) { + idx = m_len - 1 - idx; + } + return m_op.visit(idx); + } + + __device__ InputIterator operator+(int offset) { + return InputIterator(m_offset + offset, m_len, m_op); + } +}; + +template +class OutputIterator + : public std::iterator { + int m_offset, m_len; + T* m_dst; + +public: + OutputIterator(T* dst, int len) : m_offset(0), m_len(len), m_dst(dst) {} + + __device__ OutputIterator(int offset, int len, T* dst) + : m_offset(offset), m_len(len), m_dst(dst) {} + + __device__ T& operator[](int idx) { + idx += m_offset; + if (reverse) { + idx = m_len - 1 - idx; + } + return m_dst[idx]; + } + + __device__ OutputIterator operator+(int offset) { + return OutputIterator(m_offset + offset, m_len, m_dst); + } +}; + +template +struct ScanOp { + __device__ __host__ T operator()(T a, T b) { + // cub requires it to be a __device__ __host__ function but MegDNN has + // no such contraint on Op::apply; so we just trap on host +#ifdef __CUDA_ARCH__ + return Op::apply(a, b); +#else + megdnn_trap(); +#endif + } +}; + +template +void invoke(T* dst, void* workspace, size_t wk_size, const Op& op, uint32_t len, + cudaStream_t stream) { + InputIterator inp_iter(op, len); + OutputIterator out_iter(dst, len); + ScanOp scan_op; + + if (exclusive) { + cuda_check(cub::DeviceScan::ExclusiveScan(workspace, wk_size, inp_iter, + out_iter, scan_op, Op::init(), + len, stream)); + } else { + cuda_check(cub::DeviceScan::InclusiveScan( + workspace, wk_size, inp_iter, out_iter, scan_op, len, stream)); + } +} +} // namespace cubwrap + +} // namespace detail + +template +void run_kern(T* dst, void* workspace, uint32_t workspace_size, uint32_t A, + uint32_t B, uint32_t C, const Op& op, cudaStream_t stream) { + if (A == 1 && C == 1) { + return detail::cubwrap::invoke( + dst, workspace, workspace_size, op, B, stream); + } + + return detail::run_kern_multiAC( + dst, static_cast(workspace), A, B, C, op, stream); +} + +} // namespace cumprod +} // namespace cuda +} // namespace megdnn + + +// vim: ft=cuda syntax=cuda.doxygen diff --git a/dnn/src/cuda/cumprod/opr_impl.cpp b/dnn/src/cuda/cumprod/opr_impl.cpp new file mode 100644 index 00000000..13826930 --- /dev/null +++ b/dnn/src/cuda/cumprod/opr_impl.cpp @@ -0,0 +1,63 @@ +#include "./opr_impl.h" +#include "./kern.cuh" + +#include "src/common/reduce_helper_device.h" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; +using namespace cumprod; + +namespace { + +/*! + * \brief compute cumprod reduction on (A, B, C) tensor to (A, 1, C) + */ +template +void dispatch( + T* dst, T* workspace, size_t workspace_size, size_t A, size_t B, size_t C, + bool exclusive, bool reverse, const Op& op, cudaStream_t stream) { +#define IF(exclusive_v, reverse_v) \ + if (exclusive == exclusive_v && reverse == reverse_v) { \ + run_kern( \ + dst, workspace, workspace_size, A, B, C, op, stream); \ + return; \ + } + IF(true, true) + IF(true, false) + IF(false, true) + IF(false, false) + megdnn_assert_internal(false); +#undef IF +} + +} // anonymous namespace + +void CumprodForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { + check_exec(src.layout, dst.layout, workspace.size); + size_t A, B, C; + reduce::get_ABC(src.layout, A, B, C, param().axis); + auto stream = cuda_stream(handle()); +#define cb(DType) \ + if (src.layout.dtype == DType()) { \ + using ctype = DTypeTrait::ctype; \ + dispatch>( \ + dst.ptr(), workspace.ptr(), workspace.size, A, B, C, \ + param().exclusive, param().reverse, src.ptr(), stream); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb + megdnn_assert_internal(false); +} + +size_t CumprodForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout&) { + size_t A, B, C; + reduce::get_ABC(src, A, B, C, param().axis); + cuda_check(cudaSetDevice(concrete_handle(handle())->device_id())); + return cumprod::get_workspace_in_bytes(A, B, C, src.dtype.size()); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/cumprod/opr_impl.h b/dnn/src/cuda/cumprod/opr_impl.h new file mode 100644 index 00000000..418b35e2 --- /dev/null +++ b/dnn/src/cuda/cumprod/opr_impl.h @@ -0,0 +1,19 @@ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace cuda { + +class CumprodForwardImpl : public CumprodForward { +public: + using CumprodForward::CumprodForward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) override; +}; + +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cu new file mode 100644 index 00000000..cf04b51d --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_float16.cu new file mode 100644 index 00000000..bb023926 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_float16.cu @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_float32.cu new file mode 100644 index 00000000..9c8ef576 --- /dev/null +++ b/dnn/src/cuda/elemwise/kimpl/SAFE_DIV_dt_float32.cu @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SAFE_DIV_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SAFE_DIV_dt_qint8_dt_qint8.cu new file mode 100644 index 00000000..85d4d9ad --- /dev/null +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SAFE_DIV_dt_qint8_dt_qint8.cu @@ -0,0 +1,6 @@ +// generated by gen_elemwise_multi_type_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 +#include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp b/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp index 2f45e524..f381d7b5 100644 --- a/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp @@ -268,6 +268,7 @@ IMPL_MODE_DISPATCHER(2, dt_quint4, dt_quint4); #undef FOREACH #define FOREACH(cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(COND_LT_MOV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(CLIP, cb) diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index fee747c0..2669eedd 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -16,6 +16,7 @@ #include "src/cuda/convolution3d/opr_impl.h" #include "src/cuda/convpooling/opr_impl.h" #include "src/cuda/correlation/opr_impl.h" +#include "src/cuda/cumprod/opr_impl.h" #include "src/cuda/cumsum/opr_impl.h" #include "src/cuda/cvt_color/opr_impl.h" #include "src/cuda/dct/opr_impl.h" @@ -116,6 +117,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(SVDForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(CondTake); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(CumprodForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(CumsumForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward); diff --git a/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp b/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp index f0cf6c13..85e79ee1 100644 --- a/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp +++ b/dnn/src/fallback/elemwise/fallback_impl/opr_binary_impl.cpp @@ -266,6 +266,7 @@ INST(Mode::ATANH_GRAD); INST(Mode::SOFTPLUS_GRAD); INST(Mode::RELU6_GRAD); INST(Mode::HSIGMOID_GRAD); +INST(Mode::SAFE_DIV); #undef INST } // namespace fallback } // namespace megdnn diff --git a/dnn/src/naive/cumprod/opr_impl.cpp b/dnn/src/naive/cumprod/opr_impl.cpp new file mode 100644 index 00000000..078b552f --- /dev/null +++ b/dnn/src/naive/cumprod/opr_impl.cpp @@ -0,0 +1,72 @@ +#include "src/naive/cumprod/opr_impl.h" +#include "src/naive/handle.h" + +#include "src/common/reduce_helper.h" +#include "src/common/utils.h" + +namespace { + +template +void exec_internal( + const T* __restrict src, T* __restrict dst, size_t A, size_t B, size_t C, + bool exclusive, bool reverse) { + for (size_t a = 0; a < A; ++a) + for (size_t c = 0; c < C; ++c) { + if (exclusive && reverse) { + T prod = T(1); + for (size_t b = B; b > 0; --b) { + dst[a * B * C + (b - 1) * C + c] = prod; + prod *= src[a * B * C + (b - 1) * C + c]; + } + } else if (exclusive && !reverse) { + T prod = T(1); + for (size_t b = 0; b < B; ++b) { + dst[a * B * C + b * C + c] = prod; + prod *= src[a * B * C + b * C + c]; + } + } else if (!exclusive && reverse) { + T prod = T(1); + for (size_t b = B; b > 0; --b) { + prod *= src[a * B * C + (b - 1) * C + c]; + dst[a * B * C + (b - 1) * C + c] = prod; + } + } else if (!exclusive && !reverse) { + T prod = T(1); + for (size_t b = 0; b < B; ++b) { + prod *= src[a * B * C + b * C + c]; + dst[a * B * C + b * C + c] = prod; + } + } else { + megdnn_assert_internal(false); + } + } +} + +} // anonymous namespace + +namespace megdnn { +namespace naive { + +void CumprodForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(src.layout, dst.layout, workspace.size); + + size_t A, B, C; + reduce::get_ABC(src.layout, A, B, C, param().axis); +#define cb(DType) \ + if (src.layout.dtype == DType()) { \ + using ctype = DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal( \ + src.ptr(), dst.ptr(), A, B, C, param().exclusive, \ + param().reverse)); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + megdnn_assert_internal(0); +#undef cb +} + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/cumprod/opr_impl.h b/dnn/src/naive/cumprod/opr_impl.h new file mode 100644 index 00000000..29b3e88e --- /dev/null +++ b/dnn/src/naive/cumprod/opr_impl.h @@ -0,0 +1,20 @@ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class CumprodForwardImpl : public CumprodForward { +public: + using CumprodForward::CumprodForward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { + return 0; + } +}; + +} // namespace naive +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cpp b/dnn/src/naive/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cpp new file mode 100644 index 00000000..cf04b51d --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SAFE_DIV_dt_float16.cpp b/dnn/src/naive/elemwise/kimpl/SAFE_DIV_dt_float16.cpp new file mode 100644 index 00000000..bb023926 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SAFE_DIV_dt_float16.cpp @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/naive/elemwise/kimpl/SAFE_DIV_dt_float32.cpp b/dnn/src/naive/elemwise/kimpl/SAFE_DIV_dt_float32.cpp new file mode 100644 index 00000000..9c8ef576 --- /dev/null +++ b/dnn/src/naive/elemwise/kimpl/SAFE_DIV_dt_float32.cpp @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/naive/elemwise/opr_impl.cpp b/dnn/src/naive/elemwise/opr_impl.cpp index 243aa066..cd48b7f6 100644 --- a/dnn/src/naive/elemwise/opr_impl.cpp +++ b/dnn/src/naive/elemwise/opr_impl.cpp @@ -5,6 +5,8 @@ #include "src/naive/elemwise/kern_caller.h" #include "src/naive/handle.h" +#include + #include "midout.h" MIDOUT_DECL(megdnn_naive_elemwise) diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 21d23ad6..d7d06966 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -18,6 +18,7 @@ #include "src/naive/convolution3d/opr_impl.h" #include "src/naive/convpooling/opr_impl.h" #include "src/naive/correlation/opr_impl.h" +#include "src/naive/cumprod/opr_impl.h" #include "src/naive/cumsum/opr_impl.h" #include "src/naive/cvt_color/opr_impl.h" #include "src/naive/dct/opr_impl.h" diff --git a/dnn/src/rocm/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cpp.hip new file mode 100644 index 00000000..94de95ac --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SAFE_DIV_dt_bfloat16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SAFE_DIV_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SAFE_DIV_dt_float16.cpp.hip new file mode 100644 index 00000000..af3eaa2b --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SAFE_DIV_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SAFE_DIV_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SAFE_DIV_dt_float32.cpp.hip new file mode 100644 index 00000000..19234561 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SAFE_DIV_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SAFE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/test/common/elemwise.cpp b/dnn/test/common/elemwise.cpp index 2ef46962..d3123a65 100644 --- a/dnn/test/common/elemwise.cpp +++ b/dnn/test/common/elemwise.cpp @@ -815,7 +815,7 @@ DEF_TEST(all_modes) { checker.set_rng(0, &abslt1_rng_f32); } else if ( mode == Mode::MOD || mode == Mode::TRUE_DIV || - mode == Mode::FLOOR_DIV) { + mode == Mode::FLOOR_DIV || mode == Mode::SAFE_DIV) { if (dtype.category() == DTypeCategory::INT) { checker.set_rng(0, &default_rng_i32); checker.set_rng(1, &nonzero_rng_i32); diff --git a/dnn/test/cuda/cumprod.cpp b/dnn/test/cuda/cumprod.cpp new file mode 100644 index 00000000..f7fbe7a9 --- /dev/null +++ b/dnn/test/cuda/cumprod.cpp @@ -0,0 +1,63 @@ +#include "test/cuda/fixture.h" + +#include "megdnn/oprs.h" +#include "test/common/checker.h" + +namespace megdnn { +namespace test { + +TEST_F(CUDA, CUMPROD) { + Checker checker(handle_cuda()); + struct TestArg { + param::Cumprod param; + TensorShape shape; + TestArg(param::Cumprod param, TensorShape shape) : param(param), shape(shape) {} + }; + std::vector args, args_int32; + for (auto shape : + TensorShapeArray{{10000}, {33000, 33}, {100, 100, 100}, {30, 30, 30, 30}}) { + for (size_t axis = 0; axis < shape.ndim; ++axis) { + args.emplace_back(param::Cumprod(axis, true, true), shape); + args.emplace_back(param::Cumprod(axis, true, false), shape); + args.emplace_back(param::Cumprod(axis, false, true), shape); + args.emplace_back(param::Cumprod(axis, false, false), shape); + } + } + for (auto shape : TensorShapeArray{{1}, {10}, {100}, {1000}, {10000}, {100000}}) { + args.emplace_back(param::Cumprod(0, true, true), shape); + args.emplace_back(param::Cumprod(0, true, false), shape); + args.emplace_back(param::Cumprod(0, false, true), shape); + args.emplace_back(param::Cumprod(0, false, false), shape); + } + for (auto shape : TensorShapeArray{ + {1}, + {10}, + {100}, + {1000}, + {10000}, + {100000}, + {1000000}, + {1050000}, + {2100000}}) { + args_int32.emplace_back(param::Cumprod(0, true, true), shape); + args_int32.emplace_back(param::Cumprod(0, true, false), shape); + args_int32.emplace_back(param::Cumprod(0, false, true), shape); + args_int32.emplace_back(param::Cumprod(0, false, false), shape); + } + for (auto arg : args) { + checker.set_param(arg.param); + checker.set_epsilon(1e-2); + checker.set_dtype(0, dtype::Float32()).execs({{arg.shape}, {}}); + checker.set_dtype(0, dtype::Int16()).execs({{arg.shape}, {}}); + checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}}); + } + for (auto arg : args_int32) { + checker.set_param(arg.param); + checker.set_epsilon(1e-2); + checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}}); + } +} + +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/test/naive/record1.cpp b/dnn/test/naive/record1.cpp index a715782a..cfa84c3f 100644 --- a/dnn/test/naive/record1.cpp +++ b/dnn/test/naive/record1.cpp @@ -494,6 +494,49 @@ TEST_F(NAIVE, CONV3D_RECORD) { } } +//! cumprod +TEST_F(NAIVE, CUMPROD_RECORD) { + TaskRecordChecker checker(2); + struct TestArg { + param::Cumprod param; + TensorShape shape; + TestArg(param::Cumprod param, TensorShape shape) : param(param), shape(shape) {} + }; + std::vector args, args_int32; + for (auto shape : TensorShapeArray{{1000}, {330, 33}, {10, 10, 10}, {5, 5, 5, 5}}) { + for (size_t axis = 0; axis < shape.ndim; ++axis) { + args.emplace_back(param::Cumprod(axis, true, true), shape); + args.emplace_back(param::Cumprod(axis, true, false), shape); + args.emplace_back(param::Cumprod(axis, false, true), shape); + args.emplace_back(param::Cumprod(axis, false, false), shape); + } + } + for (auto shape : TensorShapeArray{{1}, {10}, {100}, {1000}, {10000}}) { + args.emplace_back(param::Cumprod(0, true, true), shape); + args.emplace_back(param::Cumprod(0, true, false), shape); + args.emplace_back(param::Cumprod(0, false, true), shape); + args.emplace_back(param::Cumprod(0, false, false), shape); + } + for (auto shape : TensorShapeArray{{1}, {10}, {100}, {1000}, {10000}}) { + args_int32.emplace_back(param::Cumprod(0, true, true), shape); + args_int32.emplace_back(param::Cumprod(0, true, false), shape); + args_int32.emplace_back(param::Cumprod(0, false, true), shape); + args_int32.emplace_back(param::Cumprod(0, false, false), shape); + } + for (auto arg : args) { + checker.set_param(arg.param); + checker.set_epsilon(1e-2); + checker.set_dtype(0, dtype::Float32()).execs({{arg.shape}, {}}); + checker.set_dtype(0, dtype::Int16()).execs({{arg.shape}, {}}); + checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}}); + } + for (auto arg : args_int32) { + checker.set_param(arg.param); + checker.set_epsilon(1e-2); + checker.set_dtype(0, dtype::Int32()).execs({{arg.shape}, {}}); + } +} + //! cumsum TEST_F(NAIVE, CUMSUM_RECORD) { TaskRecordChecker checker(2); diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 24eabf54..7ce9b2de 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -27,6 +27,7 @@ __all__ = [ "broadcast_to", "concat", "cond_take", + "cumprod", "cumsum", "diag", "expand_dims", @@ -1139,6 +1140,22 @@ def cumsum(inp: Tensor, axis: int): Tensor([[ 1 3 6] [ 4 9 15]], dtype=int32, device=xpux:0) """ - assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor" op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False) return apply(op, inp)[0] + + +def cumprod(inp: Tensor, axis: int): + r"""Computes the cumulative product of elements along given axis. + + Args: + inp: input tensor. + axis: axis along which cumprod is performed. + + Examples: + >>> x = Tensor([[1, 2, 3], [4, 5, 6]], "int32") + >>> F.cumprod(x, 1) + Tensor([[ 1 2 6] + [ 4 20 120]], dtype=int32, device=xpux:0) + """ + op = builtin.Cumprod(axis=axis, exclusive=False, reverse=False) + return apply(op, inp)[0] diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index a404b7c4..770149fe 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -501,6 +501,22 @@ def test_dot(): np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy()) +def test_cumprod(): + x = mge.Parameter(F.full((2, 2), 2.0, dtype="float32")) + + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + + def f(x): + return F.cumprod(x, axis=0) + + y = f(x) + grad(y, F.ones_like(y)) + + expected = np.array([[3.0, 3.0], [2.0, 2.0]], dtype=np.float32) + np.testing.assert_almost_equal(x.grad.numpy(), expected) + + def test_pixel_shuffle(): x = np.random.rand(2, 3, 16, 3, 4).astype("float32") diff --git a/imperative/src/impl/ops/cumxxx.cpp b/imperative/src/impl/ops/cumxxx.cpp new file mode 100644 index 00000000..c05a235e --- /dev/null +++ b/imperative/src/impl/ops/cumxxx.cpp @@ -0,0 +1,34 @@ +#include "../blob_manager_impl.h" +#include "../dnn_op_helper.h" +#include "../op_trait.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/misc.h" + +namespace mgb::imperative { + +namespace { +namespace cumsum { + +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + OperatorNodeConfig config{op.make_name()}; + return opr::Cumsum::make(inputs[0], op.param(), config); +} + +OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace cumsum +} // namespace + +namespace { +namespace cumprod { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = def.cast_final_safe(); + OperatorNodeConfig config{op.make_name()}; + return opr::Cumprod::make(inputs[0], op.param(), config); +} + +OP_TRAIT_REG(Cumprod, Cumprod).apply_on_var_node(apply_on_var_node).fallback(); +} // namespace cumprod +} // namespace + +} // namespace mgb::imperative diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index b049c432..3b7b4d38 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -652,18 +652,6 @@ OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose) } // namespace sliding_window_transpose } // namespace -namespace { -namespace cumsum { -auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { - auto&& op = static_cast(def); - OperatorNodeConfig config{op.make_name()}; - return opr::Cumsum::make(inputs[0], op.param(), config); -} - -OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback(); -} // namespace cumsum -} // namespace - namespace lrn { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index d0a73119..39699415 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ -905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py -759bfbf27fd3f0dd6b6edf06377e1d6b ../../src/core/include/megbrain/ir/ops.td -2a5851d0e2470d4d045811e7a20b1a3f generated/opdef.h.inl -55b862badeed19aed8e84c5d6f468ff2 generated/opdef.cpp.inl -f3f4c7f0ee1b39392df8a679f6d22596 generated/opdef.py.inl -6b11ca844a7855fdc5eebffaf563a89c generated/opdef.cpy.inl +792a91e469d151906d5275ebd65cedf0 ../../dnn/scripts/opr_param_defs.py +dcf9ae8b7881e9f93870aaed1b18f1dd ../../src/core/include/megbrain/ir/ops.td +59bf6c34770e60b030d962484f2eac3b generated/opdef.h.inl +d406f17445b9cec6c34b77377f15a61f generated/opdef.cpp.inl +67fd2dee346d795220231fc266d260cc generated/opdef.py.inl +65740fff1c79fdb2f955dd52df88fb6e generated/opdef.cpy.inl 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index a4b4adf9..6c5ec041 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -2303,6 +2303,49 @@ OP_TRAIT_REG(Correlation, Correlation) .props(Correlation_props_impl) .make_name(Correlation_make_name_impl); +MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumprod); + +namespace { +size_t Cumprod_hash_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + size_t val = mgb::hash(op_.dyn_typeinfo()); + val = mgb::hash_pair_combine(val, mgb::hash(op_.axis)); + val = mgb::hash_pair_combine(val, mgb::hash(op_.exclusive)); + val = mgb::hash_pair_combine(val, mgb::hash(op_.reverse)); + return val; +} +bool Cumprod_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { + auto &&a_ = lhs_.cast_final_safe(), + &&b_ = rhs_.cast_final_safe(); + static_cast(a_); + static_cast(b_); + if (a_.axis != b_.axis) return false; + if (a_.exclusive != b_.exclusive) return false; + if (a_.reverse != b_.reverse) return false; + return true; +} +std::vector> Cumprod_props_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + std::vector> props_; + props_.emplace_back("axis", std::to_string(op_.axis)); + props_.emplace_back("exclusive", std::to_string(op_.exclusive)); + props_.emplace_back("reverse", std::to_string(op_.reverse)); + return props_; +} +std::string Cumprod_make_name_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + return "Cumprod"; +} +} // anonymous namespace +OP_TRAIT_REG(Cumprod, Cumprod) + .hash(Cumprod_hash_impl) + .is_same_st(Cumprod_is_same_st_impl) + .props(Cumprod_props_impl) + .make_name(Cumprod_make_name_impl); + MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumsum); namespace { @@ -3067,6 +3110,75 @@ std::vector> Elemwise_props_impl(const OpDef case Elemwise::Mode::COND_LT_MOV: props_.emplace_back("mode", "COND_LT_MOV"); break; + case Elemwise::Mode::SINH: + props_.emplace_back("mode", "SINH"); + break; + case Elemwise::Mode::COSH: + props_.emplace_back("mode", "COSH"); + break; + case Elemwise::Mode::ASINH: + props_.emplace_back("mode", "ASINH"); + break; + case Elemwise::Mode::ACOSH: + props_.emplace_back("mode", "ACOSH"); + break; + case Elemwise::Mode::ATANH: + props_.emplace_back("mode", "ATANH"); + break; + case Elemwise::Mode::TAN: + props_.emplace_back("mode", "TAN"); + break; + case Elemwise::Mode::ASINH_GRAD: + props_.emplace_back("mode", "ASINH_GRAD"); + break; + case Elemwise::Mode::ACOSH_GRAD: + props_.emplace_back("mode", "ACOSH_GRAD"); + break; + case Elemwise::Mode::ATANH_GRAD: + props_.emplace_back("mode", "ATANH_GRAD"); + break; + case Elemwise::Mode::PRELU: + props_.emplace_back("mode", "PRELU"); + break; + case Elemwise::Mode::CLIP: + props_.emplace_back("mode", "CLIP"); + break; + case Elemwise::Mode::PRELU_GRAD: + props_.emplace_back("mode", "PRELU_GRAD"); + break; + case Elemwise::Mode::SOFTPLUS: + props_.emplace_back("mode", "SOFTPLUS"); + break; + case Elemwise::Mode::SOFTPLUS_GRAD: + props_.emplace_back("mode", "SOFTPLUS_GRAD"); + break; + case Elemwise::Mode::RELU6: + props_.emplace_back("mode", "RELU6"); + break; + case Elemwise::Mode::RELU6_GRAD: + props_.emplace_back("mode", "RELU6_GRAD"); + break; + case Elemwise::Mode::HSIGMOID: + props_.emplace_back("mode", "HSIGMOID"); + break; + case Elemwise::Mode::HSIGMOID_GRAD: + props_.emplace_back("mode", "HSIGMOID_GRAD"); + break; + case Elemwise::Mode::LOGSIGMOID: + props_.emplace_back("mode", "LOGSIGMOID"); + break; + case Elemwise::Mode::SQRT: + props_.emplace_back("mode", "SQRT"); + break; + case Elemwise::Mode::SQUARE: + props_.emplace_back("mode", "SQUARE"); + break; + case Elemwise::Mode::SIGN: + props_.emplace_back("mode", "SIGN"); + break; + case Elemwise::Mode::SAFE_DIV: + props_.emplace_back("mode", "SAFE_DIV"); + break; case Elemwise::Mode::NEQ: props_.emplace_back("mode", "NEQ"); break; diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index 68704332..0cc0d4f6 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -6674,6 +6674,131 @@ void _init_py_Correlation(py::module m) { mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Correlation::typeinfo(), &py_type).second); } +PyOpDefBegin(Cumprod) // { + static PyGetSetDef py_getsetters[]; + static PyMethodDef tp_methods[]; + + static PyObject* getstate(PyObject* self, PyObject*) { + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + std::unordered_map state { + + {"axis", serialization::dump(opdef.axis)}, + {"exclusive", serialization::dump(opdef.exclusive)}, + {"reverse", serialization::dump(opdef.reverse)} + }; + return py::cast(state).release().ptr(); + } + static PyObject* setstate(PyObject* self, PyObject* args) { + PyObject* dict = PyTuple_GetItem(args, 0); + if (!dict) return NULL; + auto state = py::cast>(dict); + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + + { + auto&& iter = state.find("axis"); + if (iter != state.end()) { + opdef.axis = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("exclusive"); + if (iter != state.end()) { + opdef.exclusive = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("reverse"); + if (iter != state.end()) { + opdef.reverse = serialization::load(iter->second); + } + } + Py_RETURN_NONE; + } + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); +// }; +PyOpDefEnd(Cumprod) + +int PyOp(Cumprod)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { + static const char* kwlist[] = {"axis", "exclusive", "reverse", "scope", NULL}; + PyObject *axis = NULL, *exclusive = NULL, *reverse = NULL, *scope = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOO", const_cast(kwlist), &axis, &exclusive, &reverse, &scope)) + return -1; + + if (axis) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().axis = + py::cast(py::handle(axis)); + } CATCH_ALL(-1) + } + + if (exclusive) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().exclusive = + py::cast(py::handle(exclusive)); + } CATCH_ALL(-1) + } + + if (reverse) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().reverse = + py::cast(py::handle(reverse)); + } CATCH_ALL(-1) + } + + if (scope) { + try { + reinterpret_cast(self)->op + ->set_scope(py::cast(py::handle(scope))); + } CATCH_ALL(-1) + } + + return 0; +} + +PyGetSetDef PyOp(Cumprod)::py_getsetters[] = { + {const_cast("axis"), py_get_generic(Cumprod, axis), py_set_generic(Cumprod, axis), const_cast("axis"), NULL}, + {const_cast("exclusive"), py_get_generic(Cumprod, exclusive), py_set_generic(Cumprod, exclusive), const_cast("exclusive"), NULL}, + {const_cast("reverse"), py_get_generic(Cumprod, reverse), py_set_generic(Cumprod, reverse), const_cast("reverse"), NULL}, + {NULL} /* Sentinel */ +}; + + PyMethodDef PyOp(Cumprod)::tp_methods[] = { + {const_cast("__getstate__"), PyOp(Cumprod)::getstate, METH_NOARGS, "Cumprod getstate"}, + {const_cast("__setstate__"), PyOp(Cumprod)::setstate, METH_VARARGS, "Cumprod setstate"}, + {NULL} /* Sentinel */ + }; + +void _init_py_Cumprod(py::module m) { + using py_op = PyOp(Cumprod); + auto& py_type = PyOpType(Cumprod); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.Cumprod"; + py_type.tp_basicsize = sizeof(PyOp(Cumprod)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "Cumprod"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + py_type.tp_init = py_op::py_init; + py_type.tp_methods = py_op::tp_methods; + py_type.tp_getset = py_op::py_getsetters; + mgb_assert(PyType_Ready(&py_type) >= 0); + + PyType_Modified(&py_type); + m.add_object("Cumprod", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(Cumprod::typeinfo(), &py_type).second); +} + PyOpDefBegin(Cumsum) // { static PyGetSetDef py_getsetters[]; static PyMethodDef tp_methods[]; @@ -8010,16 +8135,16 @@ void _init_py_Dropout(py::module m) { template<> struct EnumTrait { static constexpr const char *name = "Elemwise.Mode"; - static constexpr std::underlying_type_t max = 64 - 1; + static constexpr std::underlying_type_t max = 87 - 1; }; template<> PyTypeObject* EnumWrapper::type = nullptr; template<> const char* -EnumWrapper::members[] = {"RELU", "ABS", "ACOS", "ASIN", "CEIL", "COS", "EXP", "EXPM1", "FLOOR", "LOG", "LOG1P", "NEGATE", "SIGMOID", "SIN", "TANH", "ABS_GRAD", "ADD", "FLOOR_DIV", "MAX", "MIN", "MOD", "MUL", "POW", "SIGMOID_GRAD", "SUB", "SWITCH_GT0", "TANH_GRAD", "TRUE_DIV", "LOG_SUM_EXP", "LT", "LEQ", "EQ", "SHL", "SHR", "COND_LEQ_MOV", "FUSE_MUL_ADD3", "FUSE_MUL_ADD4", "FUSE_ADD_RELU", "FUSE_ADD_SIGMOID", "FUSE_ADD_TANH", "FAST_TANH", "FAST_TANH_GRAD", "ROUND", "RMULH", "ATAN2", "ERF", "ERFINV", "ERFC", "ERFCINV", "H_SWISH", "H_SWISH_GRAD", "FUSE_ADD_H_SWISH", "NOT", "AND", "OR", "XOR", "SILU", "SILU_GRAD", "GELU", "GELU_GRAD", "COND_LT_MOV", "NEQ", "ISNAN", "ISINF"}; +EnumWrapper::members[] = {"RELU", "ABS", "ACOS", "ASIN", "CEIL", "COS", "EXP", "EXPM1", "FLOOR", "LOG", "LOG1P", "NEGATE", "SIGMOID", "SIN", "TANH", "ABS_GRAD", "ADD", "FLOOR_DIV", "MAX", "MIN", "MOD", "MUL", "POW", "SIGMOID_GRAD", "SUB", "SWITCH_GT0", "TANH_GRAD", "TRUE_DIV", "LOG_SUM_EXP", "LT", "LEQ", "EQ", "SHL", "SHR", "COND_LEQ_MOV", "FUSE_MUL_ADD3", "FUSE_MUL_ADD4", "FUSE_ADD_RELU", "FUSE_ADD_SIGMOID", "FUSE_ADD_TANH", "FAST_TANH", "FAST_TANH_GRAD", "ROUND", "RMULH", "ATAN2", "ERF", "ERFINV", "ERFC", "ERFCINV", "H_SWISH", "H_SWISH_GRAD", "FUSE_ADD_H_SWISH", "NOT", "AND", "OR", "XOR", "SILU", "SILU_GRAD", "GELU", "GELU_GRAD", "COND_LT_MOV", "SINH", "COSH", "ASINH", "ACOSH", "ATANH", "TAN", "ASINH_GRAD", "ACOSH_GRAD", "ATANH_GRAD", "PRELU", "CLIP", "PRELU_GRAD", "SOFTPLUS", "SOFTPLUS_GRAD", "RELU6", "RELU6_GRAD", "HSIGMOID", "HSIGMOID_GRAD", "LOGSIGMOID", "SQRT", "SQUARE", "SIGN", "SAFE_DIV", "NEQ", "ISNAN", "ISINF"}; template<> std::unordered_map -EnumWrapper::mem2value = {{normalize_enum("RELU"), Elemwise::Mode::RELU}, {normalize_enum("ABS"), Elemwise::Mode::ABS}, {normalize_enum("ACOS"), Elemwise::Mode::ACOS}, {normalize_enum("ASIN"), Elemwise::Mode::ASIN}, {normalize_enum("CEIL"), Elemwise::Mode::CEIL}, {normalize_enum("COS"), Elemwise::Mode::COS}, {normalize_enum("EXP"), Elemwise::Mode::EXP}, {normalize_enum("EXPM1"), Elemwise::Mode::EXPM1}, {normalize_enum("FLOOR"), Elemwise::Mode::FLOOR}, {normalize_enum("LOG"), Elemwise::Mode::LOG}, {normalize_enum("LOG1P"), Elemwise::Mode::LOG1P}, {normalize_enum("NEGATE"), Elemwise::Mode::NEGATE}, {normalize_enum("SIGMOID"), Elemwise::Mode::SIGMOID}, {normalize_enum("SIN"), Elemwise::Mode::SIN}, {normalize_enum("TANH"), Elemwise::Mode::TANH}, {normalize_enum("ABS_GRAD"), Elemwise::Mode::ABS_GRAD}, {normalize_enum("ADD"), Elemwise::Mode::ADD}, {normalize_enum("FLOOR_DIV"), Elemwise::Mode::FLOOR_DIV}, {normalize_enum("MAX"), Elemwise::Mode::MAX}, {normalize_enum("MIN"), Elemwise::Mode::MIN}, {normalize_enum("MOD"), Elemwise::Mode::MOD}, {normalize_enum("MUL"), Elemwise::Mode::MUL}, {normalize_enum("POW"), Elemwise::Mode::POW}, {normalize_enum("SIGMOID_GRAD"), Elemwise::Mode::SIGMOID_GRAD}, {normalize_enum("SUB"), Elemwise::Mode::SUB}, {normalize_enum("SWITCH_GT0"), Elemwise::Mode::SWITCH_GT0}, {normalize_enum("TANH_GRAD"), Elemwise::Mode::TANH_GRAD}, {normalize_enum("TRUE_DIV"), Elemwise::Mode::TRUE_DIV}, {normalize_enum("LOG_SUM_EXP"), Elemwise::Mode::LOG_SUM_EXP}, {normalize_enum("LT"), Elemwise::Mode::LT}, {normalize_enum("LEQ"), Elemwise::Mode::LEQ}, {normalize_enum("EQ"), Elemwise::Mode::EQ}, {normalize_enum("SHL"), Elemwise::Mode::SHL}, {normalize_enum("SHR"), Elemwise::Mode::SHR}, {normalize_enum("COND_LEQ_MOV"), Elemwise::Mode::COND_LEQ_MOV}, {normalize_enum("FUSE_MUL_ADD3"), Elemwise::Mode::FUSE_MUL_ADD3}, {normalize_enum("FUSE_MUL_ADD4"), Elemwise::Mode::FUSE_MUL_ADD4}, {normalize_enum("FUSE_ADD_RELU"), Elemwise::Mode::FUSE_ADD_RELU}, {normalize_enum("FUSE_ADD_SIGMOID"), Elemwise::Mode::FUSE_ADD_SIGMOID}, {normalize_enum("FUSE_ADD_TANH"), Elemwise::Mode::FUSE_ADD_TANH}, {normalize_enum("FAST_TANH"), Elemwise::Mode::FAST_TANH}, {normalize_enum("FAST_TANH_GRAD"), Elemwise::Mode::FAST_TANH_GRAD}, {normalize_enum("ROUND"), Elemwise::Mode::ROUND}, {normalize_enum("RMULH"), Elemwise::Mode::RMULH}, {normalize_enum("ATAN2"), Elemwise::Mode::ATAN2}, {normalize_enum("ERF"), Elemwise::Mode::ERF}, {normalize_enum("ERFINV"), Elemwise::Mode::ERFINV}, {normalize_enum("ERFC"), Elemwise::Mode::ERFC}, {normalize_enum("ERFCINV"), Elemwise::Mode::ERFCINV}, {normalize_enum("H_SWISH"), Elemwise::Mode::H_SWISH}, {normalize_enum("H_SWISH_GRAD"), Elemwise::Mode::H_SWISH_GRAD}, {normalize_enum("FUSE_ADD_H_SWISH"), Elemwise::Mode::FUSE_ADD_H_SWISH}, {normalize_enum("NOT"), Elemwise::Mode::NOT}, {normalize_enum("AND"), Elemwise::Mode::AND}, {normalize_enum("OR"), Elemwise::Mode::OR}, {normalize_enum("XOR"), Elemwise::Mode::XOR}, {normalize_enum("SILU"), Elemwise::Mode::SILU}, {normalize_enum("SILU_GRAD"), Elemwise::Mode::SILU_GRAD}, {normalize_enum("GELU"), Elemwise::Mode::GELU}, {normalize_enum("GELU_GRAD"), Elemwise::Mode::GELU_GRAD}, {normalize_enum("COND_LT_MOV"), Elemwise::Mode::COND_LT_MOV}, {normalize_enum("NEQ"), Elemwise::Mode::NEQ}, {normalize_enum("ISNAN"), Elemwise::Mode::ISNAN}, {normalize_enum("ISINF"), Elemwise::Mode::ISINF}}; -template<> PyObject* EnumWrapper::pyobj_insts[64] = {nullptr}; +EnumWrapper::mem2value = {{normalize_enum("RELU"), Elemwise::Mode::RELU}, {normalize_enum("ABS"), Elemwise::Mode::ABS}, {normalize_enum("ACOS"), Elemwise::Mode::ACOS}, {normalize_enum("ASIN"), Elemwise::Mode::ASIN}, {normalize_enum("CEIL"), Elemwise::Mode::CEIL}, {normalize_enum("COS"), Elemwise::Mode::COS}, {normalize_enum("EXP"), Elemwise::Mode::EXP}, {normalize_enum("EXPM1"), Elemwise::Mode::EXPM1}, {normalize_enum("FLOOR"), Elemwise::Mode::FLOOR}, {normalize_enum("LOG"), Elemwise::Mode::LOG}, {normalize_enum("LOG1P"), Elemwise::Mode::LOG1P}, {normalize_enum("NEGATE"), Elemwise::Mode::NEGATE}, {normalize_enum("SIGMOID"), Elemwise::Mode::SIGMOID}, {normalize_enum("SIN"), Elemwise::Mode::SIN}, {normalize_enum("TANH"), Elemwise::Mode::TANH}, {normalize_enum("ABS_GRAD"), Elemwise::Mode::ABS_GRAD}, {normalize_enum("ADD"), Elemwise::Mode::ADD}, {normalize_enum("FLOOR_DIV"), Elemwise::Mode::FLOOR_DIV}, {normalize_enum("MAX"), Elemwise::Mode::MAX}, {normalize_enum("MIN"), Elemwise::Mode::MIN}, {normalize_enum("MOD"), Elemwise::Mode::MOD}, {normalize_enum("MUL"), Elemwise::Mode::MUL}, {normalize_enum("POW"), Elemwise::Mode::POW}, {normalize_enum("SIGMOID_GRAD"), Elemwise::Mode::SIGMOID_GRAD}, {normalize_enum("SUB"), Elemwise::Mode::SUB}, {normalize_enum("SWITCH_GT0"), Elemwise::Mode::SWITCH_GT0}, {normalize_enum("TANH_GRAD"), Elemwise::Mode::TANH_GRAD}, {normalize_enum("TRUE_DIV"), Elemwise::Mode::TRUE_DIV}, {normalize_enum("LOG_SUM_EXP"), Elemwise::Mode::LOG_SUM_EXP}, {normalize_enum("LT"), Elemwise::Mode::LT}, {normalize_enum("LEQ"), Elemwise::Mode::LEQ}, {normalize_enum("EQ"), Elemwise::Mode::EQ}, {normalize_enum("SHL"), Elemwise::Mode::SHL}, {normalize_enum("SHR"), Elemwise::Mode::SHR}, {normalize_enum("COND_LEQ_MOV"), Elemwise::Mode::COND_LEQ_MOV}, {normalize_enum("FUSE_MUL_ADD3"), Elemwise::Mode::FUSE_MUL_ADD3}, {normalize_enum("FUSE_MUL_ADD4"), Elemwise::Mode::FUSE_MUL_ADD4}, {normalize_enum("FUSE_ADD_RELU"), Elemwise::Mode::FUSE_ADD_RELU}, {normalize_enum("FUSE_ADD_SIGMOID"), Elemwise::Mode::FUSE_ADD_SIGMOID}, {normalize_enum("FUSE_ADD_TANH"), Elemwise::Mode::FUSE_ADD_TANH}, {normalize_enum("FAST_TANH"), Elemwise::Mode::FAST_TANH}, {normalize_enum("FAST_TANH_GRAD"), Elemwise::Mode::FAST_TANH_GRAD}, {normalize_enum("ROUND"), Elemwise::Mode::ROUND}, {normalize_enum("RMULH"), Elemwise::Mode::RMULH}, {normalize_enum("ATAN2"), Elemwise::Mode::ATAN2}, {normalize_enum("ERF"), Elemwise::Mode::ERF}, {normalize_enum("ERFINV"), Elemwise::Mode::ERFINV}, {normalize_enum("ERFC"), Elemwise::Mode::ERFC}, {normalize_enum("ERFCINV"), Elemwise::Mode::ERFCINV}, {normalize_enum("H_SWISH"), Elemwise::Mode::H_SWISH}, {normalize_enum("H_SWISH_GRAD"), Elemwise::Mode::H_SWISH_GRAD}, {normalize_enum("FUSE_ADD_H_SWISH"), Elemwise::Mode::FUSE_ADD_H_SWISH}, {normalize_enum("NOT"), Elemwise::Mode::NOT}, {normalize_enum("AND"), Elemwise::Mode::AND}, {normalize_enum("OR"), Elemwise::Mode::OR}, {normalize_enum("XOR"), Elemwise::Mode::XOR}, {normalize_enum("SILU"), Elemwise::Mode::SILU}, {normalize_enum("SILU_GRAD"), Elemwise::Mode::SILU_GRAD}, {normalize_enum("GELU"), Elemwise::Mode::GELU}, {normalize_enum("GELU_GRAD"), Elemwise::Mode::GELU_GRAD}, {normalize_enum("COND_LT_MOV"), Elemwise::Mode::COND_LT_MOV}, {normalize_enum("SINH"), Elemwise::Mode::SINH}, {normalize_enum("COSH"), Elemwise::Mode::COSH}, {normalize_enum("ASINH"), Elemwise::Mode::ASINH}, {normalize_enum("ACOSH"), Elemwise::Mode::ACOSH}, {normalize_enum("ATANH"), Elemwise::Mode::ATANH}, {normalize_enum("TAN"), Elemwise::Mode::TAN}, {normalize_enum("ASINH_GRAD"), Elemwise::Mode::ASINH_GRAD}, {normalize_enum("ACOSH_GRAD"), Elemwise::Mode::ACOSH_GRAD}, {normalize_enum("ATANH_GRAD"), Elemwise::Mode::ATANH_GRAD}, {normalize_enum("PRELU"), Elemwise::Mode::PRELU}, {normalize_enum("CLIP"), Elemwise::Mode::CLIP}, {normalize_enum("PRELU_GRAD"), Elemwise::Mode::PRELU_GRAD}, {normalize_enum("SOFTPLUS"), Elemwise::Mode::SOFTPLUS}, {normalize_enum("SOFTPLUS_GRAD"), Elemwise::Mode::SOFTPLUS_GRAD}, {normalize_enum("RELU6"), Elemwise::Mode::RELU6}, {normalize_enum("RELU6_GRAD"), Elemwise::Mode::RELU6_GRAD}, {normalize_enum("HSIGMOID"), Elemwise::Mode::HSIGMOID}, {normalize_enum("HSIGMOID_GRAD"), Elemwise::Mode::HSIGMOID_GRAD}, {normalize_enum("LOGSIGMOID"), Elemwise::Mode::LOGSIGMOID}, {normalize_enum("SQRT"), Elemwise::Mode::SQRT}, {normalize_enum("SQUARE"), Elemwise::Mode::SQUARE}, {normalize_enum("SIGN"), Elemwise::Mode::SIGN}, {normalize_enum("SAFE_DIV"), Elemwise::Mode::SAFE_DIV}, {normalize_enum("NEQ"), Elemwise::Mode::NEQ}, {normalize_enum("ISNAN"), Elemwise::Mode::ISNAN}, {normalize_enum("ISINF"), Elemwise::Mode::ISINF}}; +template<> PyObject* EnumWrapper::pyobj_insts[87] = {nullptr}; void _init_py_Elemwise_Mode(PyTypeObject& py_type) { auto& e_type = EnumWrapper::type; @@ -8374,19 +8499,134 @@ void _init_py_Elemwise_Mode(PyTypeObject& py_type) { EnumWrapper::pyobj_insts[60] = inst; }{ PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::SINH; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SINH", inst) >= 0); + EnumWrapper::pyobj_insts[61] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::COSH; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "COSH", inst) >= 0); + EnumWrapper::pyobj_insts[62] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::ASINH; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ASINH", inst) >= 0); + EnumWrapper::pyobj_insts[63] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::ACOSH; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ACOSH", inst) >= 0); + EnumWrapper::pyobj_insts[64] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::ATANH; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ATANH", inst) >= 0); + EnumWrapper::pyobj_insts[65] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::TAN; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "TAN", inst) >= 0); + EnumWrapper::pyobj_insts[66] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::ASINH_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ASINH_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[67] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::ACOSH_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ACOSH_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[68] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::ATANH_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ATANH_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[69] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::PRELU; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "PRELU", inst) >= 0); + EnumWrapper::pyobj_insts[70] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::CLIP; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "CLIP", inst) >= 0); + EnumWrapper::pyobj_insts[71] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::PRELU_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "PRELU_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[72] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::SOFTPLUS; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SOFTPLUS", inst) >= 0); + EnumWrapper::pyobj_insts[73] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::SOFTPLUS_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SOFTPLUS_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[74] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::RELU6; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "RELU6", inst) >= 0); + EnumWrapper::pyobj_insts[75] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::RELU6_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "RELU6_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[76] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::HSIGMOID; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "HSIGMOID", inst) >= 0); + EnumWrapper::pyobj_insts[77] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::HSIGMOID_GRAD; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "HSIGMOID_GRAD", inst) >= 0); + EnumWrapper::pyobj_insts[78] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::LOGSIGMOID; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "LOGSIGMOID", inst) >= 0); + EnumWrapper::pyobj_insts[79] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::SQRT; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SQRT", inst) >= 0); + EnumWrapper::pyobj_insts[80] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::SQUARE; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SQUARE", inst) >= 0); + EnumWrapper::pyobj_insts[81] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::SIGN; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SIGN", inst) >= 0); + EnumWrapper::pyobj_insts[82] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); + reinterpret_cast*>(inst)->value = Elemwise::Mode::SAFE_DIV; + mgb_assert(PyDict_SetItemString(e_type->tp_dict, "SAFE_DIV", inst) >= 0); + EnumWrapper::pyobj_insts[83] = inst; +}{ + PyObject* inst = e_type->tp_alloc(e_type, 0); reinterpret_cast*>(inst)->value = Elemwise::Mode::NEQ; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "NEQ", inst) >= 0); - EnumWrapper::pyobj_insts[61] = inst; + EnumWrapper::pyobj_insts[84] = inst; }{ PyObject* inst = e_type->tp_alloc(e_type, 0); reinterpret_cast*>(inst)->value = Elemwise::Mode::ISNAN; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ISNAN", inst) >= 0); - EnumWrapper::pyobj_insts[62] = inst; + EnumWrapper::pyobj_insts[85] = inst; }{ PyObject* inst = e_type->tp_alloc(e_type, 0); reinterpret_cast*>(inst)->value = Elemwise::Mode::ISINF; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "ISINF", inst) >= 0); - EnumWrapper::pyobj_insts[63] = inst; + EnumWrapper::pyobj_insts[86] = inst; } Py_INCREF(e_type); mgb_assert(PyDict_SetItemString( @@ -18216,6 +18456,7 @@ void _init_py_WarpPerspective(py::module m) { _init_py_ConvolutionBackwardData(m); \ _init_py_Copy(m); \ _init_py_Correlation(m); \ + _init_py_Cumprod(m); \ _init_py_Cumsum(m); \ _init_py_CvtColor(m); \ _init_py_DeformableConv(m); \ diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index cd51a572..8545893c 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -567,6 +567,21 @@ public: } }; +class Cumprod : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + int32_t axis = 2147483647; + bool exclusive = true; + bool reverse = false; + Cumprod() = default; + Cumprod(int32_t axis_, bool exclusive_, bool reverse_, std::string scope_ = {}): axis(axis_), exclusive(exclusive_), reverse(reverse_) { set_scope(scope_); } + Cumprod(::megdnn::param::Cumprod packed_param_0): axis(packed_param_0.axis), exclusive(packed_param_0.exclusive), reverse(packed_param_0.reverse) {} + ::megdnn::param::Cumprod param() const { + return {axis, exclusive, reverse}; + } +}; + class Cumsum : public OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; @@ -780,6 +795,29 @@ case Elemwise::Mode::SILU_GRAD: return "SILU_GRAD"; case Elemwise::Mode::GELU: return "GELU"; case Elemwise::Mode::GELU_GRAD: return "GELU_GRAD"; case Elemwise::Mode::COND_LT_MOV: return "COND_LT_MOV"; +case Elemwise::Mode::SINH: return "SINH"; +case Elemwise::Mode::COSH: return "COSH"; +case Elemwise::Mode::ASINH: return "ASINH"; +case Elemwise::Mode::ACOSH: return "ACOSH"; +case Elemwise::Mode::ATANH: return "ATANH"; +case Elemwise::Mode::TAN: return "TAN"; +case Elemwise::Mode::ASINH_GRAD: return "ASINH_GRAD"; +case Elemwise::Mode::ACOSH_GRAD: return "ACOSH_GRAD"; +case Elemwise::Mode::ATANH_GRAD: return "ATANH_GRAD"; +case Elemwise::Mode::PRELU: return "PRELU"; +case Elemwise::Mode::CLIP: return "CLIP"; +case Elemwise::Mode::PRELU_GRAD: return "PRELU_GRAD"; +case Elemwise::Mode::SOFTPLUS: return "SOFTPLUS"; +case Elemwise::Mode::SOFTPLUS_GRAD: return "SOFTPLUS_GRAD"; +case Elemwise::Mode::RELU6: return "RELU6"; +case Elemwise::Mode::RELU6_GRAD: return "RELU6_GRAD"; +case Elemwise::Mode::HSIGMOID: return "HSIGMOID"; +case Elemwise::Mode::HSIGMOID_GRAD: return "HSIGMOID_GRAD"; +case Elemwise::Mode::LOGSIGMOID: return "LOGSIGMOID"; +case Elemwise::Mode::SQRT: return "SQRT"; +case Elemwise::Mode::SQUARE: return "SQUARE"; +case Elemwise::Mode::SIGN: return "SIGN"; +case Elemwise::Mode::SAFE_DIV: return "SAFE_DIV"; case Elemwise::Mode::NEQ: return "NEQ"; case Elemwise::Mode::ISNAN: return "ISNAN"; case Elemwise::Mode::ISINF: return "ISINF"; diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index 631639c7..b594b261 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -678,6 +678,14 @@ CorrelationInst .def_readwrite("pad_size", &Correlation::pad_size) .def_readwrite("is_multiply", &Correlation::is_multiply); +py::class_, OpDef> CumprodInst(m, "Cumprod"); + +CumprodInst + .def(py::init(), py::arg("axis") = 2147483647, py::arg("exclusive") = true, py::arg("reverse") = false, py::arg("scope") = {}) + .def_readwrite("axis", &Cumprod::axis) + .def_readwrite("exclusive", &Cumprod::exclusive) + .def_readwrite("reverse", &Cumprod::reverse); + py::class_, OpDef> CumsumInst(m, "Cumsum"); CumsumInst @@ -893,6 +901,29 @@ py::enum_(ElemwiseInst, "Mode") .value("GELU", Elemwise::Mode::GELU) .value("GELU_GRAD", Elemwise::Mode::GELU_GRAD) .value("COND_LT_MOV", Elemwise::Mode::COND_LT_MOV) + .value("SINH", Elemwise::Mode::SINH) + .value("COSH", Elemwise::Mode::COSH) + .value("ASINH", Elemwise::Mode::ASINH) + .value("ACOSH", Elemwise::Mode::ACOSH) + .value("ATANH", Elemwise::Mode::ATANH) + .value("TAN", Elemwise::Mode::TAN) + .value("ASINH_GRAD", Elemwise::Mode::ASINH_GRAD) + .value("ACOSH_GRAD", Elemwise::Mode::ACOSH_GRAD) + .value("ATANH_GRAD", Elemwise::Mode::ATANH_GRAD) + .value("PRELU", Elemwise::Mode::PRELU) + .value("CLIP", Elemwise::Mode::CLIP) + .value("PRELU_GRAD", Elemwise::Mode::PRELU_GRAD) + .value("SOFTPLUS", Elemwise::Mode::SOFTPLUS) + .value("SOFTPLUS_GRAD", Elemwise::Mode::SOFTPLUS_GRAD) + .value("RELU6", Elemwise::Mode::RELU6) + .value("RELU6_GRAD", Elemwise::Mode::RELU6_GRAD) + .value("HSIGMOID", Elemwise::Mode::HSIGMOID) + .value("HSIGMOID_GRAD", Elemwise::Mode::HSIGMOID_GRAD) + .value("LOGSIGMOID", Elemwise::Mode::LOGSIGMOID) + .value("SQRT", Elemwise::Mode::SQRT) + .value("SQUARE", Elemwise::Mode::SQUARE) + .value("SIGN", Elemwise::Mode::SIGN) + .value("SAFE_DIV", Elemwise::Mode::SAFE_DIV) .value("NEQ", Elemwise::Mode::NEQ) .value("ISNAN", Elemwise::Mode::ISNAN) .value("ISINF", Elemwise::Mode::ISINF) @@ -959,6 +990,29 @@ py::enum_(ElemwiseInst, "Mode") if (str == "GELU") return Elemwise::Mode::GELU; if (str == "GELU_GRAD") return Elemwise::Mode::GELU_GRAD; if (str == "COND_LT_MOV") return Elemwise::Mode::COND_LT_MOV; + if (str == "SINH") return Elemwise::Mode::SINH; + if (str == "COSH") return Elemwise::Mode::COSH; + if (str == "ASINH") return Elemwise::Mode::ASINH; + if (str == "ACOSH") return Elemwise::Mode::ACOSH; + if (str == "ATANH") return Elemwise::Mode::ATANH; + if (str == "TAN") return Elemwise::Mode::TAN; + if (str == "ASINH_GRAD") return Elemwise::Mode::ASINH_GRAD; + if (str == "ACOSH_GRAD") return Elemwise::Mode::ACOSH_GRAD; + if (str == "ATANH_GRAD") return Elemwise::Mode::ATANH_GRAD; + if (str == "PRELU") return Elemwise::Mode::PRELU; + if (str == "CLIP") return Elemwise::Mode::CLIP; + if (str == "PRELU_GRAD") return Elemwise::Mode::PRELU_GRAD; + if (str == "SOFTPLUS") return Elemwise::Mode::SOFTPLUS; + if (str == "SOFTPLUS_GRAD") return Elemwise::Mode::SOFTPLUS_GRAD; + if (str == "RELU6") return Elemwise::Mode::RELU6; + if (str == "RELU6_GRAD") return Elemwise::Mode::RELU6_GRAD; + if (str == "HSIGMOID") return Elemwise::Mode::HSIGMOID; + if (str == "HSIGMOID_GRAD") return Elemwise::Mode::HSIGMOID_GRAD; + if (str == "LOGSIGMOID") return Elemwise::Mode::LOGSIGMOID; + if (str == "SQRT") return Elemwise::Mode::SQRT; + if (str == "SQUARE") return Elemwise::Mode::SQUARE; + if (str == "SIGN") return Elemwise::Mode::SIGN; + if (str == "SAFE_DIV") return Elemwise::Mode::SAFE_DIV; if (str == "NEQ") return Elemwise::Mode::NEQ; if (str == "ISNAN") return Elemwise::Mode::ISNAN; if (str == "ISINF") return Elemwise::Mode::ISINF; diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index d70cd20b..1853cdfc 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -472,6 +472,7 @@ def ExternOpr: MgbHashableOp<"ExternOpr"> { } def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>; +def Cumprod: MgbHashableOp<"Cumprod", [CumprodParam]>; def Split: MgbHashableOp<"Split", [EmptyParam]> { let extraArguments = (ins diff --git a/src/gopt/test/basic_arith.cpp b/src/gopt/test/basic_arith.cpp index c61ce9eb..483cc7af 100644 --- a/src/gopt/test/basic_arith.cpp +++ b/src/gopt/test/basic_arith.cpp @@ -107,6 +107,12 @@ TEST(TestGoptBasicArithInplace, Absorbing) { ASSERT_EQ(y.as_immutable_scalar()->get_cast(), 0.f); } +auto gen_postive = [](HostTensorND& dest) { + HostTensorGenerator mask_generator{ + 2.f, 4.f}; + dest = *mask_generator(dest.shape(), dest.comp_node()); +}; + TEST(TestGoptBasicArithInplace, LogExpExpand) { // test log(exp(a) * (exp(b) / (exp(c) * d**2))) -> a + b - c - log(d**2) @@ -144,9 +150,13 @@ TEST(TestGoptBasicArithInplace, LogExpExpand) { opt.numdiff_eps_single_inp[3] = 1e-3; opt.numdiff_max_err_single_inp[3] = 1e-2; Checker{make_graph, fwd} - .run(ms({2, 3}, {2, 3}), opt) - .run(ms({1, 3}, {2, 3}), opt) - .run(ms({3, 2}, {1}), opt); + .set_input_generator(0, gen_postive) + .set_input_generator(1, gen_postive) + .set_input_generator(2, gen_postive) + .set_input_generator(3, gen_postive) + .run(ms({32, 1}, {32, 1}), opt) + .run(ms({2, 32}, {2, 32}), opt) + .run(ms({1, 32}, {1, 32}), opt); } TEST(TestGoptBasicArithInplace, LogSumExp) { diff --git a/src/jit/impl/ast_c.cpp b/src/jit/impl/ast_c.cpp index eb1dc794..0d4b47a5 100644 --- a/src/jit/impl/ast_c.cpp +++ b/src/jit/impl/ast_c.cpp @@ -133,7 +133,7 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { 0.f}) / 6.f), }; - mgb_assert(map.size() + 41 == opr::Elemwise::Param::MODE_NR_MEMBER); + mgb_assert(map.size() + 42 == opr::Elemwise::Param::MODE_NR_MEMBER); // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH, // ERFINV, ERFCINV, NOT, AND, OR, XOR, NEQ, ISNAN, ISINF return map; diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 9924f747..d70c7326 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -580,6 +580,8 @@ MGB_IMPL_OPR_GRAD(Elemwise) { RET(EL2(ABS_GRAD, i0, og)); case Mode::ADD: RET(og); + case Mode::SAFE_DIV: + RET_INVALID(); case Mode::FLOOR_DIV: return nullptr; case Mode::MAX: diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index 3b45f822..166833f7 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -119,6 +119,75 @@ MGB_IMPL_OPR_GRAD(ArgsortForward) { MGB_DYN_TYPE_OBJ_FINAL_IMPL(ArgsortBackward); MEGDNN_OPR_INIT3(ArgsortBackward, "argsort_bwd", 2, false) +/* ================= Cumprod ================= */ + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumprod); + +Cumprod::Cumprod(VarNode* opr, const Param& param, const OperatorNodeConfig& config) + : Super{opr->owner_graph(), config, "Cumprod", {opr}} { + init_megdnn_opr(*this, param); + add_input({opr}, AddInputSortType::CUR_ADDED); +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(Cumprod) { + mgb_assert(out_grad[0] && !out_grad[1]); + auto x = SymbolVar{opr.input(0)}, y = SymbolVar{opr.output(0)}, + grad = SymbolVar{out_grad[0]}; + auto prod_param = opr.param(); + Cumsum::Param reversed_param; + reversed_param.axis = prod_param.axis; + reversed_param.exclusive = prod_param.exclusive; + reversed_param.reverse = !prod_param.reverse; + + auto w = y * grad; + return Elemwise::make( + {Cumsum::make(w, reversed_param), x}, Elemwise::Mode::SAFE_DIV) + .node(); +} +#endif + +SymbolVar Cumprod::make( + SymbolVar opr, const Param& param, const OperatorNodeConfig& config) { + return opr.insert_single_output_opr(opr.node(), param, config); +} + +void Cumprod::scn_do_execute() { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output().back())); +} + +void Cumprod::add_input_layout_constraint() { + input(0)->add_layout_constraint_contiguous(); +} + +void Cumprod::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto infer_shape = [](TensorShape& dest, const InpVal& iv) { + auto ishp = iv.val.at(0).shape(); + dest = ishp; + return true; + }; + owner_graph()->static_infer_manager().register_shape_infer( + output(0), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape}); + auto infer_workspace = [this](TensorShape& dest, const InpVal& iv) { + auto dtype = input(0)->dtype(); + auto ishp = iv.val.at(0).shape(); + TensorLayout ily(ishp, dtype); + Param real_param = param(); + if (real_param.axis < 0) + real_param.axis += ishp.ndim; + megdnn_opr()->param() = real_param; + dest.ndim = 1; + dest[0] = megdnn_opr()->get_workspace_in_bytes(ily, ily); + return true; + }; + owner_graph()->static_infer_manager().register_shape_infer( + output(1), + {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace}); +} + /* ================= Cumsum ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(Cumsum); diff --git a/src/opr/impl/misc.oprdecl b/src/opr/impl/misc.oprdecl index 0e7d5201..05ee72ec 100644 --- a/src/opr/impl/misc.oprdecl +++ b/src/opr/impl/misc.oprdecl @@ -42,6 +42,16 @@ decl_opr('Argsort', 'performed. Two vars are returned: the sorted array, and the ' 'indices. ') +decl_opr('Cumprod', + inputs=['src'], params='Cumprod', + body=[ + 'if param.axis == (1<<31)-1:', + ' all_inputs[0] = all_inputs[0].flatten()', + ' param.axis = 0' + ], + desc='Return the cumulative product of the elements along a given axis.' + ' If axis is INT_MAX, compute on flattened input.', version=1) + decl_opr('Cumsum', inputs=['src'], params='Cumsum', body=[ diff --git a/src/opr/impl/misc.sereg.h b/src/opr/impl/misc.sereg.h index b06beecb..1fb5523c 100644 --- a/src/opr/impl/misc.sereg.h +++ b/src/opr/impl/misc.sereg.h @@ -70,6 +70,7 @@ MGB_SEREG_OPR(TopK, 2); //! current cumsum version using CumsumV1 = opr::Cumsum; MGB_SEREG_OPR(CumsumV1, 1); +MGB_SEREG_OPR(Cumprod, 1); #if MGB_CUDA MGB_SEREG_OPR(NvOf, 1); diff --git a/src/opr/include/megbrain/opr/misc.h b/src/opr/include/megbrain/opr/misc.h index b214b583..fa872ae6 100644 --- a/src/opr/include/megbrain/opr/misc.h +++ b/src/opr/include/megbrain/opr/misc.h @@ -76,6 +76,25 @@ public: } }; +//! cumulative product along given axis +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( + Cumprod, + cg::SingleCNOperatorNodeBaseT>) // { + void add_input_layout_constraint() override; + +public: + MGE_WIN_DECLSPEC_FUC Cumprod( + VarNode* src, const Param& param, const OperatorNodeConfig& config); + + // for serialization + MGE_WIN_DECLSPEC_FUC static SymbolVar make( + SymbolVar opr, const Param& param, const OperatorNodeConfig& config = {}); + +protected: + void scn_do_execute() override; + void init_output_static_infer_desc() override; +}; + //! cumulative sum along given axis MGB_DEFINE_OPR_CLASS_WITH_EXPORT( Cumsum, diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index 89ae9ca3..9ae14bca 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -350,6 +350,9 @@ template <> struct CheckerConfig : public NoGradCheckerConfig {}; template <> +struct CheckerConfig : public NoGradCheckerConfig {}; + +template <> struct CheckerConfig : public NoGradCheckerConfig { template static InputGenerator get_inp_gen(size_t) { @@ -693,9 +696,13 @@ struct CheckerConfig : public NoGradCheckerConfig { /* ======================= ternary config ======================= */ template <> struct CheckerConfig : public BinaryInputMinGap {}; + template <> struct CheckerConfig : public BinaryInputMinGap {}; + +template <> struct CheckerConfig : public NoGradCheckerConfig {}; + template <> struct CheckerConfig : public CheckerConfig { template @@ -886,6 +893,10 @@ void TestRunner::run() { } TensorShape shapes[] = {{1}, {23, 3}, {666}}; + if (Trait::ARITY == 4) { + checker.disable_graph_opt(); + shapes[0] = {32}; + } typename Checker::RunOptions opt; Config::update_opt(opt); Config::update_checker(checker); @@ -1034,13 +1045,13 @@ TEST(TestOprBasicArithElemwise, FuseMulAdd4Shapes) { }; Checker checker{make_graph, fwd}; - checker.run({TensorShape{1, 2}, {2, 1}, {1, 2}, {2, 1}}) - .run({TensorShape{1, 2, 1, 2, 1, 2}, - {2, 1, 2, 1, 2, 1}, - {2, 1, 2, 1, 2, 1}, - {1, 2, 1, 2, 1, 2}}); + checker.run({TensorShape{1, 32}, {1, 32}, {1, 32}, {1, 32}}) + .run({TensorShape{1, 1, 1, 1, 1, 32}, + {1, 1, 1, 1, 1, 32}, + {1, 1, 1, 1, 1, 32}, + {1, 1, 1, 1, 1, 32}}); ASSERT_FALSE(opr->fuse_badlayout_warn_printed()); - checker.run({TensorShape{1, 2}, {2, 1}, {2, 2}, {2, 2}}); + checker.run({TensorShape{1, 32}, {32, 1}, {32, 32}, {32, 32}}); ASSERT_TRUE(opr->fuse_badlayout_warn_printed()); } diff --git a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl index 1ed742db..4d88c852 100644 --- a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl @@ -47,6 +47,7 @@ DEF_TRAIT(PRELU, (x > 0) ? x : (x* y)) #define _ALLOW_INT false DEF_TRAIT(POW, std::pow(x, y)) DEF_TRAIT(TRUE_DIV, x / y) +DEF_TRAIT(SAFE_DIV, y != 0 ? x / y : 0) DEF_TRAIT(LOG_SUM_EXP, do_log_sum_exp(x, y)) DEF_TRAIT(FUSE_ADD_SIGMOID, 1 / (1 + std::exp(-(x + y)))) DEF_TRAIT(FUSE_ADD_TANH, std::tanh(x + y)) diff --git a/src/opr/test/misc.cpp b/src/opr/test/misc.cpp index e5964108..47ec317a 100644 --- a/src/opr/test/misc.cpp +++ b/src/opr/test/misc.cpp @@ -145,6 +145,60 @@ TEST(TestOprMisc, Argsort) { run(Order::DESCENDING); } +TEST(TestOprMisc, Cumprod) { + using Param = opr::Cumprod::Param; + auto run = [](const Param& param) { + using Checker = AutoOprChecker<1, 1>; + auto make_graph = + [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { + return {opr::Cumprod::make(inputs[0], param)}; + }; + auto fwd = [&](Checker::NumOutArray& out, Checker::NumInpArray inp) { + out[0].resize(inp[0]->shape()); + + auto pin = inp[0]->ptr(), pout = out[0].ptr(); + size_t A, B, C; + int real_axis = param.axis; + if (real_axis < 0) + real_axis += 3; + shape_abc(inp[0]->shape(), real_axis, A, B, C); + ptrdiff_t stride = C; + if (param.reverse) + stride = -stride; + for (size_t i = 0; i < A; ++i) { + for (size_t k = 0; k < C; ++k) { + auto pi = pin + i * B * C + k, po = pout + i * B * C + k; + if (param.reverse) { + pi += (B - 1) * C; + po += (B - 1) * C; + } + if (param.exclusive) { + *po = 1; + po += stride; + } + float prod = 1; + for (size_t j = 0; j < B - 1; ++j) { + prod *= pi[j * stride]; + po[j * stride] = prod; + } + if (!param.exclusive) { + po[(B - 1) * stride] = prod * pi[(B - 1) * stride]; + } + } + } + }; + Checker{make_graph, fwd} + .run({TensorShape{2, 3, 4}}) + .run({TensorShape{3, 1, 2}}) + .run({TensorShape{4, 2, 3}}); + }; + + // test negative axis + for (int32_t axis = -3; axis < 3; ++axis) + for (int mask = 0; mask < 4; ++mask) + run({axis, bool(mask >> 1), bool(mask & 1)}); +} + TEST(TestOprMisc, Cumsum) { using Param = opr::Cumsum::Param; auto run = [](const Param& param) { diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index a7ac763b..6f2d9733 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -123,6 +123,7 @@ union OperatorParam { param.LSTM = 89, param.Softmax = 90, param.Diag = 91, + param.Cumprod = 92, } table Operator { diff --git a/test/src/autocheck.cpp b/test/src/autocheck.cpp index 1de0194c..1e502fc7 100644 --- a/test/src/autocheck.cpp +++ b/test/src/autocheck.cpp @@ -224,6 +224,12 @@ DEF_IMPL(void)::do_run(const ShapeInpArray& shapes, const RunOptions& opt) { m_inputs_generator[i](*m_inputs[i]); mgb_assert(m_inputs[i]->shape().eq_shape(shapes[i])); } + if (shapes.size() == 4u) { + m_extra_err_msg = ssprintf("%d,", *((int*)(m_inputs[0]->raw_ptr()) + 11)); + m_extra_err_msg += ssprintf("%d,", *((int*)(m_inputs[1]->raw_ptr()) + 11)); + m_extra_err_msg += ssprintf("%d,", *((int*)(m_inputs[2]->raw_ptr()) + 11)); + m_extra_err_msg += ssprintf("%d,", *((int*)(m_inputs[3]->raw_ptr()) + 11)); + } if (MGB_GETENV("MGB_AUTOCHECK_DUMP_INPUT")) { static size_t run_id; auto fname = output_file(ssprintf("autocheck-inp-%zu.bin", run_id++));