|
|
@@ -36,13 +36,14 @@ template <int out_dim, typename ctype> |
|
|
|
void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, |
|
|
|
int block_size) { |
|
|
|
auto&& args = fusion_opr->args(); |
|
|
|
std::vector<StridedMemRefType<ctype, out_dim>> param_holders; |
|
|
|
size_t num_memrefs = args.inputs.size() + args.outputs.size(); |
|
|
|
std::vector<StridedMemRefType<ctype, out_dim>> param_holders(num_memrefs); |
|
|
|
std::vector<void*> params; |
|
|
|
|
|
|
|
auto set_params = [¶m_holders, ¶ms]( |
|
|
|
void* ptr, const megdnn::TensorLayout& layout) { |
|
|
|
param_holders.push_back(StridedMemRefType<ctype, out_dim>{}); |
|
|
|
StridedMemRefType<ctype, out_dim>& desc = param_holders.back(); |
|
|
|
size_t idx, void* ptr, |
|
|
|
const megdnn::TensorLayout& layout) { |
|
|
|
auto& desc = param_holders[idx]; |
|
|
|
desc.basePtr = static_cast<ctype*>(ptr); |
|
|
|
params.push_back(&(desc.basePtr)); |
|
|
|
desc.data = static_cast<ctype*>(ptr); |
|
|
@@ -56,9 +57,12 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, |
|
|
|
params.push_back(&(desc.strides[i])); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
size_t idx = 0; |
|
|
|
for (const auto& arg : args.inputs) { |
|
|
|
set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout()); |
|
|
|
set_params(idx++, arg.from->dev_tensor().raw_ptr(), arg.from->layout()); |
|
|
|
} |
|
|
|
|
|
|
|
int64_t nr_elements = 0; |
|
|
|
for (const auto& arg : args.outputs) { |
|
|
|
if (nr_elements == 0) { |
|
|
@@ -73,8 +77,13 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, |
|
|
|
arg.from->layout().to_string().c_str()); |
|
|
|
} |
|
|
|
|
|
|
|
set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout()); |
|
|
|
set_params(idx++, arg.from->dev_tensor().raw_ptr(), arg.from->layout()); |
|
|
|
} |
|
|
|
|
|
|
|
mgb_assert(param_holders.size() == num_memrefs, |
|
|
|
"calling push_back method of param_holders is unsafe as it " |
|
|
|
"might cause reallocation of std::vector"); |
|
|
|
|
|
|
|
const CompNodeEnv& env = |
|
|
|
CompNodeEnv::from_comp_node(fusion_opr->comp_node()); |
|
|
|
|
|
|
|