|
|
@@ -6,6 +6,7 @@ |
|
|
|
#include <unordered_set> |
|
|
|
#include "megbrain/custom/op.h" |
|
|
|
#include "megbrain/custom/utils.h" |
|
|
|
#include "megbrain/utils/thin/function.h" |
|
|
|
|
|
|
|
using namespace mgb; |
|
|
|
|
|
|
@@ -99,40 +100,6 @@ std::string ArgInfo::str() const { |
|
|
|
(arg_info).name().c_str(), static_cast<int>((arg_info).ndim()), \ |
|
|
|
static_cast<int>((real_shape).ndim())) |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
class Function; |
|
|
|
|
|
|
|
template <typename RType, typename... Args> |
|
|
|
class Function<RType(Args...)> { |
|
|
|
public: |
|
|
|
using Functor = RType (*)(Args...); |
|
|
|
|
|
|
|
Function() = default; |
|
|
|
Function(Functor f) : m_f(f) {} |
|
|
|
Function(const Function& rhs) { m_f = rhs.m_f; } |
|
|
|
|
|
|
|
RType operator()(Args... args) { |
|
|
|
custom_assert(m_f != nullptr, "invalid function ptr\n"); |
|
|
|
return m_f(std::forward<Args>(args)...); |
|
|
|
} |
|
|
|
|
|
|
|
void operator=(const Function& rhs) { // not allowed continuous assignment |
|
|
|
m_f = rhs.m_f; |
|
|
|
} |
|
|
|
|
|
|
|
void operator=(const Functor f) { m_f = f; } |
|
|
|
|
|
|
|
private: |
|
|
|
Functor m_f = nullptr; |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename Functions> |
|
|
|
class FuncWithSig : public Functions { |
|
|
|
public: |
|
|
|
using Functions::operator(); |
|
|
|
using Functions::operator=; |
|
|
|
}; |
|
|
|
|
|
|
|
class CustomOpImpl { |
|
|
|
static constexpr uint32_t CURRENT_VERSION = CUSTOM_OP_VERSION; |
|
|
|
const uint32_t m_version; |
|
|
@@ -143,29 +110,26 @@ class CustomOpImpl { |
|
|
|
std::vector<ArgInfo> m_output_infos; |
|
|
|
ParamInfo m_param_infos; |
|
|
|
|
|
|
|
using DeviceInfer = FuncWithSig<Function<void( |
|
|
|
const std::vector<Device>&, const Param&, std::vector<Device>&)>>; |
|
|
|
using ShapeInfer = FuncWithSig<Function<void( |
|
|
|
const std::vector<Shape>&, const Param&, std::vector<Shape>&)>>; |
|
|
|
using DTypeInfer = FuncWithSig<Function<void( |
|
|
|
const std::vector<DType>&, const Param&, std::vector<DType>&)>>; |
|
|
|
using FormatInfer = FuncWithSig<Function<void( |
|
|
|
const std::vector<Format>&, const Param&, std::vector<Format>&)>>; |
|
|
|
using Preprocess = FuncWithSig<Function<void( |
|
|
|
const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>; |
|
|
|
using Postprocess = FuncWithSig<Function<void( |
|
|
|
const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>; |
|
|
|
using Compute = FuncWithSig<Function<void( |
|
|
|
const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>; |
|
|
|
using DeviceInfer = thin_function<void( |
|
|
|
const std::vector<Device>&, const Param&, std::vector<Device>&)>; |
|
|
|
using ShapeInfer = thin_function<void( |
|
|
|
const std::vector<Shape>&, const Param&, std::vector<Shape>&)>; |
|
|
|
using DTypeInfer = thin_function<void( |
|
|
|
const std::vector<DType>&, const Param&, std::vector<DType>&)>; |
|
|
|
using FormatInfer = thin_function<void( |
|
|
|
const std::vector<Format>&, const Param&, std::vector<Format>&)>; |
|
|
|
using Process = thin_function<void( |
|
|
|
const std::vector<Tensor>&, const Param&, std::vector<Tensor>&, |
|
|
|
const RuntimeArgs&)>; |
|
|
|
|
|
|
|
DeviceInfer infer_output_device_func; |
|
|
|
ShapeInfer infer_output_shape_func; |
|
|
|
DTypeInfer infer_output_dtype_func; |
|
|
|
FormatInfer infer_output_format_func; |
|
|
|
|
|
|
|
std::unordered_map<std::string, Compute> compute_funcs; |
|
|
|
std::unordered_map<std::string, Preprocess> preprocess_funcs; |
|
|
|
std::unordered_map<std::string, Postprocess> postprocess_funcs; |
|
|
|
std::unordered_map<std::string, Process> compute_funcs; |
|
|
|
std::unordered_map<std::string, Process> preprocess_funcs; |
|
|
|
std::unordered_map<std::string, Process> postprocess_funcs; |
|
|
|
|
|
|
|
public: |
|
|
|
CustomOpImpl(const std::string&, uint32_t version); |
|
|
@@ -215,7 +179,8 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version) |
|
|
|
|
|
|
|
for (const auto& device : Device::legal_devices()) { |
|
|
|
compute_funcs[device] = [](const std::vector<Tensor>&, const Param&, |
|
|
|
std::vector<Tensor>& outputs) -> void { |
|
|
|
std::vector<Tensor>& outputs, |
|
|
|
const RuntimeArgs&) -> void { |
|
|
|
auto device = outputs[0].device(); |
|
|
|
mgb_assert( |
|
|
|
false, |
|
|
@@ -224,9 +189,11 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version) |
|
|
|
device.str().c_str()); |
|
|
|
}; |
|
|
|
preprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&, |
|
|
|
std::vector<Tensor>&) -> void { return; }; |
|
|
|
std::vector<Tensor>&, |
|
|
|
const RuntimeArgs&) -> void { return; }; |
|
|
|
postprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&, |
|
|
|
std::vector<Tensor>&) -> void { return; }; |
|
|
|
std::vector<Tensor>&, |
|
|
|
const RuntimeArgs&) -> void { return; }; |
|
|
|
} |
|
|
|
m_param_infos.set_tag(op_type); |
|
|
|
} |
|
|
@@ -256,33 +223,78 @@ CustomOp& CustomOp::set_format_infer(FormatInferFuncPtr func) { |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
CustomOp& CustomOp::set_preprocess(PreprocessFuncPtr func) { |
|
|
|
CustomOp& CustomOp::set_preprocess(ProcessFuncPtrWithoutRuntimeArgs func) { |
|
|
|
set_preprocess("x86", func); |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
CustomOp& CustomOp::set_preprocess( |
|
|
|
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func) { |
|
|
|
auto wrap_func = [func](const std::vector<Tensor>& input, const Param& param, |
|
|
|
std::vector<Tensor>& output, const RuntimeArgs&) -> void { |
|
|
|
return func(input, param, output); |
|
|
|
}; |
|
|
|
|
|
|
|
OpImplRef(m_impl.get())->preprocess_funcs[device] = wrap_func; |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
CustomOp& CustomOp::set_preprocess(ProcessFuncPtr func) { |
|
|
|
set_preprocess("x86", func); |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
CustomOp& CustomOp::set_preprocess(const std::string& device, PreprocessFuncPtr func) { |
|
|
|
CustomOp& CustomOp::set_preprocess(const std::string& device, ProcessFuncPtr func) { |
|
|
|
OpImplRef(m_impl.get())->preprocess_funcs[device] = func; |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
CustomOp& CustomOp::set_postprocess(PostprocessFuncPtr func) { |
|
|
|
CustomOp& CustomOp::set_postprocess(ProcessFuncPtrWithoutRuntimeArgs func) { |
|
|
|
set_postprocess("x86", func); |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
CustomOp& CustomOp::set_postprocess( |
|
|
|
const std::string& device, PostprocessFuncPtr func) { |
|
|
|
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func) { |
|
|
|
auto wrap_func = [func](const std::vector<Tensor>& input, const Param& param, |
|
|
|
std::vector<Tensor>& output, |
|
|
|
const RuntimeArgs&) -> void { func(input, param, output); }; |
|
|
|
|
|
|
|
OpImplRef(m_impl.get())->postprocess_funcs[device] = wrap_func; |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
CustomOp& CustomOp::set_postprocess(ProcessFuncPtr func) { |
|
|
|
set_postprocess("x86", func); |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
CustomOp& CustomOp::set_postprocess(const std::string& device, ProcessFuncPtr func) { |
|
|
|
OpImplRef(m_impl.get())->postprocess_funcs[device] = func; |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
CustomOp& CustomOp::set_compute(ComputeFuncPtr func) { |
|
|
|
CustomOp& CustomOp::set_compute(ProcessFuncPtrWithoutRuntimeArgs func) { |
|
|
|
set_compute("x86", func); |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
CustomOp& CustomOp::set_compute( |
|
|
|
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func) { |
|
|
|
auto wrap_func = [func](const std::vector<Tensor>& input, const Param& param, |
|
|
|
std::vector<Tensor>& output, |
|
|
|
const RuntimeArgs&) -> void { func(input, param, output); }; |
|
|
|
|
|
|
|
OpImplRef(m_impl.get())->compute_funcs[device] = wrap_func; |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
CustomOp& CustomOp::set_compute(ProcessFuncPtr func) { |
|
|
|
set_compute("x86", func); |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
CustomOp& CustomOp::set_compute(const std::string& device, ComputeFuncPtr func) { |
|
|
|
CustomOp& CustomOp::set_compute(const std::string& device, ProcessFuncPtr func) { |
|
|
|
OpImplRef(m_impl.get())->compute_funcs[device] = func; |
|
|
|
return *this; |
|
|
|
} |
|
|
@@ -513,23 +525,28 @@ void CustomOp::compute( |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
std::string device = outputs[0].device().str(); |
|
|
|
Device device = outputs[0].device(); |
|
|
|
std::string device_str = device.str(); |
|
|
|
for (size_t i = 1; i < outputs.size(); ++i) { |
|
|
|
mgb_assert( |
|
|
|
outputs[i].device().str() == device, |
|
|
|
outputs[i].device().str() == device_str, |
|
|
|
"all output tensors should have the same device attribute"); |
|
|
|
} |
|
|
|
|
|
|
|
// need to add other input/output check |
|
|
|
mgb_assert(Device::is_legal(device), "unsupported device type: %s", device.c_str()); |
|
|
|
mgb_assert( |
|
|
|
Device::is_legal(device_str), "unsupported device type: %s", |
|
|
|
device_str.c_str()); |
|
|
|
|
|
|
|
auto preprocess_func = OpImplRef(m_impl.get())->preprocess_funcs[device_str]; |
|
|
|
auto forward_func = OpImplRef(m_impl.get())->compute_funcs[device_str]; |
|
|
|
auto postprocess_func = OpImplRef(m_impl.get())->postprocess_funcs[device_str]; |
|
|
|
|
|
|
|
auto preprocess_func = OpImplRef(m_impl.get())->preprocess_funcs[device]; |
|
|
|
auto forward_func = OpImplRef(m_impl.get())->compute_funcs[device]; |
|
|
|
auto postprocess_func = OpImplRef(m_impl.get())->postprocess_funcs[device]; |
|
|
|
RuntimeArgs rt_args(device); |
|
|
|
|
|
|
|
preprocess_func(inputs, param, outputs); |
|
|
|
forward_func(inputs, param, outputs); |
|
|
|
postprocess_func(outputs, param, outputs); |
|
|
|
preprocess_func(inputs, param, outputs, rt_args); |
|
|
|
forward_func(inputs, param, outputs, rt_args); |
|
|
|
postprocess_func(outputs, param, outputs, rt_args); |
|
|
|
assert_outputs_size_right(outputs); |
|
|
|
} |
|
|
|
|
|
|
|