GitOrigin-RevId: cb73e62b19
tags/v1.6.0-rc1
@@ -0,0 +1,114 @@ | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.functional as F | |||||
import megengine.module as M | |||||
import megengine.optimizer as optim | |||||
import megengine.tensor as tensor | |||||
from megengine.autodiff import GradManager | |||||
from megengine.data import DataLoader, RandomSampler, transform | |||||
from megengine.data.dataset import CIFAR10 | |||||
def _weights_init(m): | |||||
classname = m.__class__.__name__ | |||||
if isinstance(m, M.Linear) or isinstance(m, M.Conv2d): | |||||
M.init.msra_normal_(m.weight) | |||||
mean = [125.3, 123.0, 113.9] | |||||
std = [63.0, 62.1, 66.7] | |||||
class BasicBlock(M.Module): | |||||
expansion = 1 | |||||
def __init__(self, in_planes, planes, stride=1): | |||||
super(BasicBlock, self).__init__() | |||||
self.conv1 = M.Conv2d( | |||||
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False | |||||
) | |||||
self.bn1 = M.BatchNorm2d(planes) | |||||
self.conv2 = M.Conv2d( | |||||
planes, planes, kernel_size=3, stride=1, padding=1, bias=False | |||||
) | |||||
self.bn2 = M.BatchNorm2d(planes) | |||||
self.shortcut = M.Sequential() | |||||
if stride != 1 or in_planes != planes: | |||||
self.shortcut = M.Sequential( | |||||
M.Conv2d( | |||||
in_planes, | |||||
self.expansion * planes, | |||||
kernel_size=1, | |||||
stride=stride, | |||||
bias=False, | |||||
), | |||||
M.BatchNorm2d(self.expansion * planes), | |||||
) | |||||
def forward(self, x): | |||||
out = F.relu(self.bn1(self.conv1(x))) | |||||
out = self.bn2(self.conv2(out)) | |||||
out += self.shortcut(x) | |||||
out = F.relu(out) | |||||
return out | |||||
class ResNet(M.Module): | |||||
def __init__(self, block, num_blocks, num_classes=10): | |||||
super(ResNet, self).__init__() | |||||
self.in_planes = 16 | |||||
self.conv1 = M.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) | |||||
self.bn1 = M.BatchNorm2d(16) | |||||
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) | |||||
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) | |||||
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) | |||||
self.linear = M.Linear(64, num_classes) | |||||
self.apply(_weights_init) | |||||
def _make_layer(self, block, planes, num_blocks, stride): | |||||
strides = [stride] + [1] * (num_blocks - 1) | |||||
layers = [] | |||||
for stride in strides: | |||||
layers.append(block(self.in_planes, planes, stride)) | |||||
self.in_planes = planes * block.expansion | |||||
return M.Sequential(*layers) | |||||
def forward(self, x): | |||||
out = F.relu(self.bn1(self.conv1(x))) | |||||
out = self.layer1(out) | |||||
out = self.layer2(out) | |||||
out = self.layer3(out) | |||||
out = out.mean(3).mean(2) | |||||
out = self.linear(out) | |||||
return out | |||||
@pytest.mark.require_ngpu(1) | |||||
def test_dtr_resnet1202(): | |||||
batch_size = 64 | |||||
resnet1202 = ResNet(BasicBlock, [200, 200, 200]) | |||||
opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) | |||||
gm = GradManager().attach(resnet1202.parameters()) | |||||
def train_func(data, label, *, net, gm): | |||||
net.train() | |||||
with gm: | |||||
pred = net(data) | |||||
loss = F.loss.cross_entropy(pred, label) | |||||
gm.backward(loss) | |||||
return pred, loss | |||||
mge.dtr.enable() | |||||
data = np.random.randn(batch_size, 3, 32, 32).astype("float32") | |||||
label = np.random.randint(0, 10, size=(batch_size,)).astype("int32") | |||||
for step in range(10): | |||||
opt.clear_grad() | |||||
_, loss = train_func(mge.tensor(data), mge.tensor(label), net=resnet1202, gm=gm) | |||||
opt.step() | |||||
loss.item() |
@@ -615,13 +615,15 @@ void ChannelImpl::release_tensor(TensorInfo* dest) { | |||||
} | } | ||||
void ChannelImpl::regenerate(TensorInfo* dest) { | void ChannelImpl::regenerate(TensorInfo* dest) { | ||||
RECORD_EVENT(TensorCommandEvent, dest->id, TensorCommandEvent::ReGen); | |||||
if (dest->evict_type == EvictType::DROP) { | if (dest->evict_type == EvictType::DROP) { | ||||
recompute(dest->producer); | |||||
auto &&path = dest->producer; | |||||
m_apply_stack.push({ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}, 0, dest}); | |||||
if (!m_applying) flush_apply_stack(); | |||||
} else if (dest->evict_type == EvictType::SWAP) { | } else if (dest->evict_type == EvictType::SWAP) { | ||||
RECORD_EVENT(TensorCommandEvent, dest->id, TensorCommandEvent::ReGen); | |||||
produce_tensor(dest, Tensor::make(dest->h_value)); | produce_tensor(dest, Tensor::make(dest->h_value)); | ||||
RECORD_EVENT(TensorCommandFinishEvent, dest->id, TensorCommandFinishEvent::ReGen); | |||||
} | } | ||||
RECORD_EVENT(TensorCommandFinishEvent, dest->id, TensorCommandFinishEvent::ReGen); | |||||
} | } | ||||
void ChannelImpl::do_apply_op(const ApplyOp& cmd) { | void ChannelImpl::do_apply_op(const ApplyOp& cmd) { | ||||
@@ -635,17 +637,6 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { | |||||
MemoryDesc desc; | MemoryDesc desc; | ||||
}; | }; | ||||
SmallVector<TensorWithDesc> inputs; | SmallVector<TensorWithDesc> inputs; | ||||
// SmallVector<TensorPtr> tensor_inputs; | |||||
if (state.options.enable_dtr_auto_drop) { | |||||
m_dtr.pin(cmd.inputs); | |||||
} | |||||
for (auto i : cmd.inputs) { | |||||
if (!i->ptr && i->evict_type != EvictType::NONE) { | |||||
regenerate(i); | |||||
} | |||||
m_dtr.update_used_time(i); | |||||
} | |||||
// tensor_inputs.reserve(cmd.inputs.size()); | |||||
inputs.reserve(cmd.inputs.size()); | inputs.reserve(cmd.inputs.size()); | ||||
// refcnt == 1, owners: [TensorInfo::ptr] | // refcnt == 1, owners: [TensorInfo::ptr] | ||||
for (auto i : cmd.inputs) { | for (auto i : cmd.inputs) { | ||||
@@ -781,20 +772,48 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { | |||||
// End profiling operator | // End profiling operator | ||||
} | } | ||||
void ChannelImpl::recompute(TensorInfo::ComputePath* path) { | |||||
void ChannelImpl::flush_apply_stack() { | |||||
m_applying = true; | |||||
auto& state = get_worker_state(); | auto& state = get_worker_state(); | ||||
do_apply_op(ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}); | |||||
for (size_t i = 0;i < path->outputs.size();i ++) { | |||||
auto&& o = path->outputs[i]; | |||||
if (o) { | |||||
o->recompute_times ++; | |||||
if (!o->ptr) { | |||||
if (state.options.enable_dtr_auto_drop) { | |||||
while (!m_apply_stack.empty()) { | |||||
auto& [cmd, idx, recomp] = m_apply_stack.top(); // cmd.inputs[0~idx-1] is in memory | |||||
if (idx == 0) { | |||||
if (state.options.enable_dtr_auto_drop) { | |||||
m_dtr.pin(cmd.inputs); | |||||
} | |||||
if (recomp) { | |||||
RECORD_EVENT(TensorCommandEvent, recomp->id, TensorCommandEvent::ReGen); | |||||
} | |||||
} | |||||
bool regen = false; | |||||
for (size_t i = idx; i < cmd.inputs.size(); i ++) { | |||||
auto&& p = cmd.inputs[i]; | |||||
if (state.options.enable_dtr_auto_drop) { | |||||
m_dtr.update_used_time(p); | |||||
} | |||||
if (!p->ptr && p->evict_type != EvictType::NONE) { | |||||
idx = i + 1; | |||||
regenerate(p); // add ApplyOp to the stack | |||||
regen = true; | |||||
break; | |||||
} | |||||
} | |||||
if (regen) continue; | |||||
// the required input tensors are already in memory | |||||
auto cmd_backup = cmd; | |||||
auto recomp_backup = recomp; | |||||
m_apply_stack.pop(); | |||||
do_apply_op(cmd_backup); | |||||
if (recomp_backup) { | |||||
RECORD_EVENT(TensorCommandFinishEvent, recomp_backup->id, TensorCommandFinishEvent::ReGen); | |||||
for (auto o : cmd_backup.outputs) { | |||||
if (o) { | |||||
m_dtr.update_dsu_after_recompute(o); | m_dtr.update_dsu_after_recompute(o); | ||||
} | } | ||||
} | } | ||||
} | } | ||||
} | } | ||||
m_applying = false; | |||||
} | } | ||||
bool ChannelImpl::auto_evict(size_t force_num) { | bool ChannelImpl::auto_evict(size_t force_num) { | ||||
@@ -997,7 +1016,8 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { | |||||
RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Put); | RECORD_EVENT(TensorCommandFinishEvent, cmd.dest->id, TensorCommandFinishEvent::Put); | ||||
sample_on_device(cmd.dest->desc.comp_node, false); | sample_on_device(cmd.dest->desc.comp_node, false); | ||||
} else if constexpr (std::is_same_v<T, ApplyOp>) { | } else if constexpr (std::is_same_v<T, ApplyOp>) { | ||||
do_apply_op(cmd); | |||||
m_apply_stack.push({cmd, 0, nullptr}); | |||||
flush_apply_stack(); | |||||
for (size_t i = 0; i < cmd.outputs.size(); ++i) { | for (size_t i = 0; i < cmd.outputs.size(); ++i) { | ||||
auto output = cmd.outputs[i]; | auto output = cmd.outputs[i]; | ||||
if (output == nullptr) { | if (output == nullptr) { | ||||
@@ -14,10 +14,10 @@ | |||||
#include <deque> | #include <deque> | ||||
#include <future> | #include <future> | ||||
#include <list> | #include <list> | ||||
#include <stack> | |||||
#include <thread> | #include <thread> | ||||
#include <unordered_set> | #include <unordered_set> | ||||
#include <variant> | #include <variant> | ||||
#include "megbrain/comp_node.h" | #include "megbrain/comp_node.h" | ||||
#include "megbrain/utils/mempool.h" | #include "megbrain/utils/mempool.h" | ||||
#include "megbrain/imperative/interpreter.h" | #include "megbrain/imperative/interpreter.h" | ||||
@@ -103,8 +103,8 @@ private: | |||||
void release_tensor(TensorInfo* dest); | void release_tensor(TensorInfo* dest); | ||||
void regenerate(TensorInfo* dest); | void regenerate(TensorInfo* dest); | ||||
void recompute(TensorInfo::ComputePath* path); | |||||
void do_apply_op(const ApplyOp& cmd); | void do_apply_op(const ApplyOp& cmd); | ||||
void flush_apply_stack(); | |||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> init_output_and_workspace( | std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> init_output_and_workspace( | ||||
const OpDef& def, | const OpDef& def, | ||||
@@ -149,7 +149,8 @@ private: | |||||
std::exception_ptr m_worker_exc; | std::exception_ptr m_worker_exc; | ||||
std::function<void(std::string, std::string)> m_profile_dump_callback; | std::function<void(std::string, std::string)> m_profile_dump_callback; | ||||
size_t m_storage_id = 0; | size_t m_storage_id = 0; | ||||
std::stack<std::tuple<ApplyOp, size_t, TensorInfo*>> m_apply_stack; | |||||
bool m_applying = false; | |||||
bool m_closed = false; | bool m_closed = false; | ||||
struct WorkQueue : AsyncQueueSC<IdentifiedCommand, WorkQueue> { | struct WorkQueue : AsyncQueueSC<IdentifiedCommand, WorkQueue> { | ||||