GitOrigin-RevId: c68977c5dc
release-1.10
@@ -0,0 +1,21 @@ | |||||
from .core._imperative_rt.core2 import apply | |||||
from .core.ops.builtin import Barrier | |||||
from .tensor import Tensor | |||||
_dummy_tensors = {} | |||||
def _get_dummy_tensor(device): | |||||
if device not in _dummy_tensors: | |||||
_dummy_tensors[device] = Tensor([], device=device) | |||||
return _dummy_tensors[device] | |||||
def record_event(device): | |||||
x = _get_dummy_tensor(device) | |||||
(x,) = apply(Barrier(device, 1), x) | |||||
return x | |||||
def wait_event(device, event): | |||||
apply(Barrier(device, 0), event) |
@@ -51,7 +51,7 @@ _sh = _stream_helper() | |||||
def _valid_device(inp): | def _valid_device(inp): | ||||
if isinstance(inp, str) and re.match( | if isinstance(inp, str) and re.match( | ||||
"^([cxg]pu|rocm|multithread)(\d+|\d+:\d+|x)$", inp | |||||
"^([cxg]pu|rocm|multithread)(x|\d+)(:\d+)?$", inp | |||||
): | ): | ||||
return True | return True | ||||
return False | return False | ||||
@@ -255,8 +255,8 @@ def what_is_xpu(): | |||||
def coalesce_free_memory(): | def coalesce_free_memory(): | ||||
r"""This function will try it best to free all consecutive free chunks back to operating system, | r"""This function will try it best to free all consecutive free chunks back to operating system, | ||||
small pieces may not be returned. | small pieces may not be returned. | ||||
because of the async processing of megengine, the effect of this func may not be reflected | |||||
because of the async processing of megengine, the effect of this func may not be reflected | |||||
immediately. if you want to see the effect immediately, you can call megengine._full_sync after | immediately. if you want to see the effect immediately, you can call megengine._full_sync after | ||||
this func was called | this func was called | ||||
@@ -15,7 +15,7 @@ from .core._imperative_rt.core2 import Tensor as _Tensor | |||||
from .core._imperative_rt.core2 import apply, set_py_tensor_type | from .core._imperative_rt.core2 import apply, set_py_tensor_type | ||||
from .core._trace_option import use_symbolic_shape | from .core._trace_option import use_symbolic_shape | ||||
from .core._wrap import as_device | from .core._wrap import as_device | ||||
from .core.ops.builtin import Copy, GetVarShape | |||||
from .core.ops.builtin import Borrow, Copy, GetVarShape | |||||
from .core.tensor.array_method import ArrayMethodMixin | from .core.tensor.array_method import ArrayMethodMixin | ||||
from .device import _valid_device, get_default_device | from .device import _valid_device, get_default_device | ||||
from .logger import get_logger | from .logger import get_logger | ||||
@@ -205,7 +205,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
def reset_zero(self): | def reset_zero(self): | ||||
self *= 0 | self *= 0 | ||||
def to(self, device): | |||||
def to(self, device, *, _borrow=False): | |||||
r"""Copy self :class:`~.Tensor` to specified device. See :func:`~.copy`""" | r"""Copy self :class:`~.Tensor` to specified device. See :func:`~.copy`""" | ||||
if isinstance(device, str) and not _valid_device(device): | if isinstance(device, str) and not _valid_device(device): | ||||
raise ValueError( | raise ValueError( | ||||
@@ -214,7 +214,8 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
) | ) | ||||
) | ) | ||||
cn = as_device(device).to_c() | cn = as_device(device).to_c() | ||||
return apply(Copy(comp_node=cn), self)[0] | |||||
op = Borrow(comp_node=cn) if _borrow else Copy(comp_node=cn) | |||||
return apply(op, self)[0] | |||||
@property | @property | ||||
def requires_grad(self): | def requires_grad(self): | ||||
@@ -232,11 +233,11 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
return id(self) | return id(self) | ||||
def __getnewargs__(self): | def __getnewargs__(self): | ||||
r""" __getnewargs__ will be called for pickle serialization or deep copy""" | |||||
r"""__getnewargs__ will be called for pickle serialization or deep copy""" | |||||
return (self.numpy(), self.dtype, self.device.logical_name) | return (self.numpy(), self.dtype, self.device.logical_name) | ||||
def __getstate__(self): | def __getstate__(self): | ||||
r""" __getstate__ will be called for pickle serialization or deep copy""" | |||||
r"""__getstate__ will be called for pickle serialization or deep copy""" | |||||
state = {} | state = {} | ||||
if self._qparams is not None: | if self._qparams is not None: | ||||
state["qparams"] = self._qparams | state["qparams"] = self._qparams | ||||
@@ -11,6 +11,7 @@ | |||||
#pragma once | #pragma once | ||||
#include <list> | |||||
#include "megbrain/imperative/transformations/trace.h" | #include "megbrain/imperative/transformations/trace.h" | ||||
#include "megbrain/imperative/utils/map.h" | #include "megbrain/imperative/utils/map.h" | ||||
#include "megbrain/imperative/utils/stats.h" | #include "megbrain/imperative/utils/stats.h" | ||||
@@ -998,6 +998,9 @@ void init_tensor(py::module m) { | |||||
module_trace_hook.inc_ref(); | module_trace_hook.inc_ref(); | ||||
}); | }); | ||||
auto atexit = py::module::import("atexit"); | |||||
atexit.attr("register")(py::cpp_function([]() { module_trace_hook = {}; })); | |||||
m.def("begin_record_values", [] { Value::begin_record_values(); }); | m.def("begin_record_values", [] { Value::begin_record_values(); }); | ||||
m.def("end_record_values", [] { | m.def("end_record_values", [] { | ||||
@@ -0,0 +1,74 @@ | |||||
import gc | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine.functional as F | |||||
from megengine._multistream import record_event, wait_event | |||||
class MemStat: | |||||
def __init__(self, *args): | |||||
for d in args: | |||||
mge.Tensor([], device=d) | |||||
gc.collect() | |||||
mge._full_sync() | |||||
self.baseline = {d: mge.device.get_allocated_memory(d) for d in args} | |||||
for d in args: | |||||
mge.device.reset_max_memory_stats(d) | |||||
def get_max(self, device): | |||||
return mge.device.get_max_allocated_memory(device) - self.baseline[device] | |||||
@pytest.mark.require_ngpu(1) | |||||
def test_mem_stats(): | |||||
memstat = MemStat("xpux:0", "xpux:1") | |||||
F.arange(1024, device="xpux:0") | |||||
mge._full_sync() | |||||
assert 4096 <= memstat.get_max("xpux:0") == memstat.get_max("xpux:1") <= 4096 + 128 | |||||
@pytest.mark.require_ngpu(1) | |||||
def test_borrow(): | |||||
memstat = MemStat("xpux:0", "xpux:1") | |||||
x_np = np.random.randint(2 ** 30, size=(1 * 1024 * 1024,), dtype="int32") | |||||
unit = x_np.size * 4 | |||||
x0 = mge.Tensor(x_np, device="xpux:0") | |||||
x1 = x0.to("xpux:1", _borrow=True) | |||||
y = -x1 | |||||
np.testing.assert_equal(-x_np, y.numpy()) | |||||
mge._full_sync() | |||||
assert memstat.get_max("xpux:0") / unit < 2.1 | |||||
@pytest.mark.require_ngpu(1) | |||||
def test_stream_mem(): | |||||
memstat = MemStat("xpux:0", "xpux:1") | |||||
x_np = np.random.randint(2 ** 10, size=(1 * 1024 * 1024,), dtype="int32") | |||||
unit = x_np.size * 4 | |||||
x0 = mge.Tensor(x_np, device="xpux:0") | |||||
results = [] | |||||
events = [] | |||||
for i in range(100): | |||||
if len(events) >= 2: | |||||
wait_event("xpux:0", events[-2]) | |||||
x0 = x0 + 1 | |||||
results.append(x0.to("xpux:1", _borrow=True).sum()) | |||||
events.append(record_event("xpux:1")) | |||||
del events[:-2] | |||||
y_np = x_np.sum() | |||||
for i, y in enumerate(results): | |||||
y_np += x_np.size | |||||
assert y_np == y.numpy() | |||||
mge._full_sync() | |||||
assert memstat.get_max("xpux:0") / unit < 2.1 |
@@ -1,84 +0,0 @@ | |||||
/** | |||||
* \file imperative/src/impl/async_releaser.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/comp_node.h" | |||||
#include "megbrain/imperative/blob_manager.h" | |||||
#include "megbrain/imperative/resource_manager.h" | |||||
#include "megbrain/system.h" | |||||
#include "./event_pool.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
class AsyncReleaser : public CompNodeDepedentObject { | |||||
struct WaiterParam { | |||||
CompNode cn; | |||||
CompNode::Event* event; | |||||
BlobPtr blob; | |||||
HostTensorStorage::RawStorage storage; | |||||
}; | |||||
class Waiter final : public AsyncQueueSC<WaiterParam, Waiter> { | |||||
AsyncReleaser* m_par_releaser; | |||||
public: | |||||
// disable busy wait by set max_spin=0 to save CPU cycle | |||||
Waiter(AsyncReleaser* releaser) | |||||
: AsyncQueueSC<WaiterParam, Waiter>(0), m_par_releaser(releaser) {} | |||||
void process_one_task(WaiterParam& param) { | |||||
if (param.event->finished()) { | |||||
param.blob.reset(); | |||||
param.storage.reset(); | |||||
EventPool::without_timer().free(param.event); | |||||
return; | |||||
} | |||||
using namespace std::literals; | |||||
std::this_thread::sleep_for(1us); | |||||
add_task(std::move(param)); | |||||
} | |||||
void on_async_queue_worker_thread_start() override { | |||||
sys::set_thread_name("releaser"); | |||||
} | |||||
}; | |||||
Waiter m_waiter{this}; | |||||
protected: | |||||
std::shared_ptr<void> on_comp_node_finalize() override { | |||||
m_waiter.wait_task_queue_empty(); | |||||
return {}; | |||||
} | |||||
public: | |||||
static AsyncReleaser* inst() { | |||||
static auto* releaser = ResourceManager::create_global<AsyncReleaser>(); | |||||
return releaser; | |||||
} | |||||
~AsyncReleaser() { m_waiter.wait_task_queue_empty(); } | |||||
void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); } | |||||
void add(const HostTensorND& hv) { | |||||
add(hv.comp_node(), {}, hv.storage().raw_storage()); | |||||
} | |||||
void add(CompNode cn, BlobPtr blob, HostTensorStorage::RawStorage storage = {}) { | |||||
auto event = EventPool::without_timer().alloc(cn); | |||||
event->record(); | |||||
m_waiter.add_task({cn, event, std::move(blob), std::move(storage)}); | |||||
} | |||||
}; | |||||
} // namespace imperative | |||||
} // namespace mgb |
@@ -16,7 +16,7 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
BlobManagerImpl::BlobData::BlobData(Blob* in_blob) { | |||||
BlobManagerImpl::BlobData::BlobData(OwnedBlob* in_blob) { | |||||
blob = in_blob; | blob = in_blob; | ||||
DeviceTensorStorage d_storage; | DeviceTensorStorage d_storage; | ||||
d_storage.reset(blob->m_comp_node, blob->m_size, blob->m_storage); | d_storage.reset(blob->m_comp_node, blob->m_size, blob->m_storage); | ||||
@@ -28,19 +28,19 @@ BlobManagerImpl::BlobData::BlobData(Blob* in_blob) { | |||||
h_storage.copy_from(const_cast<DeviceTensorStorage&>(d_storage), blob->m_size); | h_storage.copy_from(const_cast<DeviceTensorStorage&>(d_storage), blob->m_size); | ||||
} | } | ||||
void BlobManagerImpl::register_blob(Blob* blob) { | |||||
void BlobManagerImpl::register_blob(OwnedBlob* blob) { | |||||
// add blob into the comp2blobs map | // add blob into the comp2blobs map | ||||
MGB_LOCK_GUARD(m_mtx); | MGB_LOCK_GUARD(m_mtx); | ||||
mgb_assert(m_comp2blobs_map[blob->m_comp_node].insert(blob)); | mgb_assert(m_comp2blobs_map[blob->m_comp_node].insert(blob)); | ||||
} | } | ||||
void BlobManagerImpl::unregister_blob(Blob* blob) { | |||||
void BlobManagerImpl::unregister_blob(OwnedBlob* blob) { | |||||
// erase blob into the comp2blobs map | // erase blob into the comp2blobs map | ||||
MGB_LOCK_GUARD(m_mtx); | MGB_LOCK_GUARD(m_mtx); | ||||
mgb_assert(1 == m_comp2blobs_map[blob->m_comp_node].erase(blob)); | mgb_assert(1 == m_comp2blobs_map[blob->m_comp_node].erase(blob)); | ||||
} | } | ||||
void BlobManagerImpl::alloc_with_defrag(Blob* blob, size_t size) { | |||||
void BlobManagerImpl::alloc_with_defrag(OwnedBlob* blob, size_t size) { | |||||
if (custom_allocator) { | if (custom_allocator) { | ||||
blob->m_storage = custom_allocator(blob->m_comp_node, size); | blob->m_storage = custom_allocator(blob->m_comp_node, size); | ||||
return; | return; | ||||
@@ -55,7 +55,7 @@ void BlobManagerImpl::alloc_with_defrag(Blob* blob, size_t size) { | |||||
}); | }); | ||||
} | } | ||||
void BlobManagerImpl::alloc_direct(Blob* blob, size_t size) { | |||||
void BlobManagerImpl::alloc_direct(OwnedBlob* blob, size_t size) { | |||||
DeviceTensorStorage storage(blob->m_comp_node); | DeviceTensorStorage storage(blob->m_comp_node); | ||||
mgb_assert(blob->m_comp_node.valid()); | mgb_assert(blob->m_comp_node.valid()); | ||||
storage.ensure_size(size); | storage.ensure_size(size); | ||||
@@ -143,7 +143,7 @@ void BlobManagerImpl::defrag(const CompNode& cn) { | |||||
// sort blobs by created time, may be helpful for reduce memory fragment | // sort blobs by created time, may be helpful for reduce memory fragment | ||||
std::sort( | std::sort( | ||||
blob_data_arrary.begin(), blob_data_arrary.end(), | blob_data_arrary.begin(), blob_data_arrary.end(), | ||||
[](auto& lhs, auto& rhs) { return lhs.blob->id() < rhs.blob->id(); }); | |||||
[](auto& lhs, auto& rhs) { return lhs.blob->m_id < rhs.blob->m_id; }); | |||||
// allocate for each storage | // allocate for each storage | ||||
for (auto i : blob_data_arrary) { | for (auto i : blob_data_arrary) { | ||||
@@ -158,19 +158,19 @@ void BlobManagerImpl::defrag(const CompNode& cn) { | |||||
} | } | ||||
struct BlobManagerStub : BlobManager { | struct BlobManagerStub : BlobManager { | ||||
void alloc_direct(Blob* blob, size_t size) { | |||||
void alloc_direct(OwnedBlob* blob, size_t size) { | |||||
mgb_assert(0, "prohibited after global variable destruction"); | mgb_assert(0, "prohibited after global variable destruction"); | ||||
}; | }; | ||||
void alloc_with_defrag(Blob* blob, size_t size) { | |||||
void alloc_with_defrag(OwnedBlob* blob, size_t size) { | |||||
mgb_assert(0, "prohibited after global variable destruction"); | mgb_assert(0, "prohibited after global variable destruction"); | ||||
}; | }; | ||||
DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout& layout) { | DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout& layout) { | ||||
mgb_assert(0, "prohibited after global variable destruction"); | mgb_assert(0, "prohibited after global variable destruction"); | ||||
}; | }; | ||||
void register_blob(Blob* blob) { | |||||
void register_blob(OwnedBlob* blob) { | |||||
mgb_assert(0, "prohibited after global variable destruction"); | mgb_assert(0, "prohibited after global variable destruction"); | ||||
}; | }; | ||||
void unregister_blob(Blob* blob){}; | |||||
void unregister_blob(OwnedBlob* blob){}; | |||||
void defrag(const CompNode& cn) { | void defrag(const CompNode& cn) { | ||||
mgb_assert(0, "prohibited after global variable destruction"); | mgb_assert(0, "prohibited after global variable destruction"); | ||||
}; | }; | ||||
@@ -19,21 +19,21 @@ namespace imperative { | |||||
class BlobManagerImpl final : public BlobManager { | class BlobManagerImpl final : public BlobManager { | ||||
struct BlobSetWithMux { | struct BlobSetWithMux { | ||||
std::mutex mtx; | std::mutex mtx; | ||||
ThinHashSet<Blob*> blobs_set; | |||||
bool insert(Blob* blob) { | |||||
ThinHashSet<OwnedBlob*> blobs_set; | |||||
bool insert(OwnedBlob* blob) { | |||||
MGB_LOCK_GUARD(mtx); | MGB_LOCK_GUARD(mtx); | ||||
return blobs_set.insert(blob).second; | return blobs_set.insert(blob).second; | ||||
} | } | ||||
size_t erase(Blob* blob) { | |||||
size_t erase(OwnedBlob* blob) { | |||||
MGB_LOCK_GUARD(mtx); | MGB_LOCK_GUARD(mtx); | ||||
return blobs_set.erase(blob); | return blobs_set.erase(blob); | ||||
} | } | ||||
}; | }; | ||||
struct BlobData { | struct BlobData { | ||||
Blob* blob; | |||||
OwnedBlob* blob; | |||||
HostTensorStorage h_storage; | HostTensorStorage h_storage; | ||||
BlobData(Blob* in_blob); | |||||
BlobData(OwnedBlob* in_blob); | |||||
}; | }; | ||||
std::mutex m_mtx; | std::mutex m_mtx; | ||||
@@ -41,7 +41,7 @@ class BlobManagerImpl final : public BlobManager { | |||||
void defrag(const CompNode& cn) override; | void defrag(const CompNode& cn) override; | ||||
void alloc_direct(Blob* blob, size_t size) override; | |||||
void alloc_direct(OwnedBlob* blob, size_t size) override; | |||||
DeviceTensorND alloc_workspace(CompNode cn, TensorLayout layout); | DeviceTensorND alloc_workspace(CompNode cn, TensorLayout layout); | ||||
@@ -50,14 +50,14 @@ class BlobManagerImpl final : public BlobManager { | |||||
public: | public: | ||||
static BlobManager* inst(); | static BlobManager* inst(); | ||||
void alloc_with_defrag(Blob* blob, size_t size) override; | |||||
void alloc_with_defrag(OwnedBlob* blob, size_t size) override; | |||||
DeviceTensorND alloc_workspace_with_defrag( | DeviceTensorND alloc_workspace_with_defrag( | ||||
CompNode cn, TensorLayout& layout) override; | CompNode cn, TensorLayout& layout) override; | ||||
void register_blob(Blob* blob) override; | |||||
void register_blob(OwnedBlob* blob) override; | |||||
void unregister_blob(Blob* blob) override; | |||||
void unregister_blob(OwnedBlob* blob) override; | |||||
void set_allocator(allocator_t allocator) override; | void set_allocator(allocator_t allocator) override; | ||||
}; | }; | ||||
@@ -704,7 +704,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | |||||
m_dtr.update_used_time(dest); | m_dtr.update_used_time(dest); | ||||
MGB_RECORD_EVENT( | MGB_RECORD_EVENT( | ||||
TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), | TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), | ||||
ptr->dev_tensor(false).raw_ptr()); | |||||
ptr->raw_ptr_not_for_readwrite()); | |||||
// update tensor desc for static infer | // update tensor desc for static infer | ||||
if (dest->desc.layout.ndim) { | if (dest->desc.layout.ndim) { | ||||
mgb_assert( | mgb_assert( | ||||
@@ -805,8 +805,13 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||||
inputs[idx]->to_contiguous_inplace(layout_checker); | inputs[idx]->to_contiguous_inplace(layout_checker); | ||||
} | } | ||||
} | } | ||||
return OpDef::apply_on_physical_tensor( | |||||
auto outputs = OpDef::apply_on_physical_tensor( | |||||
def, std::move(inputs), output_descs, validated); | def, std::move(inputs), output_descs, validated); | ||||
for (auto& o : outputs) { | |||||
o->set_ready_event( | |||||
record_event(o->comp_node(), def.same_type<imperative::Barrier>())); | |||||
} | |||||
return outputs; | |||||
}; | }; | ||||
MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason); | MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason); | ||||
SmallVector<std::pair<CompNode, uint64_t>> kernels; | SmallVector<std::pair<CompNode, uint64_t>> kernels; | ||||
@@ -1059,7 +1064,7 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() { | |||||
return valid_tensors; | return valid_tensors; | ||||
} | } | ||||
void ChannelImpl::alloc_tensor_with_evict(Blob* x) { | |||||
void ChannelImpl::alloc_tensor_with_evict(OwnedBlob* x) { | |||||
bool in_worker = (get_worker_tid() == std::this_thread::get_id()); | bool in_worker = (get_worker_tid() == std::this_thread::get_id()); | ||||
auto reserve_size = [&](size_t size) { | auto reserve_size = [&](size_t size) { | ||||
if (!m_dtr.comp_node.valid()) { | if (!m_dtr.comp_node.valid()) { | ||||
@@ -304,7 +304,7 @@ private: | |||||
//! automatically evict an optimal tensor | //! automatically evict an optimal tensor | ||||
bool auto_evict(size_t); | bool auto_evict(size_t); | ||||
void alloc_tensor_with_evict(Blob*); | |||||
void alloc_tensor_with_evict(OwnedBlob*); | |||||
// assert thread id when call get_xxx_state to avoid misuse | // assert thread id when call get_xxx_state to avoid misuse | ||||
ChannelState& get_channel_state(); | ChannelState& get_channel_state(); | ||||
@@ -249,7 +249,7 @@ SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | const OpDef& def, const SmallVector<TensorPtr>& inputs, | ||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | ||||
auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3]; | auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3]; | ||||
if (!(inputs[0]->blob().unique() && inputs[0]->blob()->storage().unique())) { | |||||
if (!inputs[0]->storage_is_unique()) { | |||||
mgb_log_warn( | mgb_log_warn( | ||||
"This inplace modification may change the elements of other tensors. " | "This inplace modification may change the elements of other tensors. " | ||||
"Fallback to non-inplace update."); | "Fallback to non-inplace update."); | ||||
@@ -13,7 +13,6 @@ | |||||
#include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
#include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
#include "../async_releaser.h" | |||||
#include "../dnn_op_helper.h" | #include "../dnn_op_helper.h" | ||||
#include "../op_trait.h" | #include "../op_trait.h" | ||||
@@ -267,8 +266,7 @@ SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor( | |||||
{srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(), | {srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(), | ||||
output->dev_tensor().as_megdnn(), | output->dev_tensor().as_megdnn(), | ||||
caller.create_workspace({{ws_size}, dtype::Byte()})); | caller.create_workspace({{ws_size}, dtype::Byte()})); | ||||
AsyncReleaser::inst()->add( | |||||
HostTensorND{comp_node, srcs_layout}.storage(srcs_storage)); | |||||
async_release(HostTensorND{comp_node, srcs_layout}.storage(srcs_storage)); | |||||
return {output}; | return {output}; | ||||
} | } | ||||
@@ -23,6 +23,8 @@ | |||||
#include "megbrain/opr/tensor_gen.h" | #include "megbrain/opr/tensor_gen.h" | ||||
#include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
#include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
#include "megbrain/tensor.h" | |||||
#include "megdnn/dtype.h" | |||||
#if MGB_JIT | #if MGB_JIT | ||||
#include "megbrain/jit/executor_opr.h" | #include "megbrain/jit/executor_opr.h" | ||||
@@ -37,6 +39,102 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); | |||||
OP_TRAIT_REG(GenericPyOp, GenericPyOp).fallback(); | OP_TRAIT_REG(GenericPyOp, GenericPyOp).fallback(); | ||||
namespace { | namespace { | ||||
namespace borrow { | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto& op = def.cast_final_safe<Borrow>(); | |||||
SmallVector<TensorPtr> outputs; | |||||
outputs.reserve(inputs.size()); | |||||
for (auto& i : inputs) { | |||||
outputs.push_back(Tensor::make( | |||||
i->blob()->borrow_to(op.comp_node), i->layout(), i->offset())); | |||||
} | |||||
return outputs; | |||||
} | |||||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
return SmallVector<VarNode::LayoutConstraintCallback>(inputs.size()); | |||||
} | |||||
auto infer_output_attrs_fallible( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
auto& op = def.cast_final_safe<Borrow>(); | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> ret(inputs, true); | |||||
for (auto& i : std::get<0>(ret)) { | |||||
mgb_assert( | |||||
i.comp_node.mem_node() == op.comp_node.mem_node(), | |||||
"cannot borrow memory from %s to %s", i.comp_node.to_string().c_str(), | |||||
op.comp_node.to_string().c_str()); | |||||
i.comp_node = op.comp_node; | |||||
} | |||||
return ret; | |||||
} | |||||
OP_TRAIT_REG(Borrow, Borrow) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.get_input_layout_constraint(get_input_layout_constraint) | |||||
.fallback(); | |||||
} // namespace borrow | |||||
} // namespace | |||||
namespace { | |||||
namespace barrier { | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
auto& op = def.cast_final_safe<imperative::Barrier>(); | |||||
SmallVector<TensorPtr> outputs; | |||||
for (auto& i : inputs) { | |||||
if (i->comp_node() != op.comp_node) { | |||||
device_wait_event(op.comp_node, i->comp_node(), i->get_ready_event()); | |||||
} | |||||
} | |||||
if (op.nr_outputs) { | |||||
outputs.resize(op.nr_outputs); | |||||
TensorLayout layout(TensorShape({0}), dtype::Int32{}); | |||||
for (auto& i : outputs) { | |||||
i = Tensor::make(layout, op.comp_node); | |||||
} | |||||
} | |||||
return outputs; | |||||
} | |||||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
return SmallVector<VarNode::LayoutConstraintCallback>(inputs.size()); | |||||
} | |||||
auto infer_output_attrs_fallible( | |||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
auto& op = def.cast_final_safe<imperative::Barrier>(); | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> ret; | |||||
auto& [descs, checked] = ret; | |||||
descs.resize(op.nr_outputs); | |||||
checked = true; | |||||
for (auto& desc : descs) { | |||||
desc.comp_node = op.comp_node; | |||||
desc.layout.dtype = dtype::Int32{}; | |||||
desc.layout.ndim = 1; | |||||
desc.layout.shape[0] = 0; | |||||
desc.layout.stride[0] = 1; | |||||
} | |||||
return ret; | |||||
} | |||||
OP_TRAIT_REG(Barrier, imperative::Barrier) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
.get_input_layout_constraint(get_input_layout_constraint) | |||||
.fallback(); | |||||
} // namespace barrier | |||||
} // namespace | |||||
namespace { | |||||
namespace fastpathcopy { | namespace fastpathcopy { | ||||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
return inputs; | return inputs; | ||||
@@ -9,57 +9,384 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "megbrain/imperative/physical_tensor.h" | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/comp_node.h" | |||||
#include "megbrain/imperative.h" | #include "megbrain/imperative.h" | ||||
#include "megbrain/imperative/blob_manager.h" | #include "megbrain/imperative/blob_manager.h" | ||||
#include "megbrain/imperative/profiler.h" | #include "megbrain/imperative/profiler.h" | ||||
#include "megbrain/imperative/resource_manager.h" | #include "megbrain/imperative/resource_manager.h" | ||||
#include "./async_releaser.h" | |||||
#include "./event_pool.h" | #include "./event_pool.h" | ||||
#include "./profiler/events.h" | #include "./profiler/events.h" | ||||
#include <condition_variable> | |||||
#include <cstdint> | |||||
#include <deque> | |||||
#include <map> | |||||
#include <memory> | |||||
#include <mutex> | #include <mutex> | ||||
#include <type_traits> | |||||
#include <unordered_map> | |||||
#include <utility> | |||||
#include <variant> | |||||
#include <vector> | |||||
#ifndef WIN32 | |||||
#include <pthread.h> | |||||
#endif | |||||
#include "range/v3/all.hpp" | |||||
namespace views = ranges::views; | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
namespace { | namespace { | ||||
class CompNodeSyncManager : public CompNodeDepedentObject { | |||||
ThinHashMap<Blob*, std::unique_ptr<CompNode::Event>> m_blob2event; | |||||
struct CompNodeHash { | |||||
auto operator()(CompNode cn) const { return mgb::hash(cn); } | |||||
}; | |||||
template <typename T> | |||||
struct NoThrowMovable : T { | |||||
using T::T; | |||||
NoThrowMovable(NoThrowMovable&&) noexcept = default; | |||||
}; | |||||
template <typename... Ts> | |||||
using Map = NoThrowMovable<std::map<Ts...>>; | |||||
class CompNodeSyncManager { | |||||
struct CompNodeData { | |||||
template <typename T> | |||||
class ReleaseQueue { | |||||
Map<uint64_t, T> map; | |||||
public: | |||||
template <typename A> | |||||
void emplace(uint64_t t, A&& a) { | |||||
map.emplace_hint(map.end(), t, std::forward<A>(a)); | |||||
} | |||||
void release(uint64_t t) { | |||||
auto it = map.upper_bound(t); | |||||
map.erase(map.begin(), it); | |||||
} | |||||
}; | |||||
//! next virtual event | |||||
uint64_t next = 1; | |||||
//! last completed virtual event | |||||
uint64_t completed = 0; | |||||
//! virtual event to real event | |||||
Map<uint64_t, EventPtr> events; | |||||
//! ordering information at some virtual events: | |||||
//! what virtual events on other comp nodes is _sequenced before_ this virtual | |||||
//! event | |||||
Map<uint64_t, std::vector<uint64_t>> ordering; | |||||
//! release queue for dev storage, keyed by releaser. this comp node is the | |||||
//! **receiver** | |||||
std::vector<ReleaseQueue<BlobPtr>> release_queues; | |||||
//! release queue for host storage. this comp node is the **releaser** | |||||
ReleaseQueue<HostTensorStorage::RawStorage> host_release_queue; | |||||
}; | |||||
std::mutex m_mtx; | std::mutex m_mtx; | ||||
std::condition_variable m_cv; | |||||
bool m_should_stop = false; | |||||
std::thread m_polling_thread; | |||||
std::unordered_map<CompNode, size_t, CompNodeHash> m_cn2id; | |||||
std::vector<CompNodeData> m_cndata; | |||||
auto do_record(CompNode cn, size_t cnid, std::unique_lock<std::mutex>& lock) { | |||||
// CAUSION: don't keep reference across locking boundary | |||||
lock.unlock(); | |||||
auto e = EventPool::without_timer().alloc(cn); | |||||
e->record(); | |||||
lock.lock(); | |||||
auto& cndata = m_cndata[cnid]; | |||||
return cndata.events.emplace_hint(cndata.events.end(), cndata.next++, e); | |||||
} | |||||
std::pair<uint64_t, CompNode::Event*> get_event( | |||||
CompNode cn, size_t cnid, uint64_t t, std::unique_lock<std::mutex>& lock) { | |||||
auto& cndata = m_cndata[cnid]; | |||||
auto it = cndata.events.lower_bound(t); | |||||
if (it == cndata.events.end()) { | |||||
it = do_record(cn, cnid, lock); | |||||
} | |||||
return {it->first, it->second.get()}; | |||||
} | |||||
size_t get_cnid_unsafe(CompNode cn) { | |||||
auto [it, unseen] = m_cn2id.try_emplace(cn, m_cndata.size()); | |||||
if (unseen) { | |||||
m_cndata.emplace_back(); | |||||
} | |||||
return it->second; | |||||
} | |||||
void monitor_events() { | |||||
#if defined(__APPLE__) | |||||
pthread_setname_np("CompNodeSync"); | |||||
#elif defined(__unix__) | |||||
pthread_setname_np(pthread_self(), "CompNodeSync"); | |||||
#endif | |||||
// poll events in rounds. sleep for a fixed duration between rounds. | |||||
// number of events to query is decided by the number of successful queries in | |||||
// last round, independently for each comp node: | |||||
// a. all -> double | |||||
// b. 0 -> 1 | |||||
// c. otherwise -> #successful | |||||
struct Item { | |||||
size_t cnid; | |||||
decltype(CompNodeData::events)::iterator it; | |||||
}; | |||||
struct Stat { | |||||
size_t num_success = 0; | |||||
size_t num_attempts = 0; | |||||
// iterator to the last finished event | |||||
decltype(CompNodeData::events)::iterator it; | |||||
}; | |||||
std::vector<Stat> stats; | |||||
std::vector<Item> todos; | |||||
std::unique_lock lock(m_mtx); | |||||
for (;;) { | |||||
// copy events to a temporary storage so that we may unlock while polling | |||||
stats.resize(m_cndata.size()); | |||||
for (size_t cnid = 0; cnid < m_cndata.size(); ++cnid) { | |||||
// decide max number of events to query | |||||
// rule c: #successful | |||||
size_t n = stats[cnid].num_success; | |||||
if (n == stats[cnid].num_attempts) { | |||||
// rule a: double | |||||
n *= 2; | |||||
} | |||||
if (n == 0) { | |||||
// rule b: 1 | |||||
n = 1; | |||||
} | |||||
// now copy upto n events | |||||
auto& events = m_cndata[cnid].events; | |||||
size_t i = 0; | |||||
for (auto it = events.begin(); i < n && it != events.end(); ++i, ++it) { | |||||
todos.push_back({cnid, it}); | |||||
} | |||||
// reset stats for this round | |||||
stats[cnid].num_success = 0; | |||||
stats[cnid].num_attempts = n; | |||||
} | |||||
lock.unlock(); | |||||
bool last_result = false; | |||||
size_t last_cnid = -1; | |||||
for (auto item : todos) { | |||||
if (item.cnid == last_cnid && !last_result) { | |||||
// previous failed, this one almost certainly should fail | |||||
continue; | |||||
} | |||||
last_cnid = item.cnid; | |||||
last_result = item.it->second->finished(); | |||||
if (last_result) { | |||||
stats[item.cnid].num_success++; | |||||
stats[item.cnid].it = item.it; | |||||
} | |||||
} | |||||
todos.clear(); | |||||
lock.lock(); | |||||
// release dev storage | |||||
for (size_t receiver_cnid = 0; receiver_cnid < m_cndata.size(); | |||||
++receiver_cnid) { | |||||
for (size_t releaser_cnid = 0; | |||||
releaser_cnid < m_cndata[receiver_cnid].release_queues.size(); | |||||
++releaser_cnid) { | |||||
if (releaser_cnid >= stats.size() || | |||||
stats[releaser_cnid].num_success == 0) { | |||||
continue; | |||||
} | |||||
auto& q = m_cndata[receiver_cnid].release_queues[releaser_cnid]; | |||||
q.release(stats[releaser_cnid].it->first); | |||||
} | |||||
} | |||||
for (size_t cnid = 0; cnid < stats.size(); ++cnid) { | |||||
if (stats[cnid].num_success == 0) { | |||||
continue; | |||||
} | |||||
auto& cndata = m_cndata[cnid]; | |||||
auto it = stats[cnid].it; | |||||
auto t = it->first; | |||||
// update completed | |||||
cndata.completed = t; | |||||
// release host storage | |||||
cndata.host_release_queue.release(t); | |||||
// remove completed events | |||||
auto& events = cndata.events; | |||||
events.erase(events.begin(), std::next(it)); | |||||
} | |||||
using namespace std::literals; | |||||
if (m_cv.wait_for(lock, 10us, [&] { return m_should_stop; })) { | |||||
return; | |||||
} | |||||
} | |||||
} | |||||
CompNodeSyncManager() { | |||||
m_polling_thread = std::thread([this] { monitor_events(); }); | |||||
} | |||||
public: | public: | ||||
std::shared_ptr<void> on_comp_node_finalize() override { | |||||
MGB_LOCK_GUARD(m_mtx); | |||||
m_blob2event.clear(); | |||||
return {}; | |||||
~CompNodeSyncManager() { | |||||
{ | |||||
MGB_LOCK_GUARD(m_mtx); | |||||
m_should_stop = true; | |||||
m_cv.notify_all(); | |||||
} | |||||
m_polling_thread.join(); | |||||
} | } | ||||
static CompNodeSyncManager& inst() { | |||||
static auto* sl_inst = ResourceManager::create_global<CompNodeSyncManager>(); | |||||
return *sl_inst; | |||||
static CompNodeSyncManager& inst(); | |||||
uint64_t record(CompNode cn, bool doitnow = false) { | |||||
std::unique_lock lock(m_mtx); | |||||
auto cnid = get_cnid_unsafe(cn); | |||||
if (doitnow) { | |||||
return do_record(cn, cnid, lock)->first; | |||||
} | |||||
return m_cndata[cnid].next++; | |||||
} | } | ||||
CompNode::Event* get_or_create_event(Blob* blob) { | |||||
mgb_assert(!is_finalized()); | |||||
void async_release(CompNode cn, uint64_t t, BlobPtr blob) { | |||||
MGB_LOCK_GUARD(m_mtx); | MGB_LOCK_GUARD(m_mtx); | ||||
auto&& e = m_blob2event[blob]; | |||||
if (!e) { | |||||
e = blob->comp_node().create_event(); | |||||
auto releaser_cnid = get_cnid_unsafe(cn); | |||||
if (t <= m_cndata[releaser_cnid].completed) { | |||||
return; | |||||
} | |||||
auto receiver_cnid = get_cnid_unsafe(blob->comp_node()); | |||||
auto& qs = m_cndata[receiver_cnid].release_queues; | |||||
if (releaser_cnid >= qs.size()) { | |||||
qs.resize(releaser_cnid + 1); | |||||
} | } | ||||
return e.get(); | |||||
auto& q = qs[releaser_cnid]; | |||||
q.emplace(t, std::move(blob)); | |||||
} | } | ||||
void remove(Blob* blob) { | |||||
void async_release(CompNode cn, uint64_t t, HostTensorStorage::RawStorage storage) { | |||||
MGB_LOCK_GUARD(m_mtx); | MGB_LOCK_GUARD(m_mtx); | ||||
m_blob2event.erase(blob); | |||||
auto releaser_cnid = get_cnid_unsafe(cn); | |||||
if (t <= m_cndata[releaser_cnid].completed) { | |||||
return; | |||||
} | |||||
auto& q = m_cndata[releaser_cnid].host_release_queue; | |||||
q.emplace(t, std::move(storage)); | |||||
} | |||||
void device_wait(CompNode waiter, CompNode waitee, uint64_t t) { | |||||
std::unique_lock lock(m_mtx); | |||||
auto waiter_id = get_cnid_unsafe(waiter); | |||||
auto waitee_id = get_cnid_unsafe(waitee); | |||||
auto& waiter_data = m_cndata.at(waiter_id); | |||||
auto& waitee_data = m_cndata.at(waitee_id); | |||||
auto [t_waitee, e] = get_event(waitee, waitee_id, t, lock); | |||||
// DO NOT unlock around this line! Event* could be invalidated! | |||||
e->device_wait_by(waiter); | |||||
auto t_waiter = waiter_data.next++; | |||||
std::vector<uint64_t> ordering(m_cndata.size(), 0); | |||||
if (!waiter_data.ordering.empty()) { | |||||
auto& o = waiter_data.ordering.rbegin()->second; | |||||
std::copy(o.begin(), o.end(), ordering.begin()); | |||||
} | |||||
ordering[waitee_id] = t_waitee; | |||||
ordering[waiter_id] = t_waiter; | |||||
{ | |||||
auto it = waitee_data.ordering.lower_bound(t_waitee); | |||||
if (it != waitee_data.ordering.begin()) { | |||||
for (auto [a, b] : views::zip(ordering, std::prev(it)->second)) { | |||||
static_assert(std::is_lvalue_reference_v<decltype(a)>); | |||||
a = std::max(a, b); | |||||
} | |||||
} | |||||
} | |||||
waiter_data.ordering.emplace_hint( | |||||
waiter_data.ordering.end(), t_waiter, ordering); | |||||
for (auto [t, q] : views::zip(ordering, waiter_data.release_queues)) { | |||||
q.release(t); | |||||
} | |||||
} | } | ||||
}; | }; | ||||
CompNodeSyncManager& CompNodeSyncManager::inst() { | |||||
static std::mutex mtx; | |||||
static std::unique_ptr<CompNodeSyncManager> inst; | |||||
struct Guard final : CompNodeDepedentObject { | |||||
std::shared_ptr<void> on_comp_node_finalize() override { | |||||
MGB_LOCK_GUARD(mtx); | |||||
inst.reset(); | |||||
return {}; | |||||
} | |||||
}; | |||||
static std::optional<Guard> guard; | |||||
#ifndef WIN32 | |||||
static bool broken = false; | |||||
static struct ForkGuard { | |||||
ForkGuard() { | |||||
mgb_assert(0 == pthread_atfork(NULL, NULL, [] { | |||||
if (inst) { | |||||
inst.release(); // deliberate leak, unfixable | |||||
broken = true; | |||||
} | |||||
})); | |||||
} | |||||
} fork_guard; | |||||
#endif | |||||
MGB_LOCK_GUARD(mtx); | |||||
if (!inst) { | |||||
#ifndef WIN32 | |||||
mgb_assert(!broken); | |||||
#endif | |||||
EventPool::without_timer(); | |||||
inst.reset(new CompNodeSyncManager); | |||||
guard.emplace(); | |||||
} | |||||
return *inst; | |||||
} | |||||
} // namespace | } // namespace | ||||
uint64_t record_event(CompNode cn, bool doitnow) { | |||||
return CompNodeSyncManager::inst().record(cn, doitnow); | |||||
} | |||||
void device_wait_event(CompNode waiter, CompNode waitee, uint64_t event) { | |||||
CompNodeSyncManager::inst().device_wait(waiter, waitee, event); | |||||
} | |||||
void async_release(CompNode cn, uint64_t event, BlobPtr blob) { | |||||
CompNodeSyncManager::inst().async_release(cn, event, std::move(blob)); | |||||
} | |||||
void async_release(CompNode cn, uint64_t event, HostTensorStorage::RawStorage storage) { | |||||
CompNodeSyncManager::inst().async_release(cn, event, std::move(storage)); | |||||
} | |||||
void EventDeleter::operator()(CompNode::Event* event) { | void EventDeleter::operator()(CompNode::Event* event) { | ||||
EventPool::without_timer().free(event); | EventPool::without_timer().free(event); | ||||
} | } | ||||
@@ -68,31 +395,74 @@ namespace { | |||||
std::atomic_uint64_t next_blob_id = 0; | std::atomic_uint64_t next_blob_id = 0; | ||||
} | } | ||||
Blob::Blob(const DeviceTensorStorage& s) | |||||
: m_comp_node{s.comp_node()}, | |||||
OwnedBlob::OwnedBlob(const DeviceTensorStorage& s) | |||||
: Blob(s.comp_node(), s.size() + s.offset()), | |||||
m_storage{s.raw_storage()}, | m_storage{s.raw_storage()}, | ||||
m_size{s.size() + s.offset()} { | |||||
m_id = next_blob_id++; | |||||
m_id{next_blob_id++} { | |||||
BlobManager::inst()->register_blob(this); | BlobManager::inst()->register_blob(this); | ||||
} | } | ||||
Blob::Blob(CompNode cn, size_t sz) : m_comp_node{cn}, m_storage{}, m_size{sz} { | |||||
m_id = next_blob_id++; | |||||
OwnedBlob::OwnedBlob(CompNode cn, size_t sz) | |||||
: Blob(cn, sz), m_storage{}, m_id{next_blob_id++} { | |||||
BlobManager::inst()->register_blob(this); | BlobManager::inst()->register_blob(this); | ||||
} | } | ||||
Blob::~Blob() { | |||||
OwnedBlob::~OwnedBlob() { | |||||
BlobManager::inst()->unregister_blob(this); | BlobManager::inst()->unregister_blob(this); | ||||
CompNodeSyncManager::inst().remove(this); | |||||
} | } | ||||
const Blob::RawStorage& Blob::storage() { | |||||
const Blob::RawStorage& OwnedBlob::storage() { | |||||
if (!m_storage && m_size) { | if (!m_storage && m_size) { | ||||
BlobManager::inst()->alloc_with_defrag(this, m_size); | BlobManager::inst()->alloc_with_defrag(this, m_size); | ||||
} | } | ||||
return m_storage; | return m_storage; | ||||
} | } | ||||
BlobPtr OwnedBlob::borrow_to(CompNode cn) { | |||||
return std::make_shared<BorrowedBlob>( | |||||
cn, std::static_pointer_cast<OwnedBlob>(shared_from_this())); | |||||
} | |||||
bool OwnedBlob::storage_is_unique() { | |||||
return m_storage.unique(); | |||||
} | |||||
void* OwnedBlob::raw_ptr_not_for_readwrite() { | |||||
return m_storage.get(); | |||||
} | |||||
BorrowedBlob::BorrowedBlob(CompNode cn, std::shared_ptr<OwnedBlob> owner) | |||||
: Blob(cn, owner->size()), | |||||
m_owner(std::move(owner)), | |||||
m_event(record_event(m_owner->comp_node(), true)) {} | |||||
BorrowedBlob::~BorrowedBlob() { | |||||
async_release(m_comp_node, record_event(m_comp_node, true), std::move(m_owner)); | |||||
} | |||||
const Blob::RawStorage& BorrowedBlob::storage() { | |||||
{ | |||||
MGB_LOCK_GUARD(m_mtx); | |||||
if (!m_initialized) { | |||||
device_wait_event(m_comp_node, m_owner->comp_node(), m_event); | |||||
m_initialized = true; | |||||
} | |||||
} | |||||
return m_owner->storage(); | |||||
} | |||||
BlobPtr BorrowedBlob::borrow_to(CompNode cn) { | |||||
return std::make_shared<BorrowedBlob>(cn, m_owner); | |||||
} | |||||
bool BorrowedBlob::storage_is_unique() { | |||||
return m_owner.unique() && m_owner->storage_is_unique(); | |||||
} | |||||
void* BorrowedBlob::raw_ptr_not_for_readwrite() { | |||||
return m_owner->raw_ptr_not_for_readwrite(); | |||||
} | |||||
Tensor::Tensor( | Tensor::Tensor( | ||||
BlobPtr blob, const TensorLayout& layout, size_t offset, const HostTensorND& hv) | BlobPtr blob, const TensorLayout& layout, size_t offset, const HostTensorND& hv) | ||||
: m_cn(blob->comp_node()), | : m_cn(blob->comp_node()), | ||||
@@ -119,7 +489,7 @@ Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) { | |||||
MGB_RECORD_EVENT( | MGB_RECORD_EVENT( | ||||
profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(), | profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(), | ||||
hv.raw_ptr(), dev_tensor().raw_ptr()); | hv.raw_ptr(), dev_tensor().raw_ptr()); | ||||
AsyncReleaser::inst()->add(hv); | |||||
async_release(hv); | |||||
} | } | ||||
} | } | ||||
@@ -174,7 +544,7 @@ void Tensor::to_contiguous_inplace(VarNode::LayoutConstraintCallback& layout_che | |||||
DeviceTensorND dv_contig; | DeviceTensorND dv_contig; | ||||
dv_contig.copy_from(dv); | dv_contig.copy_from(dv); | ||||
m_layout = dv_contig.layout(); | m_layout = dv_contig.layout(); | ||||
std::atomic_store(&m_blob, Blob::make(dv_contig.storage())); | |||||
std::atomic_store(&m_blob, BlobPtr(Blob::make(dv_contig.storage()))); | |||||
mgb_assert(m_layout.is_contiguous()); | mgb_assert(m_layout.is_contiguous()); | ||||
m_offset = 0; | m_offset = 0; | ||||
} | } | ||||
@@ -188,7 +558,7 @@ void Tensor::to_contiguous_inplace() { | |||||
void Tensor::assign_from_dev_tensor(DeviceTensorND dv) { | void Tensor::assign_from_dev_tensor(DeviceTensorND dv) { | ||||
MGB_LOCK_GUARD(m_blob_mtx); | MGB_LOCK_GUARD(m_blob_mtx); | ||||
std::atomic_store(&m_blob, Blob::make(dv.storage())); | |||||
std::atomic_store(&m_blob, BlobPtr(Blob::make(dv.storage()))); | |||||
m_offset = dv.storage().offset(); | m_offset = dv.storage().offset(); | ||||
m_layout = dv.layout(); | m_layout = dv.layout(); | ||||
} | } | ||||
@@ -254,21 +624,20 @@ TensorPtr Tensor::sub(size_t offset, TensorShape shape) { | |||||
return Tensor::make(m_blob, offset + m_offset, layout); | return Tensor::make(m_blob, offset + m_offset, layout); | ||||
} | } | ||||
void Tensor::add_release_callback(CompNode cn) { | |||||
AsyncReleaser::inst()->add(m_blob, cn); | |||||
uint64_t Tensor::get_ready_event() { | |||||
if (m_produced_at == 0) { | |||||
m_produced_at = record_event(comp_node()); | |||||
} | |||||
return m_produced_at; | |||||
} | } | ||||
CompNode::Event* Tensor::get_or_create_event() { | |||||
auto e = CompNodeSyncManager::inst().get_or_create_event(m_blob.get()); | |||||
e->record(); | |||||
return e; | |||||
bool Tensor::storage_is_unique() { | |||||
return m_blob.unique() && m_blob->storage_is_unique(); | |||||
} | } | ||||
void Tensor::static_initialize() { | void Tensor::static_initialize() { | ||||
EventPool::with_timer(); | EventPool::with_timer(); | ||||
EventPool::without_timer(); | EventPool::without_timer(); | ||||
AsyncReleaser::inst(); | |||||
CompNodeSyncManager::inst(); | |||||
MultiCNConstTensorCache::inst(); | MultiCNConstTensorCache::inst(); | ||||
} | } | ||||
@@ -836,8 +836,8 @@ public: | |||||
if (used_cns.insert(cn).second) { | if (used_cns.insert(cn).second) { | ||||
for (auto&& in : inputs) { | for (auto&& in : inputs) { | ||||
if (in->comp_node() != cn) { | if (in->comp_node() != cn) { | ||||
auto&& e = in->get_or_create_event(); | |||||
e->device_wait_by(cn); | |||||
auto e = in->get_ready_event(); | |||||
device_wait_event(cn, in->comp_node(), e); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -847,11 +847,17 @@ public: | |||||
// so we need create inference session here | // so we need create inference session here | ||||
minigraph.execute(raw_inputs, raw_outputs, m_env); | minigraph.execute(raw_inputs, raw_outputs, m_env); | ||||
for (auto&& cn : used_cns) { | for (auto&& cn : used_cns) { | ||||
bool should_record = false; | |||||
for (auto&& in : inputs) { | for (auto&& in : inputs) { | ||||
if (in->comp_node() != cn) { | if (in->comp_node() != cn) { | ||||
in->add_release_callback(cn); | |||||
should_record = true; | |||||
auto e = record_event(cn); | |||||
async_release(cn, e, *in); | |||||
} | } | ||||
} | } | ||||
if (should_record) { | |||||
record_event(cn, true); | |||||
} | |||||
} | } | ||||
return outputs; | return outputs; | ||||
@@ -15,7 +15,6 @@ | |||||
#include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
#include "../async_releaser.h" | |||||
#include "../mgb_cg_impl.h" | #include "../mgb_cg_impl.h" | ||||
namespace mgb { | namespace mgb { | ||||
@@ -73,7 +72,7 @@ ValueRefList LazyEvalTransformation::apply_transformation( | |||||
args.device.emplace(); | args.device.emplace(); | ||||
args.device->copy_from(*args.host); | args.device->copy_from(*args.host); | ||||
// every h2d in imperative runtime should notify AsyncReleaser | // every h2d in imperative runtime should notify AsyncReleaser | ||||
AsyncReleaser::inst()->add(*args.host); | |||||
async_release(*args.host); | |||||
} | } | ||||
return *args.device; | return *args.device; | ||||
}; | }; | ||||
@@ -155,7 +154,7 @@ ValueRefList LazyEvalTransformation::apply_transformation( | |||||
host_value.copy_from(inferred_value); | host_value.copy_from(inferred_value); | ||||
DeviceTensorND dev_value; | DeviceTensorND dev_value; | ||||
dev_value.copy_from(host_value); | dev_value.copy_from(host_value); | ||||
AsyncReleaser::inst()->add(host_value); | |||||
async_release(host_value); | |||||
return {DeviceValue::make(dev_value)}; | return {DeviceValue::make(dev_value)}; | ||||
} | } | ||||
default: | default: | ||||
@@ -24,18 +24,18 @@ public: | |||||
static BlobManager* inst(); | static BlobManager* inst(); | ||||
virtual void alloc_direct(Blob* blob, size_t size) = 0; | |||||
virtual void alloc_direct(OwnedBlob* blob, size_t size) = 0; | |||||
virtual void alloc_with_defrag(Blob* blob, size_t size) = 0; | |||||
virtual void alloc_with_defrag(OwnedBlob* blob, size_t size) = 0; | |||||
virtual void set_allocator(allocator_t allocator) = 0; | virtual void set_allocator(allocator_t allocator) = 0; | ||||
virtual DeviceTensorND alloc_workspace_with_defrag( | virtual DeviceTensorND alloc_workspace_with_defrag( | ||||
CompNode cn, TensorLayout& layout) = 0; | CompNode cn, TensorLayout& layout) = 0; | ||||
virtual void register_blob(Blob* blob) = 0; | |||||
virtual void register_blob(OwnedBlob* blob) = 0; | |||||
virtual void unregister_blob(Blob* blob) = 0; | |||||
virtual void unregister_blob(OwnedBlob* blob) = 0; | |||||
virtual void defrag(const CompNode& cn) = 0; | virtual void defrag(const CompNode& cn) = 0; | ||||
}; | }; | ||||
@@ -11,12 +11,16 @@ | |||||
#pragma once | #pragma once | ||||
#include <cstdint> | |||||
#include <memory> | #include <memory> | ||||
#include <mutex> | #include <mutex> | ||||
#include <type_traits> | |||||
#include <variant> | |||||
#include "megbrain/graph.h" | #include "megbrain/graph.h" | ||||
#include "megbrain/imperative/resource_manager.h" | #include "megbrain/imperative/resource_manager.h" | ||||
#include "megbrain/tensor.h" | #include "megbrain/tensor.h" | ||||
#include "megbrain/utils/metahelper.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -26,35 +30,67 @@ class Blob; | |||||
using BlobPtr = std::shared_ptr<Blob>; | using BlobPtr = std::shared_ptr<Blob>; | ||||
class BlobManagerImpl; | class BlobManagerImpl; | ||||
class OwnedBlob; | |||||
class Blob : public NonCopyableObj, public std::enable_shared_from_this<Blob> { | |||||
protected: | |||||
CompNode m_comp_node; | |||||
size_t m_size; | |||||
Blob(CompNode cn, size_t size) : m_comp_node(cn), m_size(size) {} | |||||
class Blob : public NonCopyableObj { | |||||
public: | public: | ||||
Blob(const DeviceTensorStorage& s); | |||||
Blob(CompNode cn, size_t sz); | |||||
~Blob(); | |||||
virtual ~Blob() = default; | |||||
template <typename... Args> | template <typename... Args> | ||||
static BlobPtr make(Args&&... args) { | |||||
return std::make_shared<Blob>(std::forward<Args>(args)...); | |||||
static std::shared_ptr<OwnedBlob> make(Args&&... args) { | |||||
return std::make_shared<OwnedBlob>(std::forward<Args>(args)...); | |||||
} | } | ||||
const CompNode& comp_node() const { return m_comp_node; } | |||||
size_t size() const { return m_size; } | |||||
using RawStorage = DeviceTensorStorage::RawStorage; | using RawStorage = DeviceTensorStorage::RawStorage; | ||||
const RawStorage& storage(); | |||||
virtual const RawStorage& storage() = 0; | |||||
virtual BlobPtr borrow_to(CompNode) = 0; | |||||
virtual bool storage_is_unique() = 0; | |||||
virtual void* raw_ptr_not_for_readwrite() = 0; | |||||
}; | |||||
const CompNode& comp_node() const { return m_comp_node; } | |||||
class OwnedBlob final : public Blob { | |||||
friend class Blob; | |||||
size_t size() const { return m_size; } | |||||
public: | |||||
OwnedBlob(const DeviceTensorStorage& s); | |||||
OwnedBlob(CompNode cn, size_t sz); | |||||
~OwnedBlob() override; | |||||
size_t id() const { return m_id; } | |||||
const RawStorage& storage() override; | |||||
BlobPtr borrow_to(CompNode) override; | |||||
bool storage_is_unique() override; | |||||
void* raw_ptr_not_for_readwrite() override; | |||||
private: | private: | ||||
friend class BlobManagerImpl; | friend class BlobManagerImpl; | ||||
CompNode m_comp_node; | |||||
mutable RawStorage m_storage; | |||||
size_t m_size = 0; | |||||
RawStorage m_storage; | |||||
size_t m_id; | size_t m_id; | ||||
}; | }; | ||||
class BorrowedBlob final : public Blob { | |||||
std::mutex m_mtx; | |||||
std::shared_ptr<OwnedBlob> m_owner; | |||||
uint64_t m_event; | |||||
bool m_initialized = false; | |||||
public: | |||||
BorrowedBlob(CompNode, std::shared_ptr<OwnedBlob>); | |||||
~BorrowedBlob() override; | |||||
const RawStorage& storage() override; | |||||
BlobPtr borrow_to(CompNode) override; | |||||
bool storage_is_unique() override; | |||||
void* raw_ptr_not_for_readwrite() override; | |||||
}; | |||||
struct EventDeleter { | struct EventDeleter { | ||||
void operator()(CompNode::Event*); | void operator()(CompNode::Event*); | ||||
}; | }; | ||||
@@ -121,6 +157,8 @@ public: | |||||
BlobPtr& blob() { return m_blob; } | BlobPtr& blob() { return m_blob; } | ||||
void* raw_ptr_not_for_readwrite() { return m_blob->raw_ptr_not_for_readwrite(); } | |||||
void fetch_value(); | void fetch_value(); | ||||
bool value_fetched(); | bool value_fetched(); | ||||
TensorPtr sub(size_t offset, TensorShape shape); | TensorPtr sub(size_t offset, TensorShape shape); | ||||
@@ -131,8 +169,10 @@ public: | |||||
// return a pointer instead of a reference to ensure thread safety | // return a pointer instead of a reference to ensure thread safety | ||||
const HostTensorND* try_get_value(); | const HostTensorND* try_get_value(); | ||||
void add_release_callback(CompNode cn); | |||||
CompNode::Event* get_or_create_event(); | |||||
void set_ready_event(uint64_t event) { m_produced_at = event; } | |||||
uint64_t get_ready_event(); | |||||
bool storage_is_unique(); | |||||
// Make sure all static objects required to destruct a tensor has completed | // Make sure all static objects required to destruct a tensor has completed | ||||
// construction. All static storage duration object that holds tensors must | // construction. All static storage duration object that holds tensors must | ||||
@@ -152,8 +192,37 @@ private: | |||||
std::mutex m_value_mtx; | std::mutex m_value_mtx; | ||||
HostTensorND m_value; | HostTensorND m_value; | ||||
EventPtr m_value_ready = nullptr; | EventPtr m_value_ready = nullptr; | ||||
uint64_t m_produced_at = 0; | |||||
}; | }; | ||||
/*! | |||||
* \brief record a virtual event | |||||
* \param doitnow also record a real event | |||||
*/ | |||||
uint64_t record_event(CompNode cn, bool doitnow = false); | |||||
//! make a device wait on a virtual event | |||||
void device_wait_event(CompNode waiter, CompNode waitee, uint64_t event); | |||||
//! hold a blob until a virtual event on a device is completed | |||||
void async_release(CompNode cn, uint64_t event, BlobPtr blob); | |||||
//! hold a host tensor until a virtual event on a device is completed | |||||
void async_release(CompNode cn, uint64_t event, HostTensorStorage::RawStorage storage); | |||||
inline void async_release(CompNode cn, uint64_t event, Tensor& tensor) { | |||||
async_release(cn, event, tensor.blob()); | |||||
} | |||||
inline void async_release(CompNode cn, uint64_t event, const HostTensorND& hnd) { | |||||
async_release(cn, event, hnd.storage().raw_storage()); | |||||
} | |||||
inline void async_release(const HostTensorND& hnd) { | |||||
auto cn = hnd.comp_node(); | |||||
async_release(cn, record_event(cn, true), hnd); | |||||
} | |||||
// Cache for small blobs | // Cache for small blobs | ||||
// 1. A blob has to be seen twice (within a window) to be eligible for cache | // 1. A blob has to be seen twice (within a window) to be eligible for cache | ||||
// 2. Cache eviction occurs when cache size reaches a threshold, in least frequently | // 2. Cache eviction occurs when cache size reaches a threshold, in least frequently | ||||
@@ -32,7 +32,7 @@ bool CompNodeImplHelper::EventImplHelper::finished() { | |||||
mgb_assert(m_recorded); | mgb_assert(m_recorded); | ||||
if (do_finished()) { | if (do_finished()) { | ||||
m_finished = true; | m_finished = true; | ||||
m_recorded = false; | |||||
// m_recorded = false; | |||||
return true; | return true; | ||||
} | } | ||||
return false; | return false; | ||||
@@ -120,6 +120,19 @@ def Copy: MgbHashableOp<"Copy"> { | |||||
); | ); | ||||
} | } | ||||
def Borrow: MgbHashableOp<"Borrow"> { | |||||
let extraArguments = (ins | |||||
MgbCompNodeAttr:$comp_node | |||||
); | |||||
} | |||||
def Barrier: MgbHashableOp<"Barrier"> { | |||||
let extraArguments = (ins | |||||
MgbCompNodeAttr:$comp_node, | |||||
MgbUI32Attr:$nr_outputs | |||||
); | |||||
} | |||||
def Argsort: MgbHashableOp<"Argsort", [ArgsortParam]>; | def Argsort: MgbHashableOp<"Argsort", [ArgsortParam]>; | ||||
def Argmax : MgbHashableOp<"Argmax", [AxisParam]>; | def Argmax : MgbHashableOp<"Argmax", [AxisParam]>; | ||||