|
|
@@ -14,7 +14,6 @@ |
|
|
|
#include "megbrain/exception.h" |
|
|
|
#include "megbrain/utils/metahelper.h" |
|
|
|
#include "megbrain/utils/arith_helper.h" |
|
|
|
#include "megdnn/dtype.h" |
|
|
|
|
|
|
|
#include <cmath> |
|
|
|
#include <cstring> |
|
|
@@ -383,24 +382,40 @@ struct QuantizedLowbitMemcpy<DT, true> { |
|
|
|
// cast with bits that 8 % bits == 0 |
|
|
|
static constexpr uint16_t bits = DTypeTrait<DT>::low_bit; |
|
|
|
static constexpr uint8_t MASK = (1 << bits) - 1; |
|
|
|
using Trait = QuantizedLowbitTrait<DT>; |
|
|
|
static constexpr bool signedness = |
|
|
|
std::is_same<DT, dtype::QuantizedS4>::value; |
|
|
|
|
|
|
|
static void byte2compact(void* dest_raw, const void* src_raw, size_t n) { |
|
|
|
auto dest = static_cast<uint8_t*>(dest_raw); |
|
|
|
auto src = static_cast<const int8_t*>(src_raw); |
|
|
|
memset(dest, 0, divup<size_t>(n * bits, 8)); |
|
|
|
for (size_t i = 0; i < n; ++i) { |
|
|
|
int8_t val = src[i] + Trait::SHIFT; |
|
|
|
mgb_assert(val >= 0 && val < (1 << bits)); |
|
|
|
dest[i * bits / 8] |= val << (i * bits % 8); |
|
|
|
int8_t val = src[i]; |
|
|
|
static const auto min_val = DTypeTrait<DT>::min(); |
|
|
|
static const auto max_val = DTypeTrait<DT>::max(); |
|
|
|
MGB_MARK_USED_VAR(min_val); |
|
|
|
MGB_MARK_USED_VAR(max_val); |
|
|
|
mgb_assert(val >= static_cast<int8_t>(min_val) && |
|
|
|
val <= static_cast<int8_t>(max_val), |
|
|
|
"data exceeds range(%d,%d) of data type", min_val, |
|
|
|
max_val); |
|
|
|
dest[i * bits / 8] |= (val & MASK) << (i * bits % 8); |
|
|
|
} |
|
|
|
} |
|
|
|
static void compact2byte(void* dest_raw, const void* src_raw, size_t n) { |
|
|
|
auto dest = static_cast<int8_t*>(dest_raw); |
|
|
|
auto dest = reinterpret_cast<int8_t*>(dest_raw); |
|
|
|
auto src = static_cast<const uint8_t*>(src_raw); |
|
|
|
for (size_t i = 0; i < n; ++i) { |
|
|
|
int8_t val = ((src[i * bits / 8] >> (i * bits % 8)) & MASK); |
|
|
|
dest[i] = val - Trait::SHIFT; |
|
|
|
uint8_t intermediate = |
|
|
|
((src[i * bits / 8] >> (i * bits % 8)) & MASK); |
|
|
|
if (signedness) { |
|
|
|
int val = (intermediate & uint8_t(1 << (bits - 1))) |
|
|
|
? ((int)(intermediate) | ~(int)(MASK)) |
|
|
|
: (int)(intermediate); |
|
|
|
dest[i] = static_cast<int8_t>(val); |
|
|
|
} else { |
|
|
|
dest[i] = static_cast<int8_t>(intermediate); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|