GitOrigin-RevId: c68977c5dc
release-1.10
@@ -0,0 +1,21 @@ | |||
from .core._imperative_rt.core2 import apply | |||
from .core.ops.builtin import Barrier | |||
from .tensor import Tensor | |||
_dummy_tensors = {} | |||
def _get_dummy_tensor(device): | |||
if device not in _dummy_tensors: | |||
_dummy_tensors[device] = Tensor([], device=device) | |||
return _dummy_tensors[device] | |||
def record_event(device): | |||
x = _get_dummy_tensor(device) | |||
(x,) = apply(Barrier(device, 1), x) | |||
return x | |||
def wait_event(device, event): | |||
apply(Barrier(device, 0), event) |
@@ -51,7 +51,7 @@ _sh = _stream_helper() | |||
def _valid_device(inp): | |||
if isinstance(inp, str) and re.match( | |||
"^([cxg]pu|rocm|multithread)(\d+|\d+:\d+|x)$", inp | |||
"^([cxg]pu|rocm|multithread)(x|\d+)(:\d+)?$", inp | |||
): | |||
return True | |||
return False | |||
@@ -255,8 +255,8 @@ def what_is_xpu(): | |||
def coalesce_free_memory(): | |||
r"""This function will try it best to free all consecutive free chunks back to operating system, | |||
small pieces may not be returned. | |||
because of the async processing of megengine, the effect of this func may not be reflected | |||
because of the async processing of megengine, the effect of this func may not be reflected | |||
immediately. if you want to see the effect immediately, you can call megengine._full_sync after | |||
this func was called | |||
@@ -15,7 +15,7 @@ from .core._imperative_rt.core2 import Tensor as _Tensor | |||
from .core._imperative_rt.core2 import apply, set_py_tensor_type | |||
from .core._trace_option import use_symbolic_shape | |||
from .core._wrap import as_device | |||
from .core.ops.builtin import Copy, GetVarShape | |||
from .core.ops.builtin import Borrow, Copy, GetVarShape | |||
from .core.tensor.array_method import ArrayMethodMixin | |||
from .device import _valid_device, get_default_device | |||
from .logger import get_logger | |||
@@ -205,7 +205,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
def reset_zero(self): | |||
self *= 0 | |||
def to(self, device): | |||
def to(self, device, *, _borrow=False): | |||
r"""Copy self :class:`~.Tensor` to specified device. See :func:`~.copy`""" | |||
if isinstance(device, str) and not _valid_device(device): | |||
raise ValueError( | |||
@@ -214,7 +214,8 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
) | |||
) | |||
cn = as_device(device).to_c() | |||
return apply(Copy(comp_node=cn), self)[0] | |||
op = Borrow(comp_node=cn) if _borrow else Copy(comp_node=cn) | |||
return apply(op, self)[0] | |||
@property | |||
def requires_grad(self): | |||
@@ -232,11 +233,11 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
return id(self) | |||
def __getnewargs__(self): | |||
r""" __getnewargs__ will be called for pickle serialization or deep copy""" | |||
r"""__getnewargs__ will be called for pickle serialization or deep copy""" | |||
return (self.numpy(), self.dtype, self.device.logical_name) | |||
def __getstate__(self): | |||
r""" __getstate__ will be called for pickle serialization or deep copy""" | |||
r"""__getstate__ will be called for pickle serialization or deep copy""" | |||
state = {} | |||
if self._qparams is not None: | |||
state["qparams"] = self._qparams | |||
@@ -11,6 +11,7 @@ | |||
#pragma once | |||
#include <list> | |||
#include "megbrain/imperative/transformations/trace.h" | |||
#include "megbrain/imperative/utils/map.h" | |||
#include "megbrain/imperative/utils/stats.h" | |||
@@ -998,6 +998,9 @@ void init_tensor(py::module m) { | |||
module_trace_hook.inc_ref(); | |||
}); | |||
auto atexit = py::module::import("atexit"); | |||
atexit.attr("register")(py::cpp_function([]() { module_trace_hook = {}; })); | |||
m.def("begin_record_values", [] { Value::begin_record_values(); }); | |||
m.def("end_record_values", [] { | |||
@@ -0,0 +1,74 @@ | |||
import gc | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.functional as F | |||
from megengine._multistream import record_event, wait_event | |||
class MemStat: | |||
def __init__(self, *args): | |||
for d in args: | |||
mge.Tensor([], device=d) | |||
gc.collect() | |||
mge._full_sync() | |||
self.baseline = {d: mge.device.get_allocated_memory(d) for d in args} | |||
for d in args: | |||
mge.device.reset_max_memory_stats(d) | |||
def get_max(self, device): | |||
return mge.device.get_max_allocated_memory(device) - self.baseline[device] | |||
@pytest.mark.require_ngpu(1) | |||
def test_mem_stats(): | |||
memstat = MemStat("xpux:0", "xpux:1") | |||
F.arange(1024, device="xpux:0") | |||
mge._full_sync() | |||
assert 4096 <= memstat.get_max("xpux:0") == memstat.get_max("xpux:1") <= 4096 + 128 | |||
@pytest.mark.require_ngpu(1) | |||
def test_borrow(): | |||
memstat = MemStat("xpux:0", "xpux:1") | |||
x_np = np.random.randint(2 ** 30, size=(1 * 1024 * 1024,), dtype="int32") | |||
unit = x_np.size * 4 | |||
x0 = mge.Tensor(x_np, device="xpux:0") | |||
x1 = x0.to("xpux:1", _borrow=True) | |||
y = -x1 | |||
np.testing.assert_equal(-x_np, y.numpy()) | |||
mge._full_sync() | |||
assert memstat.get_max("xpux:0") / unit < 2.1 | |||
@pytest.mark.require_ngpu(1) | |||
def test_stream_mem(): | |||
memstat = MemStat("xpux:0", "xpux:1") | |||
x_np = np.random.randint(2 ** 10, size=(1 * 1024 * 1024,), dtype="int32") | |||
unit = x_np.size * 4 | |||
x0 = mge.Tensor(x_np, device="xpux:0") | |||
results = [] | |||
events = [] | |||
for i in range(100): | |||
if len(events) >= 2: | |||
wait_event("xpux:0", events[-2]) | |||
x0 = x0 + 1 | |||
results.append(x0.to("xpux:1", _borrow=True).sum()) | |||
events.append(record_event("xpux:1")) | |||
del events[:-2] | |||
y_np = x_np.sum() | |||
for i, y in enumerate(results): | |||
y_np += x_np.size | |||
assert y_np == y.numpy() | |||
mge._full_sync() | |||
assert memstat.get_max("xpux:0") / unit < 2.1 |
@@ -1,84 +0,0 @@ | |||
/** | |||
* \file imperative/src/impl/async_releaser.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/comp_node.h" | |||
#include "megbrain/imperative/blob_manager.h" | |||
#include "megbrain/imperative/resource_manager.h" | |||
#include "megbrain/system.h" | |||
#include "./event_pool.h" | |||
namespace mgb { | |||
namespace imperative { | |||
class AsyncReleaser : public CompNodeDepedentObject { | |||
struct WaiterParam { | |||
CompNode cn; | |||
CompNode::Event* event; | |||
BlobPtr blob; | |||
HostTensorStorage::RawStorage storage; | |||
}; | |||
class Waiter final : public AsyncQueueSC<WaiterParam, Waiter> { | |||
AsyncReleaser* m_par_releaser; | |||
public: | |||
// disable busy wait by set max_spin=0 to save CPU cycle | |||
Waiter(AsyncReleaser* releaser) | |||
: AsyncQueueSC<WaiterParam, Waiter>(0), m_par_releaser(releaser) {} | |||
void process_one_task(WaiterParam& param) { | |||
if (param.event->finished()) { | |||
param.blob.reset(); | |||
param.storage.reset(); | |||
EventPool::without_timer().free(param.event); | |||
return; | |||
} | |||
using namespace std::literals; | |||
std::this_thread::sleep_for(1us); | |||
add_task(std::move(param)); | |||
} | |||
void on_async_queue_worker_thread_start() override { | |||
sys::set_thread_name("releaser"); | |||
} | |||
}; | |||
Waiter m_waiter{this}; | |||
protected: | |||
std::shared_ptr<void> on_comp_node_finalize() override { | |||
m_waiter.wait_task_queue_empty(); | |||
return {}; | |||
} | |||
public: | |||
static AsyncReleaser* inst() { | |||
static auto* releaser = ResourceManager::create_global<AsyncReleaser>(); | |||
return releaser; | |||
} | |||
~AsyncReleaser() { m_waiter.wait_task_queue_empty(); } | |||
void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); } | |||
void add(const HostTensorND& hv) { | |||
add(hv.comp_node(), {}, hv.storage().raw_storage()); | |||
} | |||
void add(CompNode cn, BlobPtr blob, HostTensorStorage::RawStorage storage = {}) { | |||
auto event = EventPool::without_timer().alloc(cn); | |||
event->record(); | |||
m_waiter.add_task({cn, event, std::move(blob), std::move(storage)}); | |||
} | |||
}; | |||
} // namespace imperative | |||
} // namespace mgb |
@@ -16,7 +16,7 @@ | |||
namespace mgb { | |||
namespace imperative { | |||
BlobManagerImpl::BlobData::BlobData(Blob* in_blob) { | |||
BlobManagerImpl::BlobData::BlobData(OwnedBlob* in_blob) { | |||
blob = in_blob; | |||
DeviceTensorStorage d_storage; | |||
d_storage.reset(blob->m_comp_node, blob->m_size, blob->m_storage); | |||
@@ -28,19 +28,19 @@ BlobManagerImpl::BlobData::BlobData(Blob* in_blob) { | |||
h_storage.copy_from(const_cast<DeviceTensorStorage&>(d_storage), blob->m_size); | |||
} | |||
void BlobManagerImpl::register_blob(Blob* blob) { | |||
void BlobManagerImpl::register_blob(OwnedBlob* blob) { | |||
// add blob into the comp2blobs map | |||
MGB_LOCK_GUARD(m_mtx); | |||
mgb_assert(m_comp2blobs_map[blob->m_comp_node].insert(blob)); | |||
} | |||
void BlobManagerImpl::unregister_blob(Blob* blob) { | |||
void BlobManagerImpl::unregister_blob(OwnedBlob* blob) { | |||
// erase blob into the comp2blobs map | |||
MGB_LOCK_GUARD(m_mtx); | |||
mgb_assert(1 == m_comp2blobs_map[blob->m_comp_node].erase(blob)); | |||
} | |||
void BlobManagerImpl::alloc_with_defrag(Blob* blob, size_t size) { | |||
void BlobManagerImpl::alloc_with_defrag(OwnedBlob* blob, size_t size) { | |||
if (custom_allocator) { | |||
blob->m_storage = custom_allocator(blob->m_comp_node, size); | |||
return; | |||
@@ -55,7 +55,7 @@ void BlobManagerImpl::alloc_with_defrag(Blob* blob, size_t size) { | |||
}); | |||
} | |||
void BlobManagerImpl::alloc_direct(Blob* blob, size_t size) { | |||
void BlobManagerImpl::alloc_direct(OwnedBlob* blob, size_t size) { | |||
DeviceTensorStorage storage(blob->m_comp_node); | |||
mgb_assert(blob->m_comp_node.valid()); | |||
storage.ensure_size(size); | |||
@@ -143,7 +143,7 @@ void BlobManagerImpl::defrag(const CompNode& cn) { | |||
// sort blobs by created time, may be helpful for reduce memory fragment | |||
std::sort( | |||
blob_data_arrary.begin(), blob_data_arrary.end(), | |||
[](auto& lhs, auto& rhs) { return lhs.blob->id() < rhs.blob->id(); }); | |||
[](auto& lhs, auto& rhs) { return lhs.blob->m_id < rhs.blob->m_id; }); | |||
// allocate for each storage | |||
for (auto i : blob_data_arrary) { | |||
@@ -158,19 +158,19 @@ void BlobManagerImpl::defrag(const CompNode& cn) { | |||
} | |||
struct BlobManagerStub : BlobManager { | |||
void alloc_direct(Blob* blob, size_t size) { | |||
void alloc_direct(OwnedBlob* blob, size_t size) { | |||
mgb_assert(0, "prohibited after global variable destruction"); | |||
}; | |||
void alloc_with_defrag(Blob* blob, size_t size) { | |||
void alloc_with_defrag(OwnedBlob* blob, size_t size) { | |||
mgb_assert(0, "prohibited after global variable destruction"); | |||
}; | |||
DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout& layout) { | |||
mgb_assert(0, "prohibited after global variable destruction"); | |||
}; | |||
void register_blob(Blob* blob) { | |||
void register_blob(OwnedBlob* blob) { | |||
mgb_assert(0, "prohibited after global variable destruction"); | |||
}; | |||
void unregister_blob(Blob* blob){}; | |||
void unregister_blob(OwnedBlob* blob){}; | |||
void defrag(const CompNode& cn) { | |||
mgb_assert(0, "prohibited after global variable destruction"); | |||
}; | |||
@@ -19,21 +19,21 @@ namespace imperative { | |||
class BlobManagerImpl final : public BlobManager { | |||
struct BlobSetWithMux { | |||
std::mutex mtx; | |||
ThinHashSet<Blob*> blobs_set; | |||
bool insert(Blob* blob) { | |||
ThinHashSet<OwnedBlob*> blobs_set; | |||
bool insert(OwnedBlob* blob) { | |||
MGB_LOCK_GUARD(mtx); | |||
return blobs_set.insert(blob).second; | |||
} | |||
size_t erase(Blob* blob) { | |||
size_t erase(OwnedBlob* blob) { | |||
MGB_LOCK_GUARD(mtx); | |||
return blobs_set.erase(blob); | |||
} | |||
}; | |||
struct BlobData { | |||
Blob* blob; | |||
OwnedBlob* blob; | |||
HostTensorStorage h_storage; | |||
BlobData(Blob* in_blob); | |||
BlobData(OwnedBlob* in_blob); | |||
}; | |||
std::mutex m_mtx; | |||
@@ -41,7 +41,7 @@ class BlobManagerImpl final : public BlobManager { | |||
void defrag(const CompNode& cn) override; | |||
void alloc_direct(Blob* blob, size_t size) override; | |||
void alloc_direct(OwnedBlob* blob, size_t size) override; | |||
DeviceTensorND alloc_workspace(CompNode cn, TensorLayout layout); | |||
@@ -50,14 +50,14 @@ class BlobManagerImpl final : public BlobManager { | |||
public: | |||
static BlobManager* inst(); | |||
void alloc_with_defrag(Blob* blob, size_t size) override; | |||
void alloc_with_defrag(OwnedBlob* blob, size_t size) override; | |||
DeviceTensorND alloc_workspace_with_defrag( | |||
CompNode cn, TensorLayout& layout) override; | |||
void register_blob(Blob* blob) override; | |||
void register_blob(OwnedBlob* blob) override; | |||
void unregister_blob(Blob* blob) override; | |||
void unregister_blob(OwnedBlob* blob) override; | |||
void set_allocator(allocator_t allocator) override; | |||
}; | |||
@@ -704,7 +704,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | |||
m_dtr.update_used_time(dest); | |||
MGB_RECORD_EVENT( | |||
TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), | |||
ptr->dev_tensor(false).raw_ptr()); | |||
ptr->raw_ptr_not_for_readwrite()); | |||
// update tensor desc for static infer | |||
if (dest->desc.layout.ndim) { | |||
mgb_assert( | |||
@@ -805,8 +805,13 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||
inputs[idx]->to_contiguous_inplace(layout_checker); | |||
} | |||
} | |||
return OpDef::apply_on_physical_tensor( | |||
auto outputs = OpDef::apply_on_physical_tensor( | |||
def, std::move(inputs), output_descs, validated); | |||
for (auto& o : outputs) { | |||
o->set_ready_event( | |||
record_event(o->comp_node(), def.same_type<imperative::Barrier>())); | |||
} | |||
return outputs; | |||
}; | |||
MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason); | |||
SmallVector<std::pair<CompNode, uint64_t>> kernels; | |||
@@ -1059,7 +1064,7 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() { | |||
return valid_tensors; | |||
} | |||
void ChannelImpl::alloc_tensor_with_evict(Blob* x) { | |||
void ChannelImpl::alloc_tensor_with_evict(OwnedBlob* x) { | |||
bool in_worker = (get_worker_tid() == std::this_thread::get_id()); | |||
auto reserve_size = [&](size_t size) { | |||
if (!m_dtr.comp_node.valid()) { | |||
@@ -304,7 +304,7 @@ private: | |||
//! automatically evict an optimal tensor | |||
bool auto_evict(size_t); | |||
void alloc_tensor_with_evict(Blob*); | |||
void alloc_tensor_with_evict(OwnedBlob*); | |||
// assert thread id when call get_xxx_state to avoid misuse | |||
ChannelState& get_channel_state(); | |||
@@ -249,7 +249,7 @@ SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3]; | |||
if (!(inputs[0]->blob().unique() && inputs[0]->blob()->storage().unique())) { | |||
if (!inputs[0]->storage_is_unique()) { | |||
mgb_log_warn( | |||
"This inplace modification may change the elements of other tensors. " | |||
"Fallback to non-inplace update."); | |||
@@ -13,7 +13,6 @@ | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "../async_releaser.h" | |||
#include "../dnn_op_helper.h" | |||
#include "../op_trait.h" | |||
@@ -267,8 +266,7 @@ SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor( | |||
{srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(), | |||
output->dev_tensor().as_megdnn(), | |||
caller.create_workspace({{ws_size}, dtype::Byte()})); | |||
AsyncReleaser::inst()->add( | |||
HostTensorND{comp_node, srcs_layout}.storage(srcs_storage)); | |||
async_release(HostTensorND{comp_node, srcs_layout}.storage(srcs_storage)); | |||
return {output}; | |||
} | |||
@@ -23,6 +23,8 @@ | |||
#include "megbrain/opr/tensor_gen.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "megbrain/tensor.h" | |||
#include "megdnn/dtype.h" | |||
#if MGB_JIT | |||
#include "megbrain/jit/executor_opr.h" | |||
@@ -37,6 +39,102 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); | |||
OP_TRAIT_REG(GenericPyOp, GenericPyOp).fallback(); | |||
namespace { | |||
namespace borrow { | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto& op = def.cast_final_safe<Borrow>(); | |||
SmallVector<TensorPtr> outputs; | |||
outputs.reserve(inputs.size()); | |||
for (auto& i : inputs) { | |||
outputs.push_back(Tensor::make( | |||
i->blob()->borrow_to(op.comp_node), i->layout(), i->offset())); | |||
} | |||
return outputs; | |||
} | |||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
return SmallVector<VarNode::LayoutConstraintCallback>(inputs.size()); | |||
} | |||
auto infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto& op = def.cast_final_safe<Borrow>(); | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> ret(inputs, true); | |||
for (auto& i : std::get<0>(ret)) { | |||
mgb_assert( | |||
i.comp_node.mem_node() == op.comp_node.mem_node(), | |||
"cannot borrow memory from %s to %s", i.comp_node.to_string().c_str(), | |||
op.comp_node.to_string().c_str()); | |||
i.comp_node = op.comp_node; | |||
} | |||
return ret; | |||
} | |||
OP_TRAIT_REG(Borrow, Borrow) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.get_input_layout_constraint(get_input_layout_constraint) | |||
.fallback(); | |||
} // namespace borrow | |||
} // namespace | |||
namespace { | |||
namespace barrier { | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
auto& op = def.cast_final_safe<imperative::Barrier>(); | |||
SmallVector<TensorPtr> outputs; | |||
for (auto& i : inputs) { | |||
if (i->comp_node() != op.comp_node) { | |||
device_wait_event(op.comp_node, i->comp_node(), i->get_ready_event()); | |||
} | |||
} | |||
if (op.nr_outputs) { | |||
outputs.resize(op.nr_outputs); | |||
TensorLayout layout(TensorShape({0}), dtype::Int32{}); | |||
for (auto& i : outputs) { | |||
i = Tensor::make(layout, op.comp_node); | |||
} | |||
} | |||
return outputs; | |||
} | |||
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
return SmallVector<VarNode::LayoutConstraintCallback>(inputs.size()); | |||
} | |||
auto infer_output_attrs_fallible( | |||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
auto& op = def.cast_final_safe<imperative::Barrier>(); | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> ret; | |||
auto& [descs, checked] = ret; | |||
descs.resize(op.nr_outputs); | |||
checked = true; | |||
for (auto& desc : descs) { | |||
desc.comp_node = op.comp_node; | |||
desc.layout.dtype = dtype::Int32{}; | |||
desc.layout.ndim = 1; | |||
desc.layout.shape[0] = 0; | |||
desc.layout.stride[0] = 1; | |||
} | |||
return ret; | |||
} | |||
OP_TRAIT_REG(Barrier, imperative::Barrier) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.get_input_layout_constraint(get_input_layout_constraint) | |||
.fallback(); | |||
} // namespace barrier | |||
} // namespace | |||
namespace { | |||
namespace fastpathcopy { | |||
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
return inputs; | |||
@@ -9,57 +9,384 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/physical_tensor.h" | |||
#include "megbrain/common.h" | |||
#include "megbrain/comp_node.h" | |||
#include "megbrain/imperative.h" | |||
#include "megbrain/imperative/blob_manager.h" | |||
#include "megbrain/imperative/profiler.h" | |||
#include "megbrain/imperative/resource_manager.h" | |||
#include "./async_releaser.h" | |||
#include "./event_pool.h" | |||
#include "./profiler/events.h" | |||
#include <condition_variable> | |||
#include <cstdint> | |||
#include <deque> | |||
#include <map> | |||
#include <memory> | |||
#include <mutex> | |||
#include <type_traits> | |||
#include <unordered_map> | |||
#include <utility> | |||
#include <variant> | |||
#include <vector> | |||
#ifndef WIN32 | |||
#include <pthread.h> | |||
#endif | |||
#include "range/v3/all.hpp" | |||
namespace views = ranges::views; | |||
namespace mgb { | |||
namespace imperative { | |||
namespace { | |||
class CompNodeSyncManager : public CompNodeDepedentObject { | |||
ThinHashMap<Blob*, std::unique_ptr<CompNode::Event>> m_blob2event; | |||
struct CompNodeHash { | |||
auto operator()(CompNode cn) const { return mgb::hash(cn); } | |||
}; | |||
template <typename T> | |||
struct NoThrowMovable : T { | |||
using T::T; | |||
NoThrowMovable(NoThrowMovable&&) noexcept = default; | |||
}; | |||
template <typename... Ts> | |||
using Map = NoThrowMovable<std::map<Ts...>>; | |||
class CompNodeSyncManager { | |||
struct CompNodeData { | |||
template <typename T> | |||
class ReleaseQueue { | |||
Map<uint64_t, T> map; | |||
public: | |||
template <typename A> | |||
void emplace(uint64_t t, A&& a) { | |||
map.emplace_hint(map.end(), t, std::forward<A>(a)); | |||
} | |||
void release(uint64_t t) { | |||
auto it = map.upper_bound(t); | |||
map.erase(map.begin(), it); | |||
} | |||
}; | |||
//! next virtual event | |||
uint64_t next = 1; | |||
//! last completed virtual event | |||
uint64_t completed = 0; | |||
//! virtual event to real event | |||
Map<uint64_t, EventPtr> events; | |||
//! ordering information at some virtual events: | |||
//! what virtual events on other comp nodes is _sequenced before_ this virtual | |||
//! event | |||
Map<uint64_t, std::vector<uint64_t>> ordering; | |||
//! release queue for dev storage, keyed by releaser. this comp node is the | |||
//! **receiver** | |||
std::vector<ReleaseQueue<BlobPtr>> release_queues; | |||
//! release queue for host storage. this comp node is the **releaser** | |||
ReleaseQueue<HostTensorStorage::RawStorage> host_release_queue; | |||
}; | |||
std::mutex m_mtx; | |||
std::condition_variable m_cv; | |||
bool m_should_stop = false; | |||
std::thread m_polling_thread; | |||
std::unordered_map<CompNode, size_t, CompNodeHash> m_cn2id; | |||
std::vector<CompNodeData> m_cndata; | |||
auto do_record(CompNode cn, size_t cnid, std::unique_lock<std::mutex>& lock) { | |||
// CAUSION: don't keep reference across locking boundary | |||
lock.unlock(); | |||
auto e = EventPool::without_timer().alloc(cn); | |||
e->record(); | |||
lock.lock(); | |||
auto& cndata = m_cndata[cnid]; | |||
return cndata.events.emplace_hint(cndata.events.end(), cndata.next++, e); | |||
} | |||
std::pair<uint64_t, CompNode::Event*> get_event( | |||
CompNode cn, size_t cnid, uint64_t t, std::unique_lock<std::mutex>& lock) { | |||
auto& cndata = m_cndata[cnid]; | |||
auto it = cndata.events.lower_bound(t); | |||
if (it == cndata.events.end()) { | |||
it = do_record(cn, cnid, lock); | |||
} | |||
return {it->first, it->second.get()}; | |||
} | |||
size_t get_cnid_unsafe(CompNode cn) { | |||
auto [it, unseen] = m_cn2id.try_emplace(cn, m_cndata.size()); | |||
if (unseen) { | |||
m_cndata.emplace_back(); | |||
} | |||
return it->second; | |||
} | |||
void monitor_events() { | |||
#if defined(__APPLE__) | |||
pthread_setname_np("CompNodeSync"); | |||
#elif defined(__unix__) | |||
pthread_setname_np(pthread_self(), "CompNodeSync"); | |||
#endif | |||
// poll events in rounds. sleep for a fixed duration between rounds. | |||
// number of events to query is decided by the number of successful queries in | |||
// last round, independently for each comp node: | |||
// a. all -> double | |||
// b. 0 -> 1 | |||
// c. otherwise -> #successful | |||
struct Item { | |||
size_t cnid; | |||
decltype(CompNodeData::events)::iterator it; | |||
}; | |||
struct Stat { | |||
size_t num_success = 0; | |||
size_t num_attempts = 0; | |||
// iterator to the last finished event | |||
decltype(CompNodeData::events)::iterator it; | |||
}; | |||
std::vector<Stat> stats; | |||
std::vector<Item> todos; | |||
std::unique_lock lock(m_mtx); | |||
for (;;) { | |||
// copy events to a temporary storage so that we may unlock while polling | |||
stats.resize(m_cndata.size()); | |||
for (size_t cnid = 0; cnid < m_cndata.size(); ++cnid) { | |||
// decide max number of events to query | |||
// rule c: #successful | |||
size_t n = stats[cnid].num_success; | |||
if (n == stats[cnid].num_attempts) { | |||
// rule a: double | |||
n *= 2; | |||
} | |||
if (n == 0) { | |||
// rule b: 1 | |||
n = 1; | |||
} | |||
// now copy upto n events | |||
auto& events = m_cndata[cnid].events; | |||
size_t i = 0; | |||
for (auto it = events.begin(); i < n && it != events.end(); ++i, ++it) { | |||
todos.push_back({cnid, it}); | |||
} | |||
// reset stats for this round | |||
stats[cnid].num_success = 0; | |||
stats[cnid].num_attempts = n; | |||
} | |||
lock.unlock(); | |||
bool last_result = false; | |||
size_t last_cnid = -1; | |||
for (auto item : todos) { | |||
if (item.cnid == last_cnid && !last_result) { | |||
// previous failed, this one almost certainly should fail | |||
continue; | |||
} | |||
last_cnid = item.cnid; | |||
last_result = item.it->second->finished(); | |||
if (last_result) { | |||
stats[item.cnid].num_success++; | |||
stats[item.cnid].it = item.it; | |||
} | |||
} | |||
todos.clear(); | |||
lock.lock(); | |||
// release dev storage | |||
for (size_t receiver_cnid = 0; receiver_cnid < m_cndata.size(); | |||
++receiver_cnid) { | |||
for (size_t releaser_cnid = 0; | |||
releaser_cnid < m_cndata[receiver_cnid].release_queues.size(); | |||
++releaser_cnid) { | |||
if (releaser_cnid >= stats.size() || | |||
stats[releaser_cnid].num_success == 0) { | |||
continue; | |||
} | |||
auto& q = m_cndata[receiver_cnid].release_queues[releaser_cnid]; | |||
q.release(stats[releaser_cnid].it->first); | |||
} | |||
} | |||
for (size_t cnid = 0; cnid < stats.size(); ++cnid) { | |||
if (stats[cnid].num_success == 0) { | |||
continue; | |||
} | |||
auto& cndata = m_cndata[cnid]; | |||
auto it = stats[cnid].it; | |||
auto t = it->first; | |||
// update completed | |||
cndata.completed = t; | |||
// release host storage | |||
cndata.host_release_queue.release(t); | |||
// remove completed events | |||
auto& events = cndata.events; | |||
events.erase(events.begin(), std::next(it)); | |||
} | |||
using namespace std::literals; | |||
if (m_cv.wait_for(lock, 10us, [&] { return m_should_stop; })) { | |||
return; | |||
} | |||
} | |||
} | |||
CompNodeSyncManager() { | |||
m_polling_thread = std::thread([this] { monitor_events(); }); | |||
} | |||
public: | |||
std::shared_ptr<void> on_comp_node_finalize() override { | |||
MGB_LOCK_GUARD(m_mtx); | |||
m_blob2event.clear(); | |||
return {}; | |||
~CompNodeSyncManager() { | |||
{ | |||
MGB_LOCK_GUARD(m_mtx); | |||
m_should_stop = true; | |||
m_cv.notify_all(); | |||
} | |||
m_polling_thread.join(); | |||
} | |||
static CompNodeSyncManager& inst() { | |||
static auto* sl_inst = ResourceManager::create_global<CompNodeSyncManager>(); | |||
return *sl_inst; | |||
static CompNodeSyncManager& inst(); | |||
uint64_t record(CompNode cn, bool doitnow = false) { | |||
std::unique_lock lock(m_mtx); | |||
auto cnid = get_cnid_unsafe(cn); | |||
if (doitnow) { | |||
return do_record(cn, cnid, lock)->first; | |||
} | |||
return m_cndata[cnid].next++; | |||
} | |||
CompNode::Event* get_or_create_event(Blob* blob) { | |||
mgb_assert(!is_finalized()); | |||
void async_release(CompNode cn, uint64_t t, BlobPtr blob) { | |||
MGB_LOCK_GUARD(m_mtx); | |||
auto&& e = m_blob2event[blob]; | |||
if (!e) { | |||
e = blob->comp_node().create_event(); | |||
auto releaser_cnid = get_cnid_unsafe(cn); | |||
if (t <= m_cndata[releaser_cnid].completed) { | |||
return; | |||
} | |||
auto receiver_cnid = get_cnid_unsafe(blob->comp_node()); | |||
auto& qs = m_cndata[receiver_cnid].release_queues; | |||
if (releaser_cnid >= qs.size()) { | |||
qs.resize(releaser_cnid + 1); | |||
} | |||
return e.get(); | |||
auto& q = qs[releaser_cnid]; | |||
q.emplace(t, std::move(blob)); | |||
} | |||
void remove(Blob* blob) { | |||
void async_release(CompNode cn, uint64_t t, HostTensorStorage::RawStorage storage) { | |||
MGB_LOCK_GUARD(m_mtx); | |||
m_blob2event.erase(blob); | |||
auto releaser_cnid = get_cnid_unsafe(cn); | |||
if (t <= m_cndata[releaser_cnid].completed) { | |||
return; | |||
} | |||
auto& q = m_cndata[releaser_cnid].host_release_queue; | |||
q.emplace(t, std::move(storage)); | |||
} | |||
void device_wait(CompNode waiter, CompNode waitee, uint64_t t) { | |||
std::unique_lock lock(m_mtx); | |||
auto waiter_id = get_cnid_unsafe(waiter); | |||
auto waitee_id = get_cnid_unsafe(waitee); | |||
auto& waiter_data = m_cndata.at(waiter_id); | |||
auto& waitee_data = m_cndata.at(waitee_id); | |||
auto [t_waitee, e] = get_event(waitee, waitee_id, t, lock); | |||
// DO NOT unlock around this line! Event* could be invalidated! | |||
e->device_wait_by(waiter); | |||
auto t_waiter = waiter_data.next++; | |||
std::vector<uint64_t> ordering(m_cndata.size(), 0); | |||
if (!waiter_data.ordering.empty()) { | |||
auto& o = waiter_data.ordering.rbegin()->second; | |||
std::copy(o.begin(), o.end(), ordering.begin()); | |||
} | |||
ordering[waitee_id] = t_waitee; | |||
ordering[waiter_id] = t_waiter; | |||
{ | |||
auto it = waitee_data.ordering.lower_bound(t_waitee); | |||
if (it != waitee_data.ordering.begin()) { | |||
for (auto [a, b] : views::zip(ordering, std::prev(it)->second)) { | |||
static_assert(std::is_lvalue_reference_v<decltype(a)>); | |||
a = std::max(a, b); | |||
} | |||
} | |||
} | |||
waiter_data.ordering.emplace_hint( | |||
waiter_data.ordering.end(), t_waiter, ordering); | |||
for (auto [t, q] : views::zip(ordering, waiter_data.release_queues)) { | |||
q.release(t); | |||
} | |||
} | |||
}; | |||
CompNodeSyncManager& CompNodeSyncManager::inst() { | |||
static std::mutex mtx; | |||
static std::unique_ptr<CompNodeSyncManager> inst; | |||
struct Guard final : CompNodeDepedentObject { | |||
std::shared_ptr<void> on_comp_node_finalize() override { | |||
MGB_LOCK_GUARD(mtx); | |||
inst.reset(); | |||
return {}; | |||
} | |||
}; | |||
static std::optional<Guard> guard; | |||
#ifndef WIN32 | |||
static bool broken = false; | |||
static struct ForkGuard { | |||
ForkGuard() { | |||
mgb_assert(0 == pthread_atfork(NULL, NULL, [] { | |||
if (inst) { | |||
inst.release(); // deliberate leak, unfixable | |||
broken = true; | |||
} | |||
})); | |||
} | |||
} fork_guard; | |||
#endif | |||
MGB_LOCK_GUARD(mtx); | |||
if (!inst) { | |||
#ifndef WIN32 | |||
mgb_assert(!broken); | |||
#endif | |||
EventPool::without_timer(); | |||
inst.reset(new CompNodeSyncManager); | |||
guard.emplace(); | |||
} | |||
return *inst; | |||
} | |||
} // namespace | |||
uint64_t record_event(CompNode cn, bool doitnow) { | |||
return CompNodeSyncManager::inst().record(cn, doitnow); | |||
} | |||
void device_wait_event(CompNode waiter, CompNode waitee, uint64_t event) { | |||
CompNodeSyncManager::inst().device_wait(waiter, waitee, event); | |||
} | |||
void async_release(CompNode cn, uint64_t event, BlobPtr blob) { | |||
CompNodeSyncManager::inst().async_release(cn, event, std::move(blob)); | |||
} | |||
void async_release(CompNode cn, uint64_t event, HostTensorStorage::RawStorage storage) { | |||
CompNodeSyncManager::inst().async_release(cn, event, std::move(storage)); | |||
} | |||
void EventDeleter::operator()(CompNode::Event* event) { | |||
EventPool::without_timer().free(event); | |||
} | |||
@@ -68,31 +395,74 @@ namespace { | |||
std::atomic_uint64_t next_blob_id = 0; | |||
} | |||
Blob::Blob(const DeviceTensorStorage& s) | |||
: m_comp_node{s.comp_node()}, | |||
OwnedBlob::OwnedBlob(const DeviceTensorStorage& s) | |||
: Blob(s.comp_node(), s.size() + s.offset()), | |||
m_storage{s.raw_storage()}, | |||
m_size{s.size() + s.offset()} { | |||
m_id = next_blob_id++; | |||
m_id{next_blob_id++} { | |||
BlobManager::inst()->register_blob(this); | |||
} | |||
Blob::Blob(CompNode cn, size_t sz) : m_comp_node{cn}, m_storage{}, m_size{sz} { | |||
m_id = next_blob_id++; | |||
OwnedBlob::OwnedBlob(CompNode cn, size_t sz) | |||
: Blob(cn, sz), m_storage{}, m_id{next_blob_id++} { | |||
BlobManager::inst()->register_blob(this); | |||
} | |||
Blob::~Blob() { | |||
OwnedBlob::~OwnedBlob() { | |||
BlobManager::inst()->unregister_blob(this); | |||
CompNodeSyncManager::inst().remove(this); | |||
} | |||
const Blob::RawStorage& Blob::storage() { | |||
const Blob::RawStorage& OwnedBlob::storage() { | |||
if (!m_storage && m_size) { | |||
BlobManager::inst()->alloc_with_defrag(this, m_size); | |||
} | |||
return m_storage; | |||
} | |||
BlobPtr OwnedBlob::borrow_to(CompNode cn) { | |||
return std::make_shared<BorrowedBlob>( | |||
cn, std::static_pointer_cast<OwnedBlob>(shared_from_this())); | |||
} | |||
bool OwnedBlob::storage_is_unique() { | |||
return m_storage.unique(); | |||
} | |||
void* OwnedBlob::raw_ptr_not_for_readwrite() { | |||
return m_storage.get(); | |||
} | |||
BorrowedBlob::BorrowedBlob(CompNode cn, std::shared_ptr<OwnedBlob> owner) | |||
: Blob(cn, owner->size()), | |||
m_owner(std::move(owner)), | |||
m_event(record_event(m_owner->comp_node(), true)) {} | |||
BorrowedBlob::~BorrowedBlob() { | |||
async_release(m_comp_node, record_event(m_comp_node, true), std::move(m_owner)); | |||
} | |||
const Blob::RawStorage& BorrowedBlob::storage() { | |||
{ | |||
MGB_LOCK_GUARD(m_mtx); | |||
if (!m_initialized) { | |||
device_wait_event(m_comp_node, m_owner->comp_node(), m_event); | |||
m_initialized = true; | |||
} | |||
} | |||
return m_owner->storage(); | |||
} | |||
BlobPtr BorrowedBlob::borrow_to(CompNode cn) { | |||
return std::make_shared<BorrowedBlob>(cn, m_owner); | |||
} | |||
bool BorrowedBlob::storage_is_unique() { | |||
return m_owner.unique() && m_owner->storage_is_unique(); | |||
} | |||
void* BorrowedBlob::raw_ptr_not_for_readwrite() { | |||
return m_owner->raw_ptr_not_for_readwrite(); | |||
} | |||
Tensor::Tensor( | |||
BlobPtr blob, const TensorLayout& layout, size_t offset, const HostTensorND& hv) | |||
: m_cn(blob->comp_node()), | |||
@@ -119,7 +489,7 @@ Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) { | |||
MGB_RECORD_EVENT( | |||
profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(), | |||
hv.raw_ptr(), dev_tensor().raw_ptr()); | |||
AsyncReleaser::inst()->add(hv); | |||
async_release(hv); | |||
} | |||
} | |||
@@ -174,7 +544,7 @@ void Tensor::to_contiguous_inplace(VarNode::LayoutConstraintCallback& layout_che | |||
DeviceTensorND dv_contig; | |||
dv_contig.copy_from(dv); | |||
m_layout = dv_contig.layout(); | |||
std::atomic_store(&m_blob, Blob::make(dv_contig.storage())); | |||
std::atomic_store(&m_blob, BlobPtr(Blob::make(dv_contig.storage()))); | |||
mgb_assert(m_layout.is_contiguous()); | |||
m_offset = 0; | |||
} | |||
@@ -188,7 +558,7 @@ void Tensor::to_contiguous_inplace() { | |||
void Tensor::assign_from_dev_tensor(DeviceTensorND dv) { | |||
MGB_LOCK_GUARD(m_blob_mtx); | |||
std::atomic_store(&m_blob, Blob::make(dv.storage())); | |||
std::atomic_store(&m_blob, BlobPtr(Blob::make(dv.storage()))); | |||
m_offset = dv.storage().offset(); | |||
m_layout = dv.layout(); | |||
} | |||
@@ -254,21 +624,20 @@ TensorPtr Tensor::sub(size_t offset, TensorShape shape) { | |||
return Tensor::make(m_blob, offset + m_offset, layout); | |||
} | |||
void Tensor::add_release_callback(CompNode cn) { | |||
AsyncReleaser::inst()->add(m_blob, cn); | |||
uint64_t Tensor::get_ready_event() { | |||
if (m_produced_at == 0) { | |||
m_produced_at = record_event(comp_node()); | |||
} | |||
return m_produced_at; | |||
} | |||
CompNode::Event* Tensor::get_or_create_event() { | |||
auto e = CompNodeSyncManager::inst().get_or_create_event(m_blob.get()); | |||
e->record(); | |||
return e; | |||
bool Tensor::storage_is_unique() { | |||
return m_blob.unique() && m_blob->storage_is_unique(); | |||
} | |||
void Tensor::static_initialize() { | |||
EventPool::with_timer(); | |||
EventPool::without_timer(); | |||
AsyncReleaser::inst(); | |||
CompNodeSyncManager::inst(); | |||
MultiCNConstTensorCache::inst(); | |||
} | |||
@@ -836,8 +836,8 @@ public: | |||
if (used_cns.insert(cn).second) { | |||
for (auto&& in : inputs) { | |||
if (in->comp_node() != cn) { | |||
auto&& e = in->get_or_create_event(); | |||
e->device_wait_by(cn); | |||
auto e = in->get_ready_event(); | |||
device_wait_event(cn, in->comp_node(), e); | |||
} | |||
} | |||
} | |||
@@ -847,11 +847,17 @@ public: | |||
// so we need create inference session here | |||
minigraph.execute(raw_inputs, raw_outputs, m_env); | |||
for (auto&& cn : used_cns) { | |||
bool should_record = false; | |||
for (auto&& in : inputs) { | |||
if (in->comp_node() != cn) { | |||
in->add_release_callback(cn); | |||
should_record = true; | |||
auto e = record_event(cn); | |||
async_release(cn, e, *in); | |||
} | |||
} | |||
if (should_record) { | |||
record_event(cn, true); | |||
} | |||
} | |||
return outputs; | |||
@@ -15,7 +15,6 @@ | |||
#include "megbrain/opr/utility.h" | |||
#include "../async_releaser.h" | |||
#include "../mgb_cg_impl.h" | |||
namespace mgb { | |||
@@ -73,7 +72,7 @@ ValueRefList LazyEvalTransformation::apply_transformation( | |||
args.device.emplace(); | |||
args.device->copy_from(*args.host); | |||
// every h2d in imperative runtime should notify AsyncReleaser | |||
AsyncReleaser::inst()->add(*args.host); | |||
async_release(*args.host); | |||
} | |||
return *args.device; | |||
}; | |||
@@ -155,7 +154,7 @@ ValueRefList LazyEvalTransformation::apply_transformation( | |||
host_value.copy_from(inferred_value); | |||
DeviceTensorND dev_value; | |||
dev_value.copy_from(host_value); | |||
AsyncReleaser::inst()->add(host_value); | |||
async_release(host_value); | |||
return {DeviceValue::make(dev_value)}; | |||
} | |||
default: | |||
@@ -24,18 +24,18 @@ public: | |||
static BlobManager* inst(); | |||
virtual void alloc_direct(Blob* blob, size_t size) = 0; | |||
virtual void alloc_direct(OwnedBlob* blob, size_t size) = 0; | |||
virtual void alloc_with_defrag(Blob* blob, size_t size) = 0; | |||
virtual void alloc_with_defrag(OwnedBlob* blob, size_t size) = 0; | |||
virtual void set_allocator(allocator_t allocator) = 0; | |||
virtual DeviceTensorND alloc_workspace_with_defrag( | |||
CompNode cn, TensorLayout& layout) = 0; | |||
virtual void register_blob(Blob* blob) = 0; | |||
virtual void register_blob(OwnedBlob* blob) = 0; | |||
virtual void unregister_blob(Blob* blob) = 0; | |||
virtual void unregister_blob(OwnedBlob* blob) = 0; | |||
virtual void defrag(const CompNode& cn) = 0; | |||
}; | |||
@@ -11,12 +11,16 @@ | |||
#pragma once | |||
#include <cstdint> | |||
#include <memory> | |||
#include <mutex> | |||
#include <type_traits> | |||
#include <variant> | |||
#include "megbrain/graph.h" | |||
#include "megbrain/imperative/resource_manager.h" | |||
#include "megbrain/tensor.h" | |||
#include "megbrain/utils/metahelper.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -26,35 +30,67 @@ class Blob; | |||
using BlobPtr = std::shared_ptr<Blob>; | |||
class BlobManagerImpl; | |||
class OwnedBlob; | |||
class Blob : public NonCopyableObj, public std::enable_shared_from_this<Blob> { | |||
protected: | |||
CompNode m_comp_node; | |||
size_t m_size; | |||
Blob(CompNode cn, size_t size) : m_comp_node(cn), m_size(size) {} | |||
class Blob : public NonCopyableObj { | |||
public: | |||
Blob(const DeviceTensorStorage& s); | |||
Blob(CompNode cn, size_t sz); | |||
~Blob(); | |||
virtual ~Blob() = default; | |||
template <typename... Args> | |||
static BlobPtr make(Args&&... args) { | |||
return std::make_shared<Blob>(std::forward<Args>(args)...); | |||
static std::shared_ptr<OwnedBlob> make(Args&&... args) { | |||
return std::make_shared<OwnedBlob>(std::forward<Args>(args)...); | |||
} | |||
const CompNode& comp_node() const { return m_comp_node; } | |||
size_t size() const { return m_size; } | |||
using RawStorage = DeviceTensorStorage::RawStorage; | |||
const RawStorage& storage(); | |||
virtual const RawStorage& storage() = 0; | |||
virtual BlobPtr borrow_to(CompNode) = 0; | |||
virtual bool storage_is_unique() = 0; | |||
virtual void* raw_ptr_not_for_readwrite() = 0; | |||
}; | |||
const CompNode& comp_node() const { return m_comp_node; } | |||
class OwnedBlob final : public Blob { | |||
friend class Blob; | |||
size_t size() const { return m_size; } | |||
public: | |||
OwnedBlob(const DeviceTensorStorage& s); | |||
OwnedBlob(CompNode cn, size_t sz); | |||
~OwnedBlob() override; | |||
size_t id() const { return m_id; } | |||
const RawStorage& storage() override; | |||
BlobPtr borrow_to(CompNode) override; | |||
bool storage_is_unique() override; | |||
void* raw_ptr_not_for_readwrite() override; | |||
private: | |||
friend class BlobManagerImpl; | |||
CompNode m_comp_node; | |||
mutable RawStorage m_storage; | |||
size_t m_size = 0; | |||
RawStorage m_storage; | |||
size_t m_id; | |||
}; | |||
class BorrowedBlob final : public Blob { | |||
std::mutex m_mtx; | |||
std::shared_ptr<OwnedBlob> m_owner; | |||
uint64_t m_event; | |||
bool m_initialized = false; | |||
public: | |||
BorrowedBlob(CompNode, std::shared_ptr<OwnedBlob>); | |||
~BorrowedBlob() override; | |||
const RawStorage& storage() override; | |||
BlobPtr borrow_to(CompNode) override; | |||
bool storage_is_unique() override; | |||
void* raw_ptr_not_for_readwrite() override; | |||
}; | |||
struct EventDeleter { | |||
void operator()(CompNode::Event*); | |||
}; | |||
@@ -121,6 +157,8 @@ public: | |||
BlobPtr& blob() { return m_blob; } | |||
void* raw_ptr_not_for_readwrite() { return m_blob->raw_ptr_not_for_readwrite(); } | |||
void fetch_value(); | |||
bool value_fetched(); | |||
TensorPtr sub(size_t offset, TensorShape shape); | |||
@@ -131,8 +169,10 @@ public: | |||
// return a pointer instead of a reference to ensure thread safety | |||
const HostTensorND* try_get_value(); | |||
void add_release_callback(CompNode cn); | |||
CompNode::Event* get_or_create_event(); | |||
void set_ready_event(uint64_t event) { m_produced_at = event; } | |||
uint64_t get_ready_event(); | |||
bool storage_is_unique(); | |||
// Make sure all static objects required to destruct a tensor has completed | |||
// construction. All static storage duration object that holds tensors must | |||
@@ -152,8 +192,37 @@ private: | |||
std::mutex m_value_mtx; | |||
HostTensorND m_value; | |||
EventPtr m_value_ready = nullptr; | |||
uint64_t m_produced_at = 0; | |||
}; | |||
/*! | |||
* \brief record a virtual event | |||
* \param doitnow also record a real event | |||
*/ | |||
uint64_t record_event(CompNode cn, bool doitnow = false); | |||
//! make a device wait on a virtual event | |||
void device_wait_event(CompNode waiter, CompNode waitee, uint64_t event); | |||
//! hold a blob until a virtual event on a device is completed | |||
void async_release(CompNode cn, uint64_t event, BlobPtr blob); | |||
//! hold a host tensor until a virtual event on a device is completed | |||
void async_release(CompNode cn, uint64_t event, HostTensorStorage::RawStorage storage); | |||
inline void async_release(CompNode cn, uint64_t event, Tensor& tensor) { | |||
async_release(cn, event, tensor.blob()); | |||
} | |||
inline void async_release(CompNode cn, uint64_t event, const HostTensorND& hnd) { | |||
async_release(cn, event, hnd.storage().raw_storage()); | |||
} | |||
inline void async_release(const HostTensorND& hnd) { | |||
auto cn = hnd.comp_node(); | |||
async_release(cn, record_event(cn, true), hnd); | |||
} | |||
// Cache for small blobs | |||
// 1. A blob has to be seen twice (within a window) to be eligible for cache | |||
// 2. Cache eviction occurs when cache size reaches a threshold, in least frequently | |||
@@ -32,7 +32,7 @@ bool CompNodeImplHelper::EventImplHelper::finished() { | |||
mgb_assert(m_recorded); | |||
if (do_finished()) { | |||
m_finished = true; | |||
m_recorded = false; | |||
// m_recorded = false; | |||
return true; | |||
} | |||
return false; | |||
@@ -120,6 +120,19 @@ def Copy: MgbHashableOp<"Copy"> { | |||
); | |||
} | |||
def Borrow: MgbHashableOp<"Borrow"> { | |||
let extraArguments = (ins | |||
MgbCompNodeAttr:$comp_node | |||
); | |||
} | |||
def Barrier: MgbHashableOp<"Barrier"> { | |||
let extraArguments = (ins | |||
MgbCompNodeAttr:$comp_node, | |||
MgbUI32Attr:$nr_outputs | |||
); | |||
} | |||
def Argsort: MgbHashableOp<"Argsort", [ArgsortParam]>; | |||
def Argmax : MgbHashableOp<"Argmax", [AxisParam]>; | |||