GitOrigin-RevId: ebf03d814a
release-1.7
@@ -48,10 +48,10 @@ struct MeanOp { | |||||
src_ctype* src; | src_ctype* src; | ||||
dst_ctype* dst; | dst_ctype* dst; | ||||
const size_t B; | const size_t B; | ||||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | ||||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | ||||
dst[idx] = val / static_cast<dst_ctype>(B); | |||||
dst[idx] = val / static_cast<wtype>(B); | |||||
} | } | ||||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | ||||
return lhs + rhs; | return lhs + rhs; | ||||
@@ -103,6 +103,16 @@ TEST_F(CUDA, REDUCE) { | |||||
.set_param(param) | .set_param(param) | ||||
.execs({{1, 4194304, 1}, {1, 1, 1}}); | .execs({{1, 4194304, 1}, {1, 1, 1}}); | ||||
} | } | ||||
{ | |||||
// large reduce_mean for O16C32 | |||||
Reduce::Param param{Mode::MEAN, 1, | |||||
Reduce::Param::DataType::FLOAT_O16xC32}; | |||||
checker.set_dtype(0, dtype::Float16()) | |||||
.set_dtype(1, dtype::Float16()) | |||||
.set_param(param) | |||||
.execs({{1, 65536, 5}, {1, 1, 5}}); | |||||
} | |||||
} | } | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -74,6 +74,15 @@ TEST_F(FALLBACK, REDUCE) { | |||||
Config config(param, dtype, shape); | Config config(param, dtype, shape); | ||||
configs.push_back(config); | configs.push_back(config); | ||||
} | } | ||||
{ | |||||
// large reduce_mean for O16C32 | |||||
TensorShape shape{1, 65536, 5}; | |||||
Param param(Mode::MEAN, 1, DataType::FLOAT_O16xC32); | |||||
Config config(param, dtype::Float16(), shape); | |||||
configs.push_back(config); | |||||
} | |||||
for (auto&& config : configs) { | for (auto&& config : configs) { | ||||
auto&& dtype = config.dtype; | auto&& dtype = config.dtype; | ||||
auto&& param = config.param; | auto&& param = config.param; | ||||
@@ -103,6 +103,16 @@ TEST_F(ROCM, REDUCE) { | |||||
.set_param(param) | .set_param(param) | ||||
.execs({{1, 4194304, 1}, {1, 1, 1}}); | .execs({{1, 4194304, 1}, {1, 1, 1}}); | ||||
} | } | ||||
{ | |||||
// large reduce_mean for O16C32 | |||||
Reduce::Param param{Mode::MEAN, 1, | |||||
Reduce::Param::DataType::FLOAT_O16xC32}; | |||||
checker.set_dtype(0, dtype::Float16()) | |||||
.set_dtype(1, dtype::Float16()) | |||||
.set_param(param) | |||||
.execs({{1, 65536, 5}, {1, 1, 5}}); | |||||
} | |||||
#endif | #endif | ||||
} | } | ||||