Browse Source

feat(mge/imperative): add sublinear options

GitOrigin-RevId: f0e917f716
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
0e82b959a1
10 changed files with 116 additions and 21 deletions
  1. +1
    -0
      imperative/python/megengine/jit/__init__.py
  2. +56
    -0
      imperative/python/megengine/jit/sublinear_memory_config.py
  3. +27
    -1
      imperative/python/megengine/jit/tracing.py
  4. +11
    -0
      imperative/python/src/graph_rt.cpp
  5. +8
    -7
      imperative/python/test/integration/test_correctness.py
  6. +5
    -5
      python_module/megengine/jit/__init__.py
  7. +5
    -5
      python_module/src/cpp/megbrain_config.cpp
  8. +1
    -1
      src/core/impl/graph/cg_impl.cpp
  9. +1
    -1
      src/core/include/megbrain/graph/cg.h
  10. +1
    -1
      src/core/test/sublinear_memory.cpp

+ 1
- 0
imperative/python/megengine/jit/__init__.py View File

@@ -1 +1,2 @@
from .sublinear_memory_config import SublinearMemoryConfig
from .tracing import exclude_from_trace, trace from .tracing import exclude_from_trace, trace

+ 56
- 0
imperative/python/megengine/jit/sublinear_memory_config.py View File

@@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ..device import get_device_count


class SublinearMemoryConfig:
r"""
Configuration for sublinear memory optimization.

:param thresh_nr_try: number of samples both for searching in linear space
and around current thresh in sublinear memory optimization. Default: 10.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY'.
:param genetic_nr_iter: number of iterations to find the best checkpoints in genetic algorithm.
Default: 0.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER'.
:param genetic_pool_size: number of samples for the crossover random selection
during genetic optimization. Default: 20.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE'.
:param lb_memory: memory lower bound of bottleneck size in MB for sublinear memory optimization.
It can be used to perform manual tradeoff between memory and speed. Default: 0.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB'.
:param num_worker: number of thread workers to search the optimum checkpoints
in sublinear memory optimization. Default: half of cpu number in the system.
Note: the value must be greater or equal to one.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_WORKERS'.

Note that the environmental variable MGB_COMP_GRAPH_OPT must be set to 'enable_sublinear_memory_opt=1'
in order for the above environmental variable to be effective.
"""

def __init__(
self,
thresh_nr_try: int = 10,
genetic_nr_iter: int = 0,
genetic_pool_size: int = 20,
lb_memory: int = 0,
num_worker: int = max(1, get_device_count("cpu") // 2),
):
assert thresh_nr_try >= 0, "thresh_nr_try must be greater or equal to zero"
self.thresh_nr_try = thresh_nr_try
assert genetic_nr_iter >= 0, "genetic_nr_iter must be greater or equal to zero"
self.genetic_nr_iter = genetic_nr_iter
assert (
genetic_pool_size >= 0
), "genetic_pool_size must be greater or equal to zero"
self.genetic_pool_size = genetic_pool_size
self.lb_memory = lb_memory
assert num_worker > 0, "num_worker must be greater or equal to one"
self.num_worker = num_worker

+ 27
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -7,6 +7,7 @@ from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, apply from ..core.tensor.core import OpBase, apply
from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
from .sublinear_memory_config import SublinearMemoryConfig




class TraceMismatchError(RuntimeError): class TraceMismatchError(RuntimeError):
@@ -72,11 +73,18 @@ class trace:
self.__init__(*args, **kwargs) self.__init__(*args, **kwargs)
return self return self


def __init__(self, function, symbolic=False, capture_as_const=False):
def __init__(
self,
function,
symbolic=False,
capture_as_const=False,
sublinear_memory_config: SublinearMemoryConfig = None,
):
self.__wrapped__ = function self.__wrapped__ = function
self._symbolic = symbolic self._symbolic = symbolic
self._capture_as_const = capture_as_const self._capture_as_const = capture_as_const
self._capture_static_shape = False self._capture_static_shape = False
self._sublinear_memory_config = sublinear_memory_config


self._untraced = True self._untraced = True
self._tinfo = [] # handle -> TensorInfo self._tinfo = [] # handle -> TensorInfo
@@ -227,6 +235,7 @@ class trace:
G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
for x in lazy_eval_tensors for x in lazy_eval_tensors
] ]
self._apply_graph_options(self._lazy_eval_graph)
self._lazy_eval_graph.compile(*readers) self._lazy_eval_graph.compile(*readers)
self._lazy_eval_graph() self._lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors): for r, x in zip(readers, lazy_eval_tensors):
@@ -259,9 +268,26 @@ class trace:
info.exported = True info.exported = True
info.data_read = True info.data_read = True


def _apply_graph_options(self, graph):

# sublinear
if self._sublinear_memory_config is not None:
graph.options.enable_sublinear_memory_opt = True
sublinear_config = graph.options.sublinear_mem_config
sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory
sublinear_config.genetic_nr_iter = (
self._sublinear_memory_config.genetic_nr_iter
)
sublinear_config.genetic_pool_size = (
self._sublinear_memory_config.genetic_pool_size
)
sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
sublinear_config.num_worker = self._sublinear_memory_config.num_worker

def _compile(self): def _compile(self):
graph = self._graph = G.Graph() graph = self._graph = G.Graph()
graph.options.no_force_inplace = True graph.options.no_force_inplace = True
self._apply_graph_options(graph)
# graph.options.graph_opt_level = 0 # graph.options.graph_opt_level = 0
need_reset_nodes = self._need_reset_nodes = [] need_reset_nodes = self._need_reset_nodes = []
# links enforce ordering of I/O nodes # links enforce ordering of I/O nodes


+ 11
- 0
imperative/python/src/graph_rt.cpp View File

@@ -119,6 +119,7 @@ void init_graph_rt(py::module m) {
DEF_READWRITE(enable_memory_swap) DEF_READWRITE(enable_memory_swap)
DEF_READWRITE(comp_node_seq_record_level) DEF_READWRITE(comp_node_seq_record_level)
DEF_READWRITE(no_force_inplace) DEF_READWRITE(no_force_inplace)
DEF_READWRITE(sublinear_mem_config)
// DEF_READWRITE(eager_evaluation) // DEF_READWRITE(eager_evaluation)
// DEF_READWRITE(imperative_proxy_graph) // DEF_READWRITE(imperative_proxy_graph)
// DEF_READWRITE(extra_vardeps) // DEF_READWRITE(extra_vardeps)
@@ -142,6 +143,16 @@ void init_graph_rt(py::module m) {


#undef CURRENT_CLASS #undef CURRENT_CLASS


#define CURRENT_CLASS cg::ComputingGraph::Options::SublinearMemConfig

py::class_<cg::ComputingGraph::Options::SublinearMemConfig>(PyComputingGraphOptions, "SublinearMemConfig")
DEF_READWRITE(thresh_nr_try)
DEF_READWRITE(genetic_nr_iter)
DEF_READWRITE(genetic_pool_size)
DEF_READWRITE(lb_memory)
DEF_READWRITE(num_worker);

#undef CURRENT_CLASS
auto common = rel_import("common", m, 1); auto common = rel_import("common", m, 1);


common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) { common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) {


+ 8
- 7
imperative/python/test/integration/test_correctness.py View File

@@ -19,6 +19,7 @@ import megengine.functional as F
from megengine import jit from megengine import jit
from megengine.core._trace_option import set_tensor_shape from megengine.core._trace_option import set_tensor_shape
from megengine.functional.debug_param import set_conv_execution_strategy from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.jit import SublinearMemoryConfig
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module
from megengine.optimizer import SGD from megengine.optimizer import SGD
from megengine.tensor import Tensor from megengine.tensor import Tensor
@@ -217,14 +218,14 @@ def test_correctness():
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE")


run_train(model_path, False, False, max_err=1e-5) run_train(model_path, False, False, max_err=1e-5)
# run_test(model_path, True, False)
# run_test(model_path, True, True)
run_train(model_path, True, False, max_err=1e-5)
run_train(model_path, True, True, max_err=1e-5)


# sublinear # sublinear
# config = SublinearMemoryConfig(genetic_nr_iter=10)
# run_test(
# model_path, True, True, sublinear_memory_config=config, max_err=1e-5,
# )
config = SublinearMemoryConfig(genetic_nr_iter=10)
run_train(
model_path, True, True, sublinear_memory_config=config, max_err=1e-5,
)


run_eval(model_path, False, max_err=1e-7) run_eval(model_path, False, max_err=1e-7)
# run_eval(model_path, True, max_err=1e-7) # XXX: fix me
run_eval(model_path, True, max_err=1e-7)

+ 5
- 5
python_module/megengine/jit/__init__.py View File

@@ -298,23 +298,23 @@ class trace:
if self._sublinear_memory_config is not None: if self._sublinear_memory_config is not None:
cg.set_option("enable_sublinear_memory_opt", True) cg.set_option("enable_sublinear_memory_opt", True)
cg.set_option( cg.set_option(
"sublinear_mem_cofig.lb_memory",
"sublinear_mem_config.lb_memory",
self._sublinear_memory_config.lb_memory, self._sublinear_memory_config.lb_memory,
) )
cg.set_option( cg.set_option(
"sublinear_mem_cofig.genetic_nr_iter",
"sublinear_mem_config.genetic_nr_iter",
self._sublinear_memory_config.genetic_nr_iter, self._sublinear_memory_config.genetic_nr_iter,
) )
cg.set_option( cg.set_option(
"sublinear_mem_cofig.genetic_pool_size",
"sublinear_mem_config.genetic_pool_size",
self._sublinear_memory_config.genetic_pool_size, self._sublinear_memory_config.genetic_pool_size,
) )
cg.set_option( cg.set_option(
"sublinear_mem_cofig.thresh_nr_try",
"sublinear_mem_config.thresh_nr_try",
self._sublinear_memory_config.thresh_nr_try, self._sublinear_memory_config.thresh_nr_try,
) )
cg.set_option( cg.set_option(
"sublinear_mem_cofig.num_worker",
"sublinear_mem_config.num_worker",
self._sublinear_memory_config.num_worker, self._sublinear_memory_config.num_worker,
) )
# pack allreduce # pack allreduce


+ 5
- 5
python_module/src/cpp/megbrain_config.cpp View File

@@ -116,11 +116,11 @@ bool _config::set_comp_graph_option(
SET_CG_OPTION(allocate_static_mem_after_graph_compile); SET_CG_OPTION(allocate_static_mem_after_graph_compile);
SET_CG_OPTION(log_level); SET_CG_OPTION(log_level);
SET_CG_OPTION(enable_sublinear_memory_opt); SET_CG_OPTION(enable_sublinear_memory_opt);
SET_CG_OPTION(sublinear_mem_cofig.lb_memory);
SET_CG_OPTION(sublinear_mem_cofig.genetic_nr_iter);
SET_CG_OPTION(sublinear_mem_cofig.genetic_pool_size);
SET_CG_OPTION(sublinear_mem_cofig.thresh_nr_try);
SET_CG_OPTION(sublinear_mem_cofig.num_worker);
SET_CG_OPTION(sublinear_mem_config.lb_memory);
SET_CG_OPTION(sublinear_mem_config.genetic_nr_iter);
SET_CG_OPTION(sublinear_mem_config.genetic_pool_size);
SET_CG_OPTION(sublinear_mem_config.thresh_nr_try);
SET_CG_OPTION(sublinear_mem_config.num_worker);
SET_CG_OPTION(enable_var_mem_defragment); SET_CG_OPTION(enable_var_mem_defragment);
SET_CG_OPTION(eager_evaluation); SET_CG_OPTION(eager_evaluation);
SET_CG_OPTION(enable_memory_swap); SET_CG_OPTION(enable_memory_swap);


+ 1
- 1
src/core/impl/graph/cg_impl.cpp View File

@@ -219,7 +219,7 @@ ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner)
grad_manager{owner}, grad_manager{owner},
#if MGB_ENABLE_SUBLINEAR #if MGB_ENABLE_SUBLINEAR
seq_modifier_for_sublinear_memory{owner, seq_modifier_for_sublinear_memory{owner,
&(owner->options().sublinear_mem_cofig)},
&(owner->options().sublinear_mem_config)},
#endif #endif
#if MGB_ENABLE_MEMORY_SWAP #if MGB_ENABLE_MEMORY_SWAP
memory_swap_support{owner}, memory_swap_support{owner},


+ 1
- 1
src/core/include/megbrain/graph/cg.h View File

@@ -409,7 +409,7 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
int genetic_pool_size = 20; int genetic_pool_size = 20;
int lb_memory = 0; int lb_memory = 0;
int num_worker = sys::get_cpu_count() / 2; int num_worker = sys::get_cpu_count() / 2;
} sublinear_mem_cofig;
} sublinear_mem_config;


//! do not re-profile to select best impl algo when input shape //! do not re-profile to select best impl algo when input shape
//! changes (use previous algo) //! changes (use previous algo)


+ 1
- 1
src/core/test/sublinear_memory.cpp View File

@@ -522,7 +522,7 @@ TEST(TestSublinearMemory, BadOpr) {
set_priority(z, 3); set_priority(z, 3);
graph->options().graph_opt_level = 0; graph->options().graph_opt_level = 0;
graph->options().enable_sublinear_memory_opt = 1; graph->options().enable_sublinear_memory_opt = 1;
graph->options().sublinear_mem_cofig.genetic_nr_iter = 50;
graph->options().sublinear_mem_config.genetic_nr_iter = 50;
auto func = graph->compile({{y, {}}, {z, {}}}); auto func = graph->compile({{y, {}}, {z, {}}});
auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get()) auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get())
->seq_modifier_for_sublinear_memory().prev_min_bottleneck(); ->seq_modifier_for_sublinear_memory().prev_min_bottleneck();


Loading…
Cancel
Save