GitOrigin-RevId: e7d64c4987
release-1.6
@@ -424,12 +424,20 @@ size_t TensorLayout::access_bytes() const { | |||||
if (dtype.is_low_bit()) { | if (dtype.is_low_bit()) { | ||||
ret = 1; | ret = 1; | ||||
int align_size_in_elements = 8 / dtype.low_bit(); | int align_size_in_elements = 8 / dtype.low_bit(); | ||||
auto min_stride = contig.stride[0]; | |||||
for (size_t i = 0; i < contig.ndim; ++i) { | for (size_t i = 0; i < contig.ndim; ++i) { | ||||
if (contig.stride[i] == 1) { | if (contig.stride[i] == 1) { | ||||
ret *= round_up((int)contig.shape[i], align_size_in_elements); | ret *= round_up((int)contig.shape[i], align_size_in_elements); | ||||
} else { | } else { | ||||
ret *= contig.shape[i]; | ret *= contig.shape[i]; | ||||
} | } | ||||
if (min_stride > contig.stride[i]) { | |||||
min_stride = contig.stride[i]; | |||||
} | |||||
} | |||||
if (min_stride != 1) { | |||||
megdnn_assert(min_stride == align_size_in_elements); | |||||
ret *= min_stride; | |||||
} | } | ||||
ret /= align_size_in_elements; | ret /= align_size_in_elements; | ||||
} else { | } else { | ||||
@@ -240,6 +240,7 @@ template <int ndim> | |||||
void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init( | void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init( | ||||
const TensorND& rv, int /*grid_size*/, int /*block_size*/) { | const TensorND& rv, int /*grid_size*/, int /*block_size*/) { | ||||
m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr); | m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr); | ||||
auto min_stride = rv.layout.stride[0]; | |||||
for (size_t i = 0; i < rv.layout.ndim; ++i) { | for (size_t i = 0; i < rv.layout.ndim; ++i) { | ||||
m_stride[i] = rv.layout.stride[i]; | m_stride[i] = rv.layout.stride[i]; | ||||
m_shape[i] = rv.layout.shape[i]; | m_shape[i] = rv.layout.shape[i]; | ||||
@@ -251,7 +252,12 @@ void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init( | |||||
else | else | ||||
m_align_shape_highdim[i] = rv.layout.shape[i + 1]; | m_align_shape_highdim[i] = rv.layout.shape[i + 1]; | ||||
} | } | ||||
if (min_stride > rv.layout.stride[i]) { | |||||
min_stride = rv.layout.stride[i]; | |||||
} | |||||
} | } | ||||
megdnn_assert(min_stride == 1 || min_stride == 2); | |||||
m_is_min_stride_2 = (min_stride == 2); | |||||
for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) { | for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) { | ||||
m_shape_highdim[i] = 1; | m_shape_highdim[i] = 1; | ||||
m_align_shape_highdim[i] = 1; | m_align_shape_highdim[i] = 1; | ||||
@@ -542,6 +542,7 @@ protected: | |||||
int m_stride[ndim]; | int m_stride[ndim]; | ||||
int m_shape[ndim]; | int m_shape[ndim]; | ||||
bool m_is_physical_contiguous; | bool m_is_physical_contiguous; | ||||
bool m_is_min_stride_2; | |||||
//! m_shape_highdim[i] = original_shape[i + 1] | //! m_shape_highdim[i] = original_shape[i + 1] | ||||
#ifdef _MSC_VER | #ifdef _MSC_VER | ||||
@@ -592,7 +593,7 @@ public: | |||||
int idx = 0; | int idx = 0; | ||||
if (m_is_physical_contiguous) { | if (m_is_physical_contiguous) { | ||||
idx = access_idx; | idx = access_idx; | ||||
} else { | |||||
} else if (!m_is_min_stride_2) { | |||||
int shape_idx[ndim]; | int shape_idx[ndim]; | ||||
bool valid = true; | bool valid = true; | ||||
get_shape_from_access(access_idx, shape_idx); | get_shape_from_access(access_idx, shape_idx); | ||||
@@ -605,6 +606,8 @@ public: | |||||
idx = (idx + shape_idx[i]) * m_shape[i + 1]; | idx = (idx + shape_idx[i]) * m_shape[i + 1]; | ||||
} | } | ||||
idx = valid ? idx + shape_idx[ndim - 1] : -1; | idx = valid ? idx + shape_idx[ndim - 1] : -1; | ||||
} else { // min_stride == 2 | |||||
idx = ((access_idx & 0x1) == 0) ? ((int)access_idx >> 1) : -1; | |||||
} | } | ||||
return idx; | return idx; | ||||
} | } | ||||
@@ -70,6 +70,7 @@ void ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER>::host_init( | |||||
const TensorND& rv, int /*grid_size*/, int /*block_size*/) { | const TensorND& rv, int /*grid_size*/, int /*block_size*/) { | ||||
megdnn_assert(rv.layout.ndim && rv.layout.ndim <= ndim); | megdnn_assert(rv.layout.ndim && rv.layout.ndim <= ndim); | ||||
m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr); | m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr); | ||||
auto min_stride = rv.layout.stride[0]; | |||||
for (size_t i = 0; i < rv.layout.ndim; ++i) { | for (size_t i = 0; i < rv.layout.ndim; ++i) { | ||||
m_stride[i] = rv.layout.stride[i]; | m_stride[i] = rv.layout.stride[i]; | ||||
m_shape[i] = rv.layout.shape[i]; | m_shape[i] = rv.layout.shape[i]; | ||||
@@ -81,7 +82,12 @@ void ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER>::host_init( | |||||
else | else | ||||
m_align_shape_highdim[i] = rv.layout.shape[i + 1]; | m_align_shape_highdim[i] = rv.layout.shape[i + 1]; | ||||
} | } | ||||
if (min_stride > rv.layout.stride[i]) { | |||||
min_stride = rv.layout.stride[i]; | |||||
} | |||||
} | } | ||||
megdnn_assert(min_stride == 1 || min_stride == 2); | |||||
m_is_min_stride_2 = (min_stride == 2); | |||||
for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) { | for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) { | ||||
m_shape_highdim[i] = 1; | m_shape_highdim[i] = 1; | ||||
m_align_shape_highdim[i] = 1; | m_align_shape_highdim[i] = 1; | ||||
@@ -132,6 +132,7 @@ class ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER> { | |||||
int m_shape[ndim]; | int m_shape[ndim]; | ||||
bool m_is_contiguous; | bool m_is_contiguous; | ||||
bool m_is_physical_contiguous; | bool m_is_physical_contiguous; | ||||
bool m_is_min_stride_2; | |||||
//! m_shape_highdim[i] = original_shape[i + 1] | //! m_shape_highdim[i] = original_shape[i + 1] | ||||
#ifdef _MSC_VER | #ifdef _MSC_VER | ||||
@@ -197,7 +198,7 @@ public: | |||||
int idx = 0; | int idx = 0; | ||||
if (m_is_physical_contiguous) { | if (m_is_physical_contiguous) { | ||||
idx = access_idx; | idx = access_idx; | ||||
} else { | |||||
} else if (!m_is_min_stride_2) { | |||||
int shape_idx[ndim]; | int shape_idx[ndim]; | ||||
bool valid = true; | bool valid = true; | ||||
get_shape_from_access(access_idx, shape_idx); | get_shape_from_access(access_idx, shape_idx); | ||||
@@ -209,6 +210,8 @@ public: | |||||
idx = (idx + shape_idx[i]) * m_shape[i + 1]; | idx = (idx + shape_idx[i]) * m_shape[i + 1]; | ||||
} | } | ||||
idx = valid ? idx + shape_idx[ndim - 1] : -1; | idx = valid ? idx + shape_idx[ndim - 1] : -1; | ||||
} else { // min_stride == 2 | |||||
idx = ((access_idx & 0x1) == 0) ? ((int)access_idx >> 1) : -1; | |||||
} | } | ||||
return idx; | return idx; | ||||
} | } | ||||
@@ -152,7 +152,8 @@ static void run_test_q4(int arity, Checker<ElemwiseMultiType>& checker, | |||||
.execs({{1, 4, 5, 5}, {1, 4, 5, 5}}); | .execs({{1, 4, 5, 5}, {1, 4, 5, 5}}); | ||||
} else if (arity == 2) { | } else if (arity == 2) { | ||||
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}}) | checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}}) | ||||
.execs({{1, 4, 5, 5}, {1, 4, 5, 5}, {1, 4, 5, 5}}); | |||||
.execs({{1, 4, 5, 5}, {1, 4, 5, 5}, {1, 4, 5, 5}}) | |||||
.execs({{2, 2, 3, 1}, {2, 2, 3, 1}, {2, 2, 3, 1}}); | |||||
} else { | } else { | ||||
megdnn_assert(0); | megdnn_assert(0); | ||||
} | } | ||||
@@ -925,6 +925,7 @@ TEST_F(CUDA, RELAYOUT_Q4) { | |||||
.set_rng(1, &rng_int4) | .set_rng(1, &rng_int4) | ||||
.set_dtype(0, dtype::QuantizedS4(1.f)) | .set_dtype(0, dtype::QuantizedS4(1.f)) | ||||
.set_dtype(1, dtype::QuantizedS4(1.f)) | .set_dtype(1, dtype::QuantizedS4(1.f)) | ||||
.execs({{2, 2, 1, 1}, {1, 1, 2, 2}}) | |||||
.execs({{1, 64, 15, 15}, {1, 15, 15, 64}}) | .execs({{1, 64, 15, 15}, {1, 15, 15, 64}}) | ||||
.execs({{1, 5, 9, 32}, {1, 5, 32, 9}}) | .execs({{1, 5, 9, 32}, {1, 5, 32, 9}}) | ||||
.execl(TensorLayoutArray{ | .execl(TensorLayoutArray{ | ||||
@@ -123,11 +123,13 @@ TEST_F(CUDA, QUANTIZED_TYPECVT_4BIT) { | |||||
set_err(dst_dtype); | set_err(dst_dtype); | ||||
checker.set_dtype(0, src_dtype) | checker.set_dtype(0, src_dtype) | ||||
.set_dtype(1, dst_dtype) | .set_dtype(1, dst_dtype) | ||||
.execs({{16, 3, 224, 223}, {16, 3, 224, 223}}); | |||||
.execs({{16, 3, 224, 223}, {16, 3, 224, 223}}) | |||||
.execs({{16, 3, 224, 1}, {16, 3, 224, 1}}); | |||||
set_err(src_dtype); | set_err(src_dtype); | ||||
checker.set_dtype(0, dst_dtype) | checker.set_dtype(0, dst_dtype) | ||||
.set_dtype(1, src_dtype) | .set_dtype(1, src_dtype) | ||||
.execs({{16, 3, 224, 223}, {16, 3, 224, 223}}); | |||||
.execs({{16, 3, 224, 223}, {16, 3, 224, 223}}) | |||||
.execs({{16, 3, 224, 1}, {16, 3, 224, 1}}); | |||||
}; | }; | ||||
run(dtype::Quantized4Asymm{1.19990518f, 8}, | run(dtype::Quantized4Asymm{1.19990518f, 8}, | ||||