GitOrigin-RevId: 4ade455897
release-1.6
@@ -613,17 +613,17 @@ void init_ops(py::module m) { | |||||
} | } | ||||
#define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type) \ | #define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type) \ | ||||
case mgb::custom::ParamDynType::dyn_type: { \ | |||||
case custom::ParamDynType::dyn_type: { \ | |||||
param_val = py::handle(kv.second).cast<static_type>(); \ | param_val = py::handle(kv.second).cast<static_type>(); \ | ||||
break; \ | break; \ | ||||
} | } | ||||
#define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type) \ | #define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type) \ | ||||
case mgb::custom::ParamDynType::dyn_type: { \ | |||||
case custom::ParamDynType::dyn_type: { \ | |||||
auto pyvals = py::handle(kv.second).cast<py::list>(); \ | auto pyvals = py::handle(kv.second).cast<py::list>(); \ | ||||
static_type vals; \ | static_type vals; \ | ||||
using basic_type = \ | using basic_type = \ | ||||
mgb::custom::get_vector_template_arg_type<static_type>::type; \ | |||||
custom::get_vector_template_arg_type<static_type>::type; \ | |||||
for (auto &pyval: pyvals) { \ | for (auto &pyval: pyvals) { \ | ||||
vals.push_back(py::handle(pyval).cast<basic_type>()); \ | vals.push_back(py::handle(pyval).cast<basic_type>()); \ | ||||
} \ | } \ | ||||
@@ -631,7 +631,7 @@ void init_ops(py::module m) { | |||||
break; \ | break; \ | ||||
} | } | ||||
PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs, PyObject *kwnames) { | |||||
PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs) { | |||||
auto op_name = py::handle(args[0]).cast<std::string>(); | auto op_name = py::handle(args[0]).cast<std::string>(); | ||||
auto kwargs = py::handle(args[1]).cast<py::dict>(); | auto kwargs = py::handle(args[1]).cast<py::dict>(); | ||||
@@ -680,7 +680,7 @@ PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs, PyOb | |||||
py::list install_custom(const std::string &name, const std::string &path) { | py::list install_custom(const std::string &name, const std::string &path) { | ||||
py::list ret; | py::list ret; | ||||
const auto &ops_in_lib = mgb::custom::LibManager::inst()->install(name, path); | |||||
const auto &ops_in_lib = custom::LibManager::inst()->install(name, path); | |||||
for (const auto &op: ops_in_lib) { | for (const auto &op: ops_in_lib) { | ||||
ret.append(op); | ret.append(op); | ||||
} | } | ||||
@@ -688,7 +688,7 @@ py::list install_custom(const std::string &name, const std::string &path) { | |||||
} | } | ||||
bool uninstall_custom(const std::string &name) { | bool uninstall_custom(const std::string &name) { | ||||
return mgb::custom::LibManager::inst()->uninstall(name); | |||||
return custom::LibManager::inst()->uninstall(name); | |||||
} | } | ||||
py::list get_custom_op_list(void) { | py::list get_custom_op_list(void) { | ||||
@@ -697,16 +697,28 @@ py::list get_custom_op_list(void) { | |||||
for (auto &op: all_ops) { | for (auto &op: all_ops) { | ||||
ret.append(op); | ret.append(op); | ||||
} | } | ||||
return std::move(ret); | |||||
return ret; | |||||
} | } | ||||
#ifndef METH_FASTCALL | |||||
PyObject* py35_make_custom_op(PyObject* self, PyObject* args) { | |||||
auto* arr = &PyTuple_GET_ITEM(args, 0); | |||||
auto size = PyTuple_GET_SIZE(args); | |||||
return make_custom_op(self, arr, size); | |||||
}; | |||||
#endif | |||||
void init_custom(pybind11::module m) { | void init_custom(pybind11::module m) { | ||||
m.def("_install", &install_custom); | m.def("_install", &install_custom); | ||||
m.def("_uninstall", &uninstall_custom); | m.def("_uninstall", &uninstall_custom); | ||||
m.def("_get_custom_op_list", &get_custom_op_list); | m.def("_get_custom_op_list", &get_custom_op_list); | ||||
static PyMethodDef method_def = { | static PyMethodDef method_def = { | ||||
#ifdef METH_FASTCALL | |||||
"_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" | "_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" | ||||
#else | |||||
"_make_custom_op", (PyCFunction)py35_make_custom_op, METH_VARARGS, "" | |||||
#endif | |||||
}; | }; | ||||
auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr); | auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr); | ||||
pybind11::setattr(m, method_def.ml_name, func); | pybind11::setattr(m, method_def.ml_name, func); | ||||
@@ -70,7 +70,7 @@ void CustomOpDef::compute(const SmallVector<DeviceTensorND> &inputs, | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs( | std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs( | ||||
const SmallVector<TensorPtr> &inputs) const { | const SmallVector<TensorPtr> &inputs) const { | ||||
SmallVector<LogicalTensorDesc> input_descs(inputs.size()); | SmallVector<LogicalTensorDesc> input_descs(inputs.size()); | ||||
for (int i=0; i<inputs.size(); i++) { | |||||
for (size_t i=0; i<inputs.size(); i++) { | |||||
input_descs[i].comp_node = inputs[i]->comp_node(); | input_descs[i].comp_node = inputs[i]->comp_node(); | ||||
input_descs[i].layout = inputs[i]->layout(); | input_descs[i].layout = inputs[i]->layout(); | ||||
} | } | ||||
@@ -84,7 +84,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs | |||||
SmallVector<megdnn::DType> i_dtypes(inputs.size()); | SmallVector<megdnn::DType> i_dtypes(inputs.size()); | ||||
SmallVector<TensorFormat> i_formats(inputs.size()); | SmallVector<TensorFormat> i_formats(inputs.size()); | ||||
for (int i=0; i<inputs.size(); i++) { | |||||
for (size_t i=0; i<inputs.size(); i++) { | |||||
i_devices[i] = inputs[i].comp_node; | i_devices[i] = inputs[i].comp_node; | ||||
i_shapes[i] = inputs[i].layout; // TensorLayout is derived from TensorShape | i_shapes[i] = inputs[i].layout; // TensorLayout is derived from TensorShape | ||||
i_dtypes[i] = inputs[i].layout.dtype; | i_dtypes[i] = inputs[i].layout.dtype; | ||||
@@ -132,7 +132,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs | |||||
} | } | ||||
SmallVector<LogicalTensorDesc> outputs(this->output_num()); | SmallVector<LogicalTensorDesc> outputs(this->output_num()); | ||||
for (int i=0; i<this->output_num(); i++) { | |||||
for (size_t i=0; i<this->output_num(); i++) { | |||||
outputs[i].comp_node = std::move(o_devices[i]); | outputs[i].comp_node = std::move(o_devices[i]); | ||||
outputs[i].layout = std::move( | outputs[i].layout = std::move( | ||||
TensorLayout(o_shapes[i], o_dtypes[i], o_formats[i]) | TensorLayout(o_shapes[i], o_dtypes[i], o_formats[i]) | ||||
@@ -0,0 +1,181 @@ | |||||
/** | |||||
* \file src/custom/impl/manager.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/custom/manager.h" | |||||
#include "megbrain/common.h" | |||||
#include <unordered_set> | |||||
#ifndef _WIN32 | |||||
#include <dlfcn.h> | |||||
#endif | |||||
using namespace mgb; | |||||
namespace custom { | |||||
CustomOpManager *CustomOpManager::inst(void) { | |||||
static CustomOpManager op_manager; | |||||
return &op_manager; | |||||
} | |||||
CustomOpManager::~CustomOpManager() { | |||||
mgb_assert(m_name2op.size() == m_id2op.size(), "Custom Op maintenance error!"); | |||||
LibManager::inst()->m_custom_libs.clear(); | |||||
} | |||||
std::shared_ptr<CustomOp> CustomOpManager::insert(const std::string &name, uint32_t version) { | |||||
MGB_LOCK_GUARD(m_mtx); | |||||
auto iter = m_name2op.find(name); | |||||
if (iter != m_name2op.end()) { | |||||
mgb_log_warn("Register Custom Op Failed! Op %s has been registered", name.c_str()); | |||||
return std::const_pointer_cast<CustomOp, const CustomOp>(iter->second); | |||||
} | |||||
std::shared_ptr<const CustomOp> op = std::make_shared<const CustomOp>(name, version); | |||||
m_name2op[op->op_type()] = op; | |||||
m_id2op[op->runtime_id()] = op; | |||||
return std::const_pointer_cast<CustomOp, const CustomOp>(op); | |||||
} | |||||
bool CustomOpManager::erase(const std::string &name) { | |||||
MGB_LOCK_GUARD(m_mtx); | |||||
auto iter = m_name2op.find(name); | |||||
if (iter == m_name2op.end()) { | |||||
mgb_log_warn("Erase Custom Op Failed! %s has not been registered", name.c_str()); | |||||
return false; | |||||
} | |||||
std::shared_ptr<const CustomOp> op = iter->second; | |||||
m_id2op.erase(op->runtime_id()); | |||||
m_name2op.erase(op->op_type()); | |||||
return true; | |||||
} | |||||
bool CustomOpManager::erase(const RunTimeId &id) { | |||||
MGB_LOCK_GUARD(m_mtx); | |||||
auto iter = m_id2op.find(id); | |||||
if (iter == m_id2op.end()) { | |||||
mgb_log_warn("Erase Custom Op Failed! The Op has not been registered"); | |||||
return false; | |||||
} | |||||
std::shared_ptr<const CustomOp> op = iter->second; | |||||
m_id2op.erase(op->runtime_id()); | |||||
m_name2op.erase(op->op_type()); | |||||
return true; | |||||
} | |||||
std::shared_ptr<CustomOp> CustomOpManager::find_or_reg(const std::string &name, uint32_t version) { | |||||
auto iter = m_name2op.find(name); | |||||
if (iter == m_name2op.end()) { | |||||
return insert(name, version); | |||||
} | |||||
return std::const_pointer_cast<CustomOp, const CustomOp>(iter->second); | |||||
} | |||||
RunTimeId CustomOpManager::to_id(const std::string &name) const { | |||||
std::shared_ptr<const CustomOp> op = find(name); | |||||
return op->runtime_id(); | |||||
} | |||||
std::string CustomOpManager::to_name(const RunTimeId &id) const { | |||||
std::shared_ptr<const CustomOp> op = find(id); | |||||
return op->op_type(); | |||||
} | |||||
std::shared_ptr<const CustomOp> CustomOpManager::find(const std::string &name) const { | |||||
auto ret = m_name2op.find(name); | |||||
mgb_assert(ret != m_name2op.end(), | |||||
"Find Custom Op Failed! Op %s has not been registered", name.c_str() | |||||
); | |||||
return ret->second; | |||||
} | |||||
std::shared_ptr<const CustomOp> CustomOpManager::find(const RunTimeId &id) const { | |||||
auto ret = m_id2op.find(id); | |||||
mgb_assert(ret != m_id2op.end(), "Find Custom Op Failed! Op has not been registered"); | |||||
return ret->second; | |||||
} | |||||
std::vector<std::string> CustomOpManager::op_name_list(void) { | |||||
std::vector<std::string> ret; | |||||
for (auto kv: m_name2op) { | |||||
ret.emplace_back(kv.first); | |||||
} | |||||
return ret; | |||||
} | |||||
std::vector<RunTimeId> CustomOpManager::op_id_list(void) { | |||||
std::vector<RunTimeId> ret; | |||||
for (auto kv: m_id2op) { | |||||
ret.emplace_back(kv.first); | |||||
} | |||||
return ret; | |||||
} | |||||
#ifndef _WIN32 | |||||
CustomLib::CustomLib(const std::string &path, int mode = RTLD_LAZY) | |||||
: m_handle(nullptr, [](void* handle) {dlclose(handle);}) { | |||||
auto op_list_before_load = CustomOpManager::inst()->op_name_list(); | |||||
std::unordered_set<std::string> op_set_before_load( | |||||
op_list_before_load.begin(), op_list_before_load.end()); | |||||
m_handle.reset(dlopen(path.c_str(), mode)); | |||||
mgb_assert(m_handle != nullptr, "open custom op lib failed, error type: %s", dlerror()); | |||||
auto op_list_after_load = CustomOpManager::inst()->op_name_list(); | |||||
for (auto &op: op_list_after_load) { | |||||
if (op_set_before_load.find(op) == op_set_before_load.end()) { | |||||
m_ops.emplace_back(op); | |||||
} | |||||
} | |||||
} | |||||
#else | |||||
CustomLib::CustomLib(const std::string &path, int mode = 0) | |||||
: m_handle(nullptr, [](void* handle) {}) { | |||||
mgb_assert(false, "custom op is only supported on Linux now"); | |||||
} | |||||
#endif | |||||
const std::vector<std::string> &CustomLib::ops_in_lib(void) const { | |||||
return m_ops; | |||||
} | |||||
CustomLib::~CustomLib() { | |||||
for (auto &op: m_ops) { | |||||
CustomOpManager::inst()->erase(op); | |||||
} | |||||
} | |||||
bool CustomLib::valid() const { | |||||
return m_handle != nullptr; | |||||
} | |||||
LibManager *LibManager::inst(void) { | |||||
static LibManager custom_libs; | |||||
return &custom_libs; | |||||
} | |||||
const std::vector<std::string> &LibManager::install(const std::string &name, const std::string &path) { | |||||
MGB_LOCK_GUARD(m_mtx);; | |||||
LibHandle handle = std::make_shared<CustomLib>(path); | |||||
m_custom_libs.insert({name, handle}); | |||||
return m_custom_libs[name]->ops_in_lib(); | |||||
} | |||||
bool LibManager::uninstall(const std::string &name) { | |||||
MGB_LOCK_GUARD(m_mtx);; | |||||
mgb_assert(m_custom_libs.erase(name) == 1, "uninstall error"); | |||||
return true; | |||||
} | |||||
std::shared_ptr<CustomOp> op_insert(std::string opname, uint32_t version) { | |||||
return CustomOpManager::inst()->insert(opname, version); | |||||
} | |||||
} |
@@ -0,0 +1,531 @@ | |||||
/** | |||||
* \file src/custom/impl/op.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/custom/op.h" | |||||
#include "megbrain/custom/utils.h" | |||||
#include <unordered_set> | |||||
#include <sstream> | |||||
using namespace mgb; | |||||
namespace custom { | |||||
class ArgInfoImpl { | |||||
std::string m_name; | |||||
std::string m_desc; | |||||
std::unordered_set<std::string> m_dtypes; | |||||
int m_ndim; // use int rather than size_t for representing m_dims = -1 | |||||
std::string m_mem_stgy; | |||||
friend class ArgInfo; | |||||
}; | |||||
CUSTOM_PIMPL_CLS_DEFINE(ArgInfo) | |||||
ArgInfo::ArgInfo(const std::string &name, | |||||
const std::string &desc, | |||||
const std::unordered_set<std::string> &dtypes, | |||||
const int &ndim, | |||||
const std::string &mem_stgy): m_impl(new ArgInfoImpl(), impl_deleter<ArgInfoImpl>) { | |||||
for (auto &&dtype: dtypes) { | |||||
mgb_assert(DType::is_legal(dtype), "unsupported tensor data type: %s", dtype.c_str()); | |||||
} | |||||
mgb_assert(mem_stgy == "default", "only default mem strategy is supported now!"); | |||||
TypedRef(ArgInfoImpl, m_impl.get()).m_name = name; | |||||
TypedRef(ArgInfoImpl, m_impl.get()).m_desc = desc; | |||||
TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes = dtypes; | |||||
TypedRef(ArgInfoImpl, m_impl.get()).m_ndim = ndim; | |||||
TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy = mem_stgy; | |||||
} | |||||
const std::string &ArgInfo::name(void) const { | |||||
return TypedRef(ArgInfoImpl, m_impl.get()).m_name; | |||||
} | |||||
const std::string &ArgInfo::desc(void) const { | |||||
return TypedRef(ArgInfoImpl, m_impl.get()).m_desc; | |||||
} | |||||
const std::unordered_set<std::string> &ArgInfo::dtypes(void) const { | |||||
return TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes; | |||||
} | |||||
int ArgInfo::ndim(void) const { | |||||
return TypedRef(ArgInfoImpl, m_impl.get()).m_ndim; | |||||
} | |||||
const std::string &ArgInfo::mem_strategy(void) const { | |||||
return TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy; | |||||
} | |||||
std::string ArgInfo::str() const { | |||||
std::stringstream ss; | |||||
ss << "name: " << TypedRef(ArgInfoImpl, m_impl.get()).m_name << "\n" | |||||
<< "desc: " << TypedRef(ArgInfoImpl, m_impl.get()).m_desc << "\nlegal_dtypes: {"; | |||||
for (auto &val: TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes) { | |||||
ss << val << ", "; | |||||
} | |||||
if (TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes.size() != 0) { | |||||
ss.seekp(ss.tellp()-std::streampos(2)); | |||||
} | |||||
ss << "}\ndims: " << TypedRef(ArgInfoImpl, m_impl.get()).m_ndim << "\n" | |||||
<< "memory_strategy: " << TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy; | |||||
return ss.str(); | |||||
} | |||||
#define assert_inputs_size_right(inputs_vec) \ | |||||
mgb_assert( \ | |||||
inputs_vec.size() == input_num(), \ | |||||
"op %s need %lu inputs but given %lu", \ | |||||
op_type().c_str(), static_cast<unsigned long>(input_num()), \ | |||||
static_cast<unsigned long>(inputs_vec.size()) \ | |||||
) | |||||
#define assert_outputs_size_right(outputs_vec) \ | |||||
mgb_assert( \ | |||||
outputs_vec.size() == output_num(), \ | |||||
"op %s have %lu outputs but given %lu", \ | |||||
op_type().c_str(), static_cast<unsigned long>(output_num()), \ | |||||
static_cast<unsigned long>(outputs_vec.size()) \ | |||||
) | |||||
#define assert_arg_shape_dim_right(real_shape, arg_info) \ | |||||
mgb_assert( \ | |||||
(arg_info).ndim() == -1 || static_cast<int>((real_shape).ndim()) == \ | |||||
static_cast<int>((arg_info).ndim()), \ | |||||
"%s's args: %s dim match error, need %d but given %d", op_type().c_str(), \ | |||||
(arg_info).name().c_str(), static_cast<int>((arg_info).ndim()), \ | |||||
static_cast<int>((real_shape).ndim()) \ | |||||
) | |||||
template <typename T> | |||||
class Function; | |||||
template<typename RType, typename... Args> | |||||
class Function<RType(Args...)> { | |||||
public: | |||||
using Functor = RType (*)(Args...); | |||||
Function() = default; | |||||
Function(Functor f): m_f(f) {} | |||||
Function(const Function &rhs) { | |||||
m_f = rhs.m_f; | |||||
} | |||||
RType operator()(Args... args) { | |||||
custom_assert(m_f != nullptr, "invalid function ptr\n"); | |||||
return m_f(std::forward<Args>(args)...); | |||||
} | |||||
void operator=(const Function &rhs) { // not allowed continuous assignment | |||||
m_f = rhs.m_f; | |||||
} | |||||
void operator=(const Functor f) { | |||||
m_f = f; | |||||
} | |||||
private: | |||||
Functor m_f = nullptr; | |||||
}; | |||||
template <typename Functions> | |||||
class FuncWithSig: public Functions { | |||||
public: | |||||
using Functions::operator(); | |||||
using Functions::operator=; | |||||
}; | |||||
class CustomOpImpl { | |||||
static constexpr uint32_t CURRENT_VERSION = CUSTOM_OP_VERSION; | |||||
const uint32_t m_version; | |||||
const std::string m_op_type; | |||||
std::string m_op_desc; | |||||
std::vector<ArgInfo> m_input_infos; | |||||
std::vector<ArgInfo> m_output_infos; | |||||
ParamInfo m_param_infos; | |||||
using DeviceInfer = FuncWithSig<Function<void(const std::vector<Device>&, const Param&, std::vector<Device>&)>>; | |||||
using ShapeInfer = FuncWithSig<Function<void(const std::vector<Shape>&, const Param&, std::vector<Shape>&)>>; | |||||
using DTypeInfer = FuncWithSig<Function<void(const std::vector<DType>&, const Param&, std::vector<DType>&)>>; | |||||
using FormatInfer = FuncWithSig<Function<void(const std::vector<Format>&, const Param&, std::vector<Format>&)>>; | |||||
using Preprocess = FuncWithSig<Function<void(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>; | |||||
using Postprocess = FuncWithSig<Function<void(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>; | |||||
using Compute = FuncWithSig<Function<void(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>; | |||||
DeviceInfer infer_output_device_func; | |||||
ShapeInfer infer_output_shape_func; | |||||
DTypeInfer infer_output_dtype_func; | |||||
FormatInfer infer_output_format_func; | |||||
std::unordered_map<std::string, Compute> compute_funcs; | |||||
std::unordered_map<std::string, Preprocess> preprocess_funcs; | |||||
std::unordered_map<std::string, Postprocess> postprocess_funcs; | |||||
public: | |||||
CustomOpImpl(const std::string&, uint32_t version); | |||||
PREVENT_COPY_AND_ASSIGN(CustomOpImpl); | |||||
friend CustomOp; | |||||
}; | |||||
CustomOpImpl::CustomOpImpl(const std::string &op_type, uint32_t version) | |||||
: m_version(version), m_op_type(op_type) { | |||||
if (m_version != CURRENT_VERSION) { | |||||
mgb_log_warn( | |||||
"the version of loaded custom op %s is %u, but custom op version " | |||||
"of the system is %u\n", op_type.c_str(), m_version, CURRENT_VERSION | |||||
); | |||||
} | |||||
infer_output_device_func = [](const std::vector<Device> &inputs, | |||||
const Param&, | |||||
std::vector<Device> &outputs) -> void { | |||||
static UnImpleWarnLog log_once("output_device_infer", "device", "x86"); | |||||
for (size_t i=0; i<outputs.size(); ++i) { | |||||
outputs[i] = inputs.size() > 0 ? inputs[0] : Device("x86"); | |||||
} | |||||
}; | |||||
infer_output_shape_func = [](const std::vector<Shape> &inputs, | |||||
const Param&, | |||||
std::vector<Shape> &outputs) -> void { | |||||
static UnImpleWarnLog log_once("output_shape_infer", "shape", "{1}"); | |||||
for (size_t i=0; i<outputs.size(); ++i) { | |||||
outputs[i] = inputs.size() > 0 ? inputs[0] : Shape({1}); | |||||
} | |||||
}; | |||||
infer_output_dtype_func = [](const std::vector<DType> &inputs, | |||||
const Param&, | |||||
std::vector<DType> &outputs) -> void { | |||||
static UnImpleWarnLog log_once("output_dtype_infer", "dtype", "float32"); | |||||
for (size_t i=0; i<outputs.size(); ++i) { | |||||
outputs[i] = inputs.size() > 0 ? inputs[0] : DType("float32"); | |||||
} | |||||
}; | |||||
infer_output_format_func = [](const std::vector<Format> &inputs, | |||||
const Param&, | |||||
std::vector<Format> &outputs) -> void { | |||||
for (size_t i=0; i<outputs.size(); ++i) { | |||||
outputs[i] = inputs.size() > 0 ? inputs[0] : Format("default"); | |||||
} | |||||
}; | |||||
for (const auto &device: Device::legal_devices()) { | |||||
compute_funcs[device] = [](const std::vector<Tensor>&, const Param&, std::vector<Tensor> &outputs) -> void { | |||||
auto device = outputs[0].device(); | |||||
mgb_assert(false, "There is no forward function for your op on device `%s`. " | |||||
"Please implement this function and register it.", device.str().c_str()); | |||||
}; | |||||
preprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) -> void { | |||||
return; | |||||
}; | |||||
postprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) -> void { | |||||
return; | |||||
}; | |||||
} | |||||
m_param_infos.set_tag(op_type); | |||||
} | |||||
CustomOp::CustomOp(const std::string &op_type, uint32_t version) | |||||
: m_impl(new CustomOpImpl(op_type, version), impl_deleter<CustomOpImpl>) { | |||||
} | |||||
#define OpImplRef(raw_ptr) reinterpret_cast<CustomOpImpl*>(raw_ptr) | |||||
CustomOp &CustomOp::set_device_infer(DeviceInferFuncPtr func) { | |||||
OpImplRef(m_impl.get())->infer_output_device_func = func; | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::set_shape_infer(ShapeInferFuncPtr func) { | |||||
OpImplRef(m_impl.get())->infer_output_shape_func = func; | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::set_dtype_infer(DTypeInferFuncPtr func) { | |||||
OpImplRef(m_impl.get())->infer_output_dtype_func = func; | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::set_format_infer(FormatInferFuncPtr func) { | |||||
OpImplRef(m_impl.get())->infer_output_format_func = func; | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::set_preprocess(PreprocessFuncPtr func) { | |||||
set_preprocess("x86", func); | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::set_preprocess(const std::string &device, PreprocessFuncPtr func) { | |||||
OpImplRef(m_impl.get())->preprocess_funcs[device] = func; | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::set_postprocess(PostprocessFuncPtr func) { | |||||
set_postprocess("x86", func); | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::set_postprocess(const std::string &device, PostprocessFuncPtr func) { | |||||
OpImplRef(m_impl.get())->postprocess_funcs[device] = func; | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::set_compute(ComputeFuncPtr func) { | |||||
set_compute("x86", func); | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::set_compute(const std::string &device, ComputeFuncPtr func) { | |||||
OpImplRef(m_impl.get())->compute_funcs[device] = func; | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::set_description(const std::string &op_desc) { | |||||
OpImplRef(m_impl.get())->m_op_desc = op_desc; | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::add_input(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) { | |||||
auto &ref = OpImplRef(m_impl.get())->m_input_infos; | |||||
for (const auto &input: ref) { | |||||
mgb_assert(input.name() != name, "input %s has been registered", name.c_str()); | |||||
} | |||||
ref.emplace_back(name, desc, legal_dtypes, dims, mem_stgy); | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::add_output(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) { | |||||
auto &ref = OpImplRef(m_impl.get())->m_output_infos; | |||||
for (const auto &output: ref) { | |||||
mgb_assert(output.name() != name, "output %s has been registered", name.c_str()); | |||||
} | |||||
ref.emplace_back(name, desc, legal_dtypes, dims, mem_stgy); | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::add_input(const std::string &name, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) { | |||||
add_input(name, name, legal_dtypes, dims, mem_stgy); | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::add_output(const std::string &name, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) { | |||||
add_output(name, name, legal_dtypes, dims, mem_stgy); | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::add_inputs(const size_t &num) { | |||||
size_t cur_inp_num = input_num(); | |||||
for (size_t i=cur_inp_num; i<cur_inp_num+num; i++) { | |||||
add_input(op_type() + "_Input_" + std::to_string(i)); | |||||
} | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::add_outputs(const size_t &num) { | |||||
size_t cur_oup_num = output_num(); | |||||
for (size_t i=cur_oup_num; i<cur_oup_num+num; i++) { | |||||
add_output(op_type() + "_Output_" + std::to_string(i)); | |||||
} | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::add_param(const std::string &name, const ParamVal &default_val) { | |||||
add_param(name, name, default_val); | |||||
return *this; | |||||
} | |||||
CustomOp &CustomOp::add_param(const std::string &name, const std::string &desc, const ParamVal &default_val) { | |||||
auto &meta = OpImplRef(m_impl.get())->m_param_infos.meta(); | |||||
for(const auto &schema: meta) { | |||||
mgb_assert(name != schema.name(), "param %s has been registered\n", name.c_str()); | |||||
} | |||||
ParamSchema sch = ParamSchema(name, default_val, desc); | |||||
meta.emplace_back(sch); | |||||
return *this; | |||||
} | |||||
std::string CustomOp::op_type(void) const { | |||||
return OpImplRef(m_impl.get())->m_op_type; | |||||
} | |||||
std::string CustomOp::op_desc(void) const { | |||||
return OpImplRef(m_impl.get())->m_op_desc; | |||||
} | |||||
RunTimeId CustomOp::runtime_id(void) const { | |||||
return (RunTimeId)(this); | |||||
} | |||||
size_t CustomOp::input_num(void) const { | |||||
return OpImplRef(m_impl.get())->m_input_infos.size(); | |||||
} | |||||
size_t CustomOp::output_num(void) const { | |||||
return OpImplRef(m_impl.get())->m_output_infos.size(); | |||||
} | |||||
std::string CustomOp::str(void) const { | |||||
std::stringstream ss; | |||||
ss << "op name: " << op_type() << "\nop desc: " << op_desc() << "\n\ninputs:\n"; | |||||
for (const auto &input: inputs_info()) { | |||||
ss << input.str(); | |||||
ss << "\n--------------------\n"; | |||||
} | |||||
ss << "\noutputs:\n"; | |||||
for (const auto &output: outputs_info()) { | |||||
ss << output.str(); | |||||
ss << "\n--------------------\n"; | |||||
} | |||||
ss << "\nparams:\n"; | |||||
for (const auto ¶m: param_info().meta()) { | |||||
ss << param.str(); | |||||
ss << "\n--------------------\n"; | |||||
} | |||||
return ss.str(); | |||||
} | |||||
const ParamInfo &CustomOp::param_info(void) const { | |||||
return OpImplRef(m_impl.get())->m_param_infos; | |||||
} | |||||
ArgInfo CustomOp::input_info(size_t idx) const { | |||||
return OpImplRef(m_impl.get())->m_input_infos[idx]; | |||||
} | |||||
ArgInfo CustomOp::output_info(size_t idx) const { | |||||
return OpImplRef(m_impl.get())->m_output_infos[idx]; | |||||
} | |||||
const std::vector<ArgInfo> &CustomOp::inputs_info(void) const { | |||||
return OpImplRef(m_impl.get())->m_input_infos; | |||||
} | |||||
const std::vector<ArgInfo> &CustomOp::outputs_info(void) const { | |||||
return OpImplRef(m_impl.get())->m_output_infos; | |||||
} | |||||
std::vector<Device> CustomOp::infer_output_device(const std::vector<Device> &inputs, const Param ¶m) const { | |||||
assert_inputs_size_right(inputs); | |||||
std::vector<Device> outputs(output_num()); | |||||
OpImplRef(m_impl.get())->infer_output_device_func(inputs, param, outputs); | |||||
assert_outputs_size_right(outputs); | |||||
return outputs; | |||||
} | |||||
std::vector<Shape> CustomOp::infer_output_shape(const std::vector<Shape> &inputs, const Param ¶m) const { | |||||
assert_inputs_size_right(inputs); | |||||
for (size_t i=0; i<inputs_info().size(); i++) { | |||||
assert_arg_shape_dim_right(inputs[i], input_info(i)); | |||||
} | |||||
std::vector<Shape> outputs(output_num()); | |||||
OpImplRef(m_impl.get())->infer_output_shape_func(inputs, param, outputs); | |||||
for (size_t i=0; i<outputs_info().size(); i++) { | |||||
assert_arg_shape_dim_right(outputs[i], output_info(i)); | |||||
} | |||||
assert_outputs_size_right(outputs); | |||||
return outputs; | |||||
} | |||||
std::vector<DType> CustomOp::infer_output_dtype(const std::vector<DType> &inputs, const Param ¶m) const { | |||||
assert_inputs_size_right(inputs); | |||||
for (size_t i=0; i<inputs_info().size(); i++) { | |||||
std::unordered_set<std::string> legal_input_dtypes_i = input_info(i).dtypes(); | |||||
mgb_assert( | |||||
legal_input_dtypes_i.find(inputs[i].str()) != legal_input_dtypes_i.end(), | |||||
"dtypes of input: %s(%s) is not allowed, the info of this input is:\n%s", | |||||
input_info(i).name().c_str(), inputs[i].str().c_str(), | |||||
input_info(i).str().c_str() | |||||
); | |||||
} | |||||
std::vector<DType> outputs(output_num()); | |||||
OpImplRef(m_impl.get())->infer_output_dtype_func(inputs, param, outputs); | |||||
for (size_t i=0; i<outputs_info().size(); i++) { | |||||
std::unordered_set<std::string> legal_output_dtypes_i = output_info(i).dtypes(); | |||||
mgb_assert( | |||||
legal_output_dtypes_i.find(outputs[i].str()) != legal_output_dtypes_i.end(), | |||||
"dtypes of output: %s is %s, the info of this output is:\n%s", | |||||
output_info(i).name().c_str(), outputs[i].str().c_str(), | |||||
output_info(i).str().c_str() | |||||
); | |||||
} | |||||
assert_outputs_size_right(outputs); | |||||
return outputs; | |||||
} | |||||
std::vector<Format> CustomOp::infer_output_format(const std::vector<Format> &inputs, const Param ¶m) const { | |||||
assert_inputs_size_right(inputs); | |||||
for (size_t i=0; i<inputs.size(); i++) { | |||||
mgb_assert( | |||||
inputs[i].is_default(), | |||||
"the tensor format of %s:%s is not default", | |||||
op_type().c_str(), input_info(i).name().c_str() | |||||
); | |||||
} | |||||
std::vector<Format> outputs(output_num()); | |||||
OpImplRef(m_impl.get())->infer_output_format_func(inputs, param, outputs); | |||||
for (size_t i=0; i<outputs.size(); i++) { | |||||
mgb_assert( | |||||
outputs[i].is_default(), | |||||
"the tensor format of %s:%s is not default", | |||||
op_type().c_str(), output_info(i).name().c_str() | |||||
); | |||||
} | |||||
assert_outputs_size_right(outputs); | |||||
return outputs; | |||||
} | |||||
void CustomOp::compute(const std::vector<Tensor> &inputs, const Param ¶m, std::vector<Tensor> &outputs) const { | |||||
assert_inputs_size_right(inputs); | |||||
assert_outputs_size_right(outputs); | |||||
if (outputs.size() == 0) { | |||||
return; | |||||
} | |||||
std::string device = outputs[0].device().str(); | |||||
for (size_t i=1; i<outputs.size(); ++i) { | |||||
mgb_assert( | |||||
outputs[i].device().str() == device, | |||||
"all output tensors should have the same device attribute" | |||||
); | |||||
} | |||||
// need to add other input/output check | |||||
mgb_assert(Device::is_legal(device), "unsupported device type: %s", device.c_str()); | |||||
auto preprocess_func = OpImplRef(m_impl.get())->preprocess_funcs[device]; | |||||
auto forward_func = OpImplRef(m_impl.get())->compute_funcs[device]; | |||||
auto postprocess_func = OpImplRef(m_impl.get())->postprocess_funcs[device]; | |||||
preprocess_func(inputs, param, outputs); | |||||
forward_func(inputs, param, outputs); | |||||
postprocess_func(outputs, param, outputs); | |||||
assert_outputs_size_right(outputs); | |||||
} | |||||
} |
@@ -0,0 +1,179 @@ | |||||
/** | |||||
* \file src/custom/impl/param.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/custom/param.h" | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/utils/hash.h" | |||||
#include <limits> | |||||
#include <sstream> | |||||
#include <map> | |||||
using namespace mgb; | |||||
namespace custom { | |||||
class ParamSchemaImpl { | |||||
std::string m_name; | |||||
std::string m_desc; | |||||
ParamVal m_default; | |||||
friend ParamSchema; | |||||
}; | |||||
class ParamInfoImpl { | |||||
std::vector<ParamSchema> m_meta; | |||||
uint32_t TAG; | |||||
friend ParamInfo; | |||||
}; | |||||
class ParamImpl { | |||||
std::unordered_map<std::string, ParamVal> m_vals; | |||||
ParamImpl() = default; | |||||
ParamImpl(const ParamImpl &rhs) = default; | |||||
ParamImpl &operator=(const ParamImpl &rhs) { | |||||
mgb_assert( | |||||
m_vals.size() == rhs.m_vals.size(), | |||||
"params of different op, assignment failed!" | |||||
); | |||||
for (const auto &kv: rhs.m_vals) { | |||||
auto iter = m_vals.find(kv.first); | |||||
mgb_assert(iter != m_vals.end(), "params of different op, assignment failed!"); | |||||
iter->second = kv.second; | |||||
} | |||||
return *this; | |||||
} | |||||
friend Param; | |||||
}; | |||||
CUSTOM_PIMPL_CLS_DEFINE(ParamSchema) | |||||
ParamSchema::ParamSchema(const std::string &name, const ParamVal &value, const std::string &desc) | |||||
: m_impl(new ParamSchemaImpl(), impl_deleter<ParamSchemaImpl>) { | |||||
TypedRef(ParamSchemaImpl, m_impl.get()).m_name = name; | |||||
TypedRef(ParamSchemaImpl, m_impl.get()).m_default = value; | |||||
TypedRef(ParamSchemaImpl, m_impl.get()).m_desc = desc; | |||||
} | |||||
const std::string &ParamSchema::name(void) const { | |||||
return TypedRef(ParamSchemaImpl, m_impl.get()).m_name; | |||||
} | |||||
const std::string &ParamSchema::desc(void) const { | |||||
return TypedRef(ParamSchemaImpl, m_impl.get()).m_desc; | |||||
} | |||||
const ParamVal &ParamSchema::default_val(void) const { | |||||
return TypedRef(ParamSchemaImpl, m_impl.get()).m_default; | |||||
} | |||||
ParamDynType ParamSchema::type(void) const { | |||||
return TypedRef(ParamSchemaImpl, m_impl.get()).m_default.type(); | |||||
} | |||||
std::string ParamSchema::str(void) const { | |||||
std::stringstream ss; | |||||
ss << "name: " << TypedRef(ParamSchemaImpl, m_impl.get()).m_name | |||||
<< "\ndesc: " << TypedRef(ParamSchemaImpl, m_impl.get()).m_desc | |||||
<< "\n" << TypedRef(ParamSchemaImpl, m_impl.get()).m_default.str(); | |||||
return ss.str(); | |||||
} | |||||
CUSTOM_PIMPL_CLS_DEFINE(ParamInfo) | |||||
void ParamInfo::set_tag(const std::string &hash_str) { | |||||
const char *ptr = hash_str.c_str(); | |||||
TypedRef(ParamInfoImpl, m_impl.get()).TAG = 0; | |||||
for (size_t i=0; i<hash_str.size(); i++) { | |||||
TypedRef(ParamInfoImpl, m_impl.get()).TAG = | |||||
mgb::hash_pair_combine(TypedRef(ParamInfoImpl, m_impl.get()).TAG, mgb::hash(*(ptr++))) % | |||||
std::numeric_limits<uint32_t>::max(); | |||||
} | |||||
} | |||||
void ParamInfo::set_meta(const std::vector<ParamSchema> &meta) { | |||||
TypedRef(ParamInfoImpl, m_impl.get()).m_meta = meta; | |||||
} | |||||
uint32_t ParamInfo::tag(void) const { | |||||
return TypedRef(ParamInfoImpl, m_impl.get()).TAG; | |||||
} | |||||
std::vector<ParamSchema> &ParamInfo::meta(void) { | |||||
return TypedRef(ParamInfoImpl, m_impl.get()).m_meta; | |||||
} | |||||
const std::vector<ParamSchema> &ParamInfo::meta(void) const { | |||||
return TypedRef(ParamInfoImpl, m_impl.get()).m_meta; | |||||
} | |||||
CUSTOM_PIMPL_CLS_DEFINE(Param) | |||||
Param::Param(const ParamInfo &info): m_impl(new ParamImpl(), impl_deleter<ParamImpl>) { | |||||
for (const auto &schema: info.meta()) { | |||||
TypedRef(ParamImpl, m_impl.get()).m_vals.emplace(schema.name(), schema.default_val()); | |||||
} | |||||
} | |||||
ParamVal &Param::operator[](const std::string &name) { | |||||
return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name)->second; | |||||
} | |||||
const ParamVal &Param::operator[](const std::string &name) const { | |||||
return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name)->second; | |||||
} | |||||
const std::unordered_map<std::string, ParamVal> &Param::raw() const { | |||||
return TypedRef(ParamImpl, m_impl.get()).m_vals; | |||||
} | |||||
bool Param::exist(const std::string &name) const { | |||||
return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name) != | |||||
TypedRef(ParamImpl, m_impl.get()).m_vals.end(); | |||||
} | |||||
std::string Param::to_bytes(void) const { | |||||
std::string res; | |||||
std::map<std::string, ParamVal> ordered_vals( | |||||
TypedRef(ParamImpl, m_impl.get()).m_vals.begin(), | |||||
TypedRef(ParamImpl, m_impl.get()).m_vals.end()); | |||||
for (auto &&kv: ordered_vals) { | |||||
res += ParamVal::to_bytes(kv.second); | |||||
} | |||||
return res; | |||||
} | |||||
void Param::from_bytes(const std::string &bytes) { | |||||
std::map<std::string, ParamVal> ordered_vals( | |||||
TypedRef(ParamImpl, m_impl.get()).m_vals.begin(), | |||||
TypedRef(ParamImpl, m_impl.get()).m_vals.end()); | |||||
size_t offset = 0; | |||||
for (auto &kv: ordered_vals) { | |||||
kv.second = ParamVal::from_bytes(bytes, offset); | |||||
} | |||||
TypedRef(ParamImpl, m_impl.get()).m_vals.clear(); | |||||
TypedRef(ParamImpl, m_impl.get()).m_vals.insert(ordered_vals.begin(), ordered_vals.end()); | |||||
mgb_assert(offset == bytes.size(), "wrong data loader"); | |||||
} | |||||
bool operator==(const Param &lhs, const Param &rhs) { | |||||
if (lhs.raw().size() != rhs.raw().size()) | |||||
return false; | |||||
for (const auto &kv: lhs.raw()) { | |||||
auto riter = rhs.raw().find(kv.first); | |||||
if (riter == rhs.raw().end() || !((kv.second) == riter->second)) { | |||||
return false; | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
} |
@@ -0,0 +1,400 @@ | |||||
/** | |||||
* \file src/custom/impl/param_val.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/custom/param_val.h" | |||||
#include "megbrain/common.h" | |||||
#pragma GCC diagnostic ignored "-Wsign-compare" | |||||
using namespace mgb; | |||||
namespace custom { | |||||
/** | |||||
* Macro Callback for Case | |||||
*/ | |||||
#define CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS(dyn_type, static_type) \ | |||||
case (ParamDynType::dyn_type): { \ | |||||
std::unique_ptr<void, void_deleter> new_ptr( \ | |||||
new static_type(TypedRef(static_type, rhs.m_ptr.get())), \ | |||||
impl_deleter<static_type> \ | |||||
); \ | |||||
m_ptr.swap(new_ptr); \ | |||||
break; \ | |||||
} | |||||
#define CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS(dyn_type, static_type) \ | |||||
case (ParamDynType::dyn_type): { \ | |||||
TypedRef(static_type, m_ptr.get()) = TypedRef(static_type, rhs.m_ptr.get());\ | |||||
break; \ | |||||
} | |||||
#define CUSTOM_ASSERT_OPERAND_VALID(operand, opr) \ | |||||
mgb_assert( \ | |||||
operand.m_ptr != nullptr && operand.m_type != ParamDynType::Invalid, \ | |||||
"invalid %s of operator %s of ParamVal", #operand, #opr \ | |||||
) | |||||
#define CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op) \ | |||||
mgb_assert( \ | |||||
lhs.m_type == rhs.m_type, "`%s` %s `%s` is not allowed", \ | |||||
type2name[lhs.m_type].c_str(), #op, \ | |||||
type2name[rhs.m_type].c_str() \ | |||||
) | |||||
#define CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL(dyn_type, static_type, op) \ | |||||
case (ParamDynType::dyn_type): { \ | |||||
const auto &rval = TypedRef(static_type, rhs.m_ptr.get()); \ | |||||
return lval op rval; \ | |||||
} | |||||
#define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC(dyn_type, static_type, op) \ | |||||
case (ParamDynType::dyn_type): { \ | |||||
const auto &lval = TypedRef(static_type, lhs.m_ptr.get()); \ | |||||
switch (rhs.m_type) { \ | |||||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY( \ | |||||
CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL, op) \ | |||||
default: \ | |||||
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||||
} \ | |||||
break; \ | |||||
} | |||||
#define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC(dyn_type, static_type, op) \ | |||||
case (ParamDynType::dyn_type): { \ | |||||
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||||
const auto &lval = TypedRef(static_type, lhs.m_ptr.get()); \ | |||||
const auto &rval = TypedRef(static_type, rhs.m_ptr.get()); \ | |||||
return lval op rval; \ | |||||
} | |||||
#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(op, ret_type) \ | |||||
ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \ | |||||
CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ | |||||
CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ | |||||
\ | |||||
switch (lhs.m_type) { \ | |||||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ | |||||
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ | |||||
default: \ | |||||
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||||
} \ | |||||
return {}; \ | |||||
} | |||||
#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING(op, ret_type) \ | |||||
ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \ | |||||
CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ | |||||
CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ | |||||
\ | |||||
switch (lhs.m_type) { \ | |||||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ | |||||
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ | |||||
CUSTOM_FOR_STRING_PARAMTYPE( \ | |||||
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ | |||||
default: \ | |||||
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||||
} \ | |||||
return {}; \ | |||||
} | |||||
#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(op, ret_type) \ | |||||
ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \ | |||||
CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ | |||||
CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ | |||||
\ | |||||
switch (lhs.m_type) { \ | |||||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ | |||||
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ | |||||
CUSTOM_FOR_STRING_PARAMTYPE( \ | |||||
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ | |||||
CUSTOM_FOR_EACH_LIST_PARAMTYPE( \ | |||||
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ | |||||
default: \ | |||||
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||||
} \ | |||||
return {}; \ | |||||
} | |||||
#define CUSTOM_CASE_TO_PRINT_NONLIST(dyn_type, static_type) \ | |||||
case (ParamDynType::dyn_type): { \ | |||||
auto rval = TypedRef(static_type, m_ptr.get()); \ | |||||
ss << rval; \ | |||||
break; \ | |||||
} | |||||
#define CUSTOM_CASE_TO_PRINT_LIST(dyn_type, static_type) \ | |||||
case (ParamDynType::dyn_type): { \ | |||||
auto rval = TypedRef(static_type, m_ptr.get()); \ | |||||
ss << vec2str(rval); \ | |||||
break; \ | |||||
} | |||||
#define CUSTOM_CASE_TO_RET_SIZE(dyn_type, static_type) \ | |||||
case (ParamDynType::dyn_type): { \ | |||||
return TypedRef(static_type, m_ptr.get()).size(); \ | |||||
break; \ | |||||
} | |||||
#define CUSTOM_CASE_TO_DUMP_BASIC(dyn_type, static_type) \ | |||||
case (ParamDynType::dyn_type): { \ | |||||
res.resize(sizeof(ParamDynType) + sizeof(static_type)); \ | |||||
memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); \ | |||||
memcpy(&res[sizeof(ParamDynType)], value.m_ptr.get(), sizeof(static_type)); \ | |||||
break; \ | |||||
} | |||||
#define CUSTOM_CASE_TO_DUMP_LIST(dyn_type, static_type) \ | |||||
case (ParamDynType::dyn_type): { \ | |||||
auto &ref = TypedRef(static_type, value.m_ptr.get()); \ | |||||
size_t len = ref.size(); \ | |||||
size_t elem_size = len != 0 ? sizeof(ref[0]) : 0; \ | |||||
res.resize(sizeof(ParamDynType) + sizeof(len) + len*elem_size); \ | |||||
memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); \ | |||||
memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); \ | |||||
memcpy(&res[sizeof(ParamDynType)+sizeof(len)], ref.data(), len*elem_size); \ | |||||
break; \ | |||||
} | |||||
#define CUSTOM_CASE_TO_LOAD_BASIC(dyn_type, static_type) \ | |||||
case (ParamDynType::dyn_type): { \ | |||||
static_type val; \ | |||||
memcpy(&val, &bytes[offset], sizeof(val)); \ | |||||
offset += sizeof(val); \ | |||||
return val; \ | |||||
break; \ | |||||
} | |||||
#define CUSTOM_CASE_TO_LOAD_LIST(dyn_type, static_type) \ | |||||
case (ParamDynType::dyn_type): { \ | |||||
size_t len = 0; \ | |||||
memcpy(&len, &bytes[offset], sizeof(len)); \ | |||||
offset += sizeof(len); \ | |||||
static_type vals; \ | |||||
vals.resize(len); \ | |||||
size_t elem_size = len != 0 ? sizeof(vals[0]) : 0; \ | |||||
memcpy(&vals[0], &bytes[offset], len*elem_size); \ | |||||
offset += len*elem_size; \ | |||||
return vals; \ | |||||
break; \ | |||||
} | |||||
ParamVal::ParamVal(): m_ptr(nullptr, [](void*) -> void {}) { | |||||
m_type = ParamDynType::Invalid; | |||||
} | |||||
ParamVal::ParamVal(const char *str): ParamVal(std::string(str)) { | |||||
} | |||||
ParamVal::ParamVal(const std::initializer_list<const char*> &strs): ParamVal(std::vector<const char*>(strs)) { | |||||
} | |||||
ParamVal::ParamVal(const std::vector<const char*> &strs) | |||||
: m_ptr(new std::vector<std::string>(), impl_deleter<std::vector<std::string>>) { | |||||
m_type = ParamDynType::StringList; | |||||
for (const auto &str: strs) { | |||||
TypedRef(std::vector<std::string>, m_ptr.get()).emplace_back(str); | |||||
} | |||||
} | |||||
ParamVal::ParamVal(const ParamVal &rhs): m_ptr(nullptr, [](void*) -> void {}) { | |||||
mgb_assert( | |||||
rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr, | |||||
"invalid rhs of copy constructor of ParamVal" | |||||
); | |||||
m_type = rhs.m_type; | |||||
switch(m_type) { | |||||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS) | |||||
default: { | |||||
mgb_assert(false, "invalid rhs of copy constructor of ParamVal"); | |||||
} | |||||
} | |||||
} | |||||
ParamVal &ParamVal::operator=(const char *str) { | |||||
this->operator=(std::string(str)); | |||||
return *this; | |||||
} | |||||
ParamVal &ParamVal::operator=(const std::initializer_list<const char*> &strs) { | |||||
this->operator=(std::vector<const char*>(strs)); | |||||
return *this; | |||||
} | |||||
ParamVal &ParamVal::operator=(const std::vector<const char*> &strs) { | |||||
std::vector<std::string> tmp_strs; | |||||
for (const auto &str: strs) { | |||||
tmp_strs.emplace_back(str); | |||||
} | |||||
this->operator=(tmp_strs); | |||||
return *this; | |||||
} | |||||
ParamVal &ParamVal::operator=(const ParamVal &rhs) { | |||||
if (&rhs == this) | |||||
return *this; | |||||
mgb_assert( | |||||
rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr, | |||||
"invalid rhs of assignment operator of ParamVal" | |||||
); | |||||
if (rhs.m_type == m_type) { | |||||
switch(m_type) { | |||||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS); | |||||
default: | |||||
mgb_assert(false, "invalid rhs of assignment operator of ParamVal"); | |||||
} | |||||
} | |||||
else { | |||||
m_type = rhs.m_type; | |||||
switch(m_type) { | |||||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS); | |||||
default: | |||||
mgb_assert(false, "invalid rhs of assignment operator of ParamVal"); | |||||
} | |||||
} | |||||
return *this; | |||||
} | |||||
const void *ParamVal::raw_ptr(void) const { | |||||
return m_ptr.get(); | |||||
} | |||||
void *ParamVal::raw_ptr(void) { | |||||
return m_ptr.get(); | |||||
} | |||||
ParamDynType ParamVal::type(void) const { | |||||
return m_type; | |||||
} | |||||
std::string ParamVal::str() const { | |||||
std::stringstream ss; | |||||
ss << "type: " << type2name[m_type] << "\n" << "value: "; | |||||
switch (m_type) { | |||||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) | |||||
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) | |||||
CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_PRINT_LIST) | |||||
default: | |||||
mgb_assert(false, "invalid data of assignment operator of ParamVal"); | |||||
} | |||||
return ss.str(); | |||||
} | |||||
size_t ParamVal::size(void) const { | |||||
switch (m_type) { | |||||
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE) | |||||
CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE) | |||||
default: | |||||
mgb_assert(false, "there is no size() for basic data types"); | |||||
} | |||||
} | |||||
std::string ParamVal::to_bytes(const ParamVal &value) { | |||||
std::string res; | |||||
// because the specialization of std::vector<bool> | |||||
if (value.type() == ParamDynType::BoolList) { | |||||
std::vector<bool> &ref = TypedRef(std::vector<bool>, value.m_ptr.get()); | |||||
size_t len = ref.size(); | |||||
size_t elem_size = sizeof(bool); | |||||
res.resize(sizeof(ParamDynType) + sizeof(len) + len*elem_size); | |||||
memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); | |||||
memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); | |||||
size_t startpos = sizeof(ParamDynType)+sizeof(len); | |||||
for (size_t idx=0; idx<len; idx++) { | |||||
bool b = ref[idx]; | |||||
memcpy(&res[startpos+idx*sizeof(b)], &b, sizeof(b)); | |||||
} | |||||
return res; | |||||
} | |||||
else if (value.type() == ParamDynType::StringList) { | |||||
std::vector<std::string> &ref = TypedRef(std::vector<std::string>, value.m_ptr.get()); | |||||
size_t len = ref.size(); | |||||
res.resize(sizeof(ParamDynType) + sizeof(len)); | |||||
memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); | |||||
memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); | |||||
for (size_t idx=0; idx<ref.size(); ++idx) { | |||||
size_t str_len = ref[idx].size(); | |||||
std::string bytes(sizeof(str_len) + str_len, ' '); | |||||
memcpy(&bytes[0], &str_len, sizeof(str_len)); | |||||
memcpy(&bytes[sizeof(str_len)], ref[idx].data(), str_len); | |||||
res += bytes; | |||||
} | |||||
return res; | |||||
} | |||||
switch(value.type()) { | |||||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_DUMP_BASIC) | |||||
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_DUMP_LIST) | |||||
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_DUMP_LIST) | |||||
default: | |||||
mgb_assert(false, "invalid param type"); | |||||
} | |||||
return res; | |||||
} | |||||
ParamVal ParamVal::from_bytes(const std::string &bytes, size_t &offset) { | |||||
ParamDynType data_type = ParamDynType::Invalid; | |||||
memcpy(&data_type, &bytes[offset], sizeof(ParamDynType)); | |||||
offset += sizeof(ParamDynType); | |||||
if (data_type == ParamDynType::BoolList) { | |||||
std::vector<bool> ret; | |||||
size_t len = 0; | |||||
memcpy(&len, &bytes[offset], sizeof(len)); | |||||
offset += sizeof(len); | |||||
for (size_t idx =0; idx<len; ++idx) { | |||||
bool b = true; | |||||
memcpy(&b, &bytes[offset], sizeof(bool)); | |||||
offset += sizeof(bool); | |||||
ret.push_back(b); | |||||
} | |||||
return ret; | |||||
} | |||||
else if (data_type == ParamDynType::StringList) { | |||||
std::vector<std::string> ret; | |||||
size_t len = 0; | |||||
memcpy(&len, &bytes[offset], sizeof(len)); | |||||
offset += sizeof(len); | |||||
for (size_t idx =0; idx<len; ++idx) { | |||||
size_t str_len = 0; | |||||
memcpy(&str_len, &bytes[offset], sizeof(str_len)); | |||||
offset += sizeof(str_len); | |||||
std::string str(str_len, ' '); | |||||
memcpy(&str[0], &bytes[offset], str_len); | |||||
offset += str_len; | |||||
ret.push_back(str); | |||||
} | |||||
return ret; | |||||
} | |||||
switch (data_type) { | |||||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_LOAD_BASIC) | |||||
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_LOAD_LIST) | |||||
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_LOAD_LIST); | |||||
default: | |||||
mgb_assert(false, "invalid param type"); | |||||
} | |||||
return {}; | |||||
} | |||||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING(+, ParamVal) | |||||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(-, ParamVal) | |||||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(*, ParamVal) | |||||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(/, ParamVal) | |||||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(==, bool) | |||||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(!=, bool) | |||||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(>=, bool) | |||||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<=, bool) | |||||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(>, bool) | |||||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<, bool) | |||||
} |
@@ -0,0 +1,486 @@ | |||||
/** | |||||
* \file src/custom/impl/tensor.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/custom/tensor.h" | |||||
#include "megbrain/comp_node.h" | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/tensor.h" | |||||
#include <cctype> | |||||
#include <algorithm> | |||||
using namespace mgb; | |||||
namespace custom { | |||||
template<typename T> | |||||
SmallVector<T> to_builtin_vector(const std::vector<T> &custom_data) { | |||||
SmallVector<T> builtin_data(custom_data.size()); | |||||
memcpy(builtin_data.data(), custom_data.data(), sizeof(T)*custom_data.size()); | |||||
return builtin_data; | |||||
} | |||||
using DeviceImpl = CompNode; | |||||
using ShapeImpl = megdnn::TensorShape; | |||||
using DTypeImpl = megdnn::DType; | |||||
using FormatImpl = megdnn::TensorLayout::Format; | |||||
using TensorImpl = DeviceTensorND; | |||||
#define DeviceImplRef(rawptr) (*reinterpret_cast<DeviceImpl*>(rawptr)) | |||||
#define ShapeImplRef(rawptr) (*reinterpret_cast<ShapeImpl*>(rawptr)) | |||||
#define DTypeImplRef(rawptr) (*reinterpret_cast<DTypeImpl*>(rawptr)) | |||||
#define FormatImplRef(rawptr) (*reinterpret_cast<FormatImpl*>(rawptr)) | |||||
#define TensorImplRef(rawptr) (*reinterpret_cast<TensorImpl*>(rawptr)) | |||||
#define DeviceImplConstRef(rawptr) static_cast<const DeviceImpl&>(*reinterpret_cast<const DeviceImpl*>(rawptr)) | |||||
#define ShapeImplConstRef(rawptr) static_cast<const ShapeImpl&>(*reinterpret_cast<const ShapeImpl*>(rawptr)) | |||||
#define DTypeImplConstRef(rawptr) static_cast<const DTypeImpl&>(*reinterpret_cast<const DTypeImpl*>(rawptr)) | |||||
#define FormatImplConstRef(rawptr) static_cast<const FormatImpl&>(*reinterpret_cast<const FormatImpl*>(rawptr)) | |||||
#define TensorImplConstRef(rawptr) static_cast<const TensorImpl&>(*reinterpret_cast<const TensorImpl*>(rawptr)) | |||||
static std::unordered_map<DeviceImpl::DeviceType, std::string, | |||||
EnumHash<DeviceImpl::DeviceType>, | |||||
EnumCmp<DeviceImpl::DeviceType>> dev_benum2cstr; | |||||
static std::unordered_map<DeviceImpl::DeviceType, DeviceEnum, | |||||
EnumHash<DeviceImpl::DeviceType>, | |||||
EnumCmp<DeviceImpl::DeviceType>> dev_benum2cenum; | |||||
static std::unordered_map<std::string, std::string> dev_cstr2bstr; | |||||
static std::unordered_map<DeviceEnum, std::string, | |||||
EnumHash<DeviceEnum>, | |||||
EnumCmp<DeviceEnum>> dev_cenum2bstr; | |||||
#define CUSTOM_BIND_DEVICE(custom_impl, builtin_device, builtin_str) \ | |||||
auto be2cs##custom_impl = dev_benum2cstr.emplace( \ | |||||
DeviceImpl::DeviceType::builtin_device, std::string(#custom_impl)); \ | |||||
auto be2ce##custom_impl = dev_benum2cenum.emplace( \ | |||||
DeviceImpl::DeviceType::builtin_device, DeviceEnum::custom_impl); \ | |||||
auto cs2bs##custom_impl = dev_cstr2bstr.emplace( \ | |||||
std::string(#custom_impl), std::string(builtin_str)); \ | |||||
auto ce2bs##custom_impl = dev_cenum2bstr.emplace( \ | |||||
DeviceEnum::custom_impl, std::string(builtin_str)); | |||||
CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_BIND_DEVICE) | |||||
#undef CUSTOM_BIND_DEVICE | |||||
CUSTOM_PIMPL_CLS_DEFINE(Device) | |||||
const void *Device::impl() const { | |||||
return m_impl.get(); | |||||
} | |||||
Device::Device(const void *impl): m_impl(nullptr, impl_deleter<DeviceImpl>) { | |||||
mgb_assert(impl != nullptr, "invalid ptr"); | |||||
if (!DeviceImplConstRef(impl).valid()) { | |||||
m_impl.reset(new DeviceImpl()); | |||||
return; | |||||
} | |||||
auto builtin_device_enum = DeviceImplConstRef(impl).device_type(); | |||||
mgb_assert( | |||||
dev_benum2cenum.find(builtin_device_enum) != dev_benum2cenum.end(), | |||||
"unsupported compnode type: %s", DeviceImplConstRef(impl).to_string().c_str() | |||||
); | |||||
m_impl.reset(new DeviceImpl(DeviceImplConstRef(impl))); | |||||
} | |||||
Device::Device(const std::string &device): m_impl(nullptr, impl_deleter<DeviceImpl>) { | |||||
mgb_assert(is_legal(device), "invalid device type: %s", device.c_str()); | |||||
std::string builtin_device = dev_cstr2bstr[device]; | |||||
m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device))); | |||||
} | |||||
// to avoid the ambiguous from Device(const void *impl) | |||||
Device::Device(const char *device): Device(std::string(device)) { | |||||
} | |||||
Device::Device(DeviceEnum device): m_impl(nullptr, impl_deleter<DeviceImpl>) { | |||||
mgb_assert(is_legal(device), "invalid device type"); | |||||
std::string builtin_device = dev_cenum2bstr[device]; | |||||
m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device))); | |||||
} | |||||
std::string Device::str(void) const { | |||||
if (!DeviceImplRef(m_impl.get()).valid()) { | |||||
return "invalid"; | |||||
} | |||||
auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type(); | |||||
auto iter = dev_benum2cstr.find(builtin_device_type); | |||||
mgb_assert( | |||||
iter != dev_benum2cstr.end(), "invalid device type %s\n", | |||||
DeviceImplRef(m_impl.get()).to_string().c_str() | |||||
); | |||||
return iter->second; | |||||
} | |||||
DeviceEnum Device::enumv(void) const { | |||||
mgb_assert( | |||||
DeviceImplRef(m_impl.get()).valid(), | |||||
"cannot get the enum value of invalid device" | |||||
); | |||||
auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type(); | |||||
auto iter = dev_benum2cenum.find(builtin_device_type); | |||||
mgb_assert( | |||||
iter != dev_benum2cenum.end(), "invalid device type %s\n", | |||||
DeviceImplRef(m_impl.get()).to_string().c_str() | |||||
); | |||||
return iter->second; | |||||
} | |||||
bool Device::is_legal(const std::string &device_type) { | |||||
return dev_cstr2bstr.find(device_type) != dev_cstr2bstr.end(); | |||||
} | |||||
bool Device::is_legal(DeviceEnum device_type) { | |||||
return dev_cenum2bstr.find(device_type) != dev_cenum2bstr.end(); | |||||
} | |||||
std::vector<std::string> Device::legal_devices(void) { | |||||
std::vector<std::string> ret; | |||||
for (const auto &kv: dev_cstr2bstr) { | |||||
ret.emplace_back(kv.first); | |||||
} | |||||
return ret; | |||||
} | |||||
bool operator==(const Device &lhs, const Device &rhs) { | |||||
return lhs.str() == rhs.str(); | |||||
} | |||||
CUSTOM_PIMPL_CLS_DEFINE(Shape) | |||||
const void *Shape::impl() const { | |||||
return m_impl.get(); | |||||
} | |||||
Shape::Shape(const void *impl): m_impl(nullptr, impl_deleter<ShapeImpl>) { | |||||
mgb_assert(impl != nullptr, "invalid ptr"); | |||||
m_impl.reset(new ShapeImpl(ShapeImplConstRef(impl))); | |||||
} | |||||
Shape::Shape(const std::vector<size_t> &rhs): m_impl(nullptr, impl_deleter<ShapeImpl>) { | |||||
m_impl.reset(new ShapeImpl(to_builtin_vector<size_t>(rhs))); | |||||
} | |||||
Shape::Shape(const std::initializer_list<size_t> &rhs): m_impl(nullptr, impl_deleter<ShapeImpl>) { | |||||
m_impl.reset(new ShapeImpl(rhs)); | |||||
} | |||||
size_t &Shape::operator[](size_t idx) { | |||||
mgb_assert(idx < ndim(), "wrong tensor dimension idx: %lu < %lu", static_cast<unsigned long>(idx), static_cast<unsigned long>(ndim())); | |||||
return ShapeImplRef(m_impl.get()).operator[](idx); | |||||
} | |||||
size_t Shape::operator[](size_t idx) const { | |||||
return const_cast<Shape*>(this)->operator[](idx); | |||||
} | |||||
void Shape::ndim(size_t dim) { | |||||
mgb_assert(dim < ShapeImpl::MAX_NDIM, "dimension must <= %lu", static_cast<unsigned long>(ShapeImpl::MAX_NDIM)); | |||||
ShapeImplRef(m_impl.get()).ndim = dim; | |||||
} | |||||
size_t Shape::ndim(void) const { | |||||
return ShapeImplRef(m_impl.get()).ndim; | |||||
} | |||||
bool operator==(const Shape &lhs, const Shape &rhs) { | |||||
return ShapeImplRef(lhs.m_impl.get()).eq_shape(ShapeImplRef(rhs.m_impl.get())); | |||||
} | |||||
static std::unordered_map<std::string, megdnn::DTypeEnum> dtype_cstr2benum; | |||||
static std::unordered_map<DTypeEnum, megdnn::DTypeEnum, | |||||
EnumHash<DTypeEnum>, | |||||
EnumCmp<DTypeEnum>> dtype_cenum2benum; | |||||
static std::unordered_map<megdnn::DTypeEnum, std::string, | |||||
EnumHash<megdnn::DTypeEnum>, | |||||
EnumCmp<megdnn::DTypeEnum>> dtype_benum2cstr; | |||||
static std::unordered_map<megdnn::DTypeEnum, DTypeEnum, | |||||
EnumHash<megdnn::DTypeEnum>, | |||||
EnumCmp<megdnn::DTypeEnum>> dtype_benum2cenum; | |||||
static std::unordered_map<DTypeEnum, std::string, | |||||
EnumHash<DTypeEnum>, | |||||
EnumCmp<DTypeEnum>> dtype_cenum2cstr; | |||||
#define CUSTOM_BIND_DTYPE(custom_impl, builtin_dtype, ctype) \ | |||||
auto cs2be##custom_impl = dtype_cstr2benum.emplace( \ | |||||
std::string(#custom_impl), megdnn::DTypeEnum::builtin_dtype); \ | |||||
auto ce2be##custom_impl = dtype_cenum2benum.emplace( \ | |||||
DTypeEnum::custom_impl, megdnn::DTypeEnum::builtin_dtype); \ | |||||
auto be2cs##custom_impl = dtype_benum2cstr.emplace( \ | |||||
megdnn::DTypeEnum::builtin_dtype, std::string(#custom_impl)); \ | |||||
auto be2ce##custom_impl = dtype_benum2cenum.emplace( \ | |||||
megdnn::DTypeEnum::builtin_dtype, DTypeEnum::custom_impl); \ | |||||
auto ce2cs##custom_impl = dtype_cenum2cstr.emplace( \ | |||||
DTypeEnum::custom_impl, std::string(#custom_impl)); | |||||
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_BIND_DTYPE) | |||||
#undef CUSTOM_BIND_DTYPE | |||||
CUSTOM_PIMPL_CLS_DEFINE(DType) | |||||
const void *DType::impl() const { | |||||
return m_impl.get(); | |||||
} | |||||
DType::DType(const void *impl): m_impl(nullptr, impl_deleter<DTypeImpl>) { | |||||
mgb_assert(impl != nullptr, "invalid ptr"); | |||||
m_impl.reset(new DTypeImpl(DTypeImplConstRef(impl))); | |||||
} | |||||
DType::DType(const std::string &dtype): m_impl(nullptr, impl_deleter<DTypeImpl>) { | |||||
auto iter = dtype_cstr2benum.find(dtype); | |||||
mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str()); | |||||
mgb_assert( | |||||
dtype[0] != 'q', "can not construct quantized dtype " | |||||
"%s without scale and zero_point", dtype.c_str() | |||||
); | |||||
m_impl.reset(new DTypeImpl(DTypeImpl::from_enum(iter->second))); | |||||
} | |||||
DType::DType(const char *dtype): DType(std::string(dtype)) { | |||||
} | |||||
DType::DType(const std::string &dtype, float scale, uint8_t zero_point) | |||||
: m_impl(nullptr, impl_deleter<DTypeImpl>) { | |||||
auto iter = dtype_cstr2benum.find(dtype); | |||||
mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str()); | |||||
mgb_assert( | |||||
dtype[0] == 'q', "given scale/zero_point to construct " | |||||
"non-quantized dtype: %s is not allowed", dtype.c_str() | |||||
); | |||||
if (dtype == "quint8") { | |||||
m_impl.reset(new megdnn::ParameterizedDType< | |||||
megdnn::DTypeEnum::Quantized8Asymm>(scale, zero_point)); | |||||
} | |||||
else { | |||||
mgb_assert( | |||||
zero_point == 0, "invalid zero point %d for dtype %s", | |||||
zero_point, dtype.c_str() | |||||
); | |||||
if (dtype == "qint8") { | |||||
m_impl.reset(new megdnn::ParameterizedDType< | |||||
megdnn::DTypeEnum::QuantizedS8>(scale)); | |||||
} | |||||
else if (dtype == "qint16") { | |||||
m_impl.reset(new megdnn::ParameterizedDType< | |||||
megdnn::DTypeEnum::QuantizedS16>(scale)); | |||||
} | |||||
else if (dtype == "qint32") { | |||||
m_impl.reset(new megdnn::ParameterizedDType< | |||||
megdnn::DTypeEnum::QuantizedS32>(scale)); | |||||
} | |||||
else { | |||||
mgb_assert(false, "invalid dtype %s", dtype.c_str()); | |||||
} | |||||
} | |||||
} | |||||
DType::DType(const char *dtype, float scale, uint8_t zero_point) | |||||
: DType(std::string(dtype), scale, zero_point) { | |||||
} | |||||
DType::DType(DTypeEnum dtype): m_impl(nullptr, impl_deleter<DTypeImpl>) { | |||||
auto iter = dtype_cenum2benum.find(dtype); | |||||
mgb_assert(iter != dtype_cenum2benum.end(), "invalid dtype"); | |||||
mgb_assert(dtype < DTypeEnum::quint8, | |||||
"can not construct quantized dtype without scale and zero_point"); | |||||
m_impl.reset(new DTypeImpl(DTypeImpl::from_enum(iter->second))); | |||||
} | |||||
DType::DType(DTypeEnum dtype, float scale, uint8_t zero_point) | |||||
: DType(dtype_cenum2cstr.find(dtype)->second, scale, zero_point) { | |||||
} | |||||
std::string DType::str(void) const { | |||||
if (!DTypeImplRef(m_impl.get()).valid()) | |||||
return "invalid"; | |||||
auto iter = dtype_benum2cstr.find(DTypeImplRef(m_impl.get()).enumv()); | |||||
if (iter == dtype_benum2cstr.end()) | |||||
return "invalid"; | |||||
return iter->second; | |||||
} | |||||
DTypeEnum DType::enumv(void) const { | |||||
auto iter = dtype_benum2cenum.find(DTypeImplRef(m_impl.get()).enumv()); | |||||
mgb_assert(iter != dtype_benum2cenum.end(), "invalid dtype"); | |||||
return iter->second; | |||||
} | |||||
float DType::scale() const { | |||||
if (enumv() == DTypeEnum::qint8) { | |||||
return DTypeImplRef(m_impl.get()).param<dtype::QuantizedS8>().scale; | |||||
} | |||||
else if (enumv() == DTypeEnum::qint16) { | |||||
return DTypeImplRef(m_impl.get()).param<dtype::QuantizedS16>().scale; | |||||
} | |||||
else if (enumv() == DTypeEnum::qint32) { | |||||
return DTypeImplRef(m_impl.get()).param<dtype::QuantizedS32>().scale; | |||||
} | |||||
else if (enumv() == DTypeEnum::quint8) { | |||||
return DTypeImplRef(m_impl.get()).param<dtype::Quantized8Asymm>().scale; | |||||
} | |||||
else { | |||||
mgb_assert(false, "dtype %s has no scale", str().c_str()); | |||||
return 0.f; | |||||
} | |||||
} | |||||
uint8_t DType::zero_point() const { | |||||
mgb_assert(enumv()==DTypeEnum::quint8, "dtype %s has no zero point", str().c_str()); | |||||
return DTypeImplRef(m_impl.get()).param<dtype::Quantized8Asymm>().zero_point; | |||||
} | |||||
bool DType::is_legal(const std::string &dtype) { | |||||
return dtype_cstr2benum.find(dtype) != dtype_cstr2benum.end(); | |||||
} | |||||
bool DType::is_legal(const DTypeEnum &dtype) { | |||||
return dtype_cenum2benum.find(dtype) != dtype_cenum2benum.end(); | |||||
} | |||||
std::vector<std::string> DType::legal_dtypes(void) { | |||||
std::vector<std::string> ret; | |||||
for (const auto &kv: dtype_cstr2benum) | |||||
ret.emplace_back(kv.first); | |||||
return ret; | |||||
} | |||||
bool operator==(const DType &lhs, const DType &rhs) { | |||||
return DTypeImplRef(lhs.m_impl.get()) == DTypeImplRef(rhs.m_impl.get()); | |||||
} | |||||
bool operator==(const DType &lhs, const std::string &rhs) { | |||||
return lhs.str() == rhs; | |||||
} | |||||
bool operator==(const DType &lhs, const char *rhs) { | |||||
return operator==(lhs, std::string(rhs)); | |||||
} | |||||
bool operator==(const std::string &lhs, const DType &rhs) { | |||||
return operator==(rhs, lhs); | |||||
} | |||||
bool operator==(const char *lhs, const DType &rhs) { | |||||
return operator==(rhs, std::string(lhs)); | |||||
} | |||||
CUSTOM_PIMPL_CLS_DEFINE(Format) | |||||
const void *Format::impl() const { | |||||
return m_impl.get(); | |||||
} | |||||
Format::Format(const void *impl): m_impl(nullptr, impl_deleter<FormatImpl>) { | |||||
mgb_assert(impl != nullptr, "invalid ptr"); | |||||
mgb_assert(FormatImplConstRef(impl).is_default(), "only default format is supported now"); | |||||
m_impl.reset(new FormatImpl(FormatImplConstRef(impl))); | |||||
} | |||||
Format::Format(const std::string &format): m_impl(nullptr, impl_deleter<FormatImpl>) { | |||||
mgb_assert(format == "default", "only default format is supported now"); | |||||
m_impl.reset(new FormatImpl()); | |||||
} | |||||
Format::Format(const char *format): Format(std::string(format)) { | |||||
} | |||||
std::string Format::str(void) const { | |||||
return FormatImplRef(m_impl.get()).to_string(); | |||||
} | |||||
bool Format::is_default(void) const { | |||||
return FormatImplRef(m_impl.get()).is_default(); | |||||
} | |||||
const void *Tensor::impl(void) const { | |||||
return m_tensor; | |||||
} | |||||
Tensor::Tensor(const void *impl) { | |||||
mgb_assert(impl != nullptr, "invalid ptr"); | |||||
m_tensor = const_cast<void*>(impl); | |||||
} | |||||
const size_t *Tensor::shapes_raw(void) const { | |||||
return TensorImplRef(m_tensor).shape().shape; | |||||
} | |||||
const ptrdiff_t *Tensor::strides_raw(void) const { | |||||
return TensorImplRef(m_tensor).layout().stride; | |||||
} | |||||
Tensor::Tensor(const Tensor &rhs) { | |||||
mgb_assert(rhs.m_tensor != nullptr, "invalid rhs for copy constructor\n"); | |||||
m_tensor = rhs.m_tensor; | |||||
} | |||||
Tensor &Tensor::operator=(const Tensor &rhs) { | |||||
mgb_assert(rhs.m_tensor != nullptr, "invalid rhs for assignment operator"); | |||||
if (&rhs == this || rhs.m_tensor == m_tensor) | |||||
return *this; | |||||
m_tensor = rhs.m_tensor; | |||||
return *this; | |||||
} | |||||
Shape Tensor::shape(void) const { | |||||
auto builtin = TensorImplRef(m_tensor).shape(); | |||||
return Shape(&builtin); | |||||
} | |||||
DType Tensor::dtype(void) const { | |||||
auto builtin = TensorImplRef(m_tensor).dtype(); | |||||
return DType(&builtin); | |||||
} | |||||
Format Tensor::format(void) const { | |||||
auto builtin = TensorImplRef(m_tensor).format(); | |||||
return Format(&builtin); | |||||
} | |||||
Device Tensor::device(void) const { | |||||
auto builtin = TensorImplRef(m_tensor).comp_node(); | |||||
return Device(&builtin); | |||||
} | |||||
size_t Tensor::size(void) const { | |||||
return TensorImplRef(m_tensor).shape().total_nr_elems(); | |||||
} | |||||
std::vector<ptrdiff_t> Tensor::stride(void) const { | |||||
std::vector<ptrdiff_t> ret(TensorImplRef(m_tensor).shape().ndim); | |||||
for (size_t i=0; i<ret.size(); i++) | |||||
ret[i] = TensorImplRef(m_tensor).layout().stride[i]; | |||||
return ret; | |||||
} | |||||
float Tensor::scale(void) const { | |||||
return dtype().scale(); | |||||
} | |||||
uint8_t Tensor::zero_point(void) const { | |||||
return dtype().zero_point(); | |||||
} | |||||
void *Tensor::data(void) { | |||||
return static_cast<void*>(TensorImplRef(m_tensor).raw_ptr()); | |||||
} | |||||
const void *Tensor::data(void) const { | |||||
return static_cast<const void*>(TensorImplRef(m_tensor).raw_ptr()); | |||||
} | |||||
} // namespace custom |
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* \file src/custom/impl/utils.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/custom/utils.h" | |||||
#include "megbrain/common.h" | |||||
#include <sstream> | |||||
using namespace mgb; | |||||
namespace custom { | |||||
void assert_failed_log(const char *file, int line, const char *func, const char *expr, const char *msg_fmt, ...) { | |||||
std::string msg = ssprintf("`%s' is true at %s:%d: %s", expr, file, line, func); | |||||
if (msg_fmt) { | |||||
msg_fmt = convert_fmt_str(msg_fmt); | |||||
va_list ap; | |||||
va_start(ap, msg_fmt); | |||||
msg.append("\nextra message: "); | |||||
msg.append(svsprintf(msg_fmt, ap)); | |||||
va_end(ap); | |||||
} | |||||
printf("%s\n", msg.c_str()); | |||||
} | |||||
UnImpleWarnLog::UnImpleWarnLog(const std::string &func, const std::string &attr, | |||||
const std::string &val) { | |||||
mgb_log_warn("you are using the default custom %s function, the `%s` attribute " | |||||
"of all the outputs tensor will be the same with inputs tensor[0]. " | |||||
"If there is no input tensor, it will be `%s`", | |||||
func.c_str(), attr.c_str(), val.c_str()); | |||||
} | |||||
} |
@@ -0,0 +1,185 @@ | |||||
/** | |||||
* \file src/custom/include/megbrain/custom/accessor.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cstddef> | |||||
#include <cstdint> | |||||
namespace custom { | |||||
#ifdef __CUDACC__ | |||||
#define CUSTOM_HOST __host__ | |||||
#define CUSTOM_DEVICE __device__ | |||||
#else | |||||
#define CUSTOM_HOST | |||||
#define CUSTOM_DEVICE | |||||
#endif | |||||
#define CUSTOM_HOST_DEVICE CUSTOM_HOST CUSTOM_DEVICE | |||||
template <typename T> | |||||
struct DefaultPtrTraits { | |||||
using PtrType = T*; | |||||
}; | |||||
#ifdef __CUDACC__ | |||||
template <typename T> | |||||
struct RestrictPtrTraits { | |||||
using PtrType = T* __restrict__; | |||||
}; | |||||
#endif | |||||
template <typename T, size_t N, | |||||
template <typename U> class PtrTraits = DefaultPtrTraits, | |||||
typename index_t = int64_t> | |||||
class TensorAccessorProxyBase { | |||||
public: | |||||
using PtrType = typename PtrTraits<T>::PtrType; | |||||
protected: | |||||
PtrType m_data; | |||||
const index_t* m_sizes; | |||||
const index_t* m_strides; | |||||
public: | |||||
CUSTOM_HOST_DEVICE TensorAccessorProxyBase(PtrType data, const index_t *sizes, const index_t *strides) { | |||||
m_data = data; | |||||
m_sizes = sizes; | |||||
m_strides = strides; | |||||
} | |||||
CUSTOM_HOST_DEVICE index_t stride(index_t i) const { | |||||
return m_strides[i]; | |||||
} | |||||
CUSTOM_HOST_DEVICE index_t size(index_t i) const { | |||||
return m_sizes[i]; | |||||
} | |||||
CUSTOM_HOST_DEVICE PtrType data() const { | |||||
return m_data; | |||||
} | |||||
}; | |||||
template<typename T, size_t N, | |||||
template <typename U> class PtrTraits = DefaultPtrTraits, | |||||
typename index_t = int64_t> | |||||
class TensorAccessorProxy: public TensorAccessorProxyBase<T, N, PtrTraits, index_t> { | |||||
public: | |||||
using PtrType = typename PtrTraits<T>::PtrType; | |||||
CUSTOM_HOST_DEVICE TensorAccessorProxy(PtrType data, const index_t *sizes, const index_t *strides) | |||||
: TensorAccessorProxyBase<T, N, PtrTraits, index_t>(data, sizes, strides) { | |||||
} | |||||
CUSTOM_HOST_DEVICE TensorAccessorProxy<T, N-1, PtrTraits, index_t> operator[](index_t i) { | |||||
return TensorAccessorProxy<T, N-1, PtrTraits, index_t>( | |||||
this->m_data + this->m_strides[0] * i, | |||||
this->m_sizes + 1, | |||||
this->m_strides + 1 | |||||
); | |||||
} | |||||
CUSTOM_HOST_DEVICE const TensorAccessorProxy<T, N-1, PtrTraits, index_t> operator[](index_t i) const { | |||||
return TensorAccessorProxy<T, N-1, PtrTraits, index_t>( | |||||
this->m_data + this->m_strides[0] * i, | |||||
this->m_sizes + 1, | |||||
this->m_strides + 1 | |||||
); | |||||
} | |||||
}; | |||||
template<typename T, template <typename U> class PtrTraits, typename index_t> | |||||
class TensorAccessorProxy<T, 1, PtrTraits, index_t> | |||||
: public TensorAccessorProxyBase<T, 1, PtrTraits, index_t> { | |||||
public: | |||||
using PtrType = typename PtrTraits<T>::PtrType; | |||||
CUSTOM_HOST_DEVICE TensorAccessorProxy(PtrType data, const index_t *sizes, const index_t *strides) | |||||
: TensorAccessorProxyBase<T, 1, PtrTraits, index_t>(data, sizes, strides ) { | |||||
} | |||||
CUSTOM_HOST_DEVICE T &operator[](index_t i) { | |||||
return this->m_data[this->m_strides[0]*i]; | |||||
} | |||||
CUSTOM_HOST_DEVICE const T &operator[](index_t i) const { | |||||
return this->m_data[this->m_strides[0]*i]; | |||||
} | |||||
}; | |||||
template<typename T, size_t N, | |||||
template <typename U> class PtrTraits = DefaultPtrTraits, | |||||
typename index_t = int64_t> | |||||
class TensorAccessorBase { | |||||
public: | |||||
using PtrType = typename PtrTraits<T>::PtrType; | |||||
protected: | |||||
PtrType m_data; | |||||
index_t m_sizes[N]; | |||||
index_t m_strides[N]; | |||||
public: | |||||
CUSTOM_HOST_DEVICE TensorAccessorBase(PtrType data, const size_t *sizes, const ptrdiff_t *strides) { | |||||
m_data = data; | |||||
for (size_t i=0; i<N; ++i) { | |||||
m_sizes[i] = sizes[i]; | |||||
m_strides[i] = strides[i]; | |||||
} | |||||
} | |||||
CUSTOM_HOST_DEVICE index_t stride(index_t i) const { | |||||
return m_strides[i]; | |||||
} | |||||
CUSTOM_HOST_DEVICE index_t size(index_t i) const { | |||||
return m_sizes[i]; | |||||
} | |||||
CUSTOM_HOST_DEVICE PtrType data() const { | |||||
return m_data; | |||||
} | |||||
}; | |||||
template<typename T, size_t N, | |||||
template <typename U> class PtrTraits = DefaultPtrTraits, | |||||
typename index_t = int64_t> | |||||
class TensorAccessor: public TensorAccessorBase<T, N, PtrTraits, index_t> { | |||||
public: | |||||
using PtrType = typename PtrTraits<T>::PtrType; | |||||
CUSTOM_HOST_DEVICE TensorAccessor(PtrType data, const size_t *sizes, const ptrdiff_t *strides) | |||||
: TensorAccessorBase<T, N, PtrTraits, index_t>(data, sizes, strides) { | |||||
} | |||||
CUSTOM_HOST_DEVICE decltype(auto) operator[](index_t i) { | |||||
return TensorAccessorProxy<T, N, PtrTraits, index_t>( | |||||
this->m_data, | |||||
this->m_sizes, | |||||
this->m_strides | |||||
)[i]; | |||||
} | |||||
CUSTOM_HOST_DEVICE decltype(auto) operator[](index_t i) const { | |||||
return TensorAccessorProxy<T, N, PtrTraits, index_t>( | |||||
this->m_data, | |||||
this->m_sizes, | |||||
this->m_strides | |||||
)[i]; | |||||
} | |||||
}; | |||||
} |
@@ -0,0 +1,108 @@ | |||||
/** | |||||
* \file src/custom/include/megbrain/custom/custom.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "op.h" | |||||
#include "tensor.h" | |||||
#include "param.h" | |||||
namespace custom { | |||||
std::shared_ptr<CustomOp> op_insert(std::string opname, uint32_t version); | |||||
} | |||||
#define CUSTOM_OP_REG(OpName) CustomOp &_##OpName = (*(op_insert(#OpName, CUSTOM_OP_VERSION))) | |||||
#define CUSTOM_OP_REG_BEGIN(OpName) \ | |||||
namespace custom { \ | |||||
namespace OpName { | |||||
#define CUSTOM_OP_REG_END(OpName) \ | |||||
} \ | |||||
} | |||||
#define CASE_TO_PERFORM_USING_HINT(name, case_type, real_type, hint, ...) \ | |||||
case (case_type): { \ | |||||
using hint = real_type; \ | |||||
return __VA_ARGS__(); \ | |||||
} | |||||
#define CASE_TO_PERFORM_ON_SCALAR(name, case_type, real_type, ...) \ | |||||
CASE_TO_PERFORM_USING_HINT(name, case_type, real_type, scalar_t, __VA_ARGS__) | |||||
#define DISPATCH_FLOAT_TYPES(tensor_dtype, name, ...) \ | |||||
[&]() { \ | |||||
const auto &dtype = tensor_dtype; \ | |||||
switch (dtype.enumv()) { \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \ | |||||
default: \ | |||||
custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||||
name, dtype.str().c_str()); \ | |||||
} \ | |||||
}() | |||||
#define DISPATCH_INT_TYPES(tensor_dtype, name, ...) \ | |||||
[&]() { \ | |||||
const auto &dtype = tensor_dtype; \ | |||||
switch (dtype.enumv()) { \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint8, uint8_t, __VA_ARGS__) \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint16,uint16_t, __VA_ARGS__)\ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ | |||||
default: \ | |||||
custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||||
name, dtype.str().c_str()); \ | |||||
} \ | |||||
}() | |||||
#define DISPATCH_INT_AND_FLOAT_TYPES(tensor_dtype, name, ...) \ | |||||
[&]() { \ | |||||
const auto &dtype = tensor_dtype; \ | |||||
switch (dtype.enumv()) { \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint8, uint8_t, __VA_ARGS__) \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint16,uint16_t, __VA_ARGS__)\ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \ | |||||
default: \ | |||||
custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||||
name, dtype.str().c_str()); \ | |||||
} \ | |||||
}() | |||||
#define DISPATCH_SIGN_INT_TYPES(tensor_dtype, name, ...) \ | |||||
[&]() { \ | |||||
const auto &dtype = tensor_dtype; \ | |||||
switch (dtype.enumv()) { \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ | |||||
default: \ | |||||
custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||||
name, dtype.str().c_str()); \ | |||||
} \ | |||||
}() | |||||
#define DISPATCH_SIGN_INT_AND_FLOAT_TYPES(tensor_dtype, name, ...) \ | |||||
[&]() { \ | |||||
const auto &dtype = tensor_dtype; \ | |||||
switch (dtype.enumv()) { \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ | |||||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ | |||||
default: \ | |||||
custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||||
name, dtype.str().c_str()); \ | |||||
} \ | |||||
}() |
@@ -0,0 +1,58 @@ | |||||
/** | |||||
* \file src/custom/include/megbrain/custom/data_adaptor.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/thin/small_vector.h" | |||||
namespace custom { | |||||
template <typename BuiltinT, typename CustomT> | |||||
BuiltinT to_builtin(const CustomT &custom) { | |||||
return *reinterpret_cast<const BuiltinT*>(custom.impl()); | |||||
} | |||||
template <typename BuiltinT, typename CustomT> | |||||
CustomT to_custom(const BuiltinT &builtin) { | |||||
return std::move(CustomT(&builtin)); | |||||
} | |||||
template <typename BuiltinT, typename CustomT> | |||||
megdnn::SmallVector<BuiltinT> to_builtin(const std::vector<CustomT> &customs) { | |||||
megdnn::SmallVector<BuiltinT> builtins; | |||||
for (size_t i=0; i<customs.size(); ++i) { | |||||
builtins.push_back(std::move(to_builtin<BuiltinT, CustomT>(customs[i]))); | |||||
} | |||||
return std::move(builtins); | |||||
} | |||||
template <typename BuiltinT, typename CustomT> | |||||
std::vector<CustomT> to_custom( | |||||
const megdnn::SmallVector<BuiltinT> &builtins) { | |||||
std::vector<CustomT> customs; | |||||
for (size_t i=0; i<builtins.size(); ++i) { | |||||
customs.push_back(std::move(to_custom<BuiltinT, CustomT>(builtins[i]))); | |||||
} | |||||
return std::move(customs); | |||||
} | |||||
} | |||||
#define to_custom_device(expr) custom::to_custom<CompNode, custom::Device>(expr) | |||||
#define to_builtin_device(expr) custom::to_builtin<CompNode, custom::Device>(expr) | |||||
#define to_custom_shape(expr) custom::to_custom<megdnn::TensorShape, custom::Shape>(expr) | |||||
#define to_builtin_shape(expr) custom::to_builtin<megdnn::TensorShape, custom::Shape>(expr) | |||||
#define to_custom_dtype(expr) custom::to_custom<megdnn::DType, custom::DType>(expr) | |||||
#define to_builtin_dtype(expr) custom::to_builtin<megdnn::DType, custom::DType>(expr) | |||||
#define to_custom_format(expr) custom::to_custom<megdnn::TensorLayout::Format, custom::Format>(expr) | |||||
#define to_builtin_format(expr) custom::to_builtin<megdnn::TensorLayout::Format, custom::Format>(expr) | |||||
#define to_custom_tensor(expr) custom::to_custom<DeviceTensorND, custom::Tensor>(expr) | |||||
#define to_builtin_tensor(expr) custom::to_builtin<DeviceTensorND, custom::Tensor>(expr) |
@@ -0,0 +1,75 @@ | |||||
/** | |||||
* \file src/custom/include/megbrain/custom/manager.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "custom.h" | |||||
#include "megbrain/common.h" | |||||
namespace custom { | |||||
class CustomOpManager { | |||||
std::unordered_map<std::string, std::shared_ptr<const CustomOp>> m_name2op; | |||||
std::unordered_map<RunTimeId, std::shared_ptr<const CustomOp>> m_id2op; | |||||
MGB_MUTEX m_mtx; | |||||
CustomOpManager() = default; | |||||
public: | |||||
PREVENT_COPY_AND_ASSIGN(CustomOpManager); | |||||
static CustomOpManager *inst(void); | |||||
~CustomOpManager(); | |||||
std::shared_ptr<CustomOp> insert(const std::string &name, uint32_t version); | |||||
bool erase(const std::string &name); | |||||
bool erase(const RunTimeId &id); | |||||
std::shared_ptr<CustomOp> find_or_reg(const std::string &name, uint32_t version); | |||||
RunTimeId to_id(const std::string &name) const; | |||||
std::string to_name(const RunTimeId &id) const; | |||||
std::shared_ptr<const CustomOp> find(const std::string &name) const; | |||||
std::shared_ptr<const CustomOp> find(const RunTimeId &id) const; | |||||
std::vector<std::string> op_name_list(void); | |||||
std::vector<RunTimeId> op_id_list(void); | |||||
}; | |||||
class CustomLib { | |||||
std::unique_ptr<void, void_deleter> m_handle; | |||||
std::vector<std::string> m_ops; | |||||
public: | |||||
PREVENT_COPY_AND_ASSIGN(CustomLib); | |||||
CustomLib(const std::string &path, int mode); | |||||
const std::vector<std::string> &ops_in_lib(void) const; | |||||
~CustomLib(); | |||||
bool valid(void) const; | |||||
}; | |||||
using LibHandle = std::shared_ptr<CustomLib>; | |||||
class LibManager { | |||||
std::unordered_map<std::string, LibHandle> m_custom_libs; | |||||
MGB_MUTEX m_mtx; | |||||
LibManager() = default; | |||||
public: | |||||
PREVENT_COPY_AND_ASSIGN(LibManager); | |||||
static LibManager *inst(void); | |||||
const std::vector<std::string> &install(const std::string &name, const std::string &path); | |||||
bool uninstall(const std::string &name); | |||||
friend class CustomOpManager; | |||||
}; | |||||
} |
@@ -0,0 +1,109 @@ | |||||
/** | |||||
* \file src/custom/include/megbrain/custom/op.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "tensor.h" | |||||
#include "param.h" | |||||
#include <unordered_set> | |||||
#define PREVENT_COPY_AND_ASSIGN(Cls) \ | |||||
Cls(const Cls&) = delete; \ | |||||
Cls(const Cls&&) = delete; \ | |||||
Cls &operator=(const Cls&) = delete; \ | |||||
Cls &operator=(const Cls&&) = delete | |||||
#define CUSTOM_OP_MAJOR 0 | |||||
#define CUSTOM_OP_MINOR 1 | |||||
#define CUSTOM_OP_PATCH 0 | |||||
#define CUSTOM_OP_VERSION CUSTOM_OP_MAJOR*10000 + CUSTOM_OP_MINOR*100 + CUSTOM_OP_PATCH | |||||
namespace custom { | |||||
using RunTimeId = uint64_t; | |||||
class ArgInfo { | |||||
CUSTOM_PIMPL_CLS_DECL(ArgInfo); | |||||
ArgInfo(const std::string &name, | |||||
const std::string &desc, | |||||
const std::unordered_set<std::string> &dtypes, | |||||
const int &ndim, | |||||
const std::string &mem_stgy); | |||||
const std::string &name(void) const; | |||||
const std::string &desc(void) const; | |||||
const std::unordered_set<std::string> &dtypes(void) const; | |||||
int ndim(void) const; | |||||
const std::string &mem_strategy(void) const; | |||||
std::string str() const; | |||||
}; | |||||
class CustomOp { | |||||
std::unique_ptr<void, void_deleter> m_impl; | |||||
public: | |||||
CustomOp(const std::string &op_type, uint32_t version); | |||||
PREVENT_COPY_AND_ASSIGN(CustomOp); | |||||
using DeviceInferFuncPtr = void(*)(const std::vector<Device>&, const Param&, std::vector<Device>&); | |||||
using ShapeInferFuncPtr = void(*)(const std::vector<Shape>&, const Param&, std::vector<Shape>&); | |||||
using DTypeInferFuncPtr = void(*)(const std::vector<DType>&, const Param&, std::vector<DType>&); | |||||
using FormatInferFuncPtr = void(*)(const std::vector<Format>&, const Param&, std::vector<Format>&); | |||||
using PreprocessFuncPtr = void(*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&); | |||||
using PostprocessFuncPtr = void(*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&); | |||||
using ComputeFuncPtr = void(*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&); | |||||
// write for forward | |||||
CustomOp &set_device_infer(DeviceInferFuncPtr func); | |||||
CustomOp &set_shape_infer(ShapeInferFuncPtr func); | |||||
CustomOp &set_dtype_infer(DTypeInferFuncPtr func); | |||||
CustomOp &set_format_infer(FormatInferFuncPtr func); | |||||
CustomOp &set_preprocess(PreprocessFuncPtr func); | |||||
CustomOp &set_preprocess(const std::string &device, PreprocessFuncPtr func); | |||||
CustomOp &set_postprocess(PostprocessFuncPtr func); | |||||
CustomOp &set_postprocess(const std::string &device, PostprocessFuncPtr func); | |||||
CustomOp &set_compute(ComputeFuncPtr func); | |||||
CustomOp &set_compute(const std::string &device, ComputeFuncPtr func); | |||||
CustomOp &set_description(const std::string &op_desc); | |||||
CustomOp &add_input(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); | |||||
CustomOp &add_output(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); | |||||
CustomOp &add_input(const std::string &name, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); | |||||
CustomOp &add_output(const std::string &name, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); | |||||
CustomOp &add_inputs(const size_t &input_num); | |||||
CustomOp &add_outputs(const size_t &output_num); | |||||
CustomOp &add_param(const std::string &name, const ParamVal &default_val); | |||||
CustomOp &add_param(const std::string &name, const std::string &desc, const ParamVal &default_val); | |||||
// read | |||||
std::string op_type(void) const; | |||||
std::string op_desc(void) const; | |||||
RunTimeId runtime_id(void) const; | |||||
size_t input_num(void) const; | |||||
size_t output_num(void) const; | |||||
std::string str(void) const; | |||||
const ParamInfo ¶m_info(void) const; | |||||
ArgInfo input_info(size_t idx) const; | |||||
ArgInfo output_info(size_t idx) const; | |||||
const std::vector<ArgInfo> &inputs_info(void) const; | |||||
const std::vector<ArgInfo> &outputs_info(void) const; | |||||
// use | |||||
std::vector<Device> infer_output_device(const std::vector<Device>&, const Param&) const; | |||||
std::vector<Shape> infer_output_shape (const std::vector<Shape>&, const Param&) const; | |||||
std::vector<DType> infer_output_dtype (const std::vector<DType>&, const Param&) const; | |||||
std::vector<Format> infer_output_format(const std::vector<Format>&, const Param&) const; | |||||
void compute(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) const; | |||||
}; | |||||
} |
@@ -0,0 +1,61 @@ | |||||
/** | |||||
* \file src/custom/include/megbrain/custom/param.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <vector> | |||||
#include <string> | |||||
#include <unordered_map> | |||||
#include "param_val.h" | |||||
namespace custom { | |||||
class ParamSchemaImpl; | |||||
class ParamInfoImpl; | |||||
class ParamImpl; | |||||
// Schema of a param element | |||||
class ParamSchema { | |||||
CUSTOM_PIMPL_CLS_DECL(ParamSchema); | |||||
ParamSchema(const std::string &name, const ParamVal &value, const std::string &desc=""); | |||||
const std::string &name(void) const; | |||||
const std::string &desc(void) const; | |||||
const ParamVal &default_val(void) const; | |||||
ParamDynType type(void) const; | |||||
std::string str(void) const; | |||||
}; | |||||
class ParamInfo { | |||||
CUSTOM_PIMPL_CLS_DECL(ParamInfo); | |||||
void set_tag(const std::string&); | |||||
void set_meta(const std::vector<ParamSchema> &meta); | |||||
uint32_t tag(void) const; | |||||
std::vector<ParamSchema> &meta(void); | |||||
const std::vector<ParamSchema> &meta(void) const; | |||||
}; | |||||
class Param { | |||||
CUSTOM_PIMPL_CLS_DECL(Param); | |||||
Param(const ParamInfo&); | |||||
ParamVal &operator[](const std::string&); | |||||
const ParamVal &operator[](const std::string&) const; | |||||
const std::unordered_map<std::string, ParamVal> &raw() const; | |||||
bool exist(const std::string &name) const; | |||||
std::string to_bytes(void) const; | |||||
void from_bytes(const std::string&); | |||||
}; | |||||
bool operator==(const Param&, const Param&); | |||||
} // custom |
@@ -0,0 +1,290 @@ | |||||
/** | |||||
* \file src/custom/include/megbrain/custom/param_val.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <string> | |||||
#include <vector> | |||||
#include <cassert> | |||||
#include <sstream> | |||||
#include <memory> | |||||
#include <unordered_map> | |||||
#include "utils.h" | |||||
namespace custom { | |||||
/** | |||||
* we can add a new basic data type here, basic means we can perform binary | |||||
* op such as: +, -, *, /, ==, != between any two of them | |||||
*/ | |||||
#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ...) \ | |||||
cb(Int32, int32_t, ##__VA_ARGS__) \ | |||||
cb(Int64, int64_t, ##__VA_ARGS__) \ | |||||
cb(Uint32, uint32_t, ##__VA_ARGS__) \ | |||||
cb(Uint64, uint64_t, ##__VA_ARGS__) \ | |||||
cb(Float32, float, ##__VA_ARGS__) \ | |||||
cb(Float64, double, ##__VA_ARGS__) \ | |||||
cb(Bool, bool, ##__VA_ARGS__) | |||||
#define CUSTOM_FOR_STRING_PARAMTYPE(cb, ...) \ | |||||
cb(String, std::string, ##__VA_ARGS__) | |||||
#define CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ...) \ | |||||
cb(Int32List, std::vector<int32_t>, ##__VA_ARGS__) \ | |||||
cb(Int64List, std::vector<int64_t>, ##__VA_ARGS__) \ | |||||
cb(Uint32List, std::vector<uint32_t>, ##__VA_ARGS__) \ | |||||
cb(Uint64List, std::vector<uint64_t>, ##__VA_ARGS__) \ | |||||
cb(Float32List, std::vector<float>, ##__VA_ARGS__) \ | |||||
cb(Float64List, std::vector<double>, ##__VA_ARGS__) | |||||
#define CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ...) \ | |||||
cb(BoolList, std::vector<bool>, ##__VA_ARGS__) | |||||
#define CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ...) \ | |||||
cb(StringList, std::vector<std::string>, ##__VA_ARGS__) | |||||
/** | |||||
* to avoid the recursive of MACRO | |||||
*/ | |||||
#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY(cb, ...) \ | |||||
cb(Int32, int32_t, ##__VA_ARGS__) \ | |||||
cb(Int64, int64_t, ##__VA_ARGS__) \ | |||||
cb(Uint32, uint32_t, ##__VA_ARGS__) \ | |||||
cb(Uint64, uint64_t, ##__VA_ARGS__) \ | |||||
cb(Float32, float, ##__VA_ARGS__) \ | |||||
cb(Float64, double, ##__VA_ARGS__) \ | |||||
cb(Bool, bool, ##__VA_ARGS__) | |||||
#define CUSTOM_FOR_EACH_VALID_PARAMTYPE(cb, ...) \ | |||||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||||
CUSTOM_FOR_STRING_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||||
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||||
CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||||
CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__) | |||||
#define CUSTOM_FOR_EACH_LIST_PARAMTYPE(cb, ...) \ | |||||
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||||
CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||||
CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__) | |||||
/** | |||||
* Macro Callback for Register | |||||
*/ | |||||
#define CUSTOM_REG_DYN_PARAMTYPE(dyn_type, static_type) dyn_type, | |||||
#define CUSTOM_REG_DYN_PARAMTYPE_NAME(dyn_type, static_type) {ParamDynType::dyn_type, #dyn_type}, | |||||
#define CUSTOM_REG_DYN_PARAMTYPE_GETTER(dyn_type, static_type) \ | |||||
template <> \ | |||||
struct get_dyn_type<static_type> { \ | |||||
static constexpr ParamDynType type = ParamDynType::dyn_type;\ | |||||
}; | |||||
#define CUSTOM_REG_STATIC_PARAMTYPE_GETTER(dyn_type, static_type) \ | |||||
template <> \ | |||||
struct get_static_type<ParamDynType::dyn_type> { \ | |||||
using type = static_type; \ | |||||
}; | |||||
enum class ParamDynType: uint32_t { | |||||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE) | |||||
Invalid=255 | |||||
}; | |||||
static std::unordered_map<ParamDynType, std::string, EnumHash<ParamDynType>, EnumCmp<ParamDynType>> type2name = { | |||||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_NAME) | |||||
{ParamDynType::Invalid, "Invalid"} | |||||
}; | |||||
/** | |||||
* get the dynamic data type according to the builtin static data type | |||||
* we can use it like: | |||||
* ParamDynType dyn_type = get_dyn_type<int32_t>::type; | |||||
* assert(dyn_type == ParamDynType::Int32) | |||||
*/ | |||||
template <typename T> | |||||
struct get_dyn_type { | |||||
static constexpr ParamDynType type = ParamDynType::Invalid; | |||||
}; | |||||
/** | |||||
* get the static data type according to the dynamic data type | |||||
* we can use it like: | |||||
* get_static_type<ParamDynType::Int32>::type int_32_value; | |||||
* assert(std::is_same<decltype(int_32_value), int>::value) | |||||
*/ | |||||
template <ParamDynType> | |||||
struct get_static_type; | |||||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_GETTER) | |||||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_STATIC_PARAMTYPE_GETTER) | |||||
#undef CUSTOM_REG_DYN_PARAMTYPE | |||||
#undef CUSTOM_REG_DYN_PARAMTYPE_NAME | |||||
#undef CUSTOM_REG_DYN_PARAMTYPE_GETTER | |||||
#undef CUSTOM_REG_STATIC_PARAMTYPE_GETTER | |||||
template <typename T> | |||||
struct get_vector_template_arg_type; | |||||
template <typename T> | |||||
struct get_vector_template_arg_type<std::vector<T>> { | |||||
using type = std::decay_t<T>; | |||||
}; | |||||
template <typename T> | |||||
struct is_vector { | |||||
static constexpr bool value = false; | |||||
}; | |||||
template <typename T> | |||||
struct is_vector <std::vector<T>> { | |||||
static constexpr bool value = true; | |||||
}; | |||||
template <typename T> | |||||
std::string vec2str(const std::vector<T> &vec) { | |||||
std::stringstream ss; | |||||
ss << "{"; | |||||
for (const auto &val: vec) { | |||||
ss << val << ", "; | |||||
} | |||||
if (vec.size() != 0) { | |||||
ss.seekp(ss.tellp()-std::streampos(2)); | |||||
} | |||||
ss << "}"; | |||||
return ss.str(); | |||||
} | |||||
/** | |||||
* we use void* rather than template to help us realise a complete dynamic type | |||||
* if we use template such as: | |||||
* template <typename T> | |||||
* class ParamVal { | |||||
* T m_data; | |||||
* } | |||||
* Con1: user need to set the type explicitly when class template instantiation | |||||
* Con2: ParamVal<int> can not be assigned to ParamVal<double> | |||||
*/ | |||||
class ParamVal { | |||||
std::unique_ptr<void, void_deleter> m_ptr; | |||||
ParamDynType m_type; | |||||
public: | |||||
template <typename T> | |||||
ParamVal(const T &val); | |||||
template <typename T> | |||||
ParamVal(const std::initializer_list<T> &val); | |||||
ParamVal(); | |||||
ParamVal(const char *str); | |||||
ParamVal(const std::initializer_list<const char*> &strs); | |||||
ParamVal(const std::vector<const char*> &strs); | |||||
ParamVal(const ParamVal &rhs); | |||||
template <typename T> | |||||
ParamVal &operator=(const T &rhs); | |||||
template <typename T> | |||||
ParamVal &operator=(const std::initializer_list<T> &val); | |||||
ParamVal &operator=(const char *str); | |||||
ParamVal &operator=(const std::initializer_list<const char*> &strs); | |||||
ParamVal &operator=(const std::vector<const char*> &strs); | |||||
ParamVal &operator=(const ParamVal &rhs); | |||||
template <typename T> | |||||
const T &as(void) const; | |||||
template <typename T> | |||||
T &as(void); | |||||
const void *raw_ptr(void) const; | |||||
void *raw_ptr(void); | |||||
ParamDynType type(void) const; | |||||
std::string str(void) const; | |||||
size_t size(void) const; | |||||
static std::string to_bytes(const ParamVal &value); | |||||
static ParamVal from_bytes(const std::string &bytes, size_t &offset); | |||||
friend ParamVal operator+(const ParamVal &lhs, const ParamVal &rhs); | |||||
friend ParamVal operator-(const ParamVal &lhs, const ParamVal &rhs); | |||||
friend ParamVal operator*(const ParamVal &lhs, const ParamVal &rhs); | |||||
friend ParamVal operator/(const ParamVal &lhs, const ParamVal &rhs); | |||||
friend bool operator==(const ParamVal &lhs, const ParamVal &rhs); | |||||
friend bool operator!=(const ParamVal &lhs, const ParamVal &rhs); | |||||
friend bool operator> (const ParamVal &lhs, const ParamVal &rhs); | |||||
friend bool operator< (const ParamVal &lhs, const ParamVal &rhs); | |||||
friend bool operator>=(const ParamVal &lhs, const ParamVal &rhs); | |||||
friend bool operator<=(const ParamVal &lhs, const ParamVal &rhs); | |||||
}; | |||||
ParamVal operator+(const ParamVal &lhs, const ParamVal &rhs); | |||||
ParamVal operator-(const ParamVal &lhs, const ParamVal &rhs); | |||||
ParamVal operator*(const ParamVal &lhs, const ParamVal &rhs); | |||||
ParamVal operator/(const ParamVal &lhs, const ParamVal &rhs); | |||||
bool operator==(const ParamVal &lhs, const ParamVal &rhs); | |||||
bool operator!=(const ParamVal &lhs, const ParamVal &rhs); | |||||
bool operator> (const ParamVal &lhs, const ParamVal &rhs); | |||||
bool operator< (const ParamVal &lhs, const ParamVal &rhs); | |||||
bool operator>=(const ParamVal &lhs, const ParamVal &rhs); | |||||
bool operator<=(const ParamVal &lhs, const ParamVal &rhs); | |||||
template <typename T> | |||||
ParamVal::ParamVal(const T &val): m_ptr(nullptr, impl_deleter<std::decay_t<T>>) { | |||||
using DecayType = std::decay_t<T>; | |||||
m_type = get_dyn_type<DecayType>::type; | |||||
custom_assert(m_type != ParamDynType::Invalid, "param construct error! unsupported builtin type"); | |||||
m_ptr.reset(new DecayType(val)); | |||||
} | |||||
template <typename T> | |||||
ParamVal::ParamVal(const std::initializer_list<T> &val): ParamVal(std::vector<std::decay_t<T>>(val)) { | |||||
} | |||||
template <typename T> | |||||
ParamVal &ParamVal::operator=(const T &rhs) { | |||||
using DecayType = std::decay_t<T>; | |||||
ParamDynType rhs_dyn_type = get_dyn_type<DecayType>::type; | |||||
custom_assert(rhs_dyn_type != ParamDynType::Invalid, "unsupported builtin dtype"); | |||||
if (rhs_dyn_type == m_type) { | |||||
TypedRef(DecayType, m_ptr.get()) = rhs; | |||||
} | |||||
else { | |||||
m_type = rhs_dyn_type; | |||||
std::unique_ptr<void, void_deleter> new_ptr(new DecayType(rhs), impl_deleter<DecayType>); | |||||
m_ptr.swap(new_ptr); | |||||
} | |||||
return *this; | |||||
} | |||||
template <typename T> | |||||
ParamVal &ParamVal::operator=(const std::initializer_list<T> &val) { | |||||
return this->operator=(std::vector<std::decay_t<T>>(val)); | |||||
} | |||||
template <typename T> | |||||
const T &ParamVal::as(void) const { | |||||
return const_cast<ParamVal*>(this)->as<T>(); | |||||
} | |||||
template <typename T> | |||||
T &ParamVal::as(void) { | |||||
using DecayType = std::decay_t<T>; | |||||
ParamDynType t_dyn_type = get_dyn_type<DecayType>::type; | |||||
custom_assert( | |||||
t_dyn_type == m_type, "type mismatch, type %s cannot be cast to type %s\n", | |||||
type2name[m_type].c_str(), type2name[t_dyn_type].c_str() | |||||
); | |||||
return TypedRef(T, m_ptr.get()); | |||||
} | |||||
} |
@@ -0,0 +1,280 @@ | |||||
/** | |||||
* \file src/custom/include/megbrain/custom/tensor.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <vector> | |||||
#include <string> | |||||
#include "utils.h" | |||||
#include "accessor.h" | |||||
namespace custom { | |||||
#define CUSTOM_DATA_ADAPTOR_FRIEND_DECL \ | |||||
template <typename BuiltinT, typename CustomT> \ | |||||
friend BuiltinT to_builtin(const CustomT &custom); \ | |||||
template <typename BuiltinT, typename CustomT> \ | |||||
friend CustomT to_custom(const BuiltinT &builtin) | |||||
#define CUSTOM_FOR_EACH_DEVICE_TYPE(cb) \ | |||||
cb(x86, CPU, "cpux") \ | |||||
cb(cuda, CUDA, "gpux") | |||||
#define CUSTOM_DEVICE_TYPE_ENUM_DECL(custom_type, builtin_type, builtin_str) custom_type, | |||||
class Device { | |||||
const void *impl() const; | |||||
Device(const void *impl); | |||||
CUSTOM_PIMPL_CLS_DECL(Device); | |||||
public: | |||||
enum class DeviceEnum: uint32_t { | |||||
CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_DEVICE_TYPE_ENUM_DECL) | |||||
}; | |||||
Device(const std::string &device); | |||||
Device(const char *device); | |||||
Device(DeviceEnum device); | |||||
std::string str(void) const; | |||||
DeviceEnum enumv(void) const; | |||||
static bool is_legal(const std::string &device); | |||||
static bool is_legal(DeviceEnum device); | |||||
static std::vector<std::string> legal_devices(void); | |||||
friend class Tensor; | |||||
friend bool operator==(const Device &lhs, const Device &rhs); | |||||
CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||||
}; | |||||
using DeviceEnum = Device::DeviceEnum; | |||||
bool operator==(const Device &lhs, const Device &rhs); | |||||
class Shape { | |||||
const void *impl() const; | |||||
Shape(const void *impl); | |||||
CUSTOM_PIMPL_CLS_DECL(Shape); | |||||
public: | |||||
Shape(const std::vector<size_t> &rhs); | |||||
Shape(const std::initializer_list<size_t> &rhs); | |||||
size_t &operator[](size_t idx); | |||||
size_t operator[](size_t idx) const; | |||||
void ndim(size_t dim); | |||||
size_t ndim(void) const; | |||||
friend class Tensor; | |||||
friend bool operator==(const Shape &lhs, const Shape &rhs); | |||||
CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||||
}; | |||||
bool operator==(const Shape &lhs, const Shape &rhs); | |||||
using float16_t = uint16_t; | |||||
using bfloat16_t = uint16_t; | |||||
#if MEGDNN_DISABLE_FLOAT16 | |||||
#define fp16_wrap(cb, custom_dtype, dnn_dtype, c_dtype) | |||||
#else | |||||
#define fp16_wrap(cb, custom_dtype, dnn_dtype, c_dtype) cb(custom_dtype, dnn_dtype, c_dtype) | |||||
#endif | |||||
#define CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(cb) \ | |||||
cb(float32, Float32, float) \ | |||||
cb(uint8, Uint8, uint8_t) \ | |||||
cb(int8, Int8, int8_t) \ | |||||
cb(int16, Int16, int16_t) \ | |||||
cb(int32, Int32, int32_t) \ | |||||
fp16_wrap(cb, float16, Float16, float16_t) \ | |||||
fp16_wrap(cb, bfloat16, BFloat16, bfloat16_t) \ | |||||
cb(uint16, Uint16, uint16_t) \ | |||||
cb(quint8, Quantized8Asymm, uint8_t) \ | |||||
cb(qint32, QuantizedS32, int32_t) \ | |||||
cb(qint8, QuantizedS8, int8_t) \ | |||||
cb(qint16, QuantizedS16, int16_t) | |||||
#define CUSTOM_DTYPE_ENUM_DECL(custom_type, builtin_type, ctype) custom_type, | |||||
class DType { | |||||
const void *impl() const; | |||||
DType(const void *impl); | |||||
CUSTOM_PIMPL_CLS_DECL(DType); | |||||
public: | |||||
enum class DTypeEnum: uint32_t { | |||||
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_DTYPE_ENUM_DECL) | |||||
}; | |||||
DType(const std::string &dtype); | |||||
DType(const char *dtype); | |||||
DType(const std::string &dtype, float scale, uint8_t zero_point = 0); | |||||
DType(const char *dtype, float scale, uint8_t zero_point = 0); | |||||
DType(DTypeEnum dtype); | |||||
DType(DTypeEnum dtype, float scale, uint8_t zero_point = 0); | |||||
std::string str(void) const; | |||||
DTypeEnum enumv() const; | |||||
float scale(void) const; | |||||
uint8_t zero_point(void) const; | |||||
template<typename T> | |||||
bool is_compatible(void) const; | |||||
static bool is_legal(const std::string &dtype); | |||||
static bool is_legal(const DTypeEnum &dtype); | |||||
static std::vector<std::string> legal_dtypes(void); | |||||
friend class Tensor; | |||||
friend bool operator==(const DType &lhs, const DType &rhs); | |||||
CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||||
}; | |||||
using DTypeEnum = DType::DTypeEnum; | |||||
template <DTypeEnum> | |||||
struct DTypeTrait; | |||||
#define CUSTOM_DEFINE_DTYPE_TRAIT(custom_type, builtin_type, ctype) \ | |||||
template <> \ | |||||
struct DTypeTrait<DTypeEnum::custom_type> { \ | |||||
using type = ctype; \ | |||||
}; | |||||
#define CUSTOM_CASE_TO_COMPARE_DTYPE(custom_type, builtin_type, ctype) \ | |||||
case (DTypeEnum::custom_type): { \ | |||||
return std::is_same<DecayT, ctype>::value; \ | |||||
} | |||||
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_DEFINE_DTYPE_TRAIT) | |||||
template<typename T> | |||||
bool DType::is_compatible(void) const { | |||||
using DecayT = typename std::decay<T>::type; | |||||
auto dtype_enum = enumv(); | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
if (dtype_enum == DTypeEnum::float16) { | |||||
return sizeof(DecayT) == sizeof(DTypeTrait<DTypeEnum::float16>::type); | |||||
} | |||||
else if (dtype_enum == DTypeEnum::bfloat16) { | |||||
return sizeof(DecayT) == sizeof(DTypeTrait<DTypeEnum::bfloat16>::type); | |||||
} | |||||
#endif | |||||
switch (dtype_enum) { | |||||
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_CASE_TO_COMPARE_DTYPE) | |||||
default: | |||||
return false; | |||||
} | |||||
} | |||||
bool operator==(const DType &lhs, const DType &rhs); | |||||
bool operator==(const DType &lhs, const std::string &rhs); | |||||
bool operator==(const DType &lhs, const char *rhs); | |||||
bool operator==(const std::string &lhs, const DType &rhs); | |||||
bool operator==(const char *lhs, const DType &rhs); | |||||
class Format { | |||||
const void *impl() const; | |||||
Format(const void *impl); | |||||
CUSTOM_PIMPL_CLS_DECL(Format); | |||||
public: | |||||
Format(const std::string &format); | |||||
Format(const char *format); | |||||
std::string str(void) const; | |||||
bool is_default(void) const; | |||||
friend class Tensor; | |||||
CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||||
}; | |||||
class Tensor { | |||||
void *m_tensor; | |||||
const void *impl(void) const; | |||||
Tensor(const void *impl); | |||||
const size_t *shapes_raw(void) const; | |||||
const ptrdiff_t *strides_raw(void) const; | |||||
public: | |||||
Tensor() = delete; | |||||
Tensor(const Tensor &rhs); | |||||
Tensor &operator=(const Tensor &rhs); | |||||
Shape shape(void) const; | |||||
DType dtype(void) const; | |||||
Format format(void) const; | |||||
Device device(void) const; | |||||
size_t size(void) const; | |||||
std::vector<ptrdiff_t> stride(void) const; | |||||
float scale(void) const; | |||||
uint8_t zero_point(void) const; | |||||
void *data(void); | |||||
const void *data(void) const; | |||||
template <typename T> | |||||
T *data(void); | |||||
template <typename T> | |||||
const T *data(void) const; | |||||
template <typename T, size_t N, | |||||
template <typename U> class PtrTraits = DefaultPtrTraits, | |||||
typename index_t = int64_t> | |||||
const TensorAccessor<T, N, PtrTraits, index_t> accessor() const; | |||||
template <typename T, size_t N, | |||||
template <typename U> class PtrTraits = DefaultPtrTraits, | |||||
typename index_t = int64_t> | |||||
TensorAccessor<T, N, PtrTraits, index_t> accessor(); | |||||
CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||||
}; | |||||
template <typename T> | |||||
T *Tensor::data(void) { | |||||
custom_assert(dtype().is_compatible<T>(), | |||||
"invalid convert, tensor data type is %s", dtype().str().c_str()); | |||||
return reinterpret_cast<T*>(data()); | |||||
} | |||||
template <typename T> | |||||
const T *Tensor::data(void) const { | |||||
return const_cast<Tensor*>(this)->data<T>(); | |||||
} | |||||
template <typename T, size_t N, template <typename U> class PtrTraits, typename index_t> | |||||
const TensorAccessor<T, N, PtrTraits, index_t> Tensor::accessor() const { | |||||
return const_cast<Tensor*>(this)->accessor<T, N, PtrTraits, index_t>(); | |||||
} | |||||
template <typename T, size_t N, template <typename U> class PtrTraits, typename index_t> | |||||
TensorAccessor<T, N, PtrTraits, index_t> Tensor::accessor() { | |||||
custom_assert(N == shape().ndim(), | |||||
"cannot get a %lu-d accessor for a tensor with dim %lu", static_cast<unsigned long>(N), static_cast<unsigned long>(shape().ndim())); | |||||
custom_assert(N > 0, "cannot get 0-d accessor"); | |||||
T *ptr = data<T>(); | |||||
return TensorAccessor<T, N, PtrTraits, index_t>(ptr, shapes_raw(), strides_raw()); | |||||
} | |||||
#undef CUSTOM_DATA_ADAPTOR_FRIEND_DECL | |||||
#undef CUSTOM_DEVICE_TYPE_ENUM_DECL | |||||
#undef CUSTOM_DTYPE_ENUM_DECL | |||||
#undef CUSTOM_DEFINE_DTYPE_TRAIT | |||||
#undef CUSTOM_CASE_TO_COMPARE_DTYPE | |||||
} // custom |
@@ -0,0 +1,104 @@ | |||||
/** | |||||
* \file src/custom/include/megbrain/custom/utils.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <vector> | |||||
#include <string> | |||||
#include <memory> | |||||
#include <cassert> | |||||
namespace custom { | |||||
void assert_failed_log(const char *file, int line, const char *func, const char *expr, const char *msg_fmt, ...); | |||||
#define custom_expect(expr, msg...) \ | |||||
if (!(expr)) { \ | |||||
assert_failed_log( \ | |||||
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg \ | |||||
); \ | |||||
} | |||||
#define custom_assert(expr, msg...) \ | |||||
if (!(expr)) { \ | |||||
assert_failed_log( \ | |||||
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg \ | |||||
); \ | |||||
} \ | |||||
assert((expr)) | |||||
class UnImpleWarnLog { | |||||
public: | |||||
UnImpleWarnLog(const std::string &func, const std::string &attr, | |||||
const std::string &val); | |||||
}; | |||||
using void_deleter = void(*)(void*); | |||||
template<typename Impl> | |||||
void impl_deleter(void *ptr) { | |||||
delete reinterpret_cast<Impl*>(ptr); | |||||
} | |||||
#define TypedPtr(type, raw_ptr) reinterpret_cast<type*>(raw_ptr) | |||||
#define TypedRef(type, raw_ptr) (*reinterpret_cast<type*>(raw_ptr)) | |||||
#define CUSTOM_PIMPL_CLS_DECL(Cls) \ | |||||
std::unique_ptr<void, void_deleter> m_impl; \ | |||||
public: \ | |||||
Cls(); \ | |||||
Cls(const Cls &rhs); \ | |||||
Cls &operator=(const Cls &rhs) | |||||
#define CUSTOM_PIMPL_CLS_DEFINE(Cls) \ | |||||
Cls::Cls(): m_impl(new Cls##Impl(), impl_deleter<Cls##Impl>) {} \ | |||||
\ | |||||
Cls::Cls(const Cls &rhs): m_impl(nullptr, impl_deleter<Cls##Impl>) { \ | |||||
custom_assert( \ | |||||
rhs.m_impl != nullptr, \ | |||||
"invalid rhs for the copy constructor of %s", #Cls \ | |||||
); \ | |||||
m_impl.reset(new Cls##Impl(TypedRef(Cls##Impl, rhs.m_impl.get()))); \ | |||||
} \ | |||||
\ | |||||
Cls &Cls::operator=(const Cls &rhs) { \ | |||||
custom_assert( \ | |||||
m_impl != nullptr && rhs.m_impl != nullptr, \ | |||||
"invalid assignment of %s, lhs or rhs is invalid", #Cls \ | |||||
); \ | |||||
if (&rhs == this) \ | |||||
return *this; \ | |||||
\ | |||||
TypedRef(Cls##Impl, m_impl.get()) = TypedRef(Cls##Impl, rhs.m_impl.get()); \ | |||||
return *this; \ | |||||
} | |||||
/** | |||||
* we define this two function explicitly used for std::unordered_map | |||||
* to improve the compatibility with different compiler versions | |||||
*/ | |||||
template <typename T> | |||||
struct EnumHash { | |||||
size_t operator()(const T &rhs) const { | |||||
return static_cast<size_t>(rhs); | |||||
} | |||||
}; | |||||
template <typename T> | |||||
struct EnumCmp { | |||||
bool operator()(const T &lhs, const T &rhs) const { | |||||
return static_cast<size_t>(lhs) == static_cast<size_t>(rhs); | |||||
} | |||||
}; | |||||
} // custom |
@@ -0,0 +1,96 @@ | |||||
/** | |||||
* \file src/custom/test/manager.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/custom/manager.h" | |||||
#include "megbrain/custom/custom.h" | |||||
#include "gtest/gtest.h" | |||||
#define MANAGER_TEST_LOG 0 | |||||
namespace custom { | |||||
TEST(TestOpManager, TestOpManager) { | |||||
CustomOpManager *com = CustomOpManager::inst(); | |||||
com->insert("Op1", CUSTOM_OP_VERSION); | |||||
com->insert("Op2", CUSTOM_OP_VERSION); | |||||
std::shared_ptr<CustomOp> ptr = com->find_or_reg("Op3", CUSTOM_OP_VERSION); | |||||
ASSERT_TRUE(ptr != nullptr); | |||||
std::vector<std::string> op_names = com->op_name_list(); | |||||
std::vector<RunTimeId> op_ids = com->op_id_list(); | |||||
ASSERT_TRUE(op_names.size() == 3); | |||||
ASSERT_TRUE(op_ids.size() == 3); | |||||
#if MANAGER_TEST_LOG | |||||
for (std::string &name: op_names) { | |||||
std::cout << name << std::endl; | |||||
} | |||||
#endif | |||||
for (std::string &name: op_names) { | |||||
std::shared_ptr<const CustomOp> op = com->find(name); | |||||
ASSERT_TRUE(op != nullptr); | |||||
ASSERT_TRUE(op->op_type() == name); | |||||
RunTimeId id = com->to_id(name); | |||||
ASSERT_TRUE(com->find(id) == op); | |||||
} | |||||
for (RunTimeId &id: op_ids) { | |||||
std::shared_ptr<const CustomOp> op = com->find(id); | |||||
ASSERT_TRUE(op != nullptr); | |||||
ASSERT_TRUE(op->runtime_id() == id); | |||||
std::string name = com->to_name(id); | |||||
ASSERT_TRUE(com->find(name) == op); | |||||
} | |||||
ASSERT_FALSE(com->erase("Op0")); | |||||
#if MANAGER_TEST_LOG | |||||
for (auto &name: com->op_name_list()) { | |||||
std::cout << name << std::endl; | |||||
} | |||||
#endif | |||||
ASSERT_TRUE(com->erase("Op1")); | |||||
ASSERT_TRUE(com->erase(com->to_id("Op2"))); | |||||
ASSERT_TRUE(com->op_id_list().size() == 1); | |||||
ASSERT_TRUE(com->op_name_list().size() == 1); | |||||
ASSERT_TRUE(com->op_name_list()[0] == "Op3"); | |||||
ptr.reset(); | |||||
ASSERT_TRUE(com->erase("Op3")); | |||||
} | |||||
TEST(TestOpManager, TestOpReg) { | |||||
CUSTOM_OP_REG(Op1) | |||||
.add_inputs(2) | |||||
.add_outputs(3) | |||||
.add_input("lhs") | |||||
.add_param("param1", 1) | |||||
.add_param("param2", 3.45); | |||||
CUSTOM_OP_REG(Op2) | |||||
.add_input("lhs") | |||||
.add_input("rhs") | |||||
.add_output("out") | |||||
.add_param("param1", "test") | |||||
.add_param("param2", true) | |||||
.add_param("", "no name"); | |||||
(void)_Op1; | |||||
(void)_Op2; | |||||
#if MANAGER_TEST_LOG | |||||
for (const auto &name: CustomOpManager::inst()->op_name_list()) { | |||||
std::cout << CustomOpManager::inst()->find(name)->str() << std::endl; | |||||
} | |||||
#endif | |||||
} | |||||
} |
@@ -0,0 +1,205 @@ | |||||
/** | |||||
* \file src/custom/test/op.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/custom/op.h" | |||||
#include "megbrain/comp_node.h" | |||||
#include "megbrain/tensor.h" | |||||
#include "megbrain/custom/data_adaptor.h" | |||||
#include "gtest/gtest.h" | |||||
#include "megbrain_build_config.h" | |||||
#define OP_TEST_LOG 0 | |||||
using namespace mgb; | |||||
namespace custom { | |||||
TEST(TestCustomOp, TestCustomOpInfoSetter) { | |||||
CustomOp test("TestOp", CUSTOM_OP_VERSION); | |||||
test.set_description("Test Op") | |||||
.add_input("lhs", "lhs of test op", {"float32", "int32"}, 2) | |||||
.add_inputs(2) | |||||
.add_input("rhs", "rhs of test op", {"float32", "int32"}, 2) | |||||
.add_outputs(1) | |||||
.add_output("out", "out of test op", {"float32", "int32"}, 2) | |||||
.add_outputs(3); | |||||
ASSERT_TRUE(test.op_type() == "TestOp"); | |||||
ASSERT_TRUE(test.op_desc() == "Test Op"); | |||||
ASSERT_TRUE(test.input_num() == 4); | |||||
ASSERT_TRUE(test.output_num() == 5); | |||||
#if OP_TEST_LOG | |||||
for (auto input: test.inputs_info()) { | |||||
std::cout << input.str() << std::endl; | |||||
} | |||||
for (auto output: test.outputs_info()) { | |||||
std::cout << output.str() << std::endl; | |||||
} | |||||
#endif | |||||
test.add_param("param1", "param1 - float", 1.23f) | |||||
.add_param("param2", "param2 - float list", {2.34f, 3.45f}) | |||||
.add_param("param3", "param3 - string", "test-string") | |||||
.add_param("param4", {"test", "string", "list"}) | |||||
.add_param("param5", 1); | |||||
#if OP_TEST_LOG | |||||
ParamInfo pinfo = test.param_info(); | |||||
for (auto kv: pinfo.meta()) { | |||||
std::cout << kv.str() << std::endl; | |||||
} | |||||
#endif | |||||
} | |||||
void device_infer(const std::vector<Device> &inputs, const Param ¶ms, | |||||
std::vector<Device> &outputs) { | |||||
(void)inputs; | |||||
(void)params; | |||||
(void)outputs; | |||||
outputs[0] = inputs[1]; | |||||
outputs[1] = inputs[0]; | |||||
} | |||||
void shape_infer(const std::vector<Shape> &inputs, const Param ¶ms, | |||||
std::vector<Shape> &outputs) { | |||||
(void)inputs; | |||||
(void)params; | |||||
(void)outputs; | |||||
outputs[0] = inputs[1]; | |||||
outputs[1] = inputs[0]; | |||||
} | |||||
void dtype_infer(const std::vector<DType> &inputs, const Param ¶ms, | |||||
std::vector<DType> &outputs) { | |||||
(void)inputs; | |||||
(void)params; | |||||
(void)outputs; | |||||
outputs[0] = inputs[1]; | |||||
outputs[1] = inputs[0]; | |||||
} | |||||
void format_infer(const std::vector<Format> &inputs, const Param ¶ms, | |||||
std::vector<Format> &outputs) { | |||||
(void)inputs; | |||||
(void)params; | |||||
(void)outputs; | |||||
outputs[0] = inputs[1]; | |||||
outputs[1] = inputs[0]; | |||||
} | |||||
void cpu_kernel(const std::vector<Tensor> &inputs, const Param ¶ms, | |||||
std::vector<Tensor> &outputs) { | |||||
(void)inputs; | |||||
(void)params; | |||||
(void)outputs; | |||||
#if OP_TEST_LOG | |||||
std::cout << "Checking CPU Forward - " << params["device"].as<std::string>() << std::endl; | |||||
#endif | |||||
ASSERT_TRUE(params["device"] == "x86"); | |||||
} | |||||
void gpu_kernel(const std::vector<Tensor> &inputs, const Param ¶ms, | |||||
std::vector<Tensor> &outputs) { | |||||
(void)inputs; | |||||
(void)params; | |||||
(void)outputs; | |||||
#if OP_TEST_LOG | |||||
std::cout << "Checking GPU Forward - " << params["device"].as<std::string>() << std::endl; | |||||
#endif | |||||
ASSERT_TRUE(params["device"] == "cuda"); | |||||
} | |||||
TEST(TestCustomOp, TestCustomOpFuncSetter) { | |||||
#if MGB_CUDA | |||||
CustomOp test("TestOp", CUSTOM_OP_VERSION); | |||||
test.set_description("Test Op Forward Backward Union") | |||||
.add_input("lhs", "lhs of Test op", {"float32", "int32"}, 2) | |||||
.add_input("rhs", "rhs of Test op", {"float32", "int32"}, 2) | |||||
.add_output("outl", "outl of Test op", {"float32", "int32"}, 2) | |||||
.add_output("outr", "outr of Test op", {"float32", "int32"}, 2) | |||||
.add_param("smooth", "smooth", 0.f) | |||||
.add_param("device", "using for judge device", "x86"); | |||||
std::vector<Device> idevices = {"x86", "cuda"}; | |||||
std::vector<Shape> ishapes = {{2, 3}, {3, 4}}; | |||||
std::vector<DType> idtypes = {"int32", "float32"}; | |||||
std::vector<Format> iformats = {"default", "default"}; | |||||
Param param(test.param_info()); | |||||
std::vector<Device> odevices = test.infer_output_device(idevices, param); | |||||
std::vector<Shape> oshapes = test.infer_output_shape (ishapes, param); | |||||
std::vector<DType> odtypes = test.infer_output_dtype (idtypes, param); | |||||
std::vector<Format> oformats = test.infer_output_format(iformats, param); | |||||
ASSERT_TRUE(odevices.size() == 2); | |||||
ASSERT_TRUE(oshapes.size() == 2); | |||||
ASSERT_TRUE(odtypes.size() == 2); | |||||
ASSERT_TRUE(oformats.size() == 2); | |||||
ASSERT_TRUE(odevices[0] == "x86"); | |||||
ASSERT_TRUE(odevices[1] == "x86"); | |||||
ASSERT_TRUE(oshapes[0] == Shape({2,3})); | |||||
ASSERT_TRUE(oshapes[1] == Shape({2,3})); | |||||
ASSERT_TRUE(odtypes[0] == "int32"); | |||||
ASSERT_TRUE(odtypes[1] == "int32"); | |||||
ASSERT_TRUE(iformats[0].is_default()); | |||||
ASSERT_TRUE(iformats[1].is_default()); | |||||
test.set_device_infer(device_infer) | |||||
.set_shape_infer(shape_infer) | |||||
.set_dtype_infer(dtype_infer) | |||||
.set_format_infer(format_infer); | |||||
odevices = test.infer_output_device(idevices, param); | |||||
oshapes = test.infer_output_shape (ishapes, param); | |||||
odtypes = test.infer_output_dtype (idtypes, param); | |||||
oformats = test.infer_output_format(iformats, param); | |||||
ASSERT_TRUE(odevices.size() == 2); | |||||
ASSERT_TRUE(oshapes.size() == 2); | |||||
ASSERT_TRUE(odtypes.size() == 2); | |||||
ASSERT_TRUE(oformats.size() == 2); | |||||
ASSERT_TRUE(odevices[0] == "cuda"); | |||||
ASSERT_TRUE(odevices[1] == "x86"); | |||||
ASSERT_TRUE(oshapes[0] == Shape({3,4})); | |||||
ASSERT_TRUE(oshapes[1] == Shape({2,3})); | |||||
ASSERT_TRUE(odtypes[0] == "float32"); | |||||
ASSERT_TRUE(odtypes[1] == "int32"); | |||||
ASSERT_TRUE(iformats[0].is_default()); | |||||
ASSERT_TRUE(iformats[1].is_default()); | |||||
test.set_compute(cpu_kernel); | |||||
DeviceTensorND cdev_itensor0(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); | |||||
DeviceTensorND cdev_itensor1(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); | |||||
DeviceTensorND cdev_otensor0(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); | |||||
DeviceTensorND cdev_otensor1(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); | |||||
std::vector<Tensor> cinputs = {to_custom_tensor(cdev_itensor0), to_custom_tensor(cdev_itensor1)}; | |||||
std::vector<Tensor> coutputs ={to_custom_tensor(cdev_otensor0), to_custom_tensor(cdev_otensor1)}; | |||||
param["device"] = "x86"; | |||||
test.compute(cinputs, param, coutputs); | |||||
test.set_compute("cuda", gpu_kernel); | |||||
DeviceTensorND gdev_itensor0(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); | |||||
DeviceTensorND gdev_itensor1(CompNode::load("gpux"), {3, 2}, dtype::Float32{}); | |||||
DeviceTensorND gdev_otensor0(CompNode::load("gpux"), {3, 2}, dtype::Float32{}); | |||||
DeviceTensorND gdev_otensor1(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); | |||||
std::vector<Tensor> ginputs = {to_custom_tensor(gdev_itensor0), to_custom_tensor(gdev_itensor1)}; | |||||
std::vector<Tensor> goutputs ={to_custom_tensor(gdev_otensor0), to_custom_tensor(gdev_otensor1)}; | |||||
param["device"] = "cuda"; | |||||
test.compute(ginputs, param, goutputs); | |||||
#endif | |||||
} | |||||
} |
@@ -0,0 +1,208 @@ | |||||
/** | |||||
* \file src/custom/test/param.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/custom/param.h" | |||||
#include "gtest/gtest.h" | |||||
#include <iostream> | |||||
#define PARAM_TEST_LOG 0 | |||||
namespace custom { | |||||
#define SchemaDef \ | |||||
ParamSchema schema_bool("param_bool", true, "bool"); \ | |||||
ParamSchema schema_flt("param_flt", 2.3f, "float"); \ | |||||
ParamSchema schema_int("param_int", 4, "int"); \ | |||||
ParamSchema schema_str("param_str", "test", "string"); \ | |||||
ParamSchema schema_bool_list("param_bl", {true, false, true}, "bool list"); \ | |||||
ParamSchema schema_flt_list("param_fl", {1.1f, 2.2f, 3.3f}, "float list"); \ | |||||
ParamSchema schema_int_list("param_il", {1, 2, 3}, "int list"); \ | |||||
ParamSchema schema_str_list("param_sl", {"test1", "test2", "test3"}, "string list") | |||||
#define InfoDef \ | |||||
info.meta().emplace_back(schema_bool); \ | |||||
info.meta().emplace_back(schema_flt); \ | |||||
info.meta().emplace_back(schema_int); \ | |||||
info.meta().emplace_back(schema_str); \ | |||||
info.meta().emplace_back(schema_bool_list); \ | |||||
info.meta().emplace_back(schema_flt_list); \ | |||||
info.meta().emplace_back(schema_int_list); \ | |||||
info.meta().emplace_back(schema_str_list) | |||||
TEST(TestParam, TestParamScheme) { | |||||
#if PARAM_TEST_LOG | |||||
SchemaDef; | |||||
ParamSchema new_schema = schema_int; | |||||
std::cout << schema_bool.str() << std::endl; | |||||
std::cout << schema_flt.str() << std::endl; | |||||
std::cout << schema_int.str() << std::endl; | |||||
std::cout << schema_str.str() << std::endl; | |||||
std::cout << schema_bool_list.str() << "len: "<< schema_bool_list.default_val().size() << std::endl; | |||||
std::cout << schema_flt_list.str() << "len: "<< schema_flt_list.default_val().size() << std::endl; | |||||
std::cout << schema_int_list.str() << "len: "<< schema_int_list.default_val().size() << std::endl; | |||||
std::cout << schema_str_list.str() << "len: "<< schema_str_list.default_val().size() << std::endl; | |||||
std::cout << new_schema.str() << std::endl; | |||||
#endif | |||||
} | |||||
TEST(TestParam, TestParamVal) { | |||||
ParamVal pv1 = 1.2f, pv2 = true, pv3 = "test", pv4 = {0, 1, 2}, | |||||
pv5 = {true, false, true}; | |||||
#if PARAM_TEST_LOG | |||||
ParamVal pv6 = {"test1", "test2", "test3"}; | |||||
std::cout << pv1.str() << std::endl; | |||||
std::cout << pv2.str() << std::endl; | |||||
std::cout << pv3.str() << std::endl; | |||||
std::cout << pv4.str() << std::endl; | |||||
std::cout << pv5.str() << std::endl; | |||||
std::cout << pv6.str() << std::endl; | |||||
#endif | |||||
ParamVal pv_manip = pv1; | |||||
ASSERT_TRUE(pv_manip.type() == pv1.type()); | |||||
ASSERT_TRUE(pv_manip == pv1); | |||||
pv_manip = 1.3; | |||||
ASSERT_TRUE(pv_manip.type() != pv1.type()); | |||||
ASSERT_TRUE(pv_manip != pv1); | |||||
ASSERT_TRUE(pv_manip > pv1); | |||||
pv_manip = pv_manip + pv1; | |||||
ASSERT_TRUE(pv_manip.type() == ParamDynType::Float64); | |||||
ASSERT_TRUE(pv_manip == 1.3 + 1.2f); | |||||
pv_manip = 1.3f + 1.2f; | |||||
ASSERT_TRUE(pv_manip.type() == pv1.type()); | |||||
pv_manip = false; | |||||
ASSERT_TRUE(pv_manip.type() == pv2.type()); | |||||
ASSERT_TRUE(pv_manip.type() == ParamDynType::Bool); | |||||
ASSERT_TRUE(pv_manip != pv2); | |||||
pv_manip = "test"; | |||||
ASSERT_TRUE(pv_manip.type() == pv3.type()); | |||||
ASSERT_TRUE(pv_manip.type() == ParamDynType::String); | |||||
ASSERT_TRUE(pv_manip == pv3); | |||||
pv_manip = "test1"; | |||||
ASSERT_TRUE(pv_manip > pv3); | |||||
pv_manip = pv_manip + pv3; | |||||
ASSERT_TRUE(pv_manip == "test1test"); | |||||
pv_manip = {0, 1, 2}; | |||||
ASSERT_TRUE(pv_manip.type() == pv4.type()); | |||||
ASSERT_TRUE(pv_manip.type() == ParamDynType::Int32List); | |||||
ASSERT_TRUE(pv_manip == pv4); | |||||
pv_manip = {3, 2, 1}; | |||||
ASSERT_TRUE(pv_manip != pv4); | |||||
ASSERT_TRUE(pv_manip > pv4); | |||||
pv_manip = {true, false, true}; | |||||
ASSERT_TRUE(pv_manip.type() == pv5.type()); | |||||
ASSERT_TRUE(pv_manip.type() == ParamDynType::BoolList); | |||||
ASSERT_TRUE(pv_manip == pv5); | |||||
pv_manip = {false, true, false}; | |||||
ASSERT_TRUE(pv_manip != pv5); | |||||
} | |||||
TEST(TestParam, TestParamInfo) { | |||||
ParamInfo info; | |||||
info.set_tag("Test"); | |||||
#if PARAM_TEST_LOG | |||||
uint32_t tag = info.tag(); | |||||
std::cout << tag << std::endl; | |||||
#endif | |||||
SchemaDef; | |||||
InfoDef; | |||||
ParamInfo new_info1, new_info2; | |||||
new_info1.set_meta(info.meta()); | |||||
new_info2.meta() = info.meta(); | |||||
#if PARAM_TEST_LOG | |||||
for (auto ele: new_info1.meta()) { | |||||
std::cout << ele.str() << std::endl; | |||||
} | |||||
for (auto ele: new_info2.meta()) { | |||||
std::cout << ele.str() << std::endl; | |||||
} | |||||
#endif | |||||
} | |||||
TEST(TestParam, TestParam) { | |||||
ParamInfo info; | |||||
SchemaDef; | |||||
InfoDef; | |||||
Param param(info); | |||||
#if PARAM_TEST_LOG | |||||
std::vector<std::string> names = {"param_bool", "param_flt", "param_int", "param_str", "param_bl", "param_fl", "param_il", "param_sl"}; | |||||
for (auto &name: names) { | |||||
std::cout << param[name].str() << std::endl;; | |||||
} | |||||
#endif | |||||
ASSERT_TRUE(param["param_bool"] == true); | |||||
ASSERT_TRUE(param["param_flt"] == 2.3f); | |||||
ASSERT_TRUE(param["param_int"] == 4); | |||||
ASSERT_TRUE(param["param_str"] == "test"); | |||||
ASSERT_TRUE(param["param_bl"] == ParamVal({true, false, true})); | |||||
ASSERT_TRUE(param["param_fl"] == ParamVal({1.1f, 2.2f, 3.3f})); | |||||
ASSERT_TRUE(param["param_il"] == ParamVal({1, 2, 3})); | |||||
ASSERT_TRUE(param["param_sl"] == ParamVal({"test1", "test2", "test3"})); | |||||
param["param_bool"] = false; | |||||
param["param_flt"] = 3.4f; | |||||
param["param_int"] = 5; | |||||
param["param_str"] = "tset"; | |||||
param["param_bl"] = {false, true, false, true}; | |||||
param["param_fl"] = {7.6f, 6.5f}; | |||||
param["param_il"] = {5, 4, 3, 2, 1}; | |||||
param["param_sl"] = {"1tset", "2tset", "3tset", "4tset", "5tset"}; | |||||
ASSERT_TRUE(param["param_bool"] != true); | |||||
ASSERT_TRUE(param["param_flt"] != 2.3f); | |||||
ASSERT_TRUE(param["param_int"] != 4); | |||||
ASSERT_TRUE(param["param_str"] != "test"); | |||||
ASSERT_TRUE(param["param_bl"] != ParamVal({true, false, true})); | |||||
ASSERT_TRUE(param["param_fl"] != ParamVal({1.1f, 2.2f, 3.3f})); | |||||
ASSERT_TRUE(param["param_il"] != ParamVal({1, 2, 3})); | |||||
ASSERT_TRUE(param["param_sl"] != ParamVal({"test1", "test2", "test3"})); | |||||
ASSERT_TRUE(param["param_bool"] == false); | |||||
ASSERT_TRUE(param["param_flt"] == 3.4f); | |||||
ASSERT_TRUE(param["param_int"] == 5); | |||||
ASSERT_TRUE(param["param_str"] == "tset"); | |||||
ASSERT_TRUE(param["param_bl"] == ParamVal({false, true, false, true})); | |||||
ASSERT_TRUE(param["param_fl"] == ParamVal({7.6f, 6.5f})); | |||||
ASSERT_TRUE(param["param_il"] == ParamVal({5, 4, 3, 2, 1})); | |||||
ASSERT_TRUE(param["param_sl"] == ParamVal({"1tset", "2tset", "3tset", "4tset", "5tset"})); | |||||
#if PARAM_TEST_LOG | |||||
Param copy_param = param; | |||||
for (auto &name: names) { | |||||
std::cout << copy_param[name].str() << std::endl; | |||||
} | |||||
#endif | |||||
Param loaded_param(info); | |||||
std::string bytes = param.to_bytes(); | |||||
loaded_param.from_bytes(bytes); | |||||
#if PARAM_TEST_LOG | |||||
for (auto &kv: loaded_param.raw()) { | |||||
std::cout << kv.first << ":\n" << kv.second.str() << std::endl; | |||||
} | |||||
#endif | |||||
} | |||||
} |
@@ -0,0 +1,325 @@ | |||||
/** | |||||
* \file src/custom/test/tensor.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/custom/tensor.h" | |||||
#include "megbrain/custom/data_adaptor.h" | |||||
#include "megbrain/comp_node.h" | |||||
#include "megbrain/tensor.h" | |||||
#include "gtest/gtest.h" | |||||
#include "megbrain_build_config.h" | |||||
#define TENSOR_TEST_LOG 0 | |||||
using namespace mgb; | |||||
namespace custom { | |||||
TEST(TestDevice, TestDevice) { | |||||
#if MGB_CUDA | |||||
ASSERT_TRUE(Device::is_legal("x86")); | |||||
ASSERT_TRUE(Device::is_legal(DeviceEnum::cuda)); | |||||
ASSERT_FALSE(Device::is_legal("cpu")); | |||||
Device dev1; | |||||
ASSERT_TRUE(dev1.str() == "invalid"); | |||||
dev1 = "x86"; | |||||
ASSERT_TRUE("x86" == dev1); | |||||
Device dev2 = "cuda"; | |||||
ASSERT_TRUE(dev2 == "cuda"); | |||||
ASSERT_FALSE(dev2 == dev1); | |||||
Device dev3 = dev2; | |||||
ASSERT_TRUE(dev3 == dev2); | |||||
ASSERT_FALSE(dev3 == dev1); | |||||
Device dev4 = DeviceEnum::cuda; | |||||
ASSERT_TRUE(dev4.enumv() == DeviceEnum::cuda); | |||||
#if TENSOR_TEST_LOG | |||||
std::cout << dev1.str() << "\n" << dev2.str() << "\n" | |||||
<< dev3.str() << "\n" << dev4.str() << std::endl; | |||||
#endif | |||||
CompNode compnode = to_builtin<CompNode, Device>(dev3); | |||||
ASSERT_TRUE(compnode.to_string_logical() == "gpux:0"); | |||||
compnode = CompNode::load("cpu0:0"); | |||||
Device dev5 = to_custom<CompNode, Device>(compnode); | |||||
ASSERT_TRUE(dev5.str() == "x86"); | |||||
std::vector<Device> devs1 = {"x86", "cuda", "x86"}; | |||||
megdnn::SmallVector<CompNode> compnodes = to_builtin<CompNode, Device>(devs1); | |||||
ASSERT_TRUE(compnodes[0].to_string_logical() == "cpux:0"); | |||||
ASSERT_TRUE(compnodes[1].to_string_logical() == "gpux:0"); | |||||
ASSERT_TRUE(compnodes[2].to_string_logical() == "cpux:0"); | |||||
std::vector<Device> devs2 = to_custom<CompNode, Device>(compnodes); | |||||
ASSERT_TRUE(devs2[0] == "x86"); | |||||
ASSERT_TRUE(devs2[1].str() == "cuda"); | |||||
ASSERT_TRUE(devs2[2] == "x86"); | |||||
#endif | |||||
} | |||||
TEST(TestShape, TestShape) { | |||||
Shape shape1, shape2; | |||||
ASSERT_TRUE(shape1.ndim() == 0); | |||||
shape1 = {16, 32, 8, 8}; | |||||
shape2 = shape1; | |||||
ASSERT_TRUE(shape2.ndim() == 4); | |||||
ASSERT_TRUE(shape2[0] == 16); | |||||
ASSERT_TRUE(shape2[1] == 32); | |||||
ASSERT_TRUE(shape2[2] == 8); | |||||
ASSERT_TRUE(shape2[3] == 8); | |||||
Shape shape3 = {16, 32, 8, 8}; | |||||
const Shape shape4 = shape1; | |||||
ASSERT_TRUE(shape3 == shape4); | |||||
shape3[0] = 32; | |||||
ASSERT_FALSE(shape3 == shape4); | |||||
ASSERT_TRUE(shape3[0] == 32); | |||||
ASSERT_TRUE(shape4[0] == 16); | |||||
Shape shape5 = {2, 3, 4}; | |||||
TensorShape bshape1 = to_builtin<TensorShape, Shape>(shape5); | |||||
ASSERT_TRUE(bshape1.ndim == 3); | |||||
ASSERT_TRUE(bshape1[0] == 2); | |||||
ASSERT_TRUE(bshape1[1] == 3); | |||||
ASSERT_TRUE(bshape1[2] == 4); | |||||
bshape1 = {4, 2, 3}; | |||||
Shape shape6 = to_custom<TensorShape, Shape>(bshape1); | |||||
ASSERT_TRUE(shape6.ndim() == 3); | |||||
ASSERT_TRUE(shape6[0] == 4); | |||||
ASSERT_TRUE(shape6[1] == 2); | |||||
ASSERT_TRUE(shape6[2] == 3); | |||||
Shape shape7; | |||||
shape7.ndim(3); | |||||
shape7[1] = 4; | |||||
ASSERT_TRUE(shape7 == Shape({0, 4, 0})); | |||||
std::vector<Shape> shapes1 = {{2, 3, 4}, {6}, {5, 7}}; | |||||
megdnn::SmallVector<TensorShape> bshapes = to_builtin<TensorShape, Shape>(shapes1); | |||||
ASSERT_TRUE(bshapes[0].total_nr_elems() == 2*3*4); | |||||
ASSERT_TRUE(bshapes[1].total_nr_elems() == 6); | |||||
ASSERT_TRUE(bshapes[2].total_nr_elems() == 35); | |||||
std::vector<Shape> shapes2 = to_custom<TensorShape, Shape>(bshapes); | |||||
ASSERT_TRUE(shapes2[0] == Shape({2, 3, 4})); | |||||
ASSERT_TRUE(shapes2[1] == Shape({6})); | |||||
ASSERT_TRUE(shapes2[2] == Shape({5, 7})); | |||||
} | |||||
TEST(TestDType, TestDType) { | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
ASSERT_TRUE(DType::is_legal("uint8")); | |||||
ASSERT_TRUE(DType::is_legal(DTypeEnum::bfloat16)); | |||||
DType dtype1, dtype2; | |||||
ASSERT_TRUE(dtype1.str() == "invalid"); | |||||
dtype1 = "float32"; | |||||
ASSERT_TRUE(dtype1.str() == "float32"); | |||||
dtype2 = dtype1; | |||||
DType dtype3 = dtype2; | |||||
ASSERT_TRUE(dtype3 == dtype1); | |||||
ASSERT_TRUE(dtype3 == "float32"); | |||||
dtype3 = "int8"; | |||||
ASSERT_FALSE("float32" == dtype3.str()); | |||||
ASSERT_FALSE(dtype3 == dtype2); | |||||
DType dtype4 = DTypeEnum::int8, dtype5 = dtype3; | |||||
ASSERT_TRUE(dtype4 == dtype5); | |||||
ASSERT_TRUE(dtype4.is_compatible<int8_t>()); | |||||
ASSERT_FALSE(dtype4.is_compatible<uint8_t>()); | |||||
DType dtype6 = "int32"; | |||||
megdnn::DType bdtype1 = to_builtin<megdnn::DType, DType>(dtype6); | |||||
ASSERT_TRUE(bdtype1.name() == std::string("Int32")); | |||||
bdtype1 = megdnn::DType::from_enum(megdnn::DTypeEnum::BFloat16); | |||||
DType dtype7 = to_custom<megdnn::DType, DType>(bdtype1); | |||||
ASSERT_TRUE(dtype7.enumv() == DTypeEnum::bfloat16); | |||||
std::vector<DType> dtypes1 = {"int8", "uint8", "float16"}; | |||||
megdnn::SmallVector<megdnn::DType> bdtypes | |||||
= to_builtin<megdnn::DType, DType>(dtypes1); | |||||
ASSERT_TRUE(bdtypes[0].name() == std::string("Int8")); | |||||
ASSERT_TRUE(bdtypes[1].name() == std::string("Uint8")); | |||||
ASSERT_TRUE(bdtypes[2].name() == std::string("Float16")); | |||||
std::vector<DType> dtypes2 = to_custom<megdnn::DType, DType>(bdtypes); | |||||
ASSERT_TRUE(dtypes2[0] == "int8"); | |||||
ASSERT_TRUE(dtypes2[1] == "uint8"); | |||||
ASSERT_TRUE(dtypes2[2] == "float16"); | |||||
#endif | |||||
} | |||||
TEST(TestDType, TestDTypeQuantized) { | |||||
DType quint8_1("quint8", 3.2, 15); | |||||
DType quint8_2("quint8", 3.2, 15); | |||||
DType quint8_3("quint8", 3.2, 16); | |||||
DType quint8_4("quint8", 3.1, 15); | |||||
ASSERT_TRUE(quint8_1 == quint8_2); | |||||
ASSERT_FALSE(quint8_1 == quint8_3); | |||||
ASSERT_FALSE(quint8_1 == quint8_4); | |||||
ASSERT_TRUE(quint8_1.scale() == 3.2f); | |||||
ASSERT_TRUE(quint8_1.zero_point() == 15); | |||||
DType qint8("qint8", 3.3f); | |||||
DType qint16("qint16", 3.4f); | |||||
DType qint32("qint32", 3.5f); | |||||
ASSERT_TRUE(qint8.scale() == 3.3f); | |||||
ASSERT_TRUE(qint16.scale() == 3.4f); | |||||
ASSERT_TRUE(qint32.scale() == 3.5f); | |||||
ASSERT_TRUE(qint8.enumv() == DTypeEnum::qint8); | |||||
ASSERT_TRUE(qint8.str() == "qint8"); | |||||
} | |||||
TEST(TestFormat, TestFormat) { | |||||
Format format1, format2("default"); | |||||
ASSERT_TRUE(format1.is_default()); | |||||
ASSERT_TRUE(format2.is_default()); | |||||
Format format3 = format1; | |||||
ASSERT_TRUE(format3.is_default()); | |||||
} | |||||
TEST(TestTensor, TestTensor) { | |||||
CompNode builtin_device = CompNode::load("cpux:0"); | |||||
TensorShape builtin_shape = {3, 2, 4}; | |||||
megdnn::DType builtin_dtype = dtype::Int32{}; | |||||
DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); | |||||
Tensor tensor1 = to_custom<DeviceTensorND, Tensor>(dev_tensor); | |||||
Tensor tensor2 = to_custom<DeviceTensorND, Tensor>(dev_tensor); | |||||
Device device = tensor1.device(); | |||||
Shape shape = tensor1.shape(); | |||||
DType dtype = tensor1.dtype(); | |||||
ASSERT_TRUE(device == "x86"); | |||||
ASSERT_TRUE(shape.ndim() == 3); | |||||
ASSERT_TRUE(shape[0] == 3); | |||||
ASSERT_TRUE(shape[1] == 2); | |||||
ASSERT_TRUE(shape[2] == 4); | |||||
ASSERT_TRUE(shape == std::vector<size_t>({3, 2, 4})); | |||||
ASSERT_TRUE(dtype == "int32"); | |||||
int *raw_ptr1 = tensor1.data<int>(); | |||||
for (size_t i=0; i<tensor1.size(); i++) | |||||
raw_ptr1[i] = i; | |||||
int *raw_ptr2 = tensor2.data<int>(); | |||||
for (size_t i=0; i<tensor2.size(); i++) | |||||
ASSERT_TRUE(raw_ptr2[i] == static_cast<int>(i)); | |||||
Tensor tensor3 = tensor2; | |||||
int *raw_ptr3 = tensor3.data<int>(); | |||||
for (size_t i=0; i<tensor3.size(); i++) | |||||
ASSERT_TRUE(raw_ptr3[i] == static_cast<int>(i)); | |||||
ASSERT_TRUE(raw_ptr1 == raw_ptr2); | |||||
ASSERT_TRUE(raw_ptr1 == raw_ptr3); | |||||
for (size_t i=0; i<tensor3.size(); i++) { | |||||
raw_ptr3[i] = -static_cast<int>(i); | |||||
} | |||||
for (size_t i=0; i<tensor1.size(); i++) { | |||||
ASSERT_TRUE(raw_ptr1[i] == -static_cast<int>(i)); | |||||
} | |||||
DeviceTensorND new_dev_tensor = to_builtin<DeviceTensorND, Tensor>(tensor3); | |||||
int *builtin_ptr = new_dev_tensor.ptr<int>(); | |||||
for (size_t i=0; i<new_dev_tensor.shape().total_nr_elems(); i++) { | |||||
ASSERT_TRUE(builtin_ptr[i] == -static_cast<int>(i)); | |||||
} | |||||
} | |||||
TEST(TestTensor, TestTensorQuantized) { | |||||
#if MGB_CUDA | |||||
CompNode builtin_device = CompNode::load("gpux:0"); | |||||
TensorShape builtin_shape = {3, 2, 4}; | |||||
megdnn::DType builtin_dtype = dtype::Quantized8Asymm{3.2f, uint8_t(15)}; | |||||
DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); | |||||
Tensor tensor1 = to_custom<DeviceTensorND, Tensor>(dev_tensor); | |||||
Tensor tensor2 = to_custom<DeviceTensorND, Tensor>(dev_tensor); | |||||
Device device1 = tensor1.device(), device2 = tensor2.device(); | |||||
Shape shape1 = tensor1.shape(), shape2 = tensor2.shape(); | |||||
DType dtype1 = tensor1.dtype(), dtype2 = tensor2.dtype(); | |||||
ASSERT_TRUE(device1 == "cuda"); | |||||
ASSERT_TRUE(shape1.ndim() == 3); | |||||
ASSERT_TRUE(shape1[0] == 3); | |||||
ASSERT_TRUE(shape1[1] == 2); | |||||
ASSERT_TRUE(shape1[2] == 4); | |||||
ASSERT_TRUE(shape1 == std::vector<size_t>({3, 2, 4})); | |||||
ASSERT_TRUE(dtype1 == "quint8"); | |||||
ASSERT_TRUE(dtype1.scale() == 3.2f); | |||||
ASSERT_TRUE(dtype1.zero_point() == 15); | |||||
ASSERT_TRUE(device1 == device2); | |||||
ASSERT_TRUE(shape1 == shape2); | |||||
ASSERT_TRUE(dtype1 == dtype2); | |||||
#endif | |||||
} | |||||
TEST(TestTensor, TestTensorAccessorND) { | |||||
size_t N = 2, C = 4, H = 6, W = 8; | |||||
CompNode builtin_device = CompNode::load("cpux"); | |||||
TensorShape builtin_shape = {N, C, H, W}; | |||||
megdnn::DType builtin_dtype = dtype::Int32{}; | |||||
DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); | |||||
int *builtin_ptr = dev_tensor.ptr<int>(); | |||||
for (size_t i=0; i<dev_tensor.shape().total_nr_elems(); i++) { | |||||
builtin_ptr[i] = i; | |||||
} | |||||
Tensor tensor = to_custom_tensor(dev_tensor); | |||||
auto accessor = tensor.accessor<int32_t, 4>(); | |||||
for (size_t n=0; n<N; ++n) { | |||||
for (size_t c=0; c<C; ++c) { | |||||
for (size_t h=0; h<H; ++h) { | |||||
for (size_t w=0; w<W; ++w) { | |||||
int32_t idx = n*C*H*W + c*H*W + h*W + w; | |||||
ASSERT_TRUE(accessor[n][c][h][w] == idx); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
TEST(TestTensor, TestTensorAccessor1D) { | |||||
CompNode builtin_device = CompNode::load("cpux"); | |||||
TensorShape builtin_shape = {32}; | |||||
megdnn::DType builtin_dtype = dtype::Float32{}; | |||||
DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); | |||||
float *builtin_ptr = dev_tensor.ptr<float>(); | |||||
for (size_t i=0; i<dev_tensor.shape().total_nr_elems(); i++) { | |||||
builtin_ptr[i] = i; | |||||
} | |||||
Tensor tensor = to_custom_tensor(dev_tensor); | |||||
auto accessor = tensor.accessor<float, 1>(); | |||||
for (size_t n=0; n<32; ++n) { | |||||
ASSERT_TRUE(accessor[n] == n); | |||||
} | |||||
} | |||||
} |
@@ -18,7 +18,7 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomOpNode); | |||||
void CustomOpNode::infer_output_comp_node(void) { | void CustomOpNode::infer_output_comp_node(void) { | ||||
SmallVector<CompNode> input_comp_nodes(input_num()); | SmallVector<CompNode> input_comp_nodes(input_num()); | ||||
for (int i=0; i<input_num(); ++i) { | |||||
for (size_t i=0; i<input_num(); ++i) { | |||||
input_comp_nodes[i] = input(i)->comp_node(); | input_comp_nodes[i] = input(i)->comp_node(); | ||||
} | } | ||||
@@ -28,7 +28,7 @@ void CustomOpNode::infer_output_comp_node(void) { | |||||
) | ) | ||||
); | ); | ||||
for (int i=0; i<output_num(); ++i) { | |||||
for (size_t i=0; i<output_num(); ++i) { | |||||
mgb_assert(output_comp_nodes[i] == output_comp_nodes[0], | mgb_assert(output_comp_nodes[i] == output_comp_nodes[0], | ||||
"only single comp node operator is supported"); | "only single comp node operator is supported"); | ||||
output(i)->comp_node(output_comp_nodes[i]); | output(i)->comp_node(output_comp_nodes[i]); | ||||
@@ -39,7 +39,7 @@ void CustomOpNode::infer_output_comp_node(void) { | |||||
void CustomOpNode::infer_output_dtype(void) { | void CustomOpNode::infer_output_dtype(void) { | ||||
SmallVector<DType> input_dtypes(input_num()); | SmallVector<DType> input_dtypes(input_num()); | ||||
for (int i=0; i<input_num(); ++i) { | |||||
for (size_t i=0; i<input_num(); ++i) { | |||||
input_dtypes[i] = input(i)->dtype(); | input_dtypes[i] = input(i)->dtype(); | ||||
} | } | ||||
@@ -49,14 +49,14 @@ void CustomOpNode::infer_output_dtype(void) { | |||||
) | ) | ||||
); | ); | ||||
for (int i=0; i<output_num(); ++i) { | |||||
for (size_t i=0; i<output_num(); ++i) { | |||||
output(i)->dtype(output_dtypes[i]); | output(i)->dtype(output_dtypes[i]); | ||||
} | } | ||||
} | } | ||||
void CustomOpNode::infer_output_format(void) { | void CustomOpNode::infer_output_format(void) { | ||||
SmallVector<TensorFormat> input_formats(input_num()); | SmallVector<TensorFormat> input_formats(input_num()); | ||||
for (int i=0; i<input_num(); ++i) { | |||||
for (size_t i=0; i<input_num(); ++i) { | |||||
input_formats[i] = input(i)->format(); | input_formats[i] = input(i)->format(); | ||||
} | } | ||||
@@ -66,14 +66,14 @@ void CustomOpNode::infer_output_format(void) { | |||||
) | ) | ||||
); | ); | ||||
for (int i=0; i<output_num(); ++i) { | |||||
for (size_t i=0; i<output_num(); ++i) { | |||||
output(i)->format(output_formats[i]); | output(i)->format(output_formats[i]); | ||||
} | } | ||||
} | } | ||||
void CustomOpNode::infer_output_shape(void) { | void CustomOpNode::infer_output_shape(void) { | ||||
SmallVector<TensorShape> input_shapes(input_num()); | SmallVector<TensorShape> input_shapes(input_num()); | ||||
for (int i=0; i<input_num(); ++i) { | |||||
for (size_t i=0; i<input_num(); ++i) { | |||||
input_shapes[i] = input(i)->shape(); | input_shapes[i] = input(i)->shape(); | ||||
} | } | ||||
@@ -83,7 +83,7 @@ void CustomOpNode::infer_output_shape(void) { | |||||
) | ) | ||||
); | ); | ||||
for (int i=0; i<output_num(); ++i) { | |||||
for (size_t i=0; i<output_num(); ++i) { | |||||
output(i)->shape(output_shapes[i]); | output(i)->shape(output_shapes[i]); | ||||
} | } | ||||
} | } | ||||
@@ -235,10 +235,10 @@ CustomOpNode::CustomOpNode(const std::shared_ptr<const custom::CustomOp> &op, | |||||
const OperatorNodeConfig &config): | const OperatorNodeConfig &config): | ||||
OperatorNodeBase(inputs[0]->owner_graph(), config, op->op_type(), inputs), m_op(op), m_param(param) { | OperatorNodeBase(inputs[0]->owner_graph(), config, op->op_type(), inputs), m_op(op), m_param(param) { | ||||
mgb_assert(input_num() == inputs.size(), "wrong input tensors list length"); | mgb_assert(input_num() == inputs.size(), "wrong input tensors list length"); | ||||
for (int i=0; i < input_num(); ++i) | |||||
for (size_t i=0; i < input_num(); ++i) | |||||
add_input({inputs[i]}); | add_input({inputs[i]}); | ||||
for (int i=0; i<output_num(); ++i) | |||||
for (size_t i=0; i<output_num(); ++i) | |||||
add_output(output_info(i).name()); | add_output(output_info(i).name()); | ||||
if (!std::is_empty<custom::Param>::value) { | if (!std::is_empty<custom::Param>::value) { | ||||
@@ -306,11 +306,11 @@ std::string CustomOpNode::op_desc(void) const { | |||||
return m_op->op_desc(); | return m_op->op_desc(); | ||||
} | } | ||||
int CustomOpNode::input_num(void) const { | |||||
size_t CustomOpNode::input_num(void) const { | |||||
return m_op->input_num(); | return m_op->input_num(); | ||||
} | } | ||||
int CustomOpNode::output_num(void) const { | |||||
size_t CustomOpNode::output_num(void) const { | |||||
return m_op->output_num(); | return m_op->output_num(); | ||||
} | } | ||||
@@ -93,8 +93,8 @@ public: | |||||
custom::Param param(void) const; | custom::Param param(void) const; | ||||
std::string op_type(void) const; | std::string op_type(void) const; | ||||
std::string op_desc(void) const; | std::string op_desc(void) const; | ||||
int input_num(void) const; | |||||
int output_num(void) const; | |||||
size_t input_num(void) const; | |||||
size_t output_num(void) const; | |||||
custom::ArgInfo input_info(size_t idx) const; | custom::ArgInfo input_info(size_t idx) const; | ||||
custom::ArgInfo output_info(size_t idx) const; | custom::ArgInfo output_info(size_t idx) const; | ||||
}; | }; | ||||
@@ -1,7 +1,7 @@ | |||||
include_directories("./src/include") | include_directories("./src/include") | ||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") | ||||
file(GLOB_RECURSE SOURCES ./*.cpp ../src/core/test/*.cpp ../src/gopt/test/*.cpp ../src/opr/test/*.cpp ../src/plugin/test/*.cpp ../src/serialization/test/*.cpp) | |||||
file(GLOB_RECURSE SOURCES ./*.cpp ../src/core/test/*.cpp ../src/gopt/test/*.cpp ../src/opr/test/*.cpp ../src/plugin/test/*.cpp ../src/serialization/test/*.cpp ../src/custom/test/*.cpp) | |||||
if(MGE_WITH_JIT) | if(MGE_WITH_JIT) | ||||
file(GLOB_RECURSE SOURCES_ ../src/jit/test/*.cpp) | file(GLOB_RECURSE SOURCES_ ../src/jit/test/*.cpp) | ||||
list(APPEND SOURCES ${SOURCES_}) | list(APPEND SOURCES ${SOURCES_}) | ||||