|
@@ -379,6 +379,8 @@ void TopK::init_output_static_infer_desc() { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
auto infer_workspace = [this](TensorShape& dst, const InpVal& iv) { |
|
|
auto infer_workspace = [this](TensorShape& dst, const InpVal& iv) { |
|
|
|
|
|
// active comp_node for cuda launch kernel in get_workspace_in_bytes |
|
|
|
|
|
comp_node().activate(); |
|
|
auto k = iv.val[3].value().ptr<int>()[0]; |
|
|
auto k = iv.val[3].value().ptr<int>()[0]; |
|
|
auto size = megdnn_opr()->get_workspace_in_bytes( |
|
|
auto size = megdnn_opr()->get_workspace_in_bytes( |
|
|
k, {iv.val[0].shape(), input(0)->dtype()}, |
|
|
k, {iv.val[0].shape(), input(0)->dtype()}, |
|
|