Browse Source

fix(tensor): del valid tensors when compnode finalizing

GitOrigin-RevId: bace1f2b51
release-1.4
Megvii Engine Team 4 years ago
parent
commit
b6dc4c824d
6 changed files with 73 additions and 24 deletions
  1. +2
    -1
      imperative/python/megengine/__init__.py
  2. +14
    -6
      imperative/python/src/tensor.cpp
  3. +14
    -0
      imperative/python/test/unit/core/test_interpreter.py
  4. +1
    -0
      imperative/src/impl/event_pool.cpp
  5. +37
    -14
      imperative/src/impl/interpreter/interpreter_impl.cpp
  6. +5
    -3
      imperative/src/impl/interpreter/interpreter_impl.h

+ 2
- 1
imperative/python/megengine/__init__.py View File

@@ -71,6 +71,7 @@ if sys.platform == "win32":

kernel32.SetErrorMode(old_error_mode)

from .core._imperative_rt.core2 import close as _close
from .core._imperative_rt.core2 import full_sync as _full_sync
from .core._imperative_rt.core2 import sync as _sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
@@ -90,7 +91,7 @@ _set_fork_exec_path_for_timed_func(
_persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
_persistent_cache_impl_ins.reg()

atexit.register(_full_sync)
atexit.register(_close)

del _set_fork_exec_path_for_timed_func



+ 14
- 6
imperative/python/src/tensor.cpp View File

@@ -897,6 +897,11 @@ void init_tensor(py::module m) {
}
}

static constexpr auto sync_py_task_q = []{
py::gil_scoped_release _;
py_task_q.wait_all_task_finish();
};

m.def("set_option",
[](std::string name, size_t value){ interpreter_for_py->set_option(name, value); });
m.def("get_option",
@@ -928,16 +933,19 @@ void init_tensor(py::module m) {
m.def("sync",
[]() {
interpreter_for_py->sync();
py_task_q.wait_all_task_finish();
},
py::call_guard<py::gil_scoped_release>());
sync_py_task_q();
});
m.def("full_sync",
[]() {
interpreter_for_py->sync();
CompNode::sync_all();
py_task_q.wait_all_task_finish();
},
py::call_guard<py::gil_scoped_release>());
sync_py_task_q();
});
m.def("close",
[]() {
interpreter_for_py->close();
sync_py_task_q();
});

py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach")


+ 14
- 0
imperative/python/test/unit/core/test_interpreter.py View File

@@ -1,3 +1,6 @@
import subprocess
import sys

import numpy as np
import pytest

@@ -76,3 +79,14 @@ def test_swap_drop_basic():
z.numpy()
_set_swap_flag(False)
_set_drop_flag(False)


def test_finalize():
prog = """
import megengine
with megengine.core.option("enable_host_compute", 0):
x = megengine.tensor(0)
y = x + 1
y.numpy()
"""
subprocess.check_call([sys.executable, "-c", prog])

+ 1
- 0
imperative/src/impl/event_pool.cpp View File

@@ -67,6 +67,7 @@ std::shared_ptr<void> EventPool::on_comp_node_finalize() {
for (auto&& i : m_cn2pool) {
i.second.assert_all_freed();
}
m_cn2pool.clear();
return {};
}
EventPool::~EventPool() {


+ 37
- 14
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -33,6 +33,7 @@ Interpreter& Interpreter::inst() {
}

Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
mgb_assert(check_available(), "Channel already closed");
auto info = alloc();
info->desc.layout = value.layout();
info->desc.comp_node = value.comp_node();
@@ -47,6 +48,7 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
}

Handle ChannelImpl::put(const DeviceTensorND& data) {
mgb_assert(check_available(), "Channel already closed");
auto info = alloc();
info->desc.layout = data.layout();
info->desc.comp_node = data.comp_node();
@@ -58,6 +60,9 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
}

void ChannelImpl::del(Handle handle) {
if (!check_available()){
return;
}
mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle);
m_valid_handle.erase(handle);
@@ -65,6 +70,7 @@ void ChannelImpl::del(Handle handle) {
}

void ChannelImpl::swap_in(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
if (m_worker_state.options.enable_swap) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
@@ -74,6 +80,7 @@ void ChannelImpl::swap_in(Handle handle) {
}

void ChannelImpl::swap_out(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
if (m_worker_state.options.enable_swap) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
@@ -83,6 +90,7 @@ void ChannelImpl::swap_out(Handle handle) {
}

void ChannelImpl::drop(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
if (m_worker_state.options.enable_drop) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
@@ -201,6 +209,7 @@ void ChannelImpl::dispatch_kernel(
SmallVector<Handle> ChannelImpl::apply_op(
std::shared_ptr<OpDef> op,
const SmallVector<Handle>& inputs) {
mgb_assert(check_available(), "Channel already closed");
for (auto i : inputs) {
mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
"invalid handle: %p", i);
@@ -237,6 +246,7 @@ SmallVector<Handle> ChannelImpl::apply_op(
}

HostTensorND ChannelImpl::get_value(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
// TODO: maybe get_value should be done on host. i.e. delete GetValue
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
@@ -269,6 +279,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
}

TensorShape ChannelImpl::get_shape(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
@@ -296,6 +307,7 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
}

DType ChannelImpl::get_dtype(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
@@ -308,6 +320,7 @@ DType ChannelImpl::get_dtype(Handle handle) {
}

CompNode ChannelImpl::get_device(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
@@ -320,6 +333,7 @@ CompNode ChannelImpl::get_device(Handle handle) {
}

DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
@@ -342,6 +356,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
}

void ChannelImpl::sync() {
mgb_assert(check_available(), "Channel already closed");
m_buffer.flush();
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<SyncStartEvent>();
@@ -356,14 +371,26 @@ void ChannelImpl::sync() {
}

void ChannelImpl::close() {
if (!check_available()) {
return;
}
std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
for (auto* handle: valid_handles) {
del(handle);
}
mgb_assert(m_valid_handle.empty());
mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
sync();
m_closed = true;
}

size_t ChannelImpl::get_option(std::string name) {
mgb_assert(check_available(), "Channel already closed");
return m_channel_state.options.get_option(name);
}

void ChannelImpl::set_option(std::string name, size_t value) {
mgb_assert(check_available(), "Channel already closed");
m_channel_state.options.set_option(name, value);
m_buffer.enqueue(SetOption{name, value});
}
@@ -440,9 +467,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) {
m_pool.free(ptr);
}

ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){
m_channel_state.tid = std::this_thread::get_id();
}
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){}

ChannelImpl::~ChannelImpl() {
close();
@@ -562,6 +587,10 @@ void ChannelImpl::detach_users(TensorInfo* dest) {
//dest->users.clear();
}

bool ChannelImpl::check_available() {
return !m_closed;
}

void ChannelImpl::sync_device_scope(CompNode device) {
auto& prev = m_worker_state.device_scope_map[device];
auto& current = m_worker_state.scopes;
@@ -786,9 +815,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
std::swap(profiler, m_worker_state.profiler);
auto records = profiler->stop();
auto host_map = [this](std::thread::id tid) {
if (tid == m_channel_state.tid) {
return "channel";
} else if (tid == m_worker_state.tid) {
if (tid == m_worker_state.tid) {
return "worker";
} else {
return "unknown";
@@ -959,6 +986,7 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
}

void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) {
mgb_assert(check_available(), "Channel already closed");
auto profiler_option = InterpreterProfiler::Option::from_dict(option);
auto profiler = std::make_unique<InterpreterProfiler>();
profiler->set_option(profiler_option);
@@ -968,6 +996,7 @@ void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) {
}

void ChannelImpl::stop_profile(std::string basename, std::string format) {
mgb_assert(check_available(), "Channel already closed");
m_buffer.flush();
auto profiler = std::make_unique<InterpreterProfiler>();
std::swap(profiler, m_channel_state.profiler);
@@ -976,6 +1005,7 @@ void ChannelImpl::stop_profile(std::string basename, std::string format) {
}

void ChannelImpl::push_scope(std::string name) {
mgb_assert(check_available(), "Channel already closed");
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<ChannelBeginScope>(name);
m_channel_state.scopes.push_back(name);
@@ -984,6 +1014,7 @@ void ChannelImpl::push_scope(std::string name) {
}

void ChannelImpl::pop_scope(std::string name) {
mgb_assert(check_available(), "Channel already closed");
if (m_channel_state.profiler->is_profiling()) {
mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch");
m_channel_state.scopes.pop_back();
@@ -992,14 +1023,6 @@ void ChannelImpl::pop_scope(std::string name) {
}
}

void ChannelImpl::assert_in_channel() {
mgb_assert(m_channel_state.tid != std::this_thread::get_id());
}

void ChannelImpl::assert_in_worker() {
mgb_assert(m_worker_state.tid == std::this_thread::get_id());
}

void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
for (auto i : vec) {
i->pin();


+ 5
- 3
imperative/src/impl/interpreter/interpreter_impl.h View File

@@ -18,6 +18,7 @@
#include <unordered_set>
#include <variant>

#include "megbrain/comp_node.h"
#include "megbrain/utils/mempool.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/profiler.h"
@@ -102,8 +103,7 @@ private:
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs);

void assert_in_channel();
void assert_in_worker();
bool check_available();

void sync_device_scope(CompNode device);

@@ -120,6 +120,8 @@ private:
std::exception_ptr m_worker_exc;
std::atomic_uint64_t m_last_id = 0;

bool m_closed = false;

struct WorkQueue : AsyncQueueSC<IdentifiedCommand, WorkQueue> {
// set max_spin=0 to prevent Queue fetch task in busy wait manner.
// this won't affect throughput when python interpreter is sending enough task,
@@ -186,7 +188,6 @@ private:
int m_async_level = 2;

struct State {
std::thread::id tid;
OptionManager options;
std::vector<std::string> scopes;
std::unique_ptr<InterpreterProfiler> profiler;
@@ -199,6 +200,7 @@ private:
struct ChannelState: State {};

struct WorkerState: State {
std::thread::id tid;
CompNode::UnorderedMap<std::vector<std::string>> device_scope_map;
};



Loading…
Cancel
Save