Browse Source

feat(mgb): add megengine inference

GitOrigin-RevId: 6ffec6b418
release-1.5
Megvii Engine Team 4 years ago
parent
commit
8a918717c0
2 changed files with 12 additions and 1 deletions
  1. +8
    -0
      imperative/python/megengine/tools/load_network_and_run.py
  2. +4
    -1
      imperative/python/src/graph_rt.cpp

+ 8
- 0
imperative/python/megengine/tools/load_network_and_run.py View File

@@ -156,6 +156,9 @@ def run_model(args, graph, inputs, outputs, data):

func = graph.compile(outputs)

if args.get_static_mem_info:
func.get_static_memory_alloc_info(args.get_static_mem_info)

def run():
if not args.embed_input:
for key in inp_dict:
@@ -389,6 +392,11 @@ def main():
help="embed input data as SharedDeviceTensor in model, "
"to remove memory copy for inputs",
)
parser.add_argument(
"--get-static-mem-info",
type=str,
help="Record the static graph's static memory info.",
)
args = parser.parse_args()

if args.verbose:


+ 4
- 1
imperative/python/src/graph_rt.cpp View File

@@ -215,7 +215,10 @@ void init_graph_rt(py::module m) {
}
}
return ret;
});
})
.def("get_static_memory_alloc_info",
&cg::AsyncExecutable::get_static_memory_alloc_info,
py::call_guard<py::gil_scoped_release>());

auto PyComputingGraph = py::class_<cg::ComputingGraph, std::shared_ptr<cg::ComputingGraph>>(m, "ComputingGraph")
.def(py::init(py::overload_cast<>(&cg::ComputingGraph::make)))


Loading…
Cancel
Save