@@ -89,8 +89,9 @@ void RemapBackwardData::check_exec( | |||||
size_t workspace_in_bytes) { | size_t workspace_in_bytes) { | ||||
check_layout_fwd(grad, map_xy, diff); | check_layout_fwd(grad, map_xy, diff); | ||||
megdnn_assert( | megdnn_assert( | ||||
grad.dtype == dtype::Float32() | |||||
DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()), | |||||
grad.dtype == | |||||
dtype::Float32() DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()) | |||||
DNN_INC_FLOAT16(|| grad.dtype == dtype::Float16()), | |||||
"Backward Remap only supports Float32/BFloat16."); | "Backward Remap only supports Float32/BFloat16."); | ||||
auto required_workspace_in_bytes = get_workspace_in_bytes(map_xy, diff, grad); | auto required_workspace_in_bytes = get_workspace_in_bytes(map_xy, diff, grad); | ||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | ||||
@@ -102,8 +103,9 @@ void RemapBackwardMat::check_exec( | |||||
check_layout_fwd(src, map_xy, diff); | check_layout_fwd(src, map_xy, diff); | ||||
megdnn_assert_eq_layout(map_xy, grad); | megdnn_assert_eq_layout(map_xy, grad); | ||||
megdnn_assert( | megdnn_assert( | ||||
grad.dtype == dtype::Float32() | |||||
DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()), | |||||
grad.dtype == | |||||
dtype::Float32() DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()) | |||||
DNN_INC_FLOAT16(|| grad.dtype == dtype::Float16()), | |||||
"Backward Remap only supports Float32/BFloat16."); | "Backward Remap only supports Float32/BFloat16."); | ||||
auto required_workspace_in_bytes = get_workspace_in_bytes(src, map_xy, diff, grad); | auto required_workspace_in_bytes = get_workspace_in_bytes(src, map_xy, diff, grad); | ||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | ||||
@@ -61,6 +61,7 @@ void RemapBackwardDataImpl::exec( | |||||
switch (grad.layout.dtype.enumv()) { | switch (grad.layout.dtype.enumv()) { | ||||
support_dtype(dtype::Float32); | support_dtype(dtype::Float32); | ||||
support_dtype(dtype::BFloat16); | support_dtype(dtype::BFloat16); | ||||
support_dtype(dtype::Float16); | |||||
default: | default: | ||||
megdnn_throw("unsupported dtype in remap backward cuda\n"); | megdnn_throw("unsupported dtype in remap backward cuda\n"); | ||||
} | } | ||||
@@ -155,6 +155,7 @@ void backwarddata_proxy( | |||||
FOR_FORMAT_BMODE(float) | FOR_FORMAT_BMODE(float) | ||||
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | ||||
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) | |||||
#undef FOR_FORMAT_BMODE | #undef FOR_FORMAT_BMODE | ||||
#undef INST | #undef INST | ||||
@@ -62,6 +62,7 @@ void RemapBackwardMatImpl::exec( | |||||
switch (src.layout.dtype.enumv()) { | switch (src.layout.dtype.enumv()) { | ||||
support_dtype(dtype::Float32); | support_dtype(dtype::Float32); | ||||
support_dtype(dtype::BFloat16); | support_dtype(dtype::BFloat16); | ||||
support_dtype(dtype::Float16); | |||||
default: | default: | ||||
megdnn_throw("unsupported dtype in remap backward cuda\n"); | megdnn_throw("unsupported dtype in remap backward cuda\n"); | ||||
} | } | ||||
@@ -156,6 +156,7 @@ void backwardmat_proxy( | |||||
FOR_FORMAT_BMODE(float) | FOR_FORMAT_BMODE(float) | ||||
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | ||||
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) | |||||
#undef FOR_FORMAT_BMODE | #undef FOR_FORMAT_BMODE | ||||
#undef INST | #undef INST | ||||
@@ -320,6 +320,7 @@ void RemapBackwardDataImpl::exec( | |||||
support_dtype(dtype::Float32); | support_dtype(dtype::Float32); | ||||
DNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); | DNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); | ||||
DNN_INC_FLOAT16(support_dtype(dtype::Float16)); | |||||
#undef cb | #undef cb | ||||
#undef support_dtype | #undef support_dtype | ||||
@@ -371,6 +372,7 @@ void RemapBackwardMatImpl::exec( | |||||
support_dtype(dtype::Float32); | support_dtype(dtype::Float32); | ||||
DNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); | DNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); | ||||
DNN_INC_FLOAT16(support_dtype(dtype::Float16)); | |||||
#undef cb | #undef cb | ||||
#undef support_dtype | #undef support_dtype | ||||
@@ -180,6 +180,7 @@ TEST_F(CUDA, REMAP_BACKWARD_DATA) { | |||||
.execs({arg.map_xy, arg.dst, arg.src}); \ | .execs({arg.map_xy, arg.dst, arg.src}); \ | ||||
} | } | ||||
cb(dtype::BFloat16(), float_rng); | cb(dtype::BFloat16(), float_rng); | ||||
cb(dtype::Float16(), float_rng); | |||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -222,6 +223,7 @@ TEST_F(CUDA, REMAP_BACKWARD_MAT) { | |||||
.execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \ | .execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \ | ||||
} | } | ||||
cb(dtype::BFloat16(), float_rng); | cb(dtype::BFloat16(), float_rng); | ||||
cb(dtype::Float16(), float_rng); | |||||
#undef cb | #undef cb | ||||
} | } | ||||