Browse Source

refactor(profiler): use macro to simplify event recording/definition

GitOrigin-RevId: 4d9dcfd6c2
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
5b5a826166
3 changed files with 101 additions and 146 deletions
  1. +40
    -57
      imperative/src/impl/interpreter/events.h
  2. +51
    -77
      imperative/src/impl/interpreter/interpreter_impl.cpp
  3. +10
    -12
      imperative/src/impl/interpreter/profiler.h

+ 40
- 57
imperative/src/impl/interpreter/events.h View File

@@ -16,77 +16,60 @@


namespace mgb::imperative::interpreter::intl { namespace mgb::imperative::interpreter::intl {


struct CommandEvent {
IdentifiedCommand icmd;
};

struct CommandEnqueueEvent: CommandEvent {};

struct CommandExecuteEvent: CommandEvent {};
#define DEF_EVENT(X, ...) struct X##Event __VA_ARGS__;
#define DEF_DUR_EVENT(X, ...) struct X##Event __VA_ARGS__; struct X##FinishEvent __VA_ARGS__;


struct CommandFinishEvent: CommandEvent {};
DEF_EVENT(Command, {
IdentifiedCommand icmd;
});


struct OpEvent {
DEF_EVENT(CommandEnqueue, :CommandEvent);
DEF_EVENT(CommandExecute, :CommandEvent);
DEF_EVENT(CommandFinish, :CommandEvent);
DEF_DUR_EVENT(OpExecute, {
uint64_t id; uint64_t id;
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
SmallVector<uint64_t> inputs; SmallVector<uint64_t> inputs;
SmallVector<uint64_t> outputs; SmallVector<uint64_t> outputs;
};

struct HostOpExecuteEvent: OpEvent {};

struct DeviceOpExecuteEvent: OpEvent {};

struct HostOpFinishEvent: OpEvent {};

struct DeviceOpFinishEvent: OpEvent {};

struct TensorDeclareEvent {
});
DEF_DUR_EVENT(KernelExecute, {
uint64_t id;
std::shared_ptr<OpDef> op;
SmallVector<uint64_t> inputs;
SmallVector<uint64_t> outputs;
});
DEF_EVENT(TensorDeclare, {
uint64_t tensor_id; uint64_t tensor_id;
};

struct TensorProduceEvent {
});
DEF_EVENT(TensorProduce, {
uint64_t tensor_id; uint64_t tensor_id;
TensorLayout layout; TensorLayout layout;
CompNode device; CompNode device;
};

struct TensorEraseEvent {
});
DEF_EVENT(TensorErase, {
uint64_t tensor_id; uint64_t tensor_id;
};

struct TensorPropEvent {
});
DEF_EVENT(TensorGetProp, {
uint64_t tensor_id; uint64_t tensor_id;
TensorInfo::Prop prop; TensorInfo::Prop prop;
std::string prop_desc; std::string prop_desc;
};

struct TensorGetPropEvent: TensorPropEvent{};

struct TensorWaitPropEvent: TensorPropEvent{};

struct TensorNotifyPropEvent: TensorPropEvent{};

struct TensorWaitPropFinishEvent: TensorPropEvent{};

struct SyncStartEvent {};

struct SyncFinishEvent {};

struct ScopeEvent {
});
DEF_DUR_EVENT(TensorWaitProp, {
uint64_t tensor_id;
TensorInfo::Prop prop;
std::string prop_desc;
});
DEF_EVENT(TensorNotifyProp, {
uint64_t tensor_id;
TensorInfo::Prop prop;
std::string prop_desc;
});
DEF_DUR_EVENT(Sync, {});
DEF_DUR_EVENT(Scope, {
std::string name; std::string name;
};

struct ChannelBeginScope: ScopeEvent {};

struct ChannelEndScope: ScopeEvent {};

struct WorkerBeginScope: ScopeEvent {};

struct WorkerEndScope: ScopeEvent {};

struct DeviceBeginScope: ScopeEvent {};

struct DeviceEndScope: ScopeEvent {};
});
DEF_DUR_EVENT(DeviceScope, {
std::string name;
});


} }

+ 51
- 77
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -23,6 +23,17 @@ using namespace imperative;
using namespace interpreter; using namespace interpreter;
using namespace interpreter::intl; using namespace interpreter::intl;


#define RECORD_EVENT(type, ...) \
if (state.profiler->is_profiling()) { \
state.profiler->record_host<type>(type{__VA_ARGS__}); \
} \

#define RECORD_DEVICE_EVENT(type, device, ...) \
if (state.profiler->is_profiling()) { \
state.profiler->record_device<type>((device), type{__VA_ARGS__}); \
} \


std::thread::id ChannelImpl::get_worker_tid() { std::thread::id ChannelImpl::get_worker_tid() {
return m_worker_state.tid; return m_worker_state.tid;
} }
@@ -71,9 +82,7 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
info->desc.layout = data.layout(); info->desc.layout = data.layout();
info->desc.comp_node = data.comp_node(); info->desc.comp_node = data.comp_node();
info->ptr = Tensor::make(data); info->ptr = Tensor::make(data);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node);
}
RECORD_EVENT(TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node);
return info; return info;
} }


@@ -168,10 +177,8 @@ void ChannelImpl::dispatch_default_cpu(
} }
return tid; return tid;
}; };
OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}};
if (state.profiler->is_profiling()) {
state.profiler->record_host<HostOpExecuteEvent>(event_data);
}
auto apply_id = ++m_last_id;
RECORD_EVENT(OpExecuteEvent, apply_id, op, tinfo_to_tid(input_infos), {});


OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);


@@ -187,10 +194,8 @@ void ChannelImpl::dispatch_default_cpu(
outputs->push_back(info); outputs->push_back(info);
} }


event_data.outputs = tinfo_to_tid(output_infos);
if (state.profiler->is_profiling()) {
state.profiler->record_host<HostOpFinishEvent>(event_data);
}
RECORD_EVENT(OpExecuteFinishEvent, apply_id, op,
tinfo_to_tid(input_infos), tinfo_to_tid(output_infos));
} }


void ChannelImpl::dispatch_kernel( void ChannelImpl::dispatch_kernel(
@@ -287,17 +292,13 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
if (!value_fetched()) { if (!value_fetched()) {
m_waitee = info; m_waitee = info;
m_buffer.enqueue(GetValue{info}); m_buffer.enqueue(GetValue{info});
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue);
}
RECORD_EVENT(TensorWaitPropEvent, info->id, TensorInfo::HostValue);
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
tensor_ptr = info->ptr; tensor_ptr = info->ptr;
return value_fetched(); return value_fetched();
}); });
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue);
}
RECORD_EVENT(TensorWaitPropFinishEvent, info->id, TensorInfo::HostValue);
m_waitee = nullptr; m_waitee = nullptr;
} }
return tensor_ptr->get_value(); return tensor_ptr->get_value();
@@ -316,16 +317,12 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
mgb_assert(!m_waitee); mgb_assert(!m_waitee);
m_waitee = info; m_waitee = info;
m_buffer.flush(); m_buffer.flush();
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape);
}
RECORD_EVENT(TensorWaitPropEvent, info->id, TensorInfo::Shape);
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
return static_cast<bool>(info->ptr); return static_cast<bool>(info->ptr);
}); });
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape);
}
RECORD_EVENT(TensorWaitPropFinishEvent, info->id, TensorInfo::Shape);
m_waitee = nullptr; m_waitee = nullptr;
TensorShape ret = info->ptr->layout(); TensorShape ret = info->ptr->layout();
mgb_assert(ret.ndim != 0); mgb_assert(ret.ndim != 0);
@@ -338,9 +335,7 @@ DType ChannelImpl::get_dtype(Handle handle) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType);
}
RECORD_EVENT(TensorGetPropEvent, info->id, TensorInfo::DType);
auto ret = info->desc.layout.dtype; auto ret = info->desc.layout.dtype;
mgb_assert(ret.valid()); mgb_assert(ret.valid());
return ret; return ret;
@@ -352,9 +347,7 @@ CompNode ChannelImpl::get_device(Handle handle) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device);
}
RECORD_EVENT(TensorGetPropEvent, info->id, TensorInfo::Device);
auto ret = info->desc.comp_node; auto ret = info->desc.comp_node;
mgb_assert(ret.valid()); mgb_assert(ret.valid());
return ret; return ret;
@@ -370,16 +363,12 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
mgb_assert(!m_waitee); mgb_assert(!m_waitee);
m_waitee = info; m_waitee = info;
m_buffer.flush(); m_buffer.flush();
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue);
}
RECORD_EVENT(TensorWaitPropEvent, info->id, TensorInfo::DevValue);
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
return static_cast<bool>(info->ptr); return static_cast<bool>(info->ptr);
}); });
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue);
}
RECORD_EVENT(TensorWaitPropFinishEvent, info->id, TensorInfo::DevValue);
m_waitee = nullptr; m_waitee = nullptr;
return info->ptr->dev_tensor(); return info->ptr->dev_tensor();
} }
@@ -388,14 +377,10 @@ void ChannelImpl::sync() {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
m_buffer.flush(); m_buffer.flush();
if (state.profiler->is_profiling()) {
state.profiler->record_host<SyncStartEvent>();
}
RECORD_EVENT(SyncEvent);
m_worker.wait_all_task_finish(); m_worker.wait_all_task_finish();
CompNode::sync_all(); CompNode::sync_all();
if (state.profiler->is_profiling()) {
state.profiler->record_host<SyncFinishEvent>();
}
RECORD_EVENT(SyncFinishEvent);
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
check_worker_exc_unsafe(); check_worker_exc_unsafe();
} }
@@ -433,9 +418,7 @@ TensorInfo* ChannelImpl::alloc() {
auto info = m_pool.alloc(); auto info = m_pool.alloc();
m_valid_handle.insert(info); m_valid_handle.insert(info);
info->id = m_last_id++; info->id = m_last_id++;
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorDeclareEvent>(info->id);
}
RECORD_EVENT(TensorDeclareEvent, info->id);
return info; return info;
} }


@@ -491,9 +474,7 @@ void ChannelImpl::recursive_free(TensorInfo* ptr) {
void ChannelImpl::real_free(TensorInfo* ptr) { void ChannelImpl::real_free(TensorInfo* ptr) {
auto& state = get_worker_state(); auto& state = get_worker_state();
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
if (state.profiler->is_profiling()) {
state.profiler->record_host<TensorEraseEvent>(ptr->id);
}
RECORD_EVENT(TensorEraseEvent, ptr->id);
if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
m_dtr.erase_candidate(ptr); m_dtr.erase_candidate(ptr);
} }
@@ -515,8 +496,8 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=tr
lock.lock(); lock.lock();
} }
m_dtr.update_used_time(dest); m_dtr.update_used_time(dest);
if (notice && state.profiler->is_profiling()) {
state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node());
if (notice) {
RECORD_EVENT(TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node());
} }
dest->value_fetched = ptr->value_fetched(); dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer // update tensor desc for static infer
@@ -636,10 +617,10 @@ void ChannelImpl::sync_device_scope(CompNode device) {
auto& prev = state.device_scope_map[device]; auto& prev = state.device_scope_map[device];
auto& current = state.scopes; auto& current = state.scopes;
auto push_scope = [&](std::string name) { auto push_scope = [&](std::string name) {
state.profiler->record_device<DeviceBeginScope>(device, name);
RECORD_DEVICE_EVENT(DeviceScopeEvent, device, name);
}; };
auto pop_scope = [&](std::string name) { auto pop_scope = [&](std::string name) {
state.profiler->record_device<DeviceEndScope>(device, name);
RECORD_DEVICE_EVENT(DeviceScopeFinishEvent, device, name);
}; };
size_t similarity = 0; size_t similarity = 0;
for (size_t i = 0; i < prev.size() && i < current.size(); i++) { for (size_t i = 0; i < prev.size() && i < current.size(); i++) {
@@ -661,17 +642,13 @@ void ChannelImpl::sync_device_scope(CompNode device) {


void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
auto& state = get_worker_state(); auto& state = get_worker_state();
if (state.profiler->is_profiling()) {
state.profiler->record_host<CommandExecuteEvent>(icmd);
}
RECORD_EVENT(CommandExecuteEvent, icmd);
bool finished = false; bool finished = false;
auto do_finish_command = [&]{ auto do_finish_command = [&]{
if (finished) { if (finished) {
return; return;
} }
if (state.profiler->is_profiling()) {
state.profiler->record_host<CommandFinishEvent>(icmd);
}
RECORD_EVENT(CommandFinishEvent, icmd);
finished = true; finished = true;
}; };
//TODO: remove std::visit for support osx 10.12 //TODO: remove std::visit for support osx 10.12
@@ -701,16 +678,14 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
tensor_inputs.push_back(i->ptr); tensor_inputs.push_back(i->ptr);
} }
// Begin profiling operator // Begin profiling operator
OpEvent event_data;
auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
SmallVector<uint64_t> tid;
for (auto* ptinfo: tinfo) {
tid.push_back(ptinfo->id);
}
return tid;
};
if (state.profiler->is_profiling()) { if (state.profiler->is_profiling()) {
auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
SmallVector<uint64_t> tid;
for (auto* ptinfo: tinfo) {
tid.push_back(ptinfo->id);
}
return tid;
};
event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)};
// Collecting devices // Collecting devices
for (auto i : cmd.inputs) { for (auto i : cmd.inputs) {
devices.push_back(i->desc.comp_node); devices.push_back(i->desc.comp_node);
@@ -731,11 +706,12 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
// Before wait // Before wait
//TODO: split operator wait and execute so that OpWait could be corrected recorded. //TODO: split operator wait and execute so that OpWait could be corrected recorded.
// Before execute // Before execute
RECORD_EVENT(OpExecuteEvent, apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs));
if (state.profiler->is_profiling()) { if (state.profiler->is_profiling()) {
state.profiler->record_host<HostOpExecuteEvent>(event_data);
for (auto&& device: devices) { for (auto&& device: devices) {
sync_device_scope(device); sync_device_scope(device);
state.profiler->record_device<DeviceOpExecuteEvent>(device, event_data);
RECORD_DEVICE_EVENT(KernelExecuteEvent, device, apply_id, cmd.op,
tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs));
} }
} }
if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) { if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) {
@@ -746,10 +722,10 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
auto tensor_outputs = OpDef::apply_on_physical_tensor( auto tensor_outputs = OpDef::apply_on_physical_tensor(
*cmd.op, std::move(tensor_inputs)); *cmd.op, std::move(tensor_inputs));
// After execute // After execute
RECORD_EVENT(OpExecuteFinishEvent, apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs));
if (state.profiler->is_profiling()) { if (state.profiler->is_profiling()) {
state.profiler->record_host<HostOpFinishEvent>(event_data);
for (auto&& device: devices) { for (auto&& device: devices) {
state.profiler->record_device<DeviceOpFinishEvent>(device, event_data);
RECORD_DEVICE_EVENT(KernelExecuteFinishEvent, device, apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs));
} }
} }
// End profiling operator // End profiling operator
@@ -875,12 +851,12 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
} else if constexpr (std::is_same_v<T, PushScope>) { } else if constexpr (std::is_same_v<T, PushScope>) {
state.scopes.push_back(cmd.scope_name); state.scopes.push_back(cmd.scope_name);
do_finish_command(); do_finish_command();
state.profiler->record_host<WorkerBeginScope>(cmd.scope_name);
RECORD_EVENT(ScopeEvent, cmd.scope_name);
} else if constexpr (std::is_same_v<T, PopScope>) { } else if constexpr (std::is_same_v<T, PopScope>) {
mgb_assert(state.scopes.back() == cmd.scope_name, "scope name mismatch"); mgb_assert(state.scopes.back() == cmd.scope_name, "scope name mismatch");
state.scopes.pop_back(); state.scopes.pop_back();
do_finish_command(); do_finish_command();
state.profiler->record_host<WorkerEndScope>(cmd.scope_name);
RECORD_EVENT(ScopeFinishEvent, cmd.scope_name);
} else { } else {
static_assert(!std::is_same_v<T, T>); static_assert(!std::is_same_v<T, T>);
} }
@@ -938,9 +914,7 @@ void ChannelImpl::CommandBuffer::flush(Handle pos) {
for (auto iter = m_commands.begin(); iter != pos; ++iter) { for (auto iter = m_commands.begin(); iter != pos; ++iter) {
// mgb_log_debug("%s Flushed", to_string(*iter).c_str()); // mgb_log_debug("%s Flushed", to_string(*iter).c_str());
IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)}; IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)};
if (state.profiler->is_profiling()) {
state.profiler->record_host<CommandEnqueueEvent>(icmd);
}
RECORD_EVENT(CommandEnqueueEvent, icmd);
m_owner->m_worker.add_task(std::move(icmd)); m_owner->m_worker.add_task(std::move(icmd));
} }
m_commands.erase(m_commands.begin(), pos); m_commands.erase(m_commands.begin(), pos);
@@ -1060,8 +1034,8 @@ void ChannelImpl::stop_profile(std::string basename, std::string format) {
void ChannelImpl::push_scope(std::string name) { void ChannelImpl::push_scope(std::string name) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
RECORD_EVENT(ScopeEvent, name);
if (state.profiler->is_profiling()) { if (state.profiler->is_profiling()) {
state.profiler->record_host<ChannelBeginScope>(name);
state.scopes.push_back(name); state.scopes.push_back(name);
m_buffer.enqueue(PushScope{name}); m_buffer.enqueue(PushScope{name});
} }
@@ -1070,10 +1044,10 @@ void ChannelImpl::push_scope(std::string name) {
void ChannelImpl::pop_scope(std::string name) { void ChannelImpl::pop_scope(std::string name) {
mgb_assert(check_available(), "Channel already closed"); mgb_assert(check_available(), "Channel already closed");
auto& state = get_channel_state(); auto& state = get_channel_state();
RECORD_EVENT(ScopeFinishEvent, name);
if (state.profiler->is_profiling()) { if (state.profiler->is_profiling()) {
mgb_assert((!state.scopes.empty()) && state.scopes.back() == name, "scope name mismatch"); mgb_assert((!state.scopes.empty()) && state.scopes.back() == name, "scope name mismatch");
state.scopes.pop_back(); state.scopes.pop_back();
state.profiler->record_host<ChannelEndScope>(name);
m_buffer.enqueue(PopScope{name}); m_buffer.enqueue(PopScope{name});
} }
} }


+ 10
- 12
imperative/src/impl/interpreter/profiler.h View File

@@ -21,15 +21,13 @@ namespace mgb::imperative::interpreter::intl {


class InterpreterProfiler: public Profiler< class InterpreterProfiler: public Profiler<
CommandEnqueueEvent, CommandExecuteEvent, CommandFinishEvent, CommandEnqueueEvent, CommandExecuteEvent, CommandFinishEvent,
HostOpExecuteEvent, HostOpFinishEvent,
DeviceOpExecuteEvent, DeviceOpFinishEvent,
OpExecuteEvent, OpExecuteFinishEvent,
KernelExecuteEvent, KernelExecuteFinishEvent,
TensorDeclareEvent, TensorProduceEvent, TensorEraseEvent, TensorDeclareEvent, TensorProduceEvent, TensorEraseEvent,
TensorGetPropEvent, TensorWaitPropEvent, TensorNotifyPropEvent, TensorWaitPropFinishEvent, TensorGetPropEvent, TensorWaitPropEvent, TensorNotifyPropEvent, TensorWaitPropFinishEvent,
SyncStartEvent, SyncFinishEvent,
ChannelBeginScope, ChannelEndScope,
WorkerBeginScope, WorkerEndScope,
DeviceBeginScope, DeviceEndScope> {
/*22 events now. Enum code may be a better solution*/
SyncEvent, SyncFinishEvent,
ScopeEvent, ScopeFinishEvent,
DeviceScopeEvent, DeviceScopeFinishEvent> {


public: public:
enum Topic { enum Topic {
@@ -71,8 +69,8 @@ public:
result |= mask_of<CommandEnqueueEvent, CommandExecuteEvent, CommandFinishEvent>(); result |= mask_of<CommandEnqueueEvent, CommandExecuteEvent, CommandFinishEvent>();
} }
if (topic & Operator) { if (topic & Operator) {
result |= mask_of<HostOpExecuteEvent, HostOpFinishEvent>();
result |= mask_of<DeviceOpExecuteEvent, DeviceOpFinishEvent>();
result |= mask_of<OpExecuteEvent, OpExecuteFinishEvent>();
result |= mask_of<KernelExecuteEvent, KernelExecuteFinishEvent>();
} }
if (topic & TensorLifetime) { if (topic & TensorLifetime) {
result |= mask_of<TensorDeclareEvent, TensorProduceEvent, TensorEraseEvent>(); result |= mask_of<TensorDeclareEvent, TensorProduceEvent, TensorEraseEvent>();
@@ -81,11 +79,11 @@ public:
result |= mask_of<TensorGetPropEvent, TensorWaitPropEvent, TensorNotifyPropEvent, TensorWaitPropFinishEvent>(); result |= mask_of<TensorGetPropEvent, TensorWaitPropEvent, TensorNotifyPropEvent, TensorWaitPropFinishEvent>();
} }
if (topic & Sync) { if (topic & Sync) {
result |= mask_of<SyncStartEvent, SyncFinishEvent>();
result |= mask_of<SyncEvent, SyncFinishEvent>();
} }
if (topic & Scope) { if (topic & Scope) {
result |= mask_of<ChannelBeginScope, ChannelEndScope, WorkerBeginScope, WorkerEndScope>();
result |= mask_of<DeviceBeginScope, DeviceEndScope>();
result |= mask_of<ScopeEvent, ScopeFinishEvent>();
result |= mask_of<DeviceScopeEvent, DeviceScopeFinishEvent>();
} }
return result; return result;
} }


Loading…
Cancel
Save