|
|
@@ -81,18 +81,10 @@ void TopKImpl::do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values, |
|
|
|
values.ptr<int32_t>(), indices, |
|
|
|
workspace.raw_ptr); |
|
|
|
return; |
|
|
|
// #if !MEGDNN_DISABLE_FLOAT16 |
|
|
|
// case DTypeEnum::Float16: |
|
|
|
// dispatch_with_ctype<dt_float16>(k, data.layout[0], data.layout[1], |
|
|
|
// data.layout.stride[0], data.ptr<dt_float16>(), |
|
|
|
// values.ptr<dt_float16>(), indices, |
|
|
|
// workspace.raw_ptr); |
|
|
|
// return; |
|
|
|
// #endif |
|
|
|
default: |
|
|
|
megdnn_throw( |
|
|
|
ssprintf("only float32, int32 and float16 supported for " |
|
|
|
"cuda topk, got: %s", |
|
|
|
ssprintf("only float32, int32 are supported for " |
|
|
|
"rocm topk, got: %s", |
|
|
|
data.layout.dtype.name())); |
|
|
|
} |
|
|
|
} |
|
|
|