Compare commits

...

6 Commits

Author SHA1 Message Date
  Xinran Xu 0b1d516e38 chore(release): bump version 5 years ago
  Megvii Engine Team f0959dc9c4 refactor(mge/examples): add example to use modified sublinear API 5 years ago
  Megvii Engine Team a2f0e8788f feat(mge/api): expose sublinear related parameters at mge api level 5 years ago
  Xinran Xu f44d2632da chore(release): bump version 5 years ago
  Megvii Engine Team d84dc7f75c build(python_module): use consistent flag declaration in SWIG and c++ 5 years ago
  Megvii Engine Team 03728d45da fix(mgb/build): fix multi-machine macro and add test_distributed 5 years ago
18 changed files with 313 additions and 101 deletions
Split View
  1. +5
    -3
      python_module/CMakeLists.txt
  2. +32
    -4
      python_module/megengine/jit/__init__.py
  3. +56
    -0
      python_module/megengine/jit/sublinear_memory_config.py
  4. +1
    -1
      python_module/megengine/version.py
  5. +7
    -1
      python_module/src/cpp/megbrain_config.cpp
  6. +0
    -2
      python_module/src/cpp/megbrain_config.h
  7. +18
    -10
      python_module/src/cpp/mm_handler.cpp
  8. +1
    -1
      python_module/src/cpp/mm_handler.h
  9. +1
    -0
      python_module/src/swig/mgb.i
  10. +17
    -3
      python_module/test/integration/test_correctness.py
  11. +89
    -0
      python_module/test/integration/test_distributed.py
  12. +12
    -0
      python_module/test/unit/jit/test_jit.py
  13. +0
    -2
      src/CMakeLists.txt
  14. +2
    -1
      src/core/impl/graph/cg_impl.cpp
  15. +15
    -21
      src/core/impl/graph/seq_sublinear_memory.cpp
  16. +6
    -1
      src/core/impl/graph/seq_sublinear_memory.h
  17. +10
    -0
      src/core/include/megbrain/graph/cg.h
  18. +41
    -51
      src/core/test/sublinear_memory.cpp

+ 5
- 3
python_module/CMakeLists.txt View File

@@ -55,16 +55,18 @@ add_custom_command(

add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py)

set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp)
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/mm_handler.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp)

if(MGE_WITH_DISTRIBUTED)
list(APPEND SRCS src/cpp/mm_handler.cpp src/cpp/zmq_rpc.cpp)
list(APPEND SRCS src/cpp/zmq_rpc.cpp)
endif()

include(UseSWIG)
set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON)

# cmake < 3.12 do not honor INCLUDE_DIRECTORIES property, just add include directory into SWIG_FLAGS
set_property(SOURCE ${SWIG_SRC} PROPERTY SWIG_FLAGS ${MGB_DEF} -I${PROJECT_SOURCE_DIR}/src/serialization/include)
# Add -I${PROJECT_BINARY_DIR}/genfiles in order to include megbrain_build_config.h so that we don't need to pass cmake flags by -D.
set_property(SOURCE ${SWIG_SRC} PROPERTY SWIG_FLAGS -I${PROJECT_SOURCE_DIR}/src/serialization/include -I${PROJECT_BINARY_DIR}/genfiles)

set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR})
set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal)


+ 32
- 4
python_module/megengine/jit/__init__.py View File

@@ -18,6 +18,7 @@ import megengine._internal as mgb
from megengine._internal.plugin import CompGraphProfiler

from ..core import Tensor, graph, tensor
from .sublinear_memory_config import SublinearMemoryConfig


def sideeffect(f):
@@ -78,10 +79,12 @@ class trace:
* accelerated evalutaion via :meth:`.__call__`

:param func: Positional only argument.
:param symbolic: Whether to use symbolic tensor.
:param symbolic: Whether to use symbolic tensor. Default: False
:param opt_level: Optimization level for compiling trace.
:param log_level: Log level.
:param profiling: Whether to profile compiled trace.
:param sublinear_memory_config: Configuration for sublinear memory optimization.
If not None, it enables sublinear memory optimization with given setting.
:param profiling: Whether to profile compiled trace. Default: False
"""

_active_instance = None
@@ -103,12 +106,14 @@ class trace:
symbolic: bool = False,
opt_level: int = None,
log_level: int = None,
sublinear_memory_config: SublinearMemoryConfig = None,
profiling: bool = False
):
self.__wrapped__ = func
self._symbolic = symbolic
self._graph_opt_level = opt_level
self._log_level = log_level
self._sublinear_memory_config = sublinear_memory_config
self._status = self._UNSTARTED
self._args = None
self._kwargs = None
@@ -280,11 +285,34 @@ class trace:

def _apply_graph_options(self, cg):
# graph opt level
if not self._graph_opt_level is None:
if self._graph_opt_level is not None:
cg.set_option("graph_opt_level", self._graph_opt_level)
# log level
if not self._log_level is None:
if self._log_level is not None:
cg.set_option("log_level", self._log_level)
# sublinear
if self._sublinear_memory_config is not None:
cg.set_option("enable_sublinear_memory_opt", True)
cg.set_option(
"sublinear_mem_cofig.lb_memory",
self._sublinear_memory_config.lb_memory,
)
cg.set_option(
"sublinear_mem_cofig.genetic_nr_iter",
self._sublinear_memory_config.genetic_nr_iter,
)
cg.set_option(
"sublinear_mem_cofig.genetic_pool_size",
self._sublinear_memory_config.genetic_pool_size,
)
cg.set_option(
"sublinear_mem_cofig.thresh_nr_try",
self._sublinear_memory_config.thresh_nr_try,
)
cg.set_option(
"sublinear_mem_cofig.num_worker",
self._sublinear_memory_config.num_worker,
)
# profile
if self._profiling:
self._profiler = CompGraphProfiler(cg)


+ 56
- 0
python_module/megengine/jit/sublinear_memory_config.py View File

@@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ..core.device import get_device_count


class SublinearMemoryConfig:
r"""
Configuration for sublinear memory optimization.

:param thresh_nr_try: number of samples both for searching in linear space
and around current thresh in sublinear memory optimization. Default: 10.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY'.
:param genetic_nr_iter: number of iterations to find the best checkpoints in genetic algorithm.
Default: 0.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER'.
:param genetic_pool_size: number of samples for the crossover random selection
during genetic optimization. Default: 20.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE'.
:param lb_memory: memory lower bound of bottleneck size in MB for sublinear memory optimization.
It can be used to perform manual tradeoff between memory and speed. Default: 0.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB'.
:param num_worker: number of thread workers to search the optimum checkpoints
in sublinear memory optimization. Default: half of cpu number in the system.
Note: the value must be greater or equal to one.
It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_WORKERS'.

Note that the environmental variable MGB_COMP_GRAPH_OPT must be set to 'enable_sublinear_memory_opt=1'
in order for the above environmental variable to be effective.
"""

def __init__(
self,
thresh_nr_try: int = 10,
genetic_nr_iter: int = 0,
genetic_pool_size: int = 20,
lb_memory: int = 0,
num_worker: int = max(1, get_device_count("cpu") // 2),
):
assert thresh_nr_try >= 0, "thresh_nr_try must be greater or equal to zero"
self.thresh_nr_try = thresh_nr_try
assert genetic_nr_iter >= 0, "genetic_nr_iter must be greater or equal to zero"
self.genetic_nr_iter = genetic_nr_iter
assert (
genetic_pool_size >= 0
), "genetic_pool_size must be greater or equal to zero"
self.genetic_pool_size = genetic_pool_size
self.lb_memory = lb_memory
assert num_worker > 0, "num_worker must be greater or equal to one"
self.num_worker = num_worker

+ 1
- 1
python_module/megengine/version.py View File

@@ -1,2 +1,2 @@
__version__ = "0.3.2"
__version__ = "0.3.4"


+ 7
- 1
python_module/src/cpp/megbrain_config.cpp View File

@@ -42,7 +42,8 @@ bool _config::set_comp_graph_option(
std::is_same<decltype(opt.name_chk), bool>::value || \
std::is_same<decltype(opt.name_chk), uint8_t>::value || \
std::is_same<decltype(opt.name_chk), int16_t>::value || \
std::is_same<decltype(opt.name_chk), uint16_t>::value, \
std::is_same<decltype(opt.name_chk), uint16_t>::value || \
std::is_same<decltype(opt.name_chk), int32_t>::value, \
"not bool/int opt"); \
if (name == #name_chk) { \
auto ret = opt.name_chk; \
@@ -66,6 +67,11 @@ bool _config::set_comp_graph_option(
SET_CG_OPTION(allocate_static_mem_after_graph_compile);
SET_CG_OPTION(log_level);
SET_CG_OPTION(enable_sublinear_memory_opt);
SET_CG_OPTION(sublinear_mem_cofig.lb_memory);
SET_CG_OPTION(sublinear_mem_cofig.genetic_nr_iter);
SET_CG_OPTION(sublinear_mem_cofig.genetic_pool_size);
SET_CG_OPTION(sublinear_mem_cofig.thresh_nr_try);
SET_CG_OPTION(sublinear_mem_cofig.num_worker);
SET_CG_OPTION(enable_var_mem_defragment);
SET_CG_OPTION(eager_evaluation);
SET_CG_OPTION(enable_memory_swap);


+ 0
- 2
python_module/src/cpp/megbrain_config.h View File

@@ -65,12 +65,10 @@ class _config {
static std::vector<std::pair<uint64_t, std::string>>
dump_registered_oprs();

#if MGB_ENABLE_OPR_MM
static int create_mm_server(const std::string& server_addr, int port);

static void group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank);
#endif
};

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 18
- 10
python_module/src/cpp/mm_handler.cpp View File

@@ -12,7 +12,7 @@
#include "megbrain/exception.h"
#include "megbrain_config.h"

#if MGB_CUDA
#if MGB_ENABLE_OPR_MM
#include "zmq_rpc.h"
#include <future>

@@ -242,17 +242,11 @@ int _config::create_mm_server(const std::string& server_addr, int port) {
server_addr, port, std::make_unique<GroupServerProxy>());
}

#else

int _config::create_mm_server(const std::string& server_addr, int port) {
mgb_throw(mgb::MegBrainError, "CUDA suppport disable at compile time");
return 0;
}

#endif

/* ======================== Group Barrier ========================== */

/*! see definition : src/cpp/megbrain_config.h.
* Block until all ranks in the group reach this barrier
*/
void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) {
mgb_assert(rank < size, "invalid rank %d", rank);
@@ -263,4 +257,18 @@ void _config::group_barrier(const std::string& server_addr,
mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp);
}

#else

int _config::create_mm_server(const std::string& server_addr, int port) {
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time");
return 0;
}

void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) {
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time");
}

#endif

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 1
- 1
python_module/src/cpp/mm_handler.h View File

@@ -11,7 +11,7 @@

#include "megbrain_build_config.h"

#if MGB_CUDA
#if MGB_ENABLE_OPR_MM

#include "zmq_rpc.h"



+ 1
- 0
python_module/src/swig/mgb.i View File

@@ -49,6 +49,7 @@ intb4 = _mgb.intb4
#include "plugin.h"
%}

%include "megbrain_build_config.h"
%include "comp_node.i"
%include "comp_graph.i"
%include "symbol_var.i"


+ 17
- 3
python_module/test/integration/test_correctness.py View File

@@ -17,6 +17,7 @@ import megengine as mge
import megengine.functional as F
from megengine import jit, tensor
from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.jit import SublinearMemoryConfig
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module
from megengine.optimizer import SGD
from megengine.test import assertTensorClose
@@ -130,7 +131,9 @@ def update_model(model_path):
mge.save(checkpoint, model_path)


def run_test(model_path, use_jit, use_symbolic):
def run_test(
model_path, use_jit, use_symbolic, sublinear_memory_config=None, max_err=None,
):

"""
Load the model with test cases and run the training for one iter.
@@ -152,11 +155,16 @@ def run_test(model_path, use_jit, use_symbolic):
data.set_value(checkpoint["data"])
label.set_value(checkpoint["label"])

max_err = 1e-5
if max_err is None:
max_err = 1e-5

train_func = train
if use_jit:
train_func = jit.trace(train_func, symbolic=use_symbolic)
train_func = jit.trace(
train_func,
symbolic=use_symbolic,
sublinear_memory_config=sublinear_memory_config,
)

opt.zero_grad()
loss = train_func(data, label, net=net, opt=opt)
@@ -183,3 +191,9 @@ def test_correctness():
run_test(model_path, False, False)
run_test(model_path, True, False)
run_test(model_path, True, True)

# sublinear
config = SublinearMemoryConfig(genetic_nr_iter=10)
run_test(
model_path, True, True, sublinear_memory_config=config, max_err=1e-5,
)

+ 89
- 0
python_module/test/integration/test_distributed.py View File

@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import subprocess
import sys

import numpy as np


def worker(master_ip, master_port, world_size, rank, dev, trace):
import megengine.distributed as dist
import megengine.functional as F
from megengine import is_cuda_available
from megengine import jit
from megengine.module import Linear, Module
from megengine.optimizer import SGD

if not is_cuda_available():
return

class MLP(Module):
def __init__(self):
super().__init__()
self.fc0 = Linear(3 * 224 * 224, 500)
self.fc1 = Linear(500, 10)

def forward(self, x):
x = self.fc0(x)
x = F.relu(x)
x = self.fc1(x)
return x

dist.init_process_group(
master_ip=master_ip, master_port=3456, world_size=world_size, rank=rank, dev=dev
)
net = MLP()

opt = SGD(net.parameters(requires_grad=True), lr=0.02)

data = np.random.random((64, 3 * 224 * 224)).astype(np.float32)
label = np.random.randint(0, 10, size=(64,)).astype(np.int32)

jit.trace.enabled = trace

@jit.trace()
def train_func(data, label):
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
opt.backward(loss)
return loss

for i in range(5):
opt.zero_grad()
loss = train_func(data, label)
opt.step()


def start_workers(worker, world_size, trace=False):
def run_subproc(rank):
cmd = "from test.integration.test_distributed import worker\n"
cmd += "worker('localhost', 3456, {}, {}, {}, {})".format(
world_size, rank, rank, "True" if trace else "False"
)
cmd = ["python3", "-c", cmd]
ret = subprocess.run(
cmd, stdout=sys.stdout, stderr=sys.stderr, universal_newlines=True
)
assert ret.returncode == 0, "subprocess failed"

procs = []
for rank in range(world_size):
p = mp.Process(target=run_subproc, args=(rank,))
p.start()
procs.append(p)

for p in procs:
p.join()
assert p.exitcode == 0


def test_distributed():
start_workers(worker, 2, trace=True)
start_workers(worker, 2, trace=False)

+ 12
- 0
python_module/test/unit/jit/test_jit.py View File

@@ -18,6 +18,7 @@ import megengine._internal as mgb
import megengine.module as M
from megengine import jit, tensor
from megengine.core.tensor import Tensor
from megengine.jit import SublinearMemoryConfig
from megengine.test import assertTensorClose


@@ -185,3 +186,14 @@ def test_dump_bn_fused():
mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder"
and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward"
)


# Simply verify the options passed down
def test_sublinear():
config = SublinearMemoryConfig(genetic_nr_iter=10)

@jit.trace(symbolic=True, sublinear_memory_config=config)
def f(x):
return x + x

f([0.0])

+ 0
- 2
src/CMakeLists.txt View File

@@ -31,10 +31,8 @@ if(MGE_WITH_CUDA AND MGE_WITH_TRT)
list(APPEND SOURCES ${SOURCES_})
endif()

set(MGB_DEF ${MGB_DEF} PARENT_SCOPE)
add_library(megbrain STATIC EXCLUDE_FROM_ALL ${SOURCES})
target_link_libraries(megbrain mgb_opr_param_defs)
target_compile_definitions(megbrain PUBLIC ${MGB_DEF})
target_include_directories(megbrain PUBLIC ${MGB_INC})

if(MGE_WITH_CUDA)


+ 2
- 1
src/core/impl/graph/cg_impl.cpp View File

@@ -217,7 +217,8 @@ ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner)
static_infer_comp_seq_manager{owner},
grad_manager{owner},
#if MGB_ENABLE_SUBLINEAR
seq_modifier_for_sublinear_memory{owner},
seq_modifier_for_sublinear_memory{owner,
&(owner->options().sublinear_mem_cofig)},
#endif
#if MGB_ENABLE_MEMORY_SWAP
memory_swap_support{owner},


+ 15
- 21
src/core/impl/graph/seq_sublinear_memory.cpp View File

@@ -681,14 +681,6 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN {
std::vector<std::future<void>> m_futures;
std::mutex m_mtx;

struct Config {
size_t thresh_nr_try = 10;
size_t genetic_nr_iter = 0;
size_t genetic_pool_size = 20;
double lb_memory = 0;
};
Config m_config;

/*!
* \brief check given thresh, and update states
* \return bottleneck value for given thresh
@@ -725,20 +717,22 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN {
public:
ActionSearcherSingleCN(SeqModifierForSublinearMemory* par)
: m_par_modifier{par} {
auto & m_config = m_par_modifier->m_config;
//! allow environmental variable to overwrite the setting
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY")) {
m_config.thresh_nr_try = std::stoi(env);
m_config->thresh_nr_try = std::stoi(env);
}
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER")) {
m_config.genetic_nr_iter = std::stoi(env);
m_config->genetic_nr_iter = std::stoi(env);
}
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE")) {
auto psize = static_cast<size_t>(std::stoi(env));
mgb_assert(psize > 0 || m_config.genetic_nr_iter == 0,
mgb_assert(psize > 0 || m_config->genetic_nr_iter == 0,
"invalid pool size %zu in genetic algorithm,", psize);
m_config.genetic_pool_size = psize;
m_config->genetic_pool_size = psize;
}
if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB")) {
m_config.lb_memory = std::stod(env) * 1024 * 1024;
m_config->lb_memory = std::stoi(env) * 1024 * 1024;
}
}

@@ -812,7 +806,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() {
invoke_search(thresh);
}

size_t NR_TRY = m_config.thresh_nr_try;
size_t NR_TRY = m_par_modifier->m_config->thresh_nr_try;

// search in linear space
auto step = init_thresh / (NR_TRY + 1);
@@ -833,8 +827,8 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() {

void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() {
RNGxorshf rng(2333);
size_t POOL_SIZE = m_config.genetic_pool_size;
size_t NR_ITER = m_config.genetic_nr_iter;
size_t POOL_SIZE = m_par_modifier->m_config->genetic_pool_size;
size_t NR_ITER = m_par_modifier->m_config->genetic_nr_iter;
auto mutation = [&](const SplitPointSet& sps) {
auto s = *sps;
size_t length = s.size();
@@ -953,7 +947,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() {
}

void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_refine() {
size_t lower_bound = m_config.lb_memory;
size_t lower_bound = m_par_modifier->m_config->lb_memory;
if (m_min_bottleneck >= lower_bound)
return;
OprFootprint footprint;
@@ -1052,7 +1046,7 @@ SeqModifierForSublinearMemory::ActionSearcherSingleCN::search(
msg.push_back('\n');
msg.append(ssprintf("m_min_bottleneck: %-10.2f\n",
m_min_bottleneck * SIZE2MB));
if(!m_config.genetic_nr_iter) {
if(!m_par_modifier->m_config->genetic_nr_iter) {
msg.append(ssprintf(
"\nGenetic algorithm is currently DISABLED, "
"set MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]"
@@ -1124,7 +1118,7 @@ SeqModifierForSublinearMemory::search_action(
"invalid planner concurrency: %zu", set);
planner_concur = set;
} else {
planner_concur = sys::get_cpu_count() / 2;
planner_concur = m_config->num_worker;
}

mgb_log_debug("use %zu threads to search for sublinear memory plan; "
@@ -1350,8 +1344,8 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() {
}

SeqModifierForSublinearMemory::SeqModifierForSublinearMemory(
ComputingGraphImpl* owner)
: m_mem_opt(owner), m_owner_graph(owner) {}
ComputingGraphImpl* owner, Config* config_p)
: m_config(config_p), m_mem_opt(owner), m_owner_graph(owner) {}

#endif // !MGB_ENABLE_SUBLINEAR



+ 6
- 1
src/core/impl/graph/seq_sublinear_memory.h View File

@@ -12,6 +12,7 @@
#pragma once

#include "./memory_optimizer.h"
#include "megbrain/graph/cg.h"
#include "megbrain/utils/async_worker.h"

#if MGB_ENABLE_SUBLINEAR
@@ -31,6 +32,10 @@ class SeqModifierForSublinearMemory {
using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>;
using SplitPointSet = std::shared_ptr<std::vector<size_t>>;

//! Config options
using Config = mgb::cg::ComputingGraph::Options::SublinearMemConfig;
Config* m_config;

//! get modifications to be taken under some specific constraints
class ModifyActionPlanner;

@@ -104,7 +109,7 @@ class SeqModifierForSublinearMemory {
}

public:
SeqModifierForSublinearMemory(ComputingGraphImpl* owner);
SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g);

//! see memory_optimizer set_priority_before_opt
void set_priority_before_opt(const VarNodeArray& endpoints) {


+ 10
- 0
src/core/include/megbrain/graph/cg.h View File

@@ -16,6 +16,7 @@
#include "megbrain/graph/static_infer.h"
#include "megbrain/graph/seq_comp_node_opt.h"
#include "megbrain/utils/event.h"
#include "megbrain/system.h"

#if MGB_ENABLE_JSON
#include "megbrain/utils/json.h"
@@ -300,6 +301,15 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
//! whether to enable sublinear memory optimization
bool enable_sublinear_memory_opt = false;

//! Control parameter for sublinear memory optimization
struct SublinearMemConfig {
int thresh_nr_try = 10;
int genetic_nr_iter = 0;
int genetic_pool_size = 20;
int lb_memory = 0;
int num_worker = sys::get_cpu_count() / 2;
} sublinear_mem_cofig;

//! do not re-profile to select best impl algo when input shape
//! changes (use previous algo)
bool no_profiling_on_shape_change = false;


+ 41
- 51
src/core/test/sublinear_memory.cpp View File

@@ -504,57 +504,47 @@ TEST(TestSublinearMemory, DepsInTopoSort) {
}

TEST(TestSublinearMemory, BadOpr) {
constexpr const char* KEY = "MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER";
auto old_value = getenv(KEY);
setenv(KEY, "50", 1);
MGB_TRY {
HostTensorGenerator<> gen;
auto cn = CompNode::load("xpu0");
constexpr size_t N = 1024, Scale = 2;
auto host_x = gen({N}, cn);
for (bool bad : {false, true}) {
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x),
bad_var = SublinearBadOpr::make(x, bad, Scale),
y0 = opr::reduce_sum(bad_var, x.make_scalar_dt(1)),
y1 = SublinearBadOpr::make(y0, false, N * Scale),
y = y1 + 1,
z = opr::reduce_max(bad_var, x.make_scalar_dt(1));
set_priority(y0, 0);
set_priority(y1, 1);
set_priority(y, 2);
set_priority(z, 3);
graph->options().graph_opt_level = 0;
graph->options().enable_sublinear_memory_opt = 1;
auto func = graph->compile({{y, {}}, {z, {}}});
auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get())
->seq_modifier_for_sublinear_memory().prev_min_bottleneck();
// bottleneck:
// if bad : y = y1 + 1, bad_var should be saved to calculate
// z later, total memory usage is
// N * sclae * 2(bad_var and y1) + 1 (immutable tensor 1)
// else : bad_var = BadOpr(x), total memory usage is
// N(x) + N * scale(bad_var), bad_var would be recomputed
// when calculate z = reduce(bad_var)
size_t expect = bad ? N * Scale * 2 + 1 : N * Scale + N;
ASSERT_EQ(results.at(cn), expect * host_x->dtype().size());
size_t nr_bad_opr = 0;
auto count_up = [&nr_bad_opr](cg::OperatorNodeBase* op) {
if (op->dyn_typeinfo() == SublinearBadOpr::typeinfo()) {
++ nr_bad_opr;
}
return true;
};
func->iter_opr_seq(count_up);
ASSERT_EQ(nr_bad_opr, bad ? 2 : 3);
}
} MGB_FINALLY(
if (old_value) {
setenv(KEY, old_value, 1);
} else {
unsetenv(KEY);
}
);
HostTensorGenerator<> gen;
auto cn = CompNode::load("xpu0");
constexpr size_t N = 1024, Scale = 2;
auto host_x = gen({N}, cn);
for (bool bad : {false, true}) {
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x),
bad_var = SublinearBadOpr::make(x, bad, Scale),
y0 = opr::reduce_sum(bad_var, x.make_scalar_dt(1)),
y1 = SublinearBadOpr::make(y0, false, N * Scale),
y = y1 + 1,
z = opr::reduce_max(bad_var, x.make_scalar_dt(1));
set_priority(y0, 0);
set_priority(y1, 1);
set_priority(y, 2);
set_priority(z, 3);
graph->options().graph_opt_level = 0;
graph->options().enable_sublinear_memory_opt = 1;
graph->options().sublinear_mem_cofig.genetic_nr_iter = 50;
auto func = graph->compile({{y, {}}, {z, {}}});
auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get())
->seq_modifier_for_sublinear_memory().prev_min_bottleneck();
// bottleneck:
// if bad : y = y1 + 1, bad_var should be saved to calculate
// z later, total memory usage is
// N * sclae * 2(bad_var and y1) + 1 (immutable tensor 1)
// else : bad_var = BadOpr(x), total memory usage is
// N(x) + N * scale(bad_var), bad_var would be recomputed
// when calculate z = reduce(bad_var)
size_t expect = bad ? N * Scale * 2 + 1 : N * Scale + N;
ASSERT_EQ(results.at(cn), expect * host_x->dtype().size());
size_t nr_bad_opr = 0;
auto count_up = [&nr_bad_opr](cg::OperatorNodeBase* op) {
if (op->dyn_typeinfo() == SublinearBadOpr::typeinfo()) {
++ nr_bad_opr;
}
return true;
};
func->iter_opr_seq(count_up);
ASSERT_EQ(nr_bad_opr, bad ? 2 : 3);
}
}

#else


Loading…
Cancel
Save