From b3e958d0bc53bd0f97e2131d1944f2d2d3af5cee Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 9 Sep 2021 13:24:23 +0800 Subject: [PATCH] fix(src): fix the warnings and copy.bara.sky in custom op GitOrigin-RevId: 4ade45589798cce48286ddb0f3a4298f7e03cd49 --- imperative/python/src/ops.cpp | 26 +- imperative/src/impl/ops/custom_opdef.cpp | 6 +- src/custom/impl/manager.cpp | 181 ++++++++ src/custom/impl/op.cpp | 531 ++++++++++++++++++++++ src/custom/impl/param.cpp | 179 ++++++++ src/custom/impl/param_val.cpp | 400 ++++++++++++++++ src/custom/impl/tensor.cpp | 486 ++++++++++++++++++++ src/custom/impl/utils.cpp | 41 ++ src/custom/include/megbrain/custom/accessor.h | 185 ++++++++ src/custom/include/megbrain/custom/custom.h | 108 +++++ src/custom/include/megbrain/custom/data_adaptor.h | 58 +++ src/custom/include/megbrain/custom/manager.h | 75 +++ src/custom/include/megbrain/custom/op.h | 109 +++++ src/custom/include/megbrain/custom/param.h | 61 +++ src/custom/include/megbrain/custom/param_val.h | 290 ++++++++++++ src/custom/include/megbrain/custom/tensor.h | 280 ++++++++++++ src/custom/include/megbrain/custom/utils.h | 104 +++++ src/custom/test/manager.cpp | 96 ++++ src/custom/test/op.cpp | 205 +++++++++ src/custom/test/param.cpp | 208 +++++++++ src/custom/test/tensor.cpp | 325 +++++++++++++ src/opr/impl/custom_opnode.cpp | 24 +- src/opr/include/megbrain/opr/custom_opnode.h | 4 +- test/CMakeLists.txt | 2 +- 24 files changed, 3959 insertions(+), 25 deletions(-) create mode 100644 src/custom/impl/manager.cpp create mode 100644 src/custom/impl/op.cpp create mode 100644 src/custom/impl/param.cpp create mode 100644 src/custom/impl/param_val.cpp create mode 100644 src/custom/impl/tensor.cpp create mode 100644 src/custom/impl/utils.cpp create mode 100644 src/custom/include/megbrain/custom/accessor.h create mode 100644 src/custom/include/megbrain/custom/custom.h create mode 100644 src/custom/include/megbrain/custom/data_adaptor.h create mode 100644 src/custom/include/megbrain/custom/manager.h create mode 100644 src/custom/include/megbrain/custom/op.h create mode 100644 src/custom/include/megbrain/custom/param.h create mode 100644 src/custom/include/megbrain/custom/param_val.h create mode 100644 src/custom/include/megbrain/custom/tensor.h create mode 100644 src/custom/include/megbrain/custom/utils.h create mode 100644 src/custom/test/manager.cpp create mode 100644 src/custom/test/op.cpp create mode 100644 src/custom/test/param.cpp create mode 100644 src/custom/test/tensor.cpp diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index f9e1fbde..60127c4f 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -613,17 +613,17 @@ void init_ops(py::module m) { } #define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type) \ - case mgb::custom::ParamDynType::dyn_type: { \ + case custom::ParamDynType::dyn_type: { \ param_val = py::handle(kv.second).cast(); \ break; \ } #define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type) \ - case mgb::custom::ParamDynType::dyn_type: { \ + case custom::ParamDynType::dyn_type: { \ auto pyvals = py::handle(kv.second).cast(); \ static_type vals; \ using basic_type = \ - mgb::custom::get_vector_template_arg_type::type; \ + custom::get_vector_template_arg_type::type; \ for (auto &pyval: pyvals) { \ vals.push_back(py::handle(pyval).cast()); \ } \ @@ -631,7 +631,7 @@ void init_ops(py::module m) { break; \ } -PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs, PyObject *kwnames) { +PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs) { auto op_name = py::handle(args[0]).cast(); auto kwargs = py::handle(args[1]).cast(); @@ -680,7 +680,7 @@ PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs, PyOb py::list install_custom(const std::string &name, const std::string &path) { py::list ret; - const auto &ops_in_lib = mgb::custom::LibManager::inst()->install(name, path); + const auto &ops_in_lib = custom::LibManager::inst()->install(name, path); for (const auto &op: ops_in_lib) { ret.append(op); } @@ -688,7 +688,7 @@ py::list install_custom(const std::string &name, const std::string &path) { } bool uninstall_custom(const std::string &name) { - return mgb::custom::LibManager::inst()->uninstall(name); + return custom::LibManager::inst()->uninstall(name); } py::list get_custom_op_list(void) { @@ -697,16 +697,28 @@ py::list get_custom_op_list(void) { for (auto &op: all_ops) { ret.append(op); } - return std::move(ret); + return ret; } +#ifndef METH_FASTCALL + PyObject* py35_make_custom_op(PyObject* self, PyObject* args) { + auto* arr = &PyTuple_GET_ITEM(args, 0); + auto size = PyTuple_GET_SIZE(args); + return make_custom_op(self, arr, size); + }; +#endif + void init_custom(pybind11::module m) { m.def("_install", &install_custom); m.def("_uninstall", &uninstall_custom); m.def("_get_custom_op_list", &get_custom_op_list); static PyMethodDef method_def = { +#ifdef METH_FASTCALL "_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" +#else + "_make_custom_op", (PyCFunction)py35_make_custom_op, METH_VARARGS, "" +#endif }; auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr); pybind11::setattr(m, method_def.ml_name, func); diff --git a/imperative/src/impl/ops/custom_opdef.cpp b/imperative/src/impl/ops/custom_opdef.cpp index b178c07c..b59ff282 100644 --- a/imperative/src/impl/ops/custom_opdef.cpp +++ b/imperative/src/impl/ops/custom_opdef.cpp @@ -70,7 +70,7 @@ void CustomOpDef::compute(const SmallVector &inputs, std::tuple, bool> CustomOpDef::infer_output_attrs( const SmallVector &inputs) const { SmallVector input_descs(inputs.size()); - for (int i=0; icomp_node(); input_descs[i].layout = inputs[i]->layout(); } @@ -84,7 +84,7 @@ std::tuple, bool> CustomOpDef::infer_output_attrs SmallVector i_dtypes(inputs.size()); SmallVector i_formats(inputs.size()); - for (int i=0; i, bool> CustomOpDef::infer_output_attrs } SmallVector outputs(this->output_num()); - for (int i=0; ioutput_num(); i++) { + for (size_t i=0; ioutput_num(); i++) { outputs[i].comp_node = std::move(o_devices[i]); outputs[i].layout = std::move( TensorLayout(o_shapes[i], o_dtypes[i], o_formats[i]) diff --git a/src/custom/impl/manager.cpp b/src/custom/impl/manager.cpp new file mode 100644 index 00000000..ddef40c8 --- /dev/null +++ b/src/custom/impl/manager.cpp @@ -0,0 +1,181 @@ +/** + * \file src/custom/impl/manager.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/custom/manager.h" +#include "megbrain/common.h" +#include + +#ifndef _WIN32 +#include +#endif + +using namespace mgb; + +namespace custom { + +CustomOpManager *CustomOpManager::inst(void) { + static CustomOpManager op_manager; + return &op_manager; +} + +CustomOpManager::~CustomOpManager() { + mgb_assert(m_name2op.size() == m_id2op.size(), "Custom Op maintenance error!"); + LibManager::inst()->m_custom_libs.clear(); +} + +std::shared_ptr CustomOpManager::insert(const std::string &name, uint32_t version) { + MGB_LOCK_GUARD(m_mtx); + auto iter = m_name2op.find(name); + if (iter != m_name2op.end()) { + mgb_log_warn("Register Custom Op Failed! Op %s has been registered", name.c_str()); + return std::const_pointer_cast(iter->second); + } + std::shared_ptr op = std::make_shared(name, version); + m_name2op[op->op_type()] = op; + m_id2op[op->runtime_id()] = op; + return std::const_pointer_cast(op); +} + +bool CustomOpManager::erase(const std::string &name) { + MGB_LOCK_GUARD(m_mtx); + auto iter = m_name2op.find(name); + if (iter == m_name2op.end()) { + mgb_log_warn("Erase Custom Op Failed! %s has not been registered", name.c_str()); + return false; + } + std::shared_ptr op = iter->second; + m_id2op.erase(op->runtime_id()); + m_name2op.erase(op->op_type()); + return true; +} + +bool CustomOpManager::erase(const RunTimeId &id) { + MGB_LOCK_GUARD(m_mtx); + auto iter = m_id2op.find(id); + if (iter == m_id2op.end()) { + mgb_log_warn("Erase Custom Op Failed! The Op has not been registered"); + return false; + } + std::shared_ptr op = iter->second; + m_id2op.erase(op->runtime_id()); + m_name2op.erase(op->op_type()); + return true; +} + +std::shared_ptr CustomOpManager::find_or_reg(const std::string &name, uint32_t version) { + auto iter = m_name2op.find(name); + if (iter == m_name2op.end()) { + return insert(name, version); + } + return std::const_pointer_cast(iter->second); +} + +RunTimeId CustomOpManager::to_id(const std::string &name) const { + std::shared_ptr op = find(name); + return op->runtime_id(); +} + +std::string CustomOpManager::to_name(const RunTimeId &id) const { + std::shared_ptr op = find(id); + return op->op_type(); +} + +std::shared_ptr CustomOpManager::find(const std::string &name) const { + auto ret = m_name2op.find(name); + mgb_assert(ret != m_name2op.end(), + "Find Custom Op Failed! Op %s has not been registered", name.c_str() + ); + return ret->second; +} + +std::shared_ptr CustomOpManager::find(const RunTimeId &id) const { + auto ret = m_id2op.find(id); + mgb_assert(ret != m_id2op.end(), "Find Custom Op Failed! Op has not been registered"); + return ret->second; +} + +std::vector CustomOpManager::op_name_list(void) { + std::vector ret; + for (auto kv: m_name2op) { + ret.emplace_back(kv.first); + } + return ret; +} + +std::vector CustomOpManager::op_id_list(void) { + std::vector ret; + for (auto kv: m_id2op) { + ret.emplace_back(kv.first); + } + return ret; +} + +#ifndef _WIN32 +CustomLib::CustomLib(const std::string &path, int mode = RTLD_LAZY) + : m_handle(nullptr, [](void* handle) {dlclose(handle);}) { + auto op_list_before_load = CustomOpManager::inst()->op_name_list(); + std::unordered_set op_set_before_load( + op_list_before_load.begin(), op_list_before_load.end()); + + m_handle.reset(dlopen(path.c_str(), mode)); + mgb_assert(m_handle != nullptr, "open custom op lib failed, error type: %s", dlerror()); + + auto op_list_after_load = CustomOpManager::inst()->op_name_list(); + for (auto &op: op_list_after_load) { + if (op_set_before_load.find(op) == op_set_before_load.end()) { + m_ops.emplace_back(op); + } + } +} +#else +CustomLib::CustomLib(const std::string &path, int mode = 0) + : m_handle(nullptr, [](void* handle) {}) { + mgb_assert(false, "custom op is only supported on Linux now"); +} +#endif + +const std::vector &CustomLib::ops_in_lib(void) const { + return m_ops; +} + +CustomLib::~CustomLib() { + for (auto &op: m_ops) { + CustomOpManager::inst()->erase(op); + } +} + +bool CustomLib::valid() const { + return m_handle != nullptr; +} + +LibManager *LibManager::inst(void) { + static LibManager custom_libs; + return &custom_libs; +} + +const std::vector &LibManager::install(const std::string &name, const std::string &path) { + MGB_LOCK_GUARD(m_mtx);; + LibHandle handle = std::make_shared(path); + m_custom_libs.insert({name, handle}); + return m_custom_libs[name]->ops_in_lib(); +} + +bool LibManager::uninstall(const std::string &name) { + MGB_LOCK_GUARD(m_mtx);; + mgb_assert(m_custom_libs.erase(name) == 1, "uninstall error"); + return true; +} + +std::shared_ptr op_insert(std::string opname, uint32_t version) { + return CustomOpManager::inst()->insert(opname, version); +} + +} diff --git a/src/custom/impl/op.cpp b/src/custom/impl/op.cpp new file mode 100644 index 00000000..156ee988 --- /dev/null +++ b/src/custom/impl/op.cpp @@ -0,0 +1,531 @@ +/** + * \file src/custom/impl/op.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/common.h" +#include "megbrain/custom/op.h" +#include "megbrain/custom/utils.h" +#include +#include + +using namespace mgb; + +namespace custom { + +class ArgInfoImpl { + std::string m_name; + std::string m_desc; + std::unordered_set m_dtypes; + int m_ndim; // use int rather than size_t for representing m_dims = -1 + std::string m_mem_stgy; + + friend class ArgInfo; +}; + +CUSTOM_PIMPL_CLS_DEFINE(ArgInfo) + +ArgInfo::ArgInfo(const std::string &name, + const std::string &desc, + const std::unordered_set &dtypes, + const int &ndim, + const std::string &mem_stgy): m_impl(new ArgInfoImpl(), impl_deleter) { + for (auto &&dtype: dtypes) { + mgb_assert(DType::is_legal(dtype), "unsupported tensor data type: %s", dtype.c_str()); + } + mgb_assert(mem_stgy == "default", "only default mem strategy is supported now!"); + TypedRef(ArgInfoImpl, m_impl.get()).m_name = name; + TypedRef(ArgInfoImpl, m_impl.get()).m_desc = desc; + TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes = dtypes; + TypedRef(ArgInfoImpl, m_impl.get()).m_ndim = ndim; + TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy = mem_stgy; +} + +const std::string &ArgInfo::name(void) const { + return TypedRef(ArgInfoImpl, m_impl.get()).m_name; +} + +const std::string &ArgInfo::desc(void) const { + return TypedRef(ArgInfoImpl, m_impl.get()).m_desc; +} + +const std::unordered_set &ArgInfo::dtypes(void) const { + return TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes; +} + +int ArgInfo::ndim(void) const { + return TypedRef(ArgInfoImpl, m_impl.get()).m_ndim; +} + +const std::string &ArgInfo::mem_strategy(void) const { + return TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy; +} + +std::string ArgInfo::str() const { + std::stringstream ss; + ss << "name: " << TypedRef(ArgInfoImpl, m_impl.get()).m_name << "\n" + << "desc: " << TypedRef(ArgInfoImpl, m_impl.get()).m_desc << "\nlegal_dtypes: {"; + + for (auto &val: TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes) { + ss << val << ", "; + } + if (TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes.size() != 0) { + ss.seekp(ss.tellp()-std::streampos(2)); + } + + ss << "}\ndims: " << TypedRef(ArgInfoImpl, m_impl.get()).m_ndim << "\n" + << "memory_strategy: " << TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy; + return ss.str(); +} + +#define assert_inputs_size_right(inputs_vec) \ + mgb_assert( \ + inputs_vec.size() == input_num(), \ + "op %s need %lu inputs but given %lu", \ + op_type().c_str(), static_cast(input_num()), \ + static_cast(inputs_vec.size()) \ + ) + +#define assert_outputs_size_right(outputs_vec) \ + mgb_assert( \ + outputs_vec.size() == output_num(), \ + "op %s have %lu outputs but given %lu", \ + op_type().c_str(), static_cast(output_num()), \ + static_cast(outputs_vec.size()) \ + ) + +#define assert_arg_shape_dim_right(real_shape, arg_info) \ + mgb_assert( \ + (arg_info).ndim() == -1 || static_cast((real_shape).ndim()) == \ + static_cast((arg_info).ndim()), \ + "%s's args: %s dim match error, need %d but given %d", op_type().c_str(), \ + (arg_info).name().c_str(), static_cast((arg_info).ndim()), \ + static_cast((real_shape).ndim()) \ + ) + +template +class Function; + +template +class Function { +public: + using Functor = RType (*)(Args...); + + Function() = default; + Function(Functor f): m_f(f) {} + Function(const Function &rhs) { + m_f = rhs.m_f; + } + + RType operator()(Args... args) { + custom_assert(m_f != nullptr, "invalid function ptr\n"); + return m_f(std::forward(args)...); + } + + void operator=(const Function &rhs) { // not allowed continuous assignment + m_f = rhs.m_f; + } + + void operator=(const Functor f) { + m_f = f; + } + +private: + Functor m_f = nullptr; +}; + +template +class FuncWithSig: public Functions { +public: + using Functions::operator(); + using Functions::operator=; +}; + +class CustomOpImpl { + static constexpr uint32_t CURRENT_VERSION = CUSTOM_OP_VERSION; + const uint32_t m_version; + + const std::string m_op_type; + std::string m_op_desc; + std::vector m_input_infos; + std::vector m_output_infos; + ParamInfo m_param_infos; + + using DeviceInfer = FuncWithSig&, const Param&, std::vector&)>>; + using ShapeInfer = FuncWithSig&, const Param&, std::vector&)>>; + using DTypeInfer = FuncWithSig&, const Param&, std::vector&)>>; + using FormatInfer = FuncWithSig&, const Param&, std::vector&)>>; + using Preprocess = FuncWithSig&, const Param&, std::vector&)>>; + using Postprocess = FuncWithSig&, const Param&, std::vector&)>>; + using Compute = FuncWithSig&, const Param&, std::vector&)>>; + + DeviceInfer infer_output_device_func; + ShapeInfer infer_output_shape_func; + DTypeInfer infer_output_dtype_func; + FormatInfer infer_output_format_func; + + std::unordered_map compute_funcs; + std::unordered_map preprocess_funcs; + std::unordered_map postprocess_funcs; + +public: + CustomOpImpl(const std::string&, uint32_t version); + PREVENT_COPY_AND_ASSIGN(CustomOpImpl); + friend CustomOp; +}; + +CustomOpImpl::CustomOpImpl(const std::string &op_type, uint32_t version) + : m_version(version), m_op_type(op_type) { + if (m_version != CURRENT_VERSION) { + mgb_log_warn( + "the version of loaded custom op %s is %u, but custom op version " + "of the system is %u\n", op_type.c_str(), m_version, CURRENT_VERSION + ); + } + + infer_output_device_func = [](const std::vector &inputs, + const Param&, + std::vector &outputs) -> void { + static UnImpleWarnLog log_once("output_device_infer", "device", "x86"); + for (size_t i=0; i 0 ? inputs[0] : Device("x86"); + } + }; + + infer_output_shape_func = [](const std::vector &inputs, + const Param&, + std::vector &outputs) -> void { + static UnImpleWarnLog log_once("output_shape_infer", "shape", "{1}"); + for (size_t i=0; i 0 ? inputs[0] : Shape({1}); + } + }; + + infer_output_dtype_func = [](const std::vector &inputs, + const Param&, + std::vector &outputs) -> void { + static UnImpleWarnLog log_once("output_dtype_infer", "dtype", "float32"); + for (size_t i=0; i 0 ? inputs[0] : DType("float32"); + } + }; + + infer_output_format_func = [](const std::vector &inputs, + const Param&, + std::vector &outputs) -> void { + for (size_t i=0; i 0 ? inputs[0] : Format("default"); + } + }; + + for (const auto &device: Device::legal_devices()) { + compute_funcs[device] = [](const std::vector&, const Param&, std::vector &outputs) -> void { + auto device = outputs[0].device(); + mgb_assert(false, "There is no forward function for your op on device `%s`. " + "Please implement this function and register it.", device.str().c_str()); + }; + preprocess_funcs[device] = [](const std::vector&, const Param&, std::vector&) -> void { + return; + }; + postprocess_funcs[device] = [](const std::vector&, const Param&, std::vector&) -> void { + return; + }; + } + m_param_infos.set_tag(op_type); +} + +CustomOp::CustomOp(const std::string &op_type, uint32_t version) + : m_impl(new CustomOpImpl(op_type, version), impl_deleter) { + +} + +#define OpImplRef(raw_ptr) reinterpret_cast(raw_ptr) + +CustomOp &CustomOp::set_device_infer(DeviceInferFuncPtr func) { + OpImplRef(m_impl.get())->infer_output_device_func = func; + return *this; +} + +CustomOp &CustomOp::set_shape_infer(ShapeInferFuncPtr func) { + OpImplRef(m_impl.get())->infer_output_shape_func = func; + return *this; +} + +CustomOp &CustomOp::set_dtype_infer(DTypeInferFuncPtr func) { + OpImplRef(m_impl.get())->infer_output_dtype_func = func; + return *this; +} + +CustomOp &CustomOp::set_format_infer(FormatInferFuncPtr func) { + OpImplRef(m_impl.get())->infer_output_format_func = func; + return *this; +} + +CustomOp &CustomOp::set_preprocess(PreprocessFuncPtr func) { + set_preprocess("x86", func); + return *this; +} + +CustomOp &CustomOp::set_preprocess(const std::string &device, PreprocessFuncPtr func) { + OpImplRef(m_impl.get())->preprocess_funcs[device] = func; + return *this; +} + +CustomOp &CustomOp::set_postprocess(PostprocessFuncPtr func) { + set_postprocess("x86", func); + return *this; +} + +CustomOp &CustomOp::set_postprocess(const std::string &device, PostprocessFuncPtr func) { + OpImplRef(m_impl.get())->postprocess_funcs[device] = func; + return *this; +} + +CustomOp &CustomOp::set_compute(ComputeFuncPtr func) { + set_compute("x86", func); + return *this; +} + +CustomOp &CustomOp::set_compute(const std::string &device, ComputeFuncPtr func) { + OpImplRef(m_impl.get())->compute_funcs[device] = func; + return *this; +} + +CustomOp &CustomOp::set_description(const std::string &op_desc) { + OpImplRef(m_impl.get())->m_op_desc = op_desc; + return *this; +} + +CustomOp &CustomOp::add_input(const std::string &name, const std::string &desc, const std::initializer_list &legal_dtypes, int dims, const std::string &mem_stgy) { + auto &ref = OpImplRef(m_impl.get())->m_input_infos; + for (const auto &input: ref) { + mgb_assert(input.name() != name, "input %s has been registered", name.c_str()); + } + ref.emplace_back(name, desc, legal_dtypes, dims, mem_stgy); + return *this; +} + +CustomOp &CustomOp::add_output(const std::string &name, const std::string &desc, const std::initializer_list &legal_dtypes, int dims, const std::string &mem_stgy) { + auto &ref = OpImplRef(m_impl.get())->m_output_infos; + for (const auto &output: ref) { + mgb_assert(output.name() != name, "output %s has been registered", name.c_str()); + } + ref.emplace_back(name, desc, legal_dtypes, dims, mem_stgy); + return *this; +} + +CustomOp &CustomOp::add_input(const std::string &name, const std::initializer_list &legal_dtypes, int dims, const std::string &mem_stgy) { + add_input(name, name, legal_dtypes, dims, mem_stgy); + return *this; +} + +CustomOp &CustomOp::add_output(const std::string &name, const std::initializer_list &legal_dtypes, int dims, const std::string &mem_stgy) { + add_output(name, name, legal_dtypes, dims, mem_stgy); + return *this; +} + +CustomOp &CustomOp::add_inputs(const size_t &num) { + size_t cur_inp_num = input_num(); + for (size_t i=cur_inp_num; im_param_infos.meta(); + for(const auto &schema: meta) { + mgb_assert(name != schema.name(), "param %s has been registered\n", name.c_str()); + } + ParamSchema sch = ParamSchema(name, default_val, desc); + meta.emplace_back(sch); + return *this; +} + +std::string CustomOp::op_type(void) const { + return OpImplRef(m_impl.get())->m_op_type; +} + +std::string CustomOp::op_desc(void) const { + return OpImplRef(m_impl.get())->m_op_desc; +} + +RunTimeId CustomOp::runtime_id(void) const { + return (RunTimeId)(this); +} + +size_t CustomOp::input_num(void) const { + return OpImplRef(m_impl.get())->m_input_infos.size(); +} + +size_t CustomOp::output_num(void) const { + return OpImplRef(m_impl.get())->m_output_infos.size(); +} + +std::string CustomOp::str(void) const { + std::stringstream ss; + ss << "op name: " << op_type() << "\nop desc: " << op_desc() << "\n\ninputs:\n"; + for (const auto &input: inputs_info()) { + ss << input.str(); + ss << "\n--------------------\n"; + } + ss << "\noutputs:\n"; + for (const auto &output: outputs_info()) { + ss << output.str(); + ss << "\n--------------------\n"; + } + ss << "\nparams:\n"; + for (const auto ¶m: param_info().meta()) { + ss << param.str(); + ss << "\n--------------------\n"; + } + return ss.str(); +} + +const ParamInfo &CustomOp::param_info(void) const { + return OpImplRef(m_impl.get())->m_param_infos; +} + +ArgInfo CustomOp::input_info(size_t idx) const { + return OpImplRef(m_impl.get())->m_input_infos[idx]; +} + +ArgInfo CustomOp::output_info(size_t idx) const { + return OpImplRef(m_impl.get())->m_output_infos[idx]; +} + +const std::vector &CustomOp::inputs_info(void) const { + return OpImplRef(m_impl.get())->m_input_infos; +} + +const std::vector &CustomOp::outputs_info(void) const { + return OpImplRef(m_impl.get())->m_output_infos; +} + +std::vector CustomOp::infer_output_device(const std::vector &inputs, const Param ¶m) const { + assert_inputs_size_right(inputs); + + std::vector outputs(output_num()); + OpImplRef(m_impl.get())->infer_output_device_func(inputs, param, outputs); + + assert_outputs_size_right(outputs); + return outputs; +} + +std::vector CustomOp::infer_output_shape(const std::vector &inputs, const Param ¶m) const { + assert_inputs_size_right(inputs); + for (size_t i=0; i outputs(output_num()); + OpImplRef(m_impl.get())->infer_output_shape_func(inputs, param, outputs); + for (size_t i=0; i CustomOp::infer_output_dtype(const std::vector &inputs, const Param ¶m) const { + assert_inputs_size_right(inputs); + + for (size_t i=0; i legal_input_dtypes_i = input_info(i).dtypes(); + mgb_assert( + legal_input_dtypes_i.find(inputs[i].str()) != legal_input_dtypes_i.end(), + "dtypes of input: %s(%s) is not allowed, the info of this input is:\n%s", + input_info(i).name().c_str(), inputs[i].str().c_str(), + input_info(i).str().c_str() + ); + } + std::vector outputs(output_num()); + OpImplRef(m_impl.get())->infer_output_dtype_func(inputs, param, outputs); + + for (size_t i=0; i legal_output_dtypes_i = output_info(i).dtypes(); + mgb_assert( + legal_output_dtypes_i.find(outputs[i].str()) != legal_output_dtypes_i.end(), + "dtypes of output: %s is %s, the info of this output is:\n%s", + output_info(i).name().c_str(), outputs[i].str().c_str(), + output_info(i).str().c_str() + ); + } + + assert_outputs_size_right(outputs); + return outputs; +} + +std::vector CustomOp::infer_output_format(const std::vector &inputs, const Param ¶m) const { + assert_inputs_size_right(inputs); + for (size_t i=0; i outputs(output_num()); + OpImplRef(m_impl.get())->infer_output_format_func(inputs, param, outputs); + + for (size_t i=0; i &inputs, const Param ¶m, std::vector &outputs) const { + assert_inputs_size_right(inputs); + assert_outputs_size_right(outputs); + if (outputs.size() == 0) { + return; + } + + std::string device = outputs[0].device().str(); + for (size_t i=1; ipreprocess_funcs[device]; + auto forward_func = OpImplRef(m_impl.get())->compute_funcs[device]; + auto postprocess_func = OpImplRef(m_impl.get())->postprocess_funcs[device]; + + preprocess_func(inputs, param, outputs); + forward_func(inputs, param, outputs); + postprocess_func(outputs, param, outputs); + assert_outputs_size_right(outputs); +} + +} diff --git a/src/custom/impl/param.cpp b/src/custom/impl/param.cpp new file mode 100644 index 00000000..5d790b1d --- /dev/null +++ b/src/custom/impl/param.cpp @@ -0,0 +1,179 @@ +/** + * \file src/custom/impl/param.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/custom/param.h" +#include "megbrain/common.h" +#include "megbrain/utils/hash.h" +#include +#include +#include + +using namespace mgb; + +namespace custom { + +class ParamSchemaImpl { + std::string m_name; + std::string m_desc; + ParamVal m_default; + friend ParamSchema; +}; + +class ParamInfoImpl { + std::vector m_meta; + uint32_t TAG; + friend ParamInfo; +}; + +class ParamImpl { + std::unordered_map m_vals; + + ParamImpl() = default; + ParamImpl(const ParamImpl &rhs) = default; + ParamImpl &operator=(const ParamImpl &rhs) { + mgb_assert( + m_vals.size() == rhs.m_vals.size(), + "params of different op, assignment failed!" + ); + for (const auto &kv: rhs.m_vals) { + auto iter = m_vals.find(kv.first); + mgb_assert(iter != m_vals.end(), "params of different op, assignment failed!"); + iter->second = kv.second; + } + return *this; + } + + friend Param; +}; + +CUSTOM_PIMPL_CLS_DEFINE(ParamSchema) + +ParamSchema::ParamSchema(const std::string &name, const ParamVal &value, const std::string &desc) + : m_impl(new ParamSchemaImpl(), impl_deleter) { + TypedRef(ParamSchemaImpl, m_impl.get()).m_name = name; + TypedRef(ParamSchemaImpl, m_impl.get()).m_default = value; + TypedRef(ParamSchemaImpl, m_impl.get()).m_desc = desc; +} + +const std::string &ParamSchema::name(void) const { + return TypedRef(ParamSchemaImpl, m_impl.get()).m_name; +} + +const std::string &ParamSchema::desc(void) const { + return TypedRef(ParamSchemaImpl, m_impl.get()).m_desc; +} + +const ParamVal &ParamSchema::default_val(void) const { + return TypedRef(ParamSchemaImpl, m_impl.get()).m_default; +} + +ParamDynType ParamSchema::type(void) const { + return TypedRef(ParamSchemaImpl, m_impl.get()).m_default.type(); +} + +std::string ParamSchema::str(void) const { + std::stringstream ss; + ss << "name: " << TypedRef(ParamSchemaImpl, m_impl.get()).m_name + << "\ndesc: " << TypedRef(ParamSchemaImpl, m_impl.get()).m_desc + << "\n" << TypedRef(ParamSchemaImpl, m_impl.get()).m_default.str(); + return ss.str(); +} + +CUSTOM_PIMPL_CLS_DEFINE(ParamInfo) + +void ParamInfo::set_tag(const std::string &hash_str) { + const char *ptr = hash_str.c_str(); + TypedRef(ParamInfoImpl, m_impl.get()).TAG = 0; + for (size_t i=0; i::max(); + } +} + +void ParamInfo::set_meta(const std::vector &meta) { + TypedRef(ParamInfoImpl, m_impl.get()).m_meta = meta; +} + +uint32_t ParamInfo::tag(void) const { + return TypedRef(ParamInfoImpl, m_impl.get()).TAG; +} + +std::vector &ParamInfo::meta(void) { + return TypedRef(ParamInfoImpl, m_impl.get()).m_meta; +} + +const std::vector &ParamInfo::meta(void) const { + return TypedRef(ParamInfoImpl, m_impl.get()).m_meta; +} + +CUSTOM_PIMPL_CLS_DEFINE(Param) + +Param::Param(const ParamInfo &info): m_impl(new ParamImpl(), impl_deleter) { + for (const auto &schema: info.meta()) { + TypedRef(ParamImpl, m_impl.get()).m_vals.emplace(schema.name(), schema.default_val()); + } +} + +ParamVal &Param::operator[](const std::string &name) { + return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name)->second; +} + +const ParamVal &Param::operator[](const std::string &name) const { + return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name)->second; +} + +const std::unordered_map &Param::raw() const { + return TypedRef(ParamImpl, m_impl.get()).m_vals; +} + +bool Param::exist(const std::string &name) const { + return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name) != + TypedRef(ParamImpl, m_impl.get()).m_vals.end(); +} + +std::string Param::to_bytes(void) const { + std::string res; + std::map ordered_vals( + TypedRef(ParamImpl, m_impl.get()).m_vals.begin(), + TypedRef(ParamImpl, m_impl.get()).m_vals.end()); + for (auto &&kv: ordered_vals) { + res += ParamVal::to_bytes(kv.second); + } + return res; +} + +void Param::from_bytes(const std::string &bytes) { + std::map ordered_vals( + TypedRef(ParamImpl, m_impl.get()).m_vals.begin(), + TypedRef(ParamImpl, m_impl.get()).m_vals.end()); + size_t offset = 0; + for (auto &kv: ordered_vals) { + kv.second = ParamVal::from_bytes(bytes, offset); + } + TypedRef(ParamImpl, m_impl.get()).m_vals.clear(); + TypedRef(ParamImpl, m_impl.get()).m_vals.insert(ordered_vals.begin(), ordered_vals.end()); + mgb_assert(offset == bytes.size(), "wrong data loader"); +} + +bool operator==(const Param &lhs, const Param &rhs) { + if (lhs.raw().size() != rhs.raw().size()) + return false; + for (const auto &kv: lhs.raw()) { + auto riter = rhs.raw().find(kv.first); + if (riter == rhs.raw().end() || !((kv.second) == riter->second)) { + return false; + } + } + return true; +} + +} diff --git a/src/custom/impl/param_val.cpp b/src/custom/impl/param_val.cpp new file mode 100644 index 00000000..0ff1662f --- /dev/null +++ b/src/custom/impl/param_val.cpp @@ -0,0 +1,400 @@ +/** + * \file src/custom/impl/param_val.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/custom/param_val.h" +#include "megbrain/common.h" + +#pragma GCC diagnostic ignored "-Wsign-compare" + +using namespace mgb; + +namespace custom { + +/** + * Macro Callback for Case + */ + +#define CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS(dyn_type, static_type) \ + case (ParamDynType::dyn_type): { \ + std::unique_ptr new_ptr( \ + new static_type(TypedRef(static_type, rhs.m_ptr.get())), \ + impl_deleter \ + ); \ + m_ptr.swap(new_ptr); \ + break; \ + } + +#define CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS(dyn_type, static_type) \ + case (ParamDynType::dyn_type): { \ + TypedRef(static_type, m_ptr.get()) = TypedRef(static_type, rhs.m_ptr.get());\ + break; \ + } + +#define CUSTOM_ASSERT_OPERAND_VALID(operand, opr) \ + mgb_assert( \ + operand.m_ptr != nullptr && operand.m_type != ParamDynType::Invalid, \ + "invalid %s of operator %s of ParamVal", #operand, #opr \ + ) + +#define CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op) \ + mgb_assert( \ + lhs.m_type == rhs.m_type, "`%s` %s `%s` is not allowed", \ + type2name[lhs.m_type].c_str(), #op, \ + type2name[rhs.m_type].c_str() \ + ) + +#define CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL(dyn_type, static_type, op) \ + case (ParamDynType::dyn_type): { \ + const auto &rval = TypedRef(static_type, rhs.m_ptr.get()); \ + return lval op rval; \ + } + +#define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC(dyn_type, static_type, op) \ + case (ParamDynType::dyn_type): { \ + const auto &lval = TypedRef(static_type, lhs.m_ptr.get()); \ + switch (rhs.m_type) { \ + CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY( \ + CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL, op) \ + default: \ + CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ + } \ + break; \ + } + +#define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC(dyn_type, static_type, op) \ + case (ParamDynType::dyn_type): { \ + CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ + const auto &lval = TypedRef(static_type, lhs.m_ptr.get()); \ + const auto &rval = TypedRef(static_type, rhs.m_ptr.get()); \ + return lval op rval; \ + } + +#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(op, ret_type) \ + ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \ + CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ + CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ + \ + switch (lhs.m_type) { \ + CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ + CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ + default: \ + CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ + } \ + return {}; \ + } + +#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING(op, ret_type) \ + ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \ + CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ + CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ + \ + switch (lhs.m_type) { \ + CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ + CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ + CUSTOM_FOR_STRING_PARAMTYPE( \ + CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ + default: \ + CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ + } \ + return {}; \ + } + +#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(op, ret_type) \ + ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \ + CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ + CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ + \ + switch (lhs.m_type) { \ + CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ + CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ + CUSTOM_FOR_STRING_PARAMTYPE( \ + CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ + CUSTOM_FOR_EACH_LIST_PARAMTYPE( \ + CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ + default: \ + CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ + } \ + return {}; \ + } + +#define CUSTOM_CASE_TO_PRINT_NONLIST(dyn_type, static_type) \ + case (ParamDynType::dyn_type): { \ + auto rval = TypedRef(static_type, m_ptr.get()); \ + ss << rval; \ + break; \ + } + +#define CUSTOM_CASE_TO_PRINT_LIST(dyn_type, static_type) \ + case (ParamDynType::dyn_type): { \ + auto rval = TypedRef(static_type, m_ptr.get()); \ + ss << vec2str(rval); \ + break; \ + } + +#define CUSTOM_CASE_TO_RET_SIZE(dyn_type, static_type) \ + case (ParamDynType::dyn_type): { \ + return TypedRef(static_type, m_ptr.get()).size(); \ + break; \ + } + +#define CUSTOM_CASE_TO_DUMP_BASIC(dyn_type, static_type) \ + case (ParamDynType::dyn_type): { \ + res.resize(sizeof(ParamDynType) + sizeof(static_type)); \ + memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); \ + memcpy(&res[sizeof(ParamDynType)], value.m_ptr.get(), sizeof(static_type)); \ + break; \ + } + +#define CUSTOM_CASE_TO_DUMP_LIST(dyn_type, static_type) \ + case (ParamDynType::dyn_type): { \ + auto &ref = TypedRef(static_type, value.m_ptr.get()); \ + size_t len = ref.size(); \ + size_t elem_size = len != 0 ? sizeof(ref[0]) : 0; \ + res.resize(sizeof(ParamDynType) + sizeof(len) + len*elem_size); \ + memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); \ + memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); \ + memcpy(&res[sizeof(ParamDynType)+sizeof(len)], ref.data(), len*elem_size); \ + break; \ + } + +#define CUSTOM_CASE_TO_LOAD_BASIC(dyn_type, static_type) \ + case (ParamDynType::dyn_type): { \ + static_type val; \ + memcpy(&val, &bytes[offset], sizeof(val)); \ + offset += sizeof(val); \ + return val; \ + break; \ + } + +#define CUSTOM_CASE_TO_LOAD_LIST(dyn_type, static_type) \ + case (ParamDynType::dyn_type): { \ + size_t len = 0; \ + memcpy(&len, &bytes[offset], sizeof(len)); \ + offset += sizeof(len); \ + static_type vals; \ + vals.resize(len); \ + size_t elem_size = len != 0 ? sizeof(vals[0]) : 0; \ + memcpy(&vals[0], &bytes[offset], len*elem_size); \ + offset += len*elem_size; \ + return vals; \ + break; \ + } + +ParamVal::ParamVal(): m_ptr(nullptr, [](void*) -> void {}) { + m_type = ParamDynType::Invalid; +} + +ParamVal::ParamVal(const char *str): ParamVal(std::string(str)) { + +} + +ParamVal::ParamVal(const std::initializer_list &strs): ParamVal(std::vector(strs)) { +} + +ParamVal::ParamVal(const std::vector &strs) + : m_ptr(new std::vector(), impl_deleter>) { + m_type = ParamDynType::StringList; + for (const auto &str: strs) { + TypedRef(std::vector, m_ptr.get()).emplace_back(str); + } +} + +ParamVal::ParamVal(const ParamVal &rhs): m_ptr(nullptr, [](void*) -> void {}) { + mgb_assert( + rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr, + "invalid rhs of copy constructor of ParamVal" + ); + m_type = rhs.m_type; + switch(m_type) { + CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS) + default: { + mgb_assert(false, "invalid rhs of copy constructor of ParamVal"); + } + } +} + +ParamVal &ParamVal::operator=(const char *str) { + this->operator=(std::string(str)); + return *this; +} + +ParamVal &ParamVal::operator=(const std::initializer_list &strs) { + this->operator=(std::vector(strs)); + return *this; +} + +ParamVal &ParamVal::operator=(const std::vector &strs) { + std::vector tmp_strs; + for (const auto &str: strs) { + tmp_strs.emplace_back(str); + } + this->operator=(tmp_strs); + return *this; +} + +ParamVal &ParamVal::operator=(const ParamVal &rhs) { + if (&rhs == this) + return *this; + mgb_assert( + rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr, + "invalid rhs of assignment operator of ParamVal" + ); + + if (rhs.m_type == m_type) { + switch(m_type) { + CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS); + default: + mgb_assert(false, "invalid rhs of assignment operator of ParamVal"); + } + } + else { + m_type = rhs.m_type; + switch(m_type) { + CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS); + default: + mgb_assert(false, "invalid rhs of assignment operator of ParamVal"); + } + } + return *this; +} + +const void *ParamVal::raw_ptr(void) const { + return m_ptr.get(); +} + +void *ParamVal::raw_ptr(void) { + return m_ptr.get(); +} + +ParamDynType ParamVal::type(void) const { + return m_type; +} + +std::string ParamVal::str() const { + std::stringstream ss; + ss << "type: " << type2name[m_type] << "\n" << "value: "; + switch (m_type) { + CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) + CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) + CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_PRINT_LIST) + default: + mgb_assert(false, "invalid data of assignment operator of ParamVal"); + } + return ss.str(); +} + +size_t ParamVal::size(void) const { + switch (m_type) { + CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE) + CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE) + default: + mgb_assert(false, "there is no size() for basic data types"); + } +} + +std::string ParamVal::to_bytes(const ParamVal &value) { + std::string res; + // because the specialization of std::vector + if (value.type() == ParamDynType::BoolList) { + std::vector &ref = TypedRef(std::vector, value.m_ptr.get()); + size_t len = ref.size(); + size_t elem_size = sizeof(bool); + res.resize(sizeof(ParamDynType) + sizeof(len) + len*elem_size); + memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); + memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); + size_t startpos = sizeof(ParamDynType)+sizeof(len); + for (size_t idx=0; idx &ref = TypedRef(std::vector, value.m_ptr.get()); + size_t len = ref.size(); + res.resize(sizeof(ParamDynType) + sizeof(len)); + memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); + memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); + for (size_t idx=0; idx ret; + size_t len = 0; + memcpy(&len, &bytes[offset], sizeof(len)); + offset += sizeof(len); + for (size_t idx =0; idx ret; + size_t len = 0; + memcpy(&len, &bytes[offset], sizeof(len)); + offset += sizeof(len); + for (size_t idx =0; idx=, bool) +CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<=, bool) +CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(>, bool) +CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<, bool) + +} diff --git a/src/custom/impl/tensor.cpp b/src/custom/impl/tensor.cpp new file mode 100644 index 00000000..88491e11 --- /dev/null +++ b/src/custom/impl/tensor.cpp @@ -0,0 +1,486 @@ +/** + * \file src/custom/impl/tensor.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/custom/tensor.h" +#include "megbrain/comp_node.h" +#include "megbrain/common.h" +#include "megbrain/tensor.h" +#include +#include + +using namespace mgb; + +namespace custom { + +template +SmallVector to_builtin_vector(const std::vector &custom_data) { + SmallVector builtin_data(custom_data.size()); + memcpy(builtin_data.data(), custom_data.data(), sizeof(T)*custom_data.size()); + return builtin_data; +} + +using DeviceImpl = CompNode; +using ShapeImpl = megdnn::TensorShape; +using DTypeImpl = megdnn::DType; +using FormatImpl = megdnn::TensorLayout::Format; +using TensorImpl = DeviceTensorND; + +#define DeviceImplRef(rawptr) (*reinterpret_cast(rawptr)) +#define ShapeImplRef(rawptr) (*reinterpret_cast(rawptr)) +#define DTypeImplRef(rawptr) (*reinterpret_cast(rawptr)) +#define FormatImplRef(rawptr) (*reinterpret_cast(rawptr)) +#define TensorImplRef(rawptr) (*reinterpret_cast(rawptr)) + +#define DeviceImplConstRef(rawptr) static_cast(*reinterpret_cast(rawptr)) +#define ShapeImplConstRef(rawptr) static_cast(*reinterpret_cast(rawptr)) +#define DTypeImplConstRef(rawptr) static_cast(*reinterpret_cast(rawptr)) +#define FormatImplConstRef(rawptr) static_cast(*reinterpret_cast(rawptr)) +#define TensorImplConstRef(rawptr) static_cast(*reinterpret_cast(rawptr)) + +static std::unordered_map, + EnumCmp> dev_benum2cstr; +static std::unordered_map, + EnumCmp> dev_benum2cenum; +static std::unordered_map dev_cstr2bstr; +static std::unordered_map, + EnumCmp> dev_cenum2bstr; + +#define CUSTOM_BIND_DEVICE(custom_impl, builtin_device, builtin_str) \ + auto be2cs##custom_impl = dev_benum2cstr.emplace( \ + DeviceImpl::DeviceType::builtin_device, std::string(#custom_impl)); \ + auto be2ce##custom_impl = dev_benum2cenum.emplace( \ + DeviceImpl::DeviceType::builtin_device, DeviceEnum::custom_impl); \ + auto cs2bs##custom_impl = dev_cstr2bstr.emplace( \ + std::string(#custom_impl), std::string(builtin_str)); \ + auto ce2bs##custom_impl = dev_cenum2bstr.emplace( \ + DeviceEnum::custom_impl, std::string(builtin_str)); + +CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_BIND_DEVICE) +#undef CUSTOM_BIND_DEVICE + +CUSTOM_PIMPL_CLS_DEFINE(Device) + +const void *Device::impl() const { + return m_impl.get(); +} + +Device::Device(const void *impl): m_impl(nullptr, impl_deleter) { + mgb_assert(impl != nullptr, "invalid ptr"); + if (!DeviceImplConstRef(impl).valid()) { + m_impl.reset(new DeviceImpl()); + return; + } + + auto builtin_device_enum = DeviceImplConstRef(impl).device_type(); + mgb_assert( + dev_benum2cenum.find(builtin_device_enum) != dev_benum2cenum.end(), + "unsupported compnode type: %s", DeviceImplConstRef(impl).to_string().c_str() + ); + m_impl.reset(new DeviceImpl(DeviceImplConstRef(impl))); +} + +Device::Device(const std::string &device): m_impl(nullptr, impl_deleter) { + mgb_assert(is_legal(device), "invalid device type: %s", device.c_str()); + std::string builtin_device = dev_cstr2bstr[device]; + m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device))); +} + +// to avoid the ambiguous from Device(const void *impl) +Device::Device(const char *device): Device(std::string(device)) { + +} + +Device::Device(DeviceEnum device): m_impl(nullptr, impl_deleter) { + mgb_assert(is_legal(device), "invalid device type"); + std::string builtin_device = dev_cenum2bstr[device]; + m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device))); +} + +std::string Device::str(void) const { + if (!DeviceImplRef(m_impl.get()).valid()) { + return "invalid"; + } + + auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type(); + auto iter = dev_benum2cstr.find(builtin_device_type); + mgb_assert( + iter != dev_benum2cstr.end(), "invalid device type %s\n", + DeviceImplRef(m_impl.get()).to_string().c_str() + ); + return iter->second; +} + +DeviceEnum Device::enumv(void) const { + mgb_assert( + DeviceImplRef(m_impl.get()).valid(), + "cannot get the enum value of invalid device" + ); + + auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type(); + auto iter = dev_benum2cenum.find(builtin_device_type); + mgb_assert( + iter != dev_benum2cenum.end(), "invalid device type %s\n", + DeviceImplRef(m_impl.get()).to_string().c_str() + ); + return iter->second; +} + +bool Device::is_legal(const std::string &device_type) { + return dev_cstr2bstr.find(device_type) != dev_cstr2bstr.end(); +} + +bool Device::is_legal(DeviceEnum device_type) { + return dev_cenum2bstr.find(device_type) != dev_cenum2bstr.end(); +} + +std::vector Device::legal_devices(void) { + std::vector ret; + for (const auto &kv: dev_cstr2bstr) { + ret.emplace_back(kv.first); + } + return ret; +} + +bool operator==(const Device &lhs, const Device &rhs) { + return lhs.str() == rhs.str(); +} + +CUSTOM_PIMPL_CLS_DEFINE(Shape) + +const void *Shape::impl() const { + return m_impl.get(); +} + +Shape::Shape(const void *impl): m_impl(nullptr, impl_deleter) { + mgb_assert(impl != nullptr, "invalid ptr"); + m_impl.reset(new ShapeImpl(ShapeImplConstRef(impl))); +} + +Shape::Shape(const std::vector &rhs): m_impl(nullptr, impl_deleter) { + m_impl.reset(new ShapeImpl(to_builtin_vector(rhs))); +} + +Shape::Shape(const std::initializer_list &rhs): m_impl(nullptr, impl_deleter) { + m_impl.reset(new ShapeImpl(rhs)); +} + +size_t &Shape::operator[](size_t idx) { + mgb_assert(idx < ndim(), "wrong tensor dimension idx: %lu < %lu", static_cast(idx), static_cast(ndim())); + return ShapeImplRef(m_impl.get()).operator[](idx); +} + +size_t Shape::operator[](size_t idx) const { + return const_cast(this)->operator[](idx); +} + +void Shape::ndim(size_t dim) { + mgb_assert(dim < ShapeImpl::MAX_NDIM, "dimension must <= %lu", static_cast(ShapeImpl::MAX_NDIM)); + ShapeImplRef(m_impl.get()).ndim = dim; +} + +size_t Shape::ndim(void) const { + return ShapeImplRef(m_impl.get()).ndim; +} + +bool operator==(const Shape &lhs, const Shape &rhs) { + return ShapeImplRef(lhs.m_impl.get()).eq_shape(ShapeImplRef(rhs.m_impl.get())); +} + +static std::unordered_map dtype_cstr2benum; +static std::unordered_map, + EnumCmp> dtype_cenum2benum; +static std::unordered_map, + EnumCmp> dtype_benum2cstr; +static std::unordered_map, + EnumCmp> dtype_benum2cenum; +static std::unordered_map, + EnumCmp> dtype_cenum2cstr; + +#define CUSTOM_BIND_DTYPE(custom_impl, builtin_dtype, ctype) \ + auto cs2be##custom_impl = dtype_cstr2benum.emplace( \ + std::string(#custom_impl), megdnn::DTypeEnum::builtin_dtype); \ + auto ce2be##custom_impl = dtype_cenum2benum.emplace( \ + DTypeEnum::custom_impl, megdnn::DTypeEnum::builtin_dtype); \ + auto be2cs##custom_impl = dtype_benum2cstr.emplace( \ + megdnn::DTypeEnum::builtin_dtype, std::string(#custom_impl)); \ + auto be2ce##custom_impl = dtype_benum2cenum.emplace( \ + megdnn::DTypeEnum::builtin_dtype, DTypeEnum::custom_impl); \ + auto ce2cs##custom_impl = dtype_cenum2cstr.emplace( \ + DTypeEnum::custom_impl, std::string(#custom_impl)); + +CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_BIND_DTYPE) +#undef CUSTOM_BIND_DTYPE + +CUSTOM_PIMPL_CLS_DEFINE(DType) + +const void *DType::impl() const { + return m_impl.get(); +} + +DType::DType(const void *impl): m_impl(nullptr, impl_deleter) { + mgb_assert(impl != nullptr, "invalid ptr"); + m_impl.reset(new DTypeImpl(DTypeImplConstRef(impl))); +} + +DType::DType(const std::string &dtype): m_impl(nullptr, impl_deleter) { + auto iter = dtype_cstr2benum.find(dtype); + mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str()); + mgb_assert( + dtype[0] != 'q', "can not construct quantized dtype " + "%s without scale and zero_point", dtype.c_str() + ); + m_impl.reset(new DTypeImpl(DTypeImpl::from_enum(iter->second))); +} + +DType::DType(const char *dtype): DType(std::string(dtype)) { +} + +DType::DType(const std::string &dtype, float scale, uint8_t zero_point) + : m_impl(nullptr, impl_deleter) { + auto iter = dtype_cstr2benum.find(dtype); + mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str()); + mgb_assert( + dtype[0] == 'q', "given scale/zero_point to construct " + "non-quantized dtype: %s is not allowed", dtype.c_str() + ); + if (dtype == "quint8") { + m_impl.reset(new megdnn::ParameterizedDType< + megdnn::DTypeEnum::Quantized8Asymm>(scale, zero_point)); + } + else { + mgb_assert( + zero_point == 0, "invalid zero point %d for dtype %s", + zero_point, dtype.c_str() + ); + if (dtype == "qint8") { + m_impl.reset(new megdnn::ParameterizedDType< + megdnn::DTypeEnum::QuantizedS8>(scale)); + } + else if (dtype == "qint16") { + m_impl.reset(new megdnn::ParameterizedDType< + megdnn::DTypeEnum::QuantizedS16>(scale)); + } + else if (dtype == "qint32") { + m_impl.reset(new megdnn::ParameterizedDType< + megdnn::DTypeEnum::QuantizedS32>(scale)); + } + else { + mgb_assert(false, "invalid dtype %s", dtype.c_str()); + } + } + +} + +DType::DType(const char *dtype, float scale, uint8_t zero_point) + : DType(std::string(dtype), scale, zero_point) { +} + +DType::DType(DTypeEnum dtype): m_impl(nullptr, impl_deleter) { + auto iter = dtype_cenum2benum.find(dtype); + mgb_assert(iter != dtype_cenum2benum.end(), "invalid dtype"); + mgb_assert(dtype < DTypeEnum::quint8, + "can not construct quantized dtype without scale and zero_point"); + m_impl.reset(new DTypeImpl(DTypeImpl::from_enum(iter->second))); +} + +DType::DType(DTypeEnum dtype, float scale, uint8_t zero_point) + : DType(dtype_cenum2cstr.find(dtype)->second, scale, zero_point) { +} + +std::string DType::str(void) const { + if (!DTypeImplRef(m_impl.get()).valid()) + return "invalid"; + auto iter = dtype_benum2cstr.find(DTypeImplRef(m_impl.get()).enumv()); + if (iter == dtype_benum2cstr.end()) + return "invalid"; + return iter->second; +} + +DTypeEnum DType::enumv(void) const { + auto iter = dtype_benum2cenum.find(DTypeImplRef(m_impl.get()).enumv()); + mgb_assert(iter != dtype_benum2cenum.end(), "invalid dtype"); + return iter->second; +} + +float DType::scale() const { + if (enumv() == DTypeEnum::qint8) { + return DTypeImplRef(m_impl.get()).param().scale; + } + else if (enumv() == DTypeEnum::qint16) { + return DTypeImplRef(m_impl.get()).param().scale; + } + else if (enumv() == DTypeEnum::qint32) { + return DTypeImplRef(m_impl.get()).param().scale; + } + else if (enumv() == DTypeEnum::quint8) { + return DTypeImplRef(m_impl.get()).param().scale; + } + else { + mgb_assert(false, "dtype %s has no scale", str().c_str()); + return 0.f; + } +} + +uint8_t DType::zero_point() const { + mgb_assert(enumv()==DTypeEnum::quint8, "dtype %s has no zero point", str().c_str()); + return DTypeImplRef(m_impl.get()).param().zero_point; +} + +bool DType::is_legal(const std::string &dtype) { + return dtype_cstr2benum.find(dtype) != dtype_cstr2benum.end(); +} + +bool DType::is_legal(const DTypeEnum &dtype) { + return dtype_cenum2benum.find(dtype) != dtype_cenum2benum.end(); +} + +std::vector DType::legal_dtypes(void) { + std::vector ret; + for (const auto &kv: dtype_cstr2benum) + ret.emplace_back(kv.first); + return ret; +} + +bool operator==(const DType &lhs, const DType &rhs) { + return DTypeImplRef(lhs.m_impl.get()) == DTypeImplRef(rhs.m_impl.get()); +} + +bool operator==(const DType &lhs, const std::string &rhs) { + return lhs.str() == rhs; +} + +bool operator==(const DType &lhs, const char *rhs) { + return operator==(lhs, std::string(rhs)); +} + +bool operator==(const std::string &lhs, const DType &rhs) { + return operator==(rhs, lhs); +} + +bool operator==(const char *lhs, const DType &rhs) { + return operator==(rhs, std::string(lhs)); +} + +CUSTOM_PIMPL_CLS_DEFINE(Format) + +const void *Format::impl() const { + return m_impl.get(); +} + +Format::Format(const void *impl): m_impl(nullptr, impl_deleter) { + mgb_assert(impl != nullptr, "invalid ptr"); + mgb_assert(FormatImplConstRef(impl).is_default(), "only default format is supported now"); + + m_impl.reset(new FormatImpl(FormatImplConstRef(impl))); +} + +Format::Format(const std::string &format): m_impl(nullptr, impl_deleter) { + mgb_assert(format == "default", "only default format is supported now"); + m_impl.reset(new FormatImpl()); +} + +Format::Format(const char *format): Format(std::string(format)) { + +} + +std::string Format::str(void) const { + return FormatImplRef(m_impl.get()).to_string(); +} + +bool Format::is_default(void) const { + return FormatImplRef(m_impl.get()).is_default(); +} + +const void *Tensor::impl(void) const { + return m_tensor; +} + +Tensor::Tensor(const void *impl) { + mgb_assert(impl != nullptr, "invalid ptr"); + m_tensor = const_cast(impl); +} + +const size_t *Tensor::shapes_raw(void) const { + return TensorImplRef(m_tensor).shape().shape; +} + +const ptrdiff_t *Tensor::strides_raw(void) const { + return TensorImplRef(m_tensor).layout().stride; +} + +Tensor::Tensor(const Tensor &rhs) { + mgb_assert(rhs.m_tensor != nullptr, "invalid rhs for copy constructor\n"); + m_tensor = rhs.m_tensor; +} + +Tensor &Tensor::operator=(const Tensor &rhs) { + mgb_assert(rhs.m_tensor != nullptr, "invalid rhs for assignment operator"); + if (&rhs == this || rhs.m_tensor == m_tensor) + return *this; + m_tensor = rhs.m_tensor; + return *this; +} + +Shape Tensor::shape(void) const { + auto builtin = TensorImplRef(m_tensor).shape(); + return Shape(&builtin); +} + +DType Tensor::dtype(void) const { + auto builtin = TensorImplRef(m_tensor).dtype(); + return DType(&builtin); +} + +Format Tensor::format(void) const { + auto builtin = TensorImplRef(m_tensor).format(); + return Format(&builtin); +} + +Device Tensor::device(void) const { + auto builtin = TensorImplRef(m_tensor).comp_node(); + return Device(&builtin); +} + +size_t Tensor::size(void) const { + return TensorImplRef(m_tensor).shape().total_nr_elems(); +} + +std::vector Tensor::stride(void) const { + std::vector ret(TensorImplRef(m_tensor).shape().ndim); + for (size_t i=0; i(TensorImplRef(m_tensor).raw_ptr()); +} + +const void *Tensor::data(void) const { + return static_cast(TensorImplRef(m_tensor).raw_ptr()); +} + +} // namespace custom diff --git a/src/custom/impl/utils.cpp b/src/custom/impl/utils.cpp new file mode 100644 index 00000000..1428c7c9 --- /dev/null +++ b/src/custom/impl/utils.cpp @@ -0,0 +1,41 @@ +/** + * \file src/custom/impl/utils.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/custom/utils.h" +#include "megbrain/common.h" +#include + +using namespace mgb; + +namespace custom { + +void assert_failed_log(const char *file, int line, const char *func, const char *expr, const char *msg_fmt, ...) { + std::string msg = ssprintf("`%s' is true at %s:%d: %s", expr, file, line, func); + if (msg_fmt) { + msg_fmt = convert_fmt_str(msg_fmt); + va_list ap; + va_start(ap, msg_fmt); + msg.append("\nextra message: "); + msg.append(svsprintf(msg_fmt, ap)); + va_end(ap); + } + printf("%s\n", msg.c_str()); +} + +UnImpleWarnLog::UnImpleWarnLog(const std::string &func, const std::string &attr, + const std::string &val) { + mgb_log_warn("you are using the default custom %s function, the `%s` attribute " + "of all the outputs tensor will be the same with inputs tensor[0]. " + "If there is no input tensor, it will be `%s`", + func.c_str(), attr.c_str(), val.c_str()); +} + +} diff --git a/src/custom/include/megbrain/custom/accessor.h b/src/custom/include/megbrain/custom/accessor.h new file mode 100644 index 00000000..d9135b7b --- /dev/null +++ b/src/custom/include/megbrain/custom/accessor.h @@ -0,0 +1,185 @@ +/** + * \file src/custom/include/megbrain/custom/accessor.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include + +namespace custom { + +#ifdef __CUDACC__ + #define CUSTOM_HOST __host__ + #define CUSTOM_DEVICE __device__ +#else + #define CUSTOM_HOST + #define CUSTOM_DEVICE +#endif + +#define CUSTOM_HOST_DEVICE CUSTOM_HOST CUSTOM_DEVICE + +template +struct DefaultPtrTraits { + using PtrType = T*; +}; + +#ifdef __CUDACC__ +template +struct RestrictPtrTraits { + using PtrType = T* __restrict__; +}; +#endif + +template class PtrTraits = DefaultPtrTraits, + typename index_t = int64_t> +class TensorAccessorProxyBase { +public: + using PtrType = typename PtrTraits::PtrType; + +protected: + PtrType m_data; + const index_t* m_sizes; + const index_t* m_strides; + +public: + CUSTOM_HOST_DEVICE TensorAccessorProxyBase(PtrType data, const index_t *sizes, const index_t *strides) { + m_data = data; + m_sizes = sizes; + m_strides = strides; + } + + CUSTOM_HOST_DEVICE index_t stride(index_t i) const { + return m_strides[i]; + } + + CUSTOM_HOST_DEVICE index_t size(index_t i) const { + return m_sizes[i]; + } + + CUSTOM_HOST_DEVICE PtrType data() const { + return m_data; + } +}; + +template class PtrTraits = DefaultPtrTraits, + typename index_t = int64_t> +class TensorAccessorProxy: public TensorAccessorProxyBase { +public: + using PtrType = typename PtrTraits::PtrType; + + CUSTOM_HOST_DEVICE TensorAccessorProxy(PtrType data, const index_t *sizes, const index_t *strides) + : TensorAccessorProxyBase(data, sizes, strides) { + + } + + CUSTOM_HOST_DEVICE TensorAccessorProxy operator[](index_t i) { + return TensorAccessorProxy( + this->m_data + this->m_strides[0] * i, + this->m_sizes + 1, + this->m_strides + 1 + ); + } + + CUSTOM_HOST_DEVICE const TensorAccessorProxy operator[](index_t i) const { + return TensorAccessorProxy( + this->m_data + this->m_strides[0] * i, + this->m_sizes + 1, + this->m_strides + 1 + ); + } +}; + +template class PtrTraits, typename index_t> +class TensorAccessorProxy + : public TensorAccessorProxyBase { +public: + using PtrType = typename PtrTraits::PtrType; + + CUSTOM_HOST_DEVICE TensorAccessorProxy(PtrType data, const index_t *sizes, const index_t *strides) + : TensorAccessorProxyBase(data, sizes, strides ) { + + } + + CUSTOM_HOST_DEVICE T &operator[](index_t i) { + return this->m_data[this->m_strides[0]*i]; + } + + CUSTOM_HOST_DEVICE const T &operator[](index_t i) const { + return this->m_data[this->m_strides[0]*i]; + } +}; + +template class PtrTraits = DefaultPtrTraits, + typename index_t = int64_t> +class TensorAccessorBase { +public: + using PtrType = typename PtrTraits::PtrType; + +protected: + PtrType m_data; + index_t m_sizes[N]; + index_t m_strides[N]; + +public: + CUSTOM_HOST_DEVICE TensorAccessorBase(PtrType data, const size_t *sizes, const ptrdiff_t *strides) { + m_data = data; + for (size_t i=0; i class PtrTraits = DefaultPtrTraits, + typename index_t = int64_t> +class TensorAccessor: public TensorAccessorBase { +public: + using PtrType = typename PtrTraits::PtrType; + + CUSTOM_HOST_DEVICE TensorAccessor(PtrType data, const size_t *sizes, const ptrdiff_t *strides) + : TensorAccessorBase(data, sizes, strides) { + + } + + CUSTOM_HOST_DEVICE decltype(auto) operator[](index_t i) { + return TensorAccessorProxy( + this->m_data, + this->m_sizes, + this->m_strides + )[i]; + } + + CUSTOM_HOST_DEVICE decltype(auto) operator[](index_t i) const { + return TensorAccessorProxy( + this->m_data, + this->m_sizes, + this->m_strides + )[i]; + } +}; + +} diff --git a/src/custom/include/megbrain/custom/custom.h b/src/custom/include/megbrain/custom/custom.h new file mode 100644 index 00000000..feb0fc24 --- /dev/null +++ b/src/custom/include/megbrain/custom/custom.h @@ -0,0 +1,108 @@ +/** + * \file src/custom/include/megbrain/custom/custom.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "op.h" +#include "tensor.h" +#include "param.h" + +namespace custom { + std::shared_ptr op_insert(std::string opname, uint32_t version); +} + +#define CUSTOM_OP_REG(OpName) CustomOp &_##OpName = (*(op_insert(#OpName, CUSTOM_OP_VERSION))) + +#define CUSTOM_OP_REG_BEGIN(OpName) \ + namespace custom { \ + namespace OpName { + +#define CUSTOM_OP_REG_END(OpName) \ + } \ + } + +#define CASE_TO_PERFORM_USING_HINT(name, case_type, real_type, hint, ...) \ + case (case_type): { \ + using hint = real_type; \ + return __VA_ARGS__(); \ + } + +#define CASE_TO_PERFORM_ON_SCALAR(name, case_type, real_type, ...) \ + CASE_TO_PERFORM_USING_HINT(name, case_type, real_type, scalar_t, __VA_ARGS__) + +#define DISPATCH_FLOAT_TYPES(tensor_dtype, name, ...) \ + [&]() { \ + const auto &dtype = tensor_dtype; \ + switch (dtype.enumv()) { \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \ + default: \ + custom_assert(false, "no implemented %s kernel for dtype %s\n", \ + name, dtype.str().c_str()); \ + } \ + }() + +#define DISPATCH_INT_TYPES(tensor_dtype, name, ...) \ + [&]() { \ + const auto &dtype = tensor_dtype; \ + switch (dtype.enumv()) { \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint8, uint8_t, __VA_ARGS__) \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint16,uint16_t, __VA_ARGS__)\ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ + default: \ + custom_assert(false, "no implemented %s kernel for dtype %s\n", \ + name, dtype.str().c_str()); \ + } \ + }() + +#define DISPATCH_INT_AND_FLOAT_TYPES(tensor_dtype, name, ...) \ + [&]() { \ + const auto &dtype = tensor_dtype; \ + switch (dtype.enumv()) { \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint8, uint8_t, __VA_ARGS__) \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint16,uint16_t, __VA_ARGS__)\ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \ + default: \ + custom_assert(false, "no implemented %s kernel for dtype %s\n", \ + name, dtype.str().c_str()); \ + } \ + }() + +#define DISPATCH_SIGN_INT_TYPES(tensor_dtype, name, ...) \ + [&]() { \ + const auto &dtype = tensor_dtype; \ + switch (dtype.enumv()) { \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ + default: \ + custom_assert(false, "no implemented %s kernel for dtype %s\n", \ + name, dtype.str().c_str()); \ + } \ + }() + +#define DISPATCH_SIGN_INT_AND_FLOAT_TYPES(tensor_dtype, name, ...) \ + [&]() { \ + const auto &dtype = tensor_dtype; \ + switch (dtype.enumv()) { \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ + CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ + default: \ + custom_assert(false, "no implemented %s kernel for dtype %s\n", \ + name, dtype.str().c_str()); \ + } \ + }() diff --git a/src/custom/include/megbrain/custom/data_adaptor.h b/src/custom/include/megbrain/custom/data_adaptor.h new file mode 100644 index 00000000..1d484e2d --- /dev/null +++ b/src/custom/include/megbrain/custom/data_adaptor.h @@ -0,0 +1,58 @@ +/** + * \file src/custom/include/megbrain/custom/data_adaptor.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megdnn/thin/small_vector.h" + +namespace custom { + +template +BuiltinT to_builtin(const CustomT &custom) { + return *reinterpret_cast(custom.impl()); +} + +template +CustomT to_custom(const BuiltinT &builtin) { + return std::move(CustomT(&builtin)); +} + +template +megdnn::SmallVector to_builtin(const std::vector &customs) { + megdnn::SmallVector builtins; + for (size_t i=0; i(customs[i]))); + } + return std::move(builtins); +} + +template +std::vector to_custom( + const megdnn::SmallVector &builtins) { + std::vector customs; + for (size_t i=0; i(builtins[i]))); + } + return std::move(customs); +} + +} + +#define to_custom_device(expr) custom::to_custom(expr) +#define to_builtin_device(expr) custom::to_builtin(expr) +#define to_custom_shape(expr) custom::to_custom(expr) +#define to_builtin_shape(expr) custom::to_builtin(expr) +#define to_custom_dtype(expr) custom::to_custom(expr) +#define to_builtin_dtype(expr) custom::to_builtin(expr) +#define to_custom_format(expr) custom::to_custom(expr) +#define to_builtin_format(expr) custom::to_builtin(expr) +#define to_custom_tensor(expr) custom::to_custom(expr) +#define to_builtin_tensor(expr) custom::to_builtin(expr) diff --git a/src/custom/include/megbrain/custom/manager.h b/src/custom/include/megbrain/custom/manager.h new file mode 100644 index 00000000..d751477b --- /dev/null +++ b/src/custom/include/megbrain/custom/manager.h @@ -0,0 +1,75 @@ +/** + * \file src/custom/include/megbrain/custom/manager.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "custom.h" +#include "megbrain/common.h" + +namespace custom { + +class CustomOpManager { + std::unordered_map> m_name2op; + std::unordered_map> m_id2op; + MGB_MUTEX m_mtx; + CustomOpManager() = default; +public: + PREVENT_COPY_AND_ASSIGN(CustomOpManager); + static CustomOpManager *inst(void); + ~CustomOpManager(); + + std::shared_ptr insert(const std::string &name, uint32_t version); + bool erase(const std::string &name); + bool erase(const RunTimeId &id); + + std::shared_ptr find_or_reg(const std::string &name, uint32_t version); + + RunTimeId to_id(const std::string &name) const; + std::string to_name(const RunTimeId &id) const; + + std::shared_ptr find(const std::string &name) const; + std::shared_ptr find(const RunTimeId &id) const; + + std::vector op_name_list(void); + std::vector op_id_list(void); +}; + +class CustomLib { + std::unique_ptr m_handle; + std::vector m_ops; + +public: + PREVENT_COPY_AND_ASSIGN(CustomLib); + + CustomLib(const std::string &path, int mode); + const std::vector &ops_in_lib(void) const; + ~CustomLib(); + bool valid(void) const; +}; + +using LibHandle = std::shared_ptr; + +class LibManager { + std::unordered_map m_custom_libs; + MGB_MUTEX m_mtx; + + LibManager() = default; + +public: + PREVENT_COPY_AND_ASSIGN(LibManager); + + static LibManager *inst(void); + const std::vector &install(const std::string &name, const std::string &path); + bool uninstall(const std::string &name); + friend class CustomOpManager; +}; + +} diff --git a/src/custom/include/megbrain/custom/op.h b/src/custom/include/megbrain/custom/op.h new file mode 100644 index 00000000..016de112 --- /dev/null +++ b/src/custom/include/megbrain/custom/op.h @@ -0,0 +1,109 @@ +/** + * \file src/custom/include/megbrain/custom/op.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "tensor.h" +#include "param.h" +#include + +#define PREVENT_COPY_AND_ASSIGN(Cls) \ + Cls(const Cls&) = delete; \ + Cls(const Cls&&) = delete; \ + Cls &operator=(const Cls&) = delete; \ + Cls &operator=(const Cls&&) = delete + +#define CUSTOM_OP_MAJOR 0 +#define CUSTOM_OP_MINOR 1 +#define CUSTOM_OP_PATCH 0 + +#define CUSTOM_OP_VERSION CUSTOM_OP_MAJOR*10000 + CUSTOM_OP_MINOR*100 + CUSTOM_OP_PATCH + +namespace custom { + +using RunTimeId = uint64_t; + +class ArgInfo { + CUSTOM_PIMPL_CLS_DECL(ArgInfo); + ArgInfo(const std::string &name, + const std::string &desc, + const std::unordered_set &dtypes, + const int &ndim, + const std::string &mem_stgy); + + const std::string &name(void) const; + const std::string &desc(void) const; + const std::unordered_set &dtypes(void) const; + int ndim(void) const; + const std::string &mem_strategy(void) const; + + std::string str() const; +}; + +class CustomOp { + std::unique_ptr m_impl; +public: + CustomOp(const std::string &op_type, uint32_t version); + PREVENT_COPY_AND_ASSIGN(CustomOp); + + using DeviceInferFuncPtr = void(*)(const std::vector&, const Param&, std::vector&); + using ShapeInferFuncPtr = void(*)(const std::vector&, const Param&, std::vector&); + using DTypeInferFuncPtr = void(*)(const std::vector&, const Param&, std::vector&); + using FormatInferFuncPtr = void(*)(const std::vector&, const Param&, std::vector&); + using PreprocessFuncPtr = void(*)(const std::vector&, const Param&, std::vector&); + using PostprocessFuncPtr = void(*)(const std::vector&, const Param&, std::vector&); + using ComputeFuncPtr = void(*)(const std::vector&, const Param&, std::vector&); + + // write for forward + CustomOp &set_device_infer(DeviceInferFuncPtr func); + CustomOp &set_shape_infer(ShapeInferFuncPtr func); + CustomOp &set_dtype_infer(DTypeInferFuncPtr func); + CustomOp &set_format_infer(FormatInferFuncPtr func); + CustomOp &set_preprocess(PreprocessFuncPtr func); + CustomOp &set_preprocess(const std::string &device, PreprocessFuncPtr func); + CustomOp &set_postprocess(PostprocessFuncPtr func); + CustomOp &set_postprocess(const std::string &device, PostprocessFuncPtr func); + CustomOp &set_compute(ComputeFuncPtr func); + CustomOp &set_compute(const std::string &device, ComputeFuncPtr func); + + CustomOp &set_description(const std::string &op_desc); + CustomOp &add_input(const std::string &name, const std::string &desc, const std::initializer_list &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); + CustomOp &add_output(const std::string &name, const std::string &desc, const std::initializer_list &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); + CustomOp &add_input(const std::string &name, const std::initializer_list &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); + CustomOp &add_output(const std::string &name, const std::initializer_list &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); + CustomOp &add_inputs(const size_t &input_num); + CustomOp &add_outputs(const size_t &output_num); + CustomOp &add_param(const std::string &name, const ParamVal &default_val); + CustomOp &add_param(const std::string &name, const std::string &desc, const ParamVal &default_val); + + // read + std::string op_type(void) const; + std::string op_desc(void) const; + RunTimeId runtime_id(void) const; + size_t input_num(void) const; + size_t output_num(void) const; + std::string str(void) const; + + const ParamInfo ¶m_info(void) const; + ArgInfo input_info(size_t idx) const; + ArgInfo output_info(size_t idx) const; + const std::vector &inputs_info(void) const; + const std::vector &outputs_info(void) const; + + // use + std::vector infer_output_device(const std::vector&, const Param&) const; + std::vector infer_output_shape (const std::vector&, const Param&) const; + std::vector infer_output_dtype (const std::vector&, const Param&) const; + std::vector infer_output_format(const std::vector&, const Param&) const; + void compute(const std::vector&, const Param&, std::vector&) const; +}; + +} diff --git a/src/custom/include/megbrain/custom/param.h b/src/custom/include/megbrain/custom/param.h new file mode 100644 index 00000000..760ca2a4 --- /dev/null +++ b/src/custom/include/megbrain/custom/param.h @@ -0,0 +1,61 @@ +/** + * \file src/custom/include/megbrain/custom/param.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include +#include "param_val.h" + +namespace custom { + +class ParamSchemaImpl; +class ParamInfoImpl; +class ParamImpl; + +// Schema of a param element +class ParamSchema { + CUSTOM_PIMPL_CLS_DECL(ParamSchema); + ParamSchema(const std::string &name, const ParamVal &value, const std::string &desc=""); + + const std::string &name(void) const; + const std::string &desc(void) const; + const ParamVal &default_val(void) const; + ParamDynType type(void) const; + std::string str(void) const; +}; + +class ParamInfo { + CUSTOM_PIMPL_CLS_DECL(ParamInfo); + + void set_tag(const std::string&); + void set_meta(const std::vector &meta); + uint32_t tag(void) const; + std::vector &meta(void); + const std::vector &meta(void) const; +}; + +class Param { + CUSTOM_PIMPL_CLS_DECL(Param); + + Param(const ParamInfo&); + ParamVal &operator[](const std::string&); + const ParamVal &operator[](const std::string&) const; + const std::unordered_map &raw() const; + bool exist(const std::string &name) const; + std::string to_bytes(void) const; + void from_bytes(const std::string&); +}; + +bool operator==(const Param&, const Param&); + +} // custom diff --git a/src/custom/include/megbrain/custom/param_val.h b/src/custom/include/megbrain/custom/param_val.h new file mode 100644 index 00000000..122baa1f --- /dev/null +++ b/src/custom/include/megbrain/custom/param_val.h @@ -0,0 +1,290 @@ +/** + * \file src/custom/include/megbrain/custom/param_val.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "utils.h" + +namespace custom { + +/** + * we can add a new basic data type here, basic means we can perform binary + * op such as: +, -, *, /, ==, != between any two of them + */ +#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ...) \ + cb(Int32, int32_t, ##__VA_ARGS__) \ + cb(Int64, int64_t, ##__VA_ARGS__) \ + cb(Uint32, uint32_t, ##__VA_ARGS__) \ + cb(Uint64, uint64_t, ##__VA_ARGS__) \ + cb(Float32, float, ##__VA_ARGS__) \ + cb(Float64, double, ##__VA_ARGS__) \ + cb(Bool, bool, ##__VA_ARGS__) + +#define CUSTOM_FOR_STRING_PARAMTYPE(cb, ...) \ + cb(String, std::string, ##__VA_ARGS__) + +#define CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ...) \ + cb(Int32List, std::vector, ##__VA_ARGS__) \ + cb(Int64List, std::vector, ##__VA_ARGS__) \ + cb(Uint32List, std::vector, ##__VA_ARGS__) \ + cb(Uint64List, std::vector, ##__VA_ARGS__) \ + cb(Float32List, std::vector, ##__VA_ARGS__) \ + cb(Float64List, std::vector, ##__VA_ARGS__) + +#define CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ...) \ + cb(BoolList, std::vector, ##__VA_ARGS__) + +#define CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ...) \ + cb(StringList, std::vector, ##__VA_ARGS__) + +/** + * to avoid the recursive of MACRO + */ +#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY(cb, ...) \ + cb(Int32, int32_t, ##__VA_ARGS__) \ + cb(Int64, int64_t, ##__VA_ARGS__) \ + cb(Uint32, uint32_t, ##__VA_ARGS__) \ + cb(Uint64, uint64_t, ##__VA_ARGS__) \ + cb(Float32, float, ##__VA_ARGS__) \ + cb(Float64, double, ##__VA_ARGS__) \ + cb(Bool, bool, ##__VA_ARGS__) + +#define CUSTOM_FOR_EACH_VALID_PARAMTYPE(cb, ...) \ + CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ##__VA_ARGS__) \ + CUSTOM_FOR_STRING_PARAMTYPE(cb, ##__VA_ARGS__) \ + CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ + CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ + CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__) + +#define CUSTOM_FOR_EACH_LIST_PARAMTYPE(cb, ...) \ + CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ + CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ + CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__) + +/** + * Macro Callback for Register + */ +#define CUSTOM_REG_DYN_PARAMTYPE(dyn_type, static_type) dyn_type, +#define CUSTOM_REG_DYN_PARAMTYPE_NAME(dyn_type, static_type) {ParamDynType::dyn_type, #dyn_type}, + +#define CUSTOM_REG_DYN_PARAMTYPE_GETTER(dyn_type, static_type) \ + template <> \ + struct get_dyn_type { \ + static constexpr ParamDynType type = ParamDynType::dyn_type;\ + }; + +#define CUSTOM_REG_STATIC_PARAMTYPE_GETTER(dyn_type, static_type) \ + template <> \ + struct get_static_type { \ + using type = static_type; \ + }; + +enum class ParamDynType: uint32_t { + CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE) + Invalid=255 +}; + +static std::unordered_map, EnumCmp> type2name = { + CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_NAME) + {ParamDynType::Invalid, "Invalid"} +}; + +/** + * get the dynamic data type according to the builtin static data type + * we can use it like: + * ParamDynType dyn_type = get_dyn_type::type; + * assert(dyn_type == ParamDynType::Int32) + */ +template +struct get_dyn_type { + static constexpr ParamDynType type = ParamDynType::Invalid; +}; + +/** + * get the static data type according to the dynamic data type + * we can use it like: + * get_static_type::type int_32_value; + * assert(std::is_same::value) + */ +template +struct get_static_type; + +CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_GETTER) +CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_STATIC_PARAMTYPE_GETTER) + +#undef CUSTOM_REG_DYN_PARAMTYPE +#undef CUSTOM_REG_DYN_PARAMTYPE_NAME +#undef CUSTOM_REG_DYN_PARAMTYPE_GETTER +#undef CUSTOM_REG_STATIC_PARAMTYPE_GETTER + +template +struct get_vector_template_arg_type; + +template +struct get_vector_template_arg_type> { + using type = std::decay_t; +}; + +template +struct is_vector { + static constexpr bool value = false; +}; + +template +struct is_vector > { + static constexpr bool value = true; +}; + +template +std::string vec2str(const std::vector &vec) { + std::stringstream ss; + ss << "{"; + for (const auto &val: vec) { + ss << val << ", "; + } + if (vec.size() != 0) { + ss.seekp(ss.tellp()-std::streampos(2)); + } + ss << "}"; + return ss.str(); +} + +/** + * we use void* rather than template to help us realise a complete dynamic type + * if we use template such as: + * template + * class ParamVal { + * T m_data; + * } + * Con1: user need to set the type explicitly when class template instantiation + * Con2: ParamVal can not be assigned to ParamVal + */ +class ParamVal { + std::unique_ptr m_ptr; + ParamDynType m_type; + +public: + template + ParamVal(const T &val); + template + ParamVal(const std::initializer_list &val); + + ParamVal(); + ParamVal(const char *str); + ParamVal(const std::initializer_list &strs); + ParamVal(const std::vector &strs); + ParamVal(const ParamVal &rhs); + + template + ParamVal &operator=(const T &rhs); + template + ParamVal &operator=(const std::initializer_list &val); + + ParamVal &operator=(const char *str); + ParamVal &operator=(const std::initializer_list &strs); + ParamVal &operator=(const std::vector &strs); + ParamVal &operator=(const ParamVal &rhs); + + template + const T &as(void) const; + template + T &as(void); + + const void *raw_ptr(void) const; + void *raw_ptr(void); + ParamDynType type(void) const; + std::string str(void) const; + size_t size(void) const; + + static std::string to_bytes(const ParamVal &value); + static ParamVal from_bytes(const std::string &bytes, size_t &offset); + + friend ParamVal operator+(const ParamVal &lhs, const ParamVal &rhs); + friend ParamVal operator-(const ParamVal &lhs, const ParamVal &rhs); + friend ParamVal operator*(const ParamVal &lhs, const ParamVal &rhs); + friend ParamVal operator/(const ParamVal &lhs, const ParamVal &rhs); + friend bool operator==(const ParamVal &lhs, const ParamVal &rhs); + friend bool operator!=(const ParamVal &lhs, const ParamVal &rhs); + friend bool operator> (const ParamVal &lhs, const ParamVal &rhs); + friend bool operator< (const ParamVal &lhs, const ParamVal &rhs); + friend bool operator>=(const ParamVal &lhs, const ParamVal &rhs); + friend bool operator<=(const ParamVal &lhs, const ParamVal &rhs); +}; + +ParamVal operator+(const ParamVal &lhs, const ParamVal &rhs); +ParamVal operator-(const ParamVal &lhs, const ParamVal &rhs); +ParamVal operator*(const ParamVal &lhs, const ParamVal &rhs); +ParamVal operator/(const ParamVal &lhs, const ParamVal &rhs); +bool operator==(const ParamVal &lhs, const ParamVal &rhs); +bool operator!=(const ParamVal &lhs, const ParamVal &rhs); +bool operator> (const ParamVal &lhs, const ParamVal &rhs); +bool operator< (const ParamVal &lhs, const ParamVal &rhs); +bool operator>=(const ParamVal &lhs, const ParamVal &rhs); +bool operator<=(const ParamVal &lhs, const ParamVal &rhs); + +template +ParamVal::ParamVal(const T &val): m_ptr(nullptr, impl_deleter>) { + using DecayType = std::decay_t; + m_type = get_dyn_type::type; + custom_assert(m_type != ParamDynType::Invalid, "param construct error! unsupported builtin type"); + m_ptr.reset(new DecayType(val)); +} + +template +ParamVal::ParamVal(const std::initializer_list &val): ParamVal(std::vector>(val)) { + +} + +template +ParamVal &ParamVal::operator=(const T &rhs) { + using DecayType = std::decay_t; + ParamDynType rhs_dyn_type = get_dyn_type::type; + custom_assert(rhs_dyn_type != ParamDynType::Invalid, "unsupported builtin dtype"); + + if (rhs_dyn_type == m_type) { + TypedRef(DecayType, m_ptr.get()) = rhs; + } + else { + m_type = rhs_dyn_type; + std::unique_ptr new_ptr(new DecayType(rhs), impl_deleter); + m_ptr.swap(new_ptr); + } + return *this; +} + +template +ParamVal &ParamVal::operator=(const std::initializer_list &val) { + return this->operator=(std::vector>(val)); +} + +template +const T &ParamVal::as(void) const { + return const_cast(this)->as(); +} + +template +T &ParamVal::as(void) { + using DecayType = std::decay_t; + ParamDynType t_dyn_type = get_dyn_type::type; + custom_assert( + t_dyn_type == m_type, "type mismatch, type %s cannot be cast to type %s\n", + type2name[m_type].c_str(), type2name[t_dyn_type].c_str() + ); + return TypedRef(T, m_ptr.get()); +} + +} diff --git a/src/custom/include/megbrain/custom/tensor.h b/src/custom/include/megbrain/custom/tensor.h new file mode 100644 index 00000000..cdbbf07d --- /dev/null +++ b/src/custom/include/megbrain/custom/tensor.h @@ -0,0 +1,280 @@ +/** + * \file src/custom/include/megbrain/custom/tensor.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include +#include +#include "utils.h" +#include "accessor.h" + +namespace custom { + +#define CUSTOM_DATA_ADAPTOR_FRIEND_DECL \ + template \ + friend BuiltinT to_builtin(const CustomT &custom); \ + template \ + friend CustomT to_custom(const BuiltinT &builtin) + +#define CUSTOM_FOR_EACH_DEVICE_TYPE(cb) \ + cb(x86, CPU, "cpux") \ + cb(cuda, CUDA, "gpux") + +#define CUSTOM_DEVICE_TYPE_ENUM_DECL(custom_type, builtin_type, builtin_str) custom_type, + +class Device { + const void *impl() const; + Device(const void *impl); + CUSTOM_PIMPL_CLS_DECL(Device); + +public: + enum class DeviceEnum: uint32_t { + CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_DEVICE_TYPE_ENUM_DECL) + }; + + Device(const std::string &device); + Device(const char *device); + Device(DeviceEnum device); + + std::string str(void) const; + DeviceEnum enumv(void) const; + + static bool is_legal(const std::string &device); + static bool is_legal(DeviceEnum device); + static std::vector legal_devices(void); + + friend class Tensor; + friend bool operator==(const Device &lhs, const Device &rhs); + CUSTOM_DATA_ADAPTOR_FRIEND_DECL; +}; + +using DeviceEnum = Device::DeviceEnum; + +bool operator==(const Device &lhs, const Device &rhs); + +class Shape { + const void *impl() const; + Shape(const void *impl); + CUSTOM_PIMPL_CLS_DECL(Shape); + +public: + Shape(const std::vector &rhs); + Shape(const std::initializer_list &rhs); + + size_t &operator[](size_t idx); + size_t operator[](size_t idx) const; + + void ndim(size_t dim); + size_t ndim(void) const; + + friend class Tensor; + friend bool operator==(const Shape &lhs, const Shape &rhs); + CUSTOM_DATA_ADAPTOR_FRIEND_DECL; +}; + +bool operator==(const Shape &lhs, const Shape &rhs); + +using float16_t = uint16_t; +using bfloat16_t = uint16_t; + +#if MEGDNN_DISABLE_FLOAT16 + #define fp16_wrap(cb, custom_dtype, dnn_dtype, c_dtype) +#else + #define fp16_wrap(cb, custom_dtype, dnn_dtype, c_dtype) cb(custom_dtype, dnn_dtype, c_dtype) +#endif + +#define CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(cb) \ + cb(float32, Float32, float) \ + cb(uint8, Uint8, uint8_t) \ + cb(int8, Int8, int8_t) \ + cb(int16, Int16, int16_t) \ + cb(int32, Int32, int32_t) \ + fp16_wrap(cb, float16, Float16, float16_t) \ + fp16_wrap(cb, bfloat16, BFloat16, bfloat16_t) \ + cb(uint16, Uint16, uint16_t) \ + cb(quint8, Quantized8Asymm, uint8_t) \ + cb(qint32, QuantizedS32, int32_t) \ + cb(qint8, QuantizedS8, int8_t) \ + cb(qint16, QuantizedS16, int16_t) + +#define CUSTOM_DTYPE_ENUM_DECL(custom_type, builtin_type, ctype) custom_type, + +class DType { + const void *impl() const; + DType(const void *impl); + CUSTOM_PIMPL_CLS_DECL(DType); + +public: + enum class DTypeEnum: uint32_t { + CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_DTYPE_ENUM_DECL) + }; + + DType(const std::string &dtype); + DType(const char *dtype); + DType(const std::string &dtype, float scale, uint8_t zero_point = 0); + DType(const char *dtype, float scale, uint8_t zero_point = 0); + DType(DTypeEnum dtype); + DType(DTypeEnum dtype, float scale, uint8_t zero_point = 0); + + std::string str(void) const; + DTypeEnum enumv() const; + float scale(void) const; + uint8_t zero_point(void) const; + template + bool is_compatible(void) const; + + static bool is_legal(const std::string &dtype); + static bool is_legal(const DTypeEnum &dtype); + static std::vector legal_dtypes(void); + + friend class Tensor; + friend bool operator==(const DType &lhs, const DType &rhs); + CUSTOM_DATA_ADAPTOR_FRIEND_DECL; +}; + +using DTypeEnum = DType::DTypeEnum; + +template +struct DTypeTrait; + +#define CUSTOM_DEFINE_DTYPE_TRAIT(custom_type, builtin_type, ctype) \ +template <> \ +struct DTypeTrait { \ + using type = ctype; \ +}; + +#define CUSTOM_CASE_TO_COMPARE_DTYPE(custom_type, builtin_type, ctype) \ + case (DTypeEnum::custom_type): { \ + return std::is_same::value; \ + } + +CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_DEFINE_DTYPE_TRAIT) + +template +bool DType::is_compatible(void) const { + using DecayT = typename std::decay::type; + auto dtype_enum = enumv(); +#if !MEGDNN_DISABLE_FLOAT16 + if (dtype_enum == DTypeEnum::float16) { + return sizeof(DecayT) == sizeof(DTypeTrait::type); + } + else if (dtype_enum == DTypeEnum::bfloat16) { + return sizeof(DecayT) == sizeof(DTypeTrait::type); + } +#endif + switch (dtype_enum) { + CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_CASE_TO_COMPARE_DTYPE) + default: + return false; + } +} + +bool operator==(const DType &lhs, const DType &rhs); +bool operator==(const DType &lhs, const std::string &rhs); +bool operator==(const DType &lhs, const char *rhs); +bool operator==(const std::string &lhs, const DType &rhs); +bool operator==(const char *lhs, const DType &rhs); + +class Format { + const void *impl() const; + Format(const void *impl); + CUSTOM_PIMPL_CLS_DECL(Format); + +public: + Format(const std::string &format); + Format(const char *format); + + std::string str(void) const; + bool is_default(void) const; + + friend class Tensor; + CUSTOM_DATA_ADAPTOR_FRIEND_DECL; +}; + +class Tensor { + void *m_tensor; + + const void *impl(void) const; + Tensor(const void *impl); + + const size_t *shapes_raw(void) const; + const ptrdiff_t *strides_raw(void) const; + +public: + Tensor() = delete; + Tensor(const Tensor &rhs); + Tensor &operator=(const Tensor &rhs); + + Shape shape(void) const; + DType dtype(void) const; + Format format(void) const; + Device device(void) const; + + size_t size(void) const; + std::vector stride(void) const; + float scale(void) const; + uint8_t zero_point(void) const; + + void *data(void); + const void *data(void) const; + + template + T *data(void); + template + const T *data(void) const; + + template class PtrTraits = DefaultPtrTraits, + typename index_t = int64_t> + const TensorAccessor accessor() const; + + template class PtrTraits = DefaultPtrTraits, + typename index_t = int64_t> + TensorAccessor accessor(); + + CUSTOM_DATA_ADAPTOR_FRIEND_DECL; +}; + +template +T *Tensor::data(void) { + custom_assert(dtype().is_compatible(), + "invalid convert, tensor data type is %s", dtype().str().c_str()); + return reinterpret_cast(data()); +} +template +const T *Tensor::data(void) const { + return const_cast(this)->data(); +} + +template class PtrTraits, typename index_t> +const TensorAccessor Tensor::accessor() const { + return const_cast(this)->accessor(); +} + +template class PtrTraits, typename index_t> +TensorAccessor Tensor::accessor() { + custom_assert(N == shape().ndim(), + "cannot get a %lu-d accessor for a tensor with dim %lu", static_cast(N), static_cast(shape().ndim())); + custom_assert(N > 0, "cannot get 0-d accessor"); + + T *ptr = data(); + return TensorAccessor(ptr, shapes_raw(), strides_raw()); +} + +#undef CUSTOM_DATA_ADAPTOR_FRIEND_DECL +#undef CUSTOM_DEVICE_TYPE_ENUM_DECL +#undef CUSTOM_DTYPE_ENUM_DECL +#undef CUSTOM_DEFINE_DTYPE_TRAIT +#undef CUSTOM_CASE_TO_COMPARE_DTYPE + +} // custom diff --git a/src/custom/include/megbrain/custom/utils.h b/src/custom/include/megbrain/custom/utils.h new file mode 100644 index 00000000..19968ab0 --- /dev/null +++ b/src/custom/include/megbrain/custom/utils.h @@ -0,0 +1,104 @@ +/** + * \file src/custom/include/megbrain/custom/utils.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include +#include +#include +#include + +namespace custom { + +void assert_failed_log(const char *file, int line, const char *func, const char *expr, const char *msg_fmt, ...); + +#define custom_expect(expr, msg...) \ + if (!(expr)) { \ + assert_failed_log( \ + __FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg \ + ); \ + } + +#define custom_assert(expr, msg...) \ + if (!(expr)) { \ + assert_failed_log( \ + __FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg \ + ); \ + } \ + assert((expr)) + +class UnImpleWarnLog { +public: + UnImpleWarnLog(const std::string &func, const std::string &attr, + const std::string &val); +}; + +using void_deleter = void(*)(void*); + +template +void impl_deleter(void *ptr) { + delete reinterpret_cast(ptr); +} + +#define TypedPtr(type, raw_ptr) reinterpret_cast(raw_ptr) +#define TypedRef(type, raw_ptr) (*reinterpret_cast(raw_ptr)) + +#define CUSTOM_PIMPL_CLS_DECL(Cls) \ + std::unique_ptr m_impl; \ + public: \ + Cls(); \ + Cls(const Cls &rhs); \ + Cls &operator=(const Cls &rhs) + +#define CUSTOM_PIMPL_CLS_DEFINE(Cls) \ + Cls::Cls(): m_impl(new Cls##Impl(), impl_deleter) {} \ + \ + Cls::Cls(const Cls &rhs): m_impl(nullptr, impl_deleter) { \ + custom_assert( \ + rhs.m_impl != nullptr, \ + "invalid rhs for the copy constructor of %s", #Cls \ + ); \ + m_impl.reset(new Cls##Impl(TypedRef(Cls##Impl, rhs.m_impl.get()))); \ + } \ + \ + Cls &Cls::operator=(const Cls &rhs) { \ + custom_assert( \ + m_impl != nullptr && rhs.m_impl != nullptr, \ + "invalid assignment of %s, lhs or rhs is invalid", #Cls \ + ); \ + if (&rhs == this) \ + return *this; \ + \ + TypedRef(Cls##Impl, m_impl.get()) = TypedRef(Cls##Impl, rhs.m_impl.get()); \ + return *this; \ + } + +/** + * we define this two function explicitly used for std::unordered_map + * to improve the compatibility with different compiler versions +*/ +template +struct EnumHash { + size_t operator()(const T &rhs) const { + return static_cast(rhs); + } +}; + +template +struct EnumCmp { + bool operator()(const T &lhs, const T &rhs) const { + return static_cast(lhs) == static_cast(rhs); + } +}; + + +} // custom diff --git a/src/custom/test/manager.cpp b/src/custom/test/manager.cpp new file mode 100644 index 00000000..d4389237 --- /dev/null +++ b/src/custom/test/manager.cpp @@ -0,0 +1,96 @@ +/** + * \file src/custom/test/manager.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/custom/manager.h" +#include "megbrain/custom/custom.h" +#include "gtest/gtest.h" + +#define MANAGER_TEST_LOG 0 + +namespace custom { + +TEST(TestOpManager, TestOpManager) { + CustomOpManager *com = CustomOpManager::inst(); + com->insert("Op1", CUSTOM_OP_VERSION); + com->insert("Op2", CUSTOM_OP_VERSION); + std::shared_ptr ptr = com->find_or_reg("Op3", CUSTOM_OP_VERSION); + ASSERT_TRUE(ptr != nullptr); + + std::vector op_names = com->op_name_list(); + std::vector op_ids = com->op_id_list(); + + ASSERT_TRUE(op_names.size() == 3); + ASSERT_TRUE(op_ids.size() == 3); + +#if MANAGER_TEST_LOG + for (std::string &name: op_names) { + std::cout << name << std::endl; + } +#endif + + for (std::string &name: op_names) { + std::shared_ptr op = com->find(name); + ASSERT_TRUE(op != nullptr); + ASSERT_TRUE(op->op_type() == name); + RunTimeId id = com->to_id(name); + ASSERT_TRUE(com->find(id) == op); + } + + for (RunTimeId &id: op_ids) { + std::shared_ptr op = com->find(id); + ASSERT_TRUE(op != nullptr); + ASSERT_TRUE(op->runtime_id() == id); + std::string name = com->to_name(id); + ASSERT_TRUE(com->find(name) == op); + } + + ASSERT_FALSE(com->erase("Op0")); +#if MANAGER_TEST_LOG + for (auto &name: com->op_name_list()) { + std::cout << name << std::endl; + } +#endif + ASSERT_TRUE(com->erase("Op1")); + ASSERT_TRUE(com->erase(com->to_id("Op2"))); + ASSERT_TRUE(com->op_id_list().size() == 1); + ASSERT_TRUE(com->op_name_list().size() == 1); + ASSERT_TRUE(com->op_name_list()[0] == "Op3"); + ptr.reset(); + ASSERT_TRUE(com->erase("Op3")); +} + +TEST(TestOpManager, TestOpReg) { + CUSTOM_OP_REG(Op1) + .add_inputs(2) + .add_outputs(3) + .add_input("lhs") + .add_param("param1", 1) + .add_param("param2", 3.45); + + CUSTOM_OP_REG(Op2) + .add_input("lhs") + .add_input("rhs") + .add_output("out") + .add_param("param1", "test") + .add_param("param2", true) + .add_param("", "no name"); + + (void)_Op1; + (void)_Op2; + +#if MANAGER_TEST_LOG + for (const auto &name: CustomOpManager::inst()->op_name_list()) { + std::cout << CustomOpManager::inst()->find(name)->str() << std::endl; + } +#endif +} + +} diff --git a/src/custom/test/op.cpp b/src/custom/test/op.cpp new file mode 100644 index 00000000..20449d52 --- /dev/null +++ b/src/custom/test/op.cpp @@ -0,0 +1,205 @@ +/** + * \file src/custom/test/op.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/custom/op.h" +#include "megbrain/comp_node.h" +#include "megbrain/tensor.h" +#include "megbrain/custom/data_adaptor.h" +#include "gtest/gtest.h" +#include "megbrain_build_config.h" + +#define OP_TEST_LOG 0 + +using namespace mgb; + +namespace custom { + +TEST(TestCustomOp, TestCustomOpInfoSetter) { + CustomOp test("TestOp", CUSTOM_OP_VERSION); + test.set_description("Test Op") + .add_input("lhs", "lhs of test op", {"float32", "int32"}, 2) + .add_inputs(2) + .add_input("rhs", "rhs of test op", {"float32", "int32"}, 2) + .add_outputs(1) + .add_output("out", "out of test op", {"float32", "int32"}, 2) + .add_outputs(3); + + ASSERT_TRUE(test.op_type() == "TestOp"); + ASSERT_TRUE(test.op_desc() == "Test Op"); + ASSERT_TRUE(test.input_num() == 4); + ASSERT_TRUE(test.output_num() == 5); + +#if OP_TEST_LOG + for (auto input: test.inputs_info()) { + std::cout << input.str() << std::endl; + } + for (auto output: test.outputs_info()) { + std::cout << output.str() << std::endl; + } +#endif + + test.add_param("param1", "param1 - float", 1.23f) + .add_param("param2", "param2 - float list", {2.34f, 3.45f}) + .add_param("param3", "param3 - string", "test-string") + .add_param("param4", {"test", "string", "list"}) + .add_param("param5", 1); + +#if OP_TEST_LOG + ParamInfo pinfo = test.param_info(); + for (auto kv: pinfo.meta()) { + std::cout << kv.str() << std::endl; + } +#endif +} + +void device_infer(const std::vector &inputs, const Param ¶ms, + std::vector &outputs) { + (void)inputs; + (void)params; + (void)outputs; + outputs[0] = inputs[1]; + outputs[1] = inputs[0]; +} + +void shape_infer(const std::vector &inputs, const Param ¶ms, + std::vector &outputs) { + (void)inputs; + (void)params; + (void)outputs; + outputs[0] = inputs[1]; + outputs[1] = inputs[0]; +} + +void dtype_infer(const std::vector &inputs, const Param ¶ms, + std::vector &outputs) { + (void)inputs; + (void)params; + (void)outputs; + outputs[0] = inputs[1]; + outputs[1] = inputs[0]; +} + +void format_infer(const std::vector &inputs, const Param ¶ms, + std::vector &outputs) { + (void)inputs; + (void)params; + (void)outputs; + outputs[0] = inputs[1]; + outputs[1] = inputs[0]; +} + +void cpu_kernel(const std::vector &inputs, const Param ¶ms, + std::vector &outputs) { + (void)inputs; + (void)params; + (void)outputs; +#if OP_TEST_LOG + std::cout << "Checking CPU Forward - " << params["device"].as() << std::endl; +#endif + ASSERT_TRUE(params["device"] == "x86"); +} + +void gpu_kernel(const std::vector &inputs, const Param ¶ms, + std::vector &outputs) { + (void)inputs; + (void)params; + (void)outputs; +#if OP_TEST_LOG + std::cout << "Checking GPU Forward - " << params["device"].as() << std::endl; +#endif + ASSERT_TRUE(params["device"] == "cuda"); +} + +TEST(TestCustomOp, TestCustomOpFuncSetter) { +#if MGB_CUDA + CustomOp test("TestOp", CUSTOM_OP_VERSION); + test.set_description("Test Op Forward Backward Union") + .add_input("lhs", "lhs of Test op", {"float32", "int32"}, 2) + .add_input("rhs", "rhs of Test op", {"float32", "int32"}, 2) + .add_output("outl", "outl of Test op", {"float32", "int32"}, 2) + .add_output("outr", "outr of Test op", {"float32", "int32"}, 2) + .add_param("smooth", "smooth", 0.f) + .add_param("device", "using for judge device", "x86"); + + std::vector idevices = {"x86", "cuda"}; + std::vector ishapes = {{2, 3}, {3, 4}}; + std::vector idtypes = {"int32", "float32"}; + std::vector iformats = {"default", "default"}; + Param param(test.param_info()); + + std::vector odevices = test.infer_output_device(idevices, param); + std::vector oshapes = test.infer_output_shape (ishapes, param); + std::vector odtypes = test.infer_output_dtype (idtypes, param); + std::vector oformats = test.infer_output_format(iformats, param); + + ASSERT_TRUE(odevices.size() == 2); + ASSERT_TRUE(oshapes.size() == 2); + ASSERT_TRUE(odtypes.size() == 2); + ASSERT_TRUE(oformats.size() == 2); + + ASSERT_TRUE(odevices[0] == "x86"); + ASSERT_TRUE(odevices[1] == "x86"); + ASSERT_TRUE(oshapes[0] == Shape({2,3})); + ASSERT_TRUE(oshapes[1] == Shape({2,3})); + ASSERT_TRUE(odtypes[0] == "int32"); + ASSERT_TRUE(odtypes[1] == "int32"); + ASSERT_TRUE(iformats[0].is_default()); + ASSERT_TRUE(iformats[1].is_default()); + + test.set_device_infer(device_infer) + .set_shape_infer(shape_infer) + .set_dtype_infer(dtype_infer) + .set_format_infer(format_infer); + + odevices = test.infer_output_device(idevices, param); + oshapes = test.infer_output_shape (ishapes, param); + odtypes = test.infer_output_dtype (idtypes, param); + oformats = test.infer_output_format(iformats, param); + + ASSERT_TRUE(odevices.size() == 2); + ASSERT_TRUE(oshapes.size() == 2); + ASSERT_TRUE(odtypes.size() == 2); + ASSERT_TRUE(oformats.size() == 2); + + ASSERT_TRUE(odevices[0] == "cuda"); + ASSERT_TRUE(odevices[1] == "x86"); + ASSERT_TRUE(oshapes[0] == Shape({3,4})); + ASSERT_TRUE(oshapes[1] == Shape({2,3})); + ASSERT_TRUE(odtypes[0] == "float32"); + ASSERT_TRUE(odtypes[1] == "int32"); + ASSERT_TRUE(iformats[0].is_default()); + ASSERT_TRUE(iformats[1].is_default()); + + test.set_compute(cpu_kernel); + DeviceTensorND cdev_itensor0(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); + DeviceTensorND cdev_itensor1(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); + DeviceTensorND cdev_otensor0(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); + DeviceTensorND cdev_otensor1(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); + + std::vector cinputs = {to_custom_tensor(cdev_itensor0), to_custom_tensor(cdev_itensor1)}; + std::vector coutputs ={to_custom_tensor(cdev_otensor0), to_custom_tensor(cdev_otensor1)}; + param["device"] = "x86"; + test.compute(cinputs, param, coutputs); + + test.set_compute("cuda", gpu_kernel); + DeviceTensorND gdev_itensor0(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); + DeviceTensorND gdev_itensor1(CompNode::load("gpux"), {3, 2}, dtype::Float32{}); + DeviceTensorND gdev_otensor0(CompNode::load("gpux"), {3, 2}, dtype::Float32{}); + DeviceTensorND gdev_otensor1(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); + + std::vector ginputs = {to_custom_tensor(gdev_itensor0), to_custom_tensor(gdev_itensor1)}; + std::vector goutputs ={to_custom_tensor(gdev_otensor0), to_custom_tensor(gdev_otensor1)}; + param["device"] = "cuda"; + test.compute(ginputs, param, goutputs); +#endif +} + +} diff --git a/src/custom/test/param.cpp b/src/custom/test/param.cpp new file mode 100644 index 00000000..7f22ead8 --- /dev/null +++ b/src/custom/test/param.cpp @@ -0,0 +1,208 @@ +/** + * \file src/custom/test/param.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/custom/param.h" +#include "gtest/gtest.h" +#include + +#define PARAM_TEST_LOG 0 + +namespace custom { + +#define SchemaDef \ + ParamSchema schema_bool("param_bool", true, "bool"); \ + ParamSchema schema_flt("param_flt", 2.3f, "float"); \ + ParamSchema schema_int("param_int", 4, "int"); \ + ParamSchema schema_str("param_str", "test", "string"); \ + ParamSchema schema_bool_list("param_bl", {true, false, true}, "bool list"); \ + ParamSchema schema_flt_list("param_fl", {1.1f, 2.2f, 3.3f}, "float list"); \ + ParamSchema schema_int_list("param_il", {1, 2, 3}, "int list"); \ + ParamSchema schema_str_list("param_sl", {"test1", "test2", "test3"}, "string list") + +#define InfoDef \ + info.meta().emplace_back(schema_bool); \ + info.meta().emplace_back(schema_flt); \ + info.meta().emplace_back(schema_int); \ + info.meta().emplace_back(schema_str); \ + info.meta().emplace_back(schema_bool_list); \ + info.meta().emplace_back(schema_flt_list); \ + info.meta().emplace_back(schema_int_list); \ + info.meta().emplace_back(schema_str_list) + +TEST(TestParam, TestParamScheme) { +#if PARAM_TEST_LOG + SchemaDef; + ParamSchema new_schema = schema_int; + + std::cout << schema_bool.str() << std::endl; + std::cout << schema_flt.str() << std::endl; + std::cout << schema_int.str() << std::endl; + std::cout << schema_str.str() << std::endl; + std::cout << schema_bool_list.str() << "len: "<< schema_bool_list.default_val().size() << std::endl; + std::cout << schema_flt_list.str() << "len: "<< schema_flt_list.default_val().size() << std::endl; + std::cout << schema_int_list.str() << "len: "<< schema_int_list.default_val().size() << std::endl; + std::cout << schema_str_list.str() << "len: "<< schema_str_list.default_val().size() << std::endl; + + std::cout << new_schema.str() << std::endl; +#endif +} + +TEST(TestParam, TestParamVal) { + ParamVal pv1 = 1.2f, pv2 = true, pv3 = "test", pv4 = {0, 1, 2}, + pv5 = {true, false, true}; + +#if PARAM_TEST_LOG + ParamVal pv6 = {"test1", "test2", "test3"}; + std::cout << pv1.str() << std::endl; + std::cout << pv2.str() << std::endl; + std::cout << pv3.str() << std::endl; + std::cout << pv4.str() << std::endl; + std::cout << pv5.str() << std::endl; + std::cout << pv6.str() << std::endl; +#endif + + ParamVal pv_manip = pv1; + ASSERT_TRUE(pv_manip.type() == pv1.type()); + ASSERT_TRUE(pv_manip == pv1); + pv_manip = 1.3; + ASSERT_TRUE(pv_manip.type() != pv1.type()); + ASSERT_TRUE(pv_manip != pv1); + ASSERT_TRUE(pv_manip > pv1); + pv_manip = pv_manip + pv1; + ASSERT_TRUE(pv_manip.type() == ParamDynType::Float64); + ASSERT_TRUE(pv_manip == 1.3 + 1.2f); + pv_manip = 1.3f + 1.2f; + ASSERT_TRUE(pv_manip.type() == pv1.type()); + + pv_manip = false; + ASSERT_TRUE(pv_manip.type() == pv2.type()); + ASSERT_TRUE(pv_manip.type() == ParamDynType::Bool); + ASSERT_TRUE(pv_manip != pv2); + + pv_manip = "test"; + ASSERT_TRUE(pv_manip.type() == pv3.type()); + ASSERT_TRUE(pv_manip.type() == ParamDynType::String); + ASSERT_TRUE(pv_manip == pv3); + pv_manip = "test1"; + ASSERT_TRUE(pv_manip > pv3); + pv_manip = pv_manip + pv3; + ASSERT_TRUE(pv_manip == "test1test"); + + pv_manip = {0, 1, 2}; + ASSERT_TRUE(pv_manip.type() == pv4.type()); + ASSERT_TRUE(pv_manip.type() == ParamDynType::Int32List); + ASSERT_TRUE(pv_manip == pv4); + pv_manip = {3, 2, 1}; + ASSERT_TRUE(pv_manip != pv4); + ASSERT_TRUE(pv_manip > pv4); + + pv_manip = {true, false, true}; + ASSERT_TRUE(pv_manip.type() == pv5.type()); + ASSERT_TRUE(pv_manip.type() == ParamDynType::BoolList); + ASSERT_TRUE(pv_manip == pv5); + pv_manip = {false, true, false}; + ASSERT_TRUE(pv_manip != pv5); + +} + +TEST(TestParam, TestParamInfo) { + ParamInfo info; + info.set_tag("Test"); +#if PARAM_TEST_LOG + uint32_t tag = info.tag(); + std::cout << tag << std::endl; +#endif + + SchemaDef; + InfoDef; + + ParamInfo new_info1, new_info2; + new_info1.set_meta(info.meta()); + new_info2.meta() = info.meta(); + +#if PARAM_TEST_LOG + for (auto ele: new_info1.meta()) { + std::cout << ele.str() << std::endl; + } + for (auto ele: new_info2.meta()) { + std::cout << ele.str() << std::endl; + } +#endif +} + +TEST(TestParam, TestParam) { + ParamInfo info; + SchemaDef; + InfoDef; + + Param param(info); + +#if PARAM_TEST_LOG + std::vector names = {"param_bool", "param_flt", "param_int", "param_str", "param_bl", "param_fl", "param_il", "param_sl"}; + for (auto &name: names) { + std::cout << param[name].str() << std::endl;; + } +#endif + ASSERT_TRUE(param["param_bool"] == true); + ASSERT_TRUE(param["param_flt"] == 2.3f); + ASSERT_TRUE(param["param_int"] == 4); + ASSERT_TRUE(param["param_str"] == "test"); + ASSERT_TRUE(param["param_bl"] == ParamVal({true, false, true})); + ASSERT_TRUE(param["param_fl"] == ParamVal({1.1f, 2.2f, 3.3f})); + ASSERT_TRUE(param["param_il"] == ParamVal({1, 2, 3})); + ASSERT_TRUE(param["param_sl"] == ParamVal({"test1", "test2", "test3"})); + + param["param_bool"] = false; + param["param_flt"] = 3.4f; + param["param_int"] = 5; + param["param_str"] = "tset"; + param["param_bl"] = {false, true, false, true}; + param["param_fl"] = {7.6f, 6.5f}; + param["param_il"] = {5, 4, 3, 2, 1}; + param["param_sl"] = {"1tset", "2tset", "3tset", "4tset", "5tset"}; + + ASSERT_TRUE(param["param_bool"] != true); + ASSERT_TRUE(param["param_flt"] != 2.3f); + ASSERT_TRUE(param["param_int"] != 4); + ASSERT_TRUE(param["param_str"] != "test"); + ASSERT_TRUE(param["param_bl"] != ParamVal({true, false, true})); + ASSERT_TRUE(param["param_fl"] != ParamVal({1.1f, 2.2f, 3.3f})); + ASSERT_TRUE(param["param_il"] != ParamVal({1, 2, 3})); + ASSERT_TRUE(param["param_sl"] != ParamVal({"test1", "test2", "test3"})); + + ASSERT_TRUE(param["param_bool"] == false); + ASSERT_TRUE(param["param_flt"] == 3.4f); + ASSERT_TRUE(param["param_int"] == 5); + ASSERT_TRUE(param["param_str"] == "tset"); + ASSERT_TRUE(param["param_bl"] == ParamVal({false, true, false, true})); + ASSERT_TRUE(param["param_fl"] == ParamVal({7.6f, 6.5f})); + ASSERT_TRUE(param["param_il"] == ParamVal({5, 4, 3, 2, 1})); + ASSERT_TRUE(param["param_sl"] == ParamVal({"1tset", "2tset", "3tset", "4tset", "5tset"})); + +#if PARAM_TEST_LOG + Param copy_param = param; + for (auto &name: names) { + std::cout << copy_param[name].str() << std::endl; + } +#endif + + Param loaded_param(info); + std::string bytes = param.to_bytes(); + loaded_param.from_bytes(bytes); + +#if PARAM_TEST_LOG + for (auto &kv: loaded_param.raw()) { + std::cout << kv.first << ":\n" << kv.second.str() << std::endl; + } +#endif +} + +} diff --git a/src/custom/test/tensor.cpp b/src/custom/test/tensor.cpp new file mode 100644 index 00000000..d9d4366c --- /dev/null +++ b/src/custom/test/tensor.cpp @@ -0,0 +1,325 @@ +/** + * \file src/custom/test/tensor.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/custom/tensor.h" +#include "megbrain/custom/data_adaptor.h" +#include "megbrain/comp_node.h" +#include "megbrain/tensor.h" +#include "gtest/gtest.h" +#include "megbrain_build_config.h" + +#define TENSOR_TEST_LOG 0 + +using namespace mgb; + +namespace custom { + +TEST(TestDevice, TestDevice) { +#if MGB_CUDA + ASSERT_TRUE(Device::is_legal("x86")); + ASSERT_TRUE(Device::is_legal(DeviceEnum::cuda)); + ASSERT_FALSE(Device::is_legal("cpu")); + + Device dev1; + ASSERT_TRUE(dev1.str() == "invalid"); + + dev1 = "x86"; + ASSERT_TRUE("x86" == dev1); + + Device dev2 = "cuda"; + ASSERT_TRUE(dev2 == "cuda"); + ASSERT_FALSE(dev2 == dev1); + + Device dev3 = dev2; + ASSERT_TRUE(dev3 == dev2); + ASSERT_FALSE(dev3 == dev1); + + Device dev4 = DeviceEnum::cuda; + ASSERT_TRUE(dev4.enumv() == DeviceEnum::cuda); + +#if TENSOR_TEST_LOG + std::cout << dev1.str() << "\n" << dev2.str() << "\n" + << dev3.str() << "\n" << dev4.str() << std::endl; +#endif + + CompNode compnode = to_builtin(dev3); + ASSERT_TRUE(compnode.to_string_logical() == "gpux:0"); + compnode = CompNode::load("cpu0:0"); + Device dev5 = to_custom(compnode); + ASSERT_TRUE(dev5.str() == "x86"); + + std::vector devs1 = {"x86", "cuda", "x86"}; + megdnn::SmallVector compnodes = to_builtin(devs1); + ASSERT_TRUE(compnodes[0].to_string_logical() == "cpux:0"); + ASSERT_TRUE(compnodes[1].to_string_logical() == "gpux:0"); + ASSERT_TRUE(compnodes[2].to_string_logical() == "cpux:0"); + + std::vector devs2 = to_custom(compnodes); + ASSERT_TRUE(devs2[0] == "x86"); + ASSERT_TRUE(devs2[1].str() == "cuda"); + ASSERT_TRUE(devs2[2] == "x86"); +#endif +} + +TEST(TestShape, TestShape) { + Shape shape1, shape2; + ASSERT_TRUE(shape1.ndim() == 0); + + shape1 = {16, 32, 8, 8}; + shape2 = shape1; + ASSERT_TRUE(shape2.ndim() == 4); + ASSERT_TRUE(shape2[0] == 16); + ASSERT_TRUE(shape2[1] == 32); + ASSERT_TRUE(shape2[2] == 8); + ASSERT_TRUE(shape2[3] == 8); + + Shape shape3 = {16, 32, 8, 8}; + const Shape shape4 = shape1; + ASSERT_TRUE(shape3 == shape4); + shape3[0] = 32; + ASSERT_FALSE(shape3 == shape4); + ASSERT_TRUE(shape3[0] == 32); + ASSERT_TRUE(shape4[0] == 16); + + Shape shape5 = {2, 3, 4}; + TensorShape bshape1 = to_builtin(shape5); + ASSERT_TRUE(bshape1.ndim == 3); + ASSERT_TRUE(bshape1[0] == 2); + ASSERT_TRUE(bshape1[1] == 3); + ASSERT_TRUE(bshape1[2] == 4); + bshape1 = {4, 2, 3}; + Shape shape6 = to_custom(bshape1); + ASSERT_TRUE(shape6.ndim() == 3); + ASSERT_TRUE(shape6[0] == 4); + ASSERT_TRUE(shape6[1] == 2); + ASSERT_TRUE(shape6[2] == 3); + + Shape shape7; + shape7.ndim(3); + shape7[1] = 4; + ASSERT_TRUE(shape7 == Shape({0, 4, 0})); + + std::vector shapes1 = {{2, 3, 4}, {6}, {5, 7}}; + megdnn::SmallVector bshapes = to_builtin(shapes1); + ASSERT_TRUE(bshapes[0].total_nr_elems() == 2*3*4); + ASSERT_TRUE(bshapes[1].total_nr_elems() == 6); + ASSERT_TRUE(bshapes[2].total_nr_elems() == 35); + + std::vector shapes2 = to_custom(bshapes); + ASSERT_TRUE(shapes2[0] == Shape({2, 3, 4})); + ASSERT_TRUE(shapes2[1] == Shape({6})); + ASSERT_TRUE(shapes2[2] == Shape({5, 7})); +} + +TEST(TestDType, TestDType) { +#if !MEGDNN_DISABLE_FLOAT16 + ASSERT_TRUE(DType::is_legal("uint8")); + ASSERT_TRUE(DType::is_legal(DTypeEnum::bfloat16)); + + DType dtype1, dtype2; + ASSERT_TRUE(dtype1.str() == "invalid"); + + dtype1 = "float32"; + ASSERT_TRUE(dtype1.str() == "float32"); + + dtype2 = dtype1; + DType dtype3 = dtype2; + ASSERT_TRUE(dtype3 == dtype1); + ASSERT_TRUE(dtype3 == "float32"); + + dtype3 = "int8"; + ASSERT_FALSE("float32" == dtype3.str()); + ASSERT_FALSE(dtype3 == dtype2); + + DType dtype4 = DTypeEnum::int8, dtype5 = dtype3; + ASSERT_TRUE(dtype4 == dtype5); + ASSERT_TRUE(dtype4.is_compatible()); + ASSERT_FALSE(dtype4.is_compatible()); + + DType dtype6 = "int32"; + megdnn::DType bdtype1 = to_builtin(dtype6); + ASSERT_TRUE(bdtype1.name() == std::string("Int32")); + bdtype1 = megdnn::DType::from_enum(megdnn::DTypeEnum::BFloat16); + DType dtype7 = to_custom(bdtype1); + ASSERT_TRUE(dtype7.enumv() == DTypeEnum::bfloat16); + + std::vector dtypes1 = {"int8", "uint8", "float16"}; + megdnn::SmallVector bdtypes + = to_builtin(dtypes1); + ASSERT_TRUE(bdtypes[0].name() == std::string("Int8")); + ASSERT_TRUE(bdtypes[1].name() == std::string("Uint8")); + ASSERT_TRUE(bdtypes[2].name() == std::string("Float16")); + + std::vector dtypes2 = to_custom(bdtypes); + ASSERT_TRUE(dtypes2[0] == "int8"); + ASSERT_TRUE(dtypes2[1] == "uint8"); + ASSERT_TRUE(dtypes2[2] == "float16"); +#endif +} + +TEST(TestDType, TestDTypeQuantized) { + DType quint8_1("quint8", 3.2, 15); + DType quint8_2("quint8", 3.2, 15); + DType quint8_3("quint8", 3.2, 16); + DType quint8_4("quint8", 3.1, 15); + + ASSERT_TRUE(quint8_1 == quint8_2); + ASSERT_FALSE(quint8_1 == quint8_3); + ASSERT_FALSE(quint8_1 == quint8_4); + + ASSERT_TRUE(quint8_1.scale() == 3.2f); + ASSERT_TRUE(quint8_1.zero_point() == 15); + + DType qint8("qint8", 3.3f); + DType qint16("qint16", 3.4f); + DType qint32("qint32", 3.5f); + + ASSERT_TRUE(qint8.scale() == 3.3f); + ASSERT_TRUE(qint16.scale() == 3.4f); + ASSERT_TRUE(qint32.scale() == 3.5f); + + ASSERT_TRUE(qint8.enumv() == DTypeEnum::qint8); + ASSERT_TRUE(qint8.str() == "qint8"); +} + +TEST(TestFormat, TestFormat) { + Format format1, format2("default"); + ASSERT_TRUE(format1.is_default()); + ASSERT_TRUE(format2.is_default()); + Format format3 = format1; + ASSERT_TRUE(format3.is_default()); +} + +TEST(TestTensor, TestTensor) { + CompNode builtin_device = CompNode::load("cpux:0"); + TensorShape builtin_shape = {3, 2, 4}; + megdnn::DType builtin_dtype = dtype::Int32{}; + + DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); + Tensor tensor1 = to_custom(dev_tensor); + Tensor tensor2 = to_custom(dev_tensor); + Device device = tensor1.device(); + Shape shape = tensor1.shape(); + DType dtype = tensor1.dtype(); + + ASSERT_TRUE(device == "x86"); + ASSERT_TRUE(shape.ndim() == 3); + ASSERT_TRUE(shape[0] == 3); + ASSERT_TRUE(shape[1] == 2); + ASSERT_TRUE(shape[2] == 4); + ASSERT_TRUE(shape == std::vector({3, 2, 4})); + ASSERT_TRUE(dtype == "int32"); + + int *raw_ptr1 = tensor1.data(); + for (size_t i=0; i(); + for (size_t i=0; i(i)); + + Tensor tensor3 = tensor2; + int *raw_ptr3 = tensor3.data(); + for (size_t i=0; i(i)); + ASSERT_TRUE(raw_ptr1 == raw_ptr2); + ASSERT_TRUE(raw_ptr1 == raw_ptr3); + + for (size_t i=0; i(i); + } + for (size_t i=0; i(i)); + } + + DeviceTensorND new_dev_tensor = to_builtin(tensor3); + + int *builtin_ptr = new_dev_tensor.ptr(); + for (size_t i=0; i(i)); + } +} + +TEST(TestTensor, TestTensorQuantized) { +#if MGB_CUDA + CompNode builtin_device = CompNode::load("gpux:0"); + TensorShape builtin_shape = {3, 2, 4}; + megdnn::DType builtin_dtype = dtype::Quantized8Asymm{3.2f, uint8_t(15)}; + + DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); + + Tensor tensor1 = to_custom(dev_tensor); + Tensor tensor2 = to_custom(dev_tensor); + Device device1 = tensor1.device(), device2 = tensor2.device(); + Shape shape1 = tensor1.shape(), shape2 = tensor2.shape(); + DType dtype1 = tensor1.dtype(), dtype2 = tensor2.dtype(); + + ASSERT_TRUE(device1 == "cuda"); + ASSERT_TRUE(shape1.ndim() == 3); + ASSERT_TRUE(shape1[0] == 3); + ASSERT_TRUE(shape1[1] == 2); + ASSERT_TRUE(shape1[2] == 4); + ASSERT_TRUE(shape1 == std::vector({3, 2, 4})); + ASSERT_TRUE(dtype1 == "quint8"); + ASSERT_TRUE(dtype1.scale() == 3.2f); + ASSERT_TRUE(dtype1.zero_point() == 15); + + ASSERT_TRUE(device1 == device2); + ASSERT_TRUE(shape1 == shape2); + ASSERT_TRUE(dtype1 == dtype2); +#endif +} + +TEST(TestTensor, TestTensorAccessorND) { + size_t N = 2, C = 4, H = 6, W = 8; + CompNode builtin_device = CompNode::load("cpux"); + TensorShape builtin_shape = {N, C, H, W}; + megdnn::DType builtin_dtype = dtype::Int32{}; + + DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); + int *builtin_ptr = dev_tensor.ptr(); + for (size_t i=0; i(); + for (size_t n=0; n(); + for (size_t i=0; i(); + for (size_t n=0; n<32; ++n) { + ASSERT_TRUE(accessor[n] == n); + } +} + +} diff --git a/src/opr/impl/custom_opnode.cpp b/src/opr/impl/custom_opnode.cpp index 2408a7b5..68a627b1 100644 --- a/src/opr/impl/custom_opnode.cpp +++ b/src/opr/impl/custom_opnode.cpp @@ -18,7 +18,7 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomOpNode); void CustomOpNode::infer_output_comp_node(void) { SmallVector input_comp_nodes(input_num()); - for (int i=0; icomp_node(); } @@ -28,7 +28,7 @@ void CustomOpNode::infer_output_comp_node(void) { ) ); - for (int i=0; icomp_node(output_comp_nodes[i]); @@ -39,7 +39,7 @@ void CustomOpNode::infer_output_comp_node(void) { void CustomOpNode::infer_output_dtype(void) { SmallVector input_dtypes(input_num()); - for (int i=0; idtype(); } @@ -49,14 +49,14 @@ void CustomOpNode::infer_output_dtype(void) { ) ); - for (int i=0; idtype(output_dtypes[i]); } } void CustomOpNode::infer_output_format(void) { SmallVector input_formats(input_num()); - for (int i=0; iformat(); } @@ -66,14 +66,14 @@ void CustomOpNode::infer_output_format(void) { ) ); - for (int i=0; iformat(output_formats[i]); } } void CustomOpNode::infer_output_shape(void) { SmallVector input_shapes(input_num()); - for (int i=0; ishape(); } @@ -83,7 +83,7 @@ void CustomOpNode::infer_output_shape(void) { ) ); - for (int i=0; ishape(output_shapes[i]); } } @@ -235,10 +235,10 @@ CustomOpNode::CustomOpNode(const std::shared_ptr &op, const OperatorNodeConfig &config): OperatorNodeBase(inputs[0]->owner_graph(), config, op->op_type(), inputs), m_op(op), m_param(param) { mgb_assert(input_num() == inputs.size(), "wrong input tensors list length"); - for (int i=0; i < input_num(); ++i) + for (size_t i=0; i < input_num(); ++i) add_input({inputs[i]}); - for (int i=0; i::value) { @@ -306,11 +306,11 @@ std::string CustomOpNode::op_desc(void) const { return m_op->op_desc(); } -int CustomOpNode::input_num(void) const { +size_t CustomOpNode::input_num(void) const { return m_op->input_num(); } -int CustomOpNode::output_num(void) const { +size_t CustomOpNode::output_num(void) const { return m_op->output_num(); } diff --git a/src/opr/include/megbrain/opr/custom_opnode.h b/src/opr/include/megbrain/opr/custom_opnode.h index 2480d801..8385626d 100644 --- a/src/opr/include/megbrain/opr/custom_opnode.h +++ b/src/opr/include/megbrain/opr/custom_opnode.h @@ -93,8 +93,8 @@ public: custom::Param param(void) const; std::string op_type(void) const; std::string op_desc(void) const; - int input_num(void) const; - int output_num(void) const; + size_t input_num(void) const; + size_t output_num(void) const; custom::ArgInfo input_info(size_t idx) const; custom::ArgInfo output_info(size_t idx) const; }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2b13c9b4..fbea06a3 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,7 +1,7 @@ include_directories("./src/include") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") -file(GLOB_RECURSE SOURCES ./*.cpp ../src/core/test/*.cpp ../src/gopt/test/*.cpp ../src/opr/test/*.cpp ../src/plugin/test/*.cpp ../src/serialization/test/*.cpp) +file(GLOB_RECURSE SOURCES ./*.cpp ../src/core/test/*.cpp ../src/gopt/test/*.cpp ../src/opr/test/*.cpp ../src/plugin/test/*.cpp ../src/serialization/test/*.cpp ../src/custom/test/*.cpp) if(MGE_WITH_JIT) file(GLOB_RECURSE SOURCES_ ../src/jit/test/*.cpp) list(APPEND SOURCES ${SOURCES_})