Browse Source

feat(mge/opr): add cumsum

GitOrigin-RevId: 740f00a8e5
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
dd1fecdf29
3 changed files with 47 additions and 0 deletions
  1. +33
    -0
      imperative/python/megengine/functional/tensor.py
  2. +12
    -0
      imperative/src/impl/ops/specializations.cpp
  3. +2
    -0
      src/core/include/megbrain/ir/ops.td

+ 33
- 0
imperative/python/megengine/functional/tensor.py View File

@@ -27,6 +27,7 @@ __all__ = [
"broadcast_to", "broadcast_to",
"concat", "concat",
"cond_take", "cond_take",
"cumsum",
"expand_dims", "expand_dims",
"eye", "eye",
"flatten", "flatten",
@@ -1328,3 +1329,35 @@ def roll(
if shp_bak is not None: if shp_bak is not None:
out = out.reshape(shp_bak) out = out.reshape(shp_bak)
return out return out


def cumsum(inp: Tensor, axis: int):
"""
Computes the cumulative sum of elements along given axis.

:param inp: input tensor.
:param axis: axis along which cumsum is performed.

Examples:

.. testcode::

from megengine import tensor
import megengine.functional as F

x = tensor([[1, 2, 3], [4, 5, 6]], "int32")
y = F.cumsum(x, 1)
print(y.numpy())

Outputs:

.. testoutput::

[[ 1 3 6]
[ 4 9 15]]

"""
assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor"
assert axis >= 0 and axis < inp.ndim, "input axis {} out of bound".format(axis)
op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False)
return apply(op, inp)[0]

+ 12
- 0
imperative/src/impl/ops/specializations.cpp View File

@@ -673,4 +673,16 @@ OP_TRAIT_REG(SlidingWindowTranspose, SlidingWindowTranspose)
.fallback(); .fallback();
}} // sliding_window_transpose }} // sliding_window_transpose


namespace {
namespace cumsum {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Cumsum&>(def);
OperatorNodeConfig config{op.make_name()};
return opr::Cumsum::make(inputs[0], op.param(), config);
}

OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback();
} // namespace cumsum
} // namespace

} // namespace mgb::imperative } // namespace mgb::imperative

+ 2
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -377,4 +377,6 @@ def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>;


def FastpathCopy: MgbHashableOp<"FastpathCopy">; def FastpathCopy: MgbHashableOp<"FastpathCopy">;


def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>;

#endif // MGB_OPS #endif // MGB_OPS

Loading…
Cancel
Save