Browse Source

feat(mge): enable memory swap and drop/recomputation

GitOrigin-RevId: c56c87c88b
release-1.2
Megvii Engine Team 4 years ago
parent
commit
7aa54b0ec6
9 changed files with 463 additions and 14 deletions
  1. +12
    -0
      imperative/python/megengine/core/tensor/raw_tensor/__init__.py
  2. +9
    -0
      imperative/python/megengine/core/tensor/tensor.py
  3. +9
    -0
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  4. +27
    -0
      imperative/python/megengine/jit/tracing.py
  5. +12
    -1
      imperative/python/src/imperative_rt.cpp
  6. +124
    -0
      imperative/python/test/integration/test_converge_with_swap_and_drop.py
  7. +196
    -11
      imperative/src/impl/interpreter_impl.cpp
  8. +69
    -2
      imperative/src/impl/interpreter_impl.h
  9. +5
    -0
      imperative/src/include/megbrain/imperative/interpreter.h

+ 12
- 0
imperative/python/megengine/core/tensor/raw_tensor/__init__.py View File

@@ -12,7 +12,10 @@ import numpy as np

from ..._imperative_rt import CompNode, DeviceTensorND
from ..._imperative_rt.imperative import (
_drop,
_get_dev_tensor,
_swap_in,
_swap_out,
apply_op,
delete,
get_device,
@@ -63,6 +66,15 @@ class RawTensor(TensorBase):
def _dev_tensor(self):
return _get_dev_tensor(self._handle)

def _drop(self):
_drop(self._handle)

def _swap_in(self):
_swap_in(self._handle)

def _swap_out(self):
_swap_out(self._handle)

def __repr__(self):
return "{}({}, device='{}')".format(
type(self).__qualname__, repr(self.numpy()), self.device


+ 9
- 0
imperative/python/megengine/core/tensor/tensor.py View File

@@ -53,6 +53,15 @@ class Tensor(TensorBase):
def numpy(self):
return self._data.numpy()

def _drop(self):
self._data._drop()

def _swap_in(self):
self._data._swap_in()

def _swap_out(self):
self._data._swap_out()


class ApplyContext:
__slots__ = ("inputs", "outputs", "key")


+ 9
- 0
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -473,6 +473,15 @@ class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase):
def numpy(self):
return self.__wrapped__.numpy()

def _drop(self):
self.__wrapped__._drop()

def _swap_in(self):
self.__wrapped__._swap_in()

def _swap_out(self):
self.__wrapped__._swap_out()


class TensorWrapper(GenericTensorWrapper):
def __init__(self, data, dtype=None, device=None):


+ 27
- 0
imperative/python/megengine/jit/tracing.py View File

@@ -966,6 +966,15 @@ class CompiledTensorProxy(RawTensor):
self.__data = self.__info.data_reader.get_value()
return self.__data

def _drop(self):
return

def _swap_in(self):
return

def _swap_out(self):
return

def __del__(self):
if self.__info.shape_read and self.__shape is not None:
self.__info.shape_reader.drop_value()
@@ -1001,6 +1010,15 @@ class LazyEvalTensor(RawTensor):
ret = ret.squeeze()
return ret

def _drop(self):
return

def _swap_in(self):
return

def _swap_out(self):
return

def _dev_tensor(self):
raise RuntimeError("cannot access data during symbolic tracing")

@@ -1042,6 +1060,15 @@ class TraceMixin:
active_trace._require_data(self.__handle)
return super()._dev_tensor()

def _drop(self):
return

def _swap_in(self):
return

def _swap_out(self):
return


class TracedRawTensor(TraceMixin, RawTensor):
pass


+ 12
- 1
imperative/python/src/imperative_rt.cpp View File

@@ -68,6 +68,15 @@ void init_imperative_rt(py::module m) {
.def("delete", [](Interpreter::Channel& self, Interpreter::Handle handle) {
return self.del(handle);
})
.def("_swap_in", [](Interpreter::Channel& self, Interpreter::Handle handle) {
self.swap_in(handle);
})
.def("_swap_out", [](Interpreter::Channel& self, Interpreter::Handle handle) {
self.swap_out(handle);
})
.def("_drop", [](Interpreter::Channel& self, Interpreter::Handle handle) {
self.drop(handle);
})
.def("get_value", [](Interpreter::Channel& self, Interpreter::Handle handle) {
PyObject* optr = npy::ndarray_from_tensor(self.get_value(handle), npy::ShareType::TRY_SHARE);
return py::reinterpret_steal<py::object>(optr);
@@ -76,6 +85,8 @@ void init_imperative_rt(py::module m) {
.def("get_device", &Interpreter::Channel::get_device)
.def("get_shape", &Interpreter::Channel::get_shape)
.def("_get_dev_tensor", &Interpreter::Channel::get_dev_tensor)
.def("_set_swap_flag", &Interpreter::Channel::set_swap_flag)
.def("_set_drop_flag", &Interpreter::Channel::set_drop_flag)
.def("apply_op", &Interpreter::Channel::apply_op)
.def("config_async_level", &Interpreter::Channel::config_async_level)
.def("get_async_level", &Interpreter::Channel::get_async_level)
@@ -84,7 +95,7 @@ void init_imperative_rt(py::module m) {
std::unique_ptr<Interpreter::Channel> ch = Interpreter::inst().create_channel();
m.attr("interpreter") = py::detail::make_caster<decltype(ch)>::cast(
std::move(ch), py::return_value_policy::move, {});
for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op", "config_async_level", "get_async_level"}) {
for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op", "config_async_level", "get_async_level", "_drop", "_swap_in", "_swap_out", "_set_drop_flag", "_set_swap_flag"}) {
m.attr(name) = m.attr("interpreter").attr(name);
}



+ 124
- 0
imperative/python/test/integration/test_converge_with_swap_and_drop.py View File

@@ -0,0 +1,124 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import itertools

import numpy as np
import pytest

import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F
from megengine import Tensor
from megengine.core._imperative_rt.imperative import _set_drop_flag, _set_swap_flag
from megengine.module import Linear, Module
from megengine.optimizer import SGD

batch_size = 64
data_shape = (batch_size, 2)
label_shape = (batch_size,)


def minibatch_generator():
while True:
inp_data = np.zeros((batch_size, 2))
label = np.zeros(batch_size, dtype=np.int32)
for i in range(batch_size):
# [x0, x1], sampled from U[-1, 1]
inp_data[i, :] = np.random.rand(2) * 2 - 1
label[i] = 0 if np.prod(inp_data[i]) < 0 else 1
yield inp_data.astype(np.float32), label.astype(np.int32)


def calculate_precision(data: np.ndarray, pred: np.ndarray) -> float:
""" Calculate precision for given data and prediction.

:type data: [[x, y], ...]
:param data: Input data
:type pred: [[x_pred, y_pred], ...]
:param pred: Network output data
"""
correct = 0
assert len(data) == len(pred)
for inp_data, pred_output in zip(data, pred):
label = 0 if np.prod(inp_data) < 0 else 1
pred_label = np.argmax(pred_output)
if pred_label == label:
correct += 1
return float(correct) / len(data)


class XORNet(Module):
def __init__(self):
self.mid_layers = 14
self.num_class = 2
super().__init__()

self.fc0 = Linear(self.num_class, self.mid_layers, bias=True)
self.fc1 = Linear(self.mid_layers, self.mid_layers, bias=True)

self.fc2 = Linear(self.mid_layers, self.num_class, bias=True)

def forward(self, x):
y = self.fc0(x)
x._swap_out()
x = F.tanh(y)
y = self.fc1(x)
x = F.tanh(y)
x = self.fc2(x)
y = (x + x) / 2 # in order to test drop()
y._drop()
return y


def test_training_converge_with_swap_and_drop():
_set_swap_flag(True)
_set_drop_flag(True)

net = XORNet()
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
gm = ad.GradManager().attach(net.parameters())

def train(data, label):
with gm:
pred = net(data)
loss = F.nn.cross_entropy(pred, label)
gm.backward(loss)
return loss

def infer(data):
return net(data)

train_dataset = minibatch_generator()
losses = []

for data, label in itertools.islice(train_dataset, 2000):
data = Tensor(data, dtype=np.float32)
label = Tensor(label, dtype=np.int32)
opt.clear_grad()
loss = train(data, label)
opt.step()
losses.append(loss.numpy())

assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"

ngrid = 10
x = np.linspace(-1.0, 1.0, ngrid)
xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)

pred = infer(Tensor(data)).numpy()
precision = calculate_precision(data, pred)
assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision
)

_set_swap_flag(False)
_set_drop_flag(False)

+ 196
- 11
imperative/src/impl/interpreter_impl.cpp View File

@@ -52,9 +52,37 @@ void ChannelImpl::del(void* handle) {
m_worker.add_task(Del{reinterpret_cast<TensorInfo*>(handle)});
}

void ChannelImpl::swap_in(void* handle) {
if (m_enable_evict & SWAP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
m_worker.add_task(SwapIn{reinterpret_cast<TensorInfo*>(handle)});
}
}

void ChannelImpl::swap_out(void* handle) {
if (m_enable_evict & SWAP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
m_worker.add_task(SwapOut{reinterpret_cast<TensorInfo*>(handle)});
}
}

void ChannelImpl::drop(void* handle) {
if (m_enable_evict & DROP) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
m_worker.add_task(Drop{reinterpret_cast<TensorInfo*>(handle)});
}
}

SmallVector<void*> ChannelImpl::apply_op(
std::shared_ptr<OpDef> op,
const SmallVector<void*>& inputs) {
for (auto i : inputs) {
mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
"invalid handle: %p", i);
}
SmallVector<TensorInfo*> input_infos;
input_infos.reserve(inputs.size());
SmallVector<LogicalTensorDesc> input_descs;
@@ -75,7 +103,8 @@ SmallVector<void*> ChannelImpl::apply_op(
SmallVector<void*> outputs;
// FIXME: remove this check when op check is correct
bool validated_bkp = true;
for (auto&& desc : output_descs) {
for (size_t i = 0;i < output_descs.size();i ++) {
auto&& desc = output_descs[i];
if (desc.layout.ndim == 0) {
validated_bkp = false;
}
@@ -85,6 +114,18 @@ SmallVector<void*> ChannelImpl::apply_op(
cmd.outputs.push_back(info);
outputs.push_back(info);
}
if (m_enable_evict & DROP) {
for (auto out : cmd.outputs) {
out->path.op = cmd.op;
for (auto out_ : cmd.outputs) {
out->path.outputs.push_back(m_st.at(out_));
}
for (auto inp : cmd.inputs) {
out->path.inputs.push_back(m_st.at(inp));
inp->path.dep_outputs.push_back(m_st.at(out));
}
}
}
m_worker.add_task(std::move(cmd));
if (!(validated && validated_bkp) && m_async_level == 1) {
sync();
@@ -192,11 +233,18 @@ int ChannelImpl::get_async_level() {

TensorInfo* ChannelImpl::alloc() {
MGB_LOCK_GUARD(m_mutex);
return m_pool.alloc();
auto info = m_pool.alloc();
m_st.insert(info);
return info;
}

void ChannelImpl::free(TensorInfo* ptr) {
MGB_LOCK_GUARD(m_mutex);
if (ptr->path.dep_outputs.size() > 0) {
remove_dep(ptr);
}
m_st.erase(ptr);
mgb_assert(ptr->allow_delete, "delete before ref_cnt = 0");
m_pool.free(ptr);
}

@@ -204,15 +252,136 @@ ChannelImpl::~ChannelImpl() {
close();
}

void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
MGB_LOCK_GUARD(m_mutex);
dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->ptr = std::move(ptr);
if (m_waitee == dest) {
m_cv.notify_all();
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) {
if (notice) {
MGB_LOCK_GUARD(m_mutex);
dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->ptr = std::move(ptr);
if (m_waitee == dest) {
m_cv.notify_all();
}
} else {
dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->ptr = std::move(ptr);
}
}

void ChannelImpl::do_swap_out(TensorInfo* dest) {
if (dest->evict_type == DROP) {
mgb_log_warn("the evict type of tensor %p was set to DROP, this SWAP operation will be ignored", dest);
return;
}
if (!dest->ptr) {
return;
}
dest->evict_type = SWAP;
dest->value_fetched = false;
// TODO: swap in parallel
dest->h_value.copy_from(dest->ptr->dev_tensor()).sync();
dest->ptr.reset();
}

void ChannelImpl::do_swap_in(TensorInfo* dest) {
if (dest->ptr) {
return;
}
if (dest->h_value.empty()) {
mgb_log_error("backup of the tensor %p not found", dest);
return;
}
produce_tensor(dest, Tensor::make(dest->h_value), false);
dest->evict_type = NONE;
}

void ChannelImpl::remove_dep(TensorInfo* dest) {
for (auto i : dest->path.dep_outputs) {
auto out_ptr = i.lock();
if (out_ptr) {
regenerate(out_ptr.get(), true);
}
}
}

void ChannelImpl::do_drop(TensorInfo* dest) {
if (dest->evict_type == SWAP) {
mgb_log_warn("the evict type of tensor %p was set to SWAP, this DROP operation will be ignored", dest);
return;
}
if (!dest->path.op) {
mgb_log_warn("the input that produced tensor %p has been deleted, this drop operation will be ignored", dest);
return;
}
if (dest->recompute_times >= m_max_recompute_time) {
mgb_log_warn("the recomputation time for tensor %p exceeds the limit, this drop operation will be ignored", dest);
return;
}
if (!dest->ptr) {
return;
}
dest->evict_type = DROP;
dest->value_fetched = false;
dest->ptr.reset();
}

void ChannelImpl::set_swap_flag(bool flag) {
if (flag) {
m_enable_evict |= SWAP;
} else {
m_enable_evict &= ~SWAP;
}
}

void ChannelImpl::set_drop_flag(bool flag) {
if (flag) {
m_enable_evict |= DROP;
} else {
m_enable_evict &= ~DROP;
}
}

void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) {
if (!info->ptr && info->evict_type != NONE) {
if (info->evict_type == SWAP) {
do_swap_in(info);
} else {
mgb_assert(info->evict_type == DROP);
mgb_assert(info->path.op, "recomputation path not found");
auto path = info->path;
SmallVector<TensorPtr> inputs;
inputs.reserve(path.inputs.size());
for (auto i : path.inputs) {
mgb_assert(i, "invalid history input");
if (!i->ptr) {
regenerate(i.get(), must_drop);
}
inputs.push_back(i->ptr);
}
auto outputs = OpDef::apply_on_physical_tensor(*path.op, inputs);
for (size_t i = 0; i < outputs.size(); i ++) {
auto out_ptr = path.outputs[i].lock();
if (out_ptr) {
out_ptr->recompute_times ++;
if (!out_ptr->ptr && out_ptr->evict_type == DROP) {
produce_tensor(out_ptr.get(), std::move(outputs[i]), false);
}
}
}
}
}
if (must_drop) {
if (info->path.op) {
info->path.op.reset();
info->path.inputs.clear();
if (info->evict_type == DROP) {
info->evict_type = NONE;
}
}
}
}

@@ -227,6 +396,11 @@ void ChannelImpl::process_one_task(Command& cmd) {
SmallVector<TensorPtr> tensor_inputs;
tensor_inputs.reserve(cmd.inputs.size());
for (auto i : cmd.inputs) {
if (m_enable_evict && i->evict_type != NONE) {
if (!i->ptr) {
regenerate(i);
}
}
mgb_assert(i->ptr, "Invalid input tensor ptr!");
tensor_inputs.push_back(i->ptr);
}
@@ -238,6 +412,11 @@ void ChannelImpl::process_one_task(Command& cmd) {
} else if constexpr (std::is_same_v<T, Del>) {
free(cmd.dest);
} else if constexpr (std::is_same_v<T, GetValue>) {
if (m_enable_evict && cmd.dest->evict_type != NONE) {
if (!cmd.dest->ptr) {
regenerate(cmd.dest);
}
}
mgb_assert(cmd.dest->ptr, "Invalid tensor ptr!");
cmd.dest->ptr->fetch_value();
MGB_LOCK_GUARD(m_mutex);
@@ -245,6 +424,12 @@ void ChannelImpl::process_one_task(Command& cmd) {
if (m_waitee == cmd.dest) {
m_cv.notify_all();
}
} else if constexpr (std::is_same_v<T, SwapIn>) {
do_swap_in(cmd.dest);
} else if constexpr (std::is_same_v<T, SwapOut>) {
do_swap_out(cmd.dest);
} else if constexpr (std::is_same_v<T, Drop>) {
do_drop(cmd.dest);
} else {
static_assert(!std::is_same_v<T, T>);
}


+ 69
- 2
imperative/src/impl/interpreter_impl.h View File

@@ -24,11 +24,34 @@ struct InterpreterImpl : Interpreter {
std::unique_ptr<Channel> create_channel() override;
};

enum EvictType {
NONE = 0,
SWAP = 1,
DROP = 2,
};

struct TensorInfo;
using TensorInfoPtr = std::shared_ptr<TensorInfo>;

struct TensorInfo {
TensorPtr ptr;
LogicalTensorDesc desc;
bool value_fetched = false;
bool invalid = false;
bool allow_delete = false;

EvictType evict_type = NONE;

HostTensorND h_value;
size_t locked = 0;
size_t recompute_times = 0;
struct ComputePath {
std::shared_ptr<OpDef> op;
SmallVector<TensorInfoPtr> inputs;
SmallVector<std::weak_ptr<TensorInfo>> outputs;
SmallVector<std::weak_ptr<TensorInfo>> dep_outputs;
} path;
};

struct Put {
@@ -46,10 +69,24 @@ struct Del {
struct GetValue {
TensorInfo* dest;
};

struct SwapIn {
TensorInfo* dest;
};
struct SwapOut {
TensorInfo* dest;
};
struct Drop {
TensorInfo* dest;
};

using Command = std::variant<Put,
ApplyOp,
Del,
GetValue>;
GetValue,
SwapIn,
SwapOut,
Drop>;

struct ChannelImpl : Interpreter::Channel {
ChannelImpl() : m_worker(this) {}
@@ -59,6 +96,9 @@ struct ChannelImpl : Interpreter::Channel {
Handle put(const DeviceTensorND& value) override;

void del(Handle) override;
void swap_in(Handle) override;
void swap_out(Handle) override;
void drop(Handle) override;

SmallVector<Handle> apply_op(
std::shared_ptr<OpDef> op,
@@ -73,6 +113,8 @@ struct ChannelImpl : Interpreter::Channel {

void sync() override;
void close() override;
void set_swap_flag(bool) override;
void set_drop_flag(bool) override;

void config_async_level(int level) override;
int get_async_level() override;
@@ -80,12 +122,17 @@ struct ChannelImpl : Interpreter::Channel {
private:
TensorInfo* alloc();
void free(TensorInfo*);
void remove_dep(TensorInfo*);

void process_one_task(Command&);

void check_worker_exc_unsafe();

void produce_tensor(TensorInfo* dest, TensorPtr ptr);
void produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice);
void do_swap_out(TensorInfo* dest);
void do_swap_in(TensorInfo* dest);
void do_drop(TensorInfo* dest);
void regenerate(TensorInfo* dest, bool must_drop);

std::mutex m_mutex;
std::condition_variable m_cv;
@@ -93,6 +140,7 @@ private:
std::unordered_set<Handle> m_valid_handle;
TensorInfo* m_waitee = nullptr;
std::exception_ptr m_worker_exc;
size_t m_enable_evict = 0;

struct WorkQueue : AsyncQueueSC<Command, WorkQueue> {
WorkQueue(ChannelImpl* owner) : m_owner(owner) {}
@@ -103,11 +151,30 @@ private:
ChannelImpl* m_owner;
} m_worker;

struct SharedTensorInfoMap {
void insert(TensorInfo* info) {
MGB_LOCK_GUARD(mtx);
tmap.emplace(info, TensorInfoPtr{info, [](TensorInfo* ptr){ ptr->allow_delete = true;}});
}
void erase(TensorInfo* info) {
MGB_LOCK_GUARD(mtx);
tmap.erase(info);
}
TensorInfoPtr at(TensorInfo* info) {
MGB_LOCK_GUARD(mtx);
return tmap.at(info);
}
private:
std::mutex mtx;
std::unordered_map<TensorInfo*, TensorInfoPtr> tmap;
}m_st;
//! config whether raise error exactly when invoking op.
//! level 2: both device and user side errors are async;
//! level 1: user side errors are sync;
//! level 0: both sync.
int m_async_level = 2;
int m_max_recompute_time = 1;
};

} // namespace mgb::imperative::interpreter::intl

+ 5
- 0
imperative/src/include/megbrain/imperative/interpreter.h View File

@@ -25,6 +25,9 @@ struct Interpreter {
virtual Handle put(const DeviceTensorND& value) = 0;

virtual void del(Handle) = 0;
virtual void swap_in(Handle) = 0;
virtual void swap_out(Handle) = 0;
virtual void drop(Handle) = 0;

virtual SmallVector<Handle> apply_op(
std::shared_ptr<OpDef> op,
@@ -39,6 +42,8 @@ struct Interpreter {

virtual void sync() = 0;
virtual void close() = 0;
virtual void set_swap_flag(bool) = 0;
virtual void set_drop_flag(bool) = 0;

virtual void config_async_level(int level) = 0;
virtual int get_async_level() = 0;


Loading…
Cancel
Save