GitOrigin-RevId: bff62b33a0
tags/v1.9.0
@@ -12,13 +12,3 @@ from contextlib import contextmanager | |||||
from ._imperative_rt.core2 import get_option, set_option | from ._imperative_rt.core2 import get_option, set_option | ||||
from .tensor.megbrain_graph import Graph | from .tensor.megbrain_graph import Graph | ||||
@contextmanager | |||||
def option(key, value): | |||||
value = int(value) | |||||
old = get_option(key) | |||||
set_option(key, value) | |||||
yield | |||||
assert get_option(key) == value | |||||
set_option(key, old) |
@@ -76,10 +76,11 @@ def test_drop_basic(): | |||||
def test_finalize(): | def test_finalize(): | ||||
prog = """ | prog = """ | ||||
import megengine | import megengine | ||||
with megengine.core.option("enable_host_compute", 0): | |||||
x = megengine.tensor(0) | |||||
y = x + 1 | |||||
y.numpy() | |||||
megengine.core.set_option("enable_host_compute", 0) | |||||
x = megengine.tensor(0) | |||||
y = x + 1 | |||||
y.numpy() | |||||
megengine.core.set_option("enable_host_compute", 1) | |||||
""" | """ | ||||
subprocess.check_call([sys.executable, "-c", prog]) | subprocess.check_call([sys.executable, "-c", prog]) | ||||
@@ -15,7 +15,6 @@ import pytest | |||||
from megengine import Parameter | from megengine import Parameter | ||||
from megengine import distributed as dist | from megengine import distributed as dist | ||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.core import option | |||||
from megengine.jit import trace | from megengine.jit import trace | ||||
from megengine.module import Module | from megengine.module import Module | ||||
from megengine.utils.profiler import Profiler, scope | from megengine.utils.profiler import Profiler, scope | ||||
@@ -155,7 +155,6 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { | |||||
info->h_value = value; | info->h_value = value; | ||||
info->desc.value = value.proxy_to_default_cpu(); | info->desc.value = value.proxy_to_default_cpu(); | ||||
} | } | ||||
info->mem_desc.id = StorageIdentifier::make(++m_storage_id); | |||||
m_worker.add_task( | m_worker.add_task( | ||||
{Profiler::next_id(), Put{info, value, no_cache}, | {Profiler::next_id(), Put{info, value, no_cache}, | ||||
get_channel_state().stack_manager.dump()}); | get_channel_state().stack_manager.dump()}); | ||||
@@ -180,7 +179,6 @@ TensorInfo* ChannelImpl::put_impl( | |||||
auto info = alloc(); | auto info = alloc(); | ||||
MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put); | MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put); | ||||
init(info, {data.layout(), data.comp_node()}); | init(info, {data.layout(), data.comp_node()}); | ||||
info->mem_desc.id = StorageIdentifier::make(++m_storage_id); | |||||
info->ptr = Tensor::make(data, hvalue); | info->ptr = Tensor::make(data, hvalue); | ||||
MGB_RECORD_EVENT( | MGB_RECORD_EVENT( | ||||
TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, | TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, | ||||
@@ -536,9 +534,6 @@ void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) { | |||||
MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name); | MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name); | ||||
info->status = TensorInfo::Allocated; | info->status = TensorInfo::Allocated; | ||||
info->desc = std::move(desc); | info->desc = std::move(desc); | ||||
info->mem_desc.layout = info->desc.layout; | |||||
info->mem_desc.cn = info->desc.comp_node; | |||||
info->mem_desc.offset = 0; | |||||
} | } | ||||
void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) { | void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) { | ||||
@@ -667,18 +662,14 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||||
bool profiling_device = | bool profiling_device = | ||||
Profiler::is_profiling() && Profiler::get_option("profile_device", 0); | Profiler::is_profiling() && Profiler::get_option("profile_device", 0); | ||||
uint64_t apply_id = cmd.id; | uint64_t apply_id = cmd.id; | ||||
struct TensorWithDesc { | |||||
TensorPtr tensor; | |||||
MemoryDesc desc; | |||||
}; | |||||
SmallVector<TensorWithDesc> inputs; | |||||
SmallVector<TensorPtr> inputs; | |||||
inputs.reserve(cmd.inputs.size()); | inputs.reserve(cmd.inputs.size()); | ||||
// refcnt == 1, owners: [TensorInfo::ptr] | // refcnt == 1, owners: [TensorInfo::ptr] | ||||
for (auto i : cmd.inputs) { | for (auto i : cmd.inputs) { | ||||
mgb_assert(i->ptr, "Invalid input tensor ptr!"); | mgb_assert(i->ptr, "Invalid input tensor ptr!"); | ||||
// refcnt ++, owners: [i->ptr, tensor_inputs] | // refcnt ++, owners: [i->ptr, tensor_inputs] | ||||
// tensor_inputs.push_back(i->ptr); | // tensor_inputs.push_back(i->ptr); | ||||
inputs.push_back({i->ptr, i->mem_desc}); | |||||
inputs.push_back(i->ptr); | |||||
} | } | ||||
if (state.options.enable_dtr_auto_drop && | if (state.options.enable_dtr_auto_drop && | ||||
state.options.dtr_eviction_threshold > 0) { | state.options.dtr_eviction_threshold > 0) { | ||||
@@ -686,56 +677,28 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||||
} | } | ||||
auto apply_on_physical_tensor = | auto apply_on_physical_tensor = | ||||
[&](auto&& self, const OpDef& def, | [&](auto&& self, const OpDef& def, | ||||
SmallVector<TensorWithDesc> inputs) -> SmallVector<TensorWithDesc> { | |||||
SmallVector<TensorPtr> inputs) -> SmallVector<TensorPtr> { | |||||
auto apply_functor = [&](std::shared_ptr<OpDef> op, | auto apply_functor = [&](std::shared_ptr<OpDef> op, | ||||
SmallVector<TensorWithDesc> inputs, | |||||
size_t nr_outputs) -> SmallVector<TensorWithDesc> { | |||||
SmallVector<TensorPtr> inputs, | |||||
size_t nr_outputs) -> SmallVector<TensorPtr> { | |||||
auto opname = op->trait()->make_name(*op); | auto opname = op->trait()->make_name(*op); | ||||
imperative_log_profile_begin(opname.c_str()); | imperative_log_profile_begin(opname.c_str()); | ||||
auto outputs = self(self, *op, inputs); | auto outputs = self(self, *op, inputs); | ||||
imperative_log_profile_end(opname.c_str()); | imperative_log_profile_end(opname.c_str()); | ||||
return outputs; | return outputs; | ||||
}; | }; | ||||
auto const_functor = [&](TensorPtr value) -> TensorWithDesc { | |||||
return {value, MemoryDesc{ | |||||
value->layout(), 0, value->comp_node(), | |||||
StorageIdentifier::make()}}; | |||||
}; | |||||
auto const_functor = [&](TensorPtr value) -> TensorPtr { return value; }; | |||||
if (def.trait()->make_forward_graph) { | if (def.trait()->make_forward_graph) { | ||||
// apply recursivily | // apply recursivily | ||||
SmallVector<LogicalTensorDesc> input_descs; | SmallVector<LogicalTensorDesc> input_descs; | ||||
for (auto&& input : inputs) { | for (auto&& input : inputs) { | ||||
input_descs.push_back( | |||||
{{{}, input.tensor->dtype()}, input.tensor->comp_node()}); | |||||
input_descs.push_back({{{}, input->dtype()}, input->comp_node()}); | |||||
} | } | ||||
auto forward_graph = OpDef::make_forward_graph(def, input_descs); | auto forward_graph = OpDef::make_forward_graph(def, input_descs); | ||||
auto outputs = forward_graph.apply(inputs, apply_functor, const_functor); | auto outputs = forward_graph.apply(inputs, apply_functor, const_functor); | ||||
return outputs; | return outputs; | ||||
} | } | ||||
SmallVector<TensorPtr> input_tensors; | |||||
SmallVector<MemoryDesc> input_descs; | |||||
for (auto&& input : inputs) { | |||||
input_tensors.push_back(input.tensor); | |||||
input_descs.push_back(input.desc); | |||||
} | |||||
auto [output_descs, output_tensors, workspaces] = | |||||
init_output_and_workspace(def, input_tensors, input_descs); | |||||
if (!output_descs.empty()) { | |||||
OpDef::execute(def, input_tensors, output_tensors, workspaces); | |||||
} else { | |||||
output_tensors = OpDef::apply_on_physical_tensor(def, input_tensors); | |||||
for (auto&& output_tensor : output_tensors) { | |||||
output_descs.push_back(MemoryDesc{ | |||||
output_tensor->layout(), 0, output_tensor->comp_node(), | |||||
StorageIdentifier::make()}); | |||||
} | |||||
} | |||||
SmallVector<TensorWithDesc> outputs; | |||||
for (auto&& [output_tensor, output_desc] : | |||||
ranges::zip_view(output_tensors, output_descs)) { | |||||
outputs.push_back({output_tensor, output_desc}); | |||||
} | |||||
return outputs; | |||||
return OpDef::apply_on_physical_tensor(def, inputs); | |||||
}; | }; | ||||
MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason); | MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason); | ||||
// Begin profiling operator | // Begin profiling operator | ||||
@@ -787,8 +750,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||||
MGB_RECORD_EVENT(OpOutputFinishEvent, output->id); | MGB_RECORD_EVENT(OpOutputFinishEvent, output->id); | ||||
} else { | } else { | ||||
MGB_RECORD_EVENT(OpOutputEvent, output->id); | MGB_RECORD_EVENT(OpOutputEvent, output->id); | ||||
produce_tensor(output, outputs[i].tensor); | |||||
output->mem_desc = outputs[i].desc; | |||||
produce_tensor(output, outputs[i]); | |||||
MGB_RECORD_EVENT(OpOutputFinishEvent, output->id); | MGB_RECORD_EVENT(OpOutputFinishEvent, output->id); | ||||
sample_on_device(output->desc.comp_node, false); | sample_on_device(output->desc.comp_node, false); | ||||
} | } | ||||
@@ -800,7 +762,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||||
estimate_compute_time += i->memory; | estimate_compute_time += i->memory; | ||||
} | } | ||||
for (auto i : outputs) { | for (auto i : outputs) { | ||||
estimate_compute_time += i.tensor->blob()->size(); | |||||
estimate_compute_time += i->blob()->size(); | |||||
} | } | ||||
m_dtr.estimate_timestamp += estimate_compute_time / 1e8; | m_dtr.estimate_timestamp += estimate_compute_time / 1e8; | ||||
for (auto i : cmd.outputs) { | for (auto i : cmd.outputs) { | ||||
@@ -1012,52 +974,6 @@ void ChannelImpl::alloc_tensor_with_evict(Blob* x) { | |||||
set_log_level(pre_level); | set_log_level(pre_level); | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> | |||||
ChannelImpl::init_output_and_workspace( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, | |||||
SmallVector<MemoryDesc> inputs_mem_desc) { | |||||
auto [outputs_desc, workspaces_desc] = | |||||
OpDef::infer_output_mem_desc(def, inputs, inputs_mem_desc); | |||||
if (!outputs_desc.size()) { | |||||
// failed to infer memplan | |||||
return {{}, {}, {}}; | |||||
} | |||||
// refine storage id to make it unique | |||||
for (auto&& desc : outputs_desc) { | |||||
if (desc.id->is_sys_alloc()) { | |||||
// TODO: there may be some outputs sharing the same storage id | |||||
desc.id->id = ++m_storage_id; | |||||
} | |||||
} | |||||
auto& state = get_worker_state(); | |||||
auto alloc_storage = [&](SmallVector<MemoryDesc>& desc) { | |||||
SmallVector<TensorPtr> tensors; | |||||
for (size_t i = 0; i < desc.size(); i++) { | |||||
if (desc[i].id->is_sys_alloc()) { | |||||
tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn)); | |||||
if (state.options.enable_dtr_auto_drop && !desc[i].layout.is_empty()) { | |||||
alloc_tensor_with_evict(tensors.back()->blob().get()); | |||||
} | |||||
} else if (desc[i].id->is_from_other()) { | |||||
for (size_t j = 0; j < inputs_mem_desc.size(); j++) { | |||||
if (inputs_mem_desc[j].id->desc == desc[i].id->desc) { | |||||
tensors.push_back( | |||||
inputs[j]->sub(desc[i].offset, desc[i].layout)); | |||||
break; | |||||
} | |||||
} | |||||
} else if (desc[i].id->is_device_ptr()) { | |||||
tensors.push_back(desc[i].id->ptr); | |||||
} else { | |||||
mgb_assert(0, "not implemented"); | |||||
} | |||||
} | |||||
return tensors; | |||||
}; | |||||
return {outputs_desc, alloc_storage(outputs_desc), alloc_storage(workspaces_desc)}; | |||||
} | |||||
void ChannelImpl::process_one_task(Command& icmd) { | void ChannelImpl::process_one_task(Command& icmd) { | ||||
using namespace ranges; | using namespace ranges; | ||||
using namespace ranges::views; | using namespace ranges::views; | ||||
@@ -105,11 +105,6 @@ private: | |||||
void flush_apply_stack(); | void flush_apply_stack(); | ||||
void do_apply_op(const ApplyOp& cmd, std::string reason); | void do_apply_op(const ApplyOp& cmd, std::string reason); | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> | |||||
init_output_and_workspace( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, | |||||
SmallVector<MemoryDesc> inputs_mem_desc); | |||||
void dispatch_default_cpu( | void dispatch_default_cpu( | ||||
std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos, | std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos, | ||||
const SmallVector<LogicalTensorDesc>& input_descs, | const SmallVector<LogicalTensorDesc>& input_descs, | ||||
@@ -296,6 +291,8 @@ private: | |||||
op_blacklist.end(); | op_blacklist.end(); | ||||
} | } | ||||
// operators that cannot be re-computed, including : | |||||
// distributed operators, inplace operator, random generator operators | |||||
std::vector<std::string> op_blacklist = { | std::vector<std::string> op_blacklist = { | ||||
"CollectiveComm", "InplaceAdd", "ParamPackSplit", "ParamPackConcat", | "CollectiveComm", "InplaceAdd", "ParamPackSplit", "ParamPackConcat", | ||||
"GaussianRNG", "UniformRNG", "GammaRNG", "PermutationRNG", | "GaussianRNG", "UniformRNG", "GammaRNG", "PermutationRNG", | ||||
@@ -59,7 +59,6 @@ struct TensorInfo { | |||||
// Lock interpreter when visiting `ptr`. | // Lock interpreter when visiting `ptr`. | ||||
TensorPtr ptr; | TensorPtr ptr; | ||||
LogicalTensorDesc desc; | LogicalTensorDesc desc; | ||||
MemoryDesc mem_desc; | |||||
double compute_time; | double compute_time; | ||||
size_t memory; | size_t memory; | ||||
@@ -41,20 +41,6 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs) { | const OpDef& def, SmallVector<TensorPtr> inputs) { | ||||
return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); | return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> OpDef:: | |||||
infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
return def.trait()->infer_output_mem_desc(def, inputs_tensors, inputs_mems); | |||||
} | |||||
void OpDef::execute( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||||
SmallVector<TensorPtr> workspace) { | |||||
def.trait()->execute(def, std::move(inputs), outputs, std::move(workspace)); | |||||
} | |||||
void OpDef::apply_on_device_tensornd( | void OpDef::apply_on_device_tensornd( | ||||
const OpDef& def, const SmallVector<DeviceTensorND>& inputs, | const OpDef& def, const SmallVector<DeviceTensorND>& inputs, | ||||
SmallVector<DeviceTensorND>* outputs) { | SmallVector<DeviceTensorND>* outputs) { | ||||
@@ -43,13 +43,6 @@ void OpMethFallbackByProxyGraph::impl( | |||||
ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor) { | ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor) { | ||||
func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor); | func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor); | ||||
} | } | ||||
void OpMethFallbackByProxyGraph::impl(Execute& func, op_meth_tag::Execute) { | |||||
func.Base::operator=(proxy_graph_detail::execute); | |||||
} | |||||
void OpMethFallbackByProxyGraph::impl( | |||||
InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc) { | |||||
func.Base::operator=(proxy_graph_detail::infer_output_mem_desc); | |||||
} | |||||
void OpMethFallbackByProxyGraph::impl( | void OpMethFallbackByProxyGraph::impl( | ||||
InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible) { | InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible) { | ||||
func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible); | func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible); | ||||
@@ -63,10 +56,6 @@ void OpMethFallbackFromSubgraph::impl( | |||||
func.Base::operator=(subgraph_detail::apply_on_physical_tensor); | func.Base::operator=(subgraph_detail::apply_on_physical_tensor); | ||||
} | } | ||||
void OpMethFallbackFromSubgraph::impl( | void OpMethFallbackFromSubgraph::impl( | ||||
InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc) { | |||||
func.Base::operator=(subgraph_detail::infer_output_mem_desc); | |||||
} | |||||
void OpMethFallbackFromSubgraph::impl( | |||||
ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode) { | ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode) { | ||||
func.Base::operator=(subgraph_detail::apply_on_var_node); | func.Base::operator=(subgraph_detail::apply_on_var_node); | ||||
} | } | ||||
@@ -64,12 +64,6 @@ OpMethType(DecideDispatchMode, | |||||
OpMethType(ApplyOnPhysicalTensor, | OpMethType(ApplyOnPhysicalTensor, | ||||
decltype(OpDef::apply_on_physical_tensor)); | decltype(OpDef::apply_on_physical_tensor)); | ||||
OpMethType(InferOutputMemDesc, | |||||
decltype(OpDef::infer_output_mem_desc)); | |||||
OpMethType(Execute, | |||||
decltype(OpDef::execute)); | |||||
OpMethType(ApplyOnDeviceTensorND, | OpMethType(ApplyOnDeviceTensorND, | ||||
decltype(OpDef::apply_on_device_tensornd)); | decltype(OpDef::apply_on_device_tensornd)); | ||||
@@ -123,8 +117,6 @@ struct OpMethFallback : OpMethImplBase { | |||||
struct OpMethFallbackByProxyGraph : OpMethImplBase { | struct OpMethFallbackByProxyGraph : OpMethImplBase { | ||||
using OpMethImplBase::impl; | using OpMethImplBase::impl; | ||||
static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor); | static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor); | ||||
static void impl(Execute& func, op_meth_tag::Execute); | |||||
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc); | |||||
static void impl( | static void impl( | ||||
InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible); | InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible); | ||||
static void impl(GradMaker& func, op_meth_tag::GradMaker); | static void impl(GradMaker& func, op_meth_tag::GradMaker); | ||||
@@ -133,7 +125,6 @@ struct OpMethFallbackByProxyGraph : OpMethImplBase { | |||||
struct OpMethFallbackFromSubgraph : OpMethImplBase { | struct OpMethFallbackFromSubgraph : OpMethImplBase { | ||||
using OpMethImplBase::impl; | using OpMethImplBase::impl; | ||||
static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor); | static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor); | ||||
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc); | |||||
static void impl(ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode); | static void impl(ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode); | ||||
static void impl( | static void impl( | ||||
InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible); | InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible); | ||||
@@ -185,8 +176,6 @@ struct OpTrait { | |||||
OpDefMaker make_from_op_node; | OpDefMaker make_from_op_node; | ||||
DecideDispatchMode decide_dispatch_mode; | DecideDispatchMode decide_dispatch_mode; | ||||
ApplyOnPhysicalTensor apply_on_physical_tensor; | ApplyOnPhysicalTensor apply_on_physical_tensor; | ||||
InferOutputMemDesc infer_output_mem_desc; | |||||
Execute execute; | |||||
ApplyOnDeviceTensorND apply_on_device_tensornd; | ApplyOnDeviceTensorND apply_on_device_tensornd; | ||||
ApplyOnVarNode apply_on_var_node; | ApplyOnVarNode apply_on_var_node; | ||||
InferOutputAttrsFallible infer_output_attrs_fallible; | InferOutputAttrsFallible infer_output_attrs_fallible; | ||||
@@ -207,8 +196,6 @@ struct OpTrait { | |||||
cb(make_from_op_node) \ | cb(make_from_op_node) \ | ||||
cb(decide_dispatch_mode) \ | cb(decide_dispatch_mode) \ | ||||
cb(apply_on_physical_tensor) \ | cb(apply_on_physical_tensor) \ | ||||
cb(infer_output_mem_desc) \ | |||||
cb(execute) \ | |||||
cb(apply_on_device_tensornd) \ | cb(apply_on_device_tensornd) \ | ||||
cb(apply_on_var_node) \ | cb(apply_on_var_node) \ | ||||
cb(infer_output_attrs_fallible) \ | cb(infer_output_attrs_fallible) \ | ||||
@@ -81,50 +81,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; | return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
auto& input = inputs_tensors[0]; | |||||
TensorShape target_shape; | |||||
cg::copy_tensor_value_to_shape( | |||||
target_shape, inputs_tensors[1]->get_value().proxy_to_default_cpu()); | |||||
// TODO: memory forward | |||||
// if (input->shape().eq_shape(target_shape)) { | |||||
// return {{{input->layout(), 0, input->comp_node(), | |||||
// StorageIdentifier::make(&inputs_mems[0])}}, {}}; | |||||
// } | |||||
return {{{{target_shape, input->dtype()}, | |||||
0, | |||||
input->comp_node(), | |||||
StorageIdentifier::make(0)}}, | |||||
{}}; | |||||
} | |||||
void execute( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||||
SmallVector<TensorPtr> workspace) { | |||||
if (outputs[0]->layout().is_empty()) { | |||||
return; | |||||
} | |||||
if (inputs[0]->shape().eq_shape(outputs[0]->shape())) { | |||||
mgb_assert(inputs[0]->layout().eq_layout(outputs[0]->layout())); | |||||
// TODO: memory forward | |||||
// mgb_assert(inputs[0]->offset() == outputs[0]->offset()); | |||||
// mgb_assert(inputs[0]->blob() == outputs[0]->blob()); | |||||
outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor()); | |||||
} else { | |||||
TensorLayout input_layout = inputs[0]->layout().broadcast(outputs[0]->shape()); | |||||
outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor().sub( | |||||
SubTensorSpec::make_from_layout(input_layout))); | |||||
} | |||||
} | |||||
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | ||||
.make_from_op_node(make_from_op_node) | .make_from_op_node(make_from_op_node) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
.infer_output_mem_desc(infer_output_mem_desc) | |||||
.execute(execute) | |||||
.fallback(); | .fallback(); | ||||
} // namespace broadcast | } // namespace broadcast | ||||
@@ -187,41 +147,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; | return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
auto&& op_def = def.cast_final_safe<Reshape>(); | |||||
size_t nr_inp = inputs.size(); | |||||
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); | |||||
auto&& src = inputs[0]; | |||||
auto&& tshp_nd = inputs[1]; | |||||
auto slayout = src->layout(); | |||||
TensorShape tshp; | |||||
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||||
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { | |||||
mgb_assert(tshp[op_def.axis] == -1); | |||||
tshp[op_def.axis] = 1; | |||||
tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); | |||||
} | |||||
TensorLayout tlayout = slayout.reshape(tshp); | |||||
// memory forward | |||||
return {{{tlayout, 0, src->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}, | |||||
{}}; | |||||
} | |||||
void execute( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||||
SmallVector<TensorPtr> workspace) { | |||||
mgb_assert(inputs[0]->offset() == outputs[0]->offset()); | |||||
mgb_assert(inputs[0]->blob() == outputs[0]->blob()); | |||||
} | |||||
OP_TRAIT_REG(Reshape, Reshape) | OP_TRAIT_REG(Reshape, Reshape) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
.infer_output_mem_desc(infer_output_mem_desc) | |||||
.execute(execute) | |||||
.fallback(); | .fallback(); | ||||
} // namespace reshape | } // namespace reshape | ||||
@@ -78,25 +78,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
false}; | false}; | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
return {{}, {}}; | |||||
} | |||||
void execute( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
const SmallVector<TensorPtr>& outputs, | |||||
const SmallVector<TensorPtr>& workspace) { | |||||
mgb_assert(0); | |||||
} | |||||
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) | OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.apply_on_physical_tensor(apply_on_physical_tensor) | .apply_on_physical_tensor(apply_on_physical_tensor) | ||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
.infer_output_mem_desc(infer_output_mem_desc) | |||||
.execute(execute) | |||||
.fallback(); | .fallback(); | ||||
} // namespace | } // namespace | ||||
@@ -234,12 +234,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
return op.infer_output_attrs(inputs); | return op.infer_output_attrs(inputs); | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
return {{}, {}}; | |||||
} | |||||
size_t hash(const OpDef& def) { | size_t hash(const OpDef& def) { | ||||
auto&& op = static_cast<const CustomOpDef&>(def); | auto&& op = static_cast<const CustomOpDef&>(def); | ||||
const custom::Param& param = op.param(); | const custom::Param& param = op.param(); | ||||
@@ -279,7 +273,6 @@ OP_TRAIT_REG(CustomOpDef, CustomOpDef) | |||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.apply_on_device_tensornd(apply_on_device_tensornd) | .apply_on_device_tensornd(apply_on_device_tensornd) | ||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
.infer_output_mem_desc(infer_output_mem_desc) | |||||
.hash(hash) | .hash(hash) | ||||
.is_same_st(is_same_st) | .is_same_st(is_same_st) | ||||
.props(props) | .props(props) | ||||
@@ -110,35 +110,6 @@ void apply_on_device_tensornd( | |||||
opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr); | opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr); | ||||
} | } | ||||
void execute( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||||
SmallVector<TensorPtr> workspace) { | |||||
mgb_assert(outputs.size() == 1); | |||||
SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
inp_tensornds[i] = inputs[i]->dev_tensor(); | |||||
} | |||||
SmallVector<DeviceTensorND> out_tensornds = {outputs[0]->dev_tensor()}; | |||||
apply_on_device_tensornd(def, inp_tensornds, &out_tensornds); | |||||
} | |||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
auto&& op_def = def.cast_final_safe<Elemwise>(); | |||||
TensorShapeArray inp_shapes(inputs_tensors.size()); | |||||
for (size_t i = 0; i < inputs_tensors.size(); ++i) { | |||||
inp_shapes[i] = inputs_tensors[i]->layout(); | |||||
} | |||||
TensorShape shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); | |||||
SmallVector<MemoryDesc> outputs = { | |||||
{{shape, inputs_tensors[0]->dtype()}, | |||||
0, | |||||
inputs_tensors[0]->comp_node(), | |||||
StorageIdentifier::make(1)}}; | |||||
return {outputs, {}}; | |||||
} | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | const OpDef& def, const SmallVector<TensorPtr>& inputs) { | ||||
auto&& op_def = def.cast_final_safe<Elemwise>(); | auto&& op_def = def.cast_final_safe<Elemwise>(); | ||||
@@ -251,7 +222,7 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node( | |||||
SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( | SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | const OpDef& def, const SmallVector<TensorPtr>& inputs) { | ||||
mgb_assert( | mgb_assert( | ||||
inputs[0]->blob().use_count() == 2 && inputs[0]->blob()->storage().unique(), | |||||
inputs[0]->blob().use_count() == 1 && inputs[0]->blob()->storage().unique(), | |||||
"This inplace modification may change the elements of other tensors. " | "This inplace modification may change the elements of other tensors. " | ||||
"Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs " | "Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs " | ||||
"correctly."); | "correctly."); | ||||
@@ -265,23 +236,6 @@ SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( | |||||
return {std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout())}; | return {std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout())}; | ||||
} | } | ||||
void execute_inplace( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||||
SmallVector<TensorPtr> workspace) { | |||||
apply_inplace_add_on_physical_tensor(def, inputs); | |||||
} | |||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> | |||||
infer_inplace_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
auto dest = inputs_tensors[0]; | |||||
SmallVector<MemoryDesc> outputs = { | |||||
{dest->layout(), 0, dest->comp_node(), | |||||
StorageIdentifier::make(&inputs_mems[0])}}; | |||||
return {outputs, {}}; | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible( | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible( | ||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | ||||
mgb_assert(inputs.size() == 4, "invalid input number for inplace_add"); | mgb_assert(inputs.size() == 4, "invalid input number for inplace_add"); | ||||
@@ -319,16 +273,12 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
.apply_on_device_tensornd(apply_on_device_tensornd) | .apply_on_device_tensornd(apply_on_device_tensornd) | ||||
.apply_on_physical_tensor(apply_on_physical_tensor) | .apply_on_physical_tensor(apply_on_physical_tensor) | ||||
.infer_output_mem_desc(infer_output_mem_desc) | |||||
.execute(execute) | |||||
.fallback(); | .fallback(); | ||||
OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate) | OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate) | ||||
.apply_on_var_node(apply_inplace_add_on_var_node) | .apply_on_var_node(apply_inplace_add_on_var_node) | ||||
.apply_on_physical_tensor(apply_inplace_add_on_physical_tensor) | .apply_on_physical_tensor(apply_inplace_add_on_physical_tensor) | ||||
.infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible) | .infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible) | ||||
.infer_output_mem_desc(infer_inplace_output_mem_desc) | |||||
.execute(execute_inplace) | |||||
.fallback(); | .fallback(); | ||||
} // anonymous namespace | } // anonymous namespace | ||||
@@ -75,16 +75,11 @@ SmallVector<LogicalTensorDesc> infer_output_attrs( | |||||
dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32()); | dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32()); | ||||
return dests; | return dests; | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
return {{}, {}}; | |||||
} | |||||
OP_TRAIT_REG(CheckNonFinite, CheckNonFinite) | OP_TRAIT_REG(CheckNonFinite, CheckNonFinite) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.apply_on_physical_tensor(apply_on_physical_tensor) | .apply_on_physical_tensor(apply_on_physical_tensor) | ||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
.infer_output_mem_desc(infer_output_mem_desc) | |||||
.fallback(); | .fallback(); | ||||
} // namespace check_non_finite | } // namespace check_non_finite | ||||
@@ -36,6 +36,7 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
return Reduce::make(node->param()); | return Reduce::make(node->param()); | ||||
} | } | ||||
// TODO: using this for apply_on_physical_tensor | |||||
bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) { | bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) { | ||||
auto&& reduce = static_cast<const Reduce&>(def); | auto&& reduce = static_cast<const Reduce&>(def); | ||||
if (reduce.mode != Reduce::Mode::SUM_SQR && inputs.size() == 2) { | if (reduce.mode != Reduce::Mode::SUM_SQR && inputs.size() == 2) { | ||||
@@ -49,31 +50,9 @@ bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) { | |||||
return false; | return false; | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
if (memory_forward_success(def, inputs_tensors)) { | |||||
auto& src_desc = inputs_mems[0]; | |||||
return {{{src_desc.layout, 0, src_desc.cn, StorageIdentifier::make(&src_desc)}}, | |||||
{}}; | |||||
} | |||||
return proxy_graph_detail::infer_output_mem_desc(def, inputs_tensors, inputs_mems); | |||||
} | |||||
void execute( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||||
SmallVector<TensorPtr> workspace) { | |||||
if (memory_forward_success(def, inputs)) { | |||||
return; | |||||
} | |||||
return proxy_graph_detail::execute(def, inputs, outputs, workspace); | |||||
} | |||||
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) | OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) | ||||
.make_from_op_node(make_from_op_node) | .make_from_op_node(make_from_op_node) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.infer_output_mem_desc(infer_output_mem_desc) | |||||
.execute(execute) | |||||
.fallback(); | .fallback(); | ||||
} // namespace reduce | } // namespace reduce | ||||
} // namespace | } // namespace | ||||
@@ -518,20 +518,6 @@ SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>( | |||||
} | } | ||||
template <typename Op> | template <typename Op> | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
auto&& dests = infer_output_attrs<Op>(def, inputs_tensors); | |||||
SmallVector<MemoryDesc> outputs; | |||||
for (size_t i = 0; i < dests.size(); ++i) { | |||||
outputs.push_back( | |||||
{dests[i].layout, 0, dests[i].comp_node, | |||||
StorageIdentifier::make(i + 1)}); | |||||
} | |||||
return {outputs, {}}; | |||||
} | |||||
template <typename Op> | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | const OpDef& def, const SmallVector<TensorPtr>& inputs) { | ||||
SmallVector<TensorPtr> outputs; | SmallVector<TensorPtr> outputs; | ||||
@@ -543,13 +529,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
return outputs; | return outputs; | ||||
} | } | ||||
template <typename Op> | |||||
void execute( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||||
SmallVector<TensorPtr> workspace) { | |||||
exec<Op>(def, inputs, outputs, {}); | |||||
} | |||||
template <typename Op, typename Output> | template <typename Op, typename Output> | ||||
Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
@@ -641,8 +620,6 @@ CompNode get_rng_handle_compnode(Handle handle) { | |||||
.apply_on_var_node(apply_on_var_node<NAME, Output>) \ | .apply_on_var_node(apply_on_var_node<NAME, Output>) \ | ||||
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \ | .apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \ | ||||
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ | .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ | ||||
.infer_output_mem_desc(infer_output_mem_desc<NAME>) \ | |||||
.execute(execute<NAME>) \ | |||||
.fallback(); \ | .fallback(); \ | ||||
} | } | ||||
@@ -141,39 +141,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; | return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
HostTensorND tensor = get_var_shape_host_tensor(def, inputs); | |||||
SmallVector<MemoryDesc> ret; | |||||
auto&& blob = MultiCNConstTensorCache::inst().lookup(tensor); | |||||
if (blob) { | |||||
ret.push_back( | |||||
{tensor.layout(), 0, inputs[0]->comp_node(), | |||||
StorageIdentifier::make(Tensor::make( | |||||
std::forward<decltype(blob)>(blob), tensor.layout(), | |||||
tensor))}); | |||||
} else { | |||||
ret.push_back( | |||||
{tensor.layout(), 0, inputs[0]->comp_node(), | |||||
StorageIdentifier::make(1)}); | |||||
} | |||||
return {ret, {}}; | |||||
} | |||||
void execute( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
const SmallVector<TensorPtr>& outputs, | |||||
const SmallVector<TensorPtr>& workspace) { | |||||
HostTensorND tensor = get_var_shape_host_tensor(def, inputs); | |||||
SmallVector<MemoryDesc> ret; | |||||
auto&& blob = MultiCNConstTensorCache::inst().lookup(tensor); | |||||
if (!blob || blob->storage() != outputs[0]->blob()->storage()) { | |||||
outputs[0]->dev_tensor().copy_from_fixlayout(tensor); | |||||
AsyncReleaser::inst()->add(tensor); | |||||
} | |||||
} | |||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | ||||
auto* node = &node_->cast_final_safe<opr::GetVarShape>(); | auto* node = &node_->cast_final_safe<opr::GetVarShape>(); | ||||
return GetVarShape::make(node->param()); | return GetVarShape::make(node->param()); | ||||
@@ -186,8 +153,6 @@ OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) | |||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.apply_on_device_tensornd(apply_on_device_tensornd) | .apply_on_device_tensornd(apply_on_device_tensornd) | ||||
.apply_on_physical_tensor(apply_on_physical_tensor) | .apply_on_physical_tensor(apply_on_physical_tensor) | ||||
.infer_output_mem_desc(infer_output_mem_desc) | |||||
.execute(execute) | |||||
.fallback(); | .fallback(); | ||||
} // namespace get_var_shape | } // namespace get_var_shape | ||||
@@ -215,38 +180,6 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( | |||||
return opr; | return opr; | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> | |||||
param_pack_split_infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
auto&& param = def.cast_final_safe<ParamPackSplit>(); | |||||
mgb_assert( | |||||
inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size()); | |||||
auto&& inp = inputs[0]; | |||||
auto&& shp = inp->layout(); | |||||
mgb_assert(shp.ndim == 1, "ParamPackSplit input shape invalid, ndim should be 1"); | |||||
mgb_assert(param.shapes.size() * 2 == param.offsets.size()); | |||||
SmallVector<MemoryDesc> ret; | |||||
auto&& shapes = get_shapes(param.shapes); | |||||
size_t dtype_size = inputs[0]->layout().dtype.size(); | |||||
for (size_t i = 0; i < shapes.size(); ++i) { | |||||
// memory forward | |||||
ret.push_back( | |||||
{{shapes[i], inputs[0]->dtype()}, | |||||
param.offsets[i * 2] * dtype_size, | |||||
inp->comp_node(), | |||||
StorageIdentifier::make(&inputs_mems[0])}); | |||||
} | |||||
return {ret, {}}; | |||||
} | |||||
void param_pack_split_execute( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
const SmallVector<TensorPtr>& outputs, | |||||
const SmallVector<TensorPtr>& workspace) { | |||||
// do nothing | |||||
} | |||||
SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( | SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | const OpDef& def, const SmallVector<TensorPtr>& inputs) { | ||||
auto&& param = def.cast_final_safe<ParamPackSplit>(); | auto&& param = def.cast_final_safe<ParamPackSplit>(); | ||||
@@ -268,8 +201,6 @@ SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( | |||||
OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit) | OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit) | ||||
.apply_on_var_node(param_pack_split_apply_on_var_node) | .apply_on_var_node(param_pack_split_apply_on_var_node) | ||||
.infer_output_mem_desc(param_pack_split_infer_output_mem_desc) | |||||
.execute(param_pack_split_execute) | |||||
.apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor) | .apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor) | ||||
.fallback(); | .fallback(); | ||||
@@ -286,75 +217,6 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( | |||||
return opr; | return opr; | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> | |||||
param_pack_concat_infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
def.cast_final_safe<ParamPackConcat>(); | |||||
mgb_assert(inputs.size() > 1, "param_pack should have at least one input"); | |||||
auto comp_node = inputs.front()->comp_node(); | |||||
auto dtype = inputs.front()->dtype(); | |||||
size_t nr_inputs = inputs.size() - 1; | |||||
size_t nr_elems = 0; | |||||
for (size_t i = 0; i < nr_inputs; ++i) { | |||||
auto& input = inputs[i]; | |||||
mgb_assert( | |||||
comp_node == input->comp_node(), | |||||
"inputs for param_pack_concat must in same comp_node"); | |||||
mgb_assert( | |||||
dtype == input->dtype(), | |||||
"inputs for param_pack_concat must have same dtype"); | |||||
nr_elems += input->layout().total_nr_elems(); | |||||
} | |||||
auto dest_layout = TensorLayout({nr_elems}, dtype); | |||||
auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node); | |||||
size_t ws_size; | |||||
{ | |||||
TensorShapeArray src_shapes; | |||||
for (size_t i = 0; i < nr_inputs; ++i) { | |||||
src_shapes.push_back(inputs[i]->shape()); | |||||
} | |||||
ws_size = caller.op->get_workspace_in_bytes( | |||||
src_shapes, inputs.back()->shape(), TensorShape{}); | |||||
} | |||||
SmallVector<MemoryDesc> outputs = { | |||||
{dest_layout, 0, comp_node, StorageIdentifier::make(1)}}; | |||||
MemoryDesc workspace = { | |||||
{{ws_size}, dtype::Byte()}, 0, comp_node, StorageIdentifier::make(2)}; | |||||
return {outputs, {workspace}}; | |||||
} | |||||
void param_pack_concat_execute( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
const SmallVector<TensorPtr>& outputs, | |||||
const SmallVector<TensorPtr>& workspace) { | |||||
def.cast_final_safe<ParamPackConcat>(); | |||||
mgb_assert(inputs.size() > 1, "param_pack should have at least one input"); | |||||
auto comp_node = inputs.front()->comp_node(); | |||||
size_t nr_inputs = inputs.size() - 1; | |||||
auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node); | |||||
size_t srcs_size = sizeof(void*) * nr_inputs; | |||||
void** srcs_raw_ptr = (void**)comp_node.alloc_host(srcs_size); | |||||
std::shared_ptr<dt_byte> srcs_ptr = { | |||||
(dt_byte*)srcs_raw_ptr, | |||||
[comp_node](dt_byte* ptr) { comp_node.free_host(ptr); }}; | |||||
TensorLayout srcs_layout = TensorLayout{{nr_inputs}, dtype::Int32()}; | |||||
for (size_t i = 0; i < nr_inputs; ++i) { | |||||
srcs_raw_ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr(); | |||||
} | |||||
HostTensorStorage srcs_storage; | |||||
srcs_storage.reset(comp_node, srcs_size, srcs_ptr); | |||||
megdnn::Workspace dnn_wk( | |||||
workspace[0]->blob()->storage().get(), workspace[0]->blob()->size()); | |||||
caller.op->exec( | |||||
{srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(), | |||||
outputs[0]->dev_tensor().as_megdnn(), dnn_wk); | |||||
AsyncReleaser::inst()->add( | |||||
HostTensorND{comp_node, srcs_layout}.storage(srcs_storage)); | |||||
} | |||||
SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor( | SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs) { | const OpDef& def, const SmallVector<TensorPtr>& inputs) { | ||||
def.cast_final_safe<ParamPackConcat>(); | def.cast_final_safe<ParamPackConcat>(); | ||||
@@ -407,8 +269,6 @@ SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor( | |||||
OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) | OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) | ||||
.apply_on_var_node(param_pack_concat_apply_on_var_node) | .apply_on_var_node(param_pack_concat_apply_on_var_node) | ||||
.infer_output_mem_desc(param_pack_concat_infer_output_mem_desc) | |||||
.execute(param_pack_concat_execute) | |||||
.apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor) | .apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor) | ||||
.fallback(); | .fallback(); | ||||
} // namespace param_pack | } // namespace param_pack | ||||
@@ -445,12 +445,6 @@ auto make_name(const OpDef& def) { | |||||
return ssprintf("CompiledOp[%s]", op.op->make_name().c_str()); | return ssprintf("CompiledOp[%s]", op.op->make_name().c_str()); | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
return {}; | |||||
} | |||||
EncodedSubgraph make_backward_graph( | EncodedSubgraph make_backward_graph( | ||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
@@ -498,7 +492,6 @@ OP_TRAIT_REG(CompiledOp, CompiledOp) | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
.make_backward_graph(make_backward_graph) | .make_backward_graph(make_backward_graph) | ||||
.make_name(make_name) | .make_name(make_name) | ||||
.infer_output_mem_desc(infer_output_mem_desc) | |||||
.props(props) | .props(props) | ||||
.hash(hash) | .hash(hash) | ||||
.is_same_st(is_same_st) | .is_same_st(is_same_st) | ||||
@@ -634,36 +634,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph:: | |||||
mgb_assert(0); | mgb_assert(0); | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph:: | |||||
infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<Tensor*>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
auto opr = get_proxy_opr(def, inputs_tensors); | |||||
CUR_OPR_GUARD(opr); | |||||
::mgb::opr::intl::WorkspaceLimitHook::set_impl( | |||||
m_graph.get(), ProxyGraph::get_workspace_limit); | |||||
do_shape_infer(true); | |||||
SmallVector<MemoryDesc> outputs; | |||||
SmallVector<MemoryDesc> workspaces; | |||||
size_t cur_id = 0; | |||||
for (auto&& i : opr->output()) { | |||||
if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | |||||
workspaces.push_back( | |||||
{{i->shape(), i->dtype(), i->format()}, | |||||
0, | |||||
i->comp_node(), | |||||
StorageIdentifier::make(++cur_id)}); | |||||
} else { | |||||
outputs.push_back( | |||||
{{i->shape(), i->dtype()}, | |||||
0, | |||||
i->comp_node(), | |||||
StorageIdentifier::make(++cur_id)}); | |||||
} | |||||
} | |||||
return {outputs, workspaces}; | |||||
} | |||||
struct ProxyGraph::GradGraph { | struct ProxyGraph::GradGraph { | ||||
cg::VarNodeArray inputs; | cg::VarNodeArray inputs; | ||||
cg::VarNodeArray outputs; | cg::VarNodeArray outputs; | ||||
@@ -812,7 +782,6 @@ EncodedSubgraph ProxyGraph::make_backward_graph( | |||||
return result; | return result; | ||||
} | } | ||||
VarNodeArray ProxyGraph::make_input_place_holders( | VarNodeArray ProxyGraph::make_input_place_holders( | ||||
const SmallVector<LogicalTensorDesc>& inputs) { | const SmallVector<LogicalTensorDesc>& inputs) { | ||||
VarNodeArray vinputs(inputs.size()); | VarNodeArray vinputs(inputs.size()); | ||||
@@ -47,10 +47,6 @@ public: | |||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
const SmallVector<bool>& output_has_grad); | const SmallVector<bool>& output_has_grad); | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<Tensor*>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems); | |||||
/********************** Logical Tensor API **********************/ | /********************** Logical Tensor API **********************/ | ||||
size_t get_opr_output_size( | size_t get_opr_output_size( | ||||
@@ -83,25 +83,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
return outputs; | return outputs; | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
auto&& graph = ProxyGraph::get_default_graph(); | |||||
return graph->infer_output_mem_desc( | |||||
def, to_raw_ptr_array(inputs_tensors), inputs_mems); | |||||
} | |||||
void execute( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||||
SmallVector<TensorPtr> workspace) { | |||||
exec(def, inputs, outputs, workspace); | |||||
auto async_error = ProxyGraph::get_async_error(); | |||||
if (async_error) { | |||||
throw *async_error; | |||||
} | |||||
return; | |||||
} | |||||
// std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const | // std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const | ||||
// OpDef& def, | // OpDef& def, | ||||
// const SmallVector<LogicalTensorDesc>& inputs) { | // const SmallVector<LogicalTensorDesc>& inputs) { | ||||
@@ -162,12 +162,6 @@ EncodedSubgraph make_backward_graph( | |||||
inputs, input_requires_grad, output_has_grad, forward_graph); | inputs, input_requires_grad, output_has_grad, forward_graph); | ||||
} | } | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems) { | |||||
return {{}, {}}; | |||||
} | |||||
} // namespace subgraph_detail | } // namespace subgraph_detail | ||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb |
@@ -53,10 +53,6 @@ public: | |||||
static SmallVector<TensorPtr> apply_on_physical_tensor( | static SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
const OpDef& def, SmallVector<TensorPtr> inputs); | const OpDef& def, SmallVector<TensorPtr> inputs); | ||||
static void execute( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, | |||||
SmallVector<TensorPtr> outputs, SmallVector<TensorPtr> workspace); | |||||
/*! | /*! | ||||
* \brief Call the corresponding dnn op to calculate results. Output | * \brief Call the corresponding dnn op to calculate results. Output | ||||
* tensors' device memory should be allocated outside. | * tensors' device memory should be allocated outside. | ||||
@@ -71,11 +67,6 @@ public: | |||||
static std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | static std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | ||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); | ||||
static std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> | |||||
infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems); | |||||
static EncodedSubgraph make_backward_graph( | static EncodedSubgraph make_backward_graph( | ||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | ||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
@@ -288,36 +288,6 @@ struct LogicalTensorDesc { | |||||
CompNode comp_node; | CompNode comp_node; | ||||
DeviceTensorND value; // cpu:default | DeviceTensorND value; // cpu:default | ||||
}; | }; | ||||
struct StorageIdentifier; | |||||
struct MemoryDesc { | |||||
TensorLayout layout; | |||||
size_t offset; | |||||
CompNode cn; | |||||
std::shared_ptr<StorageIdentifier> id; | |||||
}; | |||||
struct StorageIdentifier { | |||||
enum { INVALID, SYS_ALLOC, FROM_OTHER, DEVICE_PTR } tag; | |||||
union { | |||||
size_t id; | |||||
MemoryDesc* desc; | |||||
}; | |||||
TensorPtr ptr; | |||||
StorageIdentifier() = default; | |||||
StorageIdentifier(size_t id) : tag(SYS_ALLOC), id(id) {} | |||||
StorageIdentifier(const MemoryDesc* desc) : tag(FROM_OTHER), desc(desc->id->desc) {} | |||||
StorageIdentifier(TensorPtr dev_ptr) : tag(DEVICE_PTR), ptr(dev_ptr) {} | |||||
template <typename... Args> | |||||
static std::shared_ptr<StorageIdentifier> make(Args&&... args) { | |||||
return std::make_shared<StorageIdentifier>(std::forward<Args>(args)...); | |||||
} | |||||
bool is_sys_alloc() { return tag == SYS_ALLOC; } | |||||
bool is_from_other() { return tag == FROM_OTHER; } | |||||
bool is_device_ptr() { return tag == DEVICE_PTR; } | |||||
bool is_invalid() { return tag == INVALID; } | |||||
}; | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -20,17 +20,9 @@ namespace proxy_graph_detail { | |||||
SmallVector<TensorPtr> apply_on_physical_tensor( | SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
const OpDef& def, SmallVector<TensorPtr> inputs); | const OpDef& def, SmallVector<TensorPtr> inputs); | ||||
void execute( | |||||
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||||
SmallVector<TensorPtr> workspace); | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | ||||
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); | const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems); | |||||
void exec( | void exec( | ||||
const OpDef& def, const SmallVector<TensorPtr>& inputs, | const OpDef& def, const SmallVector<TensorPtr>& inputs, | ||||
const SmallVector<TensorPtr>& outputs); | const SmallVector<TensorPtr>& outputs); | ||||
@@ -35,10 +35,6 @@ EncodedSubgraph make_backward_graph( | |||||
const SmallVector<bool>& input_requires_grad, | const SmallVector<bool>& input_requires_grad, | ||||
const SmallVector<bool>& output_has_grad); | const SmallVector<bool>& output_has_grad); | ||||
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||||
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||||
const SmallVector<MemoryDesc>& inputs_mems); | |||||
} // namespace subgraph_detail | } // namespace subgraph_detail | ||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb |