From 877bda4180943728c552b24a6e024a193633c7a0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 14 Feb 2022 14:06:31 +0800 Subject: [PATCH] perf(mge): improve cross stream memory borrowing GitOrigin-RevId: c68977c5dce157bfc624f887f085d7a7aa6c54ac --- imperative/python/megengine/_multistream.py | 21 + imperative/python/megengine/device.py | 6 +- imperative/python/megengine/tensor.py | 11 +- imperative/python/src/module_trace.h | 1 + imperative/python/src/tensor.cpp | 3 + imperative/python/test/unit/core/test_stream.py | 74 ++++ imperative/src/impl/async_releaser.h | 84 ---- imperative/src/impl/blob_manager_impl.cpp | 20 +- imperative/src/impl/blob_manager_impl.h | 18 +- .../src/impl/interpreter/interpreter_impl.cpp | 11 +- imperative/src/impl/interpreter/interpreter_impl.h | 2 +- imperative/src/impl/ops/elemwise.cpp | 2 +- imperative/src/impl/ops/tensor_manip.cpp | 4 +- imperative/src/impl/ops/utility.cpp | 98 +++++ imperative/src/impl/physical_tensor.cpp | 445 +++++++++++++++++++-- imperative/src/impl/proxy_graph/mini_graph.h | 12 +- imperative/src/impl/transformations/lazy.cpp | 5 +- .../src/include/megbrain/imperative/blob_manager.h | 8 +- .../include/megbrain/imperative/physical_tensor.h | 99 ++++- src/core/impl/comp_node/impl_helper.cpp | 2 +- src/core/include/megbrain/ir/ops.td | 13 + 21 files changed, 756 insertions(+), 183 deletions(-) create mode 100644 imperative/python/megengine/_multistream.py create mode 100644 imperative/python/test/unit/core/test_stream.py delete mode 100644 imperative/src/impl/async_releaser.h diff --git a/imperative/python/megengine/_multistream.py b/imperative/python/megengine/_multistream.py new file mode 100644 index 00000000..0004b3c4 --- /dev/null +++ b/imperative/python/megengine/_multistream.py @@ -0,0 +1,21 @@ +from .core._imperative_rt.core2 import apply +from .core.ops.builtin import Barrier +from .tensor import Tensor + +_dummy_tensors = {} + + +def _get_dummy_tensor(device): + if device not in _dummy_tensors: + _dummy_tensors[device] = Tensor([], device=device) + return _dummy_tensors[device] + + +def record_event(device): + x = _get_dummy_tensor(device) + (x,) = apply(Barrier(device, 1), x) + return x + + +def wait_event(device, event): + apply(Barrier(device, 0), event) diff --git a/imperative/python/megengine/device.py b/imperative/python/megengine/device.py index d9a8f2a7..c6497b89 100644 --- a/imperative/python/megengine/device.py +++ b/imperative/python/megengine/device.py @@ -51,7 +51,7 @@ _sh = _stream_helper() def _valid_device(inp): if isinstance(inp, str) and re.match( - "^([cxg]pu|rocm|multithread)(\d+|\d+:\d+|x)$", inp + "^([cxg]pu|rocm|multithread)(x|\d+)(:\d+)?$", inp ): return True return False @@ -255,8 +255,8 @@ def what_is_xpu(): def coalesce_free_memory(): r"""This function will try it best to free all consecutive free chunks back to operating system, small pieces may not be returned. - - because of the async processing of megengine, the effect of this func may not be reflected + + because of the async processing of megengine, the effect of this func may not be reflected immediately. if you want to see the effect immediately, you can call megengine._full_sync after this func was called diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 33daff40..40432f5a 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -15,7 +15,7 @@ from .core._imperative_rt.core2 import Tensor as _Tensor from .core._imperative_rt.core2 import apply, set_py_tensor_type from .core._trace_option import use_symbolic_shape from .core._wrap import as_device -from .core.ops.builtin import Copy, GetVarShape +from .core.ops.builtin import Borrow, Copy, GetVarShape from .core.tensor.array_method import ArrayMethodMixin from .device import _valid_device, get_default_device from .logger import get_logger @@ -205,7 +205,7 @@ class Tensor(_Tensor, ArrayMethodMixin): def reset_zero(self): self *= 0 - def to(self, device): + def to(self, device, *, _borrow=False): r"""Copy self :class:`~.Tensor` to specified device. See :func:`~.copy`""" if isinstance(device, str) and not _valid_device(device): raise ValueError( @@ -214,7 +214,8 @@ class Tensor(_Tensor, ArrayMethodMixin): ) ) cn = as_device(device).to_c() - return apply(Copy(comp_node=cn), self)[0] + op = Borrow(comp_node=cn) if _borrow else Copy(comp_node=cn) + return apply(op, self)[0] @property def requires_grad(self): @@ -232,11 +233,11 @@ class Tensor(_Tensor, ArrayMethodMixin): return id(self) def __getnewargs__(self): - r""" __getnewargs__ will be called for pickle serialization or deep copy""" + r"""__getnewargs__ will be called for pickle serialization or deep copy""" return (self.numpy(), self.dtype, self.device.logical_name) def __getstate__(self): - r""" __getstate__ will be called for pickle serialization or deep copy""" + r"""__getstate__ will be called for pickle serialization or deep copy""" state = {} if self._qparams is not None: state["qparams"] = self._qparams diff --git a/imperative/python/src/module_trace.h b/imperative/python/src/module_trace.h index 7ed0be90..8a9f9ce8 100644 --- a/imperative/python/src/module_trace.h +++ b/imperative/python/src/module_trace.h @@ -11,6 +11,7 @@ #pragma once +#include #include "megbrain/imperative/transformations/trace.h" #include "megbrain/imperative/utils/map.h" #include "megbrain/imperative/utils/stats.h" diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index e61772c4..a1c6aaa7 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -998,6 +998,9 @@ void init_tensor(py::module m) { module_trace_hook.inc_ref(); }); + auto atexit = py::module::import("atexit"); + atexit.attr("register")(py::cpp_function([]() { module_trace_hook = {}; })); + m.def("begin_record_values", [] { Value::begin_record_values(); }); m.def("end_record_values", [] { diff --git a/imperative/python/test/unit/core/test_stream.py b/imperative/python/test/unit/core/test_stream.py new file mode 100644 index 00000000..ccf885cf --- /dev/null +++ b/imperative/python/test/unit/core/test_stream.py @@ -0,0 +1,74 @@ +import gc + +import numpy as np +import pytest + +import megengine as mge +import megengine.functional as F +from megengine._multistream import record_event, wait_event + + +class MemStat: + def __init__(self, *args): + for d in args: + mge.Tensor([], device=d) + gc.collect() + mge._full_sync() + self.baseline = {d: mge.device.get_allocated_memory(d) for d in args} + for d in args: + mge.device.reset_max_memory_stats(d) + + def get_max(self, device): + return mge.device.get_max_allocated_memory(device) - self.baseline[device] + + +@pytest.mark.require_ngpu(1) +def test_mem_stats(): + memstat = MemStat("xpux:0", "xpux:1") + + F.arange(1024, device="xpux:0") + + mge._full_sync() + assert 4096 <= memstat.get_max("xpux:0") == memstat.get_max("xpux:1") <= 4096 + 128 + + +@pytest.mark.require_ngpu(1) +def test_borrow(): + memstat = MemStat("xpux:0", "xpux:1") + + x_np = np.random.randint(2 ** 30, size=(1 * 1024 * 1024,), dtype="int32") + unit = x_np.size * 4 + x0 = mge.Tensor(x_np, device="xpux:0") + x1 = x0.to("xpux:1", _borrow=True) + y = -x1 + np.testing.assert_equal(-x_np, y.numpy()) + + mge._full_sync() + assert memstat.get_max("xpux:0") / unit < 2.1 + + +@pytest.mark.require_ngpu(1) +def test_stream_mem(): + memstat = MemStat("xpux:0", "xpux:1") + + x_np = np.random.randint(2 ** 10, size=(1 * 1024 * 1024,), dtype="int32") + unit = x_np.size * 4 + x0 = mge.Tensor(x_np, device="xpux:0") + + results = [] + events = [] + for i in range(100): + if len(events) >= 2: + wait_event("xpux:0", events[-2]) + x0 = x0 + 1 + results.append(x0.to("xpux:1", _borrow=True).sum()) + events.append(record_event("xpux:1")) + del events[:-2] + + y_np = x_np.sum() + for i, y in enumerate(results): + y_np += x_np.size + assert y_np == y.numpy() + + mge._full_sync() + assert memstat.get_max("xpux:0") / unit < 2.1 diff --git a/imperative/src/impl/async_releaser.h b/imperative/src/impl/async_releaser.h deleted file mode 100644 index 6cf0d6dc..00000000 --- a/imperative/src/impl/async_releaser.h +++ /dev/null @@ -1,84 +0,0 @@ -/** - * \file imperative/src/impl/async_releaser.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#pragma once - -#include "megbrain/comp_node.h" -#include "megbrain/imperative/blob_manager.h" -#include "megbrain/imperative/resource_manager.h" -#include "megbrain/system.h" - -#include "./event_pool.h" - -namespace mgb { -namespace imperative { - -class AsyncReleaser : public CompNodeDepedentObject { - struct WaiterParam { - CompNode cn; - CompNode::Event* event; - BlobPtr blob; - HostTensorStorage::RawStorage storage; - }; - class Waiter final : public AsyncQueueSC { - AsyncReleaser* m_par_releaser; - - public: - // disable busy wait by set max_spin=0 to save CPU cycle - Waiter(AsyncReleaser* releaser) - : AsyncQueueSC(0), m_par_releaser(releaser) {} - - void process_one_task(WaiterParam& param) { - if (param.event->finished()) { - param.blob.reset(); - param.storage.reset(); - EventPool::without_timer().free(param.event); - return; - } - - using namespace std::literals; - std::this_thread::sleep_for(1us); - add_task(std::move(param)); - } - void on_async_queue_worker_thread_start() override { - sys::set_thread_name("releaser"); - } - }; - Waiter m_waiter{this}; - -protected: - std::shared_ptr on_comp_node_finalize() override { - m_waiter.wait_task_queue_empty(); - return {}; - } - -public: - static AsyncReleaser* inst() { - static auto* releaser = ResourceManager::create_global(); - return releaser; - } - - ~AsyncReleaser() { m_waiter.wait_task_queue_empty(); } - - void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); } - - void add(const HostTensorND& hv) { - add(hv.comp_node(), {}, hv.storage().raw_storage()); - } - - void add(CompNode cn, BlobPtr blob, HostTensorStorage::RawStorage storage = {}) { - auto event = EventPool::without_timer().alloc(cn); - event->record(); - m_waiter.add_task({cn, event, std::move(blob), std::move(storage)}); - } -}; -} // namespace imperative -} // namespace mgb diff --git a/imperative/src/impl/blob_manager_impl.cpp b/imperative/src/impl/blob_manager_impl.cpp index 4d1d1b28..130885a8 100644 --- a/imperative/src/impl/blob_manager_impl.cpp +++ b/imperative/src/impl/blob_manager_impl.cpp @@ -16,7 +16,7 @@ namespace mgb { namespace imperative { -BlobManagerImpl::BlobData::BlobData(Blob* in_blob) { +BlobManagerImpl::BlobData::BlobData(OwnedBlob* in_blob) { blob = in_blob; DeviceTensorStorage d_storage; d_storage.reset(blob->m_comp_node, blob->m_size, blob->m_storage); @@ -28,19 +28,19 @@ BlobManagerImpl::BlobData::BlobData(Blob* in_blob) { h_storage.copy_from(const_cast(d_storage), blob->m_size); } -void BlobManagerImpl::register_blob(Blob* blob) { +void BlobManagerImpl::register_blob(OwnedBlob* blob) { // add blob into the comp2blobs map MGB_LOCK_GUARD(m_mtx); mgb_assert(m_comp2blobs_map[blob->m_comp_node].insert(blob)); } -void BlobManagerImpl::unregister_blob(Blob* blob) { +void BlobManagerImpl::unregister_blob(OwnedBlob* blob) { // erase blob into the comp2blobs map MGB_LOCK_GUARD(m_mtx); mgb_assert(1 == m_comp2blobs_map[blob->m_comp_node].erase(blob)); } -void BlobManagerImpl::alloc_with_defrag(Blob* blob, size_t size) { +void BlobManagerImpl::alloc_with_defrag(OwnedBlob* blob, size_t size) { if (custom_allocator) { blob->m_storage = custom_allocator(blob->m_comp_node, size); return; @@ -55,7 +55,7 @@ void BlobManagerImpl::alloc_with_defrag(Blob* blob, size_t size) { }); } -void BlobManagerImpl::alloc_direct(Blob* blob, size_t size) { +void BlobManagerImpl::alloc_direct(OwnedBlob* blob, size_t size) { DeviceTensorStorage storage(blob->m_comp_node); mgb_assert(blob->m_comp_node.valid()); storage.ensure_size(size); @@ -143,7 +143,7 @@ void BlobManagerImpl::defrag(const CompNode& cn) { // sort blobs by created time, may be helpful for reduce memory fragment std::sort( blob_data_arrary.begin(), blob_data_arrary.end(), - [](auto& lhs, auto& rhs) { return lhs.blob->id() < rhs.blob->id(); }); + [](auto& lhs, auto& rhs) { return lhs.blob->m_id < rhs.blob->m_id; }); // allocate for each storage for (auto i : blob_data_arrary) { @@ -158,19 +158,19 @@ void BlobManagerImpl::defrag(const CompNode& cn) { } struct BlobManagerStub : BlobManager { - void alloc_direct(Blob* blob, size_t size) { + void alloc_direct(OwnedBlob* blob, size_t size) { mgb_assert(0, "prohibited after global variable destruction"); }; - void alloc_with_defrag(Blob* blob, size_t size) { + void alloc_with_defrag(OwnedBlob* blob, size_t size) { mgb_assert(0, "prohibited after global variable destruction"); }; DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout& layout) { mgb_assert(0, "prohibited after global variable destruction"); }; - void register_blob(Blob* blob) { + void register_blob(OwnedBlob* blob) { mgb_assert(0, "prohibited after global variable destruction"); }; - void unregister_blob(Blob* blob){}; + void unregister_blob(OwnedBlob* blob){}; void defrag(const CompNode& cn) { mgb_assert(0, "prohibited after global variable destruction"); }; diff --git a/imperative/src/impl/blob_manager_impl.h b/imperative/src/impl/blob_manager_impl.h index 3b2a9abe..ad75c9cd 100644 --- a/imperative/src/impl/blob_manager_impl.h +++ b/imperative/src/impl/blob_manager_impl.h @@ -19,21 +19,21 @@ namespace imperative { class BlobManagerImpl final : public BlobManager { struct BlobSetWithMux { std::mutex mtx; - ThinHashSet blobs_set; - bool insert(Blob* blob) { + ThinHashSet blobs_set; + bool insert(OwnedBlob* blob) { MGB_LOCK_GUARD(mtx); return blobs_set.insert(blob).second; } - size_t erase(Blob* blob) { + size_t erase(OwnedBlob* blob) { MGB_LOCK_GUARD(mtx); return blobs_set.erase(blob); } }; struct BlobData { - Blob* blob; + OwnedBlob* blob; HostTensorStorage h_storage; - BlobData(Blob* in_blob); + BlobData(OwnedBlob* in_blob); }; std::mutex m_mtx; @@ -41,7 +41,7 @@ class BlobManagerImpl final : public BlobManager { void defrag(const CompNode& cn) override; - void alloc_direct(Blob* blob, size_t size) override; + void alloc_direct(OwnedBlob* blob, size_t size) override; DeviceTensorND alloc_workspace(CompNode cn, TensorLayout layout); @@ -50,14 +50,14 @@ class BlobManagerImpl final : public BlobManager { public: static BlobManager* inst(); - void alloc_with_defrag(Blob* blob, size_t size) override; + void alloc_with_defrag(OwnedBlob* blob, size_t size) override; DeviceTensorND alloc_workspace_with_defrag( CompNode cn, TensorLayout& layout) override; - void register_blob(Blob* blob) override; + void register_blob(OwnedBlob* blob) override; - void unregister_blob(Blob* blob) override; + void unregister_blob(OwnedBlob* blob) override; void set_allocator(allocator_t allocator) override; }; diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index a2c3b54c..6a5af98f 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -704,7 +704,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { m_dtr.update_used_time(dest); MGB_RECORD_EVENT( TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), - ptr->dev_tensor(false).raw_ptr()); + ptr->raw_ptr_not_for_readwrite()); // update tensor desc for static infer if (dest->desc.layout.ndim) { mgb_assert( @@ -805,8 +805,13 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { inputs[idx]->to_contiguous_inplace(layout_checker); } } - return OpDef::apply_on_physical_tensor( + auto outputs = OpDef::apply_on_physical_tensor( def, std::move(inputs), output_descs, validated); + for (auto& o : outputs) { + o->set_ready_event( + record_event(o->comp_node(), def.same_type())); + } + return outputs; }; MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason); SmallVector> kernels; @@ -1059,7 +1064,7 @@ std::unordered_set ChannelImpl::collect_valid_tensors() { return valid_tensors; } -void ChannelImpl::alloc_tensor_with_evict(Blob* x) { +void ChannelImpl::alloc_tensor_with_evict(OwnedBlob* x) { bool in_worker = (get_worker_tid() == std::this_thread::get_id()); auto reserve_size = [&](size_t size) { if (!m_dtr.comp_node.valid()) { diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index 970c2b2c..cdcafa66 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -304,7 +304,7 @@ private: //! automatically evict an optimal tensor bool auto_evict(size_t); - void alloc_tensor_with_evict(Blob*); + void alloc_tensor_with_evict(OwnedBlob*); // assert thread id when call get_xxx_state to avoid misuse ChannelState& get_channel_state(); diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index 51372629..621e7f34 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -249,7 +249,7 @@ SmallVector apply_inplace_add_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3]; - if (!(inputs[0]->blob().unique() && inputs[0]->blob()->storage().unique())) { + if (!inputs[0]->storage_is_unique()) { mgb_log_warn( "This inplace modification may change the elements of other tensors. " "Fallback to non-inplace update."); diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index 7f06da31..40335638 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -13,7 +13,6 @@ #include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/opr_attr.h" -#include "../async_releaser.h" #include "../dnn_op_helper.h" #include "../op_trait.h" @@ -267,8 +266,7 @@ SmallVector param_pack_concat_apply_on_physical_tensor( {srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(), output->dev_tensor().as_megdnn(), caller.create_workspace({{ws_size}, dtype::Byte()})); - AsyncReleaser::inst()->add( - HostTensorND{comp_node, srcs_layout}.storage(srcs_storage)); + async_release(HostTensorND{comp_node, srcs_layout}.storage(srcs_storage)); return {output}; } diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp index 870444e4..727c7056 100644 --- a/imperative/src/impl/ops/utility.cpp +++ b/imperative/src/impl/ops/utility.cpp @@ -23,6 +23,8 @@ #include "megbrain/opr/tensor_gen.h" #include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/utility.h" +#include "megbrain/tensor.h" +#include "megdnn/dtype.h" #if MGB_JIT #include "megbrain/jit/executor_opr.h" @@ -37,6 +39,102 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); OP_TRAIT_REG(GenericPyOp, GenericPyOp).fallback(); namespace { +namespace borrow { + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto& op = def.cast_final_safe(); + SmallVector outputs; + outputs.reserve(inputs.size()); + for (auto& i : inputs) { + outputs.push_back(Tensor::make( + i->blob()->borrow_to(op.comp_node), i->layout(), i->offset())); + } + return outputs; +} + +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + return SmallVector(inputs.size()); +} + +auto infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto& op = def.cast_final_safe(); + std::tuple, bool> ret(inputs, true); + for (auto& i : std::get<0>(ret)) { + mgb_assert( + i.comp_node.mem_node() == op.comp_node.mem_node(), + "cannot borrow memory from %s to %s", i.comp_node.to_string().c_str(), + op.comp_node.to_string().c_str()); + i.comp_node = op.comp_node; + } + return ret; +} + +OP_TRAIT_REG(Borrow, Borrow) + .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .get_input_layout_constraint(get_input_layout_constraint) + .fallback(); +} // namespace borrow +} // namespace + +namespace { +namespace barrier { + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto& op = def.cast_final_safe(); + SmallVector outputs; + for (auto& i : inputs) { + if (i->comp_node() != op.comp_node) { + device_wait_event(op.comp_node, i->comp_node(), i->get_ready_event()); + } + } + if (op.nr_outputs) { + outputs.resize(op.nr_outputs); + TensorLayout layout(TensorShape({0}), dtype::Int32{}); + for (auto& i : outputs) { + i = Tensor::make(layout, op.comp_node); + } + } + return outputs; +} + +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + return SmallVector(inputs.size()); +} + +auto infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto& op = def.cast_final_safe(); + std::tuple, bool> ret; + auto& [descs, checked] = ret; + descs.resize(op.nr_outputs); + checked = true; + for (auto& desc : descs) { + desc.comp_node = op.comp_node; + desc.layout.dtype = dtype::Int32{}; + desc.layout.ndim = 1; + desc.layout.shape[0] = 0; + desc.layout.stride[0] = 1; + } + return ret; +} + +OP_TRAIT_REG(Barrier, imperative::Barrier) + .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .get_input_layout_constraint(get_input_layout_constraint) + .fallback(); +} // namespace barrier +} // namespace + +namespace { namespace fastpathcopy { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { return inputs; diff --git a/imperative/src/impl/physical_tensor.cpp b/imperative/src/impl/physical_tensor.cpp index eea7b62d..1d2a6b69 100644 --- a/imperative/src/impl/physical_tensor.cpp +++ b/imperative/src/impl/physical_tensor.cpp @@ -9,57 +9,384 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "megbrain/imperative/physical_tensor.h" +#include "megbrain/common.h" +#include "megbrain/comp_node.h" #include "megbrain/imperative.h" #include "megbrain/imperative/blob_manager.h" #include "megbrain/imperative/profiler.h" #include "megbrain/imperative/resource_manager.h" -#include "./async_releaser.h" #include "./event_pool.h" #include "./profiler/events.h" +#include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include + +#ifndef WIN32 +#include +#endif + +#include "range/v3/all.hpp" + +namespace views = ranges::views; namespace mgb { namespace imperative { namespace { -class CompNodeSyncManager : public CompNodeDepedentObject { - ThinHashMap> m_blob2event; +struct CompNodeHash { + auto operator()(CompNode cn) const { return mgb::hash(cn); } +}; + +template +struct NoThrowMovable : T { + using T::T; + NoThrowMovable(NoThrowMovable&&) noexcept = default; +}; + +template +using Map = NoThrowMovable>; + +class CompNodeSyncManager { + struct CompNodeData { + template + class ReleaseQueue { + Map map; + + public: + template + void emplace(uint64_t t, A&& a) { + map.emplace_hint(map.end(), t, std::forward(a)); + } + void release(uint64_t t) { + auto it = map.upper_bound(t); + map.erase(map.begin(), it); + } + }; + + //! next virtual event + uint64_t next = 1; + //! last completed virtual event + uint64_t completed = 0; + //! virtual event to real event + Map events; + //! ordering information at some virtual events: + //! what virtual events on other comp nodes is _sequenced before_ this virtual + //! event + Map> ordering; + //! release queue for dev storage, keyed by releaser. this comp node is the + //! **receiver** + std::vector> release_queues; + //! release queue for host storage. this comp node is the **releaser** + ReleaseQueue host_release_queue; + }; + std::mutex m_mtx; + std::condition_variable m_cv; + bool m_should_stop = false; + std::thread m_polling_thread; + std::unordered_map m_cn2id; + std::vector m_cndata; + + auto do_record(CompNode cn, size_t cnid, std::unique_lock& lock) { + // CAUSION: don't keep reference across locking boundary + lock.unlock(); + auto e = EventPool::without_timer().alloc(cn); + e->record(); + lock.lock(); + auto& cndata = m_cndata[cnid]; + return cndata.events.emplace_hint(cndata.events.end(), cndata.next++, e); + } + + std::pair get_event( + CompNode cn, size_t cnid, uint64_t t, std::unique_lock& lock) { + auto& cndata = m_cndata[cnid]; + auto it = cndata.events.lower_bound(t); + if (it == cndata.events.end()) { + it = do_record(cn, cnid, lock); + } + return {it->first, it->second.get()}; + } + + size_t get_cnid_unsafe(CompNode cn) { + auto [it, unseen] = m_cn2id.try_emplace(cn, m_cndata.size()); + if (unseen) { + m_cndata.emplace_back(); + } + return it->second; + } + + void monitor_events() { +#if defined(__APPLE__) + pthread_setname_np("CompNodeSync"); +#elif defined(__unix__) + pthread_setname_np(pthread_self(), "CompNodeSync"); +#endif + + // poll events in rounds. sleep for a fixed duration between rounds. + // number of events to query is decided by the number of successful queries in + // last round, independently for each comp node: + // a. all -> double + // b. 0 -> 1 + // c. otherwise -> #successful + + struct Item { + size_t cnid; + decltype(CompNodeData::events)::iterator it; + }; + + struct Stat { + size_t num_success = 0; + size_t num_attempts = 0; + // iterator to the last finished event + decltype(CompNodeData::events)::iterator it; + }; + + std::vector stats; + std::vector todos; + std::unique_lock lock(m_mtx); + for (;;) { + // copy events to a temporary storage so that we may unlock while polling + stats.resize(m_cndata.size()); + for (size_t cnid = 0; cnid < m_cndata.size(); ++cnid) { + // decide max number of events to query + // rule c: #successful + size_t n = stats[cnid].num_success; + if (n == stats[cnid].num_attempts) { + // rule a: double + n *= 2; + } + if (n == 0) { + // rule b: 1 + n = 1; + } + // now copy upto n events + auto& events = m_cndata[cnid].events; + size_t i = 0; + for (auto it = events.begin(); i < n && it != events.end(); ++i, ++it) { + todos.push_back({cnid, it}); + } + // reset stats for this round + stats[cnid].num_success = 0; + stats[cnid].num_attempts = n; + } + + lock.unlock(); + + bool last_result = false; + size_t last_cnid = -1; + for (auto item : todos) { + if (item.cnid == last_cnid && !last_result) { + // previous failed, this one almost certainly should fail + continue; + } + last_cnid = item.cnid; + last_result = item.it->second->finished(); + if (last_result) { + stats[item.cnid].num_success++; + stats[item.cnid].it = item.it; + } + } + todos.clear(); + + lock.lock(); + + // release dev storage + for (size_t receiver_cnid = 0; receiver_cnid < m_cndata.size(); + ++receiver_cnid) { + for (size_t releaser_cnid = 0; + releaser_cnid < m_cndata[receiver_cnid].release_queues.size(); + ++releaser_cnid) { + if (releaser_cnid >= stats.size() || + stats[releaser_cnid].num_success == 0) { + continue; + } + auto& q = m_cndata[receiver_cnid].release_queues[releaser_cnid]; + q.release(stats[releaser_cnid].it->first); + } + } + + for (size_t cnid = 0; cnid < stats.size(); ++cnid) { + if (stats[cnid].num_success == 0) { + continue; + } + auto& cndata = m_cndata[cnid]; + auto it = stats[cnid].it; + auto t = it->first; + // update completed + cndata.completed = t; + // release host storage + cndata.host_release_queue.release(t); + // remove completed events + auto& events = cndata.events; + events.erase(events.begin(), std::next(it)); + } + + using namespace std::literals; + if (m_cv.wait_for(lock, 10us, [&] { return m_should_stop; })) { + return; + } + } + } + + CompNodeSyncManager() { + m_polling_thread = std::thread([this] { monitor_events(); }); + } public: - std::shared_ptr on_comp_node_finalize() override { - MGB_LOCK_GUARD(m_mtx); - m_blob2event.clear(); - return {}; + ~CompNodeSyncManager() { + { + MGB_LOCK_GUARD(m_mtx); + m_should_stop = true; + m_cv.notify_all(); + } + m_polling_thread.join(); } - static CompNodeSyncManager& inst() { - static auto* sl_inst = ResourceManager::create_global(); - return *sl_inst; + static CompNodeSyncManager& inst(); + + uint64_t record(CompNode cn, bool doitnow = false) { + std::unique_lock lock(m_mtx); + auto cnid = get_cnid_unsafe(cn); + if (doitnow) { + return do_record(cn, cnid, lock)->first; + } + return m_cndata[cnid].next++; } - CompNode::Event* get_or_create_event(Blob* blob) { - mgb_assert(!is_finalized()); + void async_release(CompNode cn, uint64_t t, BlobPtr blob) { MGB_LOCK_GUARD(m_mtx); - auto&& e = m_blob2event[blob]; - if (!e) { - e = blob->comp_node().create_event(); + auto releaser_cnid = get_cnid_unsafe(cn); + if (t <= m_cndata[releaser_cnid].completed) { + return; + } + auto receiver_cnid = get_cnid_unsafe(blob->comp_node()); + auto& qs = m_cndata[receiver_cnid].release_queues; + if (releaser_cnid >= qs.size()) { + qs.resize(releaser_cnid + 1); } - return e.get(); + auto& q = qs[releaser_cnid]; + q.emplace(t, std::move(blob)); } - void remove(Blob* blob) { + void async_release(CompNode cn, uint64_t t, HostTensorStorage::RawStorage storage) { MGB_LOCK_GUARD(m_mtx); - m_blob2event.erase(blob); + auto releaser_cnid = get_cnid_unsafe(cn); + if (t <= m_cndata[releaser_cnid].completed) { + return; + } + auto& q = m_cndata[releaser_cnid].host_release_queue; + q.emplace(t, std::move(storage)); + } + + void device_wait(CompNode waiter, CompNode waitee, uint64_t t) { + std::unique_lock lock(m_mtx); + + auto waiter_id = get_cnid_unsafe(waiter); + auto waitee_id = get_cnid_unsafe(waitee); + auto& waiter_data = m_cndata.at(waiter_id); + auto& waitee_data = m_cndata.at(waitee_id); + auto [t_waitee, e] = get_event(waitee, waitee_id, t, lock); + + // DO NOT unlock around this line! Event* could be invalidated! + e->device_wait_by(waiter); + + auto t_waiter = waiter_data.next++; + std::vector ordering(m_cndata.size(), 0); + if (!waiter_data.ordering.empty()) { + auto& o = waiter_data.ordering.rbegin()->second; + std::copy(o.begin(), o.end(), ordering.begin()); + } + ordering[waitee_id] = t_waitee; + ordering[waiter_id] = t_waiter; + { + auto it = waitee_data.ordering.lower_bound(t_waitee); + if (it != waitee_data.ordering.begin()) { + for (auto [a, b] : views::zip(ordering, std::prev(it)->second)) { + static_assert(std::is_lvalue_reference_v); + a = std::max(a, b); + } + } + } + waiter_data.ordering.emplace_hint( + waiter_data.ordering.end(), t_waiter, ordering); + + for (auto [t, q] : views::zip(ordering, waiter_data.release_queues)) { + q.release(t); + } } }; +CompNodeSyncManager& CompNodeSyncManager::inst() { + static std::mutex mtx; + static std::unique_ptr inst; + + struct Guard final : CompNodeDepedentObject { + std::shared_ptr on_comp_node_finalize() override { + MGB_LOCK_GUARD(mtx); + inst.reset(); + return {}; + } + }; + + static std::optional guard; + +#ifndef WIN32 + static bool broken = false; + static struct ForkGuard { + ForkGuard() { + mgb_assert(0 == pthread_atfork(NULL, NULL, [] { + if (inst) { + inst.release(); // deliberate leak, unfixable + broken = true; + } + })); + } + } fork_guard; +#endif + + MGB_LOCK_GUARD(mtx); + if (!inst) { +#ifndef WIN32 + mgb_assert(!broken); +#endif + EventPool::without_timer(); + inst.reset(new CompNodeSyncManager); + guard.emplace(); + } + return *inst; +} + } // namespace +uint64_t record_event(CompNode cn, bool doitnow) { + return CompNodeSyncManager::inst().record(cn, doitnow); +} + +void device_wait_event(CompNode waiter, CompNode waitee, uint64_t event) { + CompNodeSyncManager::inst().device_wait(waiter, waitee, event); +} + +void async_release(CompNode cn, uint64_t event, BlobPtr blob) { + CompNodeSyncManager::inst().async_release(cn, event, std::move(blob)); +} + +void async_release(CompNode cn, uint64_t event, HostTensorStorage::RawStorage storage) { + CompNodeSyncManager::inst().async_release(cn, event, std::move(storage)); +} + void EventDeleter::operator()(CompNode::Event* event) { EventPool::without_timer().free(event); } @@ -68,31 +395,74 @@ namespace { std::atomic_uint64_t next_blob_id = 0; } -Blob::Blob(const DeviceTensorStorage& s) - : m_comp_node{s.comp_node()}, +OwnedBlob::OwnedBlob(const DeviceTensorStorage& s) + : Blob(s.comp_node(), s.size() + s.offset()), m_storage{s.raw_storage()}, - m_size{s.size() + s.offset()} { - m_id = next_blob_id++; + m_id{next_blob_id++} { BlobManager::inst()->register_blob(this); } -Blob::Blob(CompNode cn, size_t sz) : m_comp_node{cn}, m_storage{}, m_size{sz} { - m_id = next_blob_id++; +OwnedBlob::OwnedBlob(CompNode cn, size_t sz) + : Blob(cn, sz), m_storage{}, m_id{next_blob_id++} { BlobManager::inst()->register_blob(this); } -Blob::~Blob() { +OwnedBlob::~OwnedBlob() { BlobManager::inst()->unregister_blob(this); - CompNodeSyncManager::inst().remove(this); } -const Blob::RawStorage& Blob::storage() { +const Blob::RawStorage& OwnedBlob::storage() { if (!m_storage && m_size) { BlobManager::inst()->alloc_with_defrag(this, m_size); } return m_storage; } +BlobPtr OwnedBlob::borrow_to(CompNode cn) { + return std::make_shared( + cn, std::static_pointer_cast(shared_from_this())); +} + +bool OwnedBlob::storage_is_unique() { + return m_storage.unique(); +} + +void* OwnedBlob::raw_ptr_not_for_readwrite() { + return m_storage.get(); +} + +BorrowedBlob::BorrowedBlob(CompNode cn, std::shared_ptr owner) + : Blob(cn, owner->size()), + m_owner(std::move(owner)), + m_event(record_event(m_owner->comp_node(), true)) {} + +BorrowedBlob::~BorrowedBlob() { + async_release(m_comp_node, record_event(m_comp_node, true), std::move(m_owner)); +} + +const Blob::RawStorage& BorrowedBlob::storage() { + { + MGB_LOCK_GUARD(m_mtx); + if (!m_initialized) { + device_wait_event(m_comp_node, m_owner->comp_node(), m_event); + m_initialized = true; + } + } + return m_owner->storage(); +} + +BlobPtr BorrowedBlob::borrow_to(CompNode cn) { + return std::make_shared(cn, m_owner); +} + +bool BorrowedBlob::storage_is_unique() { + return m_owner.unique() && m_owner->storage_is_unique(); +} + +void* BorrowedBlob::raw_ptr_not_for_readwrite() { + return m_owner->raw_ptr_not_for_readwrite(); +} + Tensor::Tensor( BlobPtr blob, const TensorLayout& layout, size_t offset, const HostTensorND& hv) : m_cn(blob->comp_node()), @@ -119,7 +489,7 @@ Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) { MGB_RECORD_EVENT( profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(), dev_tensor().raw_ptr()); - AsyncReleaser::inst()->add(hv); + async_release(hv); } } @@ -174,7 +544,7 @@ void Tensor::to_contiguous_inplace(VarNode::LayoutConstraintCallback& layout_che DeviceTensorND dv_contig; dv_contig.copy_from(dv); m_layout = dv_contig.layout(); - std::atomic_store(&m_blob, Blob::make(dv_contig.storage())); + std::atomic_store(&m_blob, BlobPtr(Blob::make(dv_contig.storage()))); mgb_assert(m_layout.is_contiguous()); m_offset = 0; } @@ -188,7 +558,7 @@ void Tensor::to_contiguous_inplace() { void Tensor::assign_from_dev_tensor(DeviceTensorND dv) { MGB_LOCK_GUARD(m_blob_mtx); - std::atomic_store(&m_blob, Blob::make(dv.storage())); + std::atomic_store(&m_blob, BlobPtr(Blob::make(dv.storage()))); m_offset = dv.storage().offset(); m_layout = dv.layout(); } @@ -254,21 +624,20 @@ TensorPtr Tensor::sub(size_t offset, TensorShape shape) { return Tensor::make(m_blob, offset + m_offset, layout); } -void Tensor::add_release_callback(CompNode cn) { - AsyncReleaser::inst()->add(m_blob, cn); +uint64_t Tensor::get_ready_event() { + if (m_produced_at == 0) { + m_produced_at = record_event(comp_node()); + } + return m_produced_at; } -CompNode::Event* Tensor::get_or_create_event() { - auto e = CompNodeSyncManager::inst().get_or_create_event(m_blob.get()); - e->record(); - return e; +bool Tensor::storage_is_unique() { + return m_blob.unique() && m_blob->storage_is_unique(); } void Tensor::static_initialize() { EventPool::with_timer(); EventPool::without_timer(); - AsyncReleaser::inst(); - CompNodeSyncManager::inst(); MultiCNConstTensorCache::inst(); } diff --git a/imperative/src/impl/proxy_graph/mini_graph.h b/imperative/src/impl/proxy_graph/mini_graph.h index be393406..36786d54 100644 --- a/imperative/src/impl/proxy_graph/mini_graph.h +++ b/imperative/src/impl/proxy_graph/mini_graph.h @@ -836,8 +836,8 @@ public: if (used_cns.insert(cn).second) { for (auto&& in : inputs) { if (in->comp_node() != cn) { - auto&& e = in->get_or_create_event(); - e->device_wait_by(cn); + auto e = in->get_ready_event(); + device_wait_event(cn, in->comp_node(), e); } } } @@ -847,11 +847,17 @@ public: // so we need create inference session here minigraph.execute(raw_inputs, raw_outputs, m_env); for (auto&& cn : used_cns) { + bool should_record = false; for (auto&& in : inputs) { if (in->comp_node() != cn) { - in->add_release_callback(cn); + should_record = true; + auto e = record_event(cn); + async_release(cn, e, *in); } } + if (should_record) { + record_event(cn, true); + } } return outputs; diff --git a/imperative/src/impl/transformations/lazy.cpp b/imperative/src/impl/transformations/lazy.cpp index 89c4b355..baf53aa2 100644 --- a/imperative/src/impl/transformations/lazy.cpp +++ b/imperative/src/impl/transformations/lazy.cpp @@ -15,7 +15,6 @@ #include "megbrain/opr/utility.h" -#include "../async_releaser.h" #include "../mgb_cg_impl.h" namespace mgb { @@ -73,7 +72,7 @@ ValueRefList LazyEvalTransformation::apply_transformation( args.device.emplace(); args.device->copy_from(*args.host); // every h2d in imperative runtime should notify AsyncReleaser - AsyncReleaser::inst()->add(*args.host); + async_release(*args.host); } return *args.device; }; @@ -155,7 +154,7 @@ ValueRefList LazyEvalTransformation::apply_transformation( host_value.copy_from(inferred_value); DeviceTensorND dev_value; dev_value.copy_from(host_value); - AsyncReleaser::inst()->add(host_value); + async_release(host_value); return {DeviceValue::make(dev_value)}; } default: diff --git a/imperative/src/include/megbrain/imperative/blob_manager.h b/imperative/src/include/megbrain/imperative/blob_manager.h index 9c6a41cb..cce46d9a 100644 --- a/imperative/src/include/megbrain/imperative/blob_manager.h +++ b/imperative/src/include/megbrain/imperative/blob_manager.h @@ -24,18 +24,18 @@ public: static BlobManager* inst(); - virtual void alloc_direct(Blob* blob, size_t size) = 0; + virtual void alloc_direct(OwnedBlob* blob, size_t size) = 0; - virtual void alloc_with_defrag(Blob* blob, size_t size) = 0; + virtual void alloc_with_defrag(OwnedBlob* blob, size_t size) = 0; virtual void set_allocator(allocator_t allocator) = 0; virtual DeviceTensorND alloc_workspace_with_defrag( CompNode cn, TensorLayout& layout) = 0; - virtual void register_blob(Blob* blob) = 0; + virtual void register_blob(OwnedBlob* blob) = 0; - virtual void unregister_blob(Blob* blob) = 0; + virtual void unregister_blob(OwnedBlob* blob) = 0; virtual void defrag(const CompNode& cn) = 0; }; diff --git a/imperative/src/include/megbrain/imperative/physical_tensor.h b/imperative/src/include/megbrain/imperative/physical_tensor.h index 9787a13f..b0c68365 100644 --- a/imperative/src/include/megbrain/imperative/physical_tensor.h +++ b/imperative/src/include/megbrain/imperative/physical_tensor.h @@ -11,12 +11,16 @@ #pragma once +#include #include #include +#include +#include #include "megbrain/graph.h" #include "megbrain/imperative/resource_manager.h" #include "megbrain/tensor.h" +#include "megbrain/utils/metahelper.h" namespace mgb { namespace imperative { @@ -26,35 +30,67 @@ class Blob; using BlobPtr = std::shared_ptr; class BlobManagerImpl; +class OwnedBlob; + +class Blob : public NonCopyableObj, public std::enable_shared_from_this { +protected: + CompNode m_comp_node; + size_t m_size; + + Blob(CompNode cn, size_t size) : m_comp_node(cn), m_size(size) {} -class Blob : public NonCopyableObj { public: - Blob(const DeviceTensorStorage& s); - Blob(CompNode cn, size_t sz); - ~Blob(); + virtual ~Blob() = default; template - static BlobPtr make(Args&&... args) { - return std::make_shared(std::forward(args)...); + static std::shared_ptr make(Args&&... args) { + return std::make_shared(std::forward(args)...); } + const CompNode& comp_node() const { return m_comp_node; } + size_t size() const { return m_size; } using RawStorage = DeviceTensorStorage::RawStorage; - const RawStorage& storage(); + virtual const RawStorage& storage() = 0; + virtual BlobPtr borrow_to(CompNode) = 0; + virtual bool storage_is_unique() = 0; + virtual void* raw_ptr_not_for_readwrite() = 0; +}; - const CompNode& comp_node() const { return m_comp_node; } +class OwnedBlob final : public Blob { + friend class Blob; - size_t size() const { return m_size; } +public: + OwnedBlob(const DeviceTensorStorage& s); + OwnedBlob(CompNode cn, size_t sz); + ~OwnedBlob() override; - size_t id() const { return m_id; } + const RawStorage& storage() override; + BlobPtr borrow_to(CompNode) override; + bool storage_is_unique() override; + void* raw_ptr_not_for_readwrite() override; private: friend class BlobManagerImpl; - CompNode m_comp_node; - mutable RawStorage m_storage; - size_t m_size = 0; + RawStorage m_storage; size_t m_id; }; +class BorrowedBlob final : public Blob { + std::mutex m_mtx; + std::shared_ptr m_owner; + uint64_t m_event; + bool m_initialized = false; + +public: + BorrowedBlob(CompNode, std::shared_ptr); + ~BorrowedBlob() override; + + const RawStorage& storage() override; + BlobPtr borrow_to(CompNode) override; + bool storage_is_unique() override; + void* raw_ptr_not_for_readwrite() override; +}; + struct EventDeleter { void operator()(CompNode::Event*); }; @@ -121,6 +157,8 @@ public: BlobPtr& blob() { return m_blob; } + void* raw_ptr_not_for_readwrite() { return m_blob->raw_ptr_not_for_readwrite(); } + void fetch_value(); bool value_fetched(); TensorPtr sub(size_t offset, TensorShape shape); @@ -131,8 +169,10 @@ public: // return a pointer instead of a reference to ensure thread safety const HostTensorND* try_get_value(); - void add_release_callback(CompNode cn); - CompNode::Event* get_or_create_event(); + void set_ready_event(uint64_t event) { m_produced_at = event; } + uint64_t get_ready_event(); + + bool storage_is_unique(); // Make sure all static objects required to destruct a tensor has completed // construction. All static storage duration object that holds tensors must @@ -152,8 +192,37 @@ private: std::mutex m_value_mtx; HostTensorND m_value; EventPtr m_value_ready = nullptr; + uint64_t m_produced_at = 0; }; +/*! + * \brief record a virtual event + * \param doitnow also record a real event + */ +uint64_t record_event(CompNode cn, bool doitnow = false); + +//! make a device wait on a virtual event +void device_wait_event(CompNode waiter, CompNode waitee, uint64_t event); + +//! hold a blob until a virtual event on a device is completed +void async_release(CompNode cn, uint64_t event, BlobPtr blob); + +//! hold a host tensor until a virtual event on a device is completed +void async_release(CompNode cn, uint64_t event, HostTensorStorage::RawStorage storage); + +inline void async_release(CompNode cn, uint64_t event, Tensor& tensor) { + async_release(cn, event, tensor.blob()); +} + +inline void async_release(CompNode cn, uint64_t event, const HostTensorND& hnd) { + async_release(cn, event, hnd.storage().raw_storage()); +} + +inline void async_release(const HostTensorND& hnd) { + auto cn = hnd.comp_node(); + async_release(cn, record_event(cn, true), hnd); +} + // Cache for small blobs // 1. A blob has to be seen twice (within a window) to be eligible for cache // 2. Cache eviction occurs when cache size reaches a threshold, in least frequently diff --git a/src/core/impl/comp_node/impl_helper.cpp b/src/core/impl/comp_node/impl_helper.cpp index 5e01da18..67fac48a 100644 --- a/src/core/impl/comp_node/impl_helper.cpp +++ b/src/core/impl/comp_node/impl_helper.cpp @@ -32,7 +32,7 @@ bool CompNodeImplHelper::EventImplHelper::finished() { mgb_assert(m_recorded); if (do_finished()) { m_finished = true; - m_recorded = false; + // m_recorded = false; return true; } return false; diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 9f1dd3a3..79e97951 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -120,6 +120,19 @@ def Copy: MgbHashableOp<"Copy"> { ); } +def Borrow: MgbHashableOp<"Borrow"> { + let extraArguments = (ins + MgbCompNodeAttr:$comp_node + ); +} + +def Barrier: MgbHashableOp<"Barrier"> { + let extraArguments = (ins + MgbCompNodeAttr:$comp_node, + MgbUI32Attr:$nr_outputs + ); +} + def Argsort: MgbHashableOp<"Argsort", [ArgsortParam]>; def Argmax : MgbHashableOp<"Argmax", [AxisParam]>;