diff --git a/imperative/python/megengine/core/_config.py b/imperative/python/megengine/core/_config.py index b4451320..b55c5c32 100644 --- a/imperative/python/megengine/core/_config.py +++ b/imperative/python/megengine/core/_config.py @@ -20,6 +20,7 @@ __all__ = [ "benchmark_kernel", "deterministic_kernel", "async_level", + "disable_memory_forwarding", "_compute_mode", "_conv_format", "_override", @@ -87,6 +88,25 @@ def async_level(mod, level: int): @property +def disable_memory_forwarding(mod) -> bool: + r"""Get or set config whether to disable memory forwarding. The default option is false, + which means storage may be shared among tensors. + + Examples: + .. code-block:: + + import megengine as mge + mge.config.disable_memory_forwarding = False + """ + return bool(get_option("disable_memory_forwarding")) + + +@disable_memory_forwarding.setter +def disable_memory_forwarding(mod, disable: bool): + set_option("disable_memory_forwarding", disable) + + +@property def _compute_mode(mod): r"""Get or set the precision of intermediate results. The default option is "default", which means that no special requirements will be placed on. When set to 'float32', it diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 5bf55b63..9ee07413 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -120,7 +120,6 @@ void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() { return blob->storage(); }; OpDef::set_allocator(custom_allocator); - BlobManager::inst()->set_allocator(custom_allocator); } // Do not use m_xxx_state directly @@ -358,8 +357,8 @@ void ChannelImpl::dispatch_kernel( init(info, std::move(desc)); // make sure desc's value is consistent with h_value if (!info->desc.value.empty()) { - info->h_value = HostTensorND::make_proxy(desc.value) - .proxy_to_comp_node(desc.comp_node); + info->h_value = HostTensorND::make_proxy(info->desc.value) + .proxy_to_comp_node(info->desc.comp_node); } output_infos.push_back(info); outputs->push_back(reinterpret_cast(info)); @@ -561,6 +560,15 @@ void ChannelImpl::set_option(std::string name, size_t value) { mgb_assert(check_available(), "Channel already closed"); auto& state = get_channel_state(); state.options.set_option(name, value); + // FIXME + if (name == "enable_dtr_auto_drop" && value) { + auto custom_allocator = [&](CompNode device, size_t size) { + auto blob = Blob::make(device, size); + alloc_tensor_with_evict(blob.get()); + return blob->storage(); + }; + BlobManager::inst()->set_allocator(custom_allocator); + } if (Profiler::is_profiling()) { m_worker.add_task( {Profiler::next_id(), SetOption{name, value}, @@ -598,7 +606,7 @@ void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) { m_valid_handle.insert(reinterpret_cast(info)); MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name); info->status = TensorInfo::Allocated; - info->desc = desc; + info->desc = std::move(desc); } void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) { @@ -694,7 +702,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { } // in order to avoid performance impact, // memory forwarding is disabled when DTR is enabled - if (state.options.enable_dtr_auto_drop) { + if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) { ptr->to_contiguous_inplace(); } dest->desc.layout = ptr->layout(); diff --git a/imperative/src/impl/interpreter/option_manager.h b/imperative/src/impl/interpreter/option_manager.h index 013747b6..97905e43 100644 --- a/imperative/src/impl/interpreter/option_manager.h +++ b/imperative/src/impl/interpreter/option_manager.h @@ -44,6 +44,9 @@ public: enable_host_compute, "MEGENGINE_HOST_COMPUTE", 1, "enable host compute, thus computation may be done in host event if it's " "device is gpu."); + DEF_OPTION( + disable_memory_forwarding, "MEGENGINE_DISABLE_MEMORY_FORWARDING", 0, + "disable memory forwarding, thus each tensor has its own storage."); DEF_OPTION(enable_dtr_auto_drop, "MEGENGINE_DTR_AUTO_DROP", 0, ""); DEF_OPTION(enable_dtr_sqrt_sampling, "MEGENGINE_DTR_SQRT_SAMPLING", 0, ""); DEF_OPTION(