GitOrigin-RevId: 2663504470
tags/v1.3.0
@@ -29,7 +29,7 @@ Interpreter& Interpreter::inst() { | |||
return inst_; | |||
} | |||
void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||
auto info = alloc(); | |||
info->desc.layout = value.layout(); | |||
info->desc.comp_node = value.comp_node(); | |||
@@ -39,7 +39,7 @@ void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||
return info; | |||
} | |||
void* ChannelImpl::put(const DeviceTensorND& data) { | |||
Handle ChannelImpl::put(const DeviceTensorND& data) { | |||
auto info = alloc(); | |||
info->desc.layout = data.layout(); | |||
info->desc.comp_node = data.comp_node(); | |||
@@ -48,12 +48,12 @@ void* ChannelImpl::put(const DeviceTensorND& data) { | |||
return info; | |||
} | |||
void ChannelImpl::del(void* handle) { | |||
void ChannelImpl::del(Handle handle) { | |||
mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | |||
m_buffer.enqueue(Del{reinterpret_cast<TensorInfo*>(handle)}); | |||
} | |||
void ChannelImpl::swap_in(void* handle) { | |||
void ChannelImpl::swap_in(Handle handle) { | |||
if (m_enable_evict & SWAP) { | |||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
"invalid handle: %p", handle); | |||
@@ -61,7 +61,7 @@ void ChannelImpl::swap_in(void* handle) { | |||
} | |||
} | |||
void ChannelImpl::swap_out(void* handle) { | |||
void ChannelImpl::swap_out(Handle handle) { | |||
if (m_enable_evict & SWAP) { | |||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
"invalid handle: %p", handle); | |||
@@ -69,7 +69,7 @@ void ChannelImpl::swap_out(void* handle) { | |||
} | |||
} | |||
void ChannelImpl::drop(void* handle) { | |||
void ChannelImpl::drop(Handle handle) { | |||
if (m_enable_evict & DROP) { | |||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
"invalid handle: %p", handle); | |||
@@ -77,45 +77,91 @@ void ChannelImpl::drop(void* handle) { | |||
} | |||
} | |||
SmallVector<void*> ChannelImpl::apply_op( | |||
void ChannelImpl::dispatch_default_cpu( | |||
std::shared_ptr<OpDef> op, | |||
const SmallVector<void*>& inputs) { | |||
for (auto i : inputs) { | |||
mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), | |||
"invalid handle: %p", i); | |||
} | |||
SmallVector<TensorInfo*> input_infos; | |||
input_infos.reserve(inputs.size()); | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
input_descs.reserve(inputs.size()); | |||
const SmallVector<TensorInfo*>& input_infos, | |||
const SmallVector<LogicalTensorDesc>& input_descs, | |||
SmallVector<Handle>* outputs) { | |||
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
SmallVector<DeviceTensorND> input_tensornds; | |||
input_tensornds.reserve(input_descs.size()); | |||
CompNode output_cn; | |||
{ | |||
MGB_LOCK_GUARD(m_mutex); | |||
for (auto i : inputs) { | |||
auto info = reinterpret_cast<TensorInfo*>(i); | |||
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); | |||
input_infos.push_back(info); | |||
input_descs.push_back(info->desc); | |||
for (auto&& info : input_infos) { | |||
mgb_assert(info->ptr, "invalid tensor ptr!"); | |||
if (!output_cn.valid()) { | |||
output_cn = info->ptr->comp_node(); | |||
} else { | |||
mgb_assert(output_cn == info->ptr->comp_node(), "cannot decide output comp node"); | |||
} | |||
mgb_assert(info->ptr->try_get_value(), "no valid host value"); | |||
input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu()); | |||
} | |||
} | |||
outputs->reserve(output_descs.size()); | |||
SmallVector<DeviceTensorND> output_tensornds; | |||
output_tensornds.reserve(output_descs.size()); | |||
for (auto&& desc : output_descs) { | |||
// TODO: may conflict with condtake, which need alloc inside | |||
mgb_assert(!desc.layout.is_empty()); | |||
// use HostTensorND alloc_host for cuda pinned memory | |||
output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu()); | |||
} | |||
OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); | |||
SmallVector<TensorInfo*> output_infos; | |||
output_infos.reserve(output_descs.size()); | |||
for (auto&& tensornd : output_tensornds) { | |||
// tensornd -> host_tensornd | |||
HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd) | |||
.proxy_to_comp_node(output_cn); | |||
// tensornd -> desc | |||
LogicalTensorDesc desc = {tensornd.layout(), output_cn, tensornd}; | |||
// tensornd -> tensor | |||
auto info = alloc(); | |||
info->desc = desc; | |||
m_valid_handle.insert(info); | |||
output_infos.push_back(info); | |||
info->ptr = Tensor::make(host_tensornd, true); // host_only=true | |||
info->value_fetched = true; | |||
outputs->push_back(info); | |||
} | |||
if (m_enable_evict & DROP) { | |||
for (auto out : output_infos) { | |||
out->path.op = op; | |||
for (auto out_ : output_infos) { | |||
out->path.outputs.push_back(m_st.at(out_)); | |||
} | |||
for (auto inp : input_infos) { | |||
out->path.inputs.push_back(m_st.at(inp)); | |||
inp->path.dep_outputs.push_back(m_st.at(out)); | |||
} | |||
} | |||
} | |||
} | |||
void ChannelImpl::dispatch_kernel( | |||
std::shared_ptr<OpDef> op, | |||
const SmallVector<TensorInfo*>& input_infos, | |||
const SmallVector<LogicalTensorDesc>& input_descs, | |||
SmallVector<Handle>* outputs) { | |||
auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
ApplyOp cmd{std::move(op)}; | |||
cmd.inputs = std::move(input_infos); | |||
cmd.outputs.reserve(output_descs.size()); | |||
SmallVector<void*> outputs; | |||
// FIXME: remove this check when op check is correct | |||
bool validated_bkp = true; | |||
for (size_t i = 0;i < output_descs.size();i ++) { | |||
auto&& desc = output_descs[i]; | |||
if (desc.layout.ndim == 0) { | |||
validated_bkp = false; | |||
} | |||
outputs->reserve(output_descs.size()); | |||
for (auto&& desc : output_descs) { | |||
auto info = alloc(); | |||
info->desc = desc; | |||
m_valid_handle.insert(info); | |||
cmd.outputs.push_back(info); | |||
outputs.push_back(info); | |||
outputs->push_back(info); | |||
} | |||
if (m_enable_evict & DROP) { | |||
for (auto out : cmd.outputs) { | |||
@@ -130,20 +176,55 @@ SmallVector<void*> ChannelImpl::apply_op( | |||
} | |||
} | |||
m_buffer.enqueue(std::move(cmd)); | |||
if (!(validated && validated_bkp) && m_async_level == 1) { | |||
if (!validated && m_async_level == 1) { | |||
sync(); | |||
} else if (m_async_level == 0) { | |||
sync(); | |||
// check device error | |||
for (auto&& oup : outputs) { | |||
for (auto&& oup : *outputs) { | |||
auto info = reinterpret_cast<TensorInfo*>(oup); | |||
info->ptr->comp_node().sync(); | |||
} | |||
} | |||
} | |||
SmallVector<Handle> ChannelImpl::apply_op( | |||
std::shared_ptr<OpDef> op, | |||
const SmallVector<Handle>& inputs) { | |||
for (auto i : inputs) { | |||
mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), | |||
"invalid handle: %p", i); | |||
} | |||
SmallVector<TensorInfo*> input_infos; | |||
input_infos.reserve(inputs.size()); | |||
SmallVector<LogicalTensorDesc> input_descs; | |||
input_descs.reserve(inputs.size()); | |||
{ | |||
MGB_LOCK_GUARD(m_mutex); | |||
for (auto i : inputs) { | |||
auto info = reinterpret_cast<TensorInfo*>(i); | |||
mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); | |||
input_infos.push_back(info); | |||
input_descs.push_back(info->desc); | |||
} | |||
} | |||
SmallVector<Handle> outputs; | |||
switch (OpDef::decide_dispatch_mode(*op, input_descs)) { | |||
case DEFAULT_CPU: { | |||
dispatch_default_cpu(op, input_infos, input_descs, &outputs); | |||
break; | |||
} | |||
case KERNEL: { | |||
dispatch_kernel(op, input_infos, input_descs, &outputs); | |||
break; | |||
} | |||
} | |||
mgb_assert(outputs.size() > 0, "Invalid dispatch mode!"); | |||
return outputs; | |||
} | |||
HostTensorND ChannelImpl::get_value(void* handle) { | |||
HostTensorND ChannelImpl::get_value(Handle handle) { | |||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
"invalid handle: %p", handle); | |||
auto info = reinterpret_cast<TensorInfo*>(handle); | |||
@@ -163,7 +244,7 @@ HostTensorND ChannelImpl::get_value(void* handle) { | |||
return info->ptr->get_value(); | |||
} | |||
TensorShape ChannelImpl::get_shape(void* handle) { | |||
TensorShape ChannelImpl::get_shape(Handle handle) { | |||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
"invalid handle: %p", handle); | |||
auto info = reinterpret_cast<TensorInfo*>(handle); | |||
@@ -184,7 +265,7 @@ TensorShape ChannelImpl::get_shape(void* handle) { | |||
return ret; | |||
} | |||
DType ChannelImpl::get_dtype(void* handle) { | |||
DType ChannelImpl::get_dtype(Handle handle) { | |||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
"invalid handle: %p", handle); | |||
auto info = reinterpret_cast<TensorInfo*>(handle); | |||
@@ -193,7 +274,7 @@ DType ChannelImpl::get_dtype(void* handle) { | |||
return ret; | |||
} | |||
CompNode ChannelImpl::get_device(void* handle) { | |||
CompNode ChannelImpl::get_device(Handle handle) { | |||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
"invalid handle: %p", handle); | |||
auto info = reinterpret_cast<TensorInfo*>(handle); | |||
@@ -202,7 +283,7 @@ CompNode ChannelImpl::get_device(void* handle) { | |||
return ret; | |||
} | |||
DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) { | |||
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { | |||
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
"invalid handle: %p", handle); | |||
auto info = reinterpret_cast<TensorInfo*>(handle); | |||
@@ -262,25 +343,15 @@ ChannelImpl::~ChannelImpl() { | |||
} | |||
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) { | |||
if (notice) { | |||
MGB_LOCK_GUARD(m_mutex); | |||
dest->value_fetched = ptr->value_fetched(); | |||
// update tensor desc for static infer | |||
// if (dest->desc.layout.ndim) { | |||
// mgb_assert(dest->desc.layout.eq_shape(ptr->layout())); | |||
// } | |||
dest->desc.layout = ptr->layout(); | |||
dest->desc.comp_node = ptr->comp_node(); | |||
dest->ptr = std::move(ptr); | |||
if (m_waitee == dest) { | |||
m_cv.notify_all(); | |||
} | |||
} else { | |||
dest->value_fetched = ptr->value_fetched(); | |||
// update tensor desc for static infer | |||
dest->desc.layout = ptr->layout(); | |||
dest->desc.comp_node = ptr->comp_node(); | |||
dest->ptr = std::move(ptr); | |||
auto lock = notice ? std::unique_lock<std::mutex>(m_mutex) | |||
: std::unique_lock<std::mutex>(); | |||
dest->value_fetched = ptr->value_fetched(); | |||
// update tensor desc for static infer | |||
dest->desc.layout = ptr->layout(); | |||
dest->desc.comp_node = ptr->comp_node(); | |||
dest->ptr = std::move(ptr); | |||
if (notice && m_waitee == dest) { | |||
m_cv.notify_all(); | |||
} | |||
} | |||
@@ -295,7 +366,7 @@ void ChannelImpl::do_swap_out(TensorInfo* dest) { | |||
dest->evict_type = SWAP; | |||
dest->value_fetched = false; | |||
// TODO: swap in parallel | |||
dest->h_value.copy_from(dest->ptr->dev_tensor()).sync(); | |||
dest->h_value = dest->ptr->get_value(); | |||
dest->ptr.reset(); | |||
} | |||
@@ -198,6 +198,17 @@ private: | |||
void do_drop(TensorInfo* dest); | |||
void regenerate(TensorInfo* dest, bool must_drop); | |||
void dispatch_default_cpu( | |||
std::shared_ptr<OpDef> op, | |||
const SmallVector<TensorInfo*>& input_infos, | |||
const SmallVector<LogicalTensorDesc>& input_descs, | |||
SmallVector<Handle>* outputs); | |||
void dispatch_kernel( | |||
std::shared_ptr<OpDef> op, | |||
const SmallVector<TensorInfo*>& input_infos, | |||
const SmallVector<LogicalTensorDesc>& input_descs, | |||
SmallVector<Handle>* outputs); | |||
std::mutex m_mutex; | |||
std::condition_variable m_cv; | |||
MemPool<TensorInfo> m_pool; | |||
@@ -30,12 +30,26 @@ std::shared_ptr<OpDef> OpDef::make_from_op_node( | |||
return trait->make_from_op_node(node); | |||
} | |||
DispatchMode OpDef::decide_dispatch_mode( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs) { | |||
return def.trait()->decide_dispatch_mode(def, inputs); | |||
} | |||
SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | |||
const OpDef& def, | |||
SmallVector<TensorPtr> inputs) { | |||
return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); | |||
} | |||
void OpDef::apply_on_device_tensornd( | |||
const OpDef& def, | |||
const SmallVector<DeviceTensorND>& inputs, | |||
SmallVector<DeviceTensorND>* outputs) { | |||
def.trait()->apply_on_device_tensornd(def, inputs, outputs); | |||
return; | |||
} | |||
VarNodeArray OpDef::apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
@@ -9,12 +9,16 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include <exception> | |||
#include <sstream> | |||
#include <stdexcept> | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/imperative/proxy_graph_detail.h" | |||
#include "megbrain/tensor.h" | |||
#include "./op_trait.h" | |||
#include "megbrain/imperative/proxy_graph_detail.h" | |||
namespace mgb { | |||
namespace imperative { | |||
@@ -62,6 +66,12 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){ | |||
} | |||
} | |||
DispatchMode fallback_decide_dispatch_mode( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs) { | |||
return KERNEL; | |||
} | |||
OpTraitRegistry& OpTraitRegistry::fallback() { | |||
if (trait->apply_on_var_node) { | |||
// fallback to proxy graph impl | |||
@@ -78,6 +88,9 @@ OpTraitRegistry& OpTraitRegistry::fallback() { | |||
proxy_graph_detail::make_backward_graph; | |||
} | |||
} | |||
if (!trait->decide_dispatch_mode) { | |||
trait->decide_dispatch_mode = fallback_decide_dispatch_mode; | |||
} | |||
return *this; | |||
} | |||
@@ -60,8 +60,12 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type { | |||
using OpDefMaker = detail::OpMeth< | |||
decltype(OpDef::make_from_op_node)>; | |||
using DecideDispatchMode = detail::OpMeth< | |||
decltype(OpDef::decide_dispatch_mode)>; | |||
using ApplyOnPhysicalTensor = detail::OpMeth< | |||
decltype(OpDef::apply_on_physical_tensor)>; | |||
using ApplyOnDeviceTensorND = detail::OpMeth< | |||
decltype(OpDef::apply_on_device_tensornd)>; | |||
using ApplyOnVarNode = detail::OpMeth< | |||
decltype(OpDef::apply_on_var_node)>; | |||
using InferOutputAttrsFallible = detail::OpMeth< | |||
@@ -74,7 +78,9 @@ using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; | |||
struct OpTrait { | |||
const char* name; | |||
OpDefMaker make_from_op_node; | |||
DecideDispatchMode decide_dispatch_mode; | |||
ApplyOnPhysicalTensor apply_on_physical_tensor; | |||
ApplyOnDeviceTensorND apply_on_device_tensornd; | |||
ApplyOnVarNode apply_on_var_node; | |||
InferOutputAttrsFallible infer_output_attrs_fallible; | |||
GradMaker make_backward_graph; | |||
@@ -88,7 +94,9 @@ struct OpTrait { | |||
#define FOR_EACH_OP_METH(cb) \ | |||
cb(make_from_op_node) \ | |||
cb(decide_dispatch_mode) \ | |||
cb(apply_on_physical_tensor) \ | |||
cb(apply_on_device_tensornd) \ | |||
cb(apply_on_var_node) \ | |||
cb(infer_output_attrs_fallible) \ | |||
cb(make_backward_graph) \ | |||
@@ -68,23 +68,46 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true}; | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
DispatchMode decide_dispatch_mode( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs) { | |||
const SmallVector<LogicalTensorDesc>& inputs) { | |||
bool host_computable = true; | |||
constexpr int size_threshhold = TensorShape::MAX_NDIM; | |||
for (auto&& inp : inputs) { | |||
if (inp.value.empty() || inp.value.layout().ndim == 0 | |||
|| inp.value.layout().total_nr_elems() > size_threshhold) { | |||
host_computable = false; | |||
break; | |||
} | |||
} | |||
return host_computable ? DEFAULT_CPU : KERNEL; | |||
} | |||
void apply_on_device_tensornd( | |||
const OpDef& def, | |||
const SmallVector<DeviceTensorND>& inputs, | |||
SmallVector<DeviceTensorND>* outputs) { | |||
auto&& op_def = def.cast_final_safe<Elemwise>(); | |||
auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); | |||
mgb_assert(inputs.size() == trait.arity, | |||
"%s expects %u inputs; got %zu actually", trait.name, | |||
trait.arity, inputs.size()); | |||
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0].comp_node()); | |||
opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr); | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs) { | |||
DeviceTensorND out; | |||
SmallVector<DeviceTensorND> dt_inputs(inputs.size()); | |||
SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); | |||
for (unsigned i = 0; i < inputs.size(); ++i){ | |||
dt_inputs[i] = inputs[i]->dev_tensor(); | |||
inp_tensornds[i] = inputs[i]->dev_tensor(); | |||
} | |||
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0]->comp_node()); | |||
opr::Elemwise::perform(op_def.mode, out, dt_inputs, dnn_opr); | |||
return {Tensor::make(out)}; | |||
SmallVector<DeviceTensorND> oup_tensornds = {{inp_tensornds[0].comp_node(), inp_tensornds[0].dtype()}}; | |||
apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); | |||
return {Tensor::make(oup_tensornds[0])}; | |||
} | |||
MGB_DEFINE_OPR_CLASS(ForceInplaceElemwise, cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{ | |||
@@ -214,8 +237,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_ | |||
OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) | |||
.make_from_op_node(make_from_op_node) | |||
.decide_dispatch_mode(decide_dispatch_mode) | |||
.apply_on_var_node(apply_on_var_node) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.apply_on_device_tensornd(apply_on_device_tensornd) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.fallback(); | |||
@@ -15,8 +15,8 @@ | |||
#include "../op_trait.h" | |||
namespace mgb::imperative { | |||
namespace { | |||
namespace get_var_shape { | |||
cg::OperatorNodeBase* apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
@@ -24,17 +24,38 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr(); | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
DispatchMode decide_dispatch_mode( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs) { | |||
const SmallVector<LogicalTensorDesc>& inputs) { | |||
bool host_computable = true; | |||
for (auto&& inp : inputs) { | |||
// FIXME(czh): remove value chech after proxy graph's | |||
// apply_on_device_tensornd is supported and output Tensor | |||
// is made before add_task. | |||
// then if layout is valid, ptr->layout must be ready | |||
if (inp.value.empty() || inp.value.layout().ndim == 0) { | |||
host_computable = false; | |||
break; | |||
} | |||
} | |||
return host_computable ? DEFAULT_CPU : KERNEL; | |||
} | |||
void apply_on_device_tensornd( | |||
const OpDef& def, | |||
const SmallVector<DeviceTensorND>& inputs, | |||
SmallVector<DeviceTensorND>* outputs) { | |||
auto&& op_def = def.cast_final_safe<GetVarShape>(); | |||
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); | |||
auto&& inp = inputs[0]; | |||
auto&& shp = inp->layout(); | |||
auto&& shp = inp.layout(); | |||
mgb_assert(shp.ndim != 0, "input shape invalid"); | |||
mgb_assert((*outputs)[0].comp_node() == CompNode::default_cpu(), | |||
"GetVarShape's apply_on_device_tensornd should receive default_cpu outputs."); | |||
HostTensorND hv; | |||
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ | |||
hv = HostTensorND(inp->comp_node(), {shp.ndim}, dtype::Int32()); | |||
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) { | |||
hv = HostTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32()); | |||
auto* ptr = hv.ptr<dt_int32>(); | |||
for (size_t i = 0; i < shp.ndim; ++i) { | |||
ptr[i] = shp.shape[i]; | |||
@@ -45,11 +66,29 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
axis += shp.ndim; | |||
} | |||
mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim); | |||
hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32()); | |||
hv = HostTensorND(CompNode::default_cpu(), {1}, dtype::Int32()); | |||
auto* ptr = hv.ptr<dt_int32>(); | |||
ptr[0] = shp.shape[axis]; | |||
} | |||
return {Tensor::make(std::move(hv))}; | |||
(*outputs)[0] = DeviceTensorND::make_proxy(hv); | |||
} | |||
SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, | |||
const SmallVector<TensorPtr>& inputs) { | |||
SmallVector<DeviceTensorND> input_tensornds; | |||
input_tensornds.reserve(inputs.size()); | |||
for (auto&& inp : inputs) { | |||
input_tensornds.push_back(inp->dev_tensor()); | |||
} | |||
SmallVector<DeviceTensorND> output_tensornds = {{CompNode::default_cpu(), dtype::Int32()}}; | |||
apply_on_device_tensornd(def, input_tensornds, &output_tensornds); | |||
// restore to input comp_node | |||
HostTensorND host_tensornd = HostTensorND::make_proxy(output_tensornds[0]) | |||
.proxy_to_comp_node(inputs[0]->comp_node()); | |||
return {Tensor::make(std::move(host_tensornd))}; | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
@@ -62,7 +101,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; | |||
} | |||
DeviceTensorND value; | |||
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ | |||
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) { | |||
value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); | |||
auto* ptr = value.ptr<dt_int32>(); | |||
for (size_t i = 0; i < desc.layout.ndim; ++i) { | |||
@@ -88,11 +127,15 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) | |||
.make_from_op_node(make_from_op_node) | |||
.decide_dispatch_mode(decide_dispatch_mode) | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
.apply_on_var_node(apply_on_var_node) | |||
.apply_on_device_tensornd(apply_on_device_tensornd) | |||
.apply_on_physical_tensor(apply_on_physical_tensor) | |||
.fallback(); | |||
} // get_var_shape | |||
namespace param_pack { | |||
TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) { | |||
TensorShapeArray ret; | |||
for (auto&& i:shapes) { | |||
@@ -156,6 +199,6 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( | |||
OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) | |||
.apply_on_var_node(param_pack_concat_apply_on_var_node) | |||
.fallback(); | |||
} // namespace | |||
} // param_pack | |||
} // namespace mgb::imperative |
@@ -20,6 +20,11 @@ namespace imperative { | |||
class OpDef; | |||
struct OpTrait; | |||
enum DispatchMode { | |||
DEFAULT_CPU = 0, | |||
KERNEL = 1 | |||
}; | |||
struct BackwardGraphResult { | |||
std::shared_ptr<OpDef> backward; | |||
std::vector<bool> save_for_backward; | |||
@@ -36,10 +41,31 @@ public: | |||
static std::shared_ptr<OpDef> make_from_op_node( | |||
cg::OperatorNodeBase* node); | |||
/*! | |||
* \brief Decide which dispatch method to be used according to the inputs' | |||
* host value and size. | |||
* | |||
* \param def Specific :c:expr:`OpDef` to be executed. | |||
* \param inputs Input tensor descriptions. | |||
* \return Which DispatchMode to be used, such as `CUDA` or `DEFAULT_CPU`. | |||
*/ | |||
static DispatchMode decide_dispatch_mode( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs); | |||
static SmallVector<TensorPtr> apply_on_physical_tensor( | |||
const OpDef& def, | |||
SmallVector<TensorPtr> inputs); | |||
/*! | |||
* \brief Call the corresponding dnn op to calculate results. Output | |||
* tensors' device memory should be allocated outside. | |||
*/ | |||
static void apply_on_device_tensornd( | |||
const OpDef& def, | |||
const SmallVector<DeviceTensorND>& inputs, | |||
SmallVector<DeviceTensorND>* outputs); | |||
static cg::VarNodeArray apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs); | |||