GitOrigin-RevId: 020d1e88d4
release-1.2
@@ -33,7 +33,7 @@ def _run_wrapped( | |||||
class launcher: | class launcher: | ||||
"""Decorator for launching multiple processes in single-machine multi-gpu training. | """Decorator for launching multiple processes in single-machine multi-gpu training. | ||||
:param func: the function you want to launch in distributed mode. | :param func: the function you want to launch in distributed mode. | ||||
:param n_gpus: how many devices each node. | :param n_gpus: how many devices each node. | ||||
:param world_size: how many devices totally. | :param world_size: how many devices totally. | ||||
@@ -32,7 +32,7 @@ namespace views = ranges::views; | |||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | |||||
interpreter::Interpreter::Channel* interpreter_for_py; | |||||
PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing, | PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing, | ||||
*cpp_apply_compiled_mode, *cpp_apply_const_compiled_mode; | *cpp_apply_compiled_mode, *cpp_apply_const_compiled_mode; | ||||
@@ -673,7 +673,9 @@ py::object make_empty_tensorwrapper() { | |||||
} | } | ||||
void init_tensor(py::module m) { | void init_tensor(py::module m) { | ||||
interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | |||||
imperative::Tensor::static_initialize(); | |||||
static auto sl_interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | |||||
interpreter_for_py = sl_interpreter_for_py.get(); | |||||
auto* tensor_type = TensorWrapper::wrap_t::type() | auto* tensor_type = TensorWrapper::wrap_t::type() | ||||
.def<&TensorWrapper::numpy>("numpy") | .def<&TensorWrapper::numpy>("numpy") | ||||
@@ -724,6 +726,8 @@ void init_tensor(py::module m) { | |||||
[](int level) { interpreter_for_py->config_async_level(level); }); | [](int level) { interpreter_for_py->config_async_level(level); }); | ||||
m.def("get_async_level", | m.def("get_async_level", | ||||
[]() { return interpreter_for_py->get_async_level(); }); | []() { return interpreter_for_py->get_async_level(); }); | ||||
m.def("set_buffer_length", | |||||
[](int length) { interpreter_for_py->set_buffer_length(length); }); | |||||
m.def("sync", | m.def("sync", | ||||
[]() { | []() { | ||||
interpreter_for_py->sync(); | interpreter_for_py->sync(); | ||||
@@ -34,7 +34,7 @@ struct ObjectPtr : B { | |||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | |||||
extern interpreter::Interpreter::Channel* interpreter_for_py; | |||||
class SharedHandle { | class SharedHandle { | ||||
using Handle = interpreter::Interpreter::Handle; | using Handle = interpreter::Interpreter::Handle; | ||||
@@ -111,6 +111,11 @@ void BlobManagerImpl::defrag(const CompNode& cn) { | |||||
MGB_TRY{cn.free_device(cn.alloc_device(tot_sz));} | MGB_TRY{cn.free_device(cn.alloc_device(tot_sz));} | ||||
MGB_CATCH(MemAllocError&, {}) | MGB_CATCH(MemAllocError&, {}) | ||||
// sort blobs by created time, may be helpful for reduce memory fragment | |||||
std::sort(blob_data_arrary.begin(), blob_data_arrary.end(), [](auto& lhs, auto& rhs){ | |||||
return lhs.blob->id() < rhs.blob->id(); | |||||
}); | |||||
// allocate for each storage | // allocate for each storage | ||||
for (auto i : blob_data_arrary) { | for (auto i : blob_data_arrary) { | ||||
DeviceTensorStorage d_storage = DeviceTensorStorage(cn); | DeviceTensorStorage d_storage = DeviceTensorStorage(cn); | ||||
@@ -22,10 +22,10 @@ class FunctionHooker; | |||||
template <typename TRet, typename... TArgs> | template <typename TRet, typename... TArgs> | ||||
class FunctionHooker<TRet(TArgs...)> { | class FunctionHooker<TRet(TArgs...)> { | ||||
public: | public: | ||||
using FunctionType = thin_function<TRet(TArgs&&...)>; | |||||
using FunctionType = thin_function<TRet(TArgs...)>; | |||||
//Type of hooks. Hook should accept a real function as argument | //Type of hooks. Hook should accept a real function as argument | ||||
//and invoke it on an appropriate time | //and invoke it on an appropriate time | ||||
using HookType = thin_function<TRet(FunctionType, TArgs&&...)>; | |||||
using HookType = thin_function<TRet(FunctionType, TArgs...)>; | |||||
explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} { | explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} { | ||||
m_backup = {nullptr, [](FunctionType*){}}; | m_backup = {nullptr, [](FunctionType*){}}; | ||||
} | } | ||||
@@ -43,7 +43,7 @@ public: | |||||
m_backup = decltype(m_backup)(backup, restorer); | m_backup = decltype(m_backup)(backup, restorer); | ||||
} | } | ||||
//Replace with hooked version | //Replace with hooked version | ||||
*m_fptr = [func = *m_fptr, hook](TArgs&&... args) -> TRet { | |||||
*m_fptr = [func = *m_fptr, hook](TArgs... args) -> TRet { | |||||
return hook(func, std::forward<TArgs>(args)...); | return hook(func, std::forward<TArgs>(args)...); | ||||
}; | }; | ||||
//Convinent for chain call | //Convinent for chain call | ||||
@@ -58,7 +58,7 @@ private: | |||||
//Helps to deduce template args | //Helps to deduce template args | ||||
template <typename TRet, typename... TArgs> | template <typename TRet, typename... TArgs> | ||||
FunctionHooker(thin_function<TRet(TArgs...)>* f) | FunctionHooker(thin_function<TRet(TArgs...)>* f) | ||||
->FunctionHooker<TRet(TArgs...)>; | |||||
-> FunctionHooker<TRet(TArgs...)>; | |||||
template<typename TSignature> | template<typename TSignature> | ||||
auto make_shared_hook(thin_function<TSignature>* fptr){ | auto make_shared_hook(thin_function<TSignature>* fptr){ | ||||
@@ -11,20 +11,20 @@ | |||||
#include "./interpreter_impl.h" | #include "./interpreter_impl.h" | ||||
#include "megbrain/common.h" | #include "megbrain/common.h" | ||||
#include "megbrain/imperative/opr_utility.h" | |||||
#include "megbrain/imperative/ops/backward_graph.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace imperative; | using namespace imperative; | ||||
using namespace interpreter; | using namespace interpreter; | ||||
using namespace interpreter::intl; | using namespace interpreter::intl; | ||||
std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() { | std::unique_ptr<Interpreter::Channel> InterpreterImpl::create_channel() { | ||||
return std::make_unique<ChannelImpl>(); | return std::make_unique<ChannelImpl>(); | ||||
} | } | ||||
Interpreter& Interpreter::inst() { | Interpreter& Interpreter::inst() { | ||||
Tensor::_static_init(); | |||||
static InterpreterImpl inst_; | static InterpreterImpl inst_; | ||||
return inst_; | return inst_; | ||||
} | } | ||||
@@ -35,7 +35,7 @@ void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||||
info->desc.comp_node = value.comp_node(); | info->desc.comp_node = value.comp_node(); | ||||
info->desc.value = value.proxy_to_default_cpu(); | info->desc.value = value.proxy_to_default_cpu(); | ||||
m_valid_handle.insert(info); | m_valid_handle.insert(info); | ||||
m_worker.add_task(Put{info, value, no_cache}); | |||||
m_buffer.enqueue(Put{info, value, no_cache}); | |||||
return info; | return info; | ||||
} | } | ||||
@@ -50,14 +50,14 @@ void* ChannelImpl::put(const DeviceTensorND& data) { | |||||
void ChannelImpl::del(void* handle) { | void ChannelImpl::del(void* handle) { | ||||
mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | ||||
m_worker.add_task(Del{reinterpret_cast<TensorInfo*>(handle)}); | |||||
m_buffer.enqueue(Del{reinterpret_cast<TensorInfo*>(handle)}); | |||||
} | } | ||||
void ChannelImpl::swap_in(void* handle) { | void ChannelImpl::swap_in(void* handle) { | ||||
if (m_enable_evict & SWAP) { | if (m_enable_evict & SWAP) { | ||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
m_worker.add_task(SwapIn{reinterpret_cast<TensorInfo*>(handle)}); | |||||
m_buffer.enqueue(SwapIn{reinterpret_cast<TensorInfo*>(handle)}); | |||||
} | } | ||||
} | } | ||||
@@ -65,7 +65,7 @@ void ChannelImpl::swap_out(void* handle) { | |||||
if (m_enable_evict & SWAP) { | if (m_enable_evict & SWAP) { | ||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
m_worker.add_task(SwapOut{reinterpret_cast<TensorInfo*>(handle)}); | |||||
m_buffer.enqueue(SwapOut{reinterpret_cast<TensorInfo*>(handle)}); | |||||
} | } | ||||
} | } | ||||
@@ -73,7 +73,7 @@ void ChannelImpl::drop(void* handle) { | |||||
if (m_enable_evict & DROP) { | if (m_enable_evict & DROP) { | ||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
m_worker.add_task(Drop{reinterpret_cast<TensorInfo*>(handle)}); | |||||
m_buffer.enqueue(Drop{reinterpret_cast<TensorInfo*>(handle)}); | |||||
} | } | ||||
} | } | ||||
@@ -88,14 +88,16 @@ SmallVector<void*> ChannelImpl::apply_op( | |||||
input_infos.reserve(inputs.size()); | input_infos.reserve(inputs.size()); | ||||
SmallVector<LogicalTensorDesc> input_descs; | SmallVector<LogicalTensorDesc> input_descs; | ||||
input_descs.reserve(inputs.size()); | input_descs.reserve(inputs.size()); | ||||
std::unique_lock<decltype(m_mutex)> lock(m_mutex); | |||||
for (auto i : inputs) { | |||||
auto info = reinterpret_cast<TensorInfo*>(i); | |||||
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); | |||||
input_infos.push_back(info); | |||||
input_descs.push_back(info->desc); | |||||
{ | |||||
MGB_LOCK_GUARD(m_mutex); | |||||
for (auto i : inputs) { | |||||
auto info = reinterpret_cast<TensorInfo*>(i); | |||||
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); | |||||
input_infos.push_back(info); | |||||
input_descs.push_back(info->desc); | |||||
} | |||||
} | } | ||||
lock.unlock(); | |||||
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | ||||
ApplyOp cmd{std::move(op)}; | ApplyOp cmd{std::move(op)}; | ||||
@@ -127,7 +129,7 @@ SmallVector<void*> ChannelImpl::apply_op( | |||||
} | } | ||||
} | } | ||||
} | } | ||||
m_worker.add_task(std::move(cmd)); | |||||
m_buffer.enqueue(std::move(cmd)); | |||||
if (!(validated && validated_bkp) && m_async_level == 1) { | if (!(validated && validated_bkp) && m_async_level == 1) { | ||||
sync(); | sync(); | ||||
} else if (m_async_level == 0) { | } else if (m_async_level == 0) { | ||||
@@ -150,7 +152,7 @@ HostTensorND ChannelImpl::get_value(void* handle) { | |||||
if (!info->value_fetched) { | if (!info->value_fetched) { | ||||
mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); | mgb_assert(!info->invalid, "Invalid tensor, unable to get_value!"); | ||||
m_waitee = info; | m_waitee = info; | ||||
m_worker.add_task(GetValue{info}); | |||||
m_buffer.enqueue(GetValue{info}); | |||||
m_cv.wait(lock, [&]() { | m_cv.wait(lock, [&]() { | ||||
check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
return info->value_fetched; | return info->value_fetched; | ||||
@@ -171,6 +173,7 @@ TensorShape ChannelImpl::get_shape(void* handle) { | |||||
std::unique_lock<decltype(m_mutex)> lock(m_mutex); | std::unique_lock<decltype(m_mutex)> lock(m_mutex); | ||||
mgb_assert(!m_waitee); | mgb_assert(!m_waitee); | ||||
m_waitee = info; | m_waitee = info; | ||||
m_buffer.enqueue(Flush{info}); | |||||
m_cv.wait(lock, [&]() { | m_cv.wait(lock, [&]() { | ||||
check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
return bool(info->ptr); | return bool(info->ptr); | ||||
@@ -206,6 +209,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) { | |||||
std::unique_lock<decltype(m_mutex)> lock(m_mutex); | std::unique_lock<decltype(m_mutex)> lock(m_mutex); | ||||
mgb_assert(!m_waitee); | mgb_assert(!m_waitee); | ||||
m_waitee = info; | m_waitee = info; | ||||
m_buffer.enqueue(Flush{info}); | |||||
m_cv.wait(lock, [&]() { | m_cv.wait(lock, [&]() { | ||||
check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
return bool(info->ptr); | return bool(info->ptr); | ||||
@@ -215,6 +219,9 @@ DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) { | |||||
} | } | ||||
void ChannelImpl::sync() { | void ChannelImpl::sync() { | ||||
if (!m_buffer.empty()) { | |||||
m_buffer.enqueue(Flush{}); | |||||
} | |||||
m_worker.wait_all_task_finish(); | m_worker.wait_all_task_finish(); | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
check_worker_exc_unsafe(); | check_worker_exc_unsafe(); | ||||
@@ -350,6 +357,10 @@ void ChannelImpl::set_drop_flag(bool flag) { | |||||
} | } | ||||
} | } | ||||
void ChannelImpl::set_buffer_length(int length) { | |||||
m_buffer.set_capacity(length); | |||||
} | |||||
void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) { | void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) { | ||||
if (!info->ptr && info->evict_type != NONE) { | if (!info->ptr && info->evict_type != NONE) { | ||||
if (info->evict_type == SWAP) { | if (info->evict_type == SWAP) { | ||||
@@ -401,6 +412,7 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||||
} else if constexpr (std::is_same_v<T, ApplyOp>) { | } else if constexpr (std::is_same_v<T, ApplyOp>) { | ||||
SmallVector<TensorPtr> tensor_inputs; | SmallVector<TensorPtr> tensor_inputs; | ||||
tensor_inputs.reserve(cmd.inputs.size()); | tensor_inputs.reserve(cmd.inputs.size()); | ||||
// refcnt == 1, owners: [TensorInfo::ptr] | |||||
for (auto i : cmd.inputs) { | for (auto i : cmd.inputs) { | ||||
if (m_enable_evict && i->evict_type != NONE) { | if (m_enable_evict && i->evict_type != NONE) { | ||||
if (!i->ptr) { | if (!i->ptr) { | ||||
@@ -408,9 +420,20 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||||
} | } | ||||
} | } | ||||
mgb_assert(i->ptr, "Invalid input tensor ptr!"); | mgb_assert(i->ptr, "Invalid input tensor ptr!"); | ||||
// refcnt ++, owners: [i->ptr, tensor_inputs] | |||||
tensor_inputs.push_back(i->ptr); | tensor_inputs.push_back(i->ptr); | ||||
} | } | ||||
auto tensor_outputs = OpDef::apply_on_physical_tensor(*cmd.op, tensor_inputs); | |||||
// Fused by command buffer. @see: CommandBuffer::fuse_del | |||||
// Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del. | |||||
// Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused. | |||||
for (auto* del : cmd.dels) { | |||||
// refcnt --, owners: [tensor_inputs] | |||||
// if it's decreased to 1, would be detected at @see: proxy_graph_detail::apply_on_physical_tensor | |||||
free(del); | |||||
} | |||||
// Here std::move is REQUIRED for removing duplicated references. | |||||
auto tensor_outputs = OpDef::apply_on_physical_tensor( | |||||
*cmd.op, std::move(tensor_inputs)); | |||||
mgb_assert(tensor_outputs.size() == cmd.outputs.size()); | mgb_assert(tensor_outputs.size() == cmd.outputs.size()); | ||||
for (size_t i = 0; i < tensor_outputs.size(); ++i) { | for (size_t i = 0; i < tensor_outputs.size(); ++i) { | ||||
produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i])); | produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i])); | ||||
@@ -436,8 +459,12 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||||
do_swap_out(cmd.dest); | do_swap_out(cmd.dest); | ||||
} else if constexpr (std::is_same_v<T, Drop>) { | } else if constexpr (std::is_same_v<T, Drop>) { | ||||
do_drop(cmd.dest); | do_drop(cmd.dest); | ||||
} else if constexpr (std::is_same_v<T, Move>) { | |||||
produce_tensor(cmd.dest, cmd.src->ptr); | |||||
free(cmd.src); | |||||
} else { | } else { | ||||
static_assert(!std::is_same_v<T, T>); | |||||
static_assert(std::is_same_v<T, Flush> || | |||||
std::is_same_v<T, Nop>); | |||||
} | } | ||||
} catch (...) { | } catch (...) { | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
@@ -454,7 +481,6 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||||
}, cmd); | }, cmd); | ||||
} | } | ||||
void ChannelImpl::check_worker_exc_unsafe() { | void ChannelImpl::check_worker_exc_unsafe() { | ||||
if (m_worker_exc) { | if (m_worker_exc) { | ||||
std::exception_ptr exc; | std::exception_ptr exc; | ||||
@@ -462,3 +488,120 @@ void ChannelImpl::check_worker_exc_unsafe() { | |||||
std::rethrow_exception(exc); | std::rethrow_exception(exc); | ||||
} | } | ||||
} | } | ||||
void ChannelImpl::CommandBuffer::enqueue(Command cmd) { | |||||
if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) { | |||||
return; | |||||
} | |||||
auto command_repr = std::visit([](auto& cmd){ return cmd.to_string(); }, cmd); | |||||
mgb_log_debug("%s Enqueued", command_repr.c_str()); | |||||
m_commands.push_back(std::move(cmd)); | |||||
auto flush_pos = flush_pos_for(m_commands.back()); | |||||
flush(flush_pos); | |||||
} | |||||
void ChannelImpl::CommandBuffer::flush(Handle pos) { | |||||
for (auto iter = m_commands.begin(); iter != pos; ++iter) { | |||||
auto command_repr = std::visit([](auto& cmd){ return cmd.to_string(); }, *iter); | |||||
mgb_log_debug("%s Flushed", command_repr.c_str()); | |||||
m_owner->m_worker.add_task(std::move(*iter)); | |||||
} | |||||
m_commands.erase(m_commands.begin(), pos); | |||||
} | |||||
auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle { | |||||
return std::visit([this](const auto& cmd) { | |||||
using T = std::decay_t<decltype(cmd)>; | |||||
if constexpr (std::is_same_v<T, ApplyOp>) { | |||||
auto* op_type = cmd.op->dyn_typeinfo(); | |||||
if (op_type == RemoteRecv::typeinfo() || | |||||
op_type == RemoteSend::typeinfo() || | |||||
op_type == CollectiveComm::typeinfo() || | |||||
op_type == opr::InputCallback::typeinfo() || | |||||
op_type == opr::OutputCallback::typeinfo() || | |||||
op_type == BackwardGraph::typeinfo()) { | |||||
return m_commands.end(); | |||||
} | |||||
} else if constexpr (std::is_same_v<T, GetValue>) { | |||||
return m_commands.end(); | |||||
} else if constexpr (std::is_same_v<T, Flush>) { | |||||
if (cmd.dest == nullptr) { | |||||
return m_commands.end(); | |||||
} | |||||
auto produce_iter = find_produce(cmd.dest, {m_commands.begin(), m_commands.end()}); | |||||
if (produce_iter != m_commands.end()) { | |||||
return produce_iter + 1; | |||||
} | |||||
} | |||||
if (m_commands.size() > m_capacity) { | |||||
return m_commands.begin() + (m_commands.size() - m_capacity); | |||||
} | |||||
return m_commands.begin(); | |||||
}, cmd); | |||||
} | |||||
/** | |||||
* 1. Find ApplyOp(dest) in buffered commands | |||||
* 2. Check if there are other usages between ApplyOp and Del, return false if not | |||||
* 3. Fuse Del into ApplyOp, return true | |||||
*/ | |||||
bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) { | |||||
auto* dest = cmd.dest; | |||||
// TODO: eliminate Puts | |||||
auto begin = m_commands.begin(), end = m_commands.end(); | |||||
auto apply_iter = std::find_if(begin, end, [dest](const Command& cmd){ | |||||
if (auto* apply = std::get_if<ApplyOp>(&cmd)) { | |||||
return std::count(apply->inputs.begin(), apply->inputs.end(), dest) > 0; | |||||
} | |||||
return false; | |||||
}); | |||||
if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) { | |||||
return false; | |||||
} | |||||
mgb_log_debug("%s Fused", cmd.to_string().c_str()); | |||||
std::get<ApplyOp>(*apply_iter).dels.push_back(dest); | |||||
return true; | |||||
} | |||||
auto ChannelImpl::CommandBuffer::find_last_usage(TensorInfo* dest, Range range) | |||||
-> Handle { | |||||
auto found = range[1]; | |||||
for (auto iter = range[0]; iter != range[1]; ++iter) { | |||||
std::visit([&](const auto& cmd) { | |||||
using T = std::decay_t<decltype(cmd)>; | |||||
if constexpr (std::is_same_v<T, ApplyOp>) { | |||||
if (std::count(cmd.inputs.begin(), cmd.inputs.end(), | |||||
dest) > 0) { | |||||
found = iter; | |||||
} | |||||
} else if constexpr (std::is_same_v<T, GetValue>) { | |||||
if (cmd.dest == dest) { | |||||
found = iter; | |||||
} | |||||
} else if constexpr (std::is_same_v<T, SwapIn> || | |||||
std::is_same_v<T, SwapOut> || | |||||
std::is_same_v<T, Drop>) { | |||||
//TODO: ignore swap-like commands, just remove them from buffer | |||||
if (cmd.dest == dest) { | |||||
found = iter; | |||||
} | |||||
} | |||||
}, *iter); | |||||
}; | |||||
return found; | |||||
} | |||||
auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) | |||||
-> Handle { | |||||
return std::find_if(range[0], range[1], [dest](auto& cmd) { | |||||
return std::visit([dest](const auto& cmd){ | |||||
using T = std::decay_t<decltype(cmd)>; | |||||
if constexpr (std::is_same_v<T, ApplyOp>) { | |||||
return std::count(cmd.outputs.begin(), cmd.outputs.end(), dest) > 0; | |||||
} else if constexpr (std::is_same_v<T, Put>) { | |||||
return cmd.dest == dest; | |||||
} | |||||
return false; | |||||
}, cmd); | |||||
}); | |||||
} |
@@ -9,13 +9,15 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include <variant> | |||||
#include <deque> | |||||
#include <future> | #include <future> | ||||
#include <list> | |||||
#include <unordered_set> | |||||
#include <variant> | |||||
#include "megbrain/utils/mempool.h" | #include "megbrain/utils/mempool.h" | ||||
#include "megbrain/imperative/interpreter.h" | #include "megbrain/imperative/interpreter.h" | ||||
namespace mgb::imperative::interpreter::intl { | namespace mgb::imperative::interpreter::intl { | ||||
using Handle = Interpreter::Handle; | using Handle = Interpreter::Handle; | ||||
@@ -58,39 +60,99 @@ struct Put { | |||||
TensorInfo* dest; | TensorInfo* dest; | ||||
HostTensorND value; | HostTensorND value; | ||||
bool no_cache = false; | bool no_cache = false; | ||||
std::string to_string() const { return ssprintf("Command: Put %p", dest); } | |||||
}; | }; | ||||
struct ApplyOp { | struct ApplyOp { | ||||
std::shared_ptr<OpDef> op; | std::shared_ptr<OpDef> op; | ||||
SmallVector<TensorInfo*> inputs; | SmallVector<TensorInfo*> inputs; | ||||
SmallVector<TensorInfo*> outputs; | SmallVector<TensorInfo*> outputs; | ||||
SmallVector<TensorInfo*> dels; | |||||
std::string to_string() const { | |||||
std::string builder{"Command: ApplyOp {"}; | |||||
builder += "inputs ["; | |||||
for (auto* input : inputs) { | |||||
builder += ssprintf("%p, ", input); | |||||
} | |||||
builder += "], outputs ["; | |||||
for (auto* output : outputs) { | |||||
builder += ssprintf("%p, ", output); | |||||
} | |||||
builder += "], dels ["; | |||||
for (auto* del : dels) { | |||||
builder += ssprintf("%p, ", del); | |||||
} | |||||
builder += "]"; | |||||
return builder; | |||||
} | |||||
}; | }; | ||||
struct Del { | struct Del { | ||||
TensorInfo* dest; | TensorInfo* dest; | ||||
std::string to_string() const { return ssprintf("Command: Del %p", dest); } | |||||
}; | }; | ||||
struct GetValue { | struct GetValue { | ||||
TensorInfo* dest; | TensorInfo* dest; | ||||
}; | |||||
std::string to_string() const { | |||||
return ssprintf("Command: GetValue %p", dest); | |||||
} | |||||
}; | |||||
struct SwapIn { | struct SwapIn { | ||||
TensorInfo* dest; | TensorInfo* dest; | ||||
std::string to_string() const { | |||||
return ssprintf("Command: SwapIn %p", dest); | |||||
} | |||||
}; | }; | ||||
struct SwapOut { | struct SwapOut { | ||||
TensorInfo* dest; | TensorInfo* dest; | ||||
std::string to_string() const { | |||||
return ssprintf("Command: SwapOut %p", dest); | |||||
} | |||||
}; | }; | ||||
struct Drop { | struct Drop { | ||||
TensorInfo* dest; | TensorInfo* dest; | ||||
std::string to_string() const { | |||||
return ssprintf("Command: Drop %p", dest); | |||||
} | |||||
}; | |||||
struct Move { | |||||
TensorInfo* src; | |||||
TensorInfo* dest; | |||||
std::string to_string() const { | |||||
return ssprintf("Command: Move %s to %s", | |||||
src->desc.layout.to_string().c_str(), | |||||
dest->desc.layout.to_string().c_str()); | |||||
} | |||||
}; | }; | ||||
struct Flush { | |||||
TensorInfo* dest = nullptr; | |||||
std::string to_string() const { | |||||
return ssprintf("Command: Flush %p", dest); | |||||
} | |||||
}; | |||||
struct Nop { | |||||
std::string to_string() const { return "Command: Nop"; } | |||||
}; | |||||
using Command = std::variant<Put, | using Command = std::variant<Put, | ||||
ApplyOp, | ApplyOp, | ||||
Del, | Del, | ||||
GetValue, | GetValue, | ||||
SwapIn, | SwapIn, | ||||
SwapOut, | SwapOut, | ||||
Drop>; | |||||
Drop, | |||||
Move, | |||||
Flush, | |||||
Nop>; | |||||
struct ChannelImpl : Interpreter::Channel { | struct ChannelImpl : Interpreter::Channel { | ||||
ChannelImpl() : m_worker(this) {} | |||||
ChannelImpl() : m_worker(this), m_buffer(this) {} | |||||
~ChannelImpl() override; | ~ChannelImpl() override; | ||||
Handle put(const HostTensorND& value, bool no_cache) override; | Handle put(const HostTensorND& value, bool no_cache) override; | ||||
@@ -116,6 +178,7 @@ struct ChannelImpl : Interpreter::Channel { | |||||
void close() override; | void close() override; | ||||
void set_swap_flag(bool) override; | void set_swap_flag(bool) override; | ||||
void set_drop_flag(bool) override; | void set_drop_flag(bool) override; | ||||
void set_buffer_length(int) override; | |||||
void config_async_level(int level) override; | void config_async_level(int level) override; | ||||
int get_async_level() override; | int get_async_level() override; | ||||
@@ -174,7 +237,56 @@ private: | |||||
std::mutex mtx; | std::mutex mtx; | ||||
std::unordered_map<TensorInfo*, TensorInfoPtr> tmap; | std::unordered_map<TensorInfo*, TensorInfoPtr> tmap; | ||||
}m_st; | }m_st; | ||||
/** | |||||
* Buf a command window for following fuse | |||||
* example: | |||||
* --------------------------------------------------------------------- | |||||
* | ..., Apply{in: (i0, i1), out: (o0, o1)}, ... + Del{i0} + Del{i1} | | |||||
* --------------------------------------------------------------------- | |||||
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0)}, ... + Del{i1} | | |||||
* --------------------------------------------------------------------- | |||||
* | ..., Apply{in: (i0, i1), out: (o0, o1), del: (i0, i1)}, ... | | |||||
* --------------------------------------------------------------------- | |||||
* Then the fused Apply may be invoked inplace. see: ChannelImpl::process_one_task | |||||
*/ | |||||
struct CommandBuffer { | |||||
CommandBuffer(ChannelImpl* owner) : m_owner(owner) { | |||||
int capacity = 3; | |||||
if(const char* capacity_str = MGB_GETENV("MEGENGINE_COMMAND_BUFFER_LENGTH")) { | |||||
capacity = atoi(capacity_str); | |||||
} | |||||
set_capacity(capacity); | |||||
} | |||||
void enqueue(Command cmd); | |||||
bool empty() const { | |||||
return m_commands.empty(); | |||||
} | |||||
void set_capacity(int capacity) { | |||||
mgb_assert(capacity >= 0 && capacity < 100, "invalid command buffer length"); | |||||
m_capacity = capacity; | |||||
} | |||||
private: | |||||
ChannelImpl* m_owner; | |||||
size_t m_capacity; | |||||
std::deque<Command> m_commands; | |||||
using Handle = decltype(m_commands)::iterator; | |||||
// [begin, end) | |||||
using Range = std::array<Handle, 2>; | |||||
// Launch commands in range [m_commands.begin(), pos) | |||||
void flush(Handle pos); | |||||
// Select flush position for incoming cmd | |||||
Handle flush_pos_for(const Command& cmd); | |||||
// Fuse del command into suitable ApplyOp | |||||
bool fuse_del(const Del& cmd); | |||||
// Returns the last handle that dest is used within range. If dest is not used, returns range[1] | |||||
Handle find_last_usage(TensorInfo* dest, Range range); | |||||
// Returns the produce position of dest. If not found, returns range[1] | |||||
Handle find_produce(TensorInfo* dest, Range range); | |||||
} m_buffer; | |||||
//! config whether raise error exactly when invoking op. | //! config whether raise error exactly when invoking op. | ||||
//! level 2: both device and user side errors are async; | //! level 2: both device and user side errors are async; | ||||
//! level 1: user side errors are sync; | //! level 1: user side errors are sync; | ||||
@@ -32,8 +32,8 @@ std::shared_ptr<OpDef> OpDef::make_from_op_node( | |||||
SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | ||||
const OpDef& def, | const OpDef& def, | ||||
const SmallVector<TensorPtr>& inputs) { | |||||
return def.trait()->apply_on_physical_tensor(def, inputs); | |||||
SmallVector<TensorPtr> inputs) { | |||||
return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); | |||||
} | } | ||||
VarNodeArray OpDef::apply_on_var_node( | VarNodeArray OpDef::apply_on_var_node( | ||||
@@ -17,17 +17,17 @@ namespace mgb { | |||||
namespace imperative { | namespace imperative { | ||||
namespace detail { | namespace detail { | ||||
template<typename Signature> | |||||
template <typename Signature> | |||||
struct OpMeth; | struct OpMeth; | ||||
template<typename RType, typename ...Args> | |||||
struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> { | |||||
template <typename RType, typename... Args> | |||||
struct OpMeth<RType(Args...)> : public thin_function<RType(Args...)> { | |||||
using Base = thin_function<RType(Args...)>; | using Base = thin_function<RType(Args...)>; | ||||
using Base::Base; | using Base::Base; | ||||
RType operator()(Args... args) const { | RType operator()(Args... args) const { | ||||
if (!this->Base::operator bool()) { | if (!this->Base::operator bool()) { | ||||
mgb_throw(MegBrainError, "Not Implemented"); | mgb_throw(MegBrainError, "Not Implemented"); | ||||
} | } | ||||
return this->Base::operator ()(args...); | |||||
return this->Base::operator()(std::forward<Args>(args)...); | |||||
} | } | ||||
}; | }; | ||||
template<typename T> | template<typename T> | ||||
@@ -56,7 +56,7 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type { | |||||
return opr->usable_output(); | return opr->usable_output(); | ||||
} | } | ||||
}; | }; | ||||
} // detail | |||||
} // namespace detail | |||||
using OpDefMaker = detail::OpMeth< | using OpDefMaker = detail::OpMeth< | ||||
decltype(OpDef::make_from_op_node)>; | decltype(OpDef::make_from_op_node)>; | ||||
@@ -56,17 +56,15 @@ protected: | |||||
return {}; | return {}; | ||||
} | } | ||||
AsyncReleaser() { | |||||
EventPool::without_timer(); | |||||
} | |||||
public: | public: | ||||
static AsyncReleaser* inst() { | static AsyncReleaser* inst() { | ||||
static AsyncReleaser releaser; | static AsyncReleaser releaser; | ||||
return &releaser; | return &releaser; | ||||
} | } | ||||
~AsyncReleaser() { m_waiter.wait_task_queue_empty(); } | |||||
~AsyncReleaser() { | |||||
m_waiter.wait_task_queue_empty(); | |||||
} | |||||
void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); } | void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); } | ||||
@@ -85,8 +83,6 @@ public: | |||||
class CompNodeSyncManager : public CompNodeDepedentObject { | class CompNodeSyncManager : public CompNodeDepedentObject { | ||||
ThinHashMap<Blob*, std::unique_ptr<CompNode::Event>> m_blob2event; | ThinHashMap<Blob*, std::unique_ptr<CompNode::Event>> m_blob2event; | ||||
std::mutex m_mtx; | std::mutex m_mtx; | ||||
private: | |||||
static CompNodeSyncManager mgr; | |||||
public: | public: | ||||
std::shared_ptr<void> on_comp_node_finalize() override { | std::shared_ptr<void> on_comp_node_finalize() override { | ||||
MGB_LOCK_GUARD(m_mtx); | MGB_LOCK_GUARD(m_mtx); | ||||
@@ -94,8 +90,9 @@ public: | |||||
return {}; | return {}; | ||||
} | } | ||||
static CompNodeSyncManager* inst() { | |||||
return &mgr; | |||||
static CompNodeSyncManager& inst() { | |||||
static CompNodeSyncManager sl_inst; | |||||
return sl_inst; | |||||
} | } | ||||
CompNode::Event* get_or_create_event(Blob* blob) { | CompNode::Event* get_or_create_event(Blob* blob) { | ||||
@@ -113,7 +110,6 @@ public: | |||||
m_blob2event.erase(blob); | m_blob2event.erase(blob); | ||||
} | } | ||||
}; | }; | ||||
CompNodeSyncManager CompNodeSyncManager::mgr; | |||||
// Cache for small blobs | // Cache for small blobs | ||||
// 1. A blob has to be seen twice (within a window) to be eligible for cache | // 1. A blob has to be seen twice (within a window) to be eligible for cache | ||||
@@ -236,9 +232,12 @@ struct MultiCNConstTensorCache : CompNodeDepedentObject { | |||||
MGB_LOCK_GUARD(mtx); | MGB_LOCK_GUARD(mtx); | ||||
return cn2cache[hv.comp_node()].lookup(hv); | return cn2cache[hv.comp_node()].lookup(hv); | ||||
} | } | ||||
}; | |||||
MultiCNConstTensorCache const_tensor_cache; | |||||
static MultiCNConstTensorCache& inst() { | |||||
static MultiCNConstTensorCache sl_inst; | |||||
return sl_inst; | |||||
} | |||||
}; | |||||
} // namespace | } // namespace | ||||
@@ -246,20 +245,26 @@ void EventDeleter::operator()(CompNode::Event* event) { | |||||
EventPool::without_timer().free(event); | EventPool::without_timer().free(event); | ||||
} | } | ||||
namespace { | |||||
std::atomic_uint64_t next_blob_id = 0; | |||||
} | |||||
Blob::Blob(const DeviceTensorStorage& s): | Blob::Blob(const DeviceTensorStorage& s): | ||||
m_comp_node{s.comp_node()}, m_storage{s.raw_storage()}, | m_comp_node{s.comp_node()}, m_storage{s.raw_storage()}, | ||||
m_size{s.size()} { | m_size{s.size()} { | ||||
m_id = next_blob_id++; | |||||
BlobManager::inst()->register_blob(this); | BlobManager::inst()->register_blob(this); | ||||
} | } | ||||
Blob::Blob(CompNode cn, size_t sz): | Blob::Blob(CompNode cn, size_t sz): | ||||
m_comp_node{cn}, m_storage{}, m_size{sz} { | m_comp_node{cn}, m_storage{}, m_size{sz} { | ||||
m_id = next_blob_id++; | |||||
BlobManager::inst()->register_blob(this); | BlobManager::inst()->register_blob(this); | ||||
} | } | ||||
Blob::~Blob() { | Blob::~Blob() { | ||||
BlobManager::inst()->unregister_blob(this); | BlobManager::inst()->unregister_blob(this); | ||||
CompNodeSyncManager::inst()->remove(this); | |||||
CompNodeSyncManager::inst().remove(this); | |||||
} | } | ||||
const Blob::RawStorage& Blob::storage() { | const Blob::RawStorage& Blob::storage() { | ||||
@@ -302,7 +307,7 @@ Tensor::Tensor(const BlobPtr blob, const size_t offset, const TensorLayout& layo | |||||
: m_layout{layout}, m_blob{blob}, m_offset{offset} {} | : m_layout{layout}, m_blob{blob}, m_offset{offset} {} | ||||
TensorPtr Tensor::make(const HostTensorND& hv) { | TensorPtr Tensor::make(const HostTensorND& hv) { | ||||
auto&& blob = const_tensor_cache.lookup(hv); | |||||
auto&& blob = MultiCNConstTensorCache::inst().lookup(hv); | |||||
if (blob) { | if (blob) { | ||||
return make(std::forward<decltype(blob)>(blob), hv.layout(), hv); | return make(std::forward<decltype(blob)>(blob), hv.layout(), hv); | ||||
} | } | ||||
@@ -366,13 +371,17 @@ void Tensor::add_release_callback(CompNode cn) { | |||||
} | } | ||||
CompNode::Event* Tensor::get_or_create_event() { | CompNode::Event* Tensor::get_or_create_event() { | ||||
auto e = CompNodeSyncManager::inst()->get_or_create_event(m_blob.get()); | |||||
auto e = CompNodeSyncManager::inst().get_or_create_event(m_blob.get()); | |||||
e->record(); | e->record(); | ||||
return e; | return e; | ||||
} | } | ||||
void Tensor::_static_init() { | |||||
void Tensor::static_initialize() { | |||||
EventPool::with_timer(); | |||||
EventPool::without_timer(); | EventPool::without_timer(); | ||||
AsyncReleaser::inst(); | |||||
CompNodeSyncManager::inst(); | |||||
MultiCNConstTensorCache::inst(); | |||||
} | } | ||||
} // namespace imperative | } // namespace imperative | ||||
@@ -117,7 +117,7 @@ void Profiler::start(uint32_t flags) { | |||||
auto hook_apply_on_var_node = | auto hook_apply_on_var_node = | ||||
make_shared_hook(&trait.apply_on_var_node); | make_shared_hook(&trait.apply_on_var_node); | ||||
hook_apply_on_physical_tensor->apply_hook([this, flags] | hook_apply_on_physical_tensor->apply_hook([this, flags] | ||||
(auto&& apply, const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
(auto&& apply, const OpDef& def, SmallVector<TensorPtr> inputs) { | |||||
auto shape2vector = [](const TensorShape& shape) { | auto shape2vector = [](const TensorShape& shape) { | ||||
std::vector<size_t> vector_shape; | std::vector<size_t> vector_shape; | ||||
for (size_t i = 0; i < shape.ndim; i++) { | for (size_t i = 0; i < shape.ndim; i++) { | ||||
@@ -11,6 +11,7 @@ | |||||
#include "./proxy_graph.h" | #include "./proxy_graph.h" | ||||
#include "megbrain/imperative/proxy_graph_detail.h" | #include "megbrain/imperative/proxy_graph_detail.h" | ||||
#include "megbrain/imperative/ops/autogen.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -70,11 +71,34 @@ void exec(const OpDef& def, | |||||
SmallVector<TensorPtr> | SmallVector<TensorPtr> | ||||
apply_on_physical_tensor(const OpDef& def, | apply_on_physical_tensor(const OpDef& def, | ||||
const SmallVector<TensorPtr>& inputs) { | |||||
auto desc = infer_output_attrs(def, inputs); | |||||
SmallVector<TensorPtr> outputs; | |||||
for (auto&& i : desc) { | |||||
outputs.push_back(Tensor::make(i.layout, i.comp_node)); | |||||
SmallVector<TensorPtr> inputs) { | |||||
auto output_descs = infer_output_attrs(def, inputs); | |||||
SmallVector<TensorPtr> outputs(output_descs.size(), {}); | |||||
for (size_t i = 0; i < outputs.size(); i++) { | |||||
auto& output = outputs[i]; | |||||
auto& output_desc = output_descs[i]; | |||||
if (def.same_type<Elemwise>()) { | |||||
for (size_t j = 0; j < inputs.size(); j++) { | |||||
// TODO: reindex inputs to support inplace exprs like 'y = x op x'. | |||||
auto& input = inputs[j]; | |||||
// Because we pass inputs by value, if input and input->blob() are all unique, | |||||
// their ownerships are on the stack, thus we can reuse them safely. | |||||
// @see: interpreter::intl::ChannelImpl::process_one_task | |||||
if (input.unique() && input->blob().unique() && input->blob()->storage().unique() && | |||||
input->layout().dtype == output_desc.layout.dtype && | |||||
input->layout().eq_layout(output_desc.layout) && | |||||
input->comp_node() == output_desc.comp_node) { | |||||
static std::atomic_llong inplace_count = 0; | |||||
mgb_log_debug("do inplace for elemwise, layout: %s, count: %lld", | |||||
output_desc.layout.to_string().c_str(), ++inplace_count); | |||||
output = Tensor::make(input->blob(), input->layout(), input->offset()); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
if (!output) { | |||||
output = Tensor::make(output_desc.layout, output_desc.comp_node); | |||||
} | |||||
} | } | ||||
exec(def, inputs, outputs); | exec(def, inputs, outputs); | ||||
return outputs; | return outputs; | ||||
@@ -44,6 +44,7 @@ struct Interpreter { | |||||
virtual void close() = 0; | virtual void close() = 0; | ||||
virtual void set_swap_flag(bool) = 0; | virtual void set_swap_flag(bool) = 0; | ||||
virtual void set_drop_flag(bool) = 0; | virtual void set_drop_flag(bool) = 0; | ||||
virtual void set_buffer_length(int) = 0; | |||||
virtual void config_async_level(int level) = 0; | virtual void config_async_level(int level) = 0; | ||||
virtual int get_async_level() = 0; | virtual int get_async_level() = 0; | ||||
@@ -38,7 +38,7 @@ public: | |||||
static SmallVector<TensorPtr> apply_on_physical_tensor( | static SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
const OpDef& def, | const OpDef& def, | ||||
const SmallVector<TensorPtr>& inputs); | |||||
SmallVector<TensorPtr> inputs); | |||||
static cg::VarNodeArray apply_on_var_node( | static cg::VarNodeArray apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
@@ -46,11 +46,16 @@ public: | |||||
size_t size() const { | size_t size() const { | ||||
return m_size; | return m_size; | ||||
} | } | ||||
size_t id() const { | |||||
return m_id; | |||||
} | |||||
private: | private: | ||||
friend class BlobManagerImpl; | friend class BlobManagerImpl; | ||||
CompNode m_comp_node; | CompNode m_comp_node; | ||||
mutable RawStorage m_storage; | mutable RawStorage m_storage; | ||||
size_t m_size = 0; | size_t m_size = 0; | ||||
size_t m_id; | |||||
}; | }; | ||||
struct EventDeleter { | struct EventDeleter { | ||||
@@ -134,8 +139,7 @@ public: | |||||
// Make sure all static objects required to destruct a tensor has completed | // Make sure all static objects required to destruct a tensor has completed | ||||
// construction. All static storage duration object that holds tensors must | // construction. All static storage duration object that holds tensors must | ||||
// call this method before their constructors completes. | // call this method before their constructors completes. | ||||
static void _static_init(); | |||||
static void static_initialize(); | |||||
private: | private: | ||||
TensorLayout m_layout; | TensorLayout m_layout; | ||||
@@ -19,7 +19,7 @@ namespace proxy_graph_detail { | |||||
SmallVector<TensorPtr> | SmallVector<TensorPtr> | ||||
apply_on_physical_tensor(const OpDef& def, | apply_on_physical_tensor(const OpDef& def, | ||||
const SmallVector<TensorPtr>& inputs); | |||||
SmallVector<TensorPtr> inputs); | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def, | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def, | ||||
const SmallVector<LogicalTensorDesc>& inputs); | const SmallVector<LogicalTensorDesc>& inputs); | ||||