Browse Source

test(imperative/python): fix testcase for magicmind runtime module

GitOrigin-RevId: baf2f72f01
tags/v1.7.2.m1
Megvii Engine Team XindaH 3 years ago
parent
commit
0c3699ff55
4 changed files with 9 additions and 5 deletions
  1. +1
    -0
      .gitattributes
  2. +2
    -3
      imperative/python/megengine/module/external.py
  3. +2
    -2
      src/cambricon/impl/magicmind_runtime_opr.cpp
  4. +4
    -0
      src/core/impl/comp_node_env.cpp

+ 1
- 0
.gitattributes View File

@@ -20,3 +20,4 @@ ci/resource/dump/batch_conv_bias_with_policy_8.8.0.mdl filter=lfs diff=lfs merge
ci/resource/prof/model_with_err_assert.mdl filter=lfs diff=lfs merge=lfs -text ci/resource/prof/model_with_err_assert.mdl filter=lfs diff=lfs merge=lfs -text
ci/resource/prof/test_mge.mge filter=lfs diff=lfs merge=lfs -text ci/resource/prof/test_mge.mge filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/ax_models/64-58063ce2.axe filter=lfs diff=lfs merge=lfs -text lite/test/resource/lite/ax_models/64-58063ce2.axe filter=lfs diff=lfs merge=lfs -text
imperative/python/test/unit/module/MagicMindRuntimeOprTest.GraphShapeMutable.mlu filter=lfs diff=lfs merge=lfs -text

+ 2
- 3
imperative/python/megengine/module/external.py View File

@@ -13,6 +13,7 @@ from ..functional.external import (
atlas_runtime_opr, atlas_runtime_opr,
cambricon_runtime_opr, cambricon_runtime_opr,
extern_opr_subgraph, extern_opr_subgraph,
magicmind_runtime_opr,
tensorrt_runtime_opr, tensorrt_runtime_opr,
) )
from .module import Module from .module import Module
@@ -131,6 +132,7 @@ class AtlasRuntimeSubgraph(Module):
def forward(self, *inputs): def forward(self, *inputs):
return atlas_runtime_opr(inputs, data=self._data) return atlas_runtime_opr(inputs, data=self._data)



class MagicMindRuntimeSubgraph(Module): class MagicMindRuntimeSubgraph(Module):
r"""Load a serialized MagicMindRuntime subgraph. r"""Load a serialized MagicMindRuntime subgraph.
@@ -151,6 +153,3 @@ class MagicMindRuntimeSubgraph(Module):


def forward(self, *inputs): def forward(self, *inputs):
return magicmind_runtime_opr(inputs, data=self._data) return magicmind_runtime_opr(inputs, data=self._data)




+ 2
- 2
src/cambricon/impl/magicmind_runtime_opr.cpp View File

@@ -267,7 +267,7 @@ void MagicMindRuntimeOpr::get_output_var_shape(
mgb_assert( mgb_assert(
tensor != nullptr, "failed to find input tensor(name:%s)", tensor != nullptr, "failed to find input tensor(name:%s)",
iname.c_str()); iname.c_str());
MM_CHECK(tensor->SetDimensions(mgb_shape_to_mm_dims(input(i)->shape())));
MM_CHECK(tensor->SetDimensions(mgb_shape_to_mm_dims(inp_shape[i])));
} }
if (Status::OK() == m_context->InferOutputShape(inputs, outputs)) { if (Status::OK() == m_context->InferOutputShape(inputs, outputs)) {
size_t nr_outputs = output().size(); size_t nr_outputs = output().size();
@@ -283,7 +283,7 @@ void MagicMindRuntimeOpr::get_output_var_shape(
} }
std::vector<Dims> shape(inp_shape.size()); std::vector<Dims> shape(inp_shape.size());
for (size_t i = 0; i < nr_inputs; ++i) { for (size_t i = 0; i < nr_inputs; ++i) {
shape[i] = mgb_shape_to_mm_dims(input(i)->shape());
shape[i] = mgb_shape_to_mm_dims(inp_shape[i]);
} }
size_t wk_size = 0; size_t wk_size = 0;
MM_CHECK(m_engine->QueryContextMaxWorkspaceSize(shape, &wk_size)); MM_CHECK(m_engine->QueryContextMaxWorkspaceSize(shape, &wk_size));


+ 4
- 0
src/core/impl/comp_node_env.cpp View File

@@ -390,7 +390,11 @@ void CompNodeEnv::init_cnrt(
MGB_CNRT_CHECK(cnrtGetDeviceInfo(&m_cnrt_env.device_info, dev)); MGB_CNRT_CHECK(cnrtGetDeviceInfo(&m_cnrt_env.device_info, dev));
// FIXME: doc doesn't describe the aligment requirement for device memory // FIXME: doc doesn't describe the aligment requirement for device memory
// address // address
#if CNRT_MAJOR_VERSION >= 5
m_property.mem_alignment = 256u;
#else
m_property.mem_alignment = 1u; m_property.mem_alignment = 1u;
#endif
// ensure exception safe // ensure exception safe
bool queue_created = false; bool queue_created = false;
MGB_MARK_USED_VAR(queue_created); MGB_MARK_USED_VAR(queue_created);


Loading…
Cancel
Save