Browse Source

fix(imperative): fix error message for tensors with intbx data type

GitOrigin-RevId: cbb42f8127
release-1.7
Megvii Engine Team 3 years ago
parent
commit
fe15239ac0
1 changed files with 8 additions and 7 deletions
  1. +8
    -7
      imperative/python/src/numpy_dtypes_intbx.cpp

+ 8
- 7
imperative/python/src/numpy_dtypes_intbx.cpp View File

@@ -24,7 +24,7 @@ template <size_t N>
struct LowBitType {
static_assert(N < 8, "low bit only supports less than 8 bits");
static int npy_typenum;
//! numerical value (-3, -1, 1, 3)
//! allowed numerical value: odd numbers between (-max_value, max_value)
int8_t value;

struct PyObj;
@@ -32,16 +32,17 @@ struct LowBitType {

const static int32_t max_value = (1 << N) - 1;

//! check whether val is (-3, -1, 1, 3) and set python error
//! check whether val is odd and between (-max_value, max_value) and set python error
static bool check_value_set_err(int val) {
int t = val + max_value;
if ((t & 1) || t < 0 || t > (max_value << 1)) {
PyErr_SetString(
PyExc_ValueError, mgb::ssprintf(
"low bit dtype number error: "
"value=%d; allowed {-3, -1, 1, 3}",
val)
.c_str());
PyExc_ValueError,
mgb::ssprintf(
"low bit dtype number error: "
"value=%d; allowed values are odd numbers between [%d,%d]",
val, -max_value, max_value)
.c_str());
return false;
}



Loading…
Cancel
Save