GitOrigin-RevId: 5883c68804
release-1.1
@@ -29,6 +29,8 @@ cb(sub, SubFOp); | |||||
cb(mul, MulFOp); | cb(mul, MulFOp); | ||||
cb(div, DivFOp); | cb(div, DivFOp); | ||||
cb(mod, RemFOp); | cb(mod, RemFOp); | ||||
cb(bit_and, AndOp); | |||||
cb(bit_or, OrOp); | |||||
#undef cb | #undef cb | ||||
#define cb(name, mode) \ | #define cb(name, mode) \ | ||||
@@ -72,6 +74,7 @@ cb(exp, ExpOp); | |||||
cb(exp2, Exp2Op); | cb(exp2, Exp2Op); | ||||
cb(log10, Log10Op); | cb(log10, Log10Op); | ||||
cb(log2, Log2Op); | cb(log2, Log2Op); | ||||
cb(log, LogOp); | |||||
cb(rsqrt, RsqrtOp); | cb(rsqrt, RsqrtOp); | ||||
cb(sin, SinOp); | cb(sin, SinOp); | ||||
cb(sqrt, SqrtOp); | cb(sqrt, SqrtOp); | ||||
@@ -79,7 +82,8 @@ cb(tanh, TanhOp); | |||||
#undef cb | #undef cb | ||||
mlir::Value ValueBuilderHelper::abs(mlir::Value lhs) { | mlir::Value ValueBuilderHelper::abs(mlir::Value lhs) { | ||||
return max(lhs, const_val(0.f)); | |||||
auto zero = const_val(0.f); | |||||
return select(ge(lhs, zero), lhs, sub(zero, lhs)); | |||||
} | } | ||||
mlir::Value ValueBuilderHelper::floor(mlir::Value lhs) { | mlir::Value ValueBuilderHelper::floor(mlir::Value lhs) { | ||||
@@ -87,11 +91,6 @@ mlir::Value ValueBuilderHelper::floor(mlir::Value lhs) { | |||||
return neg(ceil(neg(lhs))); | return neg(ceil(neg(lhs))); | ||||
} | } | ||||
mlir::Value ValueBuilderHelper::log(mlir::Value lhs) { | |||||
// math.log10(math.e) = 0.4342944819032518f | |||||
return div(log10(lhs), const_val(0.4342944819032518f)); | |||||
} | |||||
mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, | mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, | ||||
mlir::Value false_val) { | mlir::Value false_val) { | ||||
return m_builder.create<mlir::SelectOp>(m_location, cond, true_val, | return m_builder.create<mlir::SelectOp>(m_location, cond, true_val, | ||||
@@ -47,6 +47,8 @@ public: | |||||
cb(lt); | cb(lt); | ||||
cb(le); | cb(le); | ||||
cb(eq); | cb(eq); | ||||
cb(bit_and); | |||||
cb(bit_or); | |||||
#undef cb | #undef cb | ||||
mlir::Value const_val(float val); | mlir::Value const_val(float val); | ||||
@@ -18,6 +18,7 @@ | |||||
#include "megbrain/jit/mlir/ir/dialect.h" | #include "megbrain/jit/mlir/ir/dialect.h" | ||||
#include "./common.h" | #include "./common.h" | ||||
#include "./numerical.h" | |||||
#include <mlir/Dialect/StandardOps/IR/Ops.h> | #include <mlir/Dialect/StandardOps/IR/Ops.h> | ||||
#include <mlir/IR/Builders.h> | #include <mlir/IR/Builders.h> | ||||
@@ -28,6 +29,8 @@ | |||||
cb(ReluOp, RELU) \ | cb(ReluOp, RELU) \ | ||||
cb(AbsOp, ABS) \ | cb(AbsOp, ABS) \ | ||||
cb(NegOp, NEGATE) \ | cb(NegOp, NEGATE) \ | ||||
cb(AcosOp, ACOS) \ | |||||
cb(AsinOp, ASIN) \ | |||||
cb(CeilOp, CEIL) \ | cb(CeilOp, CEIL) \ | ||||
cb(CosOp, COS) \ | cb(CosOp, COS) \ | ||||
cb(ExpOp, EXP) \ | cb(ExpOp, EXP) \ | ||||
@@ -40,7 +43,11 @@ | |||||
cb(FastTanhOp, FAST_TANH) \ | cb(FastTanhOp, FAST_TANH) \ | ||||
cb(HswishOp, H_SWISH) \ | cb(HswishOp, H_SWISH) \ | ||||
cb(ExpM1Op, EXPM1) \ | cb(ExpM1Op, EXPM1) \ | ||||
cb(RoundOp, ROUND) | |||||
cb(RoundOp, ROUND) \ | |||||
cb(ErfOp, ERF) \ | |||||
cb(ErfInvOp, ERFINV) \ | |||||
cb(ErfCOp, ERFC) \ | |||||
cb(ErfCInvOp, ERFCINV) | |||||
#define MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) \ | #define MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) \ | ||||
cb(AbsGradOp, ABS_GRAD) \ | cb(AbsGradOp, ABS_GRAD) \ | ||||
@@ -52,6 +59,7 @@ | |||||
cb(SubOp, SUB) \ | cb(SubOp, SUB) \ | ||||
cb(MulOp, MUL) \ | cb(MulOp, MUL) \ | ||||
cb(TrueDivOp, TRUE_DIV) \ | cb(TrueDivOp, TRUE_DIV) \ | ||||
cb(PowOp, POW) \ | |||||
cb(SigmoidGradOp, SIGMOID_GRAD) \ | cb(SigmoidGradOp, SIGMOID_GRAD) \ | ||||
cb(SwishGt0Op, SWITCH_GT0) \ | cb(SwishGt0Op, SWITCH_GT0) \ | ||||
cb(TanhGradOp, TANH_GRAD) \ | cb(TanhGradOp, TANH_GRAD) \ | ||||
@@ -64,7 +72,8 @@ | |||||
cb(FastTanhGradOp, FAST_TANH_GRAD) \ | cb(FastTanhGradOp, FAST_TANH_GRAD) \ | ||||
cb(FuseAddSigmoidOp, FUSE_ADD_SIGMOID) \ | cb(FuseAddSigmoidOp, FUSE_ADD_SIGMOID) \ | ||||
cb(HswishGradOp, H_SWISH_GRAD) \ | cb(HswishGradOp, H_SWISH_GRAD) \ | ||||
cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) | |||||
cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) \ | |||||
cb(Atan2Op, ATAN2) | |||||
#define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ | #define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ | ||||
cb(CondLeqMovOp, COND_LEQ_MOV) \ | cb(CondLeqMovOp, COND_LEQ_MOV) \ | ||||
@@ -197,6 +206,79 @@ struct StandardOp<jit::RoundOp> { | |||||
} | } | ||||
}; | }; | ||||
//! pi / 2 - arctan2(x, sqrt(1 - x * x)) | |||||
template <> | |||||
struct StandardOp<jit::AcosOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto x = operands[0]; | |||||
auto one_minus_x_2 = helper.sub(helper.const_val(1.f), helper.mul(x, x)); | |||||
auto asin = atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); | |||||
auto pi_over_2 = helper.const_val(1.57079637f); | |||||
return helper.sub(pi_over_2, asin); | |||||
} | |||||
}; | |||||
//! arctan2(x, sqrt(1 - x * x)) | |||||
template <> | |||||
struct StandardOp<jit::AsinOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto x = operands[0]; | |||||
auto one_minus_x_2 = helper.sub(helper.const_val(1.f), helper.mul(x, x)); | |||||
return atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); | |||||
} | |||||
}; | |||||
//! gauss error function | |||||
template <> | |||||
struct StandardOp<jit::ErfOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return erf_approx(helper, operands[0]); | |||||
} | |||||
}; | |||||
//! inverse of gauss error function | |||||
//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c | |||||
template <> | |||||
struct StandardOp<jit::ErfInvOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto sqrt2 = helper.const_val(1.4142135623f); | |||||
auto x = helper.mul(helper.const_val(0.5f), | |||||
helper.add(operands[0], helper.const_val(1.f))); | |||||
return helper.div(ndtri_approx(helper, x), sqrt2); | |||||
} | |||||
}; | |||||
//! complementary error function | |||||
template <> | |||||
struct StandardOp<jit::ErfCOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.sub(helper.const_val(1.f), erf_approx(helper, operands[0])); | |||||
} | |||||
}; | |||||
//! inverse of complementary gauss error function | |||||
//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c | |||||
template <> | |||||
struct StandardOp<jit::ErfCInvOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
auto minus_sqrt2 = helper.const_val(-1.4142135623f); | |||||
auto x = helper.mul(helper.const_val(0.5f), operands[0]); | |||||
return helper.div(ndtri_approx(helper, x), minus_sqrt2); | |||||
} | |||||
}; | |||||
/////////////////////////// binary op /////////////////////////// | /////////////////////////// binary op /////////////////////////// | ||||
//! binary: x > 0 ? y : -y | //! binary: x > 0 ? y : -y | ||||
@@ -210,6 +292,16 @@ struct StandardOp<jit::AbsGradOp> { | |||||
} | } | ||||
}; | }; | ||||
//! x^y = exp(y * log(x)) | |||||
template <> | |||||
struct StandardOp<jit::PowOp> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return helper.exp(helper.mul(operands[1], helper.log(operands[0]))); | |||||
} | |||||
}; | |||||
//! x * (1 - x) * y | //! x * (1 - x) * y | ||||
template <> | template <> | ||||
struct StandardOp<jit::SigmoidGradOp> { | struct StandardOp<jit::SigmoidGradOp> { | ||||
@@ -382,6 +474,16 @@ struct StandardOp<jit::FuseAddHswishOp> { | |||||
} | } | ||||
}; | }; | ||||
//! arctan | |||||
template <> | |||||
struct StandardOp<jit::Atan2Op> { | |||||
mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, | |||||
ValueRange operands) { | |||||
ValueBuilderHelper helper(builder, loc); | |||||
return atan2_approx(helper, operands[0], operands[1]); | |||||
} | |||||
}; | |||||
/////////////////////////// ternary op /////////////////////////// | /////////////////////////// ternary op /////////////////////////// | ||||
//! x <= y ? z : ctype(0) | //! x <= y ? z : ctype(0) | ||||
template <> | template <> | ||||
@@ -0,0 +1,248 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/numerical.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include "numerical.h" | |||||
namespace mgb { | |||||
namespace jit { | |||||
mlir::Value polynomial(ValueBuilderHelper& helper, mlir::Value x, | |||||
std::vector<mlir::Value>& coeff) { | |||||
size_t n = coeff.size(); | |||||
if (n == 0) { | |||||
return helper.const_val(0); | |||||
} | |||||
mlir::Value r = coeff[0]; | |||||
for (size_t i = 1; i < n; i++) { | |||||
r = helper.add(helper.mul(r, x), coeff[i]); | |||||
} | |||||
return r; | |||||
} | |||||
// polynomial approximation of arctangent | |||||
// atan(t) = t + c3 * t^3 + c5 * t^5 + ... + c17 * t^17 | |||||
// original paper: | |||||
// https://arxiv.org/pdf/1508.03211.pdf | |||||
mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y, | |||||
mlir::Value x) { | |||||
auto atan_poly = [&](mlir::Value t) { | |||||
std::vector<mlir::Value> coeff = { | |||||
helper.const_val(2.90188402868807315826416015625E-3), | |||||
helper.const_val(-1.62907354533672332763671875E-2), | |||||
helper.const_val(4.3082617223262786865234375E-2), | |||||
helper.const_val(-7.5408883392810821533203125E-2), | |||||
helper.const_val(0.1066047251224517822265625), | |||||
helper.const_val(-0.14209578931331634521484375), | |||||
helper.const_val(0.19993579387664794921875), | |||||
helper.const_val(-0.3333314359188079833984375)}; | |||||
auto t2 = helper.mul(t, t); | |||||
auto p = polynomial(helper, t2, coeff); | |||||
return helper.add(helper.mul(helper.mul(p, t2), t), t); | |||||
}; | |||||
// constants | |||||
auto zero = helper.const_val(0); | |||||
auto pi = helper.const_val(3.141592653589793); | |||||
auto pi_over_2 = helper.const_val(1.570796326794897); | |||||
// transform the angle into interval [0, pi/4] | |||||
auto ax = helper.abs(x); | |||||
auto ay = helper.abs(y); | |||||
auto q = helper.div(helper.min(ax, ay), helper.max(ax, ay)); | |||||
// get approximation for interval [0, pi/4] | |||||
auto r = atan_poly(q); | |||||
// [0, pi/4] => [0, pi/2] | |||||
r = helper.select(helper.le(ax, ay), helper.sub(pi_over_2, r), r); | |||||
// [0, pi/2] => [0, pi] | |||||
r = helper.select(helper.le(x, zero), helper.sub(pi, r), r); | |||||
// [0, pi] => [-pi, pi] | |||||
r = helper.select(helper.le(y, zero), helper.sub(zero, r), r); | |||||
return r; | |||||
} | |||||
// numerical approximation of gauss error function | |||||
// https://en.wikipedia.org/wiki/Error_function#Polynomial | |||||
// original book: | |||||
// Numerical Recipes in Fortran 77: The Art of Scientific Computing | |||||
mlir::Value erf_approx(ValueBuilderHelper& helper, mlir::Value x) { | |||||
auto zero = helper.const_val(0); | |||||
auto one = helper.const_val(1); | |||||
auto half = helper.const_val(0.5); | |||||
auto t = helper.div(one, helper.add(one, helper.mul(half, helper.abs(x)))); | |||||
std::vector<mlir::Value> coeff = { | |||||
helper.const_val(0.17087277), | |||||
helper.const_val(-0.82215223), | |||||
helper.const_val(1.48851587), | |||||
helper.const_val(-1.13520398), | |||||
helper.const_val(0.27886807), | |||||
helper.const_val(-0.18628806), | |||||
helper.const_val(0.09678418), | |||||
helper.const_val(0.37409196), | |||||
helper.const_val(1.00002368), | |||||
helper.const_val(-1.26551223)}; | |||||
auto p = polynomial(helper, t, coeff); | |||||
auto r = helper.mul(t, helper.exp(helper.sub(p, helper.mul(x, x)))); | |||||
return helper.select(helper.ge(x, zero), | |||||
helper.sub(one, r), | |||||
helper.sub(r, one)); | |||||
} | |||||
// numerical approximation of the inverse of normal distribution function | |||||
// original algorithm: | |||||
// https://github.com/scipy/scipy/blob/master/scipy/special/cephes/ndtri.c | |||||
// case 1: 0 < x < exp(-2) | |||||
// z = sqrt(-2 * log(x)) | |||||
// t = 1 / z | |||||
// res = log(z) / z - z + t * P(t) / Q(t) | |||||
// where coefficients of P and Q are different | |||||
// for z < 8 and for z >= 8 | |||||
// | |||||
// case2: exp(-2) <= x <= 1 - exp(-2) | |||||
// w = x - 0.5 | |||||
// res = sqrt(2pi) * (w + w^3 * R(w^2) / S(w^2)) | |||||
// | |||||
// case3: 1 - exp(-2) < x < 1 | |||||
// 0 < 1 - x < exp(-2) | |||||
// ndtri(x) = -ndtri(1 - x) | |||||
// fallback to case 1 | |||||
mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) { | |||||
// polynomial P | |||||
auto P = [&](mlir::Value i, mlir::Value cond) { | |||||
std::vector<mlir::Value> coeff0 = { | |||||
helper.const_val(4.05544892305962419923E0), | |||||
helper.const_val(3.15251094599893866154E1), | |||||
helper.const_val(5.71628192246421288162E1), | |||||
helper.const_val(4.40805073893200834700E1), | |||||
helper.const_val(1.46849561928858024014E1), | |||||
helper.const_val(2.18663306850790267539E0), | |||||
helper.const_val(-1.40256079171354495875E-1), | |||||
helper.const_val(-3.50424626827848203418E-2), | |||||
helper.const_val(-8.57456785154685413611E-4)}; | |||||
std::vector<mlir::Value> coeff1 = { | |||||
helper.const_val(3.23774891776946035970E0), | |||||
helper.const_val(6.91522889068984211695E0), | |||||
helper.const_val(3.93881025292474443415E0), | |||||
helper.const_val(1.33303460815807542389E0), | |||||
helper.const_val(2.01485389549179081538E-1), | |||||
helper.const_val(1.23716634817820021358E-2), | |||||
helper.const_val(3.01581553508235416007E-4), | |||||
helper.const_val(2.65806974686737550832E-6), | |||||
helper.const_val(6.23974539184983293730E-9)}; | |||||
return helper.select(cond, | |||||
polynomial(helper, i, coeff0), | |||||
polynomial(helper, i, coeff1)); | |||||
}; | |||||
// polynomial Q | |||||
auto Q = [&](mlir::Value i, mlir::Value cond) { | |||||
std::vector<mlir::Value> coeff0 = { | |||||
helper.const_val(1.f), | |||||
helper.const_val(1.57799883256466749731E1), | |||||
helper.const_val(4.53907635128879210584E1), | |||||
helper.const_val(4.13172038254672030440E1), | |||||
helper.const_val(1.50425385692907503408E1), | |||||
helper.const_val(2.50464946208309415979E0), | |||||
helper.const_val(-1.42182922854787788574E-1), | |||||
helper.const_val(-3.80806407691578277194E-2), | |||||
helper.const_val(-9.33259480895457427372E-4)}; | |||||
std::vector<mlir::Value> coeff1 = { | |||||
helper.const_val(1.f), | |||||
helper.const_val(6.02427039364742014255E0), | |||||
helper.const_val(3.67983563856160859403E0), | |||||
helper.const_val(1.37702099489081330271E0), | |||||
helper.const_val(2.16236993594496635890E-1), | |||||
helper.const_val(1.34204006088543189037E-2), | |||||
helper.const_val(3.28014464682127739104E-4), | |||||
helper.const_val(2.89247864745380683936E-6), | |||||
helper.const_val(6.79019408009981274425E-9)}; | |||||
return helper.select(cond, | |||||
polynomial(helper, i, coeff0), | |||||
polynomial(helper, i, coeff1)); | |||||
}; | |||||
// polynomial R | |||||
auto R = [&](mlir::Value i) { | |||||
std::vector<mlir::Value> coeff = { | |||||
helper.const_val(-5.99633501014107895267E1), | |||||
helper.const_val(9.80010754185999661536E1), | |||||
helper.const_val(-5.66762857469070293439E1), | |||||
helper.const_val(1.39312609387279679503E1), | |||||
helper.const_val(-1.23916583867381258016E0)}; | |||||
return polynomial(helper, i, coeff); | |||||
}; | |||||
// polynomial S | |||||
auto S = [&](mlir::Value i) { | |||||
std::vector<mlir::Value> coeff = { | |||||
helper.const_val(1.f), | |||||
helper.const_val(1.95448858338141759834E0), | |||||
helper.const_val(4.67627912898881538453E0), | |||||
helper.const_val(8.63602421390890590575E1), | |||||
helper.const_val(-2.25462687854119370527E2), | |||||
helper.const_val(2.00260212380060660359E2), | |||||
helper.const_val(-8.20372256168333339912E1), | |||||
helper.const_val(1.59056225126211695515E1), | |||||
helper.const_val(-1.18331621121330003142E0)}; | |||||
return polynomial(helper, i, coeff); | |||||
}; | |||||
// constants | |||||
auto zero = helper.const_val(0); | |||||
auto one = helper.const_val(1); | |||||
auto half = helper.const_val(0.5); | |||||
auto eight = helper.const_val(8); | |||||
auto minus_2 = helper.const_val(-2); | |||||
auto exp_minus_2 = helper.const_val(0.135335283236); // exp(-2) | |||||
auto sqrt_2pi = helper.const_val(2.506628274631); // sqrt(2pi) | |||||
// conditions | |||||
auto case1 = helper.lt(x, exp_minus_2); // x < exp(-2) | |||||
auto case3 = helper.gt(x, helper.sub(one, exp_minus_2)); // x > 1 - exp(-2) | |||||
auto case13 = helper.bit_or(case1, case3); | |||||
// case1 or case3 | |||||
auto x13 = helper.select(case1, x, helper.sub(one, x)); // x or (1 - x) | |||||
auto z = helper.sqrt(helper.mul(minus_2, helper.log(x13))); | |||||
auto z_lt_8 = helper.lt(z, eight); | |||||
auto t = helper.div(one, z); | |||||
auto res1 = helper.add(helper.sub(helper.div(helper.log(z), z), z), | |||||
helper.div(helper.mul(t, P(t, z_lt_8)), Q(t, z_lt_8))); | |||||
auto res13 = helper.select(case1, res1, helper.sub(zero, res1)); | |||||
// case2 | |||||
auto w = helper.sub(x, half); | |||||
auto w2 = helper.mul(w, w); | |||||
auto w3 = helper.mul(w, w2); | |||||
auto res2 = helper.mul( | |||||
sqrt_2pi, helper.add(w, helper.div(helper.mul(w3, R(w2)), S(w2)))); | |||||
return helper.select(case13, res13, res2); | |||||
} | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,46 @@ | |||||
/** | |||||
* \file src/jit/impl/mlir/ir/numerical.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain_build_config.h" | |||||
#if MGB_JIT && MGB_JIT_MLIR | |||||
#include <vector> | |||||
#include "./common.h" | |||||
namespace mgb { | |||||
namespace jit { | |||||
/*! polynomial of degree N: | |||||
* C_0 + C_1 * x + C_2 * x^2 + ... + C_N * x^N | |||||
* where coeff = [C_N, ..., C_2, C_1, C_0] | |||||
*/ | |||||
mlir::Value polynomial(ValueBuilderHelper& helper, mlir::Value x, | |||||
std::vector<mlir::Value>& coeff); | |||||
//! numerical approximation of arctangent | |||||
mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y, mlir::Value x); | |||||
//! numerical approximation of gauss error function | |||||
mlir::Value erf_approx(ValueBuilderHelper& helper, mlir::Value x); | |||||
//! numerical approximation of the inverse of normal distribution function | |||||
mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x); | |||||
} // namespace jit | |||||
} // namespace mgb | |||||
#endif // MGB_JIT && MGB_JIT_MLIR | |||||
// vim: syntax=cpp.doxygen |
@@ -68,8 +68,8 @@ class ElemwiseUnaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | |||||
def ReluOp : ElemwiseUnaryOp<"relu", [NoSideEffect]>; | def ReluOp : ElemwiseUnaryOp<"relu", [NoSideEffect]>; | ||||
def AbsOp : ElemwiseUnaryOp<"abs", [NoSideEffect]>; | def AbsOp : ElemwiseUnaryOp<"abs", [NoSideEffect]>; | ||||
def NegOp : ElemwiseUnaryOp<"negate", [NoSideEffect]>; | def NegOp : ElemwiseUnaryOp<"negate", [NoSideEffect]>; | ||||
/* ACOS */ | |||||
/* ASIN */ | |||||
def AcosOp : ElemwiseUnaryOp<"acos", [NoSideEffect]>; | |||||
def AsinOp : ElemwiseUnaryOp<"asin", [NoSideEffect]>; | |||||
def CeilOp : ElemwiseUnaryOp<"ceil", [NoSideEffect]>; | def CeilOp : ElemwiseUnaryOp<"ceil", [NoSideEffect]>; | ||||
def CosOp : ElemwiseUnaryOp<"cos", [NoSideEffect]>; | def CosOp : ElemwiseUnaryOp<"cos", [NoSideEffect]>; | ||||
def ExpOp : ElemwiseUnaryOp<"exp", [NoSideEffect]>; | def ExpOp : ElemwiseUnaryOp<"exp", [NoSideEffect]>; | ||||
@@ -83,10 +83,10 @@ def TanhOp : ElemwiseUnaryOp<"tanh", [NoSideEffect]>; | |||||
def FastTanhOp : ElemwiseUnaryOp<"fast_tanh", [NoSideEffect]>; | def FastTanhOp : ElemwiseUnaryOp<"fast_tanh", [NoSideEffect]>; | ||||
def HswishOp : ElemwiseUnaryOp<"hswish", [NoSideEffect]>; | def HswishOp : ElemwiseUnaryOp<"hswish", [NoSideEffect]>; | ||||
def RoundOp : ElemwiseUnaryOp<"round", [NoSideEffect]>; | def RoundOp : ElemwiseUnaryOp<"round", [NoSideEffect]>; | ||||
/* ERF */ | |||||
/* ERFINV */ | |||||
/* ERFC */ | |||||
/* ERFCINV */ | |||||
def ErfOp : ElemwiseUnaryOp<"erf", [NoSideEffect]>; | |||||
def ErfInvOp : ElemwiseUnaryOp<"erfinv", [NoSideEffect]>; | |||||
def ErfCOp : ElemwiseUnaryOp<"erfc", [NoSideEffect]>; | |||||
def ErfCInvOp : ElemwiseUnaryOp<"erfcinv", [NoSideEffect]>; | |||||
class ElemwiseBinaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | class ElemwiseBinaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | ||||
ElemwiseOp<mnemonic, traits> { | ElemwiseOp<mnemonic, traits> { | ||||
@@ -130,14 +130,14 @@ def LeqOp : ElemwiseBinaryOp<"leq", [NoSideEffect]>; | |||||
def EqOp : ElemwiseBinaryOp<"eq", [Commutative, NoSideEffect]>; | def EqOp : ElemwiseBinaryOp<"eq", [Commutative, NoSideEffect]>; | ||||
def FuseAddReluOp : ElemwiseBinaryOp<"fuse_add_relu", [NoSideEffect]>; | def FuseAddReluOp : ElemwiseBinaryOp<"fuse_add_relu", [NoSideEffect]>; | ||||
def TrueDivOp : ElemwiseBinaryOp<"true_div", [NoSideEffect]>; | def TrueDivOp : ElemwiseBinaryOp<"true_div", [NoSideEffect]>; | ||||
/* POW */ | |||||
def PowOp : ElemwiseBinaryOp<"pow", [NoSideEffect]>; | |||||
def LogSumExpOp : ElemwiseBinaryOp<"log_sum_exp", [Commutative, NoSideEffect]>; | def LogSumExpOp : ElemwiseBinaryOp<"log_sum_exp", [Commutative, NoSideEffect]>; | ||||
def FuseAddTanhOp : ElemwiseBinaryOp<"fuse_add_tanh", [NoSideEffect]>; | def FuseAddTanhOp : ElemwiseBinaryOp<"fuse_add_tanh", [NoSideEffect]>; | ||||
def FastTanhGradOp : ElemwiseBinaryOp<"fast_tanh_grad", [NoSideEffect]>; | def FastTanhGradOp : ElemwiseBinaryOp<"fast_tanh_grad", [NoSideEffect]>; | ||||
def FuseAddSigmoidOp : ElemwiseBinaryOp<"fuse_add_sigmoid", [NoSideEffect]>; | def FuseAddSigmoidOp : ElemwiseBinaryOp<"fuse_add_sigmoid", [NoSideEffect]>; | ||||
def HswishGradOp : ElemwiseBinaryOp<"hswish_grad", [NoSideEffect]>; | def HswishGradOp : ElemwiseBinaryOp<"hswish_grad", [NoSideEffect]>; | ||||
def FuseAddHswishOp : ElemwiseBinaryOp<"fuse_add_hswish", [NoSideEffect]>; | def FuseAddHswishOp : ElemwiseBinaryOp<"fuse_add_hswish", [NoSideEffect]>; | ||||
/* ATAN2 */ | |||||
def Atan2Op : ElemwiseBinaryOp<"atan2", [NoSideEffect]>; | |||||
class ElemwiseTernaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | class ElemwiseTernaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> : | ||||
ElemwiseOp<mnemonic, traits> { | ElemwiseOp<mnemonic, traits> { | ||||
@@ -159,22 +159,48 @@ void run_mlir(CompNode cn) { | |||||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | ||||
} | } | ||||
struct MlirTestOpt { | |||||
float low; | |||||
float high; | |||||
float maxerr; | |||||
}; | |||||
struct MlirTestOpt get_mode_opt(opr::Elemwise::Mode mode) { | |||||
struct MlirTestOpt opt = {0, 1, 1e-6}; | |||||
if (mode == opr::Elemwise::Mode::ABS) { | |||||
opt.low = -10; | |||||
opt.high = 10; | |||||
} else if (mode == opr::Elemwise::Mode::LOG) { | |||||
opt.low = 0.1; | |||||
opt.high = 4; | |||||
} else if (mode == opr::Elemwise::Mode::ERF or | |||||
mode == opr::Elemwise::Mode::ERFC) { | |||||
opt.low = -5; | |||||
opt.high = 5; | |||||
} else if (mode == opr::Elemwise::Mode::ERFINV) { | |||||
opt.low = -0.999; | |||||
opt.high = 0.999; | |||||
opt.maxerr = 1e-4; | |||||
} else if (mode == opr::Elemwise::Mode::ERFCINV) { | |||||
opt.low = 0.001; | |||||
opt.high = 1.999; | |||||
opt.maxerr = 1e-4; | |||||
} | |||||
return opt; | |||||
} | |||||
template <typename tag, int arity> | template <typename tag, int arity> | ||||
void run_mlir_mode(CompNode cn) { | void run_mlir_mode(CompNode cn) { | ||||
set_backend(Backend::MLIR); | set_backend(Backend::MLIR); | ||||
auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
float low = 0.f, high = 1.f; | |||||
if (tag::mode == opr::Elemwise::Mode::LOG) { | |||||
low = 0.1; | |||||
high = 4; | |||||
} | |||||
HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen(low, | |||||
high); | |||||
auto opt = get_mode_opt(tag::mode); | |||||
HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen(opt.low, | |||||
opt.high); | |||||
SmallVector<std::shared_ptr<HostTensorND>> hosts; | SmallVector<std::shared_ptr<HostTensorND>> hosts; | ||||
VarNodeArray input_vars; | VarNodeArray input_vars; | ||||
for (int i = 0; i < arity; i++) { | for (int i = 0; i < arity; i++) { | ||||
hosts.push_back(gen({23, 42}, cn)); | |||||
hosts.push_back(gen({2323, 4242}, cn)); | |||||
input_vars.push_back( | input_vars.push_back( | ||||
opr::Host2DeviceCopy::make(*graph, hosts[i]).node()); | opr::Host2DeviceCopy::make(*graph, hosts[i]).node()); | ||||
} | } | ||||
@@ -198,7 +224,7 @@ void run_mlir_mode(CompNode cn) { | |||||
make_callback_copy(y_jit, host_y_jit)}); | make_callback_copy(y_jit, host_y_jit)}); | ||||
func->execute(); | func->execute(); | ||||
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); | |||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_jit, opt.maxerr); | |||||
} | } | ||||
#endif | #endif | ||||
@@ -240,18 +266,25 @@ TEST(TestJITMlirCodeGen, BasicGPU) { | |||||
cb(RELU) \ | cb(RELU) \ | ||||
cb(ABS) \ | cb(ABS) \ | ||||
cb(NEGATE) \ | cb(NEGATE) \ | ||||
cb(ACOS) \ | |||||
cb(ASIN) \ | |||||
cb(CEIL) \ | cb(CEIL) \ | ||||
cb(EXP) \ | cb(EXP) \ | ||||
cb(FLOOR) \ | cb(FLOOR) \ | ||||
cb(LOG) \ | cb(LOG) \ | ||||
cb(LOG1P) \ | cb(LOG1P) \ | ||||
cb(SIN) \ | cb(SIN) \ | ||||
cb(COS) \ | |||||
cb(TANH) \ | cb(TANH) \ | ||||
cb(FAST_TANH) \ | cb(FAST_TANH) \ | ||||
cb(H_SWISH) \ | cb(H_SWISH) \ | ||||
cb(SIGMOID) \ | cb(SIGMOID) \ | ||||
cb(EXPM1) \ | cb(EXPM1) \ | ||||
cb(ROUND) | |||||
cb(ROUND) \ | |||||
cb(ERF) \ | |||||
cb(ERFINV) \ | |||||
cb(ERFC) \ | |||||
cb(ERFCINV) | |||||
// clang-format on | // clang-format on | ||||
template <typename tag> | template <typename tag> | ||||
class TestJITMlirUnaryElemwise : public ::testing::Test {}; | class TestJITMlirUnaryElemwise : public ::testing::Test {}; | ||||
@@ -268,21 +301,27 @@ FOREACH_UNARY_MODE(def_tag) | |||||
::testing::Types<FOREACH_UNARY_MODE(t) ABS>; | ::testing::Types<FOREACH_UNARY_MODE(t) ABS>; | ||||
#undef t | #undef t | ||||
TYPED_TEST_CASE(TestJITMlirUnaryElemwise, mlir_elemwise_unary_types); | TYPED_TEST_CASE(TestJITMlirUnaryElemwise, mlir_elemwise_unary_types); | ||||
TYPED_TEST(TestJITMlirUnaryElemwise, run) { | |||||
auto cn = CompNode::load("cpu0"); | |||||
run_mlir_mode<TypeParam, 1>(cn); | |||||
} | |||||
#define SKIP_MODE(_mode) \ | #define SKIP_MODE(_mode) \ | ||||
if (TypeParam::mode == opr::Elemwise::Mode::_mode) { \ | if (TypeParam::mode == opr::Elemwise::Mode::_mode) { \ | ||||
printf("skip\n"); \ | printf("skip\n"); \ | ||||
return; \ | return; \ | ||||
} | } | ||||
TYPED_TEST(TestJITMlirUnaryElemwise, run) { | |||||
auto cn = CompNode::load("cpu0"); | |||||
SKIP_MODE(ROUND); | |||||
run_mlir_mode<TypeParam, 1>(cn); | |||||
} | |||||
TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { | TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { | ||||
REQUIRE_GPU(1); | REQUIRE_GPU(1); | ||||
auto cn = CompNode::load("gpu0"); | auto cn = CompNode::load("gpu0"); | ||||
SKIP_MODE(SIN); | SKIP_MODE(SIN); | ||||
SKIP_MODE(ROUND); | |||||
run_mlir_mode<TypeParam, 1>(cn); | run_mlir_mode<TypeParam, 1>(cn); | ||||
} | } | ||||
@@ -298,6 +337,7 @@ TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { | |||||
cb(MOD) \ | cb(MOD) \ | ||||
cb(SUB) \ | cb(SUB) \ | ||||
cb(TRUE_DIV) \ | cb(TRUE_DIV) \ | ||||
cb(POW) \ | |||||
cb(ABS_GRAD) \ | cb(ABS_GRAD) \ | ||||
cb(SIGMOID_GRAD) \ | cb(SIGMOID_GRAD) \ | ||||
cb(SWITCH_GT0) \ | cb(SWITCH_GT0) \ | ||||
@@ -311,7 +351,8 @@ TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { | |||||
cb(FAST_TANH_GRAD) \ | cb(FAST_TANH_GRAD) \ | ||||
cb(FUSE_ADD_SIGMOID) \ | cb(FUSE_ADD_SIGMOID) \ | ||||
cb(H_SWISH_GRAD) \ | cb(H_SWISH_GRAD) \ | ||||
cb(FUSE_ADD_H_SWISH) | |||||
cb(FUSE_ADD_H_SWISH) \ | |||||
cb(ATAN2) | |||||
// clang-format on | // clang-format on | ||||
template <typename tag> | template <typename tag> | ||||
class TestJITMlirBinaryElemwise : public ::testing::Test {}; | class TestJITMlirBinaryElemwise : public ::testing::Test {}; | ||||
@@ -336,6 +377,9 @@ TYPED_TEST(TestJITMlirBinaryElemwise, run) { | |||||
TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) { | TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) { | ||||
REQUIRE_GPU(1); | REQUIRE_GPU(1); | ||||
auto cn = CompNode::load("gpu0"); | auto cn = CompNode::load("gpu0"); | ||||
SKIP_MODE(MOD); | |||||
run_mlir_mode<TypeParam, 2>(cn); | run_mlir_mode<TypeParam, 2>(cn); | ||||
} | } | ||||
@@ -373,7 +417,7 @@ TYPED_TEST(TestJITMlirTernaryElemwise, runGpu) { | |||||
#undef SKIP_MODE | #undef SKIP_MODE | ||||
#endif | |||||
#endif // MGB_JIT_MLIR | |||||
#endif // MGB_JIT | #endif // MGB_JIT | ||||