Browse Source

fix(mgb/plugin): fix io dump for qint4, quint4 and bool type tensors

GitOrigin-RevId: bad1e88690
tags/v1.7.0.m1
Megvii Engine Team 3 years ago
parent
commit
5b1383e06c
1 changed files with 19 additions and 7 deletions
  1. +19
    -7
      src/plugin/impl/opr_io_dump.cpp

+ 19
- 7
src/plugin/impl/opr_io_dump.cpp View File

@@ -34,6 +34,14 @@ double as_double(megdnn::dt_qint8& a) {
return static_cast<double>(a.as_int8());
}
template <>
double as_double(megdnn::dt_quint4& a) {
return static_cast<double>(a.as_uint8());
}
template <>
double as_double(megdnn::dt_qint4& a) {
return static_cast<double>(a.as_int8());
}
template <>
double as_double(megdnn::dt_qint32& a) {
return static_cast<double>(a.as_int32());
}
@@ -69,7 +77,10 @@ void do_print_host_val(
sum2 += as_double(i) * as_double(i);
};
size_t nr = val.layout.total_nr_elems();
if (val.layout.is_contiguous()) {
bool normal_contig = !val.layout.dtype.is_low_bit() && val.layout.is_contiguous();
bool lowbit_contig =
val.layout.dtype.is_low_bit() && val.layout.is_physical_contiguous();
if (normal_contig || lowbit_contig) {
ctype* ptr = val.ptr<ctype>();
for (size_t i = 0; i < nr; ++i) {
update(ptr[i]);
@@ -99,13 +110,14 @@ void print_host_val(
fout, max_nr_print, val, print_stat);
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
cb(dtype::Bool)
#undef cb
default:
mgb_throw(
MegBrainError,
"can not handle dtype %s in "
"print_host_val",
val.layout.dtype.name());
default : mgb_throw(
MegBrainError,
"can not handle dtype %s in "
"print_host_val",
val.layout.dtype.name());
}
};



Loading…
Cancel
Save