GitOrigin-RevId: 7a47f5d0d5
release-0.3
@@ -18,6 +18,7 @@ import megengine._internal as mgb | |||
from megengine._internal.plugin import CompGraphProfiler | |||
from ..core import Tensor, graph, tensor | |||
from .sublinear_memory_config import SublinearMemConfig | |||
def sideeffect(f): | |||
@@ -78,10 +79,12 @@ class trace: | |||
* accelerated evalutaion via :meth:`.__call__` | |||
:param func: Positional only argument. | |||
:param symbolic: Whether to use symbolic tensor. | |||
:param symbolic: Whether to use symbolic tensor. Default: False | |||
:param opt_level: Optimization level for compiling trace. | |||
:param log_level: Log level. | |||
:param profiling: Whether to profile compiled trace. | |||
:param enable_sublinear: Enable sublinear memory optimization. Default: False | |||
:param sublinear_mem_config: Configuration for sublinear memory optimization. | |||
:param profiling: Whether to profile compiled trace. Default: False | |||
""" | |||
_active_instance = None | |||
@@ -103,12 +106,16 @@ class trace: | |||
symbolic: bool = False, | |||
opt_level: int = None, | |||
log_level: int = None, | |||
enable_sublinear: bool = False, | |||
sublinear_mem_config: SublinearMemConfig = None, | |||
profiling: bool = False | |||
): | |||
self.__wrapped__ = func | |||
self._symbolic = symbolic | |||
self._graph_opt_level = opt_level | |||
self._log_level = log_level | |||
self._enable_sublinear = enable_sublinear | |||
self._sublinear_mem_config = sublinear_mem_config | |||
self._status = self._UNSTARTED | |||
self._args = None | |||
self._kwargs = None | |||
@@ -280,11 +287,35 @@ class trace: | |||
def _apply_graph_options(self, cg): | |||
# graph opt level | |||
if not self._graph_opt_level is None: | |||
if not (self._graph_opt_level is None): | |||
cg.set_option("graph_opt_level", self._graph_opt_level) | |||
# log level | |||
if not self._log_level is None: | |||
if not (self._log_level is None): | |||
cg.set_option("log_level", self._log_level) | |||
# sublinear | |||
if self._enable_sublinear: | |||
cg.set_option("enable_sublinear_memory_opt", True) | |||
if not (self._sublinear_mem_config is None): | |||
cg.set_option( | |||
"sublinear_mem_cofig.lb_memory", | |||
self._sublinear_mem_config.lb_memory, | |||
) | |||
cg.set_option( | |||
"sublinear_mem_cofig.genetic_nr_iter", | |||
self._sublinear_mem_config.genetic_nr_iter, | |||
) | |||
cg.set_option( | |||
"sublinear_mem_cofig.genetic_pool_size", | |||
self._sublinear_mem_config.genetic_pool_size, | |||
) | |||
cg.set_option( | |||
"sublinear_mem_cofig.thresh_nr_try", | |||
self._sublinear_mem_config.thresh_nr_try, | |||
) | |||
cg.set_option( | |||
"sublinear_mem_cofig.num_worker", | |||
self._sublinear_mem_config.num_worker, | |||
) | |||
# profile | |||
if self._profiling: | |||
self._profiler = CompGraphProfiler(cg) | |||
@@ -0,0 +1,46 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..core.device import get_device_count | |||
class SublinearMemConfig: | |||
r""" | |||
Configuration for sublinear memory optimization. | |||
:param thresh_nr_try: number of samples both for searching in linear space | |||
and around current thresh in sublinear memory optimization. Default: 10. | |||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY'. | |||
:param genetic_nr_iter: number of iterations to find the best checkpoints in genetic algorithm. | |||
Default: 0. | |||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER'. | |||
:param genetic_pool_size: number of samples for the crossover random selection | |||
during genetic optimization. Default: 20. | |||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE'. | |||
:param lb_memory: memory lower bound of bottleneck size in MB for sublinear memory optimization. | |||
It can be used to perform manual tradeoff between memory and speed. Default: 0. | |||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB'. | |||
:param num_worker: number of thread workers to search the optimum checkpoints | |||
in sublinear memory optimization. Default: half of cpu number in the system. | |||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_WORKERS'. | |||
""" | |||
def __init__( | |||
self, | |||
thresh_nr_try: int = 10, | |||
genetic_nr_iter: int = 0, | |||
genetic_pool_size: int = 20, | |||
lb_memory: int = 0, | |||
num_worker: int = get_device_count("cpu") / 2, | |||
): | |||
self.thresh_nr_try = thresh_nr_try | |||
self.genetic_nr_iter = genetic_nr_iter | |||
self.genetic_pool_size = genetic_pool_size | |||
self.lb_memory = lb_memory | |||
self.num_worker = num_worker |
@@ -42,7 +42,8 @@ bool _config::set_comp_graph_option( | |||
std::is_same<decltype(opt.name_chk), bool>::value || \ | |||
std::is_same<decltype(opt.name_chk), uint8_t>::value || \ | |||
std::is_same<decltype(opt.name_chk), int16_t>::value || \ | |||
std::is_same<decltype(opt.name_chk), uint16_t>::value, \ | |||
std::is_same<decltype(opt.name_chk), uint16_t>::value || \ | |||
std::is_same<decltype(opt.name_chk), int32_t>::value, \ | |||
"not bool/int opt"); \ | |||
if (name == #name_chk) { \ | |||
auto ret = opt.name_chk; \ | |||
@@ -66,6 +67,11 @@ bool _config::set_comp_graph_option( | |||
SET_CG_OPTION(allocate_static_mem_after_graph_compile); | |||
SET_CG_OPTION(log_level); | |||
SET_CG_OPTION(enable_sublinear_memory_opt); | |||
SET_CG_OPTION(sublinear_mem_cofig.lb_memory); | |||
SET_CG_OPTION(sublinear_mem_cofig.genetic_nr_iter); | |||
SET_CG_OPTION(sublinear_mem_cofig.genetic_pool_size); | |||
SET_CG_OPTION(sublinear_mem_cofig.thresh_nr_try); | |||
SET_CG_OPTION(sublinear_mem_cofig.num_worker); | |||
SET_CG_OPTION(enable_var_mem_defragment); | |||
SET_CG_OPTION(eager_evaluation); | |||
SET_CG_OPTION(enable_memory_swap); | |||
@@ -17,6 +17,7 @@ import megengine as mge | |||
import megengine.functional as F | |||
from megengine import jit, tensor | |||
from megengine.functional.debug_param import set_conv_execution_strategy | |||
from megengine.jit import SublinearMemConfig | |||
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module | |||
from megengine.optimizer import SGD | |||
from megengine.test import assertTensorClose | |||
@@ -130,7 +131,14 @@ def update_model(model_path): | |||
mge.save(checkpoint, model_path) | |||
def run_test(model_path, use_jit, use_symbolic): | |||
def run_test( | |||
model_path, | |||
use_jit, | |||
use_symbolic, | |||
enable_sublinear=False, | |||
sublinear_mem_config=None, | |||
max_err=None, | |||
): | |||
""" | |||
Load the model with test cases and run the training for one iter. | |||
@@ -152,11 +160,17 @@ def run_test(model_path, use_jit, use_symbolic): | |||
data.set_value(checkpoint["data"]) | |||
label.set_value(checkpoint["label"]) | |||
max_err = 1e-5 | |||
if max_err is None: | |||
max_err = 1e-5 | |||
train_func = train | |||
if use_jit: | |||
train_func = jit.trace(train_func, symbolic=use_symbolic) | |||
train_func = jit.trace( | |||
train_func, | |||
symbolic=use_symbolic, | |||
enable_sublinear=enable_sublinear, | |||
sublinear_mem_config=sublinear_mem_config, | |||
) | |||
opt.zero_grad() | |||
loss = train_func(data, label, net=net, opt=opt) | |||
@@ -183,3 +197,14 @@ def test_correctness(): | |||
run_test(model_path, False, False) | |||
run_test(model_path, True, False) | |||
run_test(model_path, True, True) | |||
# sublinear | |||
config = SublinearMemConfig(genetic_nr_iter=10) | |||
run_test( | |||
model_path, | |||
True, | |||
True, | |||
enable_sublinear=True, | |||
sublinear_mem_config=config, | |||
max_err=1e-5, | |||
) |
@@ -18,6 +18,7 @@ import megengine._internal as mgb | |||
import megengine.module as M | |||
from megengine import jit, tensor | |||
from megengine.core.tensor import Tensor | |||
from megengine.jit import SublinearMemConfig | |||
from megengine.test import assertTensorClose | |||
@@ -185,3 +186,14 @@ def test_dump_bn_fused(): | |||
mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder" | |||
and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward" | |||
) | |||
# Simply verify the options passed down | |||
def test_sublinear(): | |||
config = SublinearMemConfig(genetic_nr_iter=10) | |||
@jit.trace(symbolic=True, enable_sublinear=True, sublinear_mem_config=config) | |||
def f(x): | |||
return x + x | |||
f([0.0]) |
@@ -217,7 +217,8 @@ ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner) | |||
static_infer_comp_seq_manager{owner}, | |||
grad_manager{owner}, | |||
#if MGB_ENABLE_SUBLINEAR | |||
seq_modifier_for_sublinear_memory{owner}, | |||
seq_modifier_for_sublinear_memory{owner, | |||
&(owner->options().sublinear_mem_cofig)}, | |||
#endif | |||
#if MGB_ENABLE_MEMORY_SWAP | |||
memory_swap_support{owner}, | |||
@@ -681,14 +681,6 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN { | |||
std::vector<std::future<void>> m_futures; | |||
std::mutex m_mtx; | |||
struct Config { | |||
size_t thresh_nr_try = 10; | |||
size_t genetic_nr_iter = 0; | |||
size_t genetic_pool_size = 20; | |||
double lb_memory = 0; | |||
}; | |||
Config m_config; | |||
/*! | |||
* \brief check given thresh, and update states | |||
* \return bottleneck value for given thresh | |||
@@ -725,20 +717,22 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN { | |||
public: | |||
ActionSearcherSingleCN(SeqModifierForSublinearMemory* par) | |||
: m_par_modifier{par} { | |||
auto & m_config = m_par_modifier->m_config; | |||
//! allow environmental variable to overwrite the setting | |||
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY")) { | |||
m_config.thresh_nr_try = std::stoi(env); | |||
m_config->thresh_nr_try = std::stoi(env); | |||
} | |||
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER")) { | |||
m_config.genetic_nr_iter = std::stoi(env); | |||
m_config->genetic_nr_iter = std::stoi(env); | |||
} | |||
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE")) { | |||
auto psize = static_cast<size_t>(std::stoi(env)); | |||
mgb_assert(psize > 0 || m_config.genetic_nr_iter == 0, | |||
mgb_assert(psize > 0 || m_config->genetic_nr_iter == 0, | |||
"invalid pool size %zu in genetic algorithm,", psize); | |||
m_config.genetic_pool_size = psize; | |||
m_config->genetic_pool_size = psize; | |||
} | |||
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB")) { | |||
m_config.lb_memory = std::stod(env) * 1024 * 1024; | |||
m_config->lb_memory = std::stoi(env) * 1024 * 1024; | |||
} | |||
} | |||
@@ -812,7 +806,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() { | |||
invoke_search(thresh); | |||
} | |||
size_t NR_TRY = m_config.thresh_nr_try; | |||
size_t NR_TRY = m_par_modifier->m_config->thresh_nr_try; | |||
// search in linear space | |||
auto step = init_thresh / (NR_TRY + 1); | |||
@@ -833,8 +827,8 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() { | |||
void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() { | |||
RNGxorshf rng(2333); | |||
size_t POOL_SIZE = m_config.genetic_pool_size; | |||
size_t NR_ITER = m_config.genetic_nr_iter; | |||
size_t POOL_SIZE = m_par_modifier->m_config->genetic_pool_size; | |||
size_t NR_ITER = m_par_modifier->m_config->genetic_nr_iter; | |||
auto mutation = [&](const SplitPointSet& sps) { | |||
auto s = *sps; | |||
size_t length = s.size(); | |||
@@ -953,7 +947,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() { | |||
} | |||
void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_refine() { | |||
size_t lower_bound = m_config.lb_memory; | |||
size_t lower_bound = m_par_modifier->m_config->lb_memory; | |||
if (m_min_bottleneck >= lower_bound) | |||
return; | |||
OprFootprint footprint; | |||
@@ -1052,7 +1046,7 @@ SeqModifierForSublinearMemory::ActionSearcherSingleCN::search( | |||
msg.push_back('\n'); | |||
msg.append(ssprintf("m_min_bottleneck: %-10.2f\n", | |||
m_min_bottleneck * SIZE2MB)); | |||
if(!m_config.genetic_nr_iter) { | |||
if(!m_par_modifier->m_config->genetic_nr_iter) { | |||
msg.append(ssprintf( | |||
"\nGenetic algorithm is currently DISABLED, " | |||
"set MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]" | |||
@@ -1124,7 +1118,7 @@ SeqModifierForSublinearMemory::search_action( | |||
"invalid planner concurrency: %zu", set); | |||
planner_concur = set; | |||
} else { | |||
planner_concur = sys::get_cpu_count() / 2; | |||
planner_concur = m_config->num_worker; | |||
} | |||
mgb_log_debug("use %zu threads to search for sublinear memory plan; " | |||
@@ -1350,8 +1344,8 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() { | |||
} | |||
SeqModifierForSublinearMemory::SeqModifierForSublinearMemory( | |||
ComputingGraphImpl* owner) | |||
: m_mem_opt(owner), m_owner_graph(owner) {} | |||
ComputingGraphImpl* owner, Config* config_p) | |||
: m_config(config_p), m_mem_opt(owner), m_owner_graph(owner) {} | |||
#endif // !MGB_ENABLE_SUBLINEAR | |||
@@ -12,6 +12,7 @@ | |||
#pragma once | |||
#include "./memory_optimizer.h" | |||
#include "megbrain/graph/cg.h" | |||
#include "megbrain/utils/async_worker.h" | |||
#if MGB_ENABLE_SUBLINEAR | |||
@@ -31,6 +32,10 @@ class SeqModifierForSublinearMemory { | |||
using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>; | |||
using SplitPointSet = std::shared_ptr<std::vector<size_t>>; | |||
//! Config options | |||
using Config = mgb::cg::ComputingGraph::Options::SublinearMemConfig; | |||
Config* m_config; | |||
//! get modifications to be taken under some specific constraints | |||
class ModifyActionPlanner; | |||
@@ -104,7 +109,7 @@ class SeqModifierForSublinearMemory { | |||
} | |||
public: | |||
SeqModifierForSublinearMemory(ComputingGraphImpl* owner); | |||
SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g); | |||
//! see memory_optimizer set_priority_before_opt | |||
void set_priority_before_opt(const VarNodeArray& endpoints) { | |||
@@ -16,6 +16,7 @@ | |||
#include "megbrain/graph/static_infer.h" | |||
#include "megbrain/graph/seq_comp_node_opt.h" | |||
#include "megbrain/utils/event.h" | |||
#include "megbrain/system.h" | |||
#if MGB_ENABLE_JSON | |||
#include "megbrain/utils/json.h" | |||
@@ -300,6 +301,15 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||
//! whether to enable sublinear memory optimization | |||
bool enable_sublinear_memory_opt = false; | |||
//! Control parameter for sublinear memory optimization | |||
struct SublinearMemConfig { | |||
int thresh_nr_try = 10; | |||
int genetic_nr_iter = 0; | |||
int genetic_pool_size = 20; | |||
int lb_memory = 0; | |||
int num_worker = sys::get_cpu_count() / 2; | |||
} sublinear_mem_cofig; | |||
//! do not re-profile to select best impl algo when input shape | |||
//! changes (use previous algo) | |||
bool no_profiling_on_shape_change = false; | |||
@@ -504,57 +504,47 @@ TEST(TestSublinearMemory, DepsInTopoSort) { | |||
} | |||
TEST(TestSublinearMemory, BadOpr) { | |||
constexpr const char* KEY = "MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER"; | |||
auto old_value = getenv(KEY); | |||
setenv(KEY, "50", 1); | |||
MGB_TRY { | |||
HostTensorGenerator<> gen; | |||
auto cn = CompNode::load("xpu0"); | |||
constexpr size_t N = 1024, Scale = 2; | |||
auto host_x = gen({N}, cn); | |||
for (bool bad : {false, true}) { | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x), | |||
bad_var = SublinearBadOpr::make(x, bad, Scale), | |||
y0 = opr::reduce_sum(bad_var, x.make_scalar_dt(1)), | |||
y1 = SublinearBadOpr::make(y0, false, N * Scale), | |||
y = y1 + 1, | |||
z = opr::reduce_max(bad_var, x.make_scalar_dt(1)); | |||
set_priority(y0, 0); | |||
set_priority(y1, 1); | |||
set_priority(y, 2); | |||
set_priority(z, 3); | |||
graph->options().graph_opt_level = 0; | |||
graph->options().enable_sublinear_memory_opt = 1; | |||
auto func = graph->compile({{y, {}}, {z, {}}}); | |||
auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get()) | |||
->seq_modifier_for_sublinear_memory().prev_min_bottleneck(); | |||
// bottleneck: | |||
// if bad : y = y1 + 1, bad_var should be saved to calculate | |||
// z later, total memory usage is | |||
// N * sclae * 2(bad_var and y1) + 1 (immutable tensor 1) | |||
// else : bad_var = BadOpr(x), total memory usage is | |||
// N(x) + N * scale(bad_var), bad_var would be recomputed | |||
// when calculate z = reduce(bad_var) | |||
size_t expect = bad ? N * Scale * 2 + 1 : N * Scale + N; | |||
ASSERT_EQ(results.at(cn), expect * host_x->dtype().size()); | |||
size_t nr_bad_opr = 0; | |||
auto count_up = [&nr_bad_opr](cg::OperatorNodeBase* op) { | |||
if (op->dyn_typeinfo() == SublinearBadOpr::typeinfo()) { | |||
++ nr_bad_opr; | |||
} | |||
return true; | |||
}; | |||
func->iter_opr_seq(count_up); | |||
ASSERT_EQ(nr_bad_opr, bad ? 2 : 3); | |||
} | |||
} MGB_FINALLY( | |||
if (old_value) { | |||
setenv(KEY, old_value, 1); | |||
} else { | |||
unsetenv(KEY); | |||
} | |||
); | |||
HostTensorGenerator<> gen; | |||
auto cn = CompNode::load("xpu0"); | |||
constexpr size_t N = 1024, Scale = 2; | |||
auto host_x = gen({N}, cn); | |||
for (bool bad : {false, true}) { | |||
auto graph = ComputingGraph::make(); | |||
auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x), | |||
bad_var = SublinearBadOpr::make(x, bad, Scale), | |||
y0 = opr::reduce_sum(bad_var, x.make_scalar_dt(1)), | |||
y1 = SublinearBadOpr::make(y0, false, N * Scale), | |||
y = y1 + 1, | |||
z = opr::reduce_max(bad_var, x.make_scalar_dt(1)); | |||
set_priority(y0, 0); | |||
set_priority(y1, 1); | |||
set_priority(y, 2); | |||
set_priority(z, 3); | |||
graph->options().graph_opt_level = 0; | |||
graph->options().enable_sublinear_memory_opt = 1; | |||
graph->options().sublinear_mem_cofig.genetic_nr_iter = 50; | |||
auto func = graph->compile({{y, {}}, {z, {}}}); | |||
auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get()) | |||
->seq_modifier_for_sublinear_memory().prev_min_bottleneck(); | |||
// bottleneck: | |||
// if bad : y = y1 + 1, bad_var should be saved to calculate | |||
// z later, total memory usage is | |||
// N * sclae * 2(bad_var and y1) + 1 (immutable tensor 1) | |||
// else : bad_var = BadOpr(x), total memory usage is | |||
// N(x) + N * scale(bad_var), bad_var would be recomputed | |||
// when calculate z = reduce(bad_var) | |||
size_t expect = bad ? N * Scale * 2 + 1 : N * Scale + N; | |||
ASSERT_EQ(results.at(cn), expect * host_x->dtype().size()); | |||
size_t nr_bad_opr = 0; | |||
auto count_up = [&nr_bad_opr](cg::OperatorNodeBase* op) { | |||
if (op->dyn_typeinfo() == SublinearBadOpr::typeinfo()) { | |||
++ nr_bad_opr; | |||
} | |||
return true; | |||
}; | |||
func->iter_opr_seq(count_up); | |||
ASSERT_EQ(nr_bad_opr, bad ? 2 : 3); | |||
} | |||
} | |||
#else | |||