GitOrigin-RevId: 7e78bdae91
release-1.7
@@ -78,6 +78,72 @@ struct ArgmxxOp { | |||
const wtype INIT; | |||
}; | |||
template <bool is_max> | |||
struct ArgmxxOp<dt_float32, is_max> { | |||
using stype_ = dt_float32; | |||
struct wtype { | |||
stype_ key; | |||
dt_int32 val; | |||
MEGDNN_HOST MEGDNN_DEVICE wtype() {} | |||
MEGDNN_HOST MEGDNN_DEVICE wtype(stype_ key, dt_int32 val) | |||
: key(key), val(val) {} | |||
MEGDNN_HOST MEGDNN_DEVICE wtype(wtype& rhs) : key(rhs.key), val(rhs.val) {} | |||
MEGDNN_HOST MEGDNN_DEVICE wtype(volatile wtype& rhs) | |||
: key(rhs.key), val(rhs.val) {} | |||
MEGDNN_HOST MEGDNN_DEVICE wtype(const wtype& rhs) | |||
: key(rhs.key), val(rhs.val) {} | |||
MEGDNN_HOST MEGDNN_DEVICE wtype(const volatile wtype& rhs) | |||
: key(rhs.key), val(rhs.val) {} | |||
MEGDNN_HOST MEGDNN_DEVICE volatile wtype& operator=(const wtype& rhs) volatile { | |||
this->key = rhs.key; | |||
this->val = rhs.val; | |||
return *this; | |||
} | |||
}; | |||
MEGDNN_HOST MEGDNN_DEVICE | |||
ArgmxxOp(stype_* src, dt_int32* dst, uint32_t A, uint32_t B, uint32_t C) | |||
: src(src), | |||
dst(dst), | |||
A(A), | |||
B(B), | |||
C(C), | |||
INIT(wtype( | |||
is_max ? DTypeTrait<stype_>::min() : DTypeTrait<stype_>::max(), | |||
0)) {} | |||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||
wtype res; | |||
res.key = src[idx]; | |||
res.val = idx / C % B; | |||
return res; | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | |||
dst[idx] = val.val; | |||
} | |||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||
#if defined(__CUDA_ARCH__) | |||
if (isnan(lhs.key)) | |||
#else | |||
if (std::isnan(lhs.key)) | |||
#endif | |||
return lhs; | |||
if (is_max) { | |||
if (lhs.key > rhs.key) | |||
return lhs; | |||
else | |||
return rhs; | |||
} else { | |||
if (lhs.key < rhs.key) | |||
return lhs; | |||
else | |||
return rhs; | |||
} | |||
} | |||
stype_* src; | |||
dt_int32* dst; | |||
uint32_t A, B, C; | |||
const wtype INIT; | |||
}; | |||
} // namespace argmxx | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -119,6 +119,28 @@ struct MinOp { | |||
: INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | |||
}; | |||
template <typename src_ctype, typename dst_ctype> | |||
struct MinOp<src_ctype, dst_ctype, dt_float32> { | |||
typedef dt_float32 wtype; | |||
const wtype INIT; | |||
src_ctype* src; | |||
dst_ctype* dst; | |||
const size_t B; | |||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||
#if defined(__CUDA_ARCH__) | |||
return (isnan(lhs) || lhs < rhs) ? lhs : rhs; | |||
#else | |||
return (std::isnan(lhs) || lhs < rhs) ? lhs : rhs; | |||
#endif | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE MinOp(src_ctype* src, dst_ctype* dst, size_t B) | |||
: INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | |||
}; | |||
template <typename src_ctype, typename dst_ctype, typename wtype_> | |||
struct MaxOp { | |||
typedef wtype_ wtype; | |||
@@ -141,6 +163,28 @@ struct MaxOp { | |||
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | |||
}; | |||
template <typename src_ctype, typename dst_ctype> | |||
struct MaxOp<src_ctype, dst_ctype, dt_float32> { | |||
typedef dt_float32 wtype; | |||
const wtype INIT; | |||
src_ctype* src; | |||
dst_ctype* dst; | |||
const size_t B; | |||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||
#if defined(__CUDA_ARCH__) | |||
return (isnan(lhs) || lhs > rhs) ? lhs : rhs; | |||
#else | |||
return (std::isnan(lhs) || lhs > rhs) ? lhs : rhs; | |||
#endif | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE MaxOp(src_ctype* src, dst_ctype* dst, size_t B) | |||
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | |||
}; | |||
template <typename src_ctype, typename dst_ctype, typename wtype_> | |||
struct CheckNonFiniteOp { | |||
typedef wtype_ wtype; | |||
@@ -30,6 +30,10 @@ struct FakeQuantKernOp { | |||
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) { | |||
ctype x = round(input[idx] / scale) + zero_point; | |||
if (isnan(x)) { | |||
output[idx] = NAN; | |||
return; | |||
} | |||
x = fmaxf(fminf(x, qmax), qmin); | |||
output[idx] = (x - zero_point) * scale; | |||
} | |||
@@ -54,7 +58,7 @@ struct FakeQuantBwdKernOp { | |||
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) { | |||
ctype x = round(input[idx] / scale) + zero_point; | |||
grad[idx] = x <= qmax && x >= qmin ? diff[idx] : 0.0; | |||
grad[idx] = isnan(x) ? NAN : x <= qmax && x >= qmin ? diff[idx] : 0.0; | |||
} | |||
#if MEGDNN_CC_HOST | |||
@@ -77,6 +81,10 @@ struct FakeQuantKernOpNonContig { | |||
__device__ void operator()( | |||
uint32_t, ctype& output, ctype input, ctype scale, ctype zero_point) { | |||
ctype x = round(input / scale) + zero_point; | |||
if (isnan(x)) { | |||
output = NAN; | |||
return; | |||
} | |||
x = fmaxf(fminf(x, qmax), qmin); | |||
output = (x - zero_point) * scale; | |||
} | |||
@@ -96,7 +104,7 @@ struct FakeQuantBwdKernOpNonContig { | |||
uint32_t, ctype& grad, ctype diff, ctype input, ctype scale, | |||
ctype zero_point) { | |||
ctype x = round(input / scale) + zero_point; | |||
grad = x <= qmax && x >= qmin ? diff : 0.0; | |||
grad = isnan(x) ? NAN : x <= qmax && x >= qmin ? diff : 0.0; | |||
} | |||
#if MEGDNN_CC_HOST | |||
@@ -26,14 +26,18 @@ struct traits; | |||
template <> | |||
struct traits<true> { | |||
static const float init; | |||
static bool better_than(float lhs, float rhs) { return lhs > rhs; } | |||
static bool better_than(float lhs, float rhs) { | |||
return std::isnan(lhs) ? true : lhs > rhs; | |||
} | |||
}; | |||
const float traits<true>::init = std::numeric_limits<float>::lowest(); | |||
template <> | |||
struct traits<false> { | |||
static const float init; | |||
static float better_than(float lhs, float rhs) { return lhs < rhs; } | |||
static float better_than(float lhs, float rhs) { | |||
return std::isnan(lhs) ? true : lhs < rhs; | |||
} | |||
}; | |||
const float traits<false>::init = std::numeric_limits<float>::max(); | |||
@@ -73,25 +73,35 @@ const ctype Trait<Mode::PRODUCT, ctype>::INIT = ctype(1); | |||
template <typename ctype> | |||
struct Trait<Mode::MIN, ctype> { | |||
static const ctype INIT; | |||
static ctype apply(ctype x, ctype y) { return x < y ? x : y; } | |||
static ctype visit(ctype x) { return x; } | |||
static ctype write(ctype x, size_t) { return x; } | |||
}; | |||
template <typename ctype> | |||
const ctype Trait<Mode::MIN, ctype>::INIT = DTypeTrait<ctype>::max(); | |||
template <> | |||
struct Trait<Mode::MIN, dt_float32> { | |||
using ctype = dt_float32; | |||
static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x < y) ? x : y; } | |||
static ctype visit(ctype x) { return x; } | |||
static ctype write(ctype x, size_t) { return x; } | |||
}; | |||
template <typename ctype> | |||
struct Trait<Mode::MAX, ctype> { | |||
static const ctype INIT; | |||
static ctype apply(ctype x, ctype y) { return x > y ? x : y; } | |||
static ctype visit(ctype x) { return x; } | |||
static ctype write(ctype x, size_t) { return x; } | |||
}; | |||
template <typename ctype> | |||
const ctype Trait<Mode::MAX, ctype>::INIT = DTypeTrait<ctype>::min(); | |||
template <> | |||
struct Trait<Mode::MAX, dt_float32> { | |||
using ctype = dt_float32; | |||
static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x > y) ? x : y; } | |||
static ctype visit(ctype x) { return x; } | |||
static ctype write(ctype x, size_t) { return x; } | |||
}; | |||
template <Mode mode, typename ctype> | |||
void reduce_fwd( | |||
@@ -21,7 +21,9 @@ using namespace fake_quant; | |||
TEST_F(CUDA, FAKE_QUANT) { | |||
std::vector<TestArg> args = get_args(); | |||
auto dtype = dtype::Float32(); | |||
std::unique_ptr<RNG> rng; | |||
UniformFloatRNG rng(-1.0f, 1.0f); | |||
const auto nan = std::numeric_limits<float>::quiet_NaN(); | |||
UniformFloatWithValueRNG rng1 = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan); | |||
for (auto&& arg : args) { | |||
auto param = arg.param; | |||
@@ -35,6 +37,17 @@ TEST_F(CUDA, FAKE_QUANT) { | |||
.set_dtype(2, dtype) | |||
.set_dtype(3, dtype) | |||
.execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape}); | |||
checker.set_allow_invalid_check(true); | |||
checker.set_rng(0, &rng1); | |||
checker.set_param(param) | |||
.set_dtype(0, dtype) | |||
.set_dtype(1, dtype) | |||
.set_dtype(2, dtype) | |||
.set_dtype(3, dtype) | |||
.execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape}); | |||
checker.set_rng(0, &rng); | |||
checker.set_allow_invalid_check(false); | |||
} | |||
// test noncontiguous layout | |||
for (auto&& arg : args) { | |||
@@ -53,12 +66,25 @@ TEST_F(CUDA, FAKE_QUANT) { | |||
{scale_shape, dtype::Float32()}, | |||
{zeropoint_shape, dtype::Float32()}, | |||
ilayout}); | |||
checker.set_allow_invalid_check(true); | |||
checker.set_rng(0, &rng1); | |||
checker.set_param(param).execl( | |||
{ilayout, | |||
{scale_shape, dtype::Float32()}, | |||
{zeropoint_shape, dtype::Float32()}, | |||
ilayout}); | |||
checker.set_rng(0, &rng); | |||
checker.set_allow_invalid_check(false); | |||
} | |||
} | |||
TEST_F(CUDA, FAKE_QUANT_BACKWARD) { | |||
std::vector<TestArg> args = get_args(); | |||
auto dtype = dtype::Float32(); | |||
UniformFloatRNG rng(-1.0f, 1.0f); | |||
const auto nan = std::numeric_limits<float>::quiet_NaN(); | |||
UniformFloatWithValueRNG rng1 = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan); | |||
for (auto&& arg : args) { | |||
auto param = arg.param; | |||
@@ -74,6 +100,19 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) { | |||
.set_dtype(4, dtype) | |||
.execs(TensorShapeArray{ | |||
ishape, ishape, scale_shape, zeropoint_shape, ishape}); | |||
checker.set_allow_invalid_check(true); | |||
checker.set_rng(0, &rng1); | |||
checker.set_param(param) | |||
.set_dtype(0, dtype) | |||
.set_dtype(1, dtype) | |||
.set_dtype(2, dtype) | |||
.set_dtype(3, dtype) | |||
.set_dtype(4, dtype) | |||
.execs(TensorShapeArray{ | |||
ishape, ishape, scale_shape, zeropoint_shape, ishape}); | |||
checker.set_rng(0, &rng); | |||
checker.set_allow_invalid_check(false); | |||
} | |||
// test noncontiguous layout | |||
for (auto&& arg : args) { | |||
@@ -93,6 +132,17 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) { | |||
{scale_shape, dtype::Float32()}, | |||
{zeropoint_shape, dtype::Float32()}, | |||
ilayout}); | |||
checker.set_allow_invalid_check(true); | |||
checker.set_rng(0, &rng1); | |||
checker.set_param(param).execl( | |||
{ilayout, | |||
ilayout, | |||
{scale_shape, dtype::Float32()}, | |||
{zeropoint_shape, dtype::Float32()}, | |||
ilayout}); | |||
checker.set_rng(0, &rng); | |||
checker.set_allow_invalid_check(false); | |||
} | |||
} | |||
@@ -54,6 +54,20 @@ TEST_F(CUDA, REDUCE) { | |||
// very large reduce | |||
checker.execs({{1, 4194304, 1}, {}}); | |||
// inputs have nan | |||
{ | |||
const auto nan = std::numeric_limits<float>::quiet_NaN(); | |||
UniformFloatWithValueRNG rng1 = | |||
UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan); | |||
checker.set_allow_invalid_check(true).set_rng(0, &rng1); | |||
for (auto mode : {Mode::MIN, Mode::MAX}) { | |||
checker.set_param({mode, 1}); | |||
checker.execs({{2, 64, 32}, {}}); | |||
} | |||
checker.set_allow_invalid_check(false); | |||
} | |||
checker.set_rng(0, &rng); | |||
auto check = [&](Reduce::Mode mode, DType src_dtype, DType dst_dtype, | |||
Reduce::DataType data_type) { | |||
for (int32_t axis : {0, 1, 2, 3}) { | |||
@@ -21,7 +21,11 @@ def common_test_reduce(opr, ref_opr): | |||
data2_shape = (2, 9, 12) | |||
data1 = np.random.random(data1_shape).astype(np.float32) | |||
data2 = np.random.random(data2_shape).astype(np.float32) | |||
cases = [{"input": data1}, {"input": data2}] | |||
cases = [ | |||
{"input": data1}, | |||
{"input": data2}, | |||
{"input": np.array([[[1, 2, np.nan, 4], [8, 6, 5, 2], [2, 3, 4, 5]]])}, | |||
] | |||
if opr not in (F.argmin, F.argmax): | |||
# test default axis | |||
@@ -143,6 +143,11 @@ def test_fakequant(): | |||
assert np.allclose(x.grad.numpy(), x1.grad.numpy()) | |||
assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape) | |||
# test nan | |||
x = F.full((1, 32, 3, 3), np.nan) | |||
y = fake_quant_tensor(x, qparams).numpy() | |||
assert np.isnan(y).all() | |||
zero_point = tensor([1.0], dtype=np.float32) | |||
scale = tensor([4.0], dtype=np.float32) | |||
run(zero_point, scale) | |||