Browse Source

feat(dnn): add float16 for remap backward

GitOrigin-RevId: 0263030051
tags/v1.7.0.m1
Megvii Engine Team 3 years ago
parent
commit
2696e4efaa
7 changed files with 14 additions and 4 deletions
  1. +6
    -4
      dnn/src/common/remap.cpp
  2. +1
    -0
      dnn/src/cuda/remap/backward_data.cpp
  3. +1
    -0
      dnn/src/cuda/remap/backward_data.cu
  4. +1
    -0
      dnn/src/cuda/remap/backward_mat.cpp
  5. +1
    -0
      dnn/src/cuda/remap/backward_mat.cu
  6. +2
    -0
      dnn/src/naive/remap/opr_impl.cpp
  7. +2
    -0
      dnn/test/cuda/remap.cpp

+ 6
- 4
dnn/src/common/remap.cpp View File

@@ -89,8 +89,9 @@ void RemapBackwardData::check_exec(
size_t workspace_in_bytes) {
check_layout_fwd(grad, map_xy, diff);
megdnn_assert(
grad.dtype == dtype::Float32()
DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()),
grad.dtype ==
dtype::Float32() DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16())
DNN_INC_FLOAT16(|| grad.dtype == dtype::Float16()),
"Backward Remap only supports Float32/BFloat16.");
auto required_workspace_in_bytes = get_workspace_in_bytes(map_xy, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
@@ -102,8 +103,9 @@ void RemapBackwardMat::check_exec(
check_layout_fwd(src, map_xy, diff);
megdnn_assert_eq_layout(map_xy, grad);
megdnn_assert(
grad.dtype == dtype::Float32()
DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()),
grad.dtype ==
dtype::Float32() DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16())
DNN_INC_FLOAT16(|| grad.dtype == dtype::Float16()),
"Backward Remap only supports Float32/BFloat16.");
auto required_workspace_in_bytes = get_workspace_in_bytes(src, map_xy, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);


+ 1
- 0
dnn/src/cuda/remap/backward_data.cpp View File

@@ -61,6 +61,7 @@ void RemapBackwardDataImpl::exec(
switch (grad.layout.dtype.enumv()) {
support_dtype(dtype::Float32);
support_dtype(dtype::BFloat16);
support_dtype(dtype::Float16);
default:
megdnn_throw("unsupported dtype in remap backward cuda\n");
}


+ 1
- 0
dnn/src/cuda/remap/backward_data.cu View File

@@ -155,6 +155,7 @@ void backwarddata_proxy(

FOR_FORMAT_BMODE(float)
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16))

#undef FOR_FORMAT_BMODE
#undef INST


+ 1
- 0
dnn/src/cuda/remap/backward_mat.cpp View File

@@ -62,6 +62,7 @@ void RemapBackwardMatImpl::exec(
switch (src.layout.dtype.enumv()) {
support_dtype(dtype::Float32);
support_dtype(dtype::BFloat16);
support_dtype(dtype::Float16);
default:
megdnn_throw("unsupported dtype in remap backward cuda\n");
}


+ 1
- 0
dnn/src/cuda/remap/backward_mat.cu View File

@@ -156,6 +156,7 @@ void backwardmat_proxy(

FOR_FORMAT_BMODE(float)
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16))

#undef FOR_FORMAT_BMODE
#undef INST


+ 2
- 0
dnn/src/naive/remap/opr_impl.cpp View File

@@ -320,6 +320,7 @@ void RemapBackwardDataImpl::exec(

support_dtype(dtype::Float32);
DNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
DNN_INC_FLOAT16(support_dtype(dtype::Float16));
#undef cb
#undef support_dtype

@@ -371,6 +372,7 @@ void RemapBackwardMatImpl::exec(

support_dtype(dtype::Float32);
DNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
DNN_INC_FLOAT16(support_dtype(dtype::Float16));
#undef cb
#undef support_dtype



+ 2
- 0
dnn/test/cuda/remap.cpp View File

@@ -180,6 +180,7 @@ TEST_F(CUDA, REMAP_BACKWARD_DATA) {
.execs({arg.map_xy, arg.dst, arg.src}); \
}
cb(dtype::BFloat16(), float_rng);
cb(dtype::Float16(), float_rng);
#undef cb
}

@@ -222,6 +223,7 @@ TEST_F(CUDA, REMAP_BACKWARD_MAT) {
.execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \
}
cb(dtype::BFloat16(), float_rng);
cb(dtype::Float16(), float_rng);
#undef cb
}



Loading…
Cancel
Save