Browse Source

fix(mgb/jit): fix a pointer bug in mlir executable_cuda

GitOrigin-RevId: 3ec79b7602
release-1.2
Megvii Engine Team 4 years ago
parent
commit
f7731bd437
1 changed files with 15 additions and 6 deletions
  1. +15
    -6
      src/jit/impl/mlir/executable_cuda.cpp

+ 15
- 6
src/jit/impl/mlir/executable_cuda.cpp View File

@@ -36,13 +36,14 @@ template <int out_dim, typename ctype>
void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func,
int block_size) {
auto&& args = fusion_opr->args();
std::vector<StridedMemRefType<ctype, out_dim>> param_holders;
size_t num_memrefs = args.inputs.size() + args.outputs.size();
std::vector<StridedMemRefType<ctype, out_dim>> param_holders(num_memrefs);
std::vector<void*> params;

auto set_params = [&param_holders, &params](
void* ptr, const megdnn::TensorLayout& layout) {
param_holders.push_back(StridedMemRefType<ctype, out_dim>{});
StridedMemRefType<ctype, out_dim>& desc = param_holders.back();
size_t idx, void* ptr,
const megdnn::TensorLayout& layout) {
auto& desc = param_holders[idx];
desc.basePtr = static_cast<ctype*>(ptr);
params.push_back(&(desc.basePtr));
desc.data = static_cast<ctype*>(ptr);
@@ -56,9 +57,12 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func,
params.push_back(&(desc.strides[i]));
}
};

size_t idx = 0;
for (const auto& arg : args.inputs) {
set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout());
set_params(idx++, arg.from->dev_tensor().raw_ptr(), arg.from->layout());
}

int64_t nr_elements = 0;
for (const auto& arg : args.outputs) {
if (nr_elements == 0) {
@@ -73,8 +77,13 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func,
arg.from->layout().to_string().c_str());
}

set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout());
set_params(idx++, arg.from->dev_tensor().raw_ptr(), arg.from->layout());
}

mgb_assert(param_holders.size() == num_memrefs,
"calling push_back method of param_holders is unsafe as it "
"might cause reallocation of std::vector");

const CompNodeEnv& env =
CompNodeEnv::from_comp_node(fusion_opr->comp_node());



Loading…
Cancel
Save