GitOrigin-RevId: 2663504470
tags/v1.3.0
@@ -29,7 +29,7 @@ Interpreter& Interpreter::inst() { | |||||
return inst_; | return inst_; | ||||
} | } | ||||
void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||||
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||||
auto info = alloc(); | auto info = alloc(); | ||||
info->desc.layout = value.layout(); | info->desc.layout = value.layout(); | ||||
info->desc.comp_node = value.comp_node(); | info->desc.comp_node = value.comp_node(); | ||||
@@ -39,7 +39,7 @@ void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||||
return info; | return info; | ||||
} | } | ||||
void* ChannelImpl::put(const DeviceTensorND& data) { | |||||
Handle ChannelImpl::put(const DeviceTensorND& data) { | |||||
auto info = alloc(); | auto info = alloc(); | ||||
info->desc.layout = data.layout(); | info->desc.layout = data.layout(); | ||||
info->desc.comp_node = data.comp_node(); | info->desc.comp_node = data.comp_node(); | ||||
@@ -48,12 +48,12 @@ void* ChannelImpl::put(const DeviceTensorND& data) { | |||||
return info; | return info; | ||||
} | } | ||||
void ChannelImpl::del(void* handle) { | |||||
void ChannelImpl::del(Handle handle) { | |||||
mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | ||||
m_buffer.enqueue(Del{reinterpret_cast<TensorInfo*>(handle)}); | m_buffer.enqueue(Del{reinterpret_cast<TensorInfo*>(handle)}); | ||||
} | } | ||||
void ChannelImpl::swap_in(void* handle) { | |||||
void ChannelImpl::swap_in(Handle handle) { | |||||
if (m_enable_evict & SWAP) { | if (m_enable_evict & SWAP) { | ||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
@@ -61,7 +61,7 @@ void ChannelImpl::swap_in(void* handle) { | |||||
} | } | ||||
} | } | ||||
void ChannelImpl::swap_out(void* handle) { | |||||
void ChannelImpl::swap_out(Handle handle) { | |||||
if (m_enable_evict & SWAP) { | if (m_enable_evict & SWAP) { | ||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
@@ -69,7 +69,7 @@ void ChannelImpl::swap_out(void* handle) { | |||||
} | } | ||||
} | } | ||||
void ChannelImpl::drop(void* handle) { | |||||
void ChannelImpl::drop(Handle handle) { | |||||
if (m_enable_evict & DROP) { | if (m_enable_evict & DROP) { | ||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
@@ -77,45 +77,91 @@ void ChannelImpl::drop(void* handle) { | |||||
} | } | ||||
} | } | ||||
SmallVector<void*> ChannelImpl::apply_op( | |||||
void ChannelImpl::dispatch_default_cpu( | |||||
std::shared_ptr<OpDef> op, | std::shared_ptr<OpDef> op, | ||||
const SmallVector<void*>& inputs) { | |||||
for (auto i : inputs) { | |||||
mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), | |||||
"invalid handle: %p", i); | |||||
} | |||||
SmallVector<TensorInfo*> input_infos; | |||||
input_infos.reserve(inputs.size()); | |||||
SmallVector<LogicalTensorDesc> input_descs; | |||||
input_descs.reserve(inputs.size()); | |||||
const SmallVector<TensorInfo*>& input_infos, | |||||
const SmallVector<LogicalTensorDesc>& input_descs, | |||||
SmallVector<Handle>* outputs) { | |||||
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | |||||
SmallVector<DeviceTensorND> input_tensornds; | |||||
input_tensornds.reserve(input_descs.size()); | |||||
CompNode output_cn; | |||||
{ | { | ||||
MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
for (auto i : inputs) { | |||||
auto info = reinterpret_cast<TensorInfo*>(i); | |||||
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); | |||||
input_infos.push_back(info); | |||||
input_descs.push_back(info->desc); | |||||
for (auto&& info : input_infos) { | |||||
mgb_assert(info->ptr, "invalid tensor ptr!"); | |||||
if (!output_cn.valid()) { | |||||
output_cn = info->ptr->comp_node(); | |||||
} else { | |||||
mgb_assert(output_cn == info->ptr->comp_node(), "cannot decide output comp node"); | |||||
} | |||||
mgb_assert(info->ptr->try_get_value(), "no valid host value"); | |||||
input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu()); | |||||
} | |||||
} | |||||
outputs->reserve(output_descs.size()); | |||||
SmallVector<DeviceTensorND> output_tensornds; | |||||
output_tensornds.reserve(output_descs.size()); | |||||
for (auto&& desc : output_descs) { | |||||
// TODO: may conflict with condtake, which need alloc inside | |||||
mgb_assert(!desc.layout.is_empty()); | |||||
// use HostTensorND alloc_host for cuda pinned memory | |||||
output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu()); | |||||
} | |||||
OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); | |||||
SmallVector<TensorInfo*> output_infos; | |||||
output_infos.reserve(output_descs.size()); | |||||
for (auto&& tensornd : output_tensornds) { | |||||
// tensornd -> host_tensornd | |||||
HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd) | |||||
.proxy_to_comp_node(output_cn); | |||||
// tensornd -> desc | |||||
LogicalTensorDesc desc = {tensornd.layout(), output_cn, tensornd}; | |||||
// tensornd -> tensor | |||||
auto info = alloc(); | |||||
info->desc = desc; | |||||
m_valid_handle.insert(info); | |||||
output_infos.push_back(info); | |||||
info->ptr = Tensor::make(host_tensornd, true); // host_only=true | |||||
info->value_fetched = true; | |||||
outputs->push_back(info); | |||||
} | |||||
if (m_enable_evict & DROP) { | |||||
for (auto out : output_infos) { | |||||
out->path.op = op; | |||||
for (auto out_ : output_infos) { | |||||
out->path.outputs.push_back(m_st.at(out_)); | |||||
} | |||||
for (auto inp : input_infos) { | |||||
out->path.inputs.push_back(m_st.at(inp)); | |||||
inp->path.dep_outputs.push_back(m_st.at(out)); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | |||||
void ChannelImpl::dispatch_kernel( | |||||
std::shared_ptr<OpDef> op, | |||||
const SmallVector<TensorInfo*>& input_infos, | |||||
const SmallVector<LogicalTensorDesc>& input_descs, | |||||
SmallVector<Handle>* outputs) { | |||||
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | ||||
ApplyOp cmd{std::move(op)}; | ApplyOp cmd{std::move(op)}; | ||||
cmd.inputs = std::move(input_infos); | cmd.inputs = std::move(input_infos); | ||||
cmd.outputs.reserve(output_descs.size()); | cmd.outputs.reserve(output_descs.size()); | ||||
SmallVector<void*> outputs; | |||||
// FIXME: remove this check when op check is correct | |||||
bool validated_bkp = true; | |||||
for (size_t i = 0;i < output_descs.size();i ++) { | |||||
auto&& desc = output_descs[i]; | |||||
if (desc.layout.ndim == 0) { | |||||
validated_bkp = false; | |||||
} | |||||
outputs->reserve(output_descs.size()); | |||||
for (auto&& desc : output_descs) { | |||||
auto info = alloc(); | auto info = alloc(); | ||||
info->desc = desc; | info->desc = desc; | ||||
m_valid_handle.insert(info); | m_valid_handle.insert(info); | ||||
cmd.outputs.push_back(info); | cmd.outputs.push_back(info); | ||||
outputs.push_back(info); | |||||
outputs->push_back(info); | |||||
} | } | ||||
if (m_enable_evict & DROP) { | if (m_enable_evict & DROP) { | ||||
for (auto out : cmd.outputs) { | for (auto out : cmd.outputs) { | ||||
@@ -130,20 +176,55 @@ SmallVector<void*> ChannelImpl::apply_op( | |||||
} | } | ||||
} | } | ||||
m_buffer.enqueue(std::move(cmd)); | m_buffer.enqueue(std::move(cmd)); | ||||
if (!(validated && validated_bkp) && m_async_level == 1) { | |||||
if (!validated && m_async_level == 1) { | |||||
sync(); | sync(); | ||||
} else if (m_async_level == 0) { | } else if (m_async_level == 0) { | ||||
sync(); | sync(); | ||||
// check device error | // check device error | ||||
for (auto&& oup : outputs) { | |||||
for (auto&& oup : *outputs) { | |||||
auto info = reinterpret_cast<TensorInfo*>(oup); | auto info = reinterpret_cast<TensorInfo*>(oup); | ||||
info->ptr->comp_node().sync(); | info->ptr->comp_node().sync(); | ||||
} | } | ||||
} | } | ||||
} | |||||
SmallVector<Handle> ChannelImpl::apply_op( | |||||
std::shared_ptr<OpDef> op, | |||||
const SmallVector<Handle>& inputs) { | |||||
for (auto i : inputs) { | |||||
mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), | |||||
"invalid handle: %p", i); | |||||
} | |||||
SmallVector<TensorInfo*> input_infos; | |||||
input_infos.reserve(inputs.size()); | |||||
SmallVector<LogicalTensorDesc> input_descs; | |||||
input_descs.reserve(inputs.size()); | |||||
{ | |||||
MGB_LOCK_GUARD(m_mutex); | |||||
for (auto i : inputs) { | |||||
auto info = reinterpret_cast<TensorInfo*>(i); | |||||
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); | |||||
input_infos.push_back(info); | |||||
input_descs.push_back(info->desc); | |||||
} | |||||
} | |||||
SmallVector<Handle> outputs; | |||||
switch (OpDef::decide_dispatch_mode(*op, input_descs)) { | |||||
case DEFAULT_CPU: { | |||||
dispatch_default_cpu(op, input_infos, input_descs, &outputs); | |||||
break; | |||||
} | |||||
case KERNEL: { | |||||
dispatch_kernel(op, input_infos, input_descs, &outputs); | |||||
break; | |||||
} | |||||
} | |||||
mgb_assert(outputs.size() > 0, "Invalid dispatch mode!"); | |||||
return outputs; | return outputs; | ||||
} | } | ||||
HostTensorND ChannelImpl::get_value(void* handle) { | |||||
HostTensorND ChannelImpl::get_value(Handle handle) { | |||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
auto info = reinterpret_cast<TensorInfo*>(handle); | auto info = reinterpret_cast<TensorInfo*>(handle); | ||||
@@ -163,7 +244,7 @@ HostTensorND ChannelImpl::get_value(void* handle) { | |||||
return info->ptr->get_value(); | return info->ptr->get_value(); | ||||
} | } | ||||
TensorShape ChannelImpl::get_shape(void* handle) { | |||||
TensorShape ChannelImpl::get_shape(Handle handle) { | |||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
auto info = reinterpret_cast<TensorInfo*>(handle); | auto info = reinterpret_cast<TensorInfo*>(handle); | ||||
@@ -184,7 +265,7 @@ TensorShape ChannelImpl::get_shape(void* handle) { | |||||
return ret; | return ret; | ||||
} | } | ||||
DType ChannelImpl::get_dtype(void* handle) { | |||||
DType ChannelImpl::get_dtype(Handle handle) { | |||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
auto info = reinterpret_cast<TensorInfo*>(handle); | auto info = reinterpret_cast<TensorInfo*>(handle); | ||||
@@ -193,7 +274,7 @@ DType ChannelImpl::get_dtype(void* handle) { | |||||
return ret; | return ret; | ||||
} | } | ||||
CompNode ChannelImpl::get_device(void* handle) { | |||||
CompNode ChannelImpl::get_device(Handle handle) { | |||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
auto info = reinterpret_cast<TensorInfo*>(handle); | auto info = reinterpret_cast<TensorInfo*>(handle); | ||||
@@ -202,7 +283,7 @@ CompNode ChannelImpl::get_device(void* handle) { | |||||
return ret; | return ret; | ||||
} | } | ||||
DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) { | |||||
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { | |||||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
"invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
auto info = reinterpret_cast<TensorInfo*>(handle); | auto info = reinterpret_cast<TensorInfo*>(handle); | ||||
@@ -262,25 +343,15 @@ ChannelImpl::~ChannelImpl() { | |||||
} | } | ||||
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) { | void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) { | ||||
if (notice) { | |||||
MGB_LOCK_GUARD(m_mutex); | |||||
dest->value_fetched = ptr->value_fetched(); | |||||
// update tensor desc for static infer | |||||
// if (dest->desc.layout.ndim) { | |||||
// mgb_assert(dest->desc.layout.eq_shape(ptr->layout())); | |||||
// } | |||||
dest->desc.layout = ptr->layout(); | |||||
dest->desc.comp_node = ptr->comp_node(); | |||||
dest->ptr = std::move(ptr); | |||||
if (m_waitee == dest) { | |||||
m_cv.notify_all(); | |||||
} | |||||
} else { | |||||
dest->value_fetched = ptr->value_fetched(); | |||||
// update tensor desc for static infer | |||||
dest->desc.layout = ptr->layout(); | |||||
dest->desc.comp_node = ptr->comp_node(); | |||||
dest->ptr = std::move(ptr); | |||||
auto lock = notice ? std::unique_lock<std::mutex>(m_mutex) | |||||
: std::unique_lock<std::mutex>(); | |||||
dest->value_fetched = ptr->value_fetched(); | |||||
// update tensor desc for static infer | |||||
dest->desc.layout = ptr->layout(); | |||||
dest->desc.comp_node = ptr->comp_node(); | |||||
dest->ptr = std::move(ptr); | |||||
if (notice && m_waitee == dest) { | |||||
m_cv.notify_all(); | |||||
} | } | ||||
} | } | ||||
@@ -295,7 +366,7 @@ void ChannelImpl::do_swap_out(TensorInfo* dest) { | |||||
dest->evict_type = SWAP; | dest->evict_type = SWAP; | ||||
dest->value_fetched = false; | dest->value_fetched = false; | ||||
// TODO: swap in parallel | // TODO: swap in parallel | ||||
dest->h_value.copy_from(dest->ptr->dev_tensor()).sync(); | |||||
dest->h_value = dest->ptr->get_value(); | |||||
dest->ptr.reset(); | dest->ptr.reset(); | ||||
} | } | ||||
@@ -198,6 +198,17 @@ private: | |||||
void do_drop(TensorInfo* dest); | void do_drop(TensorInfo* dest); | ||||
void regenerate(TensorInfo* dest, bool must_drop); | void regenerate(TensorInfo* dest, bool must_drop); | ||||
void dispatch_default_cpu( | |||||
std::shared_ptr<OpDef> op, | |||||
const SmallVector<TensorInfo*>& input_infos, | |||||
const SmallVector<LogicalTensorDesc>& input_descs, | |||||
SmallVector<Handle>* outputs); | |||||
void dispatch_kernel( | |||||
std::shared_ptr<OpDef> op, | |||||
const SmallVector<TensorInfo*>& input_infos, | |||||
const SmallVector<LogicalTensorDesc>& input_descs, | |||||
SmallVector<Handle>* outputs); | |||||
std::mutex m_mutex; | std::mutex m_mutex; | ||||
std::condition_variable m_cv; | std::condition_variable m_cv; | ||||
MemPool<TensorInfo> m_pool; | MemPool<TensorInfo> m_pool; | ||||
@@ -30,12 +30,26 @@ std::shared_ptr<OpDef> OpDef::make_from_op_node( | |||||
return trait->make_from_op_node(node); | return trait->make_from_op_node(node); | ||||
} | } | ||||
DispatchMode OpDef::decide_dispatch_mode( | |||||
const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs) { | |||||
return def.trait()->decide_dispatch_mode(def, inputs); | |||||
} | |||||
SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | ||||
const OpDef& def, | const OpDef& def, | ||||
SmallVector<TensorPtr> inputs) { | SmallVector<TensorPtr> inputs) { | ||||
return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); | return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); | ||||
} | } | ||||
void OpDef::apply_on_device_tensornd( | |||||
const OpDef& def, | |||||
const SmallVector<DeviceTensorND>& inputs, | |||||
SmallVector<DeviceTensorND>* outputs) { | |||||
def.trait()->apply_on_device_tensornd(def, inputs, outputs); | |||||
return; | |||||
} | |||||
VarNodeArray OpDef::apply_on_var_node( | VarNodeArray OpDef::apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
@@ -9,12 +9,16 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include <exception> | |||||
#include <sstream> | #include <sstream> | ||||
#include <stdexcept> | |||||
#include "megbrain/imperative/op_def.h" | |||||
#include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
#include "megbrain/imperative/proxy_graph_detail.h" | |||||
#include "megbrain/tensor.h" | |||||
#include "./op_trait.h" | #include "./op_trait.h" | ||||
#include "megbrain/imperative/proxy_graph_detail.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
@@ -62,6 +66,12 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){ | |||||
} | } | ||||
} | } | ||||
DispatchMode fallback_decide_dispatch_mode( | |||||
const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs) { | |||||
return KERNEL; | |||||
} | |||||
OpTraitRegistry& OpTraitRegistry::fallback() { | OpTraitRegistry& OpTraitRegistry::fallback() { | ||||
if (trait->apply_on_var_node) { | if (trait->apply_on_var_node) { | ||||
// fallback to proxy graph impl | // fallback to proxy graph impl | ||||
@@ -78,6 +88,9 @@ OpTraitRegistry& OpTraitRegistry::fallback() { | |||||
proxy_graph_detail::make_backward_graph; | proxy_graph_detail::make_backward_graph; | ||||
} | } | ||||
} | } | ||||
if (!trait->decide_dispatch_mode) { | |||||
trait->decide_dispatch_mode = fallback_decide_dispatch_mode; | |||||
} | |||||
return *this; | return *this; | ||||
} | } | ||||
@@ -60,8 +60,12 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type { | |||||
using OpDefMaker = detail::OpMeth< | using OpDefMaker = detail::OpMeth< | ||||
decltype(OpDef::make_from_op_node)>; | decltype(OpDef::make_from_op_node)>; | ||||
using DecideDispatchMode = detail::OpMeth< | |||||
decltype(OpDef::decide_dispatch_mode)>; | |||||
using ApplyOnPhysicalTensor = detail::OpMeth< | using ApplyOnPhysicalTensor = detail::OpMeth< | ||||
decltype(OpDef::apply_on_physical_tensor)>; | decltype(OpDef::apply_on_physical_tensor)>; | ||||
using ApplyOnDeviceTensorND = detail::OpMeth< | |||||
decltype(OpDef::apply_on_device_tensornd)>; | |||||
using ApplyOnVarNode = detail::OpMeth< | using ApplyOnVarNode = detail::OpMeth< | ||||
decltype(OpDef::apply_on_var_node)>; | decltype(OpDef::apply_on_var_node)>; | ||||
using InferOutputAttrsFallible = detail::OpMeth< | using InferOutputAttrsFallible = detail::OpMeth< | ||||
@@ -74,7 +78,9 @@ using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; | |||||
struct OpTrait { | struct OpTrait { | ||||
const char* name; | const char* name; | ||||
OpDefMaker make_from_op_node; | OpDefMaker make_from_op_node; | ||||
DecideDispatchMode decide_dispatch_mode; | |||||
ApplyOnPhysicalTensor apply_on_physical_tensor; | ApplyOnPhysicalTensor apply_on_physical_tensor; | ||||
ApplyOnDeviceTensorND apply_on_device_tensornd; | |||||
ApplyOnVarNode apply_on_var_node; | ApplyOnVarNode apply_on_var_node; | ||||
InferOutputAttrsFallible infer_output_attrs_fallible; | InferOutputAttrsFallible infer_output_attrs_fallible; | ||||
GradMaker make_backward_graph; | GradMaker make_backward_graph; | ||||
@@ -88,7 +94,9 @@ struct OpTrait { | |||||
#define FOR_EACH_OP_METH(cb) \ | #define FOR_EACH_OP_METH(cb) \ | ||||
cb(make_from_op_node) \ | cb(make_from_op_node) \ | ||||
cb(decide_dispatch_mode) \ | |||||
cb(apply_on_physical_tensor) \ | cb(apply_on_physical_tensor) \ | ||||
cb(apply_on_device_tensornd) \ | |||||
cb(apply_on_var_node) \ | cb(apply_on_var_node) \ | ||||
cb(infer_output_attrs_fallible) \ | cb(infer_output_attrs_fallible) \ | ||||
cb(make_backward_graph) \ | cb(make_backward_graph) \ | ||||
@@ -68,23 +68,46 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true}; | return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true}; | ||||
} | } | ||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
DispatchMode decide_dispatch_mode( | |||||
const OpDef& def, | const OpDef& def, | ||||
const SmallVector<TensorPtr>& inputs) { | |||||
const SmallVector<LogicalTensorDesc>& inputs) { | |||||
bool host_computable = true; | |||||
constexpr int size_threshhold = TensorShape::MAX_NDIM; | |||||
for (auto&& inp : inputs) { | |||||
if (inp.value.empty() || inp.value.layout().ndim == 0 | |||||
|| inp.value.layout().total_nr_elems() > size_threshhold) { | |||||
host_computable = false; | |||||
break; | |||||
} | |||||
} | |||||
return host_computable ? DEFAULT_CPU : KERNEL; | |||||
} | |||||
void apply_on_device_tensornd( | |||||
const OpDef& def, | |||||
const SmallVector<DeviceTensorND>& inputs, | |||||
SmallVector<DeviceTensorND>* outputs) { | |||||
auto&& op_def = def.cast_final_safe<Elemwise>(); | auto&& op_def = def.cast_final_safe<Elemwise>(); | ||||
auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); | auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); | ||||
mgb_assert(inputs.size() == trait.arity, | mgb_assert(inputs.size() == trait.arity, | ||||
"%s expects %u inputs; got %zu actually", trait.name, | "%s expects %u inputs; got %zu actually", trait.name, | ||||
trait.arity, inputs.size()); | trait.arity, inputs.size()); | ||||
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0].comp_node()); | |||||
opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr); | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs) { | |||||
DeviceTensorND out; | |||||
SmallVector<DeviceTensorND> dt_inputs(inputs.size()); | |||||
SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); | |||||
for (unsigned i = 0; i < inputs.size(); ++i){ | for (unsigned i = 0; i < inputs.size(); ++i){ | ||||
dt_inputs[i] = inputs[i]->dev_tensor(); | |||||
inp_tensornds[i] = inputs[i]->dev_tensor(); | |||||
} | } | ||||
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0]->comp_node()); | |||||
opr::Elemwise::perform(op_def.mode, out, dt_inputs, dnn_opr); | |||||
return {Tensor::make(out)}; | |||||
SmallVector<DeviceTensorND> oup_tensornds = {{inp_tensornds[0].comp_node(), inp_tensornds[0].dtype()}}; | |||||
apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); | |||||
return {Tensor::make(oup_tensornds[0])}; | |||||
} | } | ||||
MGB_DEFINE_OPR_CLASS(ForceInplaceElemwise, cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{ | MGB_DEFINE_OPR_CLASS(ForceInplaceElemwise, cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{ | ||||
@@ -214,8 +237,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_ | |||||
OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) | OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) | ||||
.make_from_op_node(make_from_op_node) | .make_from_op_node(make_from_op_node) | ||||
.decide_dispatch_mode(decide_dispatch_mode) | |||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
.apply_on_device_tensornd(apply_on_device_tensornd) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | .apply_on_physical_tensor(apply_on_physical_tensor) | ||||
.fallback(); | .fallback(); | ||||
@@ -15,8 +15,8 @@ | |||||
#include "../op_trait.h" | #include "../op_trait.h" | ||||
namespace mgb::imperative { | namespace mgb::imperative { | ||||
namespace { | |||||
namespace get_var_shape { | |||||
cg::OperatorNodeBase* apply_on_var_node( | cg::OperatorNodeBase* apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
@@ -24,17 +24,38 @@ cg::OperatorNodeBase* apply_on_var_node( | |||||
return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr(); | return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr(); | ||||
} | } | ||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
DispatchMode decide_dispatch_mode( | |||||
const OpDef& def, | const OpDef& def, | ||||
const SmallVector<TensorPtr>& inputs) { | |||||
const SmallVector<LogicalTensorDesc>& inputs) { | |||||
bool host_computable = true; | |||||
for (auto&& inp : inputs) { | |||||
// FIXME(czh): remove value chech after proxy graph's | |||||
// apply_on_device_tensornd is supported and output Tensor | |||||
// is made before add_task. | |||||
// then if layout is valid, ptr->layout must be ready | |||||
if (inp.value.empty() || inp.value.layout().ndim == 0) { | |||||
host_computable = false; | |||||
break; | |||||
} | |||||
} | |||||
return host_computable ? DEFAULT_CPU : KERNEL; | |||||
} | |||||
void apply_on_device_tensornd( | |||||
const OpDef& def, | |||||
const SmallVector<DeviceTensorND>& inputs, | |||||
SmallVector<DeviceTensorND>* outputs) { | |||||
auto&& op_def = def.cast_final_safe<GetVarShape>(); | auto&& op_def = def.cast_final_safe<GetVarShape>(); | ||||
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); | mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); | ||||
auto&& inp = inputs[0]; | auto&& inp = inputs[0]; | ||||
auto&& shp = inp->layout(); | |||||
auto&& shp = inp.layout(); | |||||
mgb_assert(shp.ndim != 0, "input shape invalid"); | mgb_assert(shp.ndim != 0, "input shape invalid"); | ||||
mgb_assert((*outputs)[0].comp_node() == CompNode::default_cpu(), | |||||
"GetVarShape's apply_on_device_tensornd should receive default_cpu outputs."); | |||||
HostTensorND hv; | HostTensorND hv; | ||||
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ | |||||
hv = HostTensorND(inp->comp_node(), {shp.ndim}, dtype::Int32()); | |||||
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) { | |||||
hv = HostTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32()); | |||||
auto* ptr = hv.ptr<dt_int32>(); | auto* ptr = hv.ptr<dt_int32>(); | ||||
for (size_t i = 0; i < shp.ndim; ++i) { | for (size_t i = 0; i < shp.ndim; ++i) { | ||||
ptr[i] = shp.shape[i]; | ptr[i] = shp.shape[i]; | ||||
@@ -45,11 +66,29 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
axis += shp.ndim; | axis += shp.ndim; | ||||
} | } | ||||
mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim); | mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim); | ||||
hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32()); | |||||
hv = HostTensorND(CompNode::default_cpu(), {1}, dtype::Int32()); | |||||
auto* ptr = hv.ptr<dt_int32>(); | auto* ptr = hv.ptr<dt_int32>(); | ||||
ptr[0] = shp.shape[axis]; | ptr[0] = shp.shape[axis]; | ||||
} | } | ||||
return {Tensor::make(std::move(hv))}; | |||||
(*outputs)[0] = DeviceTensorND::make_proxy(hv); | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
const OpDef& def, | |||||
const SmallVector<TensorPtr>& inputs) { | |||||
SmallVector<DeviceTensorND> input_tensornds; | |||||
input_tensornds.reserve(inputs.size()); | |||||
for (auto&& inp : inputs) { | |||||
input_tensornds.push_back(inp->dev_tensor()); | |||||
} | |||||
SmallVector<DeviceTensorND> output_tensornds = {{CompNode::default_cpu(), dtype::Int32()}}; | |||||
apply_on_device_tensornd(def, input_tensornds, &output_tensornds); | |||||
// restore to input comp_node | |||||
HostTensorND host_tensornd = HostTensorND::make_proxy(output_tensornds[0]) | |||||
.proxy_to_comp_node(inputs[0]->comp_node()); | |||||
return {Tensor::make(std::move(host_tensornd))}; | |||||
} | } | ||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | ||||
@@ -62,7 +101,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; | return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; | ||||
} | } | ||||
DeviceTensorND value; | DeviceTensorND value; | ||||
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ | |||||
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) { | |||||
value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); | value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); | ||||
auto* ptr = value.ptr<dt_int32>(); | auto* ptr = value.ptr<dt_int32>(); | ||||
for (size_t i = 0; i < desc.layout.ndim; ++i) { | for (size_t i = 0; i < desc.layout.ndim; ++i) { | ||||
@@ -88,11 +127,15 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) | OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) | ||||
.make_from_op_node(make_from_op_node) | .make_from_op_node(make_from_op_node) | ||||
.decide_dispatch_mode(decide_dispatch_mode) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.apply_on_device_tensornd(apply_on_device_tensornd) | |||||
.apply_on_physical_tensor(apply_on_physical_tensor) | .apply_on_physical_tensor(apply_on_physical_tensor) | ||||
.fallback(); | .fallback(); | ||||
} // get_var_shape | |||||
namespace param_pack { | |||||
TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) { | TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) { | ||||
TensorShapeArray ret; | TensorShapeArray ret; | ||||
for (auto&& i:shapes) { | for (auto&& i:shapes) { | ||||
@@ -156,6 +199,6 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( | |||||
OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) | OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) | ||||
.apply_on_var_node(param_pack_concat_apply_on_var_node) | .apply_on_var_node(param_pack_concat_apply_on_var_node) | ||||
.fallback(); | .fallback(); | ||||
} // namespace | |||||
} // param_pack | |||||
} // namespace mgb::imperative | } // namespace mgb::imperative |
@@ -20,6 +20,11 @@ namespace imperative { | |||||
class OpDef; | class OpDef; | ||||
struct OpTrait; | struct OpTrait; | ||||
enum DispatchMode { | |||||
DEFAULT_CPU = 0, | |||||
KERNEL = 1 | |||||
}; | |||||
struct BackwardGraphResult { | struct BackwardGraphResult { | ||||
std::shared_ptr<OpDef> backward; | std::shared_ptr<OpDef> backward; | ||||
std::vector<bool> save_for_backward; | std::vector<bool> save_for_backward; | ||||
@@ -36,10 +41,31 @@ public: | |||||
static std::shared_ptr<OpDef> make_from_op_node( | static std::shared_ptr<OpDef> make_from_op_node( | ||||
cg::OperatorNodeBase* node); | cg::OperatorNodeBase* node); | ||||
/*! | |||||
* \brief Decide which dispatch method to be used according to the inputs' | |||||
* host value and size. | |||||
* | |||||
* \param def Specific :c:expr:`OpDef` to be executed. | |||||
* \param inputs Input tensor descriptions. | |||||
* \return Which DispatchMode to be used, such as `CUDA` or `DEFAULT_CPU`. | |||||
*/ | |||||
static DispatchMode decide_dispatch_mode( | |||||
const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs); | |||||
static SmallVector<TensorPtr> apply_on_physical_tensor( | static SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
const OpDef& def, | const OpDef& def, | ||||
SmallVector<TensorPtr> inputs); | SmallVector<TensorPtr> inputs); | ||||
/*! | |||||
* \brief Call the corresponding dnn op to calculate results. Output | |||||
* tensors' device memory should be allocated outside. | |||||
*/ | |||||
static void apply_on_device_tensornd( | |||||
const OpDef& def, | |||||
const SmallVector<DeviceTensorND>& inputs, | |||||
SmallVector<DeviceTensorND>* outputs); | |||||
static cg::VarNodeArray apply_on_var_node( | static cg::VarNodeArray apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs); | const VarNodeArray& inputs); | ||||