Browse Source

test(imperative/python): fix testcase for magicmind runtime module

GitOrigin-RevId: baf2f72f01
tags/v1.7.2.m1
Megvii Engine Team XindaH 3 years ago
parent
commit
0c3699ff55
4 changed files with 9 additions and 5 deletions
  1. +1
    -0
      .gitattributes
  2. +2
    -3
      imperative/python/megengine/module/external.py
  3. +2
    -2
      src/cambricon/impl/magicmind_runtime_opr.cpp
  4. +4
    -0
      src/core/impl/comp_node_env.cpp

+ 1
- 0
.gitattributes View File

@@ -20,3 +20,4 @@ ci/resource/dump/batch_conv_bias_with_policy_8.8.0.mdl filter=lfs diff=lfs merge
ci/resource/prof/model_with_err_assert.mdl filter=lfs diff=lfs merge=lfs -text
ci/resource/prof/test_mge.mge filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/ax_models/64-58063ce2.axe filter=lfs diff=lfs merge=lfs -text
imperative/python/test/unit/module/MagicMindRuntimeOprTest.GraphShapeMutable.mlu filter=lfs diff=lfs merge=lfs -text

+ 2
- 3
imperative/python/megengine/module/external.py View File

@@ -13,6 +13,7 @@ from ..functional.external import (
atlas_runtime_opr,
cambricon_runtime_opr,
extern_opr_subgraph,
magicmind_runtime_opr,
tensorrt_runtime_opr,
)
from .module import Module
@@ -131,6 +132,7 @@ class AtlasRuntimeSubgraph(Module):
def forward(self, *inputs):
return atlas_runtime_opr(inputs, data=self._data)


class MagicMindRuntimeSubgraph(Module):
r"""Load a serialized MagicMindRuntime subgraph.
@@ -151,6 +153,3 @@ class MagicMindRuntimeSubgraph(Module):

def forward(self, *inputs):
return magicmind_runtime_opr(inputs, data=self._data)




+ 2
- 2
src/cambricon/impl/magicmind_runtime_opr.cpp View File

@@ -267,7 +267,7 @@ void MagicMindRuntimeOpr::get_output_var_shape(
mgb_assert(
tensor != nullptr, "failed to find input tensor(name:%s)",
iname.c_str());
MM_CHECK(tensor->SetDimensions(mgb_shape_to_mm_dims(input(i)->shape())));
MM_CHECK(tensor->SetDimensions(mgb_shape_to_mm_dims(inp_shape[i])));
}
if (Status::OK() == m_context->InferOutputShape(inputs, outputs)) {
size_t nr_outputs = output().size();
@@ -283,7 +283,7 @@ void MagicMindRuntimeOpr::get_output_var_shape(
}
std::vector<Dims> shape(inp_shape.size());
for (size_t i = 0; i < nr_inputs; ++i) {
shape[i] = mgb_shape_to_mm_dims(input(i)->shape());
shape[i] = mgb_shape_to_mm_dims(inp_shape[i]);
}
size_t wk_size = 0;
MM_CHECK(m_engine->QueryContextMaxWorkspaceSize(shape, &wk_size));


+ 4
- 0
src/core/impl/comp_node_env.cpp View File

@@ -390,7 +390,11 @@ void CompNodeEnv::init_cnrt(
MGB_CNRT_CHECK(cnrtGetDeviceInfo(&m_cnrt_env.device_info, dev));
// FIXME: doc doesn't describe the aligment requirement for device memory
// address
#if CNRT_MAJOR_VERSION >= 5
m_property.mem_alignment = 256u;
#else
m_property.mem_alignment = 1u;
#endif
// ensure exception safe
bool queue_created = false;
MGB_MARK_USED_VAR(queue_created);


Loading…
Cancel
Save