Author | SHA1 | Message | Date |
---|---|---|---|
|
0b1d516e38 | chore(release): bump version | 5 years ago |
|
f0959dc9c4 |
refactor(mge/examples): add example to use modified sublinear API
GitOrigin-RevId:
|
5 years ago |
|
a2f0e8788f |
feat(mge/api): expose sublinear related parameters at mge api level
GitOrigin-RevId:
|
5 years ago |
|
f44d2632da | chore(release): bump version | 5 years ago |
|
d84dc7f75c |
build(python_module): use consistent flag declaration in SWIG and c++
Include cmake generated megbrain_build_config.h in SWIG in order to remove usage of MGB_DEF.
GitOrigin-RevId:
|
5 years ago |
|
03728d45da |
fix(mgb/build): fix multi-machine macro and add test_distributed
GitOrigin-RevId:
|
5 years ago |
@@ -55,16 +55,18 @@ add_custom_command( | |||||
add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py) | add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py) | ||||
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) | |||||
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/mm_handler.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) | |||||
if(MGE_WITH_DISTRIBUTED) | if(MGE_WITH_DISTRIBUTED) | ||||
list(APPEND SRCS src/cpp/mm_handler.cpp src/cpp/zmq_rpc.cpp) | |||||
list(APPEND SRCS src/cpp/zmq_rpc.cpp) | |||||
endif() | endif() | ||||
include(UseSWIG) | include(UseSWIG) | ||||
set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON) | set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON) | ||||
# cmake < 3.12 do not honor INCLUDE_DIRECTORIES property, just add include directory into SWIG_FLAGS | # cmake < 3.12 do not honor INCLUDE_DIRECTORIES property, just add include directory into SWIG_FLAGS | ||||
set_property(SOURCE ${SWIG_SRC} PROPERTY SWIG_FLAGS ${MGB_DEF} -I${PROJECT_SOURCE_DIR}/src/serialization/include) | |||||
# Add -I${PROJECT_BINARY_DIR}/genfiles in order to include megbrain_build_config.h so that we don't need to pass cmake flags by -D. | |||||
set_property(SOURCE ${SWIG_SRC} PROPERTY SWIG_FLAGS -I${PROJECT_SOURCE_DIR}/src/serialization/include -I${PROJECT_BINARY_DIR}/genfiles) | |||||
set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR}) | set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR}) | ||||
set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal) | set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal) | ||||
@@ -18,6 +18,7 @@ import megengine._internal as mgb | |||||
from megengine._internal.plugin import CompGraphProfiler | from megengine._internal.plugin import CompGraphProfiler | ||||
from ..core import Tensor, graph, tensor | from ..core import Tensor, graph, tensor | ||||
from .sublinear_memory_config import SublinearMemoryConfig | |||||
def sideeffect(f): | def sideeffect(f): | ||||
@@ -78,10 +79,12 @@ class trace: | |||||
* accelerated evalutaion via :meth:`.__call__` | * accelerated evalutaion via :meth:`.__call__` | ||||
:param func: Positional only argument. | :param func: Positional only argument. | ||||
:param symbolic: Whether to use symbolic tensor. | |||||
:param symbolic: Whether to use symbolic tensor. Default: False | |||||
:param opt_level: Optimization level for compiling trace. | :param opt_level: Optimization level for compiling trace. | ||||
:param log_level: Log level. | :param log_level: Log level. | ||||
:param profiling: Whether to profile compiled trace. | |||||
:param sublinear_memory_config: Configuration for sublinear memory optimization. | |||||
If not None, it enables sublinear memory optimization with given setting. | |||||
:param profiling: Whether to profile compiled trace. Default: False | |||||
""" | """ | ||||
_active_instance = None | _active_instance = None | ||||
@@ -103,12 +106,14 @@ class trace: | |||||
symbolic: bool = False, | symbolic: bool = False, | ||||
opt_level: int = None, | opt_level: int = None, | ||||
log_level: int = None, | log_level: int = None, | ||||
sublinear_memory_config: SublinearMemoryConfig = None, | |||||
profiling: bool = False | profiling: bool = False | ||||
): | ): | ||||
self.__wrapped__ = func | self.__wrapped__ = func | ||||
self._symbolic = symbolic | self._symbolic = symbolic | ||||
self._graph_opt_level = opt_level | self._graph_opt_level = opt_level | ||||
self._log_level = log_level | self._log_level = log_level | ||||
self._sublinear_memory_config = sublinear_memory_config | |||||
self._status = self._UNSTARTED | self._status = self._UNSTARTED | ||||
self._args = None | self._args = None | ||||
self._kwargs = None | self._kwargs = None | ||||
@@ -280,11 +285,34 @@ class trace: | |||||
def _apply_graph_options(self, cg): | def _apply_graph_options(self, cg): | ||||
# graph opt level | # graph opt level | ||||
if not self._graph_opt_level is None: | |||||
if self._graph_opt_level is not None: | |||||
cg.set_option("graph_opt_level", self._graph_opt_level) | cg.set_option("graph_opt_level", self._graph_opt_level) | ||||
# log level | # log level | ||||
if not self._log_level is None: | |||||
if self._log_level is not None: | |||||
cg.set_option("log_level", self._log_level) | cg.set_option("log_level", self._log_level) | ||||
# sublinear | |||||
if self._sublinear_memory_config is not None: | |||||
cg.set_option("enable_sublinear_memory_opt", True) | |||||
cg.set_option( | |||||
"sublinear_mem_cofig.lb_memory", | |||||
self._sublinear_memory_config.lb_memory, | |||||
) | |||||
cg.set_option( | |||||
"sublinear_mem_cofig.genetic_nr_iter", | |||||
self._sublinear_memory_config.genetic_nr_iter, | |||||
) | |||||
cg.set_option( | |||||
"sublinear_mem_cofig.genetic_pool_size", | |||||
self._sublinear_memory_config.genetic_pool_size, | |||||
) | |||||
cg.set_option( | |||||
"sublinear_mem_cofig.thresh_nr_try", | |||||
self._sublinear_memory_config.thresh_nr_try, | |||||
) | |||||
cg.set_option( | |||||
"sublinear_mem_cofig.num_worker", | |||||
self._sublinear_memory_config.num_worker, | |||||
) | |||||
# profile | # profile | ||||
if self._profiling: | if self._profiling: | ||||
self._profiler = CompGraphProfiler(cg) | self._profiler = CompGraphProfiler(cg) | ||||
@@ -0,0 +1,56 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ..core.device import get_device_count | |||||
class SublinearMemoryConfig: | |||||
r""" | |||||
Configuration for sublinear memory optimization. | |||||
:param thresh_nr_try: number of samples both for searching in linear space | |||||
and around current thresh in sublinear memory optimization. Default: 10. | |||||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY'. | |||||
:param genetic_nr_iter: number of iterations to find the best checkpoints in genetic algorithm. | |||||
Default: 0. | |||||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER'. | |||||
:param genetic_pool_size: number of samples for the crossover random selection | |||||
during genetic optimization. Default: 20. | |||||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE'. | |||||
:param lb_memory: memory lower bound of bottleneck size in MB for sublinear memory optimization. | |||||
It can be used to perform manual tradeoff between memory and speed. Default: 0. | |||||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB'. | |||||
:param num_worker: number of thread workers to search the optimum checkpoints | |||||
in sublinear memory optimization. Default: half of cpu number in the system. | |||||
Note: the value must be greater or equal to one. | |||||
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_WORKERS'. | |||||
Note that the environmental variable MGB_COMP_GRAPH_OPT must be set to 'enable_sublinear_memory_opt=1' | |||||
in order for the above environmental variable to be effective. | |||||
""" | |||||
def __init__( | |||||
self, | |||||
thresh_nr_try: int = 10, | |||||
genetic_nr_iter: int = 0, | |||||
genetic_pool_size: int = 20, | |||||
lb_memory: int = 0, | |||||
num_worker: int = max(1, get_device_count("cpu") // 2), | |||||
): | |||||
assert thresh_nr_try >= 0, "thresh_nr_try must be greater or equal to zero" | |||||
self.thresh_nr_try = thresh_nr_try | |||||
assert genetic_nr_iter >= 0, "genetic_nr_iter must be greater or equal to zero" | |||||
self.genetic_nr_iter = genetic_nr_iter | |||||
assert ( | |||||
genetic_pool_size >= 0 | |||||
), "genetic_pool_size must be greater or equal to zero" | |||||
self.genetic_pool_size = genetic_pool_size | |||||
self.lb_memory = lb_memory | |||||
assert num_worker > 0, "num_worker must be greater or equal to one" | |||||
self.num_worker = num_worker |
@@ -1,2 +1,2 @@ | |||||
__version__ = "0.3.2" | |||||
__version__ = "0.3.4" | |||||
@@ -42,7 +42,8 @@ bool _config::set_comp_graph_option( | |||||
std::is_same<decltype(opt.name_chk), bool>::value || \ | std::is_same<decltype(opt.name_chk), bool>::value || \ | ||||
std::is_same<decltype(opt.name_chk), uint8_t>::value || \ | std::is_same<decltype(opt.name_chk), uint8_t>::value || \ | ||||
std::is_same<decltype(opt.name_chk), int16_t>::value || \ | std::is_same<decltype(opt.name_chk), int16_t>::value || \ | ||||
std::is_same<decltype(opt.name_chk), uint16_t>::value, \ | |||||
std::is_same<decltype(opt.name_chk), uint16_t>::value || \ | |||||
std::is_same<decltype(opt.name_chk), int32_t>::value, \ | |||||
"not bool/int opt"); \ | "not bool/int opt"); \ | ||||
if (name == #name_chk) { \ | if (name == #name_chk) { \ | ||||
auto ret = opt.name_chk; \ | auto ret = opt.name_chk; \ | ||||
@@ -66,6 +67,11 @@ bool _config::set_comp_graph_option( | |||||
SET_CG_OPTION(allocate_static_mem_after_graph_compile); | SET_CG_OPTION(allocate_static_mem_after_graph_compile); | ||||
SET_CG_OPTION(log_level); | SET_CG_OPTION(log_level); | ||||
SET_CG_OPTION(enable_sublinear_memory_opt); | SET_CG_OPTION(enable_sublinear_memory_opt); | ||||
SET_CG_OPTION(sublinear_mem_cofig.lb_memory); | |||||
SET_CG_OPTION(sublinear_mem_cofig.genetic_nr_iter); | |||||
SET_CG_OPTION(sublinear_mem_cofig.genetic_pool_size); | |||||
SET_CG_OPTION(sublinear_mem_cofig.thresh_nr_try); | |||||
SET_CG_OPTION(sublinear_mem_cofig.num_worker); | |||||
SET_CG_OPTION(enable_var_mem_defragment); | SET_CG_OPTION(enable_var_mem_defragment); | ||||
SET_CG_OPTION(eager_evaluation); | SET_CG_OPTION(eager_evaluation); | ||||
SET_CG_OPTION(enable_memory_swap); | SET_CG_OPTION(enable_memory_swap); | ||||
@@ -65,12 +65,10 @@ class _config { | |||||
static std::vector<std::pair<uint64_t, std::string>> | static std::vector<std::pair<uint64_t, std::string>> | ||||
dump_registered_oprs(); | dump_registered_oprs(); | ||||
#if MGB_ENABLE_OPR_MM | |||||
static int create_mm_server(const std::string& server_addr, int port); | static int create_mm_server(const std::string& server_addr, int port); | ||||
static void group_barrier(const std::string& server_addr, | static void group_barrier(const std::string& server_addr, | ||||
int port, uint32_t size, uint32_t rank); | int port, uint32_t size, uint32_t rank); | ||||
#endif | |||||
}; | }; | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -12,7 +12,7 @@ | |||||
#include "megbrain/exception.h" | #include "megbrain/exception.h" | ||||
#include "megbrain_config.h" | #include "megbrain_config.h" | ||||
#if MGB_CUDA | |||||
#if MGB_ENABLE_OPR_MM | |||||
#include "zmq_rpc.h" | #include "zmq_rpc.h" | ||||
#include <future> | #include <future> | ||||
@@ -242,17 +242,11 @@ int _config::create_mm_server(const std::string& server_addr, int port) { | |||||
server_addr, port, std::make_unique<GroupServerProxy>()); | server_addr, port, std::make_unique<GroupServerProxy>()); | ||||
} | } | ||||
#else | |||||
int _config::create_mm_server(const std::string& server_addr, int port) { | |||||
mgb_throw(mgb::MegBrainError, "CUDA suppport disable at compile time"); | |||||
return 0; | |||||
} | |||||
#endif | |||||
/* ======================== Group Barrier ========================== */ | /* ======================== Group Barrier ========================== */ | ||||
/*! see definition : src/cpp/megbrain_config.h. | |||||
* Block until all ranks in the group reach this barrier | |||||
*/ | |||||
void _config::group_barrier(const std::string& server_addr, | void _config::group_barrier(const std::string& server_addr, | ||||
int port, uint32_t size, uint32_t rank) { | int port, uint32_t size, uint32_t rank) { | ||||
mgb_assert(rank < size, "invalid rank %d", rank); | mgb_assert(rank < size, "invalid rank %d", rank); | ||||
@@ -263,4 +257,18 @@ void _config::group_barrier(const std::string& server_addr, | |||||
mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp); | mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp); | ||||
} | } | ||||
#else | |||||
int _config::create_mm_server(const std::string& server_addr, int port) { | |||||
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time"); | |||||
return 0; | |||||
} | |||||
void _config::group_barrier(const std::string& server_addr, | |||||
int port, uint32_t size, uint32_t rank) { | |||||
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time"); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -11,7 +11,7 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if MGB_CUDA | |||||
#if MGB_ENABLE_OPR_MM | |||||
#include "zmq_rpc.h" | #include "zmq_rpc.h" | ||||
@@ -49,6 +49,7 @@ intb4 = _mgb.intb4 | |||||
#include "plugin.h" | #include "plugin.h" | ||||
%} | %} | ||||
%include "megbrain_build_config.h" | |||||
%include "comp_node.i" | %include "comp_node.i" | ||||
%include "comp_graph.i" | %include "comp_graph.i" | ||||
%include "symbol_var.i" | %include "symbol_var.i" | ||||
@@ -17,6 +17,7 @@ import megengine as mge | |||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import jit, tensor | from megengine import jit, tensor | ||||
from megengine.functional.debug_param import set_conv_execution_strategy | from megengine.functional.debug_param import set_conv_execution_strategy | ||||
from megengine.jit import SublinearMemoryConfig | |||||
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module | from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module | ||||
from megengine.optimizer import SGD | from megengine.optimizer import SGD | ||||
from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
@@ -130,7 +131,9 @@ def update_model(model_path): | |||||
mge.save(checkpoint, model_path) | mge.save(checkpoint, model_path) | ||||
def run_test(model_path, use_jit, use_symbolic): | |||||
def run_test( | |||||
model_path, use_jit, use_symbolic, sublinear_memory_config=None, max_err=None, | |||||
): | |||||
""" | """ | ||||
Load the model with test cases and run the training for one iter. | Load the model with test cases and run the training for one iter. | ||||
@@ -152,11 +155,16 @@ def run_test(model_path, use_jit, use_symbolic): | |||||
data.set_value(checkpoint["data"]) | data.set_value(checkpoint["data"]) | ||||
label.set_value(checkpoint["label"]) | label.set_value(checkpoint["label"]) | ||||
max_err = 1e-5 | |||||
if max_err is None: | |||||
max_err = 1e-5 | |||||
train_func = train | train_func = train | ||||
if use_jit: | if use_jit: | ||||
train_func = jit.trace(train_func, symbolic=use_symbolic) | |||||
train_func = jit.trace( | |||||
train_func, | |||||
symbolic=use_symbolic, | |||||
sublinear_memory_config=sublinear_memory_config, | |||||
) | |||||
opt.zero_grad() | opt.zero_grad() | ||||
loss = train_func(data, label, net=net, opt=opt) | loss = train_func(data, label, net=net, opt=opt) | ||||
@@ -183,3 +191,9 @@ def test_correctness(): | |||||
run_test(model_path, False, False) | run_test(model_path, False, False) | ||||
run_test(model_path, True, False) | run_test(model_path, True, False) | ||||
run_test(model_path, True, True) | run_test(model_path, True, True) | ||||
# sublinear | |||||
config = SublinearMemoryConfig(genetic_nr_iter=10) | |||||
run_test( | |||||
model_path, True, True, sublinear_memory_config=config, max_err=1e-5, | |||||
) |
@@ -0,0 +1,89 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import multiprocessing as mp | |||||
import subprocess | |||||
import sys | |||||
import numpy as np | |||||
def worker(master_ip, master_port, world_size, rank, dev, trace): | |||||
import megengine.distributed as dist | |||||
import megengine.functional as F | |||||
from megengine import is_cuda_available | |||||
from megengine import jit | |||||
from megengine.module import Linear, Module | |||||
from megengine.optimizer import SGD | |||||
if not is_cuda_available(): | |||||
return | |||||
class MLP(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc0 = Linear(3 * 224 * 224, 500) | |||||
self.fc1 = Linear(500, 10) | |||||
def forward(self, x): | |||||
x = self.fc0(x) | |||||
x = F.relu(x) | |||||
x = self.fc1(x) | |||||
return x | |||||
dist.init_process_group( | |||||
master_ip=master_ip, master_port=3456, world_size=world_size, rank=rank, dev=dev | |||||
) | |||||
net = MLP() | |||||
opt = SGD(net.parameters(requires_grad=True), lr=0.02) | |||||
data = np.random.random((64, 3 * 224 * 224)).astype(np.float32) | |||||
label = np.random.randint(0, 10, size=(64,)).astype(np.int32) | |||||
jit.trace.enabled = trace | |||||
@jit.trace() | |||||
def train_func(data, label): | |||||
pred = net(data) | |||||
loss = F.cross_entropy_with_softmax(pred, label) | |||||
opt.backward(loss) | |||||
return loss | |||||
for i in range(5): | |||||
opt.zero_grad() | |||||
loss = train_func(data, label) | |||||
opt.step() | |||||
def start_workers(worker, world_size, trace=False): | |||||
def run_subproc(rank): | |||||
cmd = "from test.integration.test_distributed import worker\n" | |||||
cmd += "worker('localhost', 3456, {}, {}, {}, {})".format( | |||||
world_size, rank, rank, "True" if trace else "False" | |||||
) | |||||
cmd = ["python3", "-c", cmd] | |||||
ret = subprocess.run( | |||||
cmd, stdout=sys.stdout, stderr=sys.stderr, universal_newlines=True | |||||
) | |||||
assert ret.returncode == 0, "subprocess failed" | |||||
procs = [] | |||||
for rank in range(world_size): | |||||
p = mp.Process(target=run_subproc, args=(rank,)) | |||||
p.start() | |||||
procs.append(p) | |||||
for p in procs: | |||||
p.join() | |||||
assert p.exitcode == 0 | |||||
def test_distributed(): | |||||
start_workers(worker, 2, trace=True) | |||||
start_workers(worker, 2, trace=False) |
@@ -18,6 +18,7 @@ import megengine._internal as mgb | |||||
import megengine.module as M | import megengine.module as M | ||||
from megengine import jit, tensor | from megengine import jit, tensor | ||||
from megengine.core.tensor import Tensor | from megengine.core.tensor import Tensor | ||||
from megengine.jit import SublinearMemoryConfig | |||||
from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
@@ -185,3 +186,14 @@ def test_dump_bn_fused(): | |||||
mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder" | mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder" | ||||
and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward" | and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward" | ||||
) | ) | ||||
# Simply verify the options passed down | |||||
def test_sublinear(): | |||||
config = SublinearMemoryConfig(genetic_nr_iter=10) | |||||
@jit.trace(symbolic=True, sublinear_memory_config=config) | |||||
def f(x): | |||||
return x + x | |||||
f([0.0]) |
@@ -31,10 +31,8 @@ if(MGE_WITH_CUDA AND MGE_WITH_TRT) | |||||
list(APPEND SOURCES ${SOURCES_}) | list(APPEND SOURCES ${SOURCES_}) | ||||
endif() | endif() | ||||
set(MGB_DEF ${MGB_DEF} PARENT_SCOPE) | |||||
add_library(megbrain STATIC EXCLUDE_FROM_ALL ${SOURCES}) | add_library(megbrain STATIC EXCLUDE_FROM_ALL ${SOURCES}) | ||||
target_link_libraries(megbrain mgb_opr_param_defs) | target_link_libraries(megbrain mgb_opr_param_defs) | ||||
target_compile_definitions(megbrain PUBLIC ${MGB_DEF}) | |||||
target_include_directories(megbrain PUBLIC ${MGB_INC}) | target_include_directories(megbrain PUBLIC ${MGB_INC}) | ||||
if(MGE_WITH_CUDA) | if(MGE_WITH_CUDA) | ||||
@@ -217,7 +217,8 @@ ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner) | |||||
static_infer_comp_seq_manager{owner}, | static_infer_comp_seq_manager{owner}, | ||||
grad_manager{owner}, | grad_manager{owner}, | ||||
#if MGB_ENABLE_SUBLINEAR | #if MGB_ENABLE_SUBLINEAR | ||||
seq_modifier_for_sublinear_memory{owner}, | |||||
seq_modifier_for_sublinear_memory{owner, | |||||
&(owner->options().sublinear_mem_cofig)}, | |||||
#endif | #endif | ||||
#if MGB_ENABLE_MEMORY_SWAP | #if MGB_ENABLE_MEMORY_SWAP | ||||
memory_swap_support{owner}, | memory_swap_support{owner}, | ||||
@@ -681,14 +681,6 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN { | |||||
std::vector<std::future<void>> m_futures; | std::vector<std::future<void>> m_futures; | ||||
std::mutex m_mtx; | std::mutex m_mtx; | ||||
struct Config { | |||||
size_t thresh_nr_try = 10; | |||||
size_t genetic_nr_iter = 0; | |||||
size_t genetic_pool_size = 20; | |||||
double lb_memory = 0; | |||||
}; | |||||
Config m_config; | |||||
/*! | /*! | ||||
* \brief check given thresh, and update states | * \brief check given thresh, and update states | ||||
* \return bottleneck value for given thresh | * \return bottleneck value for given thresh | ||||
@@ -725,20 +717,22 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN { | |||||
public: | public: | ||||
ActionSearcherSingleCN(SeqModifierForSublinearMemory* par) | ActionSearcherSingleCN(SeqModifierForSublinearMemory* par) | ||||
: m_par_modifier{par} { | : m_par_modifier{par} { | ||||
auto & m_config = m_par_modifier->m_config; | |||||
//! allow environmental variable to overwrite the setting | |||||
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY")) { | if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY")) { | ||||
m_config.thresh_nr_try = std::stoi(env); | |||||
m_config->thresh_nr_try = std::stoi(env); | |||||
} | } | ||||
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER")) { | if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER")) { | ||||
m_config.genetic_nr_iter = std::stoi(env); | |||||
m_config->genetic_nr_iter = std::stoi(env); | |||||
} | } | ||||
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE")) { | if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE")) { | ||||
auto psize = static_cast<size_t>(std::stoi(env)); | auto psize = static_cast<size_t>(std::stoi(env)); | ||||
mgb_assert(psize > 0 || m_config.genetic_nr_iter == 0, | |||||
mgb_assert(psize > 0 || m_config->genetic_nr_iter == 0, | |||||
"invalid pool size %zu in genetic algorithm,", psize); | "invalid pool size %zu in genetic algorithm,", psize); | ||||
m_config.genetic_pool_size = psize; | |||||
m_config->genetic_pool_size = psize; | |||||
} | } | ||||
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB")) { | if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB")) { | ||||
m_config.lb_memory = std::stod(env) * 1024 * 1024; | |||||
m_config->lb_memory = std::stoi(env) * 1024 * 1024; | |||||
} | } | ||||
} | } | ||||
@@ -812,7 +806,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() { | |||||
invoke_search(thresh); | invoke_search(thresh); | ||||
} | } | ||||
size_t NR_TRY = m_config.thresh_nr_try; | |||||
size_t NR_TRY = m_par_modifier->m_config->thresh_nr_try; | |||||
// search in linear space | // search in linear space | ||||
auto step = init_thresh / (NR_TRY + 1); | auto step = init_thresh / (NR_TRY + 1); | ||||
@@ -833,8 +827,8 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() { | |||||
void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() { | void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() { | ||||
RNGxorshf rng(2333); | RNGxorshf rng(2333); | ||||
size_t POOL_SIZE = m_config.genetic_pool_size; | |||||
size_t NR_ITER = m_config.genetic_nr_iter; | |||||
size_t POOL_SIZE = m_par_modifier->m_config->genetic_pool_size; | |||||
size_t NR_ITER = m_par_modifier->m_config->genetic_nr_iter; | |||||
auto mutation = [&](const SplitPointSet& sps) { | auto mutation = [&](const SplitPointSet& sps) { | ||||
auto s = *sps; | auto s = *sps; | ||||
size_t length = s.size(); | size_t length = s.size(); | ||||
@@ -953,7 +947,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() { | |||||
} | } | ||||
void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_refine() { | void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_refine() { | ||||
size_t lower_bound = m_config.lb_memory; | |||||
size_t lower_bound = m_par_modifier->m_config->lb_memory; | |||||
if (m_min_bottleneck >= lower_bound) | if (m_min_bottleneck >= lower_bound) | ||||
return; | return; | ||||
OprFootprint footprint; | OprFootprint footprint; | ||||
@@ -1052,7 +1046,7 @@ SeqModifierForSublinearMemory::ActionSearcherSingleCN::search( | |||||
msg.push_back('\n'); | msg.push_back('\n'); | ||||
msg.append(ssprintf("m_min_bottleneck: %-10.2f\n", | msg.append(ssprintf("m_min_bottleneck: %-10.2f\n", | ||||
m_min_bottleneck * SIZE2MB)); | m_min_bottleneck * SIZE2MB)); | ||||
if(!m_config.genetic_nr_iter) { | |||||
if(!m_par_modifier->m_config->genetic_nr_iter) { | |||||
msg.append(ssprintf( | msg.append(ssprintf( | ||||
"\nGenetic algorithm is currently DISABLED, " | "\nGenetic algorithm is currently DISABLED, " | ||||
"set MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]" | "set MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]" | ||||
@@ -1124,7 +1118,7 @@ SeqModifierForSublinearMemory::search_action( | |||||
"invalid planner concurrency: %zu", set); | "invalid planner concurrency: %zu", set); | ||||
planner_concur = set; | planner_concur = set; | ||||
} else { | } else { | ||||
planner_concur = sys::get_cpu_count() / 2; | |||||
planner_concur = m_config->num_worker; | |||||
} | } | ||||
mgb_log_debug("use %zu threads to search for sublinear memory plan; " | mgb_log_debug("use %zu threads to search for sublinear memory plan; " | ||||
@@ -1350,8 +1344,8 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() { | |||||
} | } | ||||
SeqModifierForSublinearMemory::SeqModifierForSublinearMemory( | SeqModifierForSublinearMemory::SeqModifierForSublinearMemory( | ||||
ComputingGraphImpl* owner) | |||||
: m_mem_opt(owner), m_owner_graph(owner) {} | |||||
ComputingGraphImpl* owner, Config* config_p) | |||||
: m_config(config_p), m_mem_opt(owner), m_owner_graph(owner) {} | |||||
#endif // !MGB_ENABLE_SUBLINEAR | #endif // !MGB_ENABLE_SUBLINEAR | ||||
@@ -12,6 +12,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "./memory_optimizer.h" | #include "./memory_optimizer.h" | ||||
#include "megbrain/graph/cg.h" | |||||
#include "megbrain/utils/async_worker.h" | #include "megbrain/utils/async_worker.h" | ||||
#if MGB_ENABLE_SUBLINEAR | #if MGB_ENABLE_SUBLINEAR | ||||
@@ -31,6 +32,10 @@ class SeqModifierForSublinearMemory { | |||||
using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>; | using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>; | ||||
using SplitPointSet = std::shared_ptr<std::vector<size_t>>; | using SplitPointSet = std::shared_ptr<std::vector<size_t>>; | ||||
//! Config options | |||||
using Config = mgb::cg::ComputingGraph::Options::SublinearMemConfig; | |||||
Config* m_config; | |||||
//! get modifications to be taken under some specific constraints | //! get modifications to be taken under some specific constraints | ||||
class ModifyActionPlanner; | class ModifyActionPlanner; | ||||
@@ -104,7 +109,7 @@ class SeqModifierForSublinearMemory { | |||||
} | } | ||||
public: | public: | ||||
SeqModifierForSublinearMemory(ComputingGraphImpl* owner); | |||||
SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g); | |||||
//! see memory_optimizer set_priority_before_opt | //! see memory_optimizer set_priority_before_opt | ||||
void set_priority_before_opt(const VarNodeArray& endpoints) { | void set_priority_before_opt(const VarNodeArray& endpoints) { | ||||
@@ -16,6 +16,7 @@ | |||||
#include "megbrain/graph/static_infer.h" | #include "megbrain/graph/static_infer.h" | ||||
#include "megbrain/graph/seq_comp_node_opt.h" | #include "megbrain/graph/seq_comp_node_opt.h" | ||||
#include "megbrain/utils/event.h" | #include "megbrain/utils/event.h" | ||||
#include "megbrain/system.h" | |||||
#if MGB_ENABLE_JSON | #if MGB_ENABLE_JSON | ||||
#include "megbrain/utils/json.h" | #include "megbrain/utils/json.h" | ||||
@@ -300,6 +301,15 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||||
//! whether to enable sublinear memory optimization | //! whether to enable sublinear memory optimization | ||||
bool enable_sublinear_memory_opt = false; | bool enable_sublinear_memory_opt = false; | ||||
//! Control parameter for sublinear memory optimization | |||||
struct SublinearMemConfig { | |||||
int thresh_nr_try = 10; | |||||
int genetic_nr_iter = 0; | |||||
int genetic_pool_size = 20; | |||||
int lb_memory = 0; | |||||
int num_worker = sys::get_cpu_count() / 2; | |||||
} sublinear_mem_cofig; | |||||
//! do not re-profile to select best impl algo when input shape | //! do not re-profile to select best impl algo when input shape | ||||
//! changes (use previous algo) | //! changes (use previous algo) | ||||
bool no_profiling_on_shape_change = false; | bool no_profiling_on_shape_change = false; | ||||
@@ -504,57 +504,47 @@ TEST(TestSublinearMemory, DepsInTopoSort) { | |||||
} | } | ||||
TEST(TestSublinearMemory, BadOpr) { | TEST(TestSublinearMemory, BadOpr) { | ||||
constexpr const char* KEY = "MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER"; | |||||
auto old_value = getenv(KEY); | |||||
setenv(KEY, "50", 1); | |||||
MGB_TRY { | |||||
HostTensorGenerator<> gen; | |||||
auto cn = CompNode::load("xpu0"); | |||||
constexpr size_t N = 1024, Scale = 2; | |||||
auto host_x = gen({N}, cn); | |||||
for (bool bad : {false, true}) { | |||||
auto graph = ComputingGraph::make(); | |||||
auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x), | |||||
bad_var = SublinearBadOpr::make(x, bad, Scale), | |||||
y0 = opr::reduce_sum(bad_var, x.make_scalar_dt(1)), | |||||
y1 = SublinearBadOpr::make(y0, false, N * Scale), | |||||
y = y1 + 1, | |||||
z = opr::reduce_max(bad_var, x.make_scalar_dt(1)); | |||||
set_priority(y0, 0); | |||||
set_priority(y1, 1); | |||||
set_priority(y, 2); | |||||
set_priority(z, 3); | |||||
graph->options().graph_opt_level = 0; | |||||
graph->options().enable_sublinear_memory_opt = 1; | |||||
auto func = graph->compile({{y, {}}, {z, {}}}); | |||||
auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get()) | |||||
->seq_modifier_for_sublinear_memory().prev_min_bottleneck(); | |||||
// bottleneck: | |||||
// if bad : y = y1 + 1, bad_var should be saved to calculate | |||||
// z later, total memory usage is | |||||
// N * sclae * 2(bad_var and y1) + 1 (immutable tensor 1) | |||||
// else : bad_var = BadOpr(x), total memory usage is | |||||
// N(x) + N * scale(bad_var), bad_var would be recomputed | |||||
// when calculate z = reduce(bad_var) | |||||
size_t expect = bad ? N * Scale * 2 + 1 : N * Scale + N; | |||||
ASSERT_EQ(results.at(cn), expect * host_x->dtype().size()); | |||||
size_t nr_bad_opr = 0; | |||||
auto count_up = [&nr_bad_opr](cg::OperatorNodeBase* op) { | |||||
if (op->dyn_typeinfo() == SublinearBadOpr::typeinfo()) { | |||||
++ nr_bad_opr; | |||||
} | |||||
return true; | |||||
}; | |||||
func->iter_opr_seq(count_up); | |||||
ASSERT_EQ(nr_bad_opr, bad ? 2 : 3); | |||||
} | |||||
} MGB_FINALLY( | |||||
if (old_value) { | |||||
setenv(KEY, old_value, 1); | |||||
} else { | |||||
unsetenv(KEY); | |||||
} | |||||
); | |||||
HostTensorGenerator<> gen; | |||||
auto cn = CompNode::load("xpu0"); | |||||
constexpr size_t N = 1024, Scale = 2; | |||||
auto host_x = gen({N}, cn); | |||||
for (bool bad : {false, true}) { | |||||
auto graph = ComputingGraph::make(); | |||||
auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x), | |||||
bad_var = SublinearBadOpr::make(x, bad, Scale), | |||||
y0 = opr::reduce_sum(bad_var, x.make_scalar_dt(1)), | |||||
y1 = SublinearBadOpr::make(y0, false, N * Scale), | |||||
y = y1 + 1, | |||||
z = opr::reduce_max(bad_var, x.make_scalar_dt(1)); | |||||
set_priority(y0, 0); | |||||
set_priority(y1, 1); | |||||
set_priority(y, 2); | |||||
set_priority(z, 3); | |||||
graph->options().graph_opt_level = 0; | |||||
graph->options().enable_sublinear_memory_opt = 1; | |||||
graph->options().sublinear_mem_cofig.genetic_nr_iter = 50; | |||||
auto func = graph->compile({{y, {}}, {z, {}}}); | |||||
auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get()) | |||||
->seq_modifier_for_sublinear_memory().prev_min_bottleneck(); | |||||
// bottleneck: | |||||
// if bad : y = y1 + 1, bad_var should be saved to calculate | |||||
// z later, total memory usage is | |||||
// N * sclae * 2(bad_var and y1) + 1 (immutable tensor 1) | |||||
// else : bad_var = BadOpr(x), total memory usage is | |||||
// N(x) + N * scale(bad_var), bad_var would be recomputed | |||||
// when calculate z = reduce(bad_var) | |||||
size_t expect = bad ? N * Scale * 2 + 1 : N * Scale + N; | |||||
ASSERT_EQ(results.at(cn), expect * host_x->dtype().size()); | |||||
size_t nr_bad_opr = 0; | |||||
auto count_up = [&nr_bad_opr](cg::OperatorNodeBase* op) { | |||||
if (op->dyn_typeinfo() == SublinearBadOpr::typeinfo()) { | |||||
++ nr_bad_opr; | |||||
} | |||||
return true; | |||||
}; | |||||
func->iter_opr_seq(count_up); | |||||
ASSERT_EQ(nr_bad_opr, bad ? 2 : 3); | |||||
} | |||||
} | } | ||||
#else | #else | ||||