Browse Source

fix(mm): fix mm error when use sync

GitOrigin-RevId: 63387bda04
release-1.11
Megvii Engine Team 2 years ago
parent
commit
b8c7557b86
7 changed files with 102 additions and 31 deletions
  1. +6
    -0
      imperative/python/megengine/distributed/helper.py
  2. +2
    -1
      imperative/python/megengine/distributed/launcher.py
  3. +67
    -22
      imperative/src/impl/interpreter/interpreter_impl.cpp
  4. +11
    -8
      imperative/src/impl/interpreter/interpreter_impl.h
  5. +3
    -0
      imperative/src/include/megbrain/imperative/interpreter.h
  6. +11
    -0
      src/core/include/megbrain/utils/metahelper.h
  7. +2
    -0
      src/core/include/megbrain/utils/thread_impl_1.h

+ 6
- 0
imperative/python/megengine/distributed/helper.py View File

@@ -162,6 +162,12 @@ def _check_device_initialized(device_type: str, rank: int):
raise RuntimeError(errmsg)


def _check_interpreter_status():
from ..core._imperative_rt.core2 import get_option

_ = get_option("async_level")


get_device_count_by_fork = deprecated_func(
"1.5", "megengine.device", "get_device_count", False
)


+ 2
- 1
imperative/python/megengine/distributed/launcher.py View File

@@ -9,7 +9,7 @@ from ..core._imperative_rt.core2 import full_sync
from ..device import get_device_count
from ..logger import get_logger
from .group import _set_machine_ranks, group_barrier, init_process_group
from .helper import _check_device_initialized
from .helper import _check_device_initialized, _check_interpreter_status
from .server import Client, Server

WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
@@ -33,6 +33,7 @@ def _run_wrapped(
machine_ranks: list,
):
r"""Init distributed process group and run wrapped function."""
_check_interpreter_status()
_check_device_initialized(device_type, dev)
init_process_group(
master_ip=master_ip,


+ 67
- 22
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -115,7 +115,18 @@ void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
#define m_worker_state

std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
return std::make_unique<ChannelImpl>();
auto ret = std::make_unique<ChannelImpl>();
#if !(defined(_WIN32) || defined(_WIN64))
auto disable_channels = [](void) -> void {
for (ChannelImpl* channel : ChannelImpl::m_all_active_channels) {
if (channel->worker_started()) {
channel->update_status_to_forked();
}
}
};
pthread_atfork(nullptr, nullptr, static_cast<void (*)(void)>(disable_channels));
#endif
return ret;
}

Interpreter& Interpreter::inst() {
@@ -125,7 +136,7 @@ Interpreter& Interpreter::inst() {

Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
std::optional<StackManager::Guard> guard;
if (Profiler::is_profiling()) {
auto& state = get_channel_state();
@@ -158,7 +169,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
Put{info, value, no_cache},
});
}
if (m_async_level == 0) {

if (get_channel_state().options.async_level == 0) {
sync_impl();
info->desc.comp_node.sync();
auto err = info->desc.comp_node.check_async_error();
@@ -169,7 +181,7 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {

Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
return reinterpret_cast<Handle>(put_impl(data, hvalue));
}
TensorInfo* ChannelImpl::put_impl(
@@ -221,7 +233,7 @@ void ChannelImpl::del_impl(Handle handle) {

void ChannelImpl::drop(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state();
if (state.options.enable_drop) {
mgb_assert(
@@ -404,7 +416,7 @@ void ChannelImpl::dispatch_kernel(
SmallVector<Handle> ChannelImpl::apply_op(
std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
if (op->same_type<GetVarShape>() && input->shape_valid()) {
size_t ndim = input->desc.layout.ndim;
@@ -460,7 +472,7 @@ SmallVector<Handle> ChannelImpl::apply_op_impl(

HostTensorND ChannelImpl::get_value(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle);
@@ -472,7 +484,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) {

TensorShape ChannelImpl::get_shape(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle);
@@ -487,7 +499,7 @@ TensorShape ChannelImpl::get_shape(Handle handle) {

DType ChannelImpl::get_dtype(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle);
@@ -500,7 +512,7 @@ DType ChannelImpl::get_dtype(Handle handle) {

CompNode ChannelImpl::get_device(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle);
@@ -513,7 +525,7 @@ CompNode ChannelImpl::get_device(Handle handle) {

DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle);
@@ -523,7 +535,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {

void ChannelImpl::sync() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
sync_impl();
}

@@ -545,19 +557,19 @@ void ChannelImpl::close() {
mgb_assert(m_valid_handle.empty());
mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
sync_impl();
m_closed = true;
m_status = ChannelRunningStatus::CLOSED;
}

size_t ChannelImpl::get_option(std::string name) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state();
return state.options.get_option(name);
}

void ChannelImpl::set_option(std::string name, size_t value) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state();
state.options.set_option(name, value);
// FIXME
@@ -583,7 +595,7 @@ void ChannelImpl::set_option(std::string name, size_t value) {

void ChannelImpl::clear_candidates() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
m_dtr.candidates.clear();
}

@@ -681,10 +693,18 @@ void ChannelImpl::real_free(TensorInfo* ptr) {
m_pool.free(ptr);
}

ChannelImpl::ChannelImpl() : m_worker(this) {}
std::unordered_set<ChannelImpl*> ChannelImpl::m_all_active_channels{};
MGB_MUTEX ChannelImpl::m_all_active_channels_mutex{};

ChannelImpl::ChannelImpl() : m_worker(this) {
MGB_LOCK_GUARD(m_all_active_channels_mutex);
m_all_active_channels.emplace(this);
}

ChannelImpl::~ChannelImpl() {
close();
MGB_LOCK_GUARD(m_all_active_channels_mutex);
m_all_active_channels.erase(this);
}

void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
@@ -992,7 +1012,7 @@ void ChannelImpl::detach_users(TensorInfo* dest) {
}

bool ChannelImpl::check_available() {
return !m_closed;
return m_status == ChannelRunningStatus::RUNING;
}

TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
@@ -1352,7 +1372,7 @@ void ChannelImpl::check_worker_exc_unsafe() {

void ChannelImpl::start_profile() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto capture_tensors = collect_valid_tensors();
if (capture_tensors.size() > 0) {
if (Profiler::is_profiling()) {
@@ -1370,7 +1390,7 @@ void ChannelImpl::start_profile() {

void ChannelImpl::stop_profile() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto escape_tensors = collect_valid_tensors();
if (escape_tensors.size() > 0) {
if (Profiler::is_profiling()) {
@@ -1388,7 +1408,7 @@ void ChannelImpl::stop_profile() {

void ChannelImpl::push_scope(std::string name) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state();
state.stack_manager.enter(name);
MGB_RECORD_EVENT(ScopeEvent, name);
@@ -1406,7 +1426,7 @@ void ChannelImpl::push_scope(std::string name) {

void ChannelImpl::pop_scope(std::string name) {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state();
state.stack_manager.exit(name);
MGB_RECORD_EVENT(ScopeFinishEvent, name);
@@ -1422,6 +1442,31 @@ void ChannelImpl::pop_scope(std::string name) {
}
}

bool ChannelImpl::worker_started() const {
return m_worker.worker_started();
}

void ChannelImpl::update_status_to_forked(void) {
MGB_LOCK_GUARD(m_spin);
m_status = ChannelRunningStatus::FORKED;
}

void ChannelImpl::assert_available() const {
if (m_status == ChannelRunningStatus::RUNING) {
return;
} else if (m_status == ChannelRunningStatus::CLOSED) {
mgb_assert(false, "Channel already closed");
} else if (m_status == ChannelRunningStatus::FORKED) {
mgb_assert(
false,
"your program is forked and megengine is be disabled in subprocess, if "
"you want to use megengine in subprocess, please DO NOT setup and use "
"megengine before fork");
} else {
mgb_assert(false, "impossible, Channel status is undefined");
}
}

void ChannelImpl::assert_in_channel() {
mgb_assert(
get_worker_tid() != std::this_thread::get_id(),


+ 11
- 8
imperative/src/impl/interpreter/interpreter_impl.h View File

@@ -27,7 +27,7 @@ struct InterpreterImpl : Interpreter {
std::unique_ptr<Channel> create_channel() override;
};

struct ChannelImpl : Interpreter::Channel {
struct ChannelImpl : Interpreter::Channel, NonCopyableObj, NonMoveableObj {
ChannelImpl();
~ChannelImpl() override;

@@ -61,6 +61,13 @@ struct ChannelImpl : Interpreter::Channel {
void push_scope(std::string) override;
void pop_scope(std::string) override;

bool worker_started() const;
void update_status_to_forked(void);
void assert_available() const;

static std::unordered_set<ChannelImpl*> m_all_active_channels;
static MGB_MUTEX m_all_active_channels_mutex;

private:
struct WorkQueue;
struct State;
@@ -130,7 +137,9 @@ private:
// TODO: use explicit struct
std::stack<std::tuple<ApplyOp, size_t, TensorInfo*, std::string>> m_apply_stack;
bool m_applying = false;
bool m_closed = false;

enum class ChannelRunningStatus { RUNING, CLOSED, FORKED };
ChannelRunningStatus m_status = ChannelRunningStatus::RUNING;

struct WorkQueue : AsyncQueueSC<Command, WorkQueue> {
// set max_spin=0 to prevent Queue fetch task in busy wait manner.
@@ -159,12 +168,6 @@ private:
ChannelImpl* m_owner;
} m_worker;

//! config whether raise error exactly when invoking op.
//! level 2: both device and user side errors are async;
//! level 1: user side errors are sync;
//! level 0: both sync.
int m_async_level = 2;

struct State {
std::thread::id tid;
OptionManager options;


+ 3
- 0
imperative/src/include/megbrain/imperative/interpreter.h View File

@@ -60,6 +60,9 @@ struct Interpreter {
virtual std::unique_ptr<Channel> create_channel() = 0;

static Interpreter& inst();

protected:
Interpreter() = default;
};

} // namespace mgb::imperative::interpreter

+ 11
- 0
src/core/include/megbrain/utils/metahelper.h View File

@@ -151,6 +151,17 @@ public:
NonCopyableObj() = default;
};

/*!
* \brief base class for non-moveable objects
*/
class NonMoveableObj {
NonMoveableObj(NonMoveableObj&&) = delete;
NonMoveableObj& operator=(NonMoveableObj&&) = delete;

public:
NonMoveableObj() = default;
};

template <typename T>
class ReverseAdaptor {
T& m_t;


+ 2
- 0
src/core/include/megbrain/utils/thread_impl_1.h View File

@@ -253,6 +253,8 @@ public:
}
}

inline bool worker_started() const { return m_synchronizer.worker_started(); }

protected:
~AsyncQueueSC() noexcept = default;



Loading…
Cancel
Save