Browse Source

refactor(imperative): remove command buffer

GitOrigin-RevId: 83c8cb6d3b
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
a5af35c18c
10 changed files with 38 additions and 229 deletions
  1. +0
    -1
      imperative/python/megengine/dtr/dtr.py
  2. +0
    -4
      imperative/python/src/tensor.cpp
  3. +0
    -3
      imperative/python/test/integration/test_converge_with_drop.py
  4. +4
    -0
      imperative/python/test/unit/random/test_rng.py
  5. +0
    -2
      imperative/src/impl/interpreter/commands.h
  6. +31
    -148
      imperative/src/impl/interpreter/interpreter_impl.cpp
  7. +0
    -45
      imperative/src/impl/interpreter/interpreter_impl.h
  8. +0
    -3
      imperative/src/impl/interpreter/option_manager.h
  9. +3
    -20
      imperative/src/impl/proxy_graph.cpp
  10. +0
    -3
      imperative/src/impl/proxy_graph.h

+ 0
- 1
imperative/python/megengine/dtr/dtr.py View File

@@ -120,7 +120,6 @@ def enable():
r"""Enable to record computing path of tensors and to perform DTR policy."""
_set_option("enable_dtr_auto_drop", 1)
_set_option("enable_drop", 1)
_set_option("buffer_length", 0)
_set_option("record_computing_path", 1)




+ 0
- 4
imperative/python/src/tensor.cpp View File

@@ -702,10 +702,6 @@ void init_tensor(py::module m) {
});
m.def("get_option",
[channel](std::string name) { return channel->get_option(name); });
m.def("set_buffer_length", [channel](int length) {
mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)");
channel->set_option("buffer_length", length);
});
m.def("push_scope", [channel](std::string name) {
Transformation::push_scope(name);
channel->push_scope(name);


+ 0
- 3
imperative/python/test/integration/test_converge_with_drop.py View File

@@ -76,8 +76,6 @@ class XORNet(Module):

def test_training_converge_with_drop():
set_option("enable_drop", 1)
old_buffer_length = get_option("buffer_length")
set_option("buffer_length", 0)
net = XORNet()
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
gm = ad.GradManager().attach(net.parameters())
@@ -119,4 +117,3 @@ def test_training_converge_with_drop():
)

set_option("enable_drop", 0)
set_option("buffer_length", old_buffer_length)

+ 4
- 0
imperative/python/test/unit/random/test_rng.py View File

@@ -9,6 +9,7 @@
import numpy as np
import pytest

import megengine as mge
import megengine.functional as F
from megengine import Tensor, jit, random
from megengine.core._imperative_rt import CompNode
@@ -209,9 +210,12 @@ def test_permutation_op():
assert str(output.device) == str(cn)
assert output.dtype == dtype

# FIXME: remove this sync
mge.core.set_option("async_level", 0)
test_permutation_op_dtype(np.float32)
test_permutation_op_dtype(np.int32)
test_permutation_op_dtype(np.int16)
mge.core.set_option("async_level", 2)


@pytest.mark.skipif(


+ 0
- 2
imperative/src/impl/interpreter/commands.h View File

@@ -49,14 +49,12 @@ struct ApplyOp {
std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> outputs;
SmallVector<TensorInfo*> dels;

template <typename TFunctor>
void get_props(TFunctor&& functor) const {
functor("op", op);
functor("inputs", inputs);
functor("outputs", outputs);
functor("dels", dels);
}

const char* get_name() const { return "ApplyOp"; }


+ 31
- 148
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -156,7 +156,9 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
info->desc.value = value.proxy_to_default_cpu();
}
info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
m_buffer.enqueue(Put{info, value, no_cache});
m_worker.add_task(
{Profiler::next_id(), Put{info, value, no_cache},
get_channel_state().stack_manager.dump()});
if (m_async_level == 0) {
sync_impl();
info->desc.comp_node.sync();
@@ -200,7 +202,8 @@ void ChannelImpl::del_impl(Handle handle) {
mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle);
m_valid_handle.erase(handle);
m_buffer.enqueue(Del{info});
m_worker.add_task(
{Profiler::next_id(), Del{info}, get_channel_state().stack_manager.dump()});
}

void ChannelImpl::drop(Handle handle) {
@@ -212,7 +215,9 @@ void ChannelImpl::drop(Handle handle) {
m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle);
m_buffer.enqueue(Drop{info});
m_worker.add_task(
{Profiler::next_id(), Drop{info},
get_channel_state().stack_manager.dump()});
}
}

@@ -333,7 +338,9 @@ void ChannelImpl::dispatch_kernel(
MGB_RECORD_EVENT(
OpDispatchEvent, cmd.id, name, op_info_getter, tinfo_to_tid(cmd.inputs),
tinfo_to_tid(cmd.outputs), state.stack_manager.dump());
m_buffer.enqueue(std::move(cmd));
m_worker.add_task(
{Profiler::next_id(), std::move(cmd),
get_channel_state().stack_manager.dump()});
if (!validated && options.async_level == 1) {
sync_impl();
} else if (options.async_level == 0) {
@@ -466,7 +473,6 @@ void ChannelImpl::sync() {
}

void ChannelImpl::sync_impl() {
m_buffer.flush();
m_worker.wait_all_task_finish();
MGB_LOCK_GUARD(m_mutex);
check_worker_exc_unsafe();
@@ -499,7 +505,9 @@ void ChannelImpl::set_option(std::string name, size_t value) {
mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state();
state.options.set_option(name, value);
m_buffer.enqueue(SetOption{name, value});
m_worker.add_task(
{Profiler::next_id(), SetOption{name, value},
get_channel_state().stack_manager.dump()});
}

void ChannelImpl::clear_candidates() {
@@ -604,7 +612,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) {
m_pool.free(ptr);
}

ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this) {}
ChannelImpl::ChannelImpl() : m_worker(this) {}

ChannelImpl::~ChannelImpl() {
close();
@@ -645,7 +653,7 @@ void ChannelImpl::regenerate(TensorInfo* dest) {
if (dest->evict_type == EvictType::DROP) {
auto&& path = dest->producer;
m_apply_stack.push(
{ApplyOp{path->id, path->op, path->inputs, path->outputs, {}}, 0, dest,
{ApplyOp{path->id, path->op, path->inputs, path->outputs}, 0, dest,
"dtr"});
if (!m_applying)
flush_apply_stack();
@@ -748,19 +756,6 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
MGB_RECORD_EVENT(TensorUsageEvent, input_id);
MGB_RECORD_EVENT(OpInputFinishEvent, input_id);
}
// Fused by command buffer. @see: CommandBuffer::fuse_del
// Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by
// tensor_inputs after Del. Note for exprs like 'y = x op x', inplace is unsupported
// yet but Del would be also fused.
for (auto* del : cmd.dels) {
// refcnt --, owners: [tensor_inputs]
// if it's decreased to 1, would be detected at @see:
// proxy_graph_detail::apply_on_physical_tensor
uint64_t del_id = del->id;
MGB_RECORD_EVENT(TensorCommandEvent, del_id, TensorCommandKind::Del);
free(del);
MGB_RECORD_EVENT(TensorCommandFinishEvent, del_id, TensorCommandKind::Del);
}
// Before wait
// TODO: split operator wait and execute so that OpWait could be corrected recorded.
// Before execute
@@ -931,7 +926,6 @@ bool ChannelImpl::check_available() {
}

TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
m_buffer.flush();
std::unique_lock<decltype(m_mutex)> lock(m_mutex);
mgb_assert(!m_waitee, "duplicate waitee");
m_waitee = info;
@@ -943,8 +937,9 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
if (require_host && !host_available()) {
// avoid dead lock
lock.unlock();
m_buffer.enqueue(GetValue{info});
m_buffer.flush();
m_worker.add_task(
{Profiler::next_id(), GetValue{info},
get_channel_state().stack_manager.dump()});
lock.lock();
wait_host = true;
}
@@ -1266,141 +1261,25 @@ void ChannelImpl::check_worker_exc_unsafe() {
}
}

void ChannelImpl::CommandBuffer::enqueue(CommandData cmd) {
auto& state = m_owner->get_channel_state();
if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
return;
}
m_commands.push_back(
{Profiler::next_id(), std::move(cmd), state.stack_manager.dump()});
auto flush_pos = flush_pos_for(m_commands.back());
flush(flush_pos);
}

void ChannelImpl::CommandBuffer::flush() {
flush(m_commands.end());
}

void ChannelImpl::CommandBuffer::flush(Handle pos) {
for (auto iter = m_commands.begin(); iter != pos; ++iter) {
if (Profiler::is_profiling()) {
mgb_log_debug("%s Flushed", to_string(*iter).c_str());
}
m_owner->m_worker.add_task(std::move(*iter));
}
m_commands.erase(m_commands.begin(), pos);
}

auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
auto& state = m_owner->get_channel_state();
return std::visit(
[this, &state](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, ApplyOp>) {
auto* op_type = cmd.op->dyn_typeinfo();
if (op_type == RemoteRecv::typeinfo() ||
op_type == RemoteSend::typeinfo() ||
op_type == CollectiveComm::typeinfo() ||
op_type == opr::InputCallback::typeinfo() ||
op_type == opr::OutputCallback::typeinfo()) {
return m_commands.end();
}
} else if constexpr (std::is_same_v<T, GetValue>) {
return m_commands.end();
}
size_t buffer_length = state.options.buffer_length;
if (m_commands.size() > buffer_length) {
return m_commands.begin() + (m_commands.size() - buffer_length);
}
return m_commands.begin();
},
cmd.data);
}

/**
* 1. Find ApplyOp(dest) in buffered commands
* 2. Check if there are other usages between ApplyOp and Del, return false if not
* 3. Fuse Del into ApplyOp, return true
*/
bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
auto* dest = cmd.dest;
// TODO: eliminate Puts
auto begin = m_commands.begin(), end = m_commands.end();
auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd) {
if (auto* apply = std::get_if<ApplyOp>(&cmd.data)) {
return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0;
}
return false;
});
if (apply_iter == end || find_last_usage(dest, {apply_iter + 1, end}) != end) {
return false;
}
std::get<ApplyOp>(apply_iter->data).dels.push_back(dest);
return true;
}

auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range)
-> Handle {
auto found = range[1];
for (auto iter = range[0]; iter != range[1]; ++iter) {
std::visit(
[&](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, ApplyOp>) {
if (std::count(cmd.inputs.begin(), cmd.inputs.end(), dest) >
0) {
found = iter;
}
} else if constexpr (std::is_same_v<T, GetValue>) {
if (cmd.dest == dest) {
found = iter;
}
} else if constexpr (std::is_same_v<T, Drop>) {
// TODO: ignore swap-like commands, just remove them from buffer
if (cmd.dest == dest) {
found = iter;
}
}
},
iter->data);
};
return found;
}

auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) -> Handle {
return std::find_if(range[0], range[1], [dest](auto& cmd) {
return std::visit(
[dest](const auto& cmd) {
using T = std::decay_t<decltype(cmd)>;
if constexpr (std::is_same_v<T, ApplyOp>) {
return std::count(
cmd.outputs.begin(), cmd.outputs.end(), dest) >
0;
} else if constexpr (std::is_same_v<T, Put>) {
return cmd.dest == dest;
}
return false;
},
cmd.data);
});
}

void ChannelImpl::start_profile() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
auto capture_tensors = collect_valid_tensors();
if (capture_tensors.size() > 0) {
m_buffer.enqueue(StartProfile{std::move(capture_tensors)});
m_worker.add_task(
{Profiler::next_id(), StartProfile{std::move(capture_tensors)},
get_channel_state().stack_manager.dump()});
}
}

void ChannelImpl::stop_profile() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
m_buffer.flush();
auto escape_tensors = collect_valid_tensors();
if (escape_tensors.size() > 0) {
m_buffer.enqueue(StopProfile{std::move(escape_tensors)});
m_worker.add_task(
{Profiler::next_id(), StopProfile{std::move(escape_tensors)},
get_channel_state().stack_manager.dump()});
}
}

@@ -1410,7 +1289,9 @@ void ChannelImpl::push_scope(std::string name) {
auto& state = get_channel_state();
state.stack_manager.enter(name);
MGB_RECORD_EVENT(ScopeEvent, name);
m_buffer.enqueue(PushScope{name});
m_worker.add_task(
{Profiler::next_id(), PushScope{name},
get_channel_state().stack_manager.dump()});
}

void ChannelImpl::pop_scope(std::string name) {
@@ -1419,7 +1300,9 @@ void ChannelImpl::pop_scope(std::string name) {
auto& state = get_channel_state();
state.stack_manager.exit(name);
MGB_RECORD_EVENT(ScopeFinishEvent, name);
m_buffer.enqueue(PopScope{name});
m_worker.add_task(
{Profiler::next_id(), PopScope{name},
get_channel_state().stack_manager.dump()});
}

void ChannelImpl::assert_in_channel() {


+ 0
- 45
imperative/src/impl/interpreter/interpreter_impl.h View File

@@ -126,11 +126,6 @@ private:
void assert_in_worker();
std::thread::id get_worker_tid();

// template <typename TCommand>
// void enqueue_command(TCommand&& cmd) {
// m_buffer.enqueue(Command{std::forward<TCommand>(cmd)});
// }

void sample_on_device(CompNode device, bool force);

// valid => status != Deleted
@@ -178,46 +173,6 @@ private:
ChannelImpl* m_owner;
} m_worker;

/**
* Buf a command window for following fuse
* example:
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} |
* ---------------------------------------------------------------------
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... |
* ---------------------------------------------------------------------
* Then the fused Apply may be invoked inplace. see:
* ChannelImpl::process_one_task
*/
struct CommandBuffer {
CommandBuffer(ChannelImpl* owner) : m_owner(owner) {}
void enqueue(CommandData cmd);
bool empty() const { return m_commands.empty(); }
void flush();

private:
ChannelImpl* m_owner;
std::deque<Command> m_commands;

using Handle = decltype(m_commands)::iterator;
// [begin, end)
using Range = std::array<Handle, 2>;

// Launch commands in range [m_commands.begin(), pos)
void flush(Handle pos);
// Select flush position for incoming cmd
Handle flush_pos_for(const Command& cmd);
// Fuse del command into suitable ApplyOp
bool fuse_del(const Del& cmd);
// Returns the last handle that dest is used within range. If dest is not used,
// returns range[1]
Handle find_last_usage(TensorInfo* dest, Range range);
// Returns the produce position of dest. If not found, returns range[1]
Handle find_produce(TensorInfo* dest, Range range);
} m_buffer;

//! config whether raise error exactly when invoking op.
//! level 2: both device and user side errors are async;
//! level 1: user side errors are sync;


+ 0
- 3
imperative/src/impl/interpreter/option_manager.h View File

@@ -41,9 +41,6 @@ public:
catch_worker_execption, "MEGENGINE_CATCH_WORKER_EXEC", 1,
"catch worker exception if enabled, close it when debugging");
DEF_OPTION(
buffer_length, "MEGENGINE_COMMAND_BUFFER_LENGTH", 3,
"set command buffer length.");
DEF_OPTION(
enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1,
"enable host compute, thus computation may be done in host event if it's "
"device is gpu.");


+ 3
- 20
imperative/src/impl/proxy_graph.cpp View File

@@ -626,23 +626,12 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(

/*********************** Logical Tensor Impl ***********************/

size_t ProxyGraph::get_opr_output_size(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs) {
return get_proxy_opr(opdef, inputs)->usable_output().size();
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph::
infer_output_attrs_fallible(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs) {
auto opr = get_proxy_opr(opdef, inputs);
CUR_OPR_GUARD(opr);
SmallVector<LogicalTensorDesc> outputs;
bool validated = do_shape_infer(false);
for (auto&& i : opr->usable_output()) {
outputs.push_back({{i->shape(), i->dtype()}, i->comp_node()});
}
bool need_check = opr->same_type<opr::Reshape>();
return {outputs, validated && !need_check};
// this function is just a placeholder
// it will be overrided by ProxyGraphTypeI::infer_output_attrs_fallible in minigraph
mgb_assert(0);
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph::
@@ -823,12 +812,6 @@ EncodedSubgraph ProxyGraph::make_backward_graph(
return result;
}

cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(!m_cur_opr);
auto vinputs = make_input_place_holders(inputs);
return OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr();
}

VarNodeArray ProxyGraph::make_input_place_holders(
const SmallVector<LogicalTensorDesc>& inputs) {


+ 0
- 3
imperative/src/impl/proxy_graph.h View File

@@ -85,9 +85,6 @@ private:

/********************** Logical Tensor Helper **********************/

cg::OperatorNodeBase* get_proxy_opr(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& inputs);

cg::VarNodeArray make_input_place_holders(
const SmallVector<LogicalTensorDesc>& inputs);



Loading…
Cancel
Save