Browse Source

feat(ops): add serval utility ops

GitOrigin-RevId: 623cb5ddfc
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
3fd3e000d1
6 changed files with 470 additions and 12 deletions
  1. +19
    -7
      imperative/src/impl/interpreter/interpreter_impl.cpp
  2. +2
    -5
      imperative/src/impl/interpreter/interpreter_impl.h
  3. +11
    -0
      imperative/src/impl/op_def.cpp
  4. +383
    -0
      imperative/src/impl/ops/utility.cpp
  5. +4
    -0
      imperative/src/include/megbrain/imperative/op_def.h
  6. +51
    -0
      imperative/src/include/megbrain/imperative/ops/utility.h

+ 19
- 7
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -18,6 +18,7 @@
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/utils/to_string.h"

#include "../blob_manager_impl.h"
@@ -99,6 +100,16 @@ ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
return m_worker_state;
}

void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
sys::set_thread_name("worker");
m_owner->m_worker_state.tid = std::this_thread::get_id();
OpDef::set_allocator([&](CompNode device, size_t size) {
auto blob = Blob::make(device, size);
m_owner->alloc_tensor_with_evict(blob.get());
return blob->storage();
});
}

// Do not use m_xxx_state directly
#define m_channel_state
#define m_worker_state
@@ -649,7 +660,9 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
auto apply_on_physical_tensor = [&](auto&& self, const OpDef& def, SmallVector<TensorWithDesc> inputs) -> SmallVector<TensorWithDesc> {
auto apply_functor = [&](std::shared_ptr<OpDef> op, SmallVector<TensorWithDesc> inputs, size_t nr_outputs) -> SmallVector<TensorWithDesc> {
auto opname = op->trait()->make_name(*op);
imperative_log_profile_begin(opname.c_str());
auto outputs = self(self, *op, inputs);
imperative_log_profile_end(opname.c_str());
return outputs;
};
auto const_functor = [&](TensorPtr value) -> TensorWithDesc {
@@ -667,7 +680,6 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
}
SmallVector<TensorPtr> input_tensors;
SmallVector<MemoryDesc> input_descs;
// size_t next_mem_desc_id = 0;
for (auto&& input: inputs) {
input_tensors.push_back(input.tensor);
input_descs.push_back(input.desc);
@@ -890,7 +902,7 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
return valid_tensors;
}

void ChannelImpl::alloc_tensor_with_evict(TensorPtr x) {
void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
auto reserve_size = [&](size_t size) {
if (!m_dtr.comp_node.valid()) {
return false;
@@ -902,15 +914,15 @@ void ChannelImpl::alloc_tensor_with_evict(TensorPtr x) {
return true;
};
auto pre_level = set_log_level(LogLevel::NO_LOG);
reserve_size(x->blob()->size());
MGB_TRY { BlobManager::inst()->alloc_direct(x->blob().get(), x->blob()->size()); }
reserve_size(x->size());
MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
MGB_CATCH(MemAllocError&, {
bool suc = false;
while (!suc) {
if (!auto_evict(1)) {
break;
}
MGB_TRY { BlobManager::inst()->alloc_direct(x->blob().get(), x->blob()->size()); }
MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
MGB_CATCH(MemAllocError&, { continue; });
suc = true;
}
@@ -919,7 +931,7 @@ void ChannelImpl::alloc_tensor_with_evict(TensorPtr x) {
mgb_log_warn("reallocating all cuda memory to alleviate fragmentation, the performance may be affected");
set_log_level(LogLevel::NO_LOG);
BlobManager::inst()->defrag(x->comp_node());
BlobManager::inst()->alloc_direct(x->blob().get(), x->blob()->size());
BlobManager::inst()->alloc_direct(x, x->size());
}
});
set_log_level(pre_level);
@@ -949,7 +961,7 @@ std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPt
if (desc[i].id->is_sys_alloc()) {
tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn));
if (state.options.enable_dtr_auto_drop && !desc[i].layout.is_empty()) {
alloc_tensor_with_evict(tensors.back());
alloc_tensor_with_evict(tensors.back()->blob().get());
}
} else if (desc[i].id->is_from_other()) {
for (size_t j = 0; j < inputs_mem_desc.size();j ++) {


+ 2
- 5
imperative/src/impl/interpreter/interpreter_impl.h View File

@@ -164,10 +164,7 @@ private:
void process_one_task(IdentifiedCommand& icmd) {
m_owner->process_one_task(icmd);
}
void on_async_queue_worker_thread_start() override {
sys::set_thread_name("worker");
m_owner->m_worker_state.tid = std::this_thread::get_id();
}
void on_async_queue_worker_thread_start() override;
private:
ChannelImpl* m_owner;
} m_worker;
@@ -419,7 +416,7 @@ private:
//! automatically evict an optimal tensor
bool auto_evict(size_t);

void alloc_tensor_with_evict(TensorPtr);
void alloc_tensor_with_evict(Blob*);

// assert thread id when call get_xxx_state to avoid misuse
ChannelState& get_channel_state();


+ 11
- 0
imperative/src/impl/op_def.cpp View File

@@ -155,6 +155,17 @@ const std::string OpDef::make_name() const {
return m_scope + "." + trait()->make_name(*this);
}

static thread_local OpDef::allocator_t local_allocator;

void OpDef::set_allocator(allocator_t allocator) {
mgb_assert(!local_allocator, "allocator has been set before");
local_allocator = allocator;
}

DeviceTensorStorage::RawStorage OpDef::allocate(CompNode device, size_t size) const {
return local_allocator(device, size);
}

std::string Subgraph::repr() const {
std::ostringstream buf;
buf << "(";


+ 383
- 0
imperative/src/impl/ops/utility.cpp View File

@@ -12,7 +12,13 @@
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/graph_cache.h"
#include "megbrain/imperative/subgraph_detail.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/io.h"
#include "../op_trait.h"

namespace mgb::imperative {
@@ -32,6 +38,125 @@ OP_TRAIT_REG(FastpathCopy,FastpathCopy)
.fallback();
}} // fastpathcopy

namespace { namespace shape_infer {
auto apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto& op = def.cast_final_safe<ShapeInfer>();
size_t nr_inputs = inputs.size();
mgb_assert(nr_inputs > 0, "no inputs for ShapeInfer");
SmallVector<LogicalTensorDesc> input_descs;
for (size_t i = 0; i < nr_inputs; ++i) {
auto input = inputs[i]->get_value();
TensorLayout layout;
layout.ndim = input.shape(0);
for (size_t i = 0; i < layout.ndim; ++i) {
layout[i] = input.ptr<int32_t>()[i];
}
layout.dtype = op.dtypes[i];
layout.init_contiguous_stride();
input_descs.push_back({layout, op.devices[i]});
}
auto [output_descs, valid] = OpDef::infer_output_attrs_fallible(*op.op, input_descs);
mgb_assert(valid, "shape inference incomplete");
SmallVector<TensorPtr> outputs;
for (auto&& output_desc: output_descs) {
HostTensorND shape_tensor{output_desc.comp_node, {output_desc.layout.ndim}, dtype::Int32()};
for (size_t i = 0; i < output_desc.layout.ndim; ++i) {
shape_tensor.ptr<int32_t>()[i] = output_desc.layout[i];
}
auto output = Tensor::make(shape_tensor);
outputs.push_back(output);
}
return outputs;
}
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto& op = def.cast_final_safe<ShapeInfer>();
size_t nr_inputs = inputs.size();
VarNodeArray input_values, outputs;
mgb_assert(nr_inputs > 0, "no inputs for ShapeInfer");
for (size_t i = 0; i < nr_inputs; ++i) {
auto input_value = opr::Alloc::make(SymbolVar(inputs[i]), op.dtypes[i], {op.devices[i]});
input_values.push_back(input_value.node());
}
auto output_values = OpDef::apply_on_var_node(*op.op, input_values);
for (auto&& output_value: output_values) {
outputs.push_back(opr::GetVarShape::make(output_value).node());
}
return outputs;
}

auto infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& input_descs) {
auto& op = def.cast_final_safe<ShapeInfer>();
SmallVector<LogicalTensorDesc> input_shape_descs;
size_t nr_inputs = op.devices.size();
mgb_assert(op.dtypes.size() == nr_inputs, "number of input devices and dtypes mismatch");
for (size_t i = 0; i < nr_inputs; ++i) {
LogicalTensorDesc input_shape_desc;
input_shape_desc.comp_node = op.devices[i];
input_shape_desc.layout.ndim = 0;
input_shape_desc.layout.dtype = op.dtypes[i];
input_shape_descs.push_back(input_shape_desc);
}
auto [output_shape_descs, _] = OpDef::infer_output_attrs_fallible(*op.op, input_shape_descs);
SmallVector<LogicalTensorDesc> output_descs;
for (auto&& output_shape_desc: output_shape_descs) {
LogicalTensorDesc output_desc;
output_desc.comp_node = output_shape_desc.comp_node;
output_desc.layout.ndim = 1;
output_desc.layout.dtype = dtype::Int32();
output_descs.push_back(output_desc);
}
return std::make_tuple(output_descs, false);
}

auto props(const OpDef& def) {
auto& op = def.cast_final_safe<ShapeInfer>();
return OpDef::props(*op.op);
}

auto make_name(const OpDef& def) {
auto& op = def.cast_final_safe<ShapeInfer>();
MGB_MARK_USED_VAR(op);
return ssprintf("ShapeInfer[%s]", op.op->make_name().c_str());
}

auto hash(const OpDef& def) {
auto& op = def.cast_final_safe<ShapeInfer>();
return op.op->hash();
}

auto is_same_st(const OpDef& def, const OpDef& another) {
if (!another.same_type<ShapeInfer>()) {
return false;
}
auto& lhs = def.cast_final_safe<ShapeInfer>();
auto& rhs = another.cast_final_safe<ShapeInfer>();
if (!lhs.op->is_same(*rhs.op)) {
return false;
}
return std::tie(lhs.devices, lhs.dtypes) ==
std::tie(rhs.devices, rhs.dtypes);
}

OP_TRAIT_REG(ShapeInfer,ShapeInfer)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.make_name(make_name)
.props(props)
.hash(hash)
.is_same_st(is_same_st)
.fallback();
}}


MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShapeInfer);

namespace { namespace identity {
auto apply_on_var_node(
const OpDef& def,
@@ -53,4 +178,262 @@ OP_TRAIT_REG(Identity, Identity)
.fallback();
}} // identity

namespace { namespace subgraph {

EncodedSubraph make_forward_graph(const OpDef& def, SmallVector<LogicalTensorDesc> inputs) {
return EncodedSubraph::make(def.cast_final_safe<SubgraphOp>().graph);
}

EncodedSubraph make_backward_graph(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
SmallVector<bool> output_has_grad) {
auto& op = def.cast_final_safe<SubgraphOp>();
mgb_assert(output_has_grad.size() == op.output_grad_mask.size());
for (size_t i = 0; i < output_has_grad.size(); ++i) {
if (!op.output_grad_mask[i]) {
output_has_grad[i] = false;
}
}
auto bgraph = subgraph_detail::make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
return EncodedSubraph::make_single(SubgraphOp::make(op.name+"Grad", bgraph.graph), bgraph.input_mask, bgraph.output_mask);
}

std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
auto& op = def.cast_final_safe<SubgraphOp>();
return {
{"name", op.name},
{"inputs", mgb::imperative::to_string(op.graph.inputs)},
{"exprs", mgb::imperative::to_string(op.graph.exprs)},
{"outputs", mgb::imperative::to_string(op.graph.outputs)},
};
}

std::string make_name(const OpDef& def) {
auto& op = def.cast_final_safe<SubgraphOp>();
if (op.name.empty()) {
return "SubgraphOp";
} else {
return op.name;
}
}

auto hash(const OpDef& def) {
auto& op = def.cast_final_safe<SubgraphOp>();
if (!op.graph_key) {
return (size_t)reinterpret_cast<uintptr_t>(&op.graph);
}
return op.graph_key->hash();
}

auto is_same_st(const OpDef& def, const OpDef& another) {
if (!another.same_type<SubgraphOp>()) {
return false;
}
auto& lhs = def.cast_final_safe<SubgraphOp>();
auto& rhs = another.cast_final_safe<SubgraphOp>();
auto has_graph_key = bool(lhs.graph_key);
bool graph_same = false;
if (has_graph_key) {
graph_same = rhs.graph_key && lhs.graph_key->is_same(*rhs.graph_key);
} else {
graph_same = !rhs.graph_key && &lhs.graph == &rhs.graph;
}
return graph_same;
}

OP_TRAIT_REG(SubgraphOp, SubgraphOp)
.make_forward_graph(make_forward_graph)
.make_backward_graph(make_backward_graph)
.props(props)
.make_name(make_name)
.hash(hash)
.is_same_st(is_same_st)
.fallback();

}}

namespace { namespace compiled_op {

struct DeviceMemoryAllocatorImpl: cg::DeviceMemoryAllocator {
std::shared_ptr<OpDef> current_op;
void alloc_static(ComputingGraph* graph, DeviceTensorStorage& dest, size_t size) override {
mgb_assert(0, "alloc_static is not allowed in CompiledOp");
}
void alloc_dynamic(VarNode* var, DeviceTensorStorage& dest, size_t size) override {
auto comp_node = var->comp_node();
auto storage = current_op->allocate(comp_node, size);
dest.reset(comp_node, size, storage);
}
};

struct ComputingGraphHolder {
std::shared_ptr<ComputingGraph> graph;
std::unique_ptr<cg::AsyncExecutable> executable;
SmallVector<std::shared_ptr<DeviceTensorND>> inputs;
SmallVector<std::shared_ptr<DeviceTensorND>> outputs;
std::shared_ptr<DeviceMemoryAllocatorImpl> allocator;
};

thread_local OpMethResultCache<ComputingGraphHolder> cg_cache;

ComputingGraphHolder& get_computing_graph(std::shared_ptr<OpDef> compiled_op, SmallVector<LogicalTensorDesc> descs) {
OpMethArgs<> key = {compiled_op, descs};
auto& cg_holder = cg_cache[key];
if (!cg_holder.graph) {
cg_holder.allocator = std::make_shared<DeviceMemoryAllocatorImpl>();
cg_holder.graph = ComputingGraph::make();
cg_holder.graph->options().force_dynamic_alloc = true;
cg_holder.graph->options().async_exec_level = 0;
cg_holder.graph->options().graph_opt_level = compiled_op->cast_final_safe<CompiledOp>().gopt_level;
cg_holder.graph->options().enable_var_mem_defragment = false;
cg_holder.graph->set_device_memory_allocator(cg_holder.allocator);
// cg_holder.graph->options().graph_opt.jit = 2;
VarNodeArray input_vars;
for (auto&& desc: descs) {
auto input_device_nd = std::make_shared<DeviceTensorND>();
input_device_nd->dtype(desc.layout.dtype);
input_device_nd->comp_node(desc.comp_node);
input_device_nd->resize(desc.layout);
cg_holder.inputs.push_back(input_device_nd);
auto callback = [input_device_nd]{
return *input_device_nd;
};
auto* input_var = opr::InputCallback::make(*cg_holder.graph, callback, desc.comp_node, desc.layout.dtype, TensorShape())[0].node();
input_vars.push_back(input_var);
}
// forward to inner op
auto output_vars = OpDef::apply_on_var_node(*compiled_op, input_vars);
ComputingGraph::OutputSpec output_spec;
size_t nr_outputs = output_vars.size();
for (size_t i = 0; i < nr_outputs; ++i) {
auto* output_var = output_vars[i];
auto output_ptr = std::make_shared<DeviceTensorND>();
auto callback = [output_ptr](DeviceTensorND output){
output_ptr->reset(output.storage(), output.layout());
};
output_spec.push_back({output_var, callback});
cg_holder.outputs.push_back(output_ptr);
}
cg_holder.executable = cg_holder.graph->compile(output_spec);
}
return cg_holder;
}

auto apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
SmallVector<LogicalTensorDesc> input_descs;
for (auto&& input: inputs) {
input_descs.push_back({input->layout(), input->comp_node()});
}
size_t nr_inputs = inputs.size();
auto shared_def = const_cast<OpDef&>(def).shared_from_this();
auto& cg_holder = get_computing_graph(shared_def, input_descs);
for (size_t i = 0; i < nr_inputs; ++i) {
auto input_dev_tensor = inputs[i]->dev_tensor();
cg_holder.inputs[i]->reset(input_dev_tensor.storage(), input_dev_tensor.layout());
}
cg_holder.allocator->current_op = shared_def;
cg_holder.executable->execute();
cg_holder.executable->wait();
SmallVector<TensorPtr> outputs;
for (auto input_nd: cg_holder.inputs) {
*input_nd = {};
}
for (auto output_nd: cg_holder.outputs) {
outputs.push_back(Tensor::make(*output_nd));
*output_nd = {};
}
cg_holder.executable->clear_device_memory();
cg_holder.allocator->current_op = nullptr;
return outputs;
}
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
return OpDef::apply_on_var_node(*def.cast_final_safe<CompiledOp>().op, inputs);
}

auto infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& input_descs) {
return OpDef::infer_output_attrs_fallible(*def.cast_final_safe<CompiledOp>().op, input_descs);
}

auto props(const OpDef& def) {
return OpDef::props(*def.cast_final_safe<CompiledOp>().op);
}

auto make_name(const OpDef& def) {
auto& op = def.cast_final_safe<CompiledOp>();
MGB_MARK_USED_VAR(op);
return ssprintf("CompiledOp[%s]", op.op->make_name().c_str());
}

std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return {};
}

EncodedSubraph make_backward_graph(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) {
auto& op = def.cast_final_safe<CompiledOp>();
auto backward_graph = OpDef::make_backward_graph(*op.op, inputs, input_requires_grad, output_has_grad);
auto name = def.trait()->make_name(def);
auto key = std::make_shared<BackwardOpKey>();
key->op = op.op;
key->inputs = inputs;
key->extras = {input_requires_grad, output_has_grad};
SmallVector<bool> grad_outputs_has_grad(backward_graph.graph.outputs.size(), true);
std::shared_ptr<OpDef> bgraph_op;
if (backward_graph.graph.is_single()) {
bgraph_op = backward_graph.graph.as_single();
} else {
bgraph_op = SubgraphOp::make(name+"Grad", backward_graph.graph, grad_outputs_has_grad, key);
}
auto compiled_op = CompiledOp::make(bgraph_op, op.gopt_level);
auto encoded_graph = EncodedSubraph::make_single(compiled_op, backward_graph.input_mask, backward_graph.output_mask);
return encoded_graph;
}

auto hash(const OpDef& def) {
auto& op = def.cast_final_safe<CompiledOp>();
return mgb::hash_pair_combine(op.op->hash(), op.gopt_level);
}

auto is_same_st(const OpDef& def, const OpDef& another) {
if (!another.same_type<CompiledOp>()) {
return false;
}
auto& lhs = def.cast_final_safe<CompiledOp>();
auto& rhs = another.cast_final_safe<CompiledOp>();
return lhs.op->is_same(*rhs.op) && lhs.gopt_level == rhs.gopt_level;
}

OP_TRAIT_REG(CompiledOp, CompiledOp)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.make_backward_graph(make_backward_graph)
.make_name(make_name)
.infer_output_mem_desc(infer_output_mem_desc)
.props(props)
.hash(hash)
.is_same_st(is_same_st)
.fallback();
}}

MGB_DYN_TYPE_OBJ_FINAL_IMPL(SubgraphOp);

MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardOpKey);

MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompiledOp);

} // namespace mgb::imperative

+ 4
- 0
imperative/src/include/megbrain/imperative/op_def.h View File

@@ -36,6 +36,7 @@ class OpDef : public Hashable,
mutable const OpTrait* m_trait = nullptr;
std::string m_scope;
public:
using allocator_t = std::function<DeviceTensorStorage::RawStorage(CompNode, size_t)>;
virtual ~OpDef() = default;

static std::shared_ptr<OpDef> make_from_op_node(
@@ -112,6 +113,9 @@ public:
virtual size_t hash() const;

virtual bool is_same_st(const Hashable&) const;

static void set_allocator(allocator_t allocator);
DeviceTensorStorage::RawStorage allocate(CompNode, size_t) const;
};

template<typename T>


+ 51
- 0
imperative/src/include/megbrain/imperative/ops/utility.h View File

@@ -12,6 +12,7 @@
#pragma once

#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/graph_cache.h"

#include "megbrain/utils/hash.h"

@@ -35,4 +36,54 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};

struct ShapeInfer final : OpDefImplBase<ShapeInfer> {
std::shared_ptr<OpDef> op;
SmallVector<CompNode> devices;
SmallVector<DType> dtypes;
EncodedSubraph graph;
ShapeInfer() = default;
ShapeInfer(std::shared_ptr<OpDef> op, SmallVector<CompNode> devices,
SmallVector<DType> dtypes)
: op{op}, devices{devices}, dtypes{dtypes}{}
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};

struct SubgraphOp final: OpDefImplBase<SubgraphOp> {
std::string name;
Subgraph graph;
SmallVector<bool> output_grad_mask;
std::shared_ptr<Hashable> graph_key;
SubgraphOp() = default;
SubgraphOp(std::string name, Subgraph graph, SmallVector<bool> output_grad_mask={}, std::shared_ptr<Hashable> key=nullptr)
: name{name}, graph{graph}, output_grad_mask{output_grad_mask}, graph_key{std::move(key)}{
if (this->output_grad_mask.empty()) {
this->output_grad_mask.resize(graph.outputs.size(), true);
}
}
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};

struct BackwardOpKey final: Hashable, OpMethArgs<SmallVector<bool>, SmallVector<bool>> {
public:
using OpMethArgs<SmallVector<bool>, SmallVector<bool>>::OpMethArgs;
size_t hash() const override {
return OpMethArgs<SmallVector<bool>, SmallVector<bool>>::hash();
}
protected:
bool is_same_st(const Hashable& rhs) const override {
return OpMethArgs<SmallVector<bool>, SmallVector<bool>>::
operator==(rhs.cast_final_safe<BackwardOpKey>());
}
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};

struct CompiledOp final: OpDefImplBase<CompiledOp> {
std::shared_ptr<OpDef> op;
int gopt_level;
CompiledOp() = default;
CompiledOp(std::shared_ptr<OpDef> op, int gopt_level = 2)
: op{op}, gopt_level{gopt_level}{}
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};

} // namespace mgb::imperative

Loading…
Cancel
Save