diff --git a/imperative/python/megengine/tools/load_network_and_run.py b/imperative/python/megengine/tools/load_network_and_run.py index 2b56a2fb..ba9cd76f 100755 --- a/imperative/python/megengine/tools/load_network_and_run.py +++ b/imperative/python/megengine/tools/load_network_and_run.py @@ -156,6 +156,9 @@ def run_model(args, graph, inputs, outputs, data): func = graph.compile(outputs) + if args.get_static_mem_info: + func.get_static_memory_alloc_info(args.get_static_mem_info) + def run(): if not args.embed_input: for key in inp_dict: @@ -389,6 +392,11 @@ def main(): help="embed input data as SharedDeviceTensor in model, " "to remove memory copy for inputs", ) + parser.add_argument( + "--get-static-mem-info", + type=str, + help="Record the static graph's static memory info.", + ) args = parser.parse_args() if args.verbose: diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 15768f7e..8bb8f5ca 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -215,7 +215,10 @@ void init_graph_rt(py::module m) { } } return ret; - }); + }) + .def("get_static_memory_alloc_info", + &cg::AsyncExecutable::get_static_memory_alloc_info, + py::call_guard()); auto PyComputingGraph = py::class_>(m, "ComputingGraph") .def(py::init(py::overload_cast<>(&cg::ComputingGraph::make)))