diff --git a/src/plugin/impl/opr_io_dump.cpp b/src/plugin/impl/opr_io_dump.cpp index 287c2338..d964dafe 100644 --- a/src/plugin/impl/opr_io_dump.cpp +++ b/src/plugin/impl/opr_io_dump.cpp @@ -34,6 +34,14 @@ double as_double(megdnn::dt_qint8& a) { return static_cast(a.as_int8()); } template <> +double as_double(megdnn::dt_quint4& a) { + return static_cast(a.as_uint8()); +} +template <> +double as_double(megdnn::dt_qint4& a) { + return static_cast(a.as_int8()); +} +template <> double as_double(megdnn::dt_qint32& a) { return static_cast(a.as_int32()); } @@ -69,7 +77,10 @@ void do_print_host_val( sum2 += as_double(i) * as_double(i); }; size_t nr = val.layout.total_nr_elems(); - if (val.layout.is_contiguous()) { + bool normal_contig = !val.layout.dtype.is_low_bit() && val.layout.is_contiguous(); + bool lowbit_contig = + val.layout.dtype.is_low_bit() && val.layout.is_physical_contiguous(); + if (normal_contig || lowbit_contig) { ctype* ptr = val.ptr(); for (size_t i = 0; i < nr; ++i) { update(ptr[i]); @@ -99,13 +110,14 @@ void print_host_val( fout, max_nr_print, val, print_stat); MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) + MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) + cb(dtype::Bool) #undef cb - default: - mgb_throw( - MegBrainError, - "can not handle dtype %s in " - "print_host_val", - val.layout.dtype.name()); + default : mgb_throw( + MegBrainError, + "can not handle dtype %s in " + "print_host_val", + val.layout.dtype.name()); } };