@@ -2,6 +2,7 @@ | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
import megengine.autodiff as ad | |||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.functional.elemwise as elemwise | import megengine.functional.elemwise as elemwise | ||||
from megengine import tensor | from megengine import tensor | ||||
@@ -293,3 +294,25 @@ def test_empty_tensor(is_trace): | |||||
run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False) | run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False) | ||||
run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False) | run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False) | ||||
run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False) | run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False) | ||||
@pytest.mark.parametrize("is_trace", [True, False]) | |||||
def test_maximum_grad_consistency(is_trace): | |||||
def f(x): | |||||
with ad.GradManager() as gm: | |||||
gm.attach(x) | |||||
gm.backward(F.maximum(x, x)) | |||||
dx = x.grad | |||||
x.grad = None | |||||
return dx | |||||
def run(f): | |||||
x = F.arange(10) | |||||
for i in range(3): | |||||
np.testing.assert_equal(f(x).numpy(), np.ones(10)) | |||||
if is_trace: | |||||
for symbolic in [False, True]: | |||||
run(trace(symbolic=symbolic)(f)) | |||||
else: | |||||
run(f) |
@@ -117,6 +117,8 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { | |||||
// misc | // misc | ||||
ENTRY(COND_LEQ_MOV, | ENTRY(COND_LEQ_MOV, | ||||
ASTPtr::make<BinaryAST>("<=", inps[0], inps[1]) * inps[2]), | ASTPtr::make<BinaryAST>("<=", inps[0], inps[1]) * inps[2]), | ||||
ENTRY(COND_LT_MOV, | |||||
ASTPtr::make<BinaryAST>("<", inps[0], inps[1]) * inps[2]), | |||||
ENTRY(FUSE_MUL_ADD3, inps[0] * inps[1] + inps[2]), | ENTRY(FUSE_MUL_ADD3, inps[0] * inps[1] + inps[2]), | ||||
ENTRY(FUSE_MUL_ADD4, inps[0] * inps[1] + inps[2] * inps[3]), | ENTRY(FUSE_MUL_ADD4, inps[0] * inps[1] + inps[2] * inps[3]), | ||||
ENTRY(FUSE_ADD_RELU, make_call("fmaxf", {inps[0] + inps[1], 0})), | ENTRY(FUSE_ADD_RELU, make_call("fmaxf", {inps[0] + inps[1], 0})), | ||||
@@ -147,6 +147,8 @@ Halide::Expr dispatch_elemwise_mode( | |||||
// ternary | // ternary | ||||
case Mode::COND_LEQ_MOV: | case Mode::COND_LEQ_MOV: | ||||
return Halide::select(inp(0) <= inp(1), inp(2), cv(0)); | return Halide::select(inp(0) <= inp(1), inp(2), cv(0)); | ||||
case Mode::COND_LT_MOV: | |||||
return Halide::select(inp(0) < inp(1), inp(2), cv(0)); | |||||
case Mode::FUSE_MUL_ADD3: | case Mode::FUSE_MUL_ADD3: | ||||
return inp(0) * inp(1) + inp(2); | return inp(0) * inp(1) + inp(2); | ||||
case Mode::FUSE_MUL_ADD4: | case Mode::FUSE_MUL_ADD4: | ||||
@@ -388,6 +388,15 @@ mlir::Value lower_mode<Mode::COND_LEQ_MOV>( | |||||
helper.le(operands[0], operands[1]), operands[2], helper.const_f32(0.f)); | helper.le(operands[0], operands[1]), operands[2], helper.const_f32(0.f)); | ||||
} | } | ||||
//! COND_LT_MOV: x < y ? z : ctype(0) | |||||
template <> | |||||
mlir::Value lower_mode<Mode::COND_LT_MOV>( | |||||
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.select( | |||||
helper.lt(operands[0], operands[1]), operands[2], helper.const_f32(0.f)); | |||||
} | |||||
//! FUSE_MUL_ADD3: x * y + z | //! FUSE_MUL_ADD3: x * y + z | ||||
template <> | template <> | ||||
mlir::Value lower_mode<Mode::FUSE_MUL_ADD3>( | mlir::Value lower_mode<Mode::FUSE_MUL_ADD3>( | ||||
@@ -60,6 +60,7 @@ | |||||
#define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ | #define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ | ||||
cb(CondLeqMovOp, COND_LEQ_MOV) \ | cb(CondLeqMovOp, COND_LEQ_MOV) \ | ||||
cb(CondLtMovOp, COND_LT_MOV) \ | |||||
cb(FuseMulAdd3Op, FUSE_MUL_ADD3) | cb(FuseMulAdd3Op, FUSE_MUL_ADD3) | ||||
// clang-format on | // clang-format on | ||||
@@ -449,6 +449,7 @@ TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) { | |||||
// clang-format off | // clang-format off | ||||
#define FOREACH_TERNARY_MODE(cb) \ | #define FOREACH_TERNARY_MODE(cb) \ | ||||
cb(COND_LEQ_MOV) \ | cb(COND_LEQ_MOV) \ | ||||
cb(COND_LT_MOV) \ | |||||
cb(FUSE_MUL_ADD3) \ | cb(FUSE_MUL_ADD3) \ | ||||
// clang-format on | // clang-format on | ||||
template <typename tag> | template <typename tag> | ||||
@@ -452,6 +452,7 @@ void run<all_oprs>(Backend backend, CompNode cn) { | |||||
CHECK_ELEM2(ATAN2, true, gt0); | CHECK_ELEM2(ATAN2, true, gt0); | ||||
CHECK_ELEM3(COND_LEQ_MOV, false, none); | CHECK_ELEM3(COND_LEQ_MOV, false, none); | ||||
CHECK_ELEM3(COND_LT_MOV, false, none); | |||||
CHECK_ELEM3(FUSE_MUL_ADD3, true, none); | CHECK_ELEM3(FUSE_MUL_ADD3, true, none); | ||||
CHECK_ELEM4(FUSE_MUL_ADD4, true, none); | CHECK_ELEM4(FUSE_MUL_ADD4, true, none); | ||||
@@ -601,9 +601,17 @@ MGB_IMPL_OPR_GRAD(Elemwise) { | |||||
case Mode::FLOOR_DIV: | case Mode::FLOOR_DIV: | ||||
return nullptr; | return nullptr; | ||||
case Mode::MAX: | case Mode::MAX: | ||||
RET(EL3(COND_LEQ_MOV, i[!wrt_idx], i[wrt_idx], og)); | |||||
if (wrt_idx) { | |||||
RET(EL3(COND_LT_MOV, i[0], i[1], og)); | |||||
} else { | |||||
RET(EL3(COND_LEQ_MOV, i[1], i[0], og)); | |||||
} | |||||
case Mode::MIN: | case Mode::MIN: | ||||
RET(EL3(COND_LEQ_MOV, i[wrt_idx], i[!wrt_idx], og)); | |||||
if (wrt_idx) { | |||||
RET(EL3(COND_LT_MOV, i[1], i[0], og)); | |||||
} else { | |||||
RET(EL3(COND_LEQ_MOV, i[0], i[1], og)); | |||||
} | |||||
case Mode::MOD: | case Mode::MOD: | ||||
if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
RET(og); | RET(og); | ||||
@@ -661,7 +669,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) { | |||||
if (wrt_idx <= 1) | if (wrt_idx <= 1) | ||||
return nullptr; | return nullptr; | ||||
RET(EL3(COND_LEQ_MOV, i0, i1, og)); | RET(EL3(COND_LEQ_MOV, i0, i1, og)); | ||||
case Mode::COND_LT_MOV: | |||||
if (wrt_idx <= 1) | |||||
return nullptr; | |||||
RET(EL3(COND_LT_MOV, i0, i1, og)); | |||||
// fuse oprs | // fuse oprs | ||||
case Mode::FUSE_MUL_ADD3: | case Mode::FUSE_MUL_ADD3: | ||||
if (wrt_idx < 2) { | if (wrt_idx < 2) { | ||||
@@ -571,6 +571,8 @@ struct CheckerConfig<GELU_GRAD> : public NoGradCheckerConfig {}; | |||||
/* ======================= ternary config ======================= */ | /* ======================= ternary config ======================= */ | ||||
template <> | template <> | ||||
struct CheckerConfig<COND_LEQ_MOV> : public BinaryInputMinGap<false> {}; | struct CheckerConfig<COND_LEQ_MOV> : public BinaryInputMinGap<false> {}; | ||||
template <> | |||||
struct CheckerConfig<COND_LT_MOV> : public BinaryInputMinGap<false> {}; | |||||
/* ======================= test runner ======================= */ | /* ======================= test runner ======================= */ | ||||
namespace detail { | namespace detail { | ||||
@@ -13,6 +13,7 @@ | |||||
#define _ALLOW_FLOAT true | #define _ALLOW_FLOAT true | ||||
#define _ALLOW_INT true | #define _ALLOW_INT true | ||||
DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0) | DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0) | ||||
DEF_TRAIT(COND_LT_MOV, x < y ? z : 0) | |||||
DEF_TRAIT(FUSE_MUL_ADD3, x* y + z) | DEF_TRAIT(FUSE_MUL_ADD3, x* y + z) | ||||
#undef _ALLOW_INT | #undef _ALLOW_INT | ||||
#undef _ALLOW_FLOAT | #undef _ALLOW_FLOAT | ||||
@@ -589,6 +589,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_IS8_OS8) { | |||||
switch (mode) { | switch (mode) { | ||||
MAKE_TERNARY(FUSE_MUL_ADD3); | MAKE_TERNARY(FUSE_MUL_ADD3); | ||||
MAKE_TERNARY(COND_LEQ_MOV); | MAKE_TERNARY(COND_LEQ_MOV); | ||||
MAKE_TERNARY(COND_LT_MOV); | |||||
default: | default: | ||||
mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n"); | mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n"); | ||||
break; | break; | ||||
@@ -646,6 +647,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_I8Asymm_O8Asymm) { | |||||
switch (mode) { | switch (mode) { | ||||
MAKE_TERNARY(FUSE_MUL_ADD3); | MAKE_TERNARY(FUSE_MUL_ADD3); | ||||
MAKE_TERNARY(COND_LEQ_MOV); | MAKE_TERNARY(COND_LEQ_MOV); | ||||
MAKE_TERNARY(COND_LT_MOV); | |||||
default: | default: | ||||
mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n"); | mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n"); | ||||
break; | break; | ||||