Browse Source

refactor(imperative): remove infer_output_mem_desc

GitOrigin-RevId: bff62b33a0
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
e6706be23a
26 changed files with 20 additions and 607 deletions
  1. +0
    -10
      imperative/python/megengine/core/__init__.py
  2. +5
    -4
      imperative/python/test/unit/core/test_interpreter.py
  3. +0
    -1
      imperative/python/test/unit/utils/test_profiler.py
  4. +10
    -94
      imperative/src/impl/interpreter/interpreter_impl.cpp
  5. +2
    -5
      imperative/src/impl/interpreter/interpreter_impl.h
  6. +0
    -1
      imperative/src/impl/interpreter/tensor_info.h
  7. +0
    -14
      imperative/src/impl/op_def.cpp
  8. +0
    -11
      imperative/src/impl/op_trait.cpp
  9. +0
    -13
      imperative/src/impl/op_trait.h
  10. +0
    -72
      imperative/src/impl/ops/broadcast.cpp
  11. +0
    -15
      imperative/src/impl/ops/cond_take.cpp
  12. +0
    -7
      imperative/src/impl/ops/custom_opdef.cpp
  13. +1
    -51
      imperative/src/impl/ops/elemwise.cpp
  14. +1
    -6
      imperative/src/impl/ops/misc.cpp
  15. +1
    -22
      imperative/src/impl/ops/reduce.cpp
  16. +0
    -23
      imperative/src/impl/ops/rng.cpp
  17. +0
    -140
      imperative/src/impl/ops/tensor_manip.cpp
  18. +0
    -7
      imperative/src/impl/ops/utility.cpp
  19. +0
    -31
      imperative/src/impl/proxy_graph.cpp
  20. +0
    -4
      imperative/src/impl/proxy_graph.h
  21. +0
    -19
      imperative/src/impl/proxy_graph_detail.cpp
  22. +0
    -6
      imperative/src/impl/subgraph_detail.cpp
  23. +0
    -9
      imperative/src/include/megbrain/imperative/op_def.h
  24. +0
    -30
      imperative/src/include/megbrain/imperative/physical_tensor.h
  25. +0
    -8
      imperative/src/include/megbrain/imperative/proxy_graph_detail.h
  26. +0
    -4
      imperative/src/include/megbrain/imperative/subgraph_detail.h

+ 0
- 10
imperative/python/megengine/core/__init__.py View File

@@ -12,13 +12,3 @@ from contextlib import contextmanager

from ._imperative_rt.core2 import get_option, set_option
from .tensor.megbrain_graph import Graph


@contextmanager
def option(key, value):
value = int(value)
old = get_option(key)
set_option(key, value)
yield
assert get_option(key) == value
set_option(key, old)

+ 5
- 4
imperative/python/test/unit/core/test_interpreter.py View File

@@ -76,10 +76,11 @@ def test_drop_basic():
def test_finalize():
prog = """
import megengine
with megengine.core.option("enable_host_compute", 0):
x = megengine.tensor(0)
y = x + 1
y.numpy()
megengine.core.set_option("enable_host_compute", 0)
x = megengine.tensor(0)
y = x + 1
y.numpy()
megengine.core.set_option("enable_host_compute", 1)
"""
subprocess.check_call([sys.executable, "-c", prog])



+ 0
- 1
imperative/python/test/unit/utils/test_profiler.py View File

@@ -15,7 +15,6 @@ import pytest
from megengine import Parameter
from megengine import distributed as dist
from megengine import tensor
from megengine.core import option
from megengine.jit import trace
from megengine.module import Module
from megengine.utils.profiler import Profiler, scope


+ 10
- 94
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -155,7 +155,6 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
info->h_value = value;
info->desc.value = value.proxy_to_default_cpu();
}
info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
m_worker.add_task(
{Profiler::next_id(), Put{info, value, no_cache},
get_channel_state().stack_manager.dump()});
@@ -180,7 +179,6 @@ TensorInfo* ChannelImpl::put_impl(
auto info = alloc();
MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put);
init(info, {data.layout(), data.comp_node()});
info->mem_desc.id = StorageIdentifier::make(++m_storage_id);
info->ptr = Tensor::make(data, hvalue);
MGB_RECORD_EVENT(
TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node,
@@ -536,9 +534,6 @@ void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
info->status = TensorInfo::Allocated;
info->desc = std::move(desc);
info->mem_desc.layout = info->desc.layout;
info->mem_desc.cn = info->desc.comp_node;
info->mem_desc.offset = 0;
}

void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) {
@@ -667,18 +662,14 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
bool profiling_device =
Profiler::is_profiling() && Profiler::get_option("profile_device", 0);
uint64_t apply_id = cmd.id;
struct TensorWithDesc {
TensorPtr tensor;
MemoryDesc desc;
};
SmallVector<TensorWithDesc> inputs;
SmallVector<TensorPtr> inputs;
inputs.reserve(cmd.inputs.size());
// refcnt == 1, owners: [TensorInfo::ptr]
for (auto i : cmd.inputs) {
mgb_assert(i->ptr, "Invalid input tensor ptr!");
// refcnt ++, owners: [i->ptr, tensor_inputs]
// tensor_inputs.push_back(i->ptr);
inputs.push_back({i->ptr, i->mem_desc});
inputs.push_back(i->ptr);
}
if (state.options.enable_dtr_auto_drop &&
state.options.dtr_eviction_threshold > 0) {
@@ -686,56 +677,28 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
}
auto apply_on_physical_tensor =
[&](auto&& self, const OpDef& def,
SmallVector<TensorWithDesc> inputs) -> SmallVector<TensorWithDesc> {
SmallVector<TensorPtr> inputs) -> SmallVector<TensorPtr> {
auto apply_functor = [&](std::shared_ptr<OpDef> op,
SmallVector<TensorWithDesc> inputs,
size_t nr_outputs) -> SmallVector<TensorWithDesc> {
SmallVector<TensorPtr> inputs,
size_t nr_outputs) -> SmallVector<TensorPtr> {
auto opname = op->trait()->make_name(*op);
imperative_log_profile_begin(opname.c_str());
auto outputs = self(self, *op, inputs);
imperative_log_profile_end(opname.c_str());
return outputs;
};
auto const_functor = [&](TensorPtr value) -> TensorWithDesc {
return {value, MemoryDesc{
value->layout(), 0, value->comp_node(),
StorageIdentifier::make()}};
};
auto const_functor = [&](TensorPtr value) -> TensorPtr { return value; };
if (def.trait()->make_forward_graph) {
// apply recursivily
SmallVector<LogicalTensorDesc> input_descs;
for (auto&& input : inputs) {
input_descs.push_back(
{{{}, input.tensor->dtype()}, input.tensor->comp_node()});
input_descs.push_back({{{}, input->dtype()}, input->comp_node()});
}
auto forward_graph = OpDef::make_forward_graph(def, input_descs);
auto outputs = forward_graph.apply(inputs, apply_functor, const_functor);
return outputs;
}
SmallVector<TensorPtr> input_tensors;
SmallVector<MemoryDesc> input_descs;
for (auto&& input : inputs) {
input_tensors.push_back(input.tensor);
input_descs.push_back(input.desc);
}
auto [output_descs, output_tensors, workspaces] =
init_output_and_workspace(def, input_tensors, input_descs);
if (!output_descs.empty()) {
OpDef::execute(def, input_tensors, output_tensors, workspaces);
} else {
output_tensors = OpDef::apply_on_physical_tensor(def, input_tensors);
for (auto&& output_tensor : output_tensors) {
output_descs.push_back(MemoryDesc{
output_tensor->layout(), 0, output_tensor->comp_node(),
StorageIdentifier::make()});
}
}
SmallVector<TensorWithDesc> outputs;
for (auto&& [output_tensor, output_desc] :
ranges::zip_view(output_tensors, output_descs)) {
outputs.push_back({output_tensor, output_desc});
}
return outputs;
return OpDef::apply_on_physical_tensor(def, inputs);
};
MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
// Begin profiling operator
@@ -787,8 +750,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
} else {
MGB_RECORD_EVENT(OpOutputEvent, output->id);
produce_tensor(output, outputs[i].tensor);
output->mem_desc = outputs[i].desc;
produce_tensor(output, outputs[i]);
MGB_RECORD_EVENT(OpOutputFinishEvent, output->id);
sample_on_device(output->desc.comp_node, false);
}
@@ -800,7 +762,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
estimate_compute_time += i->memory;
}
for (auto i : outputs) {
estimate_compute_time += i.tensor->blob()->size();
estimate_compute_time += i->blob()->size();
}
m_dtr.estimate_timestamp += estimate_compute_time / 1e8;
for (auto i : cmd.outputs) {
@@ -1012,52 +974,6 @@ void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
set_log_level(pre_level);
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>>
ChannelImpl::init_output_and_workspace(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<MemoryDesc> inputs_mem_desc) {
auto [outputs_desc, workspaces_desc] =
OpDef::infer_output_mem_desc(def, inputs, inputs_mem_desc);
if (!outputs_desc.size()) {
// failed to infer memplan
return {{}, {}, {}};
}
// refine storage id to make it unique
for (auto&& desc : outputs_desc) {
if (desc.id->is_sys_alloc()) {
// TODO: there may be some outputs sharing the same storage id
desc.id->id = ++m_storage_id;
}
}
auto& state = get_worker_state();
auto alloc_storage = [&](SmallVector<MemoryDesc>& desc) {
SmallVector<TensorPtr> tensors;
for (size_t i = 0; i < desc.size(); i++) {
if (desc[i].id->is_sys_alloc()) {
tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn));
if (state.options.enable_dtr_auto_drop && !desc[i].layout.is_empty()) {
alloc_tensor_with_evict(tensors.back()->blob().get());
}
} else if (desc[i].id->is_from_other()) {
for (size_t j = 0; j < inputs_mem_desc.size(); j++) {
if (inputs_mem_desc[j].id->desc == desc[i].id->desc) {
tensors.push_back(
inputs[j]->sub(desc[i].offset, desc[i].layout));
break;
}
}
} else if (desc[i].id->is_device_ptr()) {
tensors.push_back(desc[i].id->ptr);
} else {
mgb_assert(0, "not implemented");
}
}
return tensors;
};

return {outputs_desc, alloc_storage(outputs_desc), alloc_storage(workspaces_desc)};
}

void ChannelImpl::process_one_task(Command& icmd) {
using namespace ranges;
using namespace ranges::views;


+ 2
- 5
imperative/src/impl/interpreter/interpreter_impl.h View File

@@ -105,11 +105,6 @@ private:
void flush_apply_stack();
void do_apply_op(const ApplyOp& cmd, std::string reason);

std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>>
init_output_and_workspace(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<MemoryDesc> inputs_mem_desc);

void dispatch_default_cpu(
std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos,
const SmallVector<LogicalTensorDesc>& input_descs,
@@ -296,6 +291,8 @@ private:
op_blacklist.end();
}

// operators that cannot be re-computed, including :
// distributed operators, inplace operator, random generator operators
std::vector<std::string> op_blacklist = {
"CollectiveComm", "InplaceAdd", "ParamPackSplit", "ParamPackConcat",
"GaussianRNG", "UniformRNG", "GammaRNG", "PermutationRNG",


+ 0
- 1
imperative/src/impl/interpreter/tensor_info.h View File

@@ -59,7 +59,6 @@ struct TensorInfo {
// Lock interpreter when visiting `ptr`.
TensorPtr ptr;
LogicalTensorDesc desc;
MemoryDesc mem_desc;

double compute_time;
size_t memory;


+ 0
- 14
imperative/src/impl/op_def.cpp View File

@@ -41,20 +41,6 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs) {
return def.trait()->apply_on_physical_tensor(def, std::move(inputs));
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> OpDef::
infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return def.trait()->infer_output_mem_desc(def, inputs_tensors, inputs_mems);
}

void OpDef::execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
def.trait()->execute(def, std::move(inputs), outputs, std::move(workspace));
}

void OpDef::apply_on_device_tensornd(
const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs) {


+ 0
- 11
imperative/src/impl/op_trait.cpp View File

@@ -43,13 +43,6 @@ void OpMethFallbackByProxyGraph::impl(
ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor) {
func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor);
}
void OpMethFallbackByProxyGraph::impl(Execute& func, op_meth_tag::Execute) {
func.Base::operator=(proxy_graph_detail::execute);
}
void OpMethFallbackByProxyGraph::impl(
InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc) {
func.Base::operator=(proxy_graph_detail::infer_output_mem_desc);
}
void OpMethFallbackByProxyGraph::impl(
InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible) {
func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible);
@@ -63,10 +56,6 @@ void OpMethFallbackFromSubgraph::impl(
func.Base::operator=(subgraph_detail::apply_on_physical_tensor);
}
void OpMethFallbackFromSubgraph::impl(
InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc) {
func.Base::operator=(subgraph_detail::infer_output_mem_desc);
}
void OpMethFallbackFromSubgraph::impl(
ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode) {
func.Base::operator=(subgraph_detail::apply_on_var_node);
}


+ 0
- 13
imperative/src/impl/op_trait.h View File

@@ -64,12 +64,6 @@ OpMethType(DecideDispatchMode,
OpMethType(ApplyOnPhysicalTensor,
decltype(OpDef::apply_on_physical_tensor));

OpMethType(InferOutputMemDesc,
decltype(OpDef::infer_output_mem_desc));

OpMethType(Execute,
decltype(OpDef::execute));

OpMethType(ApplyOnDeviceTensorND,
decltype(OpDef::apply_on_device_tensornd));

@@ -123,8 +117,6 @@ struct OpMethFallback : OpMethImplBase {
struct OpMethFallbackByProxyGraph : OpMethImplBase {
using OpMethImplBase::impl;
static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor);
static void impl(Execute& func, op_meth_tag::Execute);
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc);
static void impl(
InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible);
static void impl(GradMaker& func, op_meth_tag::GradMaker);
@@ -133,7 +125,6 @@ struct OpMethFallbackByProxyGraph : OpMethImplBase {
struct OpMethFallbackFromSubgraph : OpMethImplBase {
using OpMethImplBase::impl;
static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor);
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc);
static void impl(ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode);
static void impl(
InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible);
@@ -185,8 +176,6 @@ struct OpTrait {
OpDefMaker make_from_op_node;
DecideDispatchMode decide_dispatch_mode;
ApplyOnPhysicalTensor apply_on_physical_tensor;
InferOutputMemDesc infer_output_mem_desc;
Execute execute;
ApplyOnDeviceTensorND apply_on_device_tensornd;
ApplyOnVarNode apply_on_var_node;
InferOutputAttrsFallible infer_output_attrs_fallible;
@@ -207,8 +196,6 @@ struct OpTrait {
cb(make_from_op_node) \
cb(decide_dispatch_mode) \
cb(apply_on_physical_tensor) \
cb(infer_output_mem_desc) \
cb(execute) \
cb(apply_on_device_tensornd) \
cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \


+ 0
- 72
imperative/src/impl/ops/broadcast.cpp View File

@@ -81,50 +81,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto& input = inputs_tensors[0];
TensorShape target_shape;
cg::copy_tensor_value_to_shape(
target_shape, inputs_tensors[1]->get_value().proxy_to_default_cpu());
// TODO: memory forward
// if (input->shape().eq_shape(target_shape)) {
// return {{{input->layout(), 0, input->comp_node(),
// StorageIdentifier::make(&inputs_mems[0])}}, {}};
// }
return {{{{target_shape, input->dtype()},
0,
input->comp_node(),
StorageIdentifier::make(0)}},
{}};
}

void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
if (outputs[0]->layout().is_empty()) {
return;
}
if (inputs[0]->shape().eq_shape(outputs[0]->shape())) {
mgb_assert(inputs[0]->layout().eq_layout(outputs[0]->layout()));
// TODO: memory forward
// mgb_assert(inputs[0]->offset() == outputs[0]->offset());
// mgb_assert(inputs[0]->blob() == outputs[0]->blob());
outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor());
} else {
TensorLayout input_layout = inputs[0]->layout().broadcast(outputs[0]->shape());
outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor().sub(
SubTensorSpec::make_from_layout(input_layout)));
}
}

OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback();
} // namespace broadcast

@@ -187,41 +147,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& op_def = def.cast_final_safe<Reshape>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto&& tshp_nd = inputs[1];
auto slayout = src->layout();

TensorShape tshp;
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu());
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(tshp[op_def.axis] == -1);
tshp[op_def.axis] = 1;
tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems();
}
TensorLayout tlayout = slayout.reshape(tshp);
// memory forward
return {{{tlayout, 0, src->comp_node(), StorageIdentifier::make(&inputs_mems[0])}},
{}};
}

void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
mgb_assert(inputs[0]->offset() == outputs[0]->offset());
mgb_assert(inputs[0]->blob() == outputs[0]->blob());
}

OP_TRAIT_REG(Reshape, Reshape)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback();
} // namespace reshape



+ 0
- 15
imperative/src/impl/ops/cond_take.cpp View File

@@ -78,25 +78,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
false};
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
return {{}, {}};
}

void execute(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs,
const SmallVector<TensorPtr>& workspace) {
mgb_assert(0);
}

OP_TRAIT_REG(CondTake, CondTake, opr::CondTake)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback();

} // namespace


+ 0
- 7
imperative/src/impl/ops/custom_opdef.cpp View File

@@ -234,12 +234,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return op.infer_output_attrs(inputs);
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return {{}, {}};
}

size_t hash(const OpDef& def) {
auto&& op = static_cast<const CustomOpDef&>(def);
const custom::Param& param = op.param();
@@ -279,7 +273,6 @@ OP_TRAIT_REG(CustomOpDef, CustomOpDef)
.apply_on_var_node(apply_on_var_node)
.apply_on_device_tensornd(apply_on_device_tensornd)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.hash(hash)
.is_same_st(is_same_st)
.props(props)


+ 1
- 51
imperative/src/impl/ops/elemwise.cpp View File

@@ -110,35 +110,6 @@ void apply_on_device_tensornd(
opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr);
}

void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
mgb_assert(outputs.size() == 1);
SmallVector<DeviceTensorND> inp_tensornds(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
inp_tensornds[i] = inputs[i]->dev_tensor();
}
SmallVector<DeviceTensorND> out_tensornds = {outputs[0]->dev_tensor()};
apply_on_device_tensornd(def, inp_tensornds, &out_tensornds);
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& op_def = def.cast_final_safe<Elemwise>();
TensorShapeArray inp_shapes(inputs_tensors.size());
for (size_t i = 0; i < inputs_tensors.size(); ++i) {
inp_shapes[i] = inputs_tensors[i]->layout();
}
TensorShape shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes);
SmallVector<MemoryDesc> outputs = {
{{shape, inputs_tensors[0]->dtype()},
0,
inputs_tensors[0]->comp_node(),
StorageIdentifier::make(1)}};
return {outputs, {}};
}

SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
auto&& op_def = def.cast_final_safe<Elemwise>();
@@ -251,7 +222,7 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node(
SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
mgb_assert(
inputs[0]->blob().use_count() == 2 && inputs[0]->blob()->storage().unique(),
inputs[0]->blob().use_count() == 1 && inputs[0]->blob()->storage().unique(),
"This inplace modification may change the elements of other tensors. "
"Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs "
"correctly.");
@@ -265,23 +236,6 @@ SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
return {std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout())};
}

void execute_inplace(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
apply_inplace_add_on_physical_tensor(def, inputs);
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>>
infer_inplace_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto dest = inputs_tensors[0];
SmallVector<MemoryDesc> outputs = {
{dest->layout(), 0, dest->comp_node(),
StorageIdentifier::make(&inputs_mems[0])}};
return {outputs, {}};
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(inputs.size() == 4, "invalid input number for inplace_add");
@@ -319,16 +273,12 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_device_tensornd(apply_on_device_tensornd)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback();

OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate)
.apply_on_var_node(apply_inplace_add_on_var_node)
.apply_on_physical_tensor(apply_inplace_add_on_physical_tensor)
.infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible)
.infer_output_mem_desc(infer_inplace_output_mem_desc)
.execute(execute_inplace)
.fallback();
} // anonymous namespace



+ 1
- 6
imperative/src/impl/ops/misc.cpp View File

@@ -75,16 +75,11 @@ SmallVector<LogicalTensorDesc> infer_output_attrs(
dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32());
return dests;
}
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return {{}, {}};
}

OP_TRAIT_REG(CheckNonFinite, CheckNonFinite)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.fallback();
} // namespace check_non_finite



+ 1
- 22
imperative/src/impl/ops/reduce.cpp View File

@@ -36,6 +36,7 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
return Reduce::make(node->param());
}

// TODO: using this for apply_on_physical_tensor
bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) {
auto&& reduce = static_cast<const Reduce&>(def);
if (reduce.mode != Reduce::Mode::SUM_SQR && inputs.size() == 2) {
@@ -49,31 +50,9 @@ bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) {
return false;
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
if (memory_forward_success(def, inputs_tensors)) {
auto& src_desc = inputs_mems[0];
return {{{src_desc.layout, 0, src_desc.cn, StorageIdentifier::make(&src_desc)}},
{}};
}
return proxy_graph_detail::infer_output_mem_desc(def, inputs_tensors, inputs_mems);
}

void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
if (memory_forward_success(def, inputs)) {
return;
}
return proxy_graph_detail::execute(def, inputs, outputs, workspace);
}

OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback();
} // namespace reduce
} // namespace


+ 0
- 23
imperative/src/impl/ops/rng.cpp View File

@@ -518,20 +518,6 @@ SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>(
}

template <typename Op>
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& dests = infer_output_attrs<Op>(def, inputs_tensors);
SmallVector<MemoryDesc> outputs;
for (size_t i = 0; i < dests.size(); ++i) {
outputs.push_back(
{dests[i].layout, 0, dests[i].comp_node,
StorageIdentifier::make(i + 1)});
}
return {outputs, {}};
}

template <typename Op>
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<TensorPtr> outputs;
@@ -543,13 +529,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return outputs;
}

template <typename Op>
void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
exec<Op>(def, inputs, outputs, {});
}

template <typename Op, typename Output>
Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
size_t nr_inp = inputs.size();
@@ -641,8 +620,6 @@ CompNode get_rng_handle_compnode(Handle handle) {
.apply_on_var_node(apply_on_var_node<NAME, Output>) \
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.infer_output_mem_desc(infer_output_mem_desc<NAME>) \
.execute(execute<NAME>) \
.fallback(); \
}



+ 0
- 140
imperative/src/impl/ops/tensor_manip.cpp View File

@@ -141,39 +141,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
HostTensorND tensor = get_var_shape_host_tensor(def, inputs);
SmallVector<MemoryDesc> ret;
auto&& blob = MultiCNConstTensorCache::inst().lookup(tensor);
if (blob) {
ret.push_back(
{tensor.layout(), 0, inputs[0]->comp_node(),
StorageIdentifier::make(Tensor::make(
std::forward<decltype(blob)>(blob), tensor.layout(),
tensor))});
} else {
ret.push_back(
{tensor.layout(), 0, inputs[0]->comp_node(),
StorageIdentifier::make(1)});
}
return {ret, {}};
}

void execute(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs,
const SmallVector<TensorPtr>& workspace) {
HostTensorND tensor = get_var_shape_host_tensor(def, inputs);
SmallVector<MemoryDesc> ret;
auto&& blob = MultiCNConstTensorCache::inst().lookup(tensor);
if (!blob || blob->storage() != outputs[0]->blob()->storage()) {
outputs[0]->dev_tensor().copy_from_fixlayout(tensor);
AsyncReleaser::inst()->add(tensor);
}
}

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::GetVarShape>();
return GetVarShape::make(node->param());
@@ -186,8 +153,6 @@ OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape)
.apply_on_var_node(apply_on_var_node)
.apply_on_device_tensornd(apply_on_device_tensornd)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_mem_desc(infer_output_mem_desc)
.execute(execute)
.fallback();
} // namespace get_var_shape

@@ -215,38 +180,6 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node(
return opr;
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>>
param_pack_split_infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& param = def.cast_final_safe<ParamPackSplit>();
mgb_assert(
inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size());
auto&& inp = inputs[0];
auto&& shp = inp->layout();
mgb_assert(shp.ndim == 1, "ParamPackSplit input shape invalid, ndim should be 1");
mgb_assert(param.shapes.size() * 2 == param.offsets.size());
SmallVector<MemoryDesc> ret;
auto&& shapes = get_shapes(param.shapes);
size_t dtype_size = inputs[0]->layout().dtype.size();
for (size_t i = 0; i < shapes.size(); ++i) {
// memory forward
ret.push_back(
{{shapes[i], inputs[0]->dtype()},
param.offsets[i * 2] * dtype_size,
inp->comp_node(),
StorageIdentifier::make(&inputs_mems[0])});
}
return {ret, {}};
}

void param_pack_split_execute(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs,
const SmallVector<TensorPtr>& workspace) {
// do nothing
}

SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
auto&& param = def.cast_final_safe<ParamPackSplit>();
@@ -268,8 +201,6 @@ SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(

OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit)
.apply_on_var_node(param_pack_split_apply_on_var_node)
.infer_output_mem_desc(param_pack_split_infer_output_mem_desc)
.execute(param_pack_split_execute)
.apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor)
.fallback();

@@ -286,75 +217,6 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
return opr;
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>>
param_pack_concat_infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<MemoryDesc>& inputs_mems) {
def.cast_final_safe<ParamPackConcat>();
mgb_assert(inputs.size() > 1, "param_pack should have at least one input");
auto comp_node = inputs.front()->comp_node();
auto dtype = inputs.front()->dtype();
size_t nr_inputs = inputs.size() - 1;
size_t nr_elems = 0;
for (size_t i = 0; i < nr_inputs; ++i) {
auto& input = inputs[i];
mgb_assert(
comp_node == input->comp_node(),
"inputs for param_pack_concat must in same comp_node");
mgb_assert(
dtype == input->dtype(),
"inputs for param_pack_concat must have same dtype");
nr_elems += input->layout().total_nr_elems();
}
auto dest_layout = TensorLayout({nr_elems}, dtype);
auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node);
size_t ws_size;
{
TensorShapeArray src_shapes;
for (size_t i = 0; i < nr_inputs; ++i) {
src_shapes.push_back(inputs[i]->shape());
}
ws_size = caller.op->get_workspace_in_bytes(
src_shapes, inputs.back()->shape(), TensorShape{});
}

SmallVector<MemoryDesc> outputs = {
{dest_layout, 0, comp_node, StorageIdentifier::make(1)}};
MemoryDesc workspace = {
{{ws_size}, dtype::Byte()}, 0, comp_node, StorageIdentifier::make(2)};

return {outputs, {workspace}};
}

void param_pack_concat_execute(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs,
const SmallVector<TensorPtr>& workspace) {
def.cast_final_safe<ParamPackConcat>();
mgb_assert(inputs.size() > 1, "param_pack should have at least one input");
auto comp_node = inputs.front()->comp_node();
size_t nr_inputs = inputs.size() - 1;
auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node);
size_t srcs_size = sizeof(void*) * nr_inputs;
void** srcs_raw_ptr = (void**)comp_node.alloc_host(srcs_size);
std::shared_ptr<dt_byte> srcs_ptr = {
(dt_byte*)srcs_raw_ptr,
[comp_node](dt_byte* ptr) { comp_node.free_host(ptr); }};
TensorLayout srcs_layout = TensorLayout{{nr_inputs}, dtype::Int32()};
for (size_t i = 0; i < nr_inputs; ++i) {
srcs_raw_ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr();
}
HostTensorStorage srcs_storage;
srcs_storage.reset(comp_node, srcs_size, srcs_ptr);
megdnn::Workspace dnn_wk(
workspace[0]->blob()->storage().get(), workspace[0]->blob()->size());
caller.op->exec(
{srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(),
outputs[0]->dev_tensor().as_megdnn(), dnn_wk);
AsyncReleaser::inst()->add(
HostTensorND{comp_node, srcs_layout}.storage(srcs_storage));
}

SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
def.cast_final_safe<ParamPackConcat>();
@@ -407,8 +269,6 @@ SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor(

OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
.apply_on_var_node(param_pack_concat_apply_on_var_node)
.infer_output_mem_desc(param_pack_concat_infer_output_mem_desc)
.execute(param_pack_concat_execute)
.apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor)
.fallback();
} // namespace param_pack


+ 0
- 7
imperative/src/impl/ops/utility.cpp View File

@@ -445,12 +445,6 @@ auto make_name(const OpDef& def) {
return ssprintf("CompiledOp[%s]", op.op->make_name().c_str());
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return {};
}

EncodedSubgraph make_backward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
@@ -498,7 +492,6 @@ OP_TRAIT_REG(CompiledOp, CompiledOp)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.make_backward_graph(make_backward_graph)
.make_name(make_name)
.infer_output_mem_desc(infer_output_mem_desc)
.props(props)
.hash(hash)
.is_same_st(is_same_st)


+ 0
- 31
imperative/src/impl/proxy_graph.cpp View File

@@ -634,36 +634,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph::
mgb_assert(0);
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph::
infer_output_mem_desc(
const OpDef& def, const SmallVector<Tensor*>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto opr = get_proxy_opr(def, inputs_tensors);
CUR_OPR_GUARD(opr);
::mgb::opr::intl::WorkspaceLimitHook::set_impl(
m_graph.get(), ProxyGraph::get_workspace_limit);
do_shape_infer(true);
SmallVector<MemoryDesc> outputs;
SmallVector<MemoryDesc> workspaces;
size_t cur_id = 0;
for (auto&& i : opr->output()) {
if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
workspaces.push_back(
{{i->shape(), i->dtype(), i->format()},
0,
i->comp_node(),
StorageIdentifier::make(++cur_id)});
} else {
outputs.push_back(
{{i->shape(), i->dtype()},
0,
i->comp_node(),
StorageIdentifier::make(++cur_id)});
}
}
return {outputs, workspaces};
}

struct ProxyGraph::GradGraph {
cg::VarNodeArray inputs;
cg::VarNodeArray outputs;
@@ -812,7 +782,6 @@ EncodedSubgraph ProxyGraph::make_backward_graph(
return result;
}


VarNodeArray ProxyGraph::make_input_place_holders(
const SmallVector<LogicalTensorDesc>& inputs) {
VarNodeArray vinputs(inputs.size());


+ 0
- 4
imperative/src/impl/proxy_graph.h View File

@@ -47,10 +47,6 @@ public:
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad);

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<Tensor*>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems);

/********************** Logical Tensor API **********************/

size_t get_opr_output_size(


+ 0
- 19
imperative/src/impl/proxy_graph_detail.cpp View File

@@ -83,25 +83,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return outputs;
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_mem_desc(
def, to_raw_ptr_array(inputs_tensors), inputs_mems);
}

void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace) {
exec(def, inputs, outputs, workspace);
auto async_error = ProxyGraph::get_async_error();
if (async_error) {
throw *async_error;
}
return;
}

// std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const
// OpDef& def,
// const SmallVector<LogicalTensorDesc>& inputs) {


+ 0
- 6
imperative/src/impl/subgraph_detail.cpp View File

@@ -162,12 +162,6 @@ EncodedSubgraph make_backward_graph(
inputs, input_requires_grad, output_has_grad, forward_graph);
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return {{}, {}};
}

} // namespace subgraph_detail
} // namespace imperative
} // namespace mgb

+ 0
- 9
imperative/src/include/megbrain/imperative/op_def.h View File

@@ -53,10 +53,6 @@ public:
static SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs);

static void execute(
const OpDef& def, SmallVector<TensorPtr> inputs,
SmallVector<TensorPtr> outputs, SmallVector<TensorPtr> workspace);

/*!
* \brief Call the corresponding dnn op to calculate results. Output
* tensors' device memory should be allocated outside.
@@ -71,11 +67,6 @@ public:
static std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs);

static std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>>
infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems);

static EncodedSubgraph make_backward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,


+ 0
- 30
imperative/src/include/megbrain/imperative/physical_tensor.h View File

@@ -288,36 +288,6 @@ struct LogicalTensorDesc {
CompNode comp_node;
DeviceTensorND value; // cpu:default
};

struct StorageIdentifier;
struct MemoryDesc {
TensorLayout layout;
size_t offset;
CompNode cn;
std::shared_ptr<StorageIdentifier> id;
};

struct StorageIdentifier {
enum { INVALID, SYS_ALLOC, FROM_OTHER, DEVICE_PTR } tag;
union {
size_t id;
MemoryDesc* desc;
};
TensorPtr ptr;
StorageIdentifier() = default;
StorageIdentifier(size_t id) : tag(SYS_ALLOC), id(id) {}
StorageIdentifier(const MemoryDesc* desc) : tag(FROM_OTHER), desc(desc->id->desc) {}
StorageIdentifier(TensorPtr dev_ptr) : tag(DEVICE_PTR), ptr(dev_ptr) {}

template <typename... Args>
static std::shared_ptr<StorageIdentifier> make(Args&&... args) {
return std::make_shared<StorageIdentifier>(std::forward<Args>(args)...);
}
bool is_sys_alloc() { return tag == SYS_ALLOC; }
bool is_from_other() { return tag == FROM_OTHER; }
bool is_device_ptr() { return tag == DEVICE_PTR; }
bool is_invalid() { return tag == INVALID; }
};
} // namespace imperative
} // namespace mgb



+ 0
- 8
imperative/src/include/megbrain/imperative/proxy_graph_detail.h View File

@@ -20,17 +20,9 @@ namespace proxy_graph_detail {
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, SmallVector<TensorPtr> inputs);

void execute(
const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
SmallVector<TensorPtr> workspace);

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs);

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems);

void exec(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs);


+ 0
- 4
imperative/src/include/megbrain/imperative/subgraph_detail.h View File

@@ -35,10 +35,6 @@ EncodedSubgraph make_backward_graph(
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad);

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems);

} // namespace subgraph_detail
} // namespace imperative
} // namespace mgb

Loading…
Cancel
Save