Browse Source

fix(mm): fix mm error when use sync

GitOrigin-RevId: 63387bda04
release-1.11
Megvii Engine Team 2 years ago
parent
commit
b8c7557b86
7 changed files with 102 additions and 31 deletions
  1. +6
    -0
      imperative/python/megengine/distributed/helper.py
  2. +2
    -1
      imperative/python/megengine/distributed/launcher.py
  3. +67
    -22
      imperative/src/impl/interpreter/interpreter_impl.cpp
  4. +11
    -8
      imperative/src/impl/interpreter/interpreter_impl.h
  5. +3
    -0
      imperative/src/include/megbrain/imperative/interpreter.h
  6. +11
    -0
      src/core/include/megbrain/utils/metahelper.h
  7. +2
    -0
      src/core/include/megbrain/utils/thread_impl_1.h

+ 6
- 0
imperative/python/megengine/distributed/helper.py View File

@@ -162,6 +162,12 @@ def _check_device_initialized(device_type: str, rank: int):
raise RuntimeError(errmsg) raise RuntimeError(errmsg)




def _check_interpreter_status():
from ..core._imperative_rt.core2 import get_option

_ = get_option("async_level")


get_device_count_by_fork = deprecated_func( get_device_count_by_fork = deprecated_func(
"1.5", "megengine.device", "get_device_count", False "1.5", "megengine.device", "get_device_count", False
) )


+ 2
- 1
imperative/python/megengine/distributed/launcher.py View File

@@ -9,7 +9,7 @@ from ..core._imperative_rt.core2 import full_sync
from ..device import get_device_count from ..device import get_device_count
from ..logger import get_logger from ..logger import get_logger
from .group import _set_machine_ranks, group_barrier, init_process_group from .group import _set_machine_ranks, group_barrier, init_process_group
from .helper import _check_device_initialized
from .helper import _check_device_initialized, _check_interpreter_status
from .server import Client, Server from .server import Client, Server


WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
@@ -33,6 +33,7 @@ def _run_wrapped(
machine_ranks: list, machine_ranks: list,
): ):
r"""Init distributed process group and run wrapped function.""" r"""Init distributed process group and run wrapped function."""
_check_interpreter_status()
_check_device_initialized(device_type, dev) _check_device_initialized(device_type, dev)
init_process_group( init_process_group(
master_ip=master_ip, master_ip=master_ip,


+ 67
- 22
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -115,7 +115,18 @@ void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
#define m_worker_state #define m_worker_state


std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() { std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() {
return std::make_unique<ChannelImpl>();
auto ret = std::make_unique<ChannelImpl>();
#if !(defined(_WIN32) || defined(_WIN64))
auto disable_channels = [](void) -> void {
for (ChannelImpl* channel : ChannelImpl::m_all_active_channels) {
if (channel->worker_started()) {
channel->update_status_to_forked();
}
}
};
pthread_atfork(nullptr, nullptr, static_cast<void (*)(void)>(disable_channels));
#endif
return ret;
} }


Interpreter& Interpreter::inst() { Interpreter& Interpreter::inst() {
@@ -125,7 +136,7 @@ Interpreter& Interpreter::inst() {


Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
std::optional<StackManager::Guard> guard; std::optional<StackManager::Guard> guard;
if (Profiler::is_profiling()) { if (Profiler::is_profiling()) {
auto& state = get_channel_state(); auto& state = get_channel_state();
@@ -158,7 +169,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
Put{info, value, no_cache}, Put{info, value, no_cache},
}); });
} }
if (m_async_level == 0) {

if (get_channel_state().options.async_level == 0) {
sync_impl(); sync_impl();
info->desc.comp_node.sync(); info->desc.comp_node.sync();
auto err = info->desc.comp_node.check_async_error(); auto err = info->desc.comp_node.check_async_error();
@@ -169,7 +181,7 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {


Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) { Handle ChannelImpl::put(const DeviceTensorND& data, const HostTensorND& hvalue) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
return reinterpret_cast<Handle>(put_impl(data, hvalue)); return reinterpret_cast<Handle>(put_impl(data, hvalue));
} }
TensorInfo* ChannelImpl::put_impl( TensorInfo* ChannelImpl::put_impl(
@@ -221,7 +233,7 @@ void ChannelImpl::del_impl(Handle handle) {


void ChannelImpl::drop(Handle handle) { void ChannelImpl::drop(Handle handle) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state(); auto& state = get_channel_state();
if (state.options.enable_drop) { if (state.options.enable_drop) {
mgb_assert( mgb_assert(
@@ -404,7 +416,7 @@ void ChannelImpl::dispatch_kernel(
SmallVector<Handle> ChannelImpl::apply_op( SmallVector<Handle> ChannelImpl::apply_op(
std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) { std::shared_ptr<OpDef> op, const SmallVector<Handle>& inputs) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto* input = reinterpret_cast<TensorInfo*>(inputs[0]); auto* input = reinterpret_cast<TensorInfo*>(inputs[0]);
if (op->same_type<GetVarShape>() && input->shape_valid()) { if (op->same_type<GetVarShape>() && input->shape_valid()) {
size_t ndim = input->desc.layout.ndim; size_t ndim = input->desc.layout.ndim;
@@ -460,7 +472,7 @@ SmallVector<Handle> ChannelImpl::apply_op_impl(


HostTensorND ChannelImpl::get_value(Handle handle) { HostTensorND ChannelImpl::get_value(Handle handle) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert( mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle); handle);
@@ -472,7 +484,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) {


TensorShape ChannelImpl::get_shape(Handle handle) { TensorShape ChannelImpl::get_shape(Handle handle) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert( mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle); handle);
@@ -487,7 +499,7 @@ TensorShape ChannelImpl::get_shape(Handle handle) {


DType ChannelImpl::get_dtype(Handle handle) { DType ChannelImpl::get_dtype(Handle handle) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert( mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle); handle);
@@ -500,7 +512,7 @@ DType ChannelImpl::get_dtype(Handle handle) {


CompNode ChannelImpl::get_device(Handle handle) { CompNode ChannelImpl::get_device(Handle handle) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert( mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle); handle);
@@ -513,7 +525,7 @@ CompNode ChannelImpl::get_device(Handle handle) {


DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
mgb_assert( mgb_assert(
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p",
handle); handle);
@@ -523,7 +535,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {


void ChannelImpl::sync() { void ChannelImpl::sync() {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
sync_impl(); sync_impl();
} }


@@ -545,19 +557,19 @@ void ChannelImpl::close() {
mgb_assert(m_valid_handle.empty()); mgb_assert(m_valid_handle.empty());
mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size()); mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
sync_impl(); sync_impl();
m_closed = true;
m_status = ChannelRunningStatus::CLOSED;
} }


size_t ChannelImpl::get_option(std::string name) { size_t ChannelImpl::get_option(std::string name) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state(); auto& state = get_channel_state();
return state.options.get_option(name); return state.options.get_option(name);
} }


void ChannelImpl::set_option(std::string name, size_t value) { void ChannelImpl::set_option(std::string name, size_t value) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state(); auto& state = get_channel_state();
state.options.set_option(name, value); state.options.set_option(name, value);
// FIXME // FIXME
@@ -583,7 +595,7 @@ void ChannelImpl::set_option(std::string name, size_t value) {


void ChannelImpl::clear_candidates() { void ChannelImpl::clear_candidates() {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
m_dtr.candidates.clear(); m_dtr.candidates.clear();
} }


@@ -681,10 +693,18 @@ void ChannelImpl::real_free(TensorInfo* ptr) {
m_pool.free(ptr); m_pool.free(ptr);
} }


ChannelImpl::ChannelImpl() : m_worker(this) {}
std::unordered_set<ChannelImpl*> ChannelImpl::m_all_active_channels{};
MGB_MUTEX ChannelImpl::m_all_active_channels_mutex{};

ChannelImpl::ChannelImpl() : m_worker(this) {
MGB_LOCK_GUARD(m_all_active_channels_mutex);
m_all_active_channels.emplace(this);
}


ChannelImpl::~ChannelImpl() { ChannelImpl::~ChannelImpl() {
close(); close();
MGB_LOCK_GUARD(m_all_active_channels_mutex);
m_all_active_channels.erase(this);
} }


void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
@@ -992,7 +1012,7 @@ void ChannelImpl::detach_users(TensorInfo* dest) {
} }


bool ChannelImpl::check_available() { bool ChannelImpl::check_available() {
return !m_closed;
return m_status == ChannelRunningStatus::RUNING;
} }


TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
@@ -1352,7 +1372,7 @@ void ChannelImpl::check_worker_exc_unsafe() {


void ChannelImpl::start_profile() { void ChannelImpl::start_profile() {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto capture_tensors = collect_valid_tensors(); auto capture_tensors = collect_valid_tensors();
if (capture_tensors.size() > 0) { if (capture_tensors.size() > 0) {
if (Profiler::is_profiling()) { if (Profiler::is_profiling()) {
@@ -1370,7 +1390,7 @@ void ChannelImpl::start_profile() {


void ChannelImpl::stop_profile() { void ChannelImpl::stop_profile() {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto escape_tensors = collect_valid_tensors(); auto escape_tensors = collect_valid_tensors();
if (escape_tensors.size() > 0) { if (escape_tensors.size() > 0) {
if (Profiler::is_profiling()) { if (Profiler::is_profiling()) {
@@ -1388,7 +1408,7 @@ void ChannelImpl::stop_profile() {


void ChannelImpl::push_scope(std::string name) { void ChannelImpl::push_scope(std::string name) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state(); auto& state = get_channel_state();
state.stack_manager.enter(name); state.stack_manager.enter(name);
MGB_RECORD_EVENT(ScopeEvent, name); MGB_RECORD_EVENT(ScopeEvent, name);
@@ -1406,7 +1426,7 @@ void ChannelImpl::push_scope(std::string name) {


void ChannelImpl::pop_scope(std::string name) { void ChannelImpl::pop_scope(std::string name) {
MGB_LOCK_GUARD(m_spin); MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
assert_available();
auto& state = get_channel_state(); auto& state = get_channel_state();
state.stack_manager.exit(name); state.stack_manager.exit(name);
MGB_RECORD_EVENT(ScopeFinishEvent, name); MGB_RECORD_EVENT(ScopeFinishEvent, name);
@@ -1422,6 +1442,31 @@ void ChannelImpl::pop_scope(std::string name) {
} }
} }


bool ChannelImpl::worker_started() const {
return m_worker.worker_started();
}

void ChannelImpl::update_status_to_forked(void) {
MGB_LOCK_GUARD(m_spin);
m_status = ChannelRunningStatus::FORKED;
}

void ChannelImpl::assert_available() const {
if (m_status == ChannelRunningStatus::RUNING) {
return;
} else if (m_status == ChannelRunningStatus::CLOSED) {
mgb_assert(false, "Channel already closed");
} else if (m_status == ChannelRunningStatus::FORKED) {
mgb_assert(
false,
"your program is forked and megengine is be disabled in subprocess, if "
"you want to use megengine in subprocess, please DO NOT setup and use "
"megengine before fork");
} else {
mgb_assert(false, "impossible, Channel status is undefined");
}
}

void ChannelImpl::assert_in_channel() { void ChannelImpl::assert_in_channel() {
mgb_assert( mgb_assert(
get_worker_tid() != std::this_thread::get_id(), get_worker_tid() != std::this_thread::get_id(),


+ 11
- 8
imperative/src/impl/interpreter/interpreter_impl.h View File

@@ -27,7 +27,7 @@ struct InterpreterImpl : Interpreter {
std::unique_ptr<Channel> create_channel() override; std::unique_ptr<Channel> create_channel() override;
}; };


struct ChannelImpl : Interpreter::Channel {
struct ChannelImpl : Interpreter::Channel, NonCopyableObj, NonMoveableObj {
ChannelImpl(); ChannelImpl();
~ChannelImpl() override; ~ChannelImpl() override;


@@ -61,6 +61,13 @@ struct ChannelImpl : Interpreter::Channel {
void push_scope(std::string) override; void push_scope(std::string) override;
void pop_scope(std::string) override; void pop_scope(std::string) override;


bool worker_started() const;
void update_status_to_forked(void);
void assert_available() const;

static std::unordered_set<ChannelImpl*> m_all_active_channels;
static MGB_MUTEX m_all_active_channels_mutex;

private: private:
struct WorkQueue; struct WorkQueue;
struct State; struct State;
@@ -130,7 +137,9 @@ private:
// TODO: use explicit struct // TODO: use explicit struct
std::stack<std::tuple<ApplyOp, size_t, TensorInfo*, std::string>> m_apply_stack; std::stack<std::tuple<ApplyOp, size_t, TensorInfo*, std::string>> m_apply_stack;
bool m_applying = false; bool m_applying = false;
bool m_closed = false;

enum class ChannelRunningStatus { RUNING, CLOSED, FORKED };
ChannelRunningStatus m_status = ChannelRunningStatus::RUNING;


struct WorkQueue : AsyncQueueSC<Command, WorkQueue> { struct WorkQueue : AsyncQueueSC<Command, WorkQueue> {
// set max_spin=0 to prevent Queue fetch task in busy wait manner. // set max_spin=0 to prevent Queue fetch task in busy wait manner.
@@ -159,12 +168,6 @@ private:
ChannelImpl* m_owner; ChannelImpl* m_owner;
} m_worker; } m_worker;


//! config whether raise error exactly when invoking op.
//! level 2: both device and user side errors are async;
//! level 1: user side errors are sync;
//! level 0: both sync.
int m_async_level = 2;

struct State { struct State {
std::thread::id tid; std::thread::id tid;
OptionManager options; OptionManager options;


+ 3
- 0
imperative/src/include/megbrain/imperative/interpreter.h View File

@@ -60,6 +60,9 @@ struct Interpreter {
virtual std::unique_ptr<Channel> create_channel() = 0; virtual std::unique_ptr<Channel> create_channel() = 0;


static Interpreter& inst(); static Interpreter& inst();

protected:
Interpreter() = default;
}; };


} // namespace mgb::imperative::interpreter } // namespace mgb::imperative::interpreter

+ 11
- 0
src/core/include/megbrain/utils/metahelper.h View File

@@ -151,6 +151,17 @@ public:
NonCopyableObj() = default; NonCopyableObj() = default;
}; };


/*!
* \brief base class for non-moveable objects
*/
class NonMoveableObj {
NonMoveableObj(NonMoveableObj&&) = delete;
NonMoveableObj& operator=(NonMoveableObj&&) = delete;

public:
NonMoveableObj() = default;
};

template <typename T> template <typename T>
class ReverseAdaptor { class ReverseAdaptor {
T& m_t; T& m_t;


+ 2
- 0
src/core/include/megbrain/utils/thread_impl_1.h View File

@@ -253,6 +253,8 @@ public:
} }
} }


inline bool worker_started() const { return m_synchronizer.worker_started(); }

protected: protected:
~AsyncQueueSC() noexcept = default; ~AsyncQueueSC() noexcept = default;




Loading…
Cancel
Save