GitOrigin-RevId: 1444f69cce
tags/v1.7.2.m1
@@ -119,6 +119,15 @@ __device__ __host__ inline float dispatch_powf(float x, float y) { | |||||
return powf(x, y); | return powf(x, y); | ||||
} | } | ||||
__device__ __host__ inline int dispatch_floordiv_int(int x, int y) { | |||||
if ((x ^ y) < 0) { | |||||
const auto quot = x / y; | |||||
const auto rem = x % y; | |||||
return rem ? quot - 1 : quot; | |||||
} | |||||
return x / y; | |||||
} | |||||
#include "src/common/elemwise/each_mode.inl" | #include "src/common/elemwise/each_mode.inl" | ||||
template <megcorePlatform_t plat, uint32_t mode, typename dtype> | template <megcorePlatform_t plat, uint32_t mode, typename dtype> | ||||
@@ -227,7 +236,7 @@ DEF_KERN(dt_bool, LT, x < y); | |||||
DEF_KERN(dt_bool, LEQ, x <= y); | DEF_KERN(dt_bool, LEQ, x <= y); | ||||
DEF_KERN(dt_bool, EQ, x == y); | DEF_KERN(dt_bool, EQ, x == y); | ||||
DEF_KERN_INT(FLOOR_DIV, x / y); | |||||
DEF_KERN_INT(FLOOR_DIV, dispatch_floordiv_int(x, y)); | |||||
DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); | DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); | ||||
DEF_KERN_INT(MOD, x % y); | DEF_KERN_INT(MOD, x % y); | ||||
@@ -59,7 +59,7 @@ def test_multiply(): | |||||
def test_div(): | def test_div(): | ||||
np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
F.div(tensor([3, 4]), 2).numpy(), | |||||
F.div(tensor([3.0, 4.0]), 2).numpy(), | |||||
np.divide(np.array([3, 4], dtype=np.float32), 2), | np.divide(np.array([3, 4], dtype=np.float32), 2), | ||||
) | ) | ||||
@@ -67,6 +67,16 @@ def test_div(): | |||||
(tensor([3, 4]) / 2).numpy(), np.divide(np.array([3, 4], dtype=np.float32), 2), | (tensor([3, 4]) / 2).numpy(), np.divide(np.array([3, 4], dtype=np.float32), 2), | ||||
) | ) | ||||
np.testing.assert_allclose( | |||||
F.floor_div(tensor([-5.0, -7.0]), 2).numpy(), | |||||
np.floor_divide(np.array([-5.0, -7.0], dtype=np.float32), 2), | |||||
) | |||||
np.testing.assert_allclose( | |||||
(tensor([-5, -7]) // 2).numpy(), | |||||
np.floor_divide(np.array([-5, -7], dtype=np.int32), 2), | |||||
) | |||||
def test_clamp(): | def test_clamp(): | ||||
"""Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and | """Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and | ||||
@@ -39,6 +39,19 @@ int do_mod(int a, int b) { | |||||
return a % b; | return a % b; | ||||
} | } | ||||
float do_floor_div(float a, float b) { | |||||
return std::floor(a / b); | |||||
} | |||||
int do_floor_div(int a, int b) { | |||||
if ((a ^ b) < 0) { | |||||
const auto quot = a / b; | |||||
const auto rem = a % b; | |||||
return rem ? quot - 1 : quot; | |||||
} | |||||
return a / b; | |||||
} | |||||
float do_erfinv(float x) { | float do_erfinv(float x) { | ||||
return erfinvf(x); | return erfinvf(x); | ||||
} | } | ||||
@@ -41,7 +41,7 @@ DEF_TRAIT(LT, x < y) | |||||
#define _ALLOW_INT true | #define _ALLOW_INT true | ||||
DEF_TRAIT(ABS_GRAD, x > 0 ? y : -y) | DEF_TRAIT(ABS_GRAD, x > 0 ? y : -y) | ||||
DEF_TRAIT(ADD, x + y) | DEF_TRAIT(ADD, x + y) | ||||
DEF_TRAIT(FLOOR_DIV, floor(x / y)) | |||||
DEF_TRAIT(FLOOR_DIV, do_floor_div(x, y)) | |||||
DEF_TRAIT(MAX, std::max(x, y)) | DEF_TRAIT(MAX, std::max(x, y)) | ||||
DEF_TRAIT(MIN, std::min(x, y)) | DEF_TRAIT(MIN, std::min(x, y)) | ||||
DEF_TRAIT(MOD, do_mod(x, y)) | DEF_TRAIT(MOD, do_mod(x, y)) | ||||