Compare commits

...

6 Commits

Author SHA1 Message Date
  Xinran Xu 0b1d516e38 chore(release): bump version 5 years ago
  Megvii Engine Team f0959dc9c4 refactor(mge/examples): add example to use modified sublinear API 5 years ago
  Megvii Engine Team a2f0e8788f feat(mge/api): expose sublinear related parameters at mge api level 5 years ago
  Xinran Xu f44d2632da chore(release): bump version 5 years ago
  Megvii Engine Team d84dc7f75c build(python_module): use consistent flag declaration in SWIG and c++ 5 years ago
  Megvii Engine Team 03728d45da fix(mgb/build): fix multi-machine macro and add test_distributed 5 years ago
18 changed files with 313 additions and 101 deletions
Unified View
  1. +5
    -3
      python_module/CMakeLists.txt
  2. +32
    -4
      python_module/megengine/jit/__init__.py
  3. +56
    -0
      python_module/megengine/jit/sublinear_memory_config.py
  4. +1
    -1
      python_module/megengine/version.py
  5. +7
    -1
      python_module/src/cpp/megbrain_config.cpp
  6. +0
    -2
      python_module/src/cpp/megbrain_config.h
  7. +18
    -10
      python_module/src/cpp/mm_handler.cpp
  8. +1
    -1
      python_module/src/cpp/mm_handler.h
  9. +1
    -0
      python_module/src/swig/mgb.i
  10. +17
    -3
      python_module/test/integration/test_correctness.py
  11. +89
    -0
      python_module/test/integration/test_distributed.py
  12. +12
    -0
      python_module/test/unit/jit/test_jit.py
  13. +0
    -2
      src/CMakeLists.txt
  14. +2
    -1
      src/core/impl/graph/cg_impl.cpp
  15. +15
    -21
      src/core/impl/graph/seq_sublinear_memory.cpp
  16. +6
    -1
      src/core/impl/graph/seq_sublinear_memory.h
  17. +10
    -0
      src/core/include/megbrain/graph/cg.h
  18. +41
    -51
      src/core/test/sublinear_memory.cpp

+ 5
- 3
python_module/CMakeLists.txt View File

@@ -55,16 +55,18 @@ add_custom_command(


add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py) add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py)


set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp)
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/mm_handler.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp)


if(MGE_WITH_DISTRIBUTED) if(MGE_WITH_DISTRIBUTED)
list(APPEND SRCS src/cpp/mm_handler.cpp src/cpp/zmq_rpc.cpp)
list(APPEND SRCS src/cpp/zmq_rpc.cpp)
endif() endif()


include(UseSWIG) include(UseSWIG)
set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON) set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON)

# cmake < 3.12 do not honor INCLUDE_DIRECTORIES property, just add include directory into SWIG_FLAGS # cmake < 3.12 do not honor INCLUDE_DIRECTORIES property, just add include directory into SWIG_FLAGS
set_property(SOURCE ${SWIG_SRC} PROPERTY SWIG_FLAGS ${MGB_DEF} -I${PROJECT_SOURCE_DIR}/src/serialization/include)
# Add -I${PROJECT_BINARY_DIR}/genfiles in order to include megbrain_build_config.h so that we don't need to pass cmake flags by -D.
set_property(SOURCE ${SWIG_SRC} PROPERTY SWIG_FLAGS -I${PROJECT_SOURCE_DIR}/src/serialization/include -I${PROJECT_BINARY_DIR}/genfiles)


set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR})
set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal) set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal)


+ 32
- 4
python_module/megengine/jit/__init__.py View File

@@ -18,6 +18,7 @@ import megengine._internal as mgb
from megengine._internal.plugin import CompGraphProfiler from megengine._internal.plugin import CompGraphProfiler


from ..core import Tensor, graph, tensor from ..core import Tensor, graph, tensor
from .sublinear_memory_config import SublinearMemoryConfig




def sideeffect(f): def sideeffect(f):
@@ -78,10 +79,12 @@ class trace:
* accelerated evalutaion via :meth:`.__call__` * accelerated evalutaion via :meth:`.__call__`


:param func: Positional only argument. :param func: Positional only argument.
:param symbolic: Whether to use symbolic tensor.
:param symbolic: Whether to use symbolic tensor. Default: False
:param opt_level: Optimization level for compiling trace. :param opt_level: Optimization level for compiling trace.
:param log_level: Log level. :param log_level: Log level.
:param profiling: Whether to profile compiled trace.
:param sublinear_memory_config: Configuration for sublinear memory optimization.
If not None, it enables sublinear memory optimization with given setting.
:param profiling: Whether to profile compiled trace. Default: False
""" """


_active_instance = None _active_instance = None
@@ -103,12 +106,14 @@ class trace:
symbolic: bool = False, symbolic: bool = False,
opt_level: int = None, opt_level: int = None,
log_level: int = None, log_level: int = None,
sublinear_memory_config: SublinearMemoryConfig = None,
profiling: bool = False profiling: bool = False
): ):
self.__wrapped__ = func self.__wrapped__ = func
self._symbolic = symbolic self._symbolic = symbolic
self._graph_opt_level = opt_level self._graph_opt_level = opt_level
self._log_level = log_level self._log_level = log_level
self._sublinear_memory_config = sublinear_memory_config
self._status = self._UNSTARTED self._status = self._UNSTARTED
self._args = None self._args = None
self._kwargs = None self._kwargs = None
@@ -280,11 +285,34 @@ class trace:


def _apply_graph_options(self, cg): def _apply_graph_options(self, cg):
# graph opt level # graph opt level
if not self._graph_opt_level is None:
if self._graph_opt_level is not None:
cg.set_option("graph_opt_level", self._graph_opt_level) cg.set_option("graph_opt_level", self._graph_opt_level)
# log level # log level
if not self._log_level is None:
if self._log_level is not None:
cg.set_option("log_level", self._log_level) cg.set_option("log_level", self._log_level)
# sublinear
if self._sublinear_memory_config is not None:
cg.set_option("enable_sublinear_memory_opt", True)
cg.set_option(
"sublinear_mem_cofig.lb_memory",
self._sublinear_memory_config.lb_memory,
)
cg.set_option(
"sublinear_mem_cofig.genetic_nr_iter",
self._sublinear_memory_config.genetic_nr_iter,
)
cg.set_option(
"sublinear_mem_cofig.genetic_pool_size",
self._sublinear_memory_config.genetic_pool_size,
)
cg.set_option(
"sublinear_mem_cofig.thresh_nr_try",
self._sublinear_memory_config.thresh_nr_try,
)
cg.set_option(
"sublinear_mem_cofig.num_worker",
self._sublinear_memory_config.num_worker,
)
# profile # profile
if self._profiling: if self._profiling:
self._profiler = CompGraphProfiler(cg) self._profiler = CompGraphProfiler(cg)


+ 56
- 0
python_module/megengine/jit/sublinear_memory_config.py View File

@@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ..core.device import get_device_count


class SublinearMemoryConfig:
r"""
Configuration for sublinear memory optimization.

:param thresh_nr_try: number of samples both for searching in linear space
and around current thresh in sublinear memory optimization. Default: 10.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY'.
:param genetic_nr_iter: number of iterations to find the best checkpoints in genetic algorithm.
Default: 0.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER'.
:param genetic_pool_size: number of samples for the crossover random selection
during genetic optimization. Default: 20.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE'.
:param lb_memory: memory lower bound of bottleneck size in MB for sublinear memory optimization.
It can be used to perform manual tradeoff between memory and speed. Default: 0.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB'.
:param num_worker: number of thread workers to search the optimum checkpoints
in sublinear memory optimization. Default: half of cpu number in the system.
Note: the value must be greater or equal to one.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_WORKERS'.

Note that the environmental variable MGB_COMP_GRAPH_OPT must be set to 'enable_sublinear_memory_opt=1'
in order for the above environmental variable to be effective.
"""

def __init__(
self,
thresh_nr_try: int = 10,
genetic_nr_iter: int = 0,
genetic_pool_size: int = 20,
lb_memory: int = 0,
num_worker: int = max(1, get_device_count("cpu") // 2),
):
assert thresh_nr_try >= 0, "thresh_nr_try must be greater or equal to zero"
self.thresh_nr_try = thresh_nr_try
assert genetic_nr_iter >= 0, "genetic_nr_iter must be greater or equal to zero"
self.genetic_nr_iter = genetic_nr_iter
assert (
genetic_pool_size >= 0
), "genetic_pool_size must be greater or equal to zero"
self.genetic_pool_size = genetic_pool_size
self.lb_memory = lb_memory
assert num_worker > 0, "num_worker must be greater or equal to one"
self.num_worker = num_worker

+ 1
- 1
python_module/megengine/version.py View File

@@ -1,2 +1,2 @@
__version__ = "0.3.2"
__version__ = "0.3.4"



+ 7
- 1
python_module/src/cpp/megbrain_config.cpp View File

@@ -42,7 +42,8 @@ bool _config::set_comp_graph_option(
std::is_same<decltype(opt.name_chk), bool>::value || \ std::is_same<decltype(opt.name_chk), bool>::value || \
std::is_same<decltype(opt.name_chk), uint8_t>::value || \ std::is_same<decltype(opt.name_chk), uint8_t>::value || \
std::is_same<decltype(opt.name_chk), int16_t>::value || \ std::is_same<decltype(opt.name_chk), int16_t>::value || \
std::is_same<decltype(opt.name_chk), uint16_t>::value, \
std::is_same<decltype(opt.name_chk), uint16_t>::value || \
std::is_same<decltype(opt.name_chk), int32_t>::value, \
"not bool/int opt"); \ "not bool/int opt"); \
if (name == #name_chk) { \ if (name == #name_chk) { \
auto ret = opt.name_chk; \ auto ret = opt.name_chk; \
@@ -66,6 +67,11 @@ bool _config::set_comp_graph_option(
SET_CG_OPTION(allocate_static_mem_after_graph_compile); SET_CG_OPTION(allocate_static_mem_after_graph_compile);
SET_CG_OPTION(log_level); SET_CG_OPTION(log_level);
SET_CG_OPTION(enable_sublinear_memory_opt); SET_CG_OPTION(enable_sublinear_memory_opt);
SET_CG_OPTION(sublinear_mem_cofig.lb_memory);
SET_CG_OPTION(sublinear_mem_cofig.genetic_nr_iter);
SET_CG_OPTION(sublinear_mem_cofig.genetic_pool_size);
SET_CG_OPTION(sublinear_mem_cofig.thresh_nr_try);
SET_CG_OPTION(sublinear_mem_cofig.num_worker);
SET_CG_OPTION(enable_var_mem_defragment); SET_CG_OPTION(enable_var_mem_defragment);
SET_CG_OPTION(eager_evaluation); SET_CG_OPTION(eager_evaluation);
SET_CG_OPTION(enable_memory_swap); SET_CG_OPTION(enable_memory_swap);


+ 0
- 2
python_module/src/cpp/megbrain_config.h View File

@@ -65,12 +65,10 @@ class _config {
static std::vector<std::pair<uint64_t, std::string>> static std::vector<std::pair<uint64_t, std::string>>
dump_registered_oprs(); dump_registered_oprs();


#if MGB_ENABLE_OPR_MM
static int create_mm_server(const std::string& server_addr, int port); static int create_mm_server(const std::string& server_addr, int port);


static void group_barrier(const std::string& server_addr, static void group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank); int port, uint32_t size, uint32_t rank);
#endif
}; };


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 18
- 10
python_module/src/cpp/mm_handler.cpp View File

@@ -12,7 +12,7 @@
#include "megbrain/exception.h" #include "megbrain/exception.h"
#include "megbrain_config.h" #include "megbrain_config.h"


#if MGB_CUDA
#if MGB_ENABLE_OPR_MM
#include "zmq_rpc.h" #include "zmq_rpc.h"
#include <future> #include <future>


@@ -242,17 +242,11 @@ int _config::create_mm_server(const std::string& server_addr, int port) {
server_addr, port, std::make_unique<GroupServerProxy>()); server_addr, port, std::make_unique<GroupServerProxy>());
} }


#else

int _config::create_mm_server(const std::string& server_addr, int port) {
mgb_throw(mgb::MegBrainError, "CUDA suppport disable at compile time");
return 0;
}

#endif

/* ======================== Group Barrier ========================== */ /* ======================== Group Barrier ========================== */


/*! see definition : src/cpp/megbrain_config.h.
* Block until all ranks in the group reach this barrier
*/
void _config::group_barrier(const std::string& server_addr, void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) { int port, uint32_t size, uint32_t rank) {
mgb_assert(rank < size, "invalid rank %d", rank); mgb_assert(rank < size, "invalid rank %d", rank);
@@ -263,4 +257,18 @@ void _config::group_barrier(const std::string& server_addr,
mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp); mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp);
} }


#else

int _config::create_mm_server(const std::string& server_addr, int port) {
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time");
return 0;
}

void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) {
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time");
}

#endif

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 1
- 1
python_module/src/cpp/mm_handler.h View File

@@ -11,7 +11,7 @@


#include "megbrain_build_config.h" #include "megbrain_build_config.h"


#if MGB_CUDA
#if MGB_ENABLE_OPR_MM


#include "zmq_rpc.h" #include "zmq_rpc.h"




+ 1
- 0
python_module/src/swig/mgb.i View File

@@ -49,6 +49,7 @@ intb4 = _mgb.intb4
#include "plugin.h" #include "plugin.h"
%} %}


%include "megbrain_build_config.h"
%include "comp_node.i" %include "comp_node.i"
%include "comp_graph.i" %include "comp_graph.i"
%include "symbol_var.i" %include "symbol_var.i"


+ 17
- 3
python_module/test/integration/test_correctness.py View File

@@ -17,6 +17,7 @@ import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine import jit, tensor from megengine import jit, tensor
from megengine.functional.debug_param import set_conv_execution_strategy from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.jit import SublinearMemoryConfig
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module
from megengine.optimizer import SGD from megengine.optimizer import SGD
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
@@ -130,7 +131,9 @@ def update_model(model_path):
mge.save(checkpoint, model_path) mge.save(checkpoint, model_path)




def run_test(model_path, use_jit, use_symbolic):
def run_test(
model_path, use_jit, use_symbolic, sublinear_memory_config=None, max_err=None,
):


""" """
Load the model with test cases and run the training for one iter. Load the model with test cases and run the training for one iter.
@@ -152,11 +155,16 @@ def run_test(model_path, use_jit, use_symbolic):
data.set_value(checkpoint["data"]) data.set_value(checkpoint["data"])
label.set_value(checkpoint["label"]) label.set_value(checkpoint["label"])


max_err = 1e-5
if max_err is None:
max_err = 1e-5


train_func = train train_func = train
if use_jit: if use_jit:
train_func = jit.trace(train_func, symbolic=use_symbolic)
train_func = jit.trace(
train_func,
symbolic=use_symbolic,
sublinear_memory_config=sublinear_memory_config,
)


opt.zero_grad() opt.zero_grad()
loss = train_func(data, label, net=net, opt=opt) loss = train_func(data, label, net=net, opt=opt)
@@ -183,3 +191,9 @@ def test_correctness():
run_test(model_path, False, False) run_test(model_path, False, False)
run_test(model_path, True, False) run_test(model_path, True, False)
run_test(model_path, True, True) run_test(model_path, True, True)

# sublinear
config = SublinearMemoryConfig(genetic_nr_iter=10)
run_test(
model_path, True, True, sublinear_memory_config=config, max_err=1e-5,
)

+ 89
- 0
python_module/test/integration/test_distributed.py View File

@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import subprocess
import sys

import numpy as np


def worker(master_ip, master_port, world_size, rank, dev, trace):
import megengine.distributed as dist
import megengine.functional as F
from megengine import is_cuda_available
from megengine import jit
from megengine.module import Linear, Module
from megengine.optimizer import SGD

if not is_cuda_available():
return

class MLP(Module):
def __init__(self):
super().__init__()
self.fc0 = Linear(3 * 224 * 224, 500)
self.fc1 = Linear(500, 10)

def forward(self, x):
x = self.fc0(x)
x = F.relu(x)
x = self.fc1(x)
return x

dist.init_process_group(
master_ip=master_ip, master_port=3456, world_size=world_size, rank=rank, dev=dev
)
net = MLP()

opt = SGD(net.parameters(requires_grad=True), lr=0.02)

data = np.random.random((64, 3 * 224 * 224)).astype(np.float32)
label = np.random.randint(0, 10, size=(64,)).astype(np.int32)

jit.trace.enabled = trace

@jit.trace()
def train_func(data, label):
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
opt.backward(loss)
return loss

for i in range(5):
opt.zero_grad()
loss = train_func(data, label)
opt.step()


def start_workers(worker, world_size, trace=False):
def run_subproc(rank):
cmd = "from test.integration.test_distributed import worker\n"
cmd += "worker('localhost', 3456, {}, {}, {}, {})".format(
world_size, rank, rank, "True" if trace else "False"
)
cmd = ["python3", "-c", cmd]
ret = subprocess.run(
cmd, stdout=sys.stdout, stderr=sys.stderr, universal_newlines=True
)
assert ret.returncode == 0, "subprocess failed"

procs = []
for rank in range(world_size):
p = mp.Process(target=run_subproc, args=(rank,))
p.start()
procs.append(p)

for p in procs:
p.join()
assert p.exitcode == 0


def test_distributed():
start_workers(worker, 2, trace=True)
start_workers(worker, 2, trace=False)

+ 12
- 0
python_module/test/unit/jit/test_jit.py View File

@@ -18,6 +18,7 @@ import megengine._internal as mgb
import megengine.module as M import megengine.module as M
from megengine import jit, tensor from megengine import jit, tensor
from megengine.core.tensor import Tensor from megengine.core.tensor import Tensor
from megengine.jit import SublinearMemoryConfig
from megengine.test import assertTensorClose from megengine.test import assertTensorClose




@@ -185,3 +186,14 @@ def test_dump_bn_fused():
mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder" mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder"
and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward" and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward"
) )


# Simply verify the options passed down
def test_sublinear():
config = SublinearMemoryConfig(genetic_nr_iter=10)

@jit.trace(symbolic=True, sublinear_memory_config=config)
def f(x):
return x + x

f([0.0])

+ 0
- 2
src/CMakeLists.txt View File

@@ -31,10 +31,8 @@ if(MGE_WITH_CUDA AND MGE_WITH_TRT)
list(APPEND SOURCES ${SOURCES_}) list(APPEND SOURCES ${SOURCES_})
endif() endif()


set(MGB_DEF ${MGB_DEF} PARENT_SCOPE)
add_library(megbrain STATIC EXCLUDE_FROM_ALL ${SOURCES}) add_library(megbrain STATIC EXCLUDE_FROM_ALL ${SOURCES})
target_link_libraries(megbrain mgb_opr_param_defs) target_link_libraries(megbrain mgb_opr_param_defs)
target_compile_definitions(megbrain PUBLIC ${MGB_DEF})
target_include_directories(megbrain PUBLIC ${MGB_INC}) target_include_directories(megbrain PUBLIC ${MGB_INC})


if(MGE_WITH_CUDA) if(MGE_WITH_CUDA)


+ 2
- 1
src/core/impl/graph/cg_impl.cpp View File

@@ -217,7 +217,8 @@ ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner)
static_infer_comp_seq_manager{owner}, static_infer_comp_seq_manager{owner},
grad_manager{owner}, grad_manager{owner},
#if MGB_ENABLE_SUBLINEAR #if MGB_ENABLE_SUBLINEAR
seq_modifier_for_sublinear_memory{owner},
seq_modifier_for_sublinear_memory{owner,
&(owner->options().sublinear_mem_cofig)},
#endif #endif
#if MGB_ENABLE_MEMORY_SWAP #if MGB_ENABLE_MEMORY_SWAP
memory_swap_support{owner}, memory_swap_support{owner},


+ 15
- 21
src/core/impl/graph/seq_sublinear_memory.cpp View File

@@ -681,14 +681,6 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN {
std::vector<std::future<void>> m_futures; std::vector<std::future<void>> m_futures;
std::mutex m_mtx; std::mutex m_mtx;


struct Config {
size_t thresh_nr_try = 10;
size_t genetic_nr_iter = 0;
size_t genetic_pool_size = 20;
double lb_memory = 0;
};
Config m_config;

/*! /*!
* \brief check given thresh, and update states * \brief check given thresh, and update states
* \return bottleneck value for given thresh * \return bottleneck value for given thresh
@@ -725,20 +717,22 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN {
public: public:
ActionSearcherSingleCN(SeqModifierForSublinearMemory* par) ActionSearcherSingleCN(SeqModifierForSublinearMemory* par)
: m_par_modifier{par} { : m_par_modifier{par} {
auto & m_config = m_par_modifier->m_config;
//! allow environmental variable to overwrite the setting
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY")) { if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY")) {
m_config.thresh_nr_try = std::stoi(env);
m_config->thresh_nr_try = std::stoi(env);
} }
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER")) { if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER")) {
m_config.genetic_nr_iter = std::stoi(env);
m_config->genetic_nr_iter = std::stoi(env);
} }
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE")) { if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE")) {
auto psize = static_cast<size_t>(std::stoi(env)); auto psize = static_cast<size_t>(std::stoi(env));
mgb_assert(psize > 0 || m_config.genetic_nr_iter == 0,
mgb_assert(psize > 0 || m_config->genetic_nr_iter == 0,
"invalid pool size %zu in genetic algorithm,", psize); "invalid pool size %zu in genetic algorithm,", psize);
m_config.genetic_pool_size = psize;
m_config->genetic_pool_size = psize;
} }
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB")) { if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB")) {
m_config.lb_memory = std::stod(env) * 1024 * 1024;
m_config->lb_memory = std::stoi(env) * 1024 * 1024;
} }
} }


@@ -812,7 +806,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() {
invoke_search(thresh); invoke_search(thresh);
} }


size_t NR_TRY = m_config.thresh_nr_try;
size_t NR_TRY = m_par_modifier->m_config->thresh_nr_try;


// search in linear space // search in linear space
auto step = init_thresh / (NR_TRY + 1); auto step = init_thresh / (NR_TRY + 1);
@@ -833,8 +827,8 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() {


void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() { void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() {
RNGxorshf rng(2333); RNGxorshf rng(2333);
size_t POOL_SIZE = m_config.genetic_pool_size;
size_t NR_ITER = m_config.genetic_nr_iter;
size_t POOL_SIZE = m_par_modifier->m_config->genetic_pool_size;
size_t NR_ITER = m_par_modifier->m_config->genetic_nr_iter;
auto mutation = [&](const SplitPointSet& sps) { auto mutation = [&](const SplitPointSet& sps) {
auto s = *sps; auto s = *sps;
size_t length = s.size(); size_t length = s.size();
@@ -953,7 +947,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() {
} }


void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_refine() { void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_refine() {
size_t lower_bound = m_config.lb_memory;
size_t lower_bound = m_par_modifier->m_config->lb_memory;
if (m_min_bottleneck >= lower_bound) if (m_min_bottleneck >= lower_bound)
return; return;
OprFootprint footprint; OprFootprint footprint;
@@ -1052,7 +1046,7 @@ SeqModifierForSublinearMemory::ActionSearcherSingleCN::search(
msg.push_back('\n'); msg.push_back('\n');
msg.append(ssprintf("m_min_bottleneck: %-10.2f\n", msg.append(ssprintf("m_min_bottleneck: %-10.2f\n",
m_min_bottleneck * SIZE2MB)); m_min_bottleneck * SIZE2MB));
if(!m_config.genetic_nr_iter) {
if(!m_par_modifier->m_config->genetic_nr_iter) {
msg.append(ssprintf( msg.append(ssprintf(
"\nGenetic algorithm is currently DISABLED, " "\nGenetic algorithm is currently DISABLED, "
"set MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]" "set MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]"
@@ -1124,7 +1118,7 @@ SeqModifierForSublinearMemory::search_action(
"invalid planner concurrency: %zu", set); "invalid planner concurrency: %zu", set);
planner_concur = set; planner_concur = set;
} else { } else {
planner_concur = sys::get_cpu_count() / 2;
planner_concur = m_config->num_worker;
} }


mgb_log_debug("use %zu threads to search for sublinear memory plan; " mgb_log_debug("use %zu threads to search for sublinear memory plan; "
@@ -1350,8 +1344,8 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() {
} }


SeqModifierForSublinearMemory::SeqModifierForSublinearMemory( SeqModifierForSublinearMemory::SeqModifierForSublinearMemory(
ComputingGraphImpl* owner)
: m_mem_opt(owner), m_owner_graph(owner) {}
ComputingGraphImpl* owner, Config* config_p)
: m_config(config_p), m_mem_opt(owner), m_owner_graph(owner) {}


#endif // !MGB_ENABLE_SUBLINEAR #endif // !MGB_ENABLE_SUBLINEAR




+ 6
- 1
src/core/impl/graph/seq_sublinear_memory.h View File

@@ -12,6 +12,7 @@
#pragma once #pragma once


#include "./memory_optimizer.h" #include "./memory_optimizer.h"
#include "megbrain/graph/cg.h"
#include "megbrain/utils/async_worker.h" #include "megbrain/utils/async_worker.h"


#if MGB_ENABLE_SUBLINEAR #if MGB_ENABLE_SUBLINEAR
@@ -31,6 +32,10 @@ class SeqModifierForSublinearMemory {
using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>; using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>;
using SplitPointSet = std::shared_ptr<std::vector<size_t>>; using SplitPointSet = std::shared_ptr<std::vector<size_t>>;


//! Config options
using Config = mgb::cg::ComputingGraph::Options::SublinearMemConfig;
Config* m_config;

//! get modifications to be taken under some specific constraints //! get modifications to be taken under some specific constraints
class ModifyActionPlanner; class ModifyActionPlanner;


@@ -104,7 +109,7 @@ class SeqModifierForSublinearMemory {
} }


public: public:
SeqModifierForSublinearMemory(ComputingGraphImpl* owner);
SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g);


//! see memory_optimizer set_priority_before_opt //! see memory_optimizer set_priority_before_opt
void set_priority_before_opt(const VarNodeArray& endpoints) { void set_priority_before_opt(const VarNodeArray& endpoints) {


+ 10
- 0
src/core/include/megbrain/graph/cg.h View File

@@ -16,6 +16,7 @@
#include "megbrain/graph/static_infer.h" #include "megbrain/graph/static_infer.h"
#include "megbrain/graph/seq_comp_node_opt.h" #include "megbrain/graph/seq_comp_node_opt.h"
#include "megbrain/utils/event.h" #include "megbrain/utils/event.h"
#include "megbrain/system.h"


#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
#include "megbrain/utils/json.h" #include "megbrain/utils/json.h"
@@ -300,6 +301,15 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
//! whether to enable sublinear memory optimization //! whether to enable sublinear memory optimization
bool enable_sublinear_memory_opt = false; bool enable_sublinear_memory_opt = false;


//! Control parameter for sublinear memory optimization
struct SublinearMemConfig {
int thresh_nr_try = 10;
int genetic_nr_iter = 0;
int genetic_pool_size = 20;
int lb_memory = 0;
int num_worker = sys::get_cpu_count() / 2;
} sublinear_mem_cofig;

//! do not re-profile to select best impl algo when input shape //! do not re-profile to select best impl algo when input shape
//! changes (use previous algo) //! changes (use previous algo)
bool no_profiling_on_shape_change = false; bool no_profiling_on_shape_change = false;


+ 41
- 51
src/core/test/sublinear_memory.cpp View File

@@ -504,57 +504,47 @@ TEST(TestSublinearMemory, DepsInTopoSort) {
} }


TEST(TestSublinearMemory, BadOpr) { TEST(TestSublinearMemory, BadOpr) {
constexpr const char* KEY = "MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER";
auto old_value = getenv(KEY);
setenv(KEY, "50", 1);
MGB_TRY {
HostTensorGenerator<> gen;
auto cn = CompNode::load("xpu0");
constexpr size_t N = 1024, Scale = 2;
auto host_x = gen({N}, cn);
for (bool bad : {false, true}) {
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x),
bad_var = SublinearBadOpr::make(x, bad, Scale),
y0 = opr::reduce_sum(bad_var, x.make_scalar_dt(1)),
y1 = SublinearBadOpr::make(y0, false, N * Scale),
y = y1 + 1,
z = opr::reduce_max(bad_var, x.make_scalar_dt(1));
set_priority(y0, 0);
set_priority(y1, 1);
set_priority(y, 2);
set_priority(z, 3);
graph->options().graph_opt_level = 0;
graph->options().enable_sublinear_memory_opt = 1;
auto func = graph->compile({{y, {}}, {z, {}}});
auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get())
->seq_modifier_for_sublinear_memory().prev_min_bottleneck();
// bottleneck:
// if bad : y = y1 + 1, bad_var should be saved to calculate
// z later, total memory usage is
// N * sclae * 2(bad_var and y1) + 1 (immutable tensor 1)
// else : bad_var = BadOpr(x), total memory usage is
// N(x) + N * scale(bad_var), bad_var would be recomputed
// when calculate z = reduce(bad_var)
size_t expect = bad ? N * Scale * 2 + 1 : N * Scale + N;
ASSERT_EQ(results.at(cn), expect * host_x->dtype().size());
size_t nr_bad_opr = 0;
auto count_up = [&nr_bad_opr](cg::OperatorNodeBase* op) {
if (op->dyn_typeinfo() == SublinearBadOpr::typeinfo()) {
++ nr_bad_opr;
}
return true;
};
func->iter_opr_seq(count_up);
ASSERT_EQ(nr_bad_opr, bad ? 2 : 3);
}
} MGB_FINALLY(
if (old_value) {
setenv(KEY, old_value, 1);
} else {
unsetenv(KEY);
}
);
HostTensorGenerator<> gen;
auto cn = CompNode::load("xpu0");
constexpr size_t N = 1024, Scale = 2;
auto host_x = gen({N}, cn);
for (bool bad : {false, true}) {
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x),
bad_var = SublinearBadOpr::make(x, bad, Scale),
y0 = opr::reduce_sum(bad_var, x.make_scalar_dt(1)),
y1 = SublinearBadOpr::make(y0, false, N * Scale),
y = y1 + 1,
z = opr::reduce_max(bad_var, x.make_scalar_dt(1));
set_priority(y0, 0);
set_priority(y1, 1);
set_priority(y, 2);
set_priority(z, 3);
graph->options().graph_opt_level = 0;
graph->options().enable_sublinear_memory_opt = 1;
graph->options().sublinear_mem_cofig.genetic_nr_iter = 50;
auto func = graph->compile({{y, {}}, {z, {}}});
auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get())
->seq_modifier_for_sublinear_memory().prev_min_bottleneck();
// bottleneck:
// if bad : y = y1 + 1, bad_var should be saved to calculate
// z later, total memory usage is
// N * sclae * 2(bad_var and y1) + 1 (immutable tensor 1)
// else : bad_var = BadOpr(x), total memory usage is
// N(x) + N * scale(bad_var), bad_var would be recomputed
// when calculate z = reduce(bad_var)
size_t expect = bad ? N * Scale * 2 + 1 : N * Scale + N;
ASSERT_EQ(results.at(cn), expect * host_x->dtype().size());
size_t nr_bad_opr = 0;
auto count_up = [&nr_bad_opr](cg::OperatorNodeBase* op) {
if (op->dyn_typeinfo() == SublinearBadOpr::typeinfo()) {
++ nr_bad_opr;
}
return true;
};
func->iter_opr_seq(count_up);
ASSERT_EQ(nr_bad_opr, bad ? 2 : 3);
}
} }


#else #else


Loading…
Cancel
Save