Browse Source

fix(mge): fix grad of maximum(x, x)

GitOrigin-RevId: e0e2efb71b
release-1.10
Megvii Engine Team 3 years ago
parent
commit
cfc41648a4
11 changed files with 58 additions and 3 deletions
  1. +23
    -0
      imperative/python/test/unit/functional/test_elemwise.py
  2. +2
    -0
      src/jit/impl/ast_c.cpp
  3. +2
    -0
      src/jit/impl/halide/ast_hl.cpp
  4. +9
    -0
      src/jit/impl/mlir/ir/each_mode.cpp
  5. +1
    -0
      src/jit/impl/mlir/ir/each_mode.h
  6. +1
    -0
      src/jit/test/codegen.cpp
  7. +1
    -0
      src/jit/test/fusion.cpp
  8. +14
    -3
      src/opr/impl/basic_arith.cpp
  9. +2
    -0
      src/opr/test/basic_arith/elemwise.cpp
  10. +1
    -0
      src/opr/test/basic_arith/elemwise_ternary_trait_def.inl
  11. +2
    -0
      src/opr/test/nn_int.cpp

+ 23
- 0
imperative/python/test/unit/functional/test_elemwise.py View File

@@ -2,6 +2,7 @@
import numpy as np import numpy as np
import pytest import pytest


import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
import megengine.functional.elemwise as elemwise import megengine.functional.elemwise as elemwise
from megengine import tensor from megengine import tensor
@@ -293,3 +294,25 @@ def test_empty_tensor(is_trace):
run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False) run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False)
run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False) run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False)
run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False) run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False)


@pytest.mark.parametrize("is_trace", [True, False])
def test_maximum_grad_consistency(is_trace):
def f(x):
with ad.GradManager() as gm:
gm.attach(x)
gm.backward(F.maximum(x, x))
dx = x.grad
x.grad = None
return dx

def run(f):
x = F.arange(10)
for i in range(3):
np.testing.assert_equal(f(x).numpy(), np.ones(10))

if is_trace:
for symbolic in [False, True]:
run(trace(symbolic=symbolic)(f))
else:
run(f)

+ 2
- 0
src/jit/impl/ast_c.cpp View File

@@ -117,6 +117,8 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
// misc // misc
ENTRY(COND_LEQ_MOV, ENTRY(COND_LEQ_MOV,
ASTPtr::make<BinaryAST>("<=", inps[0], inps[1]) * inps[2]), ASTPtr::make<BinaryAST>("<=", inps[0], inps[1]) * inps[2]),
ENTRY(COND_LT_MOV,
ASTPtr::make<BinaryAST>("<", inps[0], inps[1]) * inps[2]),
ENTRY(FUSE_MUL_ADD3, inps[0] * inps[1] + inps[2]), ENTRY(FUSE_MUL_ADD3, inps[0] * inps[1] + inps[2]),
ENTRY(FUSE_MUL_ADD4, inps[0] * inps[1] + inps[2] * inps[3]), ENTRY(FUSE_MUL_ADD4, inps[0] * inps[1] + inps[2] * inps[3]),
ENTRY(FUSE_ADD_RELU, make_call("fmaxf", {inps[0] + inps[1], 0})), ENTRY(FUSE_ADD_RELU, make_call("fmaxf", {inps[0] + inps[1], 0})),


+ 2
- 0
src/jit/impl/halide/ast_hl.cpp View File

@@ -147,6 +147,8 @@ Halide::Expr dispatch_elemwise_mode(
// ternary // ternary
case Mode::COND_LEQ_MOV: case Mode::COND_LEQ_MOV:
return Halide::select(inp(0) <= inp(1), inp(2), cv(0)); return Halide::select(inp(0) <= inp(1), inp(2), cv(0));
case Mode::COND_LT_MOV:
return Halide::select(inp(0) < inp(1), inp(2), cv(0));
case Mode::FUSE_MUL_ADD3: case Mode::FUSE_MUL_ADD3:
return inp(0) * inp(1) + inp(2); return inp(0) * inp(1) + inp(2);
case Mode::FUSE_MUL_ADD4: case Mode::FUSE_MUL_ADD4:


+ 9
- 0
src/jit/impl/mlir/ir/each_mode.cpp View File

@@ -388,6 +388,15 @@ mlir::Value lower_mode<Mode::COND_LEQ_MOV>(
helper.le(operands[0], operands[1]), operands[2], helper.const_f32(0.f)); helper.le(operands[0], operands[1]), operands[2], helper.const_f32(0.f));
} }


//! COND_LT_MOV: x < y ? z : ctype(0)
template <>
mlir::Value lower_mode<Mode::COND_LT_MOV>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.lt(operands[0], operands[1]), operands[2], helper.const_f32(0.f));
}

//! FUSE_MUL_ADD3: x * y + z //! FUSE_MUL_ADD3: x * y + z
template <> template <>
mlir::Value lower_mode<Mode::FUSE_MUL_ADD3>( mlir::Value lower_mode<Mode::FUSE_MUL_ADD3>(


+ 1
- 0
src/jit/impl/mlir/ir/each_mode.h View File

@@ -60,6 +60,7 @@


#define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ #define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \
cb(CondLeqMovOp, COND_LEQ_MOV) \ cb(CondLeqMovOp, COND_LEQ_MOV) \
cb(CondLtMovOp, COND_LT_MOV) \
cb(FuseMulAdd3Op, FUSE_MUL_ADD3) cb(FuseMulAdd3Op, FUSE_MUL_ADD3)
// clang-format on // clang-format on




+ 1
- 0
src/jit/test/codegen.cpp View File

@@ -449,6 +449,7 @@ TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) {
// clang-format off // clang-format off
#define FOREACH_TERNARY_MODE(cb) \ #define FOREACH_TERNARY_MODE(cb) \
cb(COND_LEQ_MOV) \ cb(COND_LEQ_MOV) \
cb(COND_LT_MOV) \
cb(FUSE_MUL_ADD3) \ cb(FUSE_MUL_ADD3) \
// clang-format on // clang-format on
template <typename tag> template <typename tag>


+ 1
- 0
src/jit/test/fusion.cpp View File

@@ -452,6 +452,7 @@ void run<all_oprs>(Backend backend, CompNode cn) {
CHECK_ELEM2(ATAN2, true, gt0); CHECK_ELEM2(ATAN2, true, gt0);


CHECK_ELEM3(COND_LEQ_MOV, false, none); CHECK_ELEM3(COND_LEQ_MOV, false, none);
CHECK_ELEM3(COND_LT_MOV, false, none);
CHECK_ELEM3(FUSE_MUL_ADD3, true, none); CHECK_ELEM3(FUSE_MUL_ADD3, true, none);


CHECK_ELEM4(FUSE_MUL_ADD4, true, none); CHECK_ELEM4(FUSE_MUL_ADD4, true, none);


+ 14
- 3
src/opr/impl/basic_arith.cpp View File

@@ -601,9 +601,17 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
case Mode::FLOOR_DIV: case Mode::FLOOR_DIV:
return nullptr; return nullptr;
case Mode::MAX: case Mode::MAX:
RET(EL3(COND_LEQ_MOV, i[!wrt_idx], i[wrt_idx], og));
if (wrt_idx) {
RET(EL3(COND_LT_MOV, i[0], i[1], og));
} else {
RET(EL3(COND_LEQ_MOV, i[1], i[0], og));
}
case Mode::MIN: case Mode::MIN:
RET(EL3(COND_LEQ_MOV, i[wrt_idx], i[!wrt_idx], og));
if (wrt_idx) {
RET(EL3(COND_LT_MOV, i[1], i[0], og));
} else {
RET(EL3(COND_LEQ_MOV, i[0], i[1], og));
}
case Mode::MOD: case Mode::MOD:
if (wrt_idx == 0) { if (wrt_idx == 0) {
RET(og); RET(og);
@@ -661,7 +669,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
if (wrt_idx <= 1) if (wrt_idx <= 1)
return nullptr; return nullptr;
RET(EL3(COND_LEQ_MOV, i0, i1, og)); RET(EL3(COND_LEQ_MOV, i0, i1, og));

case Mode::COND_LT_MOV:
if (wrt_idx <= 1)
return nullptr;
RET(EL3(COND_LT_MOV, i0, i1, og));
// fuse oprs // fuse oprs
case Mode::FUSE_MUL_ADD3: case Mode::FUSE_MUL_ADD3:
if (wrt_idx < 2) { if (wrt_idx < 2) {


+ 2
- 0
src/opr/test/basic_arith/elemwise.cpp View File

@@ -571,6 +571,8 @@ struct CheckerConfig<GELU_GRAD> : public NoGradCheckerConfig {};
/* ======================= ternary config ======================= */ /* ======================= ternary config ======================= */
template <> template <>
struct CheckerConfig<COND_LEQ_MOV> : public BinaryInputMinGap<false> {}; struct CheckerConfig<COND_LEQ_MOV> : public BinaryInputMinGap<false> {};
template <>
struct CheckerConfig<COND_LT_MOV> : public BinaryInputMinGap<false> {};


/* ======================= test runner ======================= */ /* ======================= test runner ======================= */
namespace detail { namespace detail {


+ 1
- 0
src/opr/test/basic_arith/elemwise_ternary_trait_def.inl View File

@@ -13,6 +13,7 @@
#define _ALLOW_FLOAT true #define _ALLOW_FLOAT true
#define _ALLOW_INT true #define _ALLOW_INT true
DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0) DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0)
DEF_TRAIT(COND_LT_MOV, x < y ? z : 0)
DEF_TRAIT(FUSE_MUL_ADD3, x* y + z) DEF_TRAIT(FUSE_MUL_ADD3, x* y + z)
#undef _ALLOW_INT #undef _ALLOW_INT
#undef _ALLOW_FLOAT #undef _ALLOW_FLOAT


+ 2
- 0
src/opr/test/nn_int.cpp View File

@@ -589,6 +589,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_IS8_OS8) {
switch (mode) { switch (mode) {
MAKE_TERNARY(FUSE_MUL_ADD3); MAKE_TERNARY(FUSE_MUL_ADD3);
MAKE_TERNARY(COND_LEQ_MOV); MAKE_TERNARY(COND_LEQ_MOV);
MAKE_TERNARY(COND_LT_MOV);
default: default:
mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n"); mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n");
break; break;
@@ -646,6 +647,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_I8Asymm_O8Asymm) {
switch (mode) { switch (mode) {
MAKE_TERNARY(FUSE_MUL_ADD3); MAKE_TERNARY(FUSE_MUL_ADD3);
MAKE_TERNARY(COND_LEQ_MOV); MAKE_TERNARY(COND_LEQ_MOV);
MAKE_TERNARY(COND_LT_MOV);
default: default:
mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n"); mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n");
break; break;


Loading…
Cancel
Save