GitOrigin-RevId: 7e78bdae91
release-1.7
@@ -78,6 +78,72 @@ struct ArgmxxOp { | |||||
const wtype INIT; | const wtype INIT; | ||||
}; | }; | ||||
template <bool is_max> | |||||
struct ArgmxxOp<dt_float32, is_max> { | |||||
using stype_ = dt_float32; | |||||
struct wtype { | |||||
stype_ key; | |||||
dt_int32 val; | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype() {} | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype(stype_ key, dt_int32 val) | |||||
: key(key), val(val) {} | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype(wtype& rhs) : key(rhs.key), val(rhs.val) {} | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype(volatile wtype& rhs) | |||||
: key(rhs.key), val(rhs.val) {} | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype(const wtype& rhs) | |||||
: key(rhs.key), val(rhs.val) {} | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype(const volatile wtype& rhs) | |||||
: key(rhs.key), val(rhs.val) {} | |||||
MEGDNN_HOST MEGDNN_DEVICE volatile wtype& operator=(const wtype& rhs) volatile { | |||||
this->key = rhs.key; | |||||
this->val = rhs.val; | |||||
return *this; | |||||
} | |||||
}; | |||||
MEGDNN_HOST MEGDNN_DEVICE | |||||
ArgmxxOp(stype_* src, dt_int32* dst, uint32_t A, uint32_t B, uint32_t C) | |||||
: src(src), | |||||
dst(dst), | |||||
A(A), | |||||
B(B), | |||||
C(C), | |||||
INIT(wtype( | |||||
is_max ? DTypeTrait<stype_>::min() : DTypeTrait<stype_>::max(), | |||||
0)) {} | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||||
wtype res; | |||||
res.key = src[idx]; | |||||
res.val = idx / C % B; | |||||
return res; | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | |||||
dst[idx] = val.val; | |||||
} | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
#if defined(__CUDA_ARCH__) | |||||
if (isnan(lhs.key)) | |||||
#else | |||||
if (std::isnan(lhs.key)) | |||||
#endif | |||||
return lhs; | |||||
if (is_max) { | |||||
if (lhs.key > rhs.key) | |||||
return lhs; | |||||
else | |||||
return rhs; | |||||
} else { | |||||
if (lhs.key < rhs.key) | |||||
return lhs; | |||||
else | |||||
return rhs; | |||||
} | |||||
} | |||||
stype_* src; | |||||
dt_int32* dst; | |||||
uint32_t A, B, C; | |||||
const wtype INIT; | |||||
}; | |||||
} // namespace argmxx | } // namespace argmxx | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -119,6 +119,28 @@ struct MinOp { | |||||
: INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | ||||
}; | }; | ||||
template <typename src_ctype, typename dst_ctype> | |||||
struct MinOp<src_ctype, dst_ctype, dt_float32> { | |||||
typedef dt_float32 wtype; | |||||
const wtype INIT; | |||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
const size_t B; | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
#if defined(__CUDA_ARCH__) | |||||
return (isnan(lhs) || lhs < rhs) ? lhs : rhs; | |||||
#else | |||||
return (std::isnan(lhs) || lhs < rhs) ? lhs : rhs; | |||||
#endif | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE MinOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
: INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | |||||
}; | |||||
template <typename src_ctype, typename dst_ctype, typename wtype_> | template <typename src_ctype, typename dst_ctype, typename wtype_> | ||||
struct MaxOp { | struct MaxOp { | ||||
typedef wtype_ wtype; | typedef wtype_ wtype; | ||||
@@ -141,6 +163,28 @@ struct MaxOp { | |||||
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | ||||
}; | }; | ||||
template <typename src_ctype, typename dst_ctype> | |||||
struct MaxOp<src_ctype, dst_ctype, dt_float32> { | |||||
typedef dt_float32 wtype; | |||||
const wtype INIT; | |||||
src_ctype* src; | |||||
dst_ctype* dst; | |||||
const size_t B; | |||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
#if defined(__CUDA_ARCH__) | |||||
return (isnan(lhs) || lhs > rhs) ? lhs : rhs; | |||||
#else | |||||
return (std::isnan(lhs) || lhs > rhs) ? lhs : rhs; | |||||
#endif | |||||
} | |||||
MEGDNN_HOST MEGDNN_DEVICE MaxOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | |||||
}; | |||||
template <typename src_ctype, typename dst_ctype, typename wtype_> | template <typename src_ctype, typename dst_ctype, typename wtype_> | ||||
struct CheckNonFiniteOp { | struct CheckNonFiniteOp { | ||||
typedef wtype_ wtype; | typedef wtype_ wtype; | ||||
@@ -30,6 +30,10 @@ struct FakeQuantKernOp { | |||||
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) { | __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) { | ||||
ctype x = round(input[idx] / scale) + zero_point; | ctype x = round(input[idx] / scale) + zero_point; | ||||
if (isnan(x)) { | |||||
output[idx] = NAN; | |||||
return; | |||||
} | |||||
x = fmaxf(fminf(x, qmax), qmin); | x = fmaxf(fminf(x, qmax), qmin); | ||||
output[idx] = (x - zero_point) * scale; | output[idx] = (x - zero_point) * scale; | ||||
} | } | ||||
@@ -54,7 +58,7 @@ struct FakeQuantBwdKernOp { | |||||
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) { | __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) { | ||||
ctype x = round(input[idx] / scale) + zero_point; | ctype x = round(input[idx] / scale) + zero_point; | ||||
grad[idx] = x <= qmax && x >= qmin ? diff[idx] : 0.0; | |||||
grad[idx] = isnan(x) ? NAN : x <= qmax && x >= qmin ? diff[idx] : 0.0; | |||||
} | } | ||||
#if MEGDNN_CC_HOST | #if MEGDNN_CC_HOST | ||||
@@ -77,6 +81,10 @@ struct FakeQuantKernOpNonContig { | |||||
__device__ void operator()( | __device__ void operator()( | ||||
uint32_t, ctype& output, ctype input, ctype scale, ctype zero_point) { | uint32_t, ctype& output, ctype input, ctype scale, ctype zero_point) { | ||||
ctype x = round(input / scale) + zero_point; | ctype x = round(input / scale) + zero_point; | ||||
if (isnan(x)) { | |||||
output = NAN; | |||||
return; | |||||
} | |||||
x = fmaxf(fminf(x, qmax), qmin); | x = fmaxf(fminf(x, qmax), qmin); | ||||
output = (x - zero_point) * scale; | output = (x - zero_point) * scale; | ||||
} | } | ||||
@@ -96,7 +104,7 @@ struct FakeQuantBwdKernOpNonContig { | |||||
uint32_t, ctype& grad, ctype diff, ctype input, ctype scale, | uint32_t, ctype& grad, ctype diff, ctype input, ctype scale, | ||||
ctype zero_point) { | ctype zero_point) { | ||||
ctype x = round(input / scale) + zero_point; | ctype x = round(input / scale) + zero_point; | ||||
grad = x <= qmax && x >= qmin ? diff : 0.0; | |||||
grad = isnan(x) ? NAN : x <= qmax && x >= qmin ? diff : 0.0; | |||||
} | } | ||||
#if MEGDNN_CC_HOST | #if MEGDNN_CC_HOST | ||||
@@ -26,14 +26,18 @@ struct traits; | |||||
template <> | template <> | ||||
struct traits<true> { | struct traits<true> { | ||||
static const float init; | static const float init; | ||||
static bool better_than(float lhs, float rhs) { return lhs > rhs; } | |||||
static bool better_than(float lhs, float rhs) { | |||||
return std::isnan(lhs) ? true : lhs > rhs; | |||||
} | |||||
}; | }; | ||||
const float traits<true>::init = std::numeric_limits<float>::lowest(); | const float traits<true>::init = std::numeric_limits<float>::lowest(); | ||||
template <> | template <> | ||||
struct traits<false> { | struct traits<false> { | ||||
static const float init; | static const float init; | ||||
static float better_than(float lhs, float rhs) { return lhs < rhs; } | |||||
static float better_than(float lhs, float rhs) { | |||||
return std::isnan(lhs) ? true : lhs < rhs; | |||||
} | |||||
}; | }; | ||||
const float traits<false>::init = std::numeric_limits<float>::max(); | const float traits<false>::init = std::numeric_limits<float>::max(); | ||||
@@ -73,25 +73,35 @@ const ctype Trait<Mode::PRODUCT, ctype>::INIT = ctype(1); | |||||
template <typename ctype> | template <typename ctype> | ||||
struct Trait<Mode::MIN, ctype> { | struct Trait<Mode::MIN, ctype> { | ||||
static const ctype INIT; | |||||
static ctype apply(ctype x, ctype y) { return x < y ? x : y; } | static ctype apply(ctype x, ctype y) { return x < y ? x : y; } | ||||
static ctype visit(ctype x) { return x; } | static ctype visit(ctype x) { return x; } | ||||
static ctype write(ctype x, size_t) { return x; } | static ctype write(ctype x, size_t) { return x; } | ||||
}; | }; | ||||
template <typename ctype> | |||||
const ctype Trait<Mode::MIN, ctype>::INIT = DTypeTrait<ctype>::max(); | |||||
template <> | |||||
struct Trait<Mode::MIN, dt_float32> { | |||||
using ctype = dt_float32; | |||||
static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x < y) ? x : y; } | |||||
static ctype visit(ctype x) { return x; } | |||||
static ctype write(ctype x, size_t) { return x; } | |||||
}; | |||||
template <typename ctype> | template <typename ctype> | ||||
struct Trait<Mode::MAX, ctype> { | struct Trait<Mode::MAX, ctype> { | ||||
static const ctype INIT; | |||||
static ctype apply(ctype x, ctype y) { return x > y ? x : y; } | static ctype apply(ctype x, ctype y) { return x > y ? x : y; } | ||||
static ctype visit(ctype x) { return x; } | static ctype visit(ctype x) { return x; } | ||||
static ctype write(ctype x, size_t) { return x; } | static ctype write(ctype x, size_t) { return x; } | ||||
}; | }; | ||||
template <typename ctype> | |||||
const ctype Trait<Mode::MAX, ctype>::INIT = DTypeTrait<ctype>::min(); | |||||
template <> | |||||
struct Trait<Mode::MAX, dt_float32> { | |||||
using ctype = dt_float32; | |||||
static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x > y) ? x : y; } | |||||
static ctype visit(ctype x) { return x; } | |||||
static ctype write(ctype x, size_t) { return x; } | |||||
}; | |||||
template <Mode mode, typename ctype> | template <Mode mode, typename ctype> | ||||
void reduce_fwd( | void reduce_fwd( | ||||
@@ -21,7 +21,9 @@ using namespace fake_quant; | |||||
TEST_F(CUDA, FAKE_QUANT) { | TEST_F(CUDA, FAKE_QUANT) { | ||||
std::vector<TestArg> args = get_args(); | std::vector<TestArg> args = get_args(); | ||||
auto dtype = dtype::Float32(); | auto dtype = dtype::Float32(); | ||||
std::unique_ptr<RNG> rng; | |||||
UniformFloatRNG rng(-1.0f, 1.0f); | |||||
const auto nan = std::numeric_limits<float>::quiet_NaN(); | |||||
UniformFloatWithValueRNG rng1 = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan); | |||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
auto param = arg.param; | auto param = arg.param; | ||||
@@ -35,6 +37,17 @@ TEST_F(CUDA, FAKE_QUANT) { | |||||
.set_dtype(2, dtype) | .set_dtype(2, dtype) | ||||
.set_dtype(3, dtype) | .set_dtype(3, dtype) | ||||
.execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape}); | .execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape}); | ||||
checker.set_allow_invalid_check(true); | |||||
checker.set_rng(0, &rng1); | |||||
checker.set_param(param) | |||||
.set_dtype(0, dtype) | |||||
.set_dtype(1, dtype) | |||||
.set_dtype(2, dtype) | |||||
.set_dtype(3, dtype) | |||||
.execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape}); | |||||
checker.set_rng(0, &rng); | |||||
checker.set_allow_invalid_check(false); | |||||
} | } | ||||
// test noncontiguous layout | // test noncontiguous layout | ||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
@@ -53,12 +66,25 @@ TEST_F(CUDA, FAKE_QUANT) { | |||||
{scale_shape, dtype::Float32()}, | {scale_shape, dtype::Float32()}, | ||||
{zeropoint_shape, dtype::Float32()}, | {zeropoint_shape, dtype::Float32()}, | ||||
ilayout}); | ilayout}); | ||||
checker.set_allow_invalid_check(true); | |||||
checker.set_rng(0, &rng1); | |||||
checker.set_param(param).execl( | |||||
{ilayout, | |||||
{scale_shape, dtype::Float32()}, | |||||
{zeropoint_shape, dtype::Float32()}, | |||||
ilayout}); | |||||
checker.set_rng(0, &rng); | |||||
checker.set_allow_invalid_check(false); | |||||
} | } | ||||
} | } | ||||
TEST_F(CUDA, FAKE_QUANT_BACKWARD) { | TEST_F(CUDA, FAKE_QUANT_BACKWARD) { | ||||
std::vector<TestArg> args = get_args(); | std::vector<TestArg> args = get_args(); | ||||
auto dtype = dtype::Float32(); | auto dtype = dtype::Float32(); | ||||
UniformFloatRNG rng(-1.0f, 1.0f); | |||||
const auto nan = std::numeric_limits<float>::quiet_NaN(); | |||||
UniformFloatWithValueRNG rng1 = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan); | |||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
auto param = arg.param; | auto param = arg.param; | ||||
@@ -74,6 +100,19 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) { | |||||
.set_dtype(4, dtype) | .set_dtype(4, dtype) | ||||
.execs(TensorShapeArray{ | .execs(TensorShapeArray{ | ||||
ishape, ishape, scale_shape, zeropoint_shape, ishape}); | ishape, ishape, scale_shape, zeropoint_shape, ishape}); | ||||
checker.set_allow_invalid_check(true); | |||||
checker.set_rng(0, &rng1); | |||||
checker.set_param(param) | |||||
.set_dtype(0, dtype) | |||||
.set_dtype(1, dtype) | |||||
.set_dtype(2, dtype) | |||||
.set_dtype(3, dtype) | |||||
.set_dtype(4, dtype) | |||||
.execs(TensorShapeArray{ | |||||
ishape, ishape, scale_shape, zeropoint_shape, ishape}); | |||||
checker.set_rng(0, &rng); | |||||
checker.set_allow_invalid_check(false); | |||||
} | } | ||||
// test noncontiguous layout | // test noncontiguous layout | ||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
@@ -93,6 +132,17 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) { | |||||
{scale_shape, dtype::Float32()}, | {scale_shape, dtype::Float32()}, | ||||
{zeropoint_shape, dtype::Float32()}, | {zeropoint_shape, dtype::Float32()}, | ||||
ilayout}); | ilayout}); | ||||
checker.set_allow_invalid_check(true); | |||||
checker.set_rng(0, &rng1); | |||||
checker.set_param(param).execl( | |||||
{ilayout, | |||||
ilayout, | |||||
{scale_shape, dtype::Float32()}, | |||||
{zeropoint_shape, dtype::Float32()}, | |||||
ilayout}); | |||||
checker.set_rng(0, &rng); | |||||
checker.set_allow_invalid_check(false); | |||||
} | } | ||||
} | } | ||||
@@ -54,6 +54,20 @@ TEST_F(CUDA, REDUCE) { | |||||
// very large reduce | // very large reduce | ||||
checker.execs({{1, 4194304, 1}, {}}); | checker.execs({{1, 4194304, 1}, {}}); | ||||
// inputs have nan | |||||
{ | |||||
const auto nan = std::numeric_limits<float>::quiet_NaN(); | |||||
UniformFloatWithValueRNG rng1 = | |||||
UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan); | |||||
checker.set_allow_invalid_check(true).set_rng(0, &rng1); | |||||
for (auto mode : {Mode::MIN, Mode::MAX}) { | |||||
checker.set_param({mode, 1}); | |||||
checker.execs({{2, 64, 32}, {}}); | |||||
} | |||||
checker.set_allow_invalid_check(false); | |||||
} | |||||
checker.set_rng(0, &rng); | |||||
auto check = [&](Reduce::Mode mode, DType src_dtype, DType dst_dtype, | auto check = [&](Reduce::Mode mode, DType src_dtype, DType dst_dtype, | ||||
Reduce::DataType data_type) { | Reduce::DataType data_type) { | ||||
for (int32_t axis : {0, 1, 2, 3}) { | for (int32_t axis : {0, 1, 2, 3}) { | ||||
@@ -21,7 +21,11 @@ def common_test_reduce(opr, ref_opr): | |||||
data2_shape = (2, 9, 12) | data2_shape = (2, 9, 12) | ||||
data1 = np.random.random(data1_shape).astype(np.float32) | data1 = np.random.random(data1_shape).astype(np.float32) | ||||
data2 = np.random.random(data2_shape).astype(np.float32) | data2 = np.random.random(data2_shape).astype(np.float32) | ||||
cases = [{"input": data1}, {"input": data2}] | |||||
cases = [ | |||||
{"input": data1}, | |||||
{"input": data2}, | |||||
{"input": np.array([[[1, 2, np.nan, 4], [8, 6, 5, 2], [2, 3, 4, 5]]])}, | |||||
] | |||||
if opr not in (F.argmin, F.argmax): | if opr not in (F.argmin, F.argmax): | ||||
# test default axis | # test default axis | ||||
@@ -143,6 +143,11 @@ def test_fakequant(): | |||||
assert np.allclose(x.grad.numpy(), x1.grad.numpy()) | assert np.allclose(x.grad.numpy(), x1.grad.numpy()) | ||||
assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape) | assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape) | ||||
# test nan | |||||
x = F.full((1, 32, 3, 3), np.nan) | |||||
y = fake_quant_tensor(x, qparams).numpy() | |||||
assert np.isnan(y).all() | |||||
zero_point = tensor([1.0], dtype=np.float32) | zero_point = tensor([1.0], dtype=np.float32) | ||||
scale = tensor([4.0], dtype=np.float32) | scale = tensor([4.0], dtype=np.float32) | ||||
run(zero_point, scale) | run(zero_point, scale) | ||||