|
@@ -40,11 +40,17 @@ class NMSKeep::CUDAKern final : public Kern { |
|
|
void init(const NMSKeep* opr, const TensorShape& boxes) { |
|
|
void init(const NMSKeep* opr, const TensorShape& boxes) { |
|
|
auto align = opr->comp_node().get_mem_addr_alignment(); |
|
|
auto align = opr->comp_node().get_mem_addr_alignment(); |
|
|
size_t nr_boxes = boxes[1]; |
|
|
size_t nr_boxes = boxes[1]; |
|
|
m_workspace_overlap_mask_bytes = |
|
|
|
|
|
nr_boxes * DIVUP(nr_boxes, 64) * sizeof(uint64_t); |
|
|
|
|
|
m_workspace_overlap_mask_bytes_align = |
|
|
|
|
|
get_aligned_power2(m_workspace_overlap_mask_bytes, align); |
|
|
|
|
|
m_workspace_rm_mask_bytes = DIVUP(nr_boxes, 64) * sizeof(uint64_t); |
|
|
|
|
|
|
|
|
if (nr_boxes == 0) { |
|
|
|
|
|
m_workspace_overlap_mask_bytes = 0; |
|
|
|
|
|
m_workspace_overlap_mask_bytes_align = 0; |
|
|
|
|
|
m_workspace_rm_mask_bytes = 0; |
|
|
|
|
|
} else { |
|
|
|
|
|
m_workspace_overlap_mask_bytes = |
|
|
|
|
|
nr_boxes * DIVUP(nr_boxes, 64) * sizeof(uint64_t); |
|
|
|
|
|
m_workspace_overlap_mask_bytes_align = |
|
|
|
|
|
get_aligned_power2(m_workspace_overlap_mask_bytes, align); |
|
|
|
|
|
m_workspace_rm_mask_bytes = DIVUP(nr_boxes, 64) * sizeof(uint64_t); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
public: |
|
|
public: |
|
@@ -88,7 +94,10 @@ void NMSKeep::CUDAKern::exec(const NMSKeep* opr, const DeviceTensorND& inp, |
|
|
auto out_idx_ptr = reinterpret_cast<uint32_t*>(out_idx.ptr<int32_t>()), |
|
|
auto out_idx_ptr = reinterpret_cast<uint32_t*>(out_idx.ptr<int32_t>()), |
|
|
out_size_ptr = reinterpret_cast<uint32_t*>(out_size.ptr<int32_t>()); |
|
|
out_size_ptr = reinterpret_cast<uint32_t*>(out_size.ptr<int32_t>()); |
|
|
size_t batch = inp.shape(0), nr_boxes = inp.shape(1); |
|
|
size_t batch = inp.shape(0), nr_boxes = inp.shape(1); |
|
|
|
|
|
|
|
|
|
|
|
if (nr_boxes == 0) { |
|
|
|
|
|
MGB_CUDA_CHECK(cudaMemsetAsync(out_size_ptr, 0, batch*sizeof(uint32_t), stream)); |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
MGB_CUDA_CHECK(cudaMemsetAsync(dev_overlap_mask, 0, |
|
|
MGB_CUDA_CHECK(cudaMemsetAsync(dev_overlap_mask, 0, |
|
|
m_workspace_overlap_mask_bytes, stream)); |
|
|
m_workspace_overlap_mask_bytes, stream)); |
|
|
|
|
|
|
|
@@ -136,6 +145,12 @@ void NMSKeep::CPUKern::exec(const NMSKeep* opr, const DeviceTensorND& inp, |
|
|
auto out_idx_ptr = reinterpret_cast<uint32_t*>(out_idx.ptr<int32_t>()), |
|
|
auto out_idx_ptr = reinterpret_cast<uint32_t*>(out_idx.ptr<int32_t>()), |
|
|
out_size_ptr = reinterpret_cast<uint32_t*>(out_size.ptr<int32_t>()); |
|
|
out_size_ptr = reinterpret_cast<uint32_t*>(out_size.ptr<int32_t>()); |
|
|
size_t batch = inp.shape(0), nr_boxes = inp.shape(1); |
|
|
size_t batch = inp.shape(0), nr_boxes = inp.shape(1); |
|
|
|
|
|
if (nr_boxes == 0) { |
|
|
|
|
|
for (size_t i = 0; i < batch; ++i) { |
|
|
|
|
|
*(out_size_ptr + i) = 0; |
|
|
|
|
|
} |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
auto param = opr->param(); |
|
|
auto param = opr->param(); |
|
|
|
|
|
|
|
|
auto workspace_ptr = workspace.raw_ptr(); |
|
|
auto workspace_ptr = workspace.raw_ptr(); |
|
@@ -183,7 +198,8 @@ NMSKeep::NMSKeep(VarNode* boxes, const Param& param, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
add_input({boxes}); |
|
|
add_input({boxes}); |
|
|
add_output("indices")->dtype(dtype::Int32()); |
|
|
|
|
|
|
|
|
add_output("indices")->dtype(dtype::Int32()) |
|
|
|
|
|
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); |
|
|
add_output("sizes")->dtype(dtype::Int32()); |
|
|
add_output("sizes")->dtype(dtype::Int32()); |
|
|
cg::add_workspace_output(this); // workspace is also an output var |
|
|
cg::add_workspace_output(this); // workspace is also an output var |
|
|
|
|
|
|
|
@@ -233,6 +249,13 @@ void NMSKeep::scn_do_execute() { |
|
|
: empty_workspace); |
|
|
: empty_workspace); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
NMSKeep::NodeProp* NMSKeep::do_make_node_prop() const { |
|
|
|
|
|
auto ret = Super::do_make_node_prop(); |
|
|
|
|
|
ret->add_dep_type_existing_var(input(0), |
|
|
|
|
|
NodeProp::DepType::VALUE_ALLOW_EMPTY); |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
#if MGB_ENABLE_FBS_SERIALIZATION |
|
|
#if MGB_ENABLE_FBS_SERIALIZATION |
|
|
|
|
|
|
|
|
namespace mgb { |
|
|
namespace mgb { |
|
|