|
|
@@ -339,16 +339,17 @@ void ChannelImpl::dispatch_kernel( |
|
|
|
auto& state = get_channel_state(); |
|
|
|
auto& options = state.options; |
|
|
|
|
|
|
|
auto name = op->trait()->make_name(*op); |
|
|
|
auto _ = StackManager::Guard{name, &state.stack_manager}; |
|
|
|
|
|
|
|
auto [output_descs, validated] = |
|
|
|
OpDef::infer_output_attrs_fallible(*op, input_descs); |
|
|
|
MGB_RECORD_EVENT(ShapeInferEvent, validated); |
|
|
|
|
|
|
|
SmallVector<TensorInfo*> output_infos; |
|
|
|
output_infos.reserve(output_descs.size()); |
|
|
|
uint64_t apply_id = Profiler::next_id(); |
|
|
|
|
|
|
|
outputs->reserve(output_descs.size()); |
|
|
|
|
|
|
|
for (int i = 0; i < output_descs.size(); ++i) { |
|
|
|
auto&& desc = output_descs[i]; |
|
|
|
auto info = alloc(); |
|
|
@@ -361,31 +362,28 @@ void ChannelImpl::dispatch_kernel( |
|
|
|
output_infos.push_back(info); |
|
|
|
outputs->push_back(reinterpret_cast<Handle>(info)); |
|
|
|
} |
|
|
|
auto op_info_getter = [op] { |
|
|
|
std::unordered_map<std::string, std::string> op_info; |
|
|
|
auto props = OpDef::props(*op); |
|
|
|
for (auto&& [key, value] : props) { |
|
|
|
op_info[key] = value; |
|
|
|
} |
|
|
|
return op_info; |
|
|
|
}; |
|
|
|
ApplyOp cmd{ |
|
|
|
Profiler::next_id(), std::move(op), std::move(input_infos), |
|
|
|
std::move(output_infos), validated}; |
|
|
|
if (Profiler::is_profiling()) { |
|
|
|
auto name = op->trait()->make_name(*op); |
|
|
|
auto _ = StackManager::Guard{name, &state.stack_manager}; |
|
|
|
auto op_info_getter = [op = cmd.op] { |
|
|
|
std::unordered_map<std::string, std::string> op_info; |
|
|
|
auto props = OpDef::props(*op); |
|
|
|
for (auto&& [key, value] : props) { |
|
|
|
op_info[key] = value; |
|
|
|
} |
|
|
|
return op_info; |
|
|
|
}; |
|
|
|
MGB_RECORD_EVENT( |
|
|
|
OpDispatchEvent, apply_id, name, op_info_getter, |
|
|
|
tinfo_to_tid(std::move(input_infos)), |
|
|
|
tinfo_to_tid(std::move(output_infos)), state.stack_manager.dump()); |
|
|
|
OpDispatchEvent, cmd.id, name, op_info_getter, tinfo_to_tid(cmd.inputs), |
|
|
|
tinfo_to_tid(cmd.outputs), state.stack_manager.dump()); |
|
|
|
m_worker.add_task( |
|
|
|
{Profiler::next_id(), |
|
|
|
ApplyOp{apply_id, std::move(op), std::move(input_infos), |
|
|
|
std::move(output_infos), validated}, |
|
|
|
{Profiler::next_id(), std::move(cmd), |
|
|
|
get_channel_state().stack_manager.dump()}); |
|
|
|
} else { |
|
|
|
m_worker.add_task({ |
|
|
|
Profiler::next_id(), |
|
|
|
ApplyOp{apply_id, std::move(op), std::move(input_infos), |
|
|
|
std::move(output_infos), validated}, |
|
|
|
std::move(cmd), |
|
|
|
}); |
|
|
|
} |
|
|
|
if (!validated && options.async_level == 1) { |
|
|
|