GitOrigin-RevId: f0e917f716
tags/v1.0.0-rc1
@@ -1 +1,2 @@ | |||||
from .sublinear_memory_config import SublinearMemoryConfig | |||||
from .tracing import exclude_from_trace, trace | from .tracing import exclude_from_trace, trace |
@@ -0,0 +1,56 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ..device import get_device_count | |||||
class SublinearMemoryConfig: | |||||
r""" | |||||
Configuration for sublinear memory optimization. | |||||
:param thresh_nr_try: number of samples both for searching in linear space | |||||
and around current thresh in sublinear memory optimization. Default: 10. | |||||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY'. | |||||
:param genetic_nr_iter: number of iterations to find the best checkpoints in genetic algorithm. | |||||
Default: 0. | |||||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER'. | |||||
:param genetic_pool_size: number of samples for the crossover random selection | |||||
during genetic optimization. Default: 20. | |||||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE'. | |||||
:param lb_memory: memory lower bound of bottleneck size in MB for sublinear memory optimization. | |||||
It can be used to perform manual tradeoff between memory and speed. Default: 0. | |||||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB'. | |||||
:param num_worker: number of thread workers to search the optimum checkpoints | |||||
in sublinear memory optimization. Default: half of cpu number in the system. | |||||
Note: the value must be greater or equal to one. | |||||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_WORKERS'. | |||||
Note that the environmental variable MGB_COMP_GRAPH_OPT must be set to 'enable_sublinear_memory_opt=1' | |||||
in order for the above environmental variable to be effective. | |||||
""" | |||||
def __init__( | |||||
self, | |||||
thresh_nr_try: int = 10, | |||||
genetic_nr_iter: int = 0, | |||||
genetic_pool_size: int = 20, | |||||
lb_memory: int = 0, | |||||
num_worker: int = max(1, get_device_count("cpu") // 2), | |||||
): | |||||
assert thresh_nr_try >= 0, "thresh_nr_try must be greater or equal to zero" | |||||
self.thresh_nr_try = thresh_nr_try | |||||
assert genetic_nr_iter >= 0, "genetic_nr_iter must be greater or equal to zero" | |||||
self.genetic_nr_iter = genetic_nr_iter | |||||
assert ( | |||||
genetic_pool_size >= 0 | |||||
), "genetic_pool_size must be greater or equal to zero" | |||||
self.genetic_pool_size = genetic_pool_size | |||||
self.lb_memory = lb_memory | |||||
assert num_worker > 0, "num_worker must be greater or equal to one" | |||||
self.num_worker = num_worker |
@@ -7,6 +7,7 @@ from ..core.ops.special import Const | |||||
from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
from ..core.tensor.core import OpBase, apply | from ..core.tensor.core import OpBase, apply | ||||
from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor | from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor | ||||
from .sublinear_memory_config import SublinearMemoryConfig | |||||
class TraceMismatchError(RuntimeError): | class TraceMismatchError(RuntimeError): | ||||
@@ -72,11 +73,18 @@ class trace: | |||||
self.__init__(*args, **kwargs) | self.__init__(*args, **kwargs) | ||||
return self | return self | ||||
def __init__(self, function, symbolic=False, capture_as_const=False): | |||||
def __init__( | |||||
self, | |||||
function, | |||||
symbolic=False, | |||||
capture_as_const=False, | |||||
sublinear_memory_config: SublinearMemoryConfig = None, | |||||
): | |||||
self.__wrapped__ = function | self.__wrapped__ = function | ||||
self._symbolic = symbolic | self._symbolic = symbolic | ||||
self._capture_as_const = capture_as_const | self._capture_as_const = capture_as_const | ||||
self._capture_static_shape = False | self._capture_static_shape = False | ||||
self._sublinear_memory_config = sublinear_memory_config | |||||
self._untraced = True | self._untraced = True | ||||
self._tinfo = [] # handle -> TensorInfo | self._tinfo = [] # handle -> TensorInfo | ||||
@@ -227,6 +235,7 @@ class trace: | |||||
G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] | G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] | ||||
for x in lazy_eval_tensors | for x in lazy_eval_tensors | ||||
] | ] | ||||
self._apply_graph_options(self._lazy_eval_graph) | |||||
self._lazy_eval_graph.compile(*readers) | self._lazy_eval_graph.compile(*readers) | ||||
self._lazy_eval_graph() | self._lazy_eval_graph() | ||||
for r, x in zip(readers, lazy_eval_tensors): | for r, x in zip(readers, lazy_eval_tensors): | ||||
@@ -259,9 +268,26 @@ class trace: | |||||
info.exported = True | info.exported = True | ||||
info.data_read = True | info.data_read = True | ||||
def _apply_graph_options(self, graph): | |||||
# sublinear | |||||
if self._sublinear_memory_config is not None: | |||||
graph.options.enable_sublinear_memory_opt = True | |||||
sublinear_config = graph.options.sublinear_mem_config | |||||
sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory | |||||
sublinear_config.genetic_nr_iter = ( | |||||
self._sublinear_memory_config.genetic_nr_iter | |||||
) | |||||
sublinear_config.genetic_pool_size = ( | |||||
self._sublinear_memory_config.genetic_pool_size | |||||
) | |||||
sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try | |||||
sublinear_config.num_worker = self._sublinear_memory_config.num_worker | |||||
def _compile(self): | def _compile(self): | ||||
graph = self._graph = G.Graph() | graph = self._graph = G.Graph() | ||||
graph.options.no_force_inplace = True | graph.options.no_force_inplace = True | ||||
self._apply_graph_options(graph) | |||||
# graph.options.graph_opt_level = 0 | # graph.options.graph_opt_level = 0 | ||||
need_reset_nodes = self._need_reset_nodes = [] | need_reset_nodes = self._need_reset_nodes = [] | ||||
# links enforce ordering of I/O nodes | # links enforce ordering of I/O nodes | ||||
@@ -119,6 +119,7 @@ void init_graph_rt(py::module m) { | |||||
DEF_READWRITE(enable_memory_swap) | DEF_READWRITE(enable_memory_swap) | ||||
DEF_READWRITE(comp_node_seq_record_level) | DEF_READWRITE(comp_node_seq_record_level) | ||||
DEF_READWRITE(no_force_inplace) | DEF_READWRITE(no_force_inplace) | ||||
DEF_READWRITE(sublinear_mem_config) | |||||
// DEF_READWRITE(eager_evaluation) | // DEF_READWRITE(eager_evaluation) | ||||
// DEF_READWRITE(imperative_proxy_graph) | // DEF_READWRITE(imperative_proxy_graph) | ||||
// DEF_READWRITE(extra_vardeps) | // DEF_READWRITE(extra_vardeps) | ||||
@@ -142,6 +143,16 @@ void init_graph_rt(py::module m) { | |||||
#undef CURRENT_CLASS | #undef CURRENT_CLASS | ||||
#define CURRENT_CLASS cg::ComputingGraph::Options::SublinearMemConfig | |||||
py::class_<cg::ComputingGraph::Options::SublinearMemConfig>(PyComputingGraphOptions, "SublinearMemConfig") | |||||
DEF_READWRITE(thresh_nr_try) | |||||
DEF_READWRITE(genetic_nr_iter) | |||||
DEF_READWRITE(genetic_pool_size) | |||||
DEF_READWRITE(lb_memory) | |||||
DEF_READWRITE(num_worker); | |||||
#undef CURRENT_CLASS | |||||
auto common = rel_import("common", m, 1); | auto common = rel_import("common", m, 1); | ||||
common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) { | common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) { | ||||
@@ -19,6 +19,7 @@ import megengine.functional as F | |||||
from megengine import jit | from megengine import jit | ||||
from megengine.core._trace_option import set_tensor_shape | from megengine.core._trace_option import set_tensor_shape | ||||
from megengine.functional.debug_param import set_conv_execution_strategy | from megengine.functional.debug_param import set_conv_execution_strategy | ||||
from megengine.jit import SublinearMemoryConfig | |||||
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module | from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module | ||||
from megengine.optimizer import SGD | from megengine.optimizer import SGD | ||||
from megengine.tensor import Tensor | from megengine.tensor import Tensor | ||||
@@ -217,14 +218,14 @@ def test_correctness(): | |||||
set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") | set_conv_execution_strategy("HEURISTIC_REPRODUCIBLE") | ||||
run_train(model_path, False, False, max_err=1e-5) | run_train(model_path, False, False, max_err=1e-5) | ||||
# run_test(model_path, True, False) | |||||
# run_test(model_path, True, True) | |||||
run_train(model_path, True, False, max_err=1e-5) | |||||
run_train(model_path, True, True, max_err=1e-5) | |||||
# sublinear | # sublinear | ||||
# config = SublinearMemoryConfig(genetic_nr_iter=10) | |||||
# run_test( | |||||
# model_path, True, True, sublinear_memory_config=config, max_err=1e-5, | |||||
# ) | |||||
config = SublinearMemoryConfig(genetic_nr_iter=10) | |||||
run_train( | |||||
model_path, True, True, sublinear_memory_config=config, max_err=1e-5, | |||||
) | |||||
run_eval(model_path, False, max_err=1e-7) | run_eval(model_path, False, max_err=1e-7) | ||||
# run_eval(model_path, True, max_err=1e-7) # XXX: fix me | |||||
run_eval(model_path, True, max_err=1e-7) |
@@ -298,23 +298,23 @@ class trace: | |||||
if self._sublinear_memory_config is not None: | if self._sublinear_memory_config is not None: | ||||
cg.set_option("enable_sublinear_memory_opt", True) | cg.set_option("enable_sublinear_memory_opt", True) | ||||
cg.set_option( | cg.set_option( | ||||
"sublinear_mem_cofig.lb_memory", | |||||
"sublinear_mem_config.lb_memory", | |||||
self._sublinear_memory_config.lb_memory, | self._sublinear_memory_config.lb_memory, | ||||
) | ) | ||||
cg.set_option( | cg.set_option( | ||||
"sublinear_mem_cofig.genetic_nr_iter", | |||||
"sublinear_mem_config.genetic_nr_iter", | |||||
self._sublinear_memory_config.genetic_nr_iter, | self._sublinear_memory_config.genetic_nr_iter, | ||||
) | ) | ||||
cg.set_option( | cg.set_option( | ||||
"sublinear_mem_cofig.genetic_pool_size", | |||||
"sublinear_mem_config.genetic_pool_size", | |||||
self._sublinear_memory_config.genetic_pool_size, | self._sublinear_memory_config.genetic_pool_size, | ||||
) | ) | ||||
cg.set_option( | cg.set_option( | ||||
"sublinear_mem_cofig.thresh_nr_try", | |||||
"sublinear_mem_config.thresh_nr_try", | |||||
self._sublinear_memory_config.thresh_nr_try, | self._sublinear_memory_config.thresh_nr_try, | ||||
) | ) | ||||
cg.set_option( | cg.set_option( | ||||
"sublinear_mem_cofig.num_worker", | |||||
"sublinear_mem_config.num_worker", | |||||
self._sublinear_memory_config.num_worker, | self._sublinear_memory_config.num_worker, | ||||
) | ) | ||||
# pack allreduce | # pack allreduce | ||||
@@ -116,11 +116,11 @@ bool _config::set_comp_graph_option( | |||||
SET_CG_OPTION(allocate_static_mem_after_graph_compile); | SET_CG_OPTION(allocate_static_mem_after_graph_compile); | ||||
SET_CG_OPTION(log_level); | SET_CG_OPTION(log_level); | ||||
SET_CG_OPTION(enable_sublinear_memory_opt); | SET_CG_OPTION(enable_sublinear_memory_opt); | ||||
SET_CG_OPTION(sublinear_mem_cofig.lb_memory); | |||||
SET_CG_OPTION(sublinear_mem_cofig.genetic_nr_iter); | |||||
SET_CG_OPTION(sublinear_mem_cofig.genetic_pool_size); | |||||
SET_CG_OPTION(sublinear_mem_cofig.thresh_nr_try); | |||||
SET_CG_OPTION(sublinear_mem_cofig.num_worker); | |||||
SET_CG_OPTION(sublinear_mem_config.lb_memory); | |||||
SET_CG_OPTION(sublinear_mem_config.genetic_nr_iter); | |||||
SET_CG_OPTION(sublinear_mem_config.genetic_pool_size); | |||||
SET_CG_OPTION(sublinear_mem_config.thresh_nr_try); | |||||
SET_CG_OPTION(sublinear_mem_config.num_worker); | |||||
SET_CG_OPTION(enable_var_mem_defragment); | SET_CG_OPTION(enable_var_mem_defragment); | ||||
SET_CG_OPTION(eager_evaluation); | SET_CG_OPTION(eager_evaluation); | ||||
SET_CG_OPTION(enable_memory_swap); | SET_CG_OPTION(enable_memory_swap); | ||||
@@ -219,7 +219,7 @@ ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner) | |||||
grad_manager{owner}, | grad_manager{owner}, | ||||
#if MGB_ENABLE_SUBLINEAR | #if MGB_ENABLE_SUBLINEAR | ||||
seq_modifier_for_sublinear_memory{owner, | seq_modifier_for_sublinear_memory{owner, | ||||
&(owner->options().sublinear_mem_cofig)}, | |||||
&(owner->options().sublinear_mem_config)}, | |||||
#endif | #endif | ||||
#if MGB_ENABLE_MEMORY_SWAP | #if MGB_ENABLE_MEMORY_SWAP | ||||
memory_swap_support{owner}, | memory_swap_support{owner}, | ||||
@@ -409,7 +409,7 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||||
int genetic_pool_size = 20; | int genetic_pool_size = 20; | ||||
int lb_memory = 0; | int lb_memory = 0; | ||||
int num_worker = sys::get_cpu_count() / 2; | int num_worker = sys::get_cpu_count() / 2; | ||||
} sublinear_mem_cofig; | |||||
} sublinear_mem_config; | |||||
//! do not re-profile to select best impl algo when input shape | //! do not re-profile to select best impl algo when input shape | ||||
//! changes (use previous algo) | //! changes (use previous algo) | ||||
@@ -522,7 +522,7 @@ TEST(TestSublinearMemory, BadOpr) { | |||||
set_priority(z, 3); | set_priority(z, 3); | ||||
graph->options().graph_opt_level = 0; | graph->options().graph_opt_level = 0; | ||||
graph->options().enable_sublinear_memory_opt = 1; | graph->options().enable_sublinear_memory_opt = 1; | ||||
graph->options().sublinear_mem_cofig.genetic_nr_iter = 50; | |||||
graph->options().sublinear_mem_config.genetic_nr_iter = 50; | |||||
auto func = graph->compile({{y, {}}, {z, {}}}); | auto func = graph->compile({{y, {}}, {z, {}}}); | ||||
auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get()) | auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get()) | ||||
->seq_modifier_for_sublinear_memory().prev_min_bottleneck(); | ->seq_modifier_for_sublinear_memory().prev_min_bottleneck(); | ||||