|
|
@@ -89,8 +89,9 @@ void RemapBackwardData::check_exec( |
|
|
|
size_t workspace_in_bytes) { |
|
|
|
check_layout_fwd(grad, map_xy, diff); |
|
|
|
megdnn_assert( |
|
|
|
grad.dtype == dtype::Float32() |
|
|
|
DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()), |
|
|
|
grad.dtype == |
|
|
|
dtype::Float32() DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()) |
|
|
|
DNN_INC_FLOAT16(|| grad.dtype == dtype::Float16()), |
|
|
|
"Backward Remap only supports Float32/BFloat16."); |
|
|
|
auto required_workspace_in_bytes = get_workspace_in_bytes(map_xy, diff, grad); |
|
|
|
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); |
|
|
@@ -102,8 +103,9 @@ void RemapBackwardMat::check_exec( |
|
|
|
check_layout_fwd(src, map_xy, diff); |
|
|
|
megdnn_assert_eq_layout(map_xy, grad); |
|
|
|
megdnn_assert( |
|
|
|
grad.dtype == dtype::Float32() |
|
|
|
DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()), |
|
|
|
grad.dtype == |
|
|
|
dtype::Float32() DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()) |
|
|
|
DNN_INC_FLOAT16(|| grad.dtype == dtype::Float16()), |
|
|
|
"Backward Remap only supports Float32/BFloat16."); |
|
|
|
auto required_workspace_in_bytes = get_workspace_in_bytes(src, map_xy, diff, grad); |
|
|
|
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); |
|
|
|