GitOrigin-RevId: 4ade455897
release-1.6
@@ -613,17 +613,17 @@ void init_ops(py::module m) { | |||
} | |||
#define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type) \ | |||
case mgb::custom::ParamDynType::dyn_type: { \ | |||
case custom::ParamDynType::dyn_type: { \ | |||
param_val = py::handle(kv.second).cast<static_type>(); \ | |||
break; \ | |||
} | |||
#define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type) \ | |||
case mgb::custom::ParamDynType::dyn_type: { \ | |||
case custom::ParamDynType::dyn_type: { \ | |||
auto pyvals = py::handle(kv.second).cast<py::list>(); \ | |||
static_type vals; \ | |||
using basic_type = \ | |||
mgb::custom::get_vector_template_arg_type<static_type>::type; \ | |||
custom::get_vector_template_arg_type<static_type>::type; \ | |||
for (auto &pyval: pyvals) { \ | |||
vals.push_back(py::handle(pyval).cast<basic_type>()); \ | |||
} \ | |||
@@ -631,7 +631,7 @@ void init_ops(py::module m) { | |||
break; \ | |||
} | |||
PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs, PyObject *kwnames) { | |||
PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs) { | |||
auto op_name = py::handle(args[0]).cast<std::string>(); | |||
auto kwargs = py::handle(args[1]).cast<py::dict>(); | |||
@@ -680,7 +680,7 @@ PyObject *make_custom_op(PyObject *self, PyObject **args, Py_ssize_t nargs, PyOb | |||
py::list install_custom(const std::string &name, const std::string &path) { | |||
py::list ret; | |||
const auto &ops_in_lib = mgb::custom::LibManager::inst()->install(name, path); | |||
const auto &ops_in_lib = custom::LibManager::inst()->install(name, path); | |||
for (const auto &op: ops_in_lib) { | |||
ret.append(op); | |||
} | |||
@@ -688,7 +688,7 @@ py::list install_custom(const std::string &name, const std::string &path) { | |||
} | |||
bool uninstall_custom(const std::string &name) { | |||
return mgb::custom::LibManager::inst()->uninstall(name); | |||
return custom::LibManager::inst()->uninstall(name); | |||
} | |||
py::list get_custom_op_list(void) { | |||
@@ -697,16 +697,28 @@ py::list get_custom_op_list(void) { | |||
for (auto &op: all_ops) { | |||
ret.append(op); | |||
} | |||
return std::move(ret); | |||
return ret; | |||
} | |||
#ifndef METH_FASTCALL | |||
PyObject* py35_make_custom_op(PyObject* self, PyObject* args) { | |||
auto* arr = &PyTuple_GET_ITEM(args, 0); | |||
auto size = PyTuple_GET_SIZE(args); | |||
return make_custom_op(self, arr, size); | |||
}; | |||
#endif | |||
void init_custom(pybind11::module m) { | |||
m.def("_install", &install_custom); | |||
m.def("_uninstall", &uninstall_custom); | |||
m.def("_get_custom_op_list", &get_custom_op_list); | |||
static PyMethodDef method_def = { | |||
#ifdef METH_FASTCALL | |||
"_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, "" | |||
#else | |||
"_make_custom_op", (PyCFunction)py35_make_custom_op, METH_VARARGS, "" | |||
#endif | |||
}; | |||
auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr); | |||
pybind11::setattr(m, method_def.ml_name, func); | |||
@@ -70,7 +70,7 @@ void CustomOpDef::compute(const SmallVector<DeviceTensorND> &inputs, | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs( | |||
const SmallVector<TensorPtr> &inputs) const { | |||
SmallVector<LogicalTensorDesc> input_descs(inputs.size()); | |||
for (int i=0; i<inputs.size(); i++) { | |||
for (size_t i=0; i<inputs.size(); i++) { | |||
input_descs[i].comp_node = inputs[i]->comp_node(); | |||
input_descs[i].layout = inputs[i]->layout(); | |||
} | |||
@@ -84,7 +84,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs | |||
SmallVector<megdnn::DType> i_dtypes(inputs.size()); | |||
SmallVector<TensorFormat> i_formats(inputs.size()); | |||
for (int i=0; i<inputs.size(); i++) { | |||
for (size_t i=0; i<inputs.size(); i++) { | |||
i_devices[i] = inputs[i].comp_node; | |||
i_shapes[i] = inputs[i].layout; // TensorLayout is derived from TensorShape | |||
i_dtypes[i] = inputs[i].layout.dtype; | |||
@@ -132,7 +132,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs | |||
} | |||
SmallVector<LogicalTensorDesc> outputs(this->output_num()); | |||
for (int i=0; i<this->output_num(); i++) { | |||
for (size_t i=0; i<this->output_num(); i++) { | |||
outputs[i].comp_node = std::move(o_devices[i]); | |||
outputs[i].layout = std::move( | |||
TensorLayout(o_shapes[i], o_dtypes[i], o_formats[i]) | |||
@@ -0,0 +1,181 @@ | |||
/** | |||
* \file src/custom/impl/manager.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/custom/manager.h" | |||
#include "megbrain/common.h" | |||
#include <unordered_set> | |||
#ifndef _WIN32 | |||
#include <dlfcn.h> | |||
#endif | |||
using namespace mgb; | |||
namespace custom { | |||
CustomOpManager *CustomOpManager::inst(void) { | |||
static CustomOpManager op_manager; | |||
return &op_manager; | |||
} | |||
CustomOpManager::~CustomOpManager() { | |||
mgb_assert(m_name2op.size() == m_id2op.size(), "Custom Op maintenance error!"); | |||
LibManager::inst()->m_custom_libs.clear(); | |||
} | |||
std::shared_ptr<CustomOp> CustomOpManager::insert(const std::string &name, uint32_t version) { | |||
MGB_LOCK_GUARD(m_mtx); | |||
auto iter = m_name2op.find(name); | |||
if (iter != m_name2op.end()) { | |||
mgb_log_warn("Register Custom Op Failed! Op %s has been registered", name.c_str()); | |||
return std::const_pointer_cast<CustomOp, const CustomOp>(iter->second); | |||
} | |||
std::shared_ptr<const CustomOp> op = std::make_shared<const CustomOp>(name, version); | |||
m_name2op[op->op_type()] = op; | |||
m_id2op[op->runtime_id()] = op; | |||
return std::const_pointer_cast<CustomOp, const CustomOp>(op); | |||
} | |||
bool CustomOpManager::erase(const std::string &name) { | |||
MGB_LOCK_GUARD(m_mtx); | |||
auto iter = m_name2op.find(name); | |||
if (iter == m_name2op.end()) { | |||
mgb_log_warn("Erase Custom Op Failed! %s has not been registered", name.c_str()); | |||
return false; | |||
} | |||
std::shared_ptr<const CustomOp> op = iter->second; | |||
m_id2op.erase(op->runtime_id()); | |||
m_name2op.erase(op->op_type()); | |||
return true; | |||
} | |||
bool CustomOpManager::erase(const RunTimeId &id) { | |||
MGB_LOCK_GUARD(m_mtx); | |||
auto iter = m_id2op.find(id); | |||
if (iter == m_id2op.end()) { | |||
mgb_log_warn("Erase Custom Op Failed! The Op has not been registered"); | |||
return false; | |||
} | |||
std::shared_ptr<const CustomOp> op = iter->second; | |||
m_id2op.erase(op->runtime_id()); | |||
m_name2op.erase(op->op_type()); | |||
return true; | |||
} | |||
std::shared_ptr<CustomOp> CustomOpManager::find_or_reg(const std::string &name, uint32_t version) { | |||
auto iter = m_name2op.find(name); | |||
if (iter == m_name2op.end()) { | |||
return insert(name, version); | |||
} | |||
return std::const_pointer_cast<CustomOp, const CustomOp>(iter->second); | |||
} | |||
RunTimeId CustomOpManager::to_id(const std::string &name) const { | |||
std::shared_ptr<const CustomOp> op = find(name); | |||
return op->runtime_id(); | |||
} | |||
std::string CustomOpManager::to_name(const RunTimeId &id) const { | |||
std::shared_ptr<const CustomOp> op = find(id); | |||
return op->op_type(); | |||
} | |||
std::shared_ptr<const CustomOp> CustomOpManager::find(const std::string &name) const { | |||
auto ret = m_name2op.find(name); | |||
mgb_assert(ret != m_name2op.end(), | |||
"Find Custom Op Failed! Op %s has not been registered", name.c_str() | |||
); | |||
return ret->second; | |||
} | |||
std::shared_ptr<const CustomOp> CustomOpManager::find(const RunTimeId &id) const { | |||
auto ret = m_id2op.find(id); | |||
mgb_assert(ret != m_id2op.end(), "Find Custom Op Failed! Op has not been registered"); | |||
return ret->second; | |||
} | |||
std::vector<std::string> CustomOpManager::op_name_list(void) { | |||
std::vector<std::string> ret; | |||
for (auto kv: m_name2op) { | |||
ret.emplace_back(kv.first); | |||
} | |||
return ret; | |||
} | |||
std::vector<RunTimeId> CustomOpManager::op_id_list(void) { | |||
std::vector<RunTimeId> ret; | |||
for (auto kv: m_id2op) { | |||
ret.emplace_back(kv.first); | |||
} | |||
return ret; | |||
} | |||
#ifndef _WIN32 | |||
CustomLib::CustomLib(const std::string &path, int mode = RTLD_LAZY) | |||
: m_handle(nullptr, [](void* handle) {dlclose(handle);}) { | |||
auto op_list_before_load = CustomOpManager::inst()->op_name_list(); | |||
std::unordered_set<std::string> op_set_before_load( | |||
op_list_before_load.begin(), op_list_before_load.end()); | |||
m_handle.reset(dlopen(path.c_str(), mode)); | |||
mgb_assert(m_handle != nullptr, "open custom op lib failed, error type: %s", dlerror()); | |||
auto op_list_after_load = CustomOpManager::inst()->op_name_list(); | |||
for (auto &op: op_list_after_load) { | |||
if (op_set_before_load.find(op) == op_set_before_load.end()) { | |||
m_ops.emplace_back(op); | |||
} | |||
} | |||
} | |||
#else | |||
CustomLib::CustomLib(const std::string &path, int mode = 0) | |||
: m_handle(nullptr, [](void* handle) {}) { | |||
mgb_assert(false, "custom op is only supported on Linux now"); | |||
} | |||
#endif | |||
const std::vector<std::string> &CustomLib::ops_in_lib(void) const { | |||
return m_ops; | |||
} | |||
CustomLib::~CustomLib() { | |||
for (auto &op: m_ops) { | |||
CustomOpManager::inst()->erase(op); | |||
} | |||
} | |||
bool CustomLib::valid() const { | |||
return m_handle != nullptr; | |||
} | |||
LibManager *LibManager::inst(void) { | |||
static LibManager custom_libs; | |||
return &custom_libs; | |||
} | |||
const std::vector<std::string> &LibManager::install(const std::string &name, const std::string &path) { | |||
MGB_LOCK_GUARD(m_mtx);; | |||
LibHandle handle = std::make_shared<CustomLib>(path); | |||
m_custom_libs.insert({name, handle}); | |||
return m_custom_libs[name]->ops_in_lib(); | |||
} | |||
bool LibManager::uninstall(const std::string &name) { | |||
MGB_LOCK_GUARD(m_mtx);; | |||
mgb_assert(m_custom_libs.erase(name) == 1, "uninstall error"); | |||
return true; | |||
} | |||
std::shared_ptr<CustomOp> op_insert(std::string opname, uint32_t version) { | |||
return CustomOpManager::inst()->insert(opname, version); | |||
} | |||
} |
@@ -0,0 +1,531 @@ | |||
/** | |||
* \file src/custom/impl/op.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/common.h" | |||
#include "megbrain/custom/op.h" | |||
#include "megbrain/custom/utils.h" | |||
#include <unordered_set> | |||
#include <sstream> | |||
using namespace mgb; | |||
namespace custom { | |||
class ArgInfoImpl { | |||
std::string m_name; | |||
std::string m_desc; | |||
std::unordered_set<std::string> m_dtypes; | |||
int m_ndim; // use int rather than size_t for representing m_dims = -1 | |||
std::string m_mem_stgy; | |||
friend class ArgInfo; | |||
}; | |||
CUSTOM_PIMPL_CLS_DEFINE(ArgInfo) | |||
ArgInfo::ArgInfo(const std::string &name, | |||
const std::string &desc, | |||
const std::unordered_set<std::string> &dtypes, | |||
const int &ndim, | |||
const std::string &mem_stgy): m_impl(new ArgInfoImpl(), impl_deleter<ArgInfoImpl>) { | |||
for (auto &&dtype: dtypes) { | |||
mgb_assert(DType::is_legal(dtype), "unsupported tensor data type: %s", dtype.c_str()); | |||
} | |||
mgb_assert(mem_stgy == "default", "only default mem strategy is supported now!"); | |||
TypedRef(ArgInfoImpl, m_impl.get()).m_name = name; | |||
TypedRef(ArgInfoImpl, m_impl.get()).m_desc = desc; | |||
TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes = dtypes; | |||
TypedRef(ArgInfoImpl, m_impl.get()).m_ndim = ndim; | |||
TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy = mem_stgy; | |||
} | |||
const std::string &ArgInfo::name(void) const { | |||
return TypedRef(ArgInfoImpl, m_impl.get()).m_name; | |||
} | |||
const std::string &ArgInfo::desc(void) const { | |||
return TypedRef(ArgInfoImpl, m_impl.get()).m_desc; | |||
} | |||
const std::unordered_set<std::string> &ArgInfo::dtypes(void) const { | |||
return TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes; | |||
} | |||
int ArgInfo::ndim(void) const { | |||
return TypedRef(ArgInfoImpl, m_impl.get()).m_ndim; | |||
} | |||
const std::string &ArgInfo::mem_strategy(void) const { | |||
return TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy; | |||
} | |||
std::string ArgInfo::str() const { | |||
std::stringstream ss; | |||
ss << "name: " << TypedRef(ArgInfoImpl, m_impl.get()).m_name << "\n" | |||
<< "desc: " << TypedRef(ArgInfoImpl, m_impl.get()).m_desc << "\nlegal_dtypes: {"; | |||
for (auto &val: TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes) { | |||
ss << val << ", "; | |||
} | |||
if (TypedRef(ArgInfoImpl, m_impl.get()).m_dtypes.size() != 0) { | |||
ss.seekp(ss.tellp()-std::streampos(2)); | |||
} | |||
ss << "}\ndims: " << TypedRef(ArgInfoImpl, m_impl.get()).m_ndim << "\n" | |||
<< "memory_strategy: " << TypedRef(ArgInfoImpl, m_impl.get()).m_mem_stgy; | |||
return ss.str(); | |||
} | |||
#define assert_inputs_size_right(inputs_vec) \ | |||
mgb_assert( \ | |||
inputs_vec.size() == input_num(), \ | |||
"op %s need %lu inputs but given %lu", \ | |||
op_type().c_str(), static_cast<unsigned long>(input_num()), \ | |||
static_cast<unsigned long>(inputs_vec.size()) \ | |||
) | |||
#define assert_outputs_size_right(outputs_vec) \ | |||
mgb_assert( \ | |||
outputs_vec.size() == output_num(), \ | |||
"op %s have %lu outputs but given %lu", \ | |||
op_type().c_str(), static_cast<unsigned long>(output_num()), \ | |||
static_cast<unsigned long>(outputs_vec.size()) \ | |||
) | |||
#define assert_arg_shape_dim_right(real_shape, arg_info) \ | |||
mgb_assert( \ | |||
(arg_info).ndim() == -1 || static_cast<int>((real_shape).ndim()) == \ | |||
static_cast<int>((arg_info).ndim()), \ | |||
"%s's args: %s dim match error, need %d but given %d", op_type().c_str(), \ | |||
(arg_info).name().c_str(), static_cast<int>((arg_info).ndim()), \ | |||
static_cast<int>((real_shape).ndim()) \ | |||
) | |||
template <typename T> | |||
class Function; | |||
template<typename RType, typename... Args> | |||
class Function<RType(Args...)> { | |||
public: | |||
using Functor = RType (*)(Args...); | |||
Function() = default; | |||
Function(Functor f): m_f(f) {} | |||
Function(const Function &rhs) { | |||
m_f = rhs.m_f; | |||
} | |||
RType operator()(Args... args) { | |||
custom_assert(m_f != nullptr, "invalid function ptr\n"); | |||
return m_f(std::forward<Args>(args)...); | |||
} | |||
void operator=(const Function &rhs) { // not allowed continuous assignment | |||
m_f = rhs.m_f; | |||
} | |||
void operator=(const Functor f) { | |||
m_f = f; | |||
} | |||
private: | |||
Functor m_f = nullptr; | |||
}; | |||
template <typename Functions> | |||
class FuncWithSig: public Functions { | |||
public: | |||
using Functions::operator(); | |||
using Functions::operator=; | |||
}; | |||
class CustomOpImpl { | |||
static constexpr uint32_t CURRENT_VERSION = CUSTOM_OP_VERSION; | |||
const uint32_t m_version; | |||
const std::string m_op_type; | |||
std::string m_op_desc; | |||
std::vector<ArgInfo> m_input_infos; | |||
std::vector<ArgInfo> m_output_infos; | |||
ParamInfo m_param_infos; | |||
using DeviceInfer = FuncWithSig<Function<void(const std::vector<Device>&, const Param&, std::vector<Device>&)>>; | |||
using ShapeInfer = FuncWithSig<Function<void(const std::vector<Shape>&, const Param&, std::vector<Shape>&)>>; | |||
using DTypeInfer = FuncWithSig<Function<void(const std::vector<DType>&, const Param&, std::vector<DType>&)>>; | |||
using FormatInfer = FuncWithSig<Function<void(const std::vector<Format>&, const Param&, std::vector<Format>&)>>; | |||
using Preprocess = FuncWithSig<Function<void(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>; | |||
using Postprocess = FuncWithSig<Function<void(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>; | |||
using Compute = FuncWithSig<Function<void(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>; | |||
DeviceInfer infer_output_device_func; | |||
ShapeInfer infer_output_shape_func; | |||
DTypeInfer infer_output_dtype_func; | |||
FormatInfer infer_output_format_func; | |||
std::unordered_map<std::string, Compute> compute_funcs; | |||
std::unordered_map<std::string, Preprocess> preprocess_funcs; | |||
std::unordered_map<std::string, Postprocess> postprocess_funcs; | |||
public: | |||
CustomOpImpl(const std::string&, uint32_t version); | |||
PREVENT_COPY_AND_ASSIGN(CustomOpImpl); | |||
friend CustomOp; | |||
}; | |||
CustomOpImpl::CustomOpImpl(const std::string &op_type, uint32_t version) | |||
: m_version(version), m_op_type(op_type) { | |||
if (m_version != CURRENT_VERSION) { | |||
mgb_log_warn( | |||
"the version of loaded custom op %s is %u, but custom op version " | |||
"of the system is %u\n", op_type.c_str(), m_version, CURRENT_VERSION | |||
); | |||
} | |||
infer_output_device_func = [](const std::vector<Device> &inputs, | |||
const Param&, | |||
std::vector<Device> &outputs) -> void { | |||
static UnImpleWarnLog log_once("output_device_infer", "device", "x86"); | |||
for (size_t i=0; i<outputs.size(); ++i) { | |||
outputs[i] = inputs.size() > 0 ? inputs[0] : Device("x86"); | |||
} | |||
}; | |||
infer_output_shape_func = [](const std::vector<Shape> &inputs, | |||
const Param&, | |||
std::vector<Shape> &outputs) -> void { | |||
static UnImpleWarnLog log_once("output_shape_infer", "shape", "{1}"); | |||
for (size_t i=0; i<outputs.size(); ++i) { | |||
outputs[i] = inputs.size() > 0 ? inputs[0] : Shape({1}); | |||
} | |||
}; | |||
infer_output_dtype_func = [](const std::vector<DType> &inputs, | |||
const Param&, | |||
std::vector<DType> &outputs) -> void { | |||
static UnImpleWarnLog log_once("output_dtype_infer", "dtype", "float32"); | |||
for (size_t i=0; i<outputs.size(); ++i) { | |||
outputs[i] = inputs.size() > 0 ? inputs[0] : DType("float32"); | |||
} | |||
}; | |||
infer_output_format_func = [](const std::vector<Format> &inputs, | |||
const Param&, | |||
std::vector<Format> &outputs) -> void { | |||
for (size_t i=0; i<outputs.size(); ++i) { | |||
outputs[i] = inputs.size() > 0 ? inputs[0] : Format("default"); | |||
} | |||
}; | |||
for (const auto &device: Device::legal_devices()) { | |||
compute_funcs[device] = [](const std::vector<Tensor>&, const Param&, std::vector<Tensor> &outputs) -> void { | |||
auto device = outputs[0].device(); | |||
mgb_assert(false, "There is no forward function for your op on device `%s`. " | |||
"Please implement this function and register it.", device.str().c_str()); | |||
}; | |||
preprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) -> void { | |||
return; | |||
}; | |||
postprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) -> void { | |||
return; | |||
}; | |||
} | |||
m_param_infos.set_tag(op_type); | |||
} | |||
CustomOp::CustomOp(const std::string &op_type, uint32_t version) | |||
: m_impl(new CustomOpImpl(op_type, version), impl_deleter<CustomOpImpl>) { | |||
} | |||
#define OpImplRef(raw_ptr) reinterpret_cast<CustomOpImpl*>(raw_ptr) | |||
CustomOp &CustomOp::set_device_infer(DeviceInferFuncPtr func) { | |||
OpImplRef(m_impl.get())->infer_output_device_func = func; | |||
return *this; | |||
} | |||
CustomOp &CustomOp::set_shape_infer(ShapeInferFuncPtr func) { | |||
OpImplRef(m_impl.get())->infer_output_shape_func = func; | |||
return *this; | |||
} | |||
CustomOp &CustomOp::set_dtype_infer(DTypeInferFuncPtr func) { | |||
OpImplRef(m_impl.get())->infer_output_dtype_func = func; | |||
return *this; | |||
} | |||
CustomOp &CustomOp::set_format_infer(FormatInferFuncPtr func) { | |||
OpImplRef(m_impl.get())->infer_output_format_func = func; | |||
return *this; | |||
} | |||
CustomOp &CustomOp::set_preprocess(PreprocessFuncPtr func) { | |||
set_preprocess("x86", func); | |||
return *this; | |||
} | |||
CustomOp &CustomOp::set_preprocess(const std::string &device, PreprocessFuncPtr func) { | |||
OpImplRef(m_impl.get())->preprocess_funcs[device] = func; | |||
return *this; | |||
} | |||
CustomOp &CustomOp::set_postprocess(PostprocessFuncPtr func) { | |||
set_postprocess("x86", func); | |||
return *this; | |||
} | |||
CustomOp &CustomOp::set_postprocess(const std::string &device, PostprocessFuncPtr func) { | |||
OpImplRef(m_impl.get())->postprocess_funcs[device] = func; | |||
return *this; | |||
} | |||
CustomOp &CustomOp::set_compute(ComputeFuncPtr func) { | |||
set_compute("x86", func); | |||
return *this; | |||
} | |||
CustomOp &CustomOp::set_compute(const std::string &device, ComputeFuncPtr func) { | |||
OpImplRef(m_impl.get())->compute_funcs[device] = func; | |||
return *this; | |||
} | |||
CustomOp &CustomOp::set_description(const std::string &op_desc) { | |||
OpImplRef(m_impl.get())->m_op_desc = op_desc; | |||
return *this; | |||
} | |||
CustomOp &CustomOp::add_input(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) { | |||
auto &ref = OpImplRef(m_impl.get())->m_input_infos; | |||
for (const auto &input: ref) { | |||
mgb_assert(input.name() != name, "input %s has been registered", name.c_str()); | |||
} | |||
ref.emplace_back(name, desc, legal_dtypes, dims, mem_stgy); | |||
return *this; | |||
} | |||
CustomOp &CustomOp::add_output(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) { | |||
auto &ref = OpImplRef(m_impl.get())->m_output_infos; | |||
for (const auto &output: ref) { | |||
mgb_assert(output.name() != name, "output %s has been registered", name.c_str()); | |||
} | |||
ref.emplace_back(name, desc, legal_dtypes, dims, mem_stgy); | |||
return *this; | |||
} | |||
CustomOp &CustomOp::add_input(const std::string &name, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) { | |||
add_input(name, name, legal_dtypes, dims, mem_stgy); | |||
return *this; | |||
} | |||
CustomOp &CustomOp::add_output(const std::string &name, const std::initializer_list<std::string> &legal_dtypes, int dims, const std::string &mem_stgy) { | |||
add_output(name, name, legal_dtypes, dims, mem_stgy); | |||
return *this; | |||
} | |||
CustomOp &CustomOp::add_inputs(const size_t &num) { | |||
size_t cur_inp_num = input_num(); | |||
for (size_t i=cur_inp_num; i<cur_inp_num+num; i++) { | |||
add_input(op_type() + "_Input_" + std::to_string(i)); | |||
} | |||
return *this; | |||
} | |||
CustomOp &CustomOp::add_outputs(const size_t &num) { | |||
size_t cur_oup_num = output_num(); | |||
for (size_t i=cur_oup_num; i<cur_oup_num+num; i++) { | |||
add_output(op_type() + "_Output_" + std::to_string(i)); | |||
} | |||
return *this; | |||
} | |||
CustomOp &CustomOp::add_param(const std::string &name, const ParamVal &default_val) { | |||
add_param(name, name, default_val); | |||
return *this; | |||
} | |||
CustomOp &CustomOp::add_param(const std::string &name, const std::string &desc, const ParamVal &default_val) { | |||
auto &meta = OpImplRef(m_impl.get())->m_param_infos.meta(); | |||
for(const auto &schema: meta) { | |||
mgb_assert(name != schema.name(), "param %s has been registered\n", name.c_str()); | |||
} | |||
ParamSchema sch = ParamSchema(name, default_val, desc); | |||
meta.emplace_back(sch); | |||
return *this; | |||
} | |||
std::string CustomOp::op_type(void) const { | |||
return OpImplRef(m_impl.get())->m_op_type; | |||
} | |||
std::string CustomOp::op_desc(void) const { | |||
return OpImplRef(m_impl.get())->m_op_desc; | |||
} | |||
RunTimeId CustomOp::runtime_id(void) const { | |||
return (RunTimeId)(this); | |||
} | |||
size_t CustomOp::input_num(void) const { | |||
return OpImplRef(m_impl.get())->m_input_infos.size(); | |||
} | |||
size_t CustomOp::output_num(void) const { | |||
return OpImplRef(m_impl.get())->m_output_infos.size(); | |||
} | |||
std::string CustomOp::str(void) const { | |||
std::stringstream ss; | |||
ss << "op name: " << op_type() << "\nop desc: " << op_desc() << "\n\ninputs:\n"; | |||
for (const auto &input: inputs_info()) { | |||
ss << input.str(); | |||
ss << "\n--------------------\n"; | |||
} | |||
ss << "\noutputs:\n"; | |||
for (const auto &output: outputs_info()) { | |||
ss << output.str(); | |||
ss << "\n--------------------\n"; | |||
} | |||
ss << "\nparams:\n"; | |||
for (const auto ¶m: param_info().meta()) { | |||
ss << param.str(); | |||
ss << "\n--------------------\n"; | |||
} | |||
return ss.str(); | |||
} | |||
const ParamInfo &CustomOp::param_info(void) const { | |||
return OpImplRef(m_impl.get())->m_param_infos; | |||
} | |||
ArgInfo CustomOp::input_info(size_t idx) const { | |||
return OpImplRef(m_impl.get())->m_input_infos[idx]; | |||
} | |||
ArgInfo CustomOp::output_info(size_t idx) const { | |||
return OpImplRef(m_impl.get())->m_output_infos[idx]; | |||
} | |||
const std::vector<ArgInfo> &CustomOp::inputs_info(void) const { | |||
return OpImplRef(m_impl.get())->m_input_infos; | |||
} | |||
const std::vector<ArgInfo> &CustomOp::outputs_info(void) const { | |||
return OpImplRef(m_impl.get())->m_output_infos; | |||
} | |||
std::vector<Device> CustomOp::infer_output_device(const std::vector<Device> &inputs, const Param ¶m) const { | |||
assert_inputs_size_right(inputs); | |||
std::vector<Device> outputs(output_num()); | |||
OpImplRef(m_impl.get())->infer_output_device_func(inputs, param, outputs); | |||
assert_outputs_size_right(outputs); | |||
return outputs; | |||
} | |||
std::vector<Shape> CustomOp::infer_output_shape(const std::vector<Shape> &inputs, const Param ¶m) const { | |||
assert_inputs_size_right(inputs); | |||
for (size_t i=0; i<inputs_info().size(); i++) { | |||
assert_arg_shape_dim_right(inputs[i], input_info(i)); | |||
} | |||
std::vector<Shape> outputs(output_num()); | |||
OpImplRef(m_impl.get())->infer_output_shape_func(inputs, param, outputs); | |||
for (size_t i=0; i<outputs_info().size(); i++) { | |||
assert_arg_shape_dim_right(outputs[i], output_info(i)); | |||
} | |||
assert_outputs_size_right(outputs); | |||
return outputs; | |||
} | |||
std::vector<DType> CustomOp::infer_output_dtype(const std::vector<DType> &inputs, const Param ¶m) const { | |||
assert_inputs_size_right(inputs); | |||
for (size_t i=0; i<inputs_info().size(); i++) { | |||
std::unordered_set<std::string> legal_input_dtypes_i = input_info(i).dtypes(); | |||
mgb_assert( | |||
legal_input_dtypes_i.find(inputs[i].str()) != legal_input_dtypes_i.end(), | |||
"dtypes of input: %s(%s) is not allowed, the info of this input is:\n%s", | |||
input_info(i).name().c_str(), inputs[i].str().c_str(), | |||
input_info(i).str().c_str() | |||
); | |||
} | |||
std::vector<DType> outputs(output_num()); | |||
OpImplRef(m_impl.get())->infer_output_dtype_func(inputs, param, outputs); | |||
for (size_t i=0; i<outputs_info().size(); i++) { | |||
std::unordered_set<std::string> legal_output_dtypes_i = output_info(i).dtypes(); | |||
mgb_assert( | |||
legal_output_dtypes_i.find(outputs[i].str()) != legal_output_dtypes_i.end(), | |||
"dtypes of output: %s is %s, the info of this output is:\n%s", | |||
output_info(i).name().c_str(), outputs[i].str().c_str(), | |||
output_info(i).str().c_str() | |||
); | |||
} | |||
assert_outputs_size_right(outputs); | |||
return outputs; | |||
} | |||
std::vector<Format> CustomOp::infer_output_format(const std::vector<Format> &inputs, const Param ¶m) const { | |||
assert_inputs_size_right(inputs); | |||
for (size_t i=0; i<inputs.size(); i++) { | |||
mgb_assert( | |||
inputs[i].is_default(), | |||
"the tensor format of %s:%s is not default", | |||
op_type().c_str(), input_info(i).name().c_str() | |||
); | |||
} | |||
std::vector<Format> outputs(output_num()); | |||
OpImplRef(m_impl.get())->infer_output_format_func(inputs, param, outputs); | |||
for (size_t i=0; i<outputs.size(); i++) { | |||
mgb_assert( | |||
outputs[i].is_default(), | |||
"the tensor format of %s:%s is not default", | |||
op_type().c_str(), output_info(i).name().c_str() | |||
); | |||
} | |||
assert_outputs_size_right(outputs); | |||
return outputs; | |||
} | |||
void CustomOp::compute(const std::vector<Tensor> &inputs, const Param ¶m, std::vector<Tensor> &outputs) const { | |||
assert_inputs_size_right(inputs); | |||
assert_outputs_size_right(outputs); | |||
if (outputs.size() == 0) { | |||
return; | |||
} | |||
std::string device = outputs[0].device().str(); | |||
for (size_t i=1; i<outputs.size(); ++i) { | |||
mgb_assert( | |||
outputs[i].device().str() == device, | |||
"all output tensors should have the same device attribute" | |||
); | |||
} | |||
// need to add other input/output check | |||
mgb_assert(Device::is_legal(device), "unsupported device type: %s", device.c_str()); | |||
auto preprocess_func = OpImplRef(m_impl.get())->preprocess_funcs[device]; | |||
auto forward_func = OpImplRef(m_impl.get())->compute_funcs[device]; | |||
auto postprocess_func = OpImplRef(m_impl.get())->postprocess_funcs[device]; | |||
preprocess_func(inputs, param, outputs); | |||
forward_func(inputs, param, outputs); | |||
postprocess_func(outputs, param, outputs); | |||
assert_outputs_size_right(outputs); | |||
} | |||
} |
@@ -0,0 +1,179 @@ | |||
/** | |||
* \file src/custom/impl/param.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/custom/param.h" | |||
#include "megbrain/common.h" | |||
#include "megbrain/utils/hash.h" | |||
#include <limits> | |||
#include <sstream> | |||
#include <map> | |||
using namespace mgb; | |||
namespace custom { | |||
class ParamSchemaImpl { | |||
std::string m_name; | |||
std::string m_desc; | |||
ParamVal m_default; | |||
friend ParamSchema; | |||
}; | |||
class ParamInfoImpl { | |||
std::vector<ParamSchema> m_meta; | |||
uint32_t TAG; | |||
friend ParamInfo; | |||
}; | |||
class ParamImpl { | |||
std::unordered_map<std::string, ParamVal> m_vals; | |||
ParamImpl() = default; | |||
ParamImpl(const ParamImpl &rhs) = default; | |||
ParamImpl &operator=(const ParamImpl &rhs) { | |||
mgb_assert( | |||
m_vals.size() == rhs.m_vals.size(), | |||
"params of different op, assignment failed!" | |||
); | |||
for (const auto &kv: rhs.m_vals) { | |||
auto iter = m_vals.find(kv.first); | |||
mgb_assert(iter != m_vals.end(), "params of different op, assignment failed!"); | |||
iter->second = kv.second; | |||
} | |||
return *this; | |||
} | |||
friend Param; | |||
}; | |||
CUSTOM_PIMPL_CLS_DEFINE(ParamSchema) | |||
ParamSchema::ParamSchema(const std::string &name, const ParamVal &value, const std::string &desc) | |||
: m_impl(new ParamSchemaImpl(), impl_deleter<ParamSchemaImpl>) { | |||
TypedRef(ParamSchemaImpl, m_impl.get()).m_name = name; | |||
TypedRef(ParamSchemaImpl, m_impl.get()).m_default = value; | |||
TypedRef(ParamSchemaImpl, m_impl.get()).m_desc = desc; | |||
} | |||
const std::string &ParamSchema::name(void) const { | |||
return TypedRef(ParamSchemaImpl, m_impl.get()).m_name; | |||
} | |||
const std::string &ParamSchema::desc(void) const { | |||
return TypedRef(ParamSchemaImpl, m_impl.get()).m_desc; | |||
} | |||
const ParamVal &ParamSchema::default_val(void) const { | |||
return TypedRef(ParamSchemaImpl, m_impl.get()).m_default; | |||
} | |||
ParamDynType ParamSchema::type(void) const { | |||
return TypedRef(ParamSchemaImpl, m_impl.get()).m_default.type(); | |||
} | |||
std::string ParamSchema::str(void) const { | |||
std::stringstream ss; | |||
ss << "name: " << TypedRef(ParamSchemaImpl, m_impl.get()).m_name | |||
<< "\ndesc: " << TypedRef(ParamSchemaImpl, m_impl.get()).m_desc | |||
<< "\n" << TypedRef(ParamSchemaImpl, m_impl.get()).m_default.str(); | |||
return ss.str(); | |||
} | |||
CUSTOM_PIMPL_CLS_DEFINE(ParamInfo) | |||
void ParamInfo::set_tag(const std::string &hash_str) { | |||
const char *ptr = hash_str.c_str(); | |||
TypedRef(ParamInfoImpl, m_impl.get()).TAG = 0; | |||
for (size_t i=0; i<hash_str.size(); i++) { | |||
TypedRef(ParamInfoImpl, m_impl.get()).TAG = | |||
mgb::hash_pair_combine(TypedRef(ParamInfoImpl, m_impl.get()).TAG, mgb::hash(*(ptr++))) % | |||
std::numeric_limits<uint32_t>::max(); | |||
} | |||
} | |||
void ParamInfo::set_meta(const std::vector<ParamSchema> &meta) { | |||
TypedRef(ParamInfoImpl, m_impl.get()).m_meta = meta; | |||
} | |||
uint32_t ParamInfo::tag(void) const { | |||
return TypedRef(ParamInfoImpl, m_impl.get()).TAG; | |||
} | |||
std::vector<ParamSchema> &ParamInfo::meta(void) { | |||
return TypedRef(ParamInfoImpl, m_impl.get()).m_meta; | |||
} | |||
const std::vector<ParamSchema> &ParamInfo::meta(void) const { | |||
return TypedRef(ParamInfoImpl, m_impl.get()).m_meta; | |||
} | |||
CUSTOM_PIMPL_CLS_DEFINE(Param) | |||
Param::Param(const ParamInfo &info): m_impl(new ParamImpl(), impl_deleter<ParamImpl>) { | |||
for (const auto &schema: info.meta()) { | |||
TypedRef(ParamImpl, m_impl.get()).m_vals.emplace(schema.name(), schema.default_val()); | |||
} | |||
} | |||
ParamVal &Param::operator[](const std::string &name) { | |||
return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name)->second; | |||
} | |||
const ParamVal &Param::operator[](const std::string &name) const { | |||
return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name)->second; | |||
} | |||
const std::unordered_map<std::string, ParamVal> &Param::raw() const { | |||
return TypedRef(ParamImpl, m_impl.get()).m_vals; | |||
} | |||
bool Param::exist(const std::string &name) const { | |||
return TypedRef(ParamImpl, m_impl.get()).m_vals.find(name) != | |||
TypedRef(ParamImpl, m_impl.get()).m_vals.end(); | |||
} | |||
std::string Param::to_bytes(void) const { | |||
std::string res; | |||
std::map<std::string, ParamVal> ordered_vals( | |||
TypedRef(ParamImpl, m_impl.get()).m_vals.begin(), | |||
TypedRef(ParamImpl, m_impl.get()).m_vals.end()); | |||
for (auto &&kv: ordered_vals) { | |||
res += ParamVal::to_bytes(kv.second); | |||
} | |||
return res; | |||
} | |||
void Param::from_bytes(const std::string &bytes) { | |||
std::map<std::string, ParamVal> ordered_vals( | |||
TypedRef(ParamImpl, m_impl.get()).m_vals.begin(), | |||
TypedRef(ParamImpl, m_impl.get()).m_vals.end()); | |||
size_t offset = 0; | |||
for (auto &kv: ordered_vals) { | |||
kv.second = ParamVal::from_bytes(bytes, offset); | |||
} | |||
TypedRef(ParamImpl, m_impl.get()).m_vals.clear(); | |||
TypedRef(ParamImpl, m_impl.get()).m_vals.insert(ordered_vals.begin(), ordered_vals.end()); | |||
mgb_assert(offset == bytes.size(), "wrong data loader"); | |||
} | |||
bool operator==(const Param &lhs, const Param &rhs) { | |||
if (lhs.raw().size() != rhs.raw().size()) | |||
return false; | |||
for (const auto &kv: lhs.raw()) { | |||
auto riter = rhs.raw().find(kv.first); | |||
if (riter == rhs.raw().end() || !((kv.second) == riter->second)) { | |||
return false; | |||
} | |||
} | |||
return true; | |||
} | |||
} |
@@ -0,0 +1,400 @@ | |||
/** | |||
* \file src/custom/impl/param_val.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/custom/param_val.h" | |||
#include "megbrain/common.h" | |||
#pragma GCC diagnostic ignored "-Wsign-compare" | |||
using namespace mgb; | |||
namespace custom { | |||
/** | |||
* Macro Callback for Case | |||
*/ | |||
#define CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS(dyn_type, static_type) \ | |||
case (ParamDynType::dyn_type): { \ | |||
std::unique_ptr<void, void_deleter> new_ptr( \ | |||
new static_type(TypedRef(static_type, rhs.m_ptr.get())), \ | |||
impl_deleter<static_type> \ | |||
); \ | |||
m_ptr.swap(new_ptr); \ | |||
break; \ | |||
} | |||
#define CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS(dyn_type, static_type) \ | |||
case (ParamDynType::dyn_type): { \ | |||
TypedRef(static_type, m_ptr.get()) = TypedRef(static_type, rhs.m_ptr.get());\ | |||
break; \ | |||
} | |||
#define CUSTOM_ASSERT_OPERAND_VALID(operand, opr) \ | |||
mgb_assert( \ | |||
operand.m_ptr != nullptr && operand.m_type != ParamDynType::Invalid, \ | |||
"invalid %s of operator %s of ParamVal", #operand, #opr \ | |||
) | |||
#define CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op) \ | |||
mgb_assert( \ | |||
lhs.m_type == rhs.m_type, "`%s` %s `%s` is not allowed", \ | |||
type2name[lhs.m_type].c_str(), #op, \ | |||
type2name[rhs.m_type].c_str() \ | |||
) | |||
#define CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL(dyn_type, static_type, op) \ | |||
case (ParamDynType::dyn_type): { \ | |||
const auto &rval = TypedRef(static_type, rhs.m_ptr.get()); \ | |||
return lval op rval; \ | |||
} | |||
#define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC(dyn_type, static_type, op) \ | |||
case (ParamDynType::dyn_type): { \ | |||
const auto &lval = TypedRef(static_type, lhs.m_ptr.get()); \ | |||
switch (rhs.m_type) { \ | |||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY( \ | |||
CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL, op) \ | |||
default: \ | |||
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||
} \ | |||
break; \ | |||
} | |||
#define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC(dyn_type, static_type, op) \ | |||
case (ParamDynType::dyn_type): { \ | |||
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||
const auto &lval = TypedRef(static_type, lhs.m_ptr.get()); \ | |||
const auto &rval = TypedRef(static_type, rhs.m_ptr.get()); \ | |||
return lval op rval; \ | |||
} | |||
#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(op, ret_type) \ | |||
ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \ | |||
CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ | |||
CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ | |||
\ | |||
switch (lhs.m_type) { \ | |||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ | |||
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ | |||
default: \ | |||
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||
} \ | |||
return {}; \ | |||
} | |||
#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING(op, ret_type) \ | |||
ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \ | |||
CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ | |||
CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ | |||
\ | |||
switch (lhs.m_type) { \ | |||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ | |||
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ | |||
CUSTOM_FOR_STRING_PARAMTYPE( \ | |||
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ | |||
default: \ | |||
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||
} \ | |||
return {}; \ | |||
} | |||
#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(op, ret_type) \ | |||
ret_type operator op(const ParamVal &lhs, const ParamVal &rhs) { \ | |||
CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ | |||
CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ | |||
\ | |||
switch (lhs.m_type) { \ | |||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ | |||
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ | |||
CUSTOM_FOR_STRING_PARAMTYPE( \ | |||
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ | |||
CUSTOM_FOR_EACH_LIST_PARAMTYPE( \ | |||
CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ | |||
default: \ | |||
CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ | |||
} \ | |||
return {}; \ | |||
} | |||
#define CUSTOM_CASE_TO_PRINT_NONLIST(dyn_type, static_type) \ | |||
case (ParamDynType::dyn_type): { \ | |||
auto rval = TypedRef(static_type, m_ptr.get()); \ | |||
ss << rval; \ | |||
break; \ | |||
} | |||
#define CUSTOM_CASE_TO_PRINT_LIST(dyn_type, static_type) \ | |||
case (ParamDynType::dyn_type): { \ | |||
auto rval = TypedRef(static_type, m_ptr.get()); \ | |||
ss << vec2str(rval); \ | |||
break; \ | |||
} | |||
#define CUSTOM_CASE_TO_RET_SIZE(dyn_type, static_type) \ | |||
case (ParamDynType::dyn_type): { \ | |||
return TypedRef(static_type, m_ptr.get()).size(); \ | |||
break; \ | |||
} | |||
#define CUSTOM_CASE_TO_DUMP_BASIC(dyn_type, static_type) \ | |||
case (ParamDynType::dyn_type): { \ | |||
res.resize(sizeof(ParamDynType) + sizeof(static_type)); \ | |||
memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); \ | |||
memcpy(&res[sizeof(ParamDynType)], value.m_ptr.get(), sizeof(static_type)); \ | |||
break; \ | |||
} | |||
#define CUSTOM_CASE_TO_DUMP_LIST(dyn_type, static_type) \ | |||
case (ParamDynType::dyn_type): { \ | |||
auto &ref = TypedRef(static_type, value.m_ptr.get()); \ | |||
size_t len = ref.size(); \ | |||
size_t elem_size = len != 0 ? sizeof(ref[0]) : 0; \ | |||
res.resize(sizeof(ParamDynType) + sizeof(len) + len*elem_size); \ | |||
memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); \ | |||
memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); \ | |||
memcpy(&res[sizeof(ParamDynType)+sizeof(len)], ref.data(), len*elem_size); \ | |||
break; \ | |||
} | |||
#define CUSTOM_CASE_TO_LOAD_BASIC(dyn_type, static_type) \ | |||
case (ParamDynType::dyn_type): { \ | |||
static_type val; \ | |||
memcpy(&val, &bytes[offset], sizeof(val)); \ | |||
offset += sizeof(val); \ | |||
return val; \ | |||
break; \ | |||
} | |||
#define CUSTOM_CASE_TO_LOAD_LIST(dyn_type, static_type) \ | |||
case (ParamDynType::dyn_type): { \ | |||
size_t len = 0; \ | |||
memcpy(&len, &bytes[offset], sizeof(len)); \ | |||
offset += sizeof(len); \ | |||
static_type vals; \ | |||
vals.resize(len); \ | |||
size_t elem_size = len != 0 ? sizeof(vals[0]) : 0; \ | |||
memcpy(&vals[0], &bytes[offset], len*elem_size); \ | |||
offset += len*elem_size; \ | |||
return vals; \ | |||
break; \ | |||
} | |||
ParamVal::ParamVal(): m_ptr(nullptr, [](void*) -> void {}) { | |||
m_type = ParamDynType::Invalid; | |||
} | |||
ParamVal::ParamVal(const char *str): ParamVal(std::string(str)) { | |||
} | |||
ParamVal::ParamVal(const std::initializer_list<const char*> &strs): ParamVal(std::vector<const char*>(strs)) { | |||
} | |||
ParamVal::ParamVal(const std::vector<const char*> &strs) | |||
: m_ptr(new std::vector<std::string>(), impl_deleter<std::vector<std::string>>) { | |||
m_type = ParamDynType::StringList; | |||
for (const auto &str: strs) { | |||
TypedRef(std::vector<std::string>, m_ptr.get()).emplace_back(str); | |||
} | |||
} | |||
ParamVal::ParamVal(const ParamVal &rhs): m_ptr(nullptr, [](void*) -> void {}) { | |||
mgb_assert( | |||
rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr, | |||
"invalid rhs of copy constructor of ParamVal" | |||
); | |||
m_type = rhs.m_type; | |||
switch(m_type) { | |||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS) | |||
default: { | |||
mgb_assert(false, "invalid rhs of copy constructor of ParamVal"); | |||
} | |||
} | |||
} | |||
ParamVal &ParamVal::operator=(const char *str) { | |||
this->operator=(std::string(str)); | |||
return *this; | |||
} | |||
ParamVal &ParamVal::operator=(const std::initializer_list<const char*> &strs) { | |||
this->operator=(std::vector<const char*>(strs)); | |||
return *this; | |||
} | |||
ParamVal &ParamVal::operator=(const std::vector<const char*> &strs) { | |||
std::vector<std::string> tmp_strs; | |||
for (const auto &str: strs) { | |||
tmp_strs.emplace_back(str); | |||
} | |||
this->operator=(tmp_strs); | |||
return *this; | |||
} | |||
ParamVal &ParamVal::operator=(const ParamVal &rhs) { | |||
if (&rhs == this) | |||
return *this; | |||
mgb_assert( | |||
rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr, | |||
"invalid rhs of assignment operator of ParamVal" | |||
); | |||
if (rhs.m_type == m_type) { | |||
switch(m_type) { | |||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS); | |||
default: | |||
mgb_assert(false, "invalid rhs of assignment operator of ParamVal"); | |||
} | |||
} | |||
else { | |||
m_type = rhs.m_type; | |||
switch(m_type) { | |||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS); | |||
default: | |||
mgb_assert(false, "invalid rhs of assignment operator of ParamVal"); | |||
} | |||
} | |||
return *this; | |||
} | |||
const void *ParamVal::raw_ptr(void) const { | |||
return m_ptr.get(); | |||
} | |||
void *ParamVal::raw_ptr(void) { | |||
return m_ptr.get(); | |||
} | |||
ParamDynType ParamVal::type(void) const { | |||
return m_type; | |||
} | |||
std::string ParamVal::str() const { | |||
std::stringstream ss; | |||
ss << "type: " << type2name[m_type] << "\n" << "value: "; | |||
switch (m_type) { | |||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) | |||
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) | |||
CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_PRINT_LIST) | |||
default: | |||
mgb_assert(false, "invalid data of assignment operator of ParamVal"); | |||
} | |||
return ss.str(); | |||
} | |||
size_t ParamVal::size(void) const { | |||
switch (m_type) { | |||
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE) | |||
CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE) | |||
default: | |||
mgb_assert(false, "there is no size() for basic data types"); | |||
} | |||
} | |||
std::string ParamVal::to_bytes(const ParamVal &value) { | |||
std::string res; | |||
// because the specialization of std::vector<bool> | |||
if (value.type() == ParamDynType::BoolList) { | |||
std::vector<bool> &ref = TypedRef(std::vector<bool>, value.m_ptr.get()); | |||
size_t len = ref.size(); | |||
size_t elem_size = sizeof(bool); | |||
res.resize(sizeof(ParamDynType) + sizeof(len) + len*elem_size); | |||
memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); | |||
memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); | |||
size_t startpos = sizeof(ParamDynType)+sizeof(len); | |||
for (size_t idx=0; idx<len; idx++) { | |||
bool b = ref[idx]; | |||
memcpy(&res[startpos+idx*sizeof(b)], &b, sizeof(b)); | |||
} | |||
return res; | |||
} | |||
else if (value.type() == ParamDynType::StringList) { | |||
std::vector<std::string> &ref = TypedRef(std::vector<std::string>, value.m_ptr.get()); | |||
size_t len = ref.size(); | |||
res.resize(sizeof(ParamDynType) + sizeof(len)); | |||
memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); | |||
memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); | |||
for (size_t idx=0; idx<ref.size(); ++idx) { | |||
size_t str_len = ref[idx].size(); | |||
std::string bytes(sizeof(str_len) + str_len, ' '); | |||
memcpy(&bytes[0], &str_len, sizeof(str_len)); | |||
memcpy(&bytes[sizeof(str_len)], ref[idx].data(), str_len); | |||
res += bytes; | |||
} | |||
return res; | |||
} | |||
switch(value.type()) { | |||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_DUMP_BASIC) | |||
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_DUMP_LIST) | |||
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_DUMP_LIST) | |||
default: | |||
mgb_assert(false, "invalid param type"); | |||
} | |||
return res; | |||
} | |||
ParamVal ParamVal::from_bytes(const std::string &bytes, size_t &offset) { | |||
ParamDynType data_type = ParamDynType::Invalid; | |||
memcpy(&data_type, &bytes[offset], sizeof(ParamDynType)); | |||
offset += sizeof(ParamDynType); | |||
if (data_type == ParamDynType::BoolList) { | |||
std::vector<bool> ret; | |||
size_t len = 0; | |||
memcpy(&len, &bytes[offset], sizeof(len)); | |||
offset += sizeof(len); | |||
for (size_t idx =0; idx<len; ++idx) { | |||
bool b = true; | |||
memcpy(&b, &bytes[offset], sizeof(bool)); | |||
offset += sizeof(bool); | |||
ret.push_back(b); | |||
} | |||
return ret; | |||
} | |||
else if (data_type == ParamDynType::StringList) { | |||
std::vector<std::string> ret; | |||
size_t len = 0; | |||
memcpy(&len, &bytes[offset], sizeof(len)); | |||
offset += sizeof(len); | |||
for (size_t idx =0; idx<len; ++idx) { | |||
size_t str_len = 0; | |||
memcpy(&str_len, &bytes[offset], sizeof(str_len)); | |||
offset += sizeof(str_len); | |||
std::string str(str_len, ' '); | |||
memcpy(&str[0], &bytes[offset], str_len); | |||
offset += str_len; | |||
ret.push_back(str); | |||
} | |||
return ret; | |||
} | |||
switch (data_type) { | |||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_LOAD_BASIC) | |||
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_LOAD_LIST) | |||
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_LOAD_LIST); | |||
default: | |||
mgb_assert(false, "invalid param type"); | |||
} | |||
return {}; | |||
} | |||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING(+, ParamVal) | |||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(-, ParamVal) | |||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(*, ParamVal) | |||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(/, ParamVal) | |||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(==, bool) | |||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(!=, bool) | |||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(>=, bool) | |||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<=, bool) | |||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(>, bool) | |||
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<, bool) | |||
} |
@@ -0,0 +1,486 @@ | |||
/** | |||
* \file src/custom/impl/tensor.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/custom/tensor.h" | |||
#include "megbrain/comp_node.h" | |||
#include "megbrain/common.h" | |||
#include "megbrain/tensor.h" | |||
#include <cctype> | |||
#include <algorithm> | |||
using namespace mgb; | |||
namespace custom { | |||
template<typename T> | |||
SmallVector<T> to_builtin_vector(const std::vector<T> &custom_data) { | |||
SmallVector<T> builtin_data(custom_data.size()); | |||
memcpy(builtin_data.data(), custom_data.data(), sizeof(T)*custom_data.size()); | |||
return builtin_data; | |||
} | |||
using DeviceImpl = CompNode; | |||
using ShapeImpl = megdnn::TensorShape; | |||
using DTypeImpl = megdnn::DType; | |||
using FormatImpl = megdnn::TensorLayout::Format; | |||
using TensorImpl = DeviceTensorND; | |||
#define DeviceImplRef(rawptr) (*reinterpret_cast<DeviceImpl*>(rawptr)) | |||
#define ShapeImplRef(rawptr) (*reinterpret_cast<ShapeImpl*>(rawptr)) | |||
#define DTypeImplRef(rawptr) (*reinterpret_cast<DTypeImpl*>(rawptr)) | |||
#define FormatImplRef(rawptr) (*reinterpret_cast<FormatImpl*>(rawptr)) | |||
#define TensorImplRef(rawptr) (*reinterpret_cast<TensorImpl*>(rawptr)) | |||
#define DeviceImplConstRef(rawptr) static_cast<const DeviceImpl&>(*reinterpret_cast<const DeviceImpl*>(rawptr)) | |||
#define ShapeImplConstRef(rawptr) static_cast<const ShapeImpl&>(*reinterpret_cast<const ShapeImpl*>(rawptr)) | |||
#define DTypeImplConstRef(rawptr) static_cast<const DTypeImpl&>(*reinterpret_cast<const DTypeImpl*>(rawptr)) | |||
#define FormatImplConstRef(rawptr) static_cast<const FormatImpl&>(*reinterpret_cast<const FormatImpl*>(rawptr)) | |||
#define TensorImplConstRef(rawptr) static_cast<const TensorImpl&>(*reinterpret_cast<const TensorImpl*>(rawptr)) | |||
static std::unordered_map<DeviceImpl::DeviceType, std::string, | |||
EnumHash<DeviceImpl::DeviceType>, | |||
EnumCmp<DeviceImpl::DeviceType>> dev_benum2cstr; | |||
static std::unordered_map<DeviceImpl::DeviceType, DeviceEnum, | |||
EnumHash<DeviceImpl::DeviceType>, | |||
EnumCmp<DeviceImpl::DeviceType>> dev_benum2cenum; | |||
static std::unordered_map<std::string, std::string> dev_cstr2bstr; | |||
static std::unordered_map<DeviceEnum, std::string, | |||
EnumHash<DeviceEnum>, | |||
EnumCmp<DeviceEnum>> dev_cenum2bstr; | |||
#define CUSTOM_BIND_DEVICE(custom_impl, builtin_device, builtin_str) \ | |||
auto be2cs##custom_impl = dev_benum2cstr.emplace( \ | |||
DeviceImpl::DeviceType::builtin_device, std::string(#custom_impl)); \ | |||
auto be2ce##custom_impl = dev_benum2cenum.emplace( \ | |||
DeviceImpl::DeviceType::builtin_device, DeviceEnum::custom_impl); \ | |||
auto cs2bs##custom_impl = dev_cstr2bstr.emplace( \ | |||
std::string(#custom_impl), std::string(builtin_str)); \ | |||
auto ce2bs##custom_impl = dev_cenum2bstr.emplace( \ | |||
DeviceEnum::custom_impl, std::string(builtin_str)); | |||
CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_BIND_DEVICE) | |||
#undef CUSTOM_BIND_DEVICE | |||
CUSTOM_PIMPL_CLS_DEFINE(Device) | |||
const void *Device::impl() const { | |||
return m_impl.get(); | |||
} | |||
Device::Device(const void *impl): m_impl(nullptr, impl_deleter<DeviceImpl>) { | |||
mgb_assert(impl != nullptr, "invalid ptr"); | |||
if (!DeviceImplConstRef(impl).valid()) { | |||
m_impl.reset(new DeviceImpl()); | |||
return; | |||
} | |||
auto builtin_device_enum = DeviceImplConstRef(impl).device_type(); | |||
mgb_assert( | |||
dev_benum2cenum.find(builtin_device_enum) != dev_benum2cenum.end(), | |||
"unsupported compnode type: %s", DeviceImplConstRef(impl).to_string().c_str() | |||
); | |||
m_impl.reset(new DeviceImpl(DeviceImplConstRef(impl))); | |||
} | |||
Device::Device(const std::string &device): m_impl(nullptr, impl_deleter<DeviceImpl>) { | |||
mgb_assert(is_legal(device), "invalid device type: %s", device.c_str()); | |||
std::string builtin_device = dev_cstr2bstr[device]; | |||
m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device))); | |||
} | |||
// to avoid the ambiguous from Device(const void *impl) | |||
Device::Device(const char *device): Device(std::string(device)) { | |||
} | |||
Device::Device(DeviceEnum device): m_impl(nullptr, impl_deleter<DeviceImpl>) { | |||
mgb_assert(is_legal(device), "invalid device type"); | |||
std::string builtin_device = dev_cenum2bstr[device]; | |||
m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device))); | |||
} | |||
std::string Device::str(void) const { | |||
if (!DeviceImplRef(m_impl.get()).valid()) { | |||
return "invalid"; | |||
} | |||
auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type(); | |||
auto iter = dev_benum2cstr.find(builtin_device_type); | |||
mgb_assert( | |||
iter != dev_benum2cstr.end(), "invalid device type %s\n", | |||
DeviceImplRef(m_impl.get()).to_string().c_str() | |||
); | |||
return iter->second; | |||
} | |||
DeviceEnum Device::enumv(void) const { | |||
mgb_assert( | |||
DeviceImplRef(m_impl.get()).valid(), | |||
"cannot get the enum value of invalid device" | |||
); | |||
auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type(); | |||
auto iter = dev_benum2cenum.find(builtin_device_type); | |||
mgb_assert( | |||
iter != dev_benum2cenum.end(), "invalid device type %s\n", | |||
DeviceImplRef(m_impl.get()).to_string().c_str() | |||
); | |||
return iter->second; | |||
} | |||
bool Device::is_legal(const std::string &device_type) { | |||
return dev_cstr2bstr.find(device_type) != dev_cstr2bstr.end(); | |||
} | |||
bool Device::is_legal(DeviceEnum device_type) { | |||
return dev_cenum2bstr.find(device_type) != dev_cenum2bstr.end(); | |||
} | |||
std::vector<std::string> Device::legal_devices(void) { | |||
std::vector<std::string> ret; | |||
for (const auto &kv: dev_cstr2bstr) { | |||
ret.emplace_back(kv.first); | |||
} | |||
return ret; | |||
} | |||
bool operator==(const Device &lhs, const Device &rhs) { | |||
return lhs.str() == rhs.str(); | |||
} | |||
CUSTOM_PIMPL_CLS_DEFINE(Shape) | |||
const void *Shape::impl() const { | |||
return m_impl.get(); | |||
} | |||
Shape::Shape(const void *impl): m_impl(nullptr, impl_deleter<ShapeImpl>) { | |||
mgb_assert(impl != nullptr, "invalid ptr"); | |||
m_impl.reset(new ShapeImpl(ShapeImplConstRef(impl))); | |||
} | |||
Shape::Shape(const std::vector<size_t> &rhs): m_impl(nullptr, impl_deleter<ShapeImpl>) { | |||
m_impl.reset(new ShapeImpl(to_builtin_vector<size_t>(rhs))); | |||
} | |||
Shape::Shape(const std::initializer_list<size_t> &rhs): m_impl(nullptr, impl_deleter<ShapeImpl>) { | |||
m_impl.reset(new ShapeImpl(rhs)); | |||
} | |||
size_t &Shape::operator[](size_t idx) { | |||
mgb_assert(idx < ndim(), "wrong tensor dimension idx: %lu < %lu", static_cast<unsigned long>(idx), static_cast<unsigned long>(ndim())); | |||
return ShapeImplRef(m_impl.get()).operator[](idx); | |||
} | |||
size_t Shape::operator[](size_t idx) const { | |||
return const_cast<Shape*>(this)->operator[](idx); | |||
} | |||
void Shape::ndim(size_t dim) { | |||
mgb_assert(dim < ShapeImpl::MAX_NDIM, "dimension must <= %lu", static_cast<unsigned long>(ShapeImpl::MAX_NDIM)); | |||
ShapeImplRef(m_impl.get()).ndim = dim; | |||
} | |||
size_t Shape::ndim(void) const { | |||
return ShapeImplRef(m_impl.get()).ndim; | |||
} | |||
bool operator==(const Shape &lhs, const Shape &rhs) { | |||
return ShapeImplRef(lhs.m_impl.get()).eq_shape(ShapeImplRef(rhs.m_impl.get())); | |||
} | |||
static std::unordered_map<std::string, megdnn::DTypeEnum> dtype_cstr2benum; | |||
static std::unordered_map<DTypeEnum, megdnn::DTypeEnum, | |||
EnumHash<DTypeEnum>, | |||
EnumCmp<DTypeEnum>> dtype_cenum2benum; | |||
static std::unordered_map<megdnn::DTypeEnum, std::string, | |||
EnumHash<megdnn::DTypeEnum>, | |||
EnumCmp<megdnn::DTypeEnum>> dtype_benum2cstr; | |||
static std::unordered_map<megdnn::DTypeEnum, DTypeEnum, | |||
EnumHash<megdnn::DTypeEnum>, | |||
EnumCmp<megdnn::DTypeEnum>> dtype_benum2cenum; | |||
static std::unordered_map<DTypeEnum, std::string, | |||
EnumHash<DTypeEnum>, | |||
EnumCmp<DTypeEnum>> dtype_cenum2cstr; | |||
#define CUSTOM_BIND_DTYPE(custom_impl, builtin_dtype, ctype) \ | |||
auto cs2be##custom_impl = dtype_cstr2benum.emplace( \ | |||
std::string(#custom_impl), megdnn::DTypeEnum::builtin_dtype); \ | |||
auto ce2be##custom_impl = dtype_cenum2benum.emplace( \ | |||
DTypeEnum::custom_impl, megdnn::DTypeEnum::builtin_dtype); \ | |||
auto be2cs##custom_impl = dtype_benum2cstr.emplace( \ | |||
megdnn::DTypeEnum::builtin_dtype, std::string(#custom_impl)); \ | |||
auto be2ce##custom_impl = dtype_benum2cenum.emplace( \ | |||
megdnn::DTypeEnum::builtin_dtype, DTypeEnum::custom_impl); \ | |||
auto ce2cs##custom_impl = dtype_cenum2cstr.emplace( \ | |||
DTypeEnum::custom_impl, std::string(#custom_impl)); | |||
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_BIND_DTYPE) | |||
#undef CUSTOM_BIND_DTYPE | |||
CUSTOM_PIMPL_CLS_DEFINE(DType) | |||
const void *DType::impl() const { | |||
return m_impl.get(); | |||
} | |||
DType::DType(const void *impl): m_impl(nullptr, impl_deleter<DTypeImpl>) { | |||
mgb_assert(impl != nullptr, "invalid ptr"); | |||
m_impl.reset(new DTypeImpl(DTypeImplConstRef(impl))); | |||
} | |||
DType::DType(const std::string &dtype): m_impl(nullptr, impl_deleter<DTypeImpl>) { | |||
auto iter = dtype_cstr2benum.find(dtype); | |||
mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str()); | |||
mgb_assert( | |||
dtype[0] != 'q', "can not construct quantized dtype " | |||
"%s without scale and zero_point", dtype.c_str() | |||
); | |||
m_impl.reset(new DTypeImpl(DTypeImpl::from_enum(iter->second))); | |||
} | |||
DType::DType(const char *dtype): DType(std::string(dtype)) { | |||
} | |||
DType::DType(const std::string &dtype, float scale, uint8_t zero_point) | |||
: m_impl(nullptr, impl_deleter<DTypeImpl>) { | |||
auto iter = dtype_cstr2benum.find(dtype); | |||
mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str()); | |||
mgb_assert( | |||
dtype[0] == 'q', "given scale/zero_point to construct " | |||
"non-quantized dtype: %s is not allowed", dtype.c_str() | |||
); | |||
if (dtype == "quint8") { | |||
m_impl.reset(new megdnn::ParameterizedDType< | |||
megdnn::DTypeEnum::Quantized8Asymm>(scale, zero_point)); | |||
} | |||
else { | |||
mgb_assert( | |||
zero_point == 0, "invalid zero point %d for dtype %s", | |||
zero_point, dtype.c_str() | |||
); | |||
if (dtype == "qint8") { | |||
m_impl.reset(new megdnn::ParameterizedDType< | |||
megdnn::DTypeEnum::QuantizedS8>(scale)); | |||
} | |||
else if (dtype == "qint16") { | |||
m_impl.reset(new megdnn::ParameterizedDType< | |||
megdnn::DTypeEnum::QuantizedS16>(scale)); | |||
} | |||
else if (dtype == "qint32") { | |||
m_impl.reset(new megdnn::ParameterizedDType< | |||
megdnn::DTypeEnum::QuantizedS32>(scale)); | |||
} | |||
else { | |||
mgb_assert(false, "invalid dtype %s", dtype.c_str()); | |||
} | |||
} | |||
} | |||
DType::DType(const char *dtype, float scale, uint8_t zero_point) | |||
: DType(std::string(dtype), scale, zero_point) { | |||
} | |||
DType::DType(DTypeEnum dtype): m_impl(nullptr, impl_deleter<DTypeImpl>) { | |||
auto iter = dtype_cenum2benum.find(dtype); | |||
mgb_assert(iter != dtype_cenum2benum.end(), "invalid dtype"); | |||
mgb_assert(dtype < DTypeEnum::quint8, | |||
"can not construct quantized dtype without scale and zero_point"); | |||
m_impl.reset(new DTypeImpl(DTypeImpl::from_enum(iter->second))); | |||
} | |||
DType::DType(DTypeEnum dtype, float scale, uint8_t zero_point) | |||
: DType(dtype_cenum2cstr.find(dtype)->second, scale, zero_point) { | |||
} | |||
std::string DType::str(void) const { | |||
if (!DTypeImplRef(m_impl.get()).valid()) | |||
return "invalid"; | |||
auto iter = dtype_benum2cstr.find(DTypeImplRef(m_impl.get()).enumv()); | |||
if (iter == dtype_benum2cstr.end()) | |||
return "invalid"; | |||
return iter->second; | |||
} | |||
DTypeEnum DType::enumv(void) const { | |||
auto iter = dtype_benum2cenum.find(DTypeImplRef(m_impl.get()).enumv()); | |||
mgb_assert(iter != dtype_benum2cenum.end(), "invalid dtype"); | |||
return iter->second; | |||
} | |||
float DType::scale() const { | |||
if (enumv() == DTypeEnum::qint8) { | |||
return DTypeImplRef(m_impl.get()).param<dtype::QuantizedS8>().scale; | |||
} | |||
else if (enumv() == DTypeEnum::qint16) { | |||
return DTypeImplRef(m_impl.get()).param<dtype::QuantizedS16>().scale; | |||
} | |||
else if (enumv() == DTypeEnum::qint32) { | |||
return DTypeImplRef(m_impl.get()).param<dtype::QuantizedS32>().scale; | |||
} | |||
else if (enumv() == DTypeEnum::quint8) { | |||
return DTypeImplRef(m_impl.get()).param<dtype::Quantized8Asymm>().scale; | |||
} | |||
else { | |||
mgb_assert(false, "dtype %s has no scale", str().c_str()); | |||
return 0.f; | |||
} | |||
} | |||
uint8_t DType::zero_point() const { | |||
mgb_assert(enumv()==DTypeEnum::quint8, "dtype %s has no zero point", str().c_str()); | |||
return DTypeImplRef(m_impl.get()).param<dtype::Quantized8Asymm>().zero_point; | |||
} | |||
bool DType::is_legal(const std::string &dtype) { | |||
return dtype_cstr2benum.find(dtype) != dtype_cstr2benum.end(); | |||
} | |||
bool DType::is_legal(const DTypeEnum &dtype) { | |||
return dtype_cenum2benum.find(dtype) != dtype_cenum2benum.end(); | |||
} | |||
std::vector<std::string> DType::legal_dtypes(void) { | |||
std::vector<std::string> ret; | |||
for (const auto &kv: dtype_cstr2benum) | |||
ret.emplace_back(kv.first); | |||
return ret; | |||
} | |||
bool operator==(const DType &lhs, const DType &rhs) { | |||
return DTypeImplRef(lhs.m_impl.get()) == DTypeImplRef(rhs.m_impl.get()); | |||
} | |||
bool operator==(const DType &lhs, const std::string &rhs) { | |||
return lhs.str() == rhs; | |||
} | |||
bool operator==(const DType &lhs, const char *rhs) { | |||
return operator==(lhs, std::string(rhs)); | |||
} | |||
bool operator==(const std::string &lhs, const DType &rhs) { | |||
return operator==(rhs, lhs); | |||
} | |||
bool operator==(const char *lhs, const DType &rhs) { | |||
return operator==(rhs, std::string(lhs)); | |||
} | |||
CUSTOM_PIMPL_CLS_DEFINE(Format) | |||
const void *Format::impl() const { | |||
return m_impl.get(); | |||
} | |||
Format::Format(const void *impl): m_impl(nullptr, impl_deleter<FormatImpl>) { | |||
mgb_assert(impl != nullptr, "invalid ptr"); | |||
mgb_assert(FormatImplConstRef(impl).is_default(), "only default format is supported now"); | |||
m_impl.reset(new FormatImpl(FormatImplConstRef(impl))); | |||
} | |||
Format::Format(const std::string &format): m_impl(nullptr, impl_deleter<FormatImpl>) { | |||
mgb_assert(format == "default", "only default format is supported now"); | |||
m_impl.reset(new FormatImpl()); | |||
} | |||
Format::Format(const char *format): Format(std::string(format)) { | |||
} | |||
std::string Format::str(void) const { | |||
return FormatImplRef(m_impl.get()).to_string(); | |||
} | |||
bool Format::is_default(void) const { | |||
return FormatImplRef(m_impl.get()).is_default(); | |||
} | |||
const void *Tensor::impl(void) const { | |||
return m_tensor; | |||
} | |||
Tensor::Tensor(const void *impl) { | |||
mgb_assert(impl != nullptr, "invalid ptr"); | |||
m_tensor = const_cast<void*>(impl); | |||
} | |||
const size_t *Tensor::shapes_raw(void) const { | |||
return TensorImplRef(m_tensor).shape().shape; | |||
} | |||
const ptrdiff_t *Tensor::strides_raw(void) const { | |||
return TensorImplRef(m_tensor).layout().stride; | |||
} | |||
Tensor::Tensor(const Tensor &rhs) { | |||
mgb_assert(rhs.m_tensor != nullptr, "invalid rhs for copy constructor\n"); | |||
m_tensor = rhs.m_tensor; | |||
} | |||
Tensor &Tensor::operator=(const Tensor &rhs) { | |||
mgb_assert(rhs.m_tensor != nullptr, "invalid rhs for assignment operator"); | |||
if (&rhs == this || rhs.m_tensor == m_tensor) | |||
return *this; | |||
m_tensor = rhs.m_tensor; | |||
return *this; | |||
} | |||
Shape Tensor::shape(void) const { | |||
auto builtin = TensorImplRef(m_tensor).shape(); | |||
return Shape(&builtin); | |||
} | |||
DType Tensor::dtype(void) const { | |||
auto builtin = TensorImplRef(m_tensor).dtype(); | |||
return DType(&builtin); | |||
} | |||
Format Tensor::format(void) const { | |||
auto builtin = TensorImplRef(m_tensor).format(); | |||
return Format(&builtin); | |||
} | |||
Device Tensor::device(void) const { | |||
auto builtin = TensorImplRef(m_tensor).comp_node(); | |||
return Device(&builtin); | |||
} | |||
size_t Tensor::size(void) const { | |||
return TensorImplRef(m_tensor).shape().total_nr_elems(); | |||
} | |||
std::vector<ptrdiff_t> Tensor::stride(void) const { | |||
std::vector<ptrdiff_t> ret(TensorImplRef(m_tensor).shape().ndim); | |||
for (size_t i=0; i<ret.size(); i++) | |||
ret[i] = TensorImplRef(m_tensor).layout().stride[i]; | |||
return ret; | |||
} | |||
float Tensor::scale(void) const { | |||
return dtype().scale(); | |||
} | |||
uint8_t Tensor::zero_point(void) const { | |||
return dtype().zero_point(); | |||
} | |||
void *Tensor::data(void) { | |||
return static_cast<void*>(TensorImplRef(m_tensor).raw_ptr()); | |||
} | |||
const void *Tensor::data(void) const { | |||
return static_cast<const void*>(TensorImplRef(m_tensor).raw_ptr()); | |||
} | |||
} // namespace custom |
@@ -0,0 +1,41 @@ | |||
/** | |||
* \file src/custom/impl/utils.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/custom/utils.h" | |||
#include "megbrain/common.h" | |||
#include <sstream> | |||
using namespace mgb; | |||
namespace custom { | |||
void assert_failed_log(const char *file, int line, const char *func, const char *expr, const char *msg_fmt, ...) { | |||
std::string msg = ssprintf("`%s' is true at %s:%d: %s", expr, file, line, func); | |||
if (msg_fmt) { | |||
msg_fmt = convert_fmt_str(msg_fmt); | |||
va_list ap; | |||
va_start(ap, msg_fmt); | |||
msg.append("\nextra message: "); | |||
msg.append(svsprintf(msg_fmt, ap)); | |||
va_end(ap); | |||
} | |||
printf("%s\n", msg.c_str()); | |||
} | |||
UnImpleWarnLog::UnImpleWarnLog(const std::string &func, const std::string &attr, | |||
const std::string &val) { | |||
mgb_log_warn("you are using the default custom %s function, the `%s` attribute " | |||
"of all the outputs tensor will be the same with inputs tensor[0]. " | |||
"If there is no input tensor, it will be `%s`", | |||
func.c_str(), attr.c_str(), val.c_str()); | |||
} | |||
} |
@@ -0,0 +1,185 @@ | |||
/** | |||
* \file src/custom/include/megbrain/custom/accessor.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <cstddef> | |||
#include <cstdint> | |||
namespace custom { | |||
#ifdef __CUDACC__ | |||
#define CUSTOM_HOST __host__ | |||
#define CUSTOM_DEVICE __device__ | |||
#else | |||
#define CUSTOM_HOST | |||
#define CUSTOM_DEVICE | |||
#endif | |||
#define CUSTOM_HOST_DEVICE CUSTOM_HOST CUSTOM_DEVICE | |||
template <typename T> | |||
struct DefaultPtrTraits { | |||
using PtrType = T*; | |||
}; | |||
#ifdef __CUDACC__ | |||
template <typename T> | |||
struct RestrictPtrTraits { | |||
using PtrType = T* __restrict__; | |||
}; | |||
#endif | |||
template <typename T, size_t N, | |||
template <typename U> class PtrTraits = DefaultPtrTraits, | |||
typename index_t = int64_t> | |||
class TensorAccessorProxyBase { | |||
public: | |||
using PtrType = typename PtrTraits<T>::PtrType; | |||
protected: | |||
PtrType m_data; | |||
const index_t* m_sizes; | |||
const index_t* m_strides; | |||
public: | |||
CUSTOM_HOST_DEVICE TensorAccessorProxyBase(PtrType data, const index_t *sizes, const index_t *strides) { | |||
m_data = data; | |||
m_sizes = sizes; | |||
m_strides = strides; | |||
} | |||
CUSTOM_HOST_DEVICE index_t stride(index_t i) const { | |||
return m_strides[i]; | |||
} | |||
CUSTOM_HOST_DEVICE index_t size(index_t i) const { | |||
return m_sizes[i]; | |||
} | |||
CUSTOM_HOST_DEVICE PtrType data() const { | |||
return m_data; | |||
} | |||
}; | |||
template<typename T, size_t N, | |||
template <typename U> class PtrTraits = DefaultPtrTraits, | |||
typename index_t = int64_t> | |||
class TensorAccessorProxy: public TensorAccessorProxyBase<T, N, PtrTraits, index_t> { | |||
public: | |||
using PtrType = typename PtrTraits<T>::PtrType; | |||
CUSTOM_HOST_DEVICE TensorAccessorProxy(PtrType data, const index_t *sizes, const index_t *strides) | |||
: TensorAccessorProxyBase<T, N, PtrTraits, index_t>(data, sizes, strides) { | |||
} | |||
CUSTOM_HOST_DEVICE TensorAccessorProxy<T, N-1, PtrTraits, index_t> operator[](index_t i) { | |||
return TensorAccessorProxy<T, N-1, PtrTraits, index_t>( | |||
this->m_data + this->m_strides[0] * i, | |||
this->m_sizes + 1, | |||
this->m_strides + 1 | |||
); | |||
} | |||
CUSTOM_HOST_DEVICE const TensorAccessorProxy<T, N-1, PtrTraits, index_t> operator[](index_t i) const { | |||
return TensorAccessorProxy<T, N-1, PtrTraits, index_t>( | |||
this->m_data + this->m_strides[0] * i, | |||
this->m_sizes + 1, | |||
this->m_strides + 1 | |||
); | |||
} | |||
}; | |||
template<typename T, template <typename U> class PtrTraits, typename index_t> | |||
class TensorAccessorProxy<T, 1, PtrTraits, index_t> | |||
: public TensorAccessorProxyBase<T, 1, PtrTraits, index_t> { | |||
public: | |||
using PtrType = typename PtrTraits<T>::PtrType; | |||
CUSTOM_HOST_DEVICE TensorAccessorProxy(PtrType data, const index_t *sizes, const index_t *strides) | |||
: TensorAccessorProxyBase<T, 1, PtrTraits, index_t>(data, sizes, strides ) { | |||
} | |||
CUSTOM_HOST_DEVICE T &operator[](index_t i) { | |||
return this->m_data[this->m_strides[0]*i]; | |||
} | |||
CUSTOM_HOST_DEVICE const T &operator[](index_t i) const { | |||
return this->m_data[this->m_strides[0]*i]; | |||
} | |||
}; | |||
template<typename T, size_t N, | |||
template <typename U> class PtrTraits = DefaultPtrTraits, | |||
typename index_t = int64_t> | |||
class TensorAccessorBase { | |||
public: | |||
using PtrType = typename PtrTraits<T>::PtrType; | |||
protected: | |||
PtrType m_data; | |||
index_t m_sizes[N]; | |||
index_t m_strides[N]; | |||
public: | |||
CUSTOM_HOST_DEVICE TensorAccessorBase(PtrType data, const size_t *sizes, const ptrdiff_t *strides) { | |||
m_data = data; | |||
for (size_t i=0; i<N; ++i) { | |||
m_sizes[i] = sizes[i]; | |||
m_strides[i] = strides[i]; | |||
} | |||
} | |||
CUSTOM_HOST_DEVICE index_t stride(index_t i) const { | |||
return m_strides[i]; | |||
} | |||
CUSTOM_HOST_DEVICE index_t size(index_t i) const { | |||
return m_sizes[i]; | |||
} | |||
CUSTOM_HOST_DEVICE PtrType data() const { | |||
return m_data; | |||
} | |||
}; | |||
template<typename T, size_t N, | |||
template <typename U> class PtrTraits = DefaultPtrTraits, | |||
typename index_t = int64_t> | |||
class TensorAccessor: public TensorAccessorBase<T, N, PtrTraits, index_t> { | |||
public: | |||
using PtrType = typename PtrTraits<T>::PtrType; | |||
CUSTOM_HOST_DEVICE TensorAccessor(PtrType data, const size_t *sizes, const ptrdiff_t *strides) | |||
: TensorAccessorBase<T, N, PtrTraits, index_t>(data, sizes, strides) { | |||
} | |||
CUSTOM_HOST_DEVICE decltype(auto) operator[](index_t i) { | |||
return TensorAccessorProxy<T, N, PtrTraits, index_t>( | |||
this->m_data, | |||
this->m_sizes, | |||
this->m_strides | |||
)[i]; | |||
} | |||
CUSTOM_HOST_DEVICE decltype(auto) operator[](index_t i) const { | |||
return TensorAccessorProxy<T, N, PtrTraits, index_t>( | |||
this->m_data, | |||
this->m_sizes, | |||
this->m_strides | |||
)[i]; | |||
} | |||
}; | |||
} |
@@ -0,0 +1,108 @@ | |||
/** | |||
* \file src/custom/include/megbrain/custom/custom.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "op.h" | |||
#include "tensor.h" | |||
#include "param.h" | |||
namespace custom { | |||
std::shared_ptr<CustomOp> op_insert(std::string opname, uint32_t version); | |||
} | |||
#define CUSTOM_OP_REG(OpName) CustomOp &_##OpName = (*(op_insert(#OpName, CUSTOM_OP_VERSION))) | |||
#define CUSTOM_OP_REG_BEGIN(OpName) \ | |||
namespace custom { \ | |||
namespace OpName { | |||
#define CUSTOM_OP_REG_END(OpName) \ | |||
} \ | |||
} | |||
#define CASE_TO_PERFORM_USING_HINT(name, case_type, real_type, hint, ...) \ | |||
case (case_type): { \ | |||
using hint = real_type; \ | |||
return __VA_ARGS__(); \ | |||
} | |||
#define CASE_TO_PERFORM_ON_SCALAR(name, case_type, real_type, ...) \ | |||
CASE_TO_PERFORM_USING_HINT(name, case_type, real_type, scalar_t, __VA_ARGS__) | |||
#define DISPATCH_FLOAT_TYPES(tensor_dtype, name, ...) \ | |||
[&]() { \ | |||
const auto &dtype = tensor_dtype; \ | |||
switch (dtype.enumv()) { \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \ | |||
default: \ | |||
custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||
name, dtype.str().c_str()); \ | |||
} \ | |||
}() | |||
#define DISPATCH_INT_TYPES(tensor_dtype, name, ...) \ | |||
[&]() { \ | |||
const auto &dtype = tensor_dtype; \ | |||
switch (dtype.enumv()) { \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint8, uint8_t, __VA_ARGS__) \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint16,uint16_t, __VA_ARGS__)\ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ | |||
default: \ | |||
custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||
name, dtype.str().c_str()); \ | |||
} \ | |||
}() | |||
#define DISPATCH_INT_AND_FLOAT_TYPES(tensor_dtype, name, ...) \ | |||
[&]() { \ | |||
const auto &dtype = tensor_dtype; \ | |||
switch (dtype.enumv()) { \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint8, uint8_t, __VA_ARGS__) \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::uint16,uint16_t, __VA_ARGS__)\ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \ | |||
default: \ | |||
custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||
name, dtype.str().c_str()); \ | |||
} \ | |||
}() | |||
#define DISPATCH_SIGN_INT_TYPES(tensor_dtype, name, ...) \ | |||
[&]() { \ | |||
const auto &dtype = tensor_dtype; \ | |||
switch (dtype.enumv()) { \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ | |||
default: \ | |||
custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||
name, dtype.str().c_str()); \ | |||
} \ | |||
}() | |||
#define DISPATCH_SIGN_INT_AND_FLOAT_TYPES(tensor_dtype, name, ...) \ | |||
[&]() { \ | |||
const auto &dtype = tensor_dtype; \ | |||
switch (dtype.enumv()) { \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::float32, float, __VA_ARGS__) \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int8, int8_t, __VA_ARGS__) \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int16, int16_t, __VA_ARGS__) \ | |||
CASE_TO_PERFORM_ON_SCALAR(name, DTypeEnum::int32, int32_t, __VA_ARGS__) \ | |||
default: \ | |||
custom_assert(false, "no implemented %s kernel for dtype %s\n", \ | |||
name, dtype.str().c_str()); \ | |||
} \ | |||
}() |
@@ -0,0 +1,58 @@ | |||
/** | |||
* \file src/custom/include/megbrain/custom/data_adaptor.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/thin/small_vector.h" | |||
namespace custom { | |||
template <typename BuiltinT, typename CustomT> | |||
BuiltinT to_builtin(const CustomT &custom) { | |||
return *reinterpret_cast<const BuiltinT*>(custom.impl()); | |||
} | |||
template <typename BuiltinT, typename CustomT> | |||
CustomT to_custom(const BuiltinT &builtin) { | |||
return std::move(CustomT(&builtin)); | |||
} | |||
template <typename BuiltinT, typename CustomT> | |||
megdnn::SmallVector<BuiltinT> to_builtin(const std::vector<CustomT> &customs) { | |||
megdnn::SmallVector<BuiltinT> builtins; | |||
for (size_t i=0; i<customs.size(); ++i) { | |||
builtins.push_back(std::move(to_builtin<BuiltinT, CustomT>(customs[i]))); | |||
} | |||
return std::move(builtins); | |||
} | |||
template <typename BuiltinT, typename CustomT> | |||
std::vector<CustomT> to_custom( | |||
const megdnn::SmallVector<BuiltinT> &builtins) { | |||
std::vector<CustomT> customs; | |||
for (size_t i=0; i<builtins.size(); ++i) { | |||
customs.push_back(std::move(to_custom<BuiltinT, CustomT>(builtins[i]))); | |||
} | |||
return std::move(customs); | |||
} | |||
} | |||
#define to_custom_device(expr) custom::to_custom<CompNode, custom::Device>(expr) | |||
#define to_builtin_device(expr) custom::to_builtin<CompNode, custom::Device>(expr) | |||
#define to_custom_shape(expr) custom::to_custom<megdnn::TensorShape, custom::Shape>(expr) | |||
#define to_builtin_shape(expr) custom::to_builtin<megdnn::TensorShape, custom::Shape>(expr) | |||
#define to_custom_dtype(expr) custom::to_custom<megdnn::DType, custom::DType>(expr) | |||
#define to_builtin_dtype(expr) custom::to_builtin<megdnn::DType, custom::DType>(expr) | |||
#define to_custom_format(expr) custom::to_custom<megdnn::TensorLayout::Format, custom::Format>(expr) | |||
#define to_builtin_format(expr) custom::to_builtin<megdnn::TensorLayout::Format, custom::Format>(expr) | |||
#define to_custom_tensor(expr) custom::to_custom<DeviceTensorND, custom::Tensor>(expr) | |||
#define to_builtin_tensor(expr) custom::to_builtin<DeviceTensorND, custom::Tensor>(expr) |
@@ -0,0 +1,75 @@ | |||
/** | |||
* \file src/custom/include/megbrain/custom/manager.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "custom.h" | |||
#include "megbrain/common.h" | |||
namespace custom { | |||
class CustomOpManager { | |||
std::unordered_map<std::string, std::shared_ptr<const CustomOp>> m_name2op; | |||
std::unordered_map<RunTimeId, std::shared_ptr<const CustomOp>> m_id2op; | |||
MGB_MUTEX m_mtx; | |||
CustomOpManager() = default; | |||
public: | |||
PREVENT_COPY_AND_ASSIGN(CustomOpManager); | |||
static CustomOpManager *inst(void); | |||
~CustomOpManager(); | |||
std::shared_ptr<CustomOp> insert(const std::string &name, uint32_t version); | |||
bool erase(const std::string &name); | |||
bool erase(const RunTimeId &id); | |||
std::shared_ptr<CustomOp> find_or_reg(const std::string &name, uint32_t version); | |||
RunTimeId to_id(const std::string &name) const; | |||
std::string to_name(const RunTimeId &id) const; | |||
std::shared_ptr<const CustomOp> find(const std::string &name) const; | |||
std::shared_ptr<const CustomOp> find(const RunTimeId &id) const; | |||
std::vector<std::string> op_name_list(void); | |||
std::vector<RunTimeId> op_id_list(void); | |||
}; | |||
class CustomLib { | |||
std::unique_ptr<void, void_deleter> m_handle; | |||
std::vector<std::string> m_ops; | |||
public: | |||
PREVENT_COPY_AND_ASSIGN(CustomLib); | |||
CustomLib(const std::string &path, int mode); | |||
const std::vector<std::string> &ops_in_lib(void) const; | |||
~CustomLib(); | |||
bool valid(void) const; | |||
}; | |||
using LibHandle = std::shared_ptr<CustomLib>; | |||
class LibManager { | |||
std::unordered_map<std::string, LibHandle> m_custom_libs; | |||
MGB_MUTEX m_mtx; | |||
LibManager() = default; | |||
public: | |||
PREVENT_COPY_AND_ASSIGN(LibManager); | |||
static LibManager *inst(void); | |||
const std::vector<std::string> &install(const std::string &name, const std::string &path); | |||
bool uninstall(const std::string &name); | |||
friend class CustomOpManager; | |||
}; | |||
} |
@@ -0,0 +1,109 @@ | |||
/** | |||
* \file src/custom/include/megbrain/custom/op.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "tensor.h" | |||
#include "param.h" | |||
#include <unordered_set> | |||
#define PREVENT_COPY_AND_ASSIGN(Cls) \ | |||
Cls(const Cls&) = delete; \ | |||
Cls(const Cls&&) = delete; \ | |||
Cls &operator=(const Cls&) = delete; \ | |||
Cls &operator=(const Cls&&) = delete | |||
#define CUSTOM_OP_MAJOR 0 | |||
#define CUSTOM_OP_MINOR 1 | |||
#define CUSTOM_OP_PATCH 0 | |||
#define CUSTOM_OP_VERSION CUSTOM_OP_MAJOR*10000 + CUSTOM_OP_MINOR*100 + CUSTOM_OP_PATCH | |||
namespace custom { | |||
using RunTimeId = uint64_t; | |||
class ArgInfo { | |||
CUSTOM_PIMPL_CLS_DECL(ArgInfo); | |||
ArgInfo(const std::string &name, | |||
const std::string &desc, | |||
const std::unordered_set<std::string> &dtypes, | |||
const int &ndim, | |||
const std::string &mem_stgy); | |||
const std::string &name(void) const; | |||
const std::string &desc(void) const; | |||
const std::unordered_set<std::string> &dtypes(void) const; | |||
int ndim(void) const; | |||
const std::string &mem_strategy(void) const; | |||
std::string str() const; | |||
}; | |||
class CustomOp { | |||
std::unique_ptr<void, void_deleter> m_impl; | |||
public: | |||
CustomOp(const std::string &op_type, uint32_t version); | |||
PREVENT_COPY_AND_ASSIGN(CustomOp); | |||
using DeviceInferFuncPtr = void(*)(const std::vector<Device>&, const Param&, std::vector<Device>&); | |||
using ShapeInferFuncPtr = void(*)(const std::vector<Shape>&, const Param&, std::vector<Shape>&); | |||
using DTypeInferFuncPtr = void(*)(const std::vector<DType>&, const Param&, std::vector<DType>&); | |||
using FormatInferFuncPtr = void(*)(const std::vector<Format>&, const Param&, std::vector<Format>&); | |||
using PreprocessFuncPtr = void(*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&); | |||
using PostprocessFuncPtr = void(*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&); | |||
using ComputeFuncPtr = void(*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&); | |||
// write for forward | |||
CustomOp &set_device_infer(DeviceInferFuncPtr func); | |||
CustomOp &set_shape_infer(ShapeInferFuncPtr func); | |||
CustomOp &set_dtype_infer(DTypeInferFuncPtr func); | |||
CustomOp &set_format_infer(FormatInferFuncPtr func); | |||
CustomOp &set_preprocess(PreprocessFuncPtr func); | |||
CustomOp &set_preprocess(const std::string &device, PreprocessFuncPtr func); | |||
CustomOp &set_postprocess(PostprocessFuncPtr func); | |||
CustomOp &set_postprocess(const std::string &device, PostprocessFuncPtr func); | |||
CustomOp &set_compute(ComputeFuncPtr func); | |||
CustomOp &set_compute(const std::string &device, ComputeFuncPtr func); | |||
CustomOp &set_description(const std::string &op_desc); | |||
CustomOp &add_input(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); | |||
CustomOp &add_output(const std::string &name, const std::string &desc, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); | |||
CustomOp &add_input(const std::string &name, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); | |||
CustomOp &add_output(const std::string &name, const std::initializer_list<std::string> &legal_dtypes={"float32"}, int dims=-1, const std::string &mem_stgy="default"); | |||
CustomOp &add_inputs(const size_t &input_num); | |||
CustomOp &add_outputs(const size_t &output_num); | |||
CustomOp &add_param(const std::string &name, const ParamVal &default_val); | |||
CustomOp &add_param(const std::string &name, const std::string &desc, const ParamVal &default_val); | |||
// read | |||
std::string op_type(void) const; | |||
std::string op_desc(void) const; | |||
RunTimeId runtime_id(void) const; | |||
size_t input_num(void) const; | |||
size_t output_num(void) const; | |||
std::string str(void) const; | |||
const ParamInfo ¶m_info(void) const; | |||
ArgInfo input_info(size_t idx) const; | |||
ArgInfo output_info(size_t idx) const; | |||
const std::vector<ArgInfo> &inputs_info(void) const; | |||
const std::vector<ArgInfo> &outputs_info(void) const; | |||
// use | |||
std::vector<Device> infer_output_device(const std::vector<Device>&, const Param&) const; | |||
std::vector<Shape> infer_output_shape (const std::vector<Shape>&, const Param&) const; | |||
std::vector<DType> infer_output_dtype (const std::vector<DType>&, const Param&) const; | |||
std::vector<Format> infer_output_format(const std::vector<Format>&, const Param&) const; | |||
void compute(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&) const; | |||
}; | |||
} |
@@ -0,0 +1,61 @@ | |||
/** | |||
* \file src/custom/include/megbrain/custom/param.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <vector> | |||
#include <string> | |||
#include <unordered_map> | |||
#include "param_val.h" | |||
namespace custom { | |||
class ParamSchemaImpl; | |||
class ParamInfoImpl; | |||
class ParamImpl; | |||
// Schema of a param element | |||
class ParamSchema { | |||
CUSTOM_PIMPL_CLS_DECL(ParamSchema); | |||
ParamSchema(const std::string &name, const ParamVal &value, const std::string &desc=""); | |||
const std::string &name(void) const; | |||
const std::string &desc(void) const; | |||
const ParamVal &default_val(void) const; | |||
ParamDynType type(void) const; | |||
std::string str(void) const; | |||
}; | |||
class ParamInfo { | |||
CUSTOM_PIMPL_CLS_DECL(ParamInfo); | |||
void set_tag(const std::string&); | |||
void set_meta(const std::vector<ParamSchema> &meta); | |||
uint32_t tag(void) const; | |||
std::vector<ParamSchema> &meta(void); | |||
const std::vector<ParamSchema> &meta(void) const; | |||
}; | |||
class Param { | |||
CUSTOM_PIMPL_CLS_DECL(Param); | |||
Param(const ParamInfo&); | |||
ParamVal &operator[](const std::string&); | |||
const ParamVal &operator[](const std::string&) const; | |||
const std::unordered_map<std::string, ParamVal> &raw() const; | |||
bool exist(const std::string &name) const; | |||
std::string to_bytes(void) const; | |||
void from_bytes(const std::string&); | |||
}; | |||
bool operator==(const Param&, const Param&); | |||
} // custom |
@@ -0,0 +1,290 @@ | |||
/** | |||
* \file src/custom/include/megbrain/custom/param_val.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <string> | |||
#include <vector> | |||
#include <cassert> | |||
#include <sstream> | |||
#include <memory> | |||
#include <unordered_map> | |||
#include "utils.h" | |||
namespace custom { | |||
/** | |||
* we can add a new basic data type here, basic means we can perform binary | |||
* op such as: +, -, *, /, ==, != between any two of them | |||
*/ | |||
#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ...) \ | |||
cb(Int32, int32_t, ##__VA_ARGS__) \ | |||
cb(Int64, int64_t, ##__VA_ARGS__) \ | |||
cb(Uint32, uint32_t, ##__VA_ARGS__) \ | |||
cb(Uint64, uint64_t, ##__VA_ARGS__) \ | |||
cb(Float32, float, ##__VA_ARGS__) \ | |||
cb(Float64, double, ##__VA_ARGS__) \ | |||
cb(Bool, bool, ##__VA_ARGS__) | |||
#define CUSTOM_FOR_STRING_PARAMTYPE(cb, ...) \ | |||
cb(String, std::string, ##__VA_ARGS__) | |||
#define CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ...) \ | |||
cb(Int32List, std::vector<int32_t>, ##__VA_ARGS__) \ | |||
cb(Int64List, std::vector<int64_t>, ##__VA_ARGS__) \ | |||
cb(Uint32List, std::vector<uint32_t>, ##__VA_ARGS__) \ | |||
cb(Uint64List, std::vector<uint64_t>, ##__VA_ARGS__) \ | |||
cb(Float32List, std::vector<float>, ##__VA_ARGS__) \ | |||
cb(Float64List, std::vector<double>, ##__VA_ARGS__) | |||
#define CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ...) \ | |||
cb(BoolList, std::vector<bool>, ##__VA_ARGS__) | |||
#define CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ...) \ | |||
cb(StringList, std::vector<std::string>, ##__VA_ARGS__) | |||
/** | |||
* to avoid the recursive of MACRO | |||
*/ | |||
#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY(cb, ...) \ | |||
cb(Int32, int32_t, ##__VA_ARGS__) \ | |||
cb(Int64, int64_t, ##__VA_ARGS__) \ | |||
cb(Uint32, uint32_t, ##__VA_ARGS__) \ | |||
cb(Uint64, uint64_t, ##__VA_ARGS__) \ | |||
cb(Float32, float, ##__VA_ARGS__) \ | |||
cb(Float64, double, ##__VA_ARGS__) \ | |||
cb(Bool, bool, ##__VA_ARGS__) | |||
#define CUSTOM_FOR_EACH_VALID_PARAMTYPE(cb, ...) \ | |||
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||
CUSTOM_FOR_STRING_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||
CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||
CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__) | |||
#define CUSTOM_FOR_EACH_LIST_PARAMTYPE(cb, ...) \ | |||
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||
CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ | |||
CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__) | |||
/** | |||
* Macro Callback for Register | |||
*/ | |||
#define CUSTOM_REG_DYN_PARAMTYPE(dyn_type, static_type) dyn_type, | |||
#define CUSTOM_REG_DYN_PARAMTYPE_NAME(dyn_type, static_type) {ParamDynType::dyn_type, #dyn_type}, | |||
#define CUSTOM_REG_DYN_PARAMTYPE_GETTER(dyn_type, static_type) \ | |||
template <> \ | |||
struct get_dyn_type<static_type> { \ | |||
static constexpr ParamDynType type = ParamDynType::dyn_type;\ | |||
}; | |||
#define CUSTOM_REG_STATIC_PARAMTYPE_GETTER(dyn_type, static_type) \ | |||
template <> \ | |||
struct get_static_type<ParamDynType::dyn_type> { \ | |||
using type = static_type; \ | |||
}; | |||
enum class ParamDynType: uint32_t { | |||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE) | |||
Invalid=255 | |||
}; | |||
static std::unordered_map<ParamDynType, std::string, EnumHash<ParamDynType>, EnumCmp<ParamDynType>> type2name = { | |||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_NAME) | |||
{ParamDynType::Invalid, "Invalid"} | |||
}; | |||
/** | |||
* get the dynamic data type according to the builtin static data type | |||
* we can use it like: | |||
* ParamDynType dyn_type = get_dyn_type<int32_t>::type; | |||
* assert(dyn_type == ParamDynType::Int32) | |||
*/ | |||
template <typename T> | |||
struct get_dyn_type { | |||
static constexpr ParamDynType type = ParamDynType::Invalid; | |||
}; | |||
/** | |||
* get the static data type according to the dynamic data type | |||
* we can use it like: | |||
* get_static_type<ParamDynType::Int32>::type int_32_value; | |||
* assert(std::is_same<decltype(int_32_value), int>::value) | |||
*/ | |||
template <ParamDynType> | |||
struct get_static_type; | |||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_GETTER) | |||
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_STATIC_PARAMTYPE_GETTER) | |||
#undef CUSTOM_REG_DYN_PARAMTYPE | |||
#undef CUSTOM_REG_DYN_PARAMTYPE_NAME | |||
#undef CUSTOM_REG_DYN_PARAMTYPE_GETTER | |||
#undef CUSTOM_REG_STATIC_PARAMTYPE_GETTER | |||
template <typename T> | |||
struct get_vector_template_arg_type; | |||
template <typename T> | |||
struct get_vector_template_arg_type<std::vector<T>> { | |||
using type = std::decay_t<T>; | |||
}; | |||
template <typename T> | |||
struct is_vector { | |||
static constexpr bool value = false; | |||
}; | |||
template <typename T> | |||
struct is_vector <std::vector<T>> { | |||
static constexpr bool value = true; | |||
}; | |||
template <typename T> | |||
std::string vec2str(const std::vector<T> &vec) { | |||
std::stringstream ss; | |||
ss << "{"; | |||
for (const auto &val: vec) { | |||
ss << val << ", "; | |||
} | |||
if (vec.size() != 0) { | |||
ss.seekp(ss.tellp()-std::streampos(2)); | |||
} | |||
ss << "}"; | |||
return ss.str(); | |||
} | |||
/** | |||
* we use void* rather than template to help us realise a complete dynamic type | |||
* if we use template such as: | |||
* template <typename T> | |||
* class ParamVal { | |||
* T m_data; | |||
* } | |||
* Con1: user need to set the type explicitly when class template instantiation | |||
* Con2: ParamVal<int> can not be assigned to ParamVal<double> | |||
*/ | |||
class ParamVal { | |||
std::unique_ptr<void, void_deleter> m_ptr; | |||
ParamDynType m_type; | |||
public: | |||
template <typename T> | |||
ParamVal(const T &val); | |||
template <typename T> | |||
ParamVal(const std::initializer_list<T> &val); | |||
ParamVal(); | |||
ParamVal(const char *str); | |||
ParamVal(const std::initializer_list<const char*> &strs); | |||
ParamVal(const std::vector<const char*> &strs); | |||
ParamVal(const ParamVal &rhs); | |||
template <typename T> | |||
ParamVal &operator=(const T &rhs); | |||
template <typename T> | |||
ParamVal &operator=(const std::initializer_list<T> &val); | |||
ParamVal &operator=(const char *str); | |||
ParamVal &operator=(const std::initializer_list<const char*> &strs); | |||
ParamVal &operator=(const std::vector<const char*> &strs); | |||
ParamVal &operator=(const ParamVal &rhs); | |||
template <typename T> | |||
const T &as(void) const; | |||
template <typename T> | |||
T &as(void); | |||
const void *raw_ptr(void) const; | |||
void *raw_ptr(void); | |||
ParamDynType type(void) const; | |||
std::string str(void) const; | |||
size_t size(void) const; | |||
static std::string to_bytes(const ParamVal &value); | |||
static ParamVal from_bytes(const std::string &bytes, size_t &offset); | |||
friend ParamVal operator+(const ParamVal &lhs, const ParamVal &rhs); | |||
friend ParamVal operator-(const ParamVal &lhs, const ParamVal &rhs); | |||
friend ParamVal operator*(const ParamVal &lhs, const ParamVal &rhs); | |||
friend ParamVal operator/(const ParamVal &lhs, const ParamVal &rhs); | |||
friend bool operator==(const ParamVal &lhs, const ParamVal &rhs); | |||
friend bool operator!=(const ParamVal &lhs, const ParamVal &rhs); | |||
friend bool operator> (const ParamVal &lhs, const ParamVal &rhs); | |||
friend bool operator< (const ParamVal &lhs, const ParamVal &rhs); | |||
friend bool operator>=(const ParamVal &lhs, const ParamVal &rhs); | |||
friend bool operator<=(const ParamVal &lhs, const ParamVal &rhs); | |||
}; | |||
ParamVal operator+(const ParamVal &lhs, const ParamVal &rhs); | |||
ParamVal operator-(const ParamVal &lhs, const ParamVal &rhs); | |||
ParamVal operator*(const ParamVal &lhs, const ParamVal &rhs); | |||
ParamVal operator/(const ParamVal &lhs, const ParamVal &rhs); | |||
bool operator==(const ParamVal &lhs, const ParamVal &rhs); | |||
bool operator!=(const ParamVal &lhs, const ParamVal &rhs); | |||
bool operator> (const ParamVal &lhs, const ParamVal &rhs); | |||
bool operator< (const ParamVal &lhs, const ParamVal &rhs); | |||
bool operator>=(const ParamVal &lhs, const ParamVal &rhs); | |||
bool operator<=(const ParamVal &lhs, const ParamVal &rhs); | |||
template <typename T> | |||
ParamVal::ParamVal(const T &val): m_ptr(nullptr, impl_deleter<std::decay_t<T>>) { | |||
using DecayType = std::decay_t<T>; | |||
m_type = get_dyn_type<DecayType>::type; | |||
custom_assert(m_type != ParamDynType::Invalid, "param construct error! unsupported builtin type"); | |||
m_ptr.reset(new DecayType(val)); | |||
} | |||
template <typename T> | |||
ParamVal::ParamVal(const std::initializer_list<T> &val): ParamVal(std::vector<std::decay_t<T>>(val)) { | |||
} | |||
template <typename T> | |||
ParamVal &ParamVal::operator=(const T &rhs) { | |||
using DecayType = std::decay_t<T>; | |||
ParamDynType rhs_dyn_type = get_dyn_type<DecayType>::type; | |||
custom_assert(rhs_dyn_type != ParamDynType::Invalid, "unsupported builtin dtype"); | |||
if (rhs_dyn_type == m_type) { | |||
TypedRef(DecayType, m_ptr.get()) = rhs; | |||
} | |||
else { | |||
m_type = rhs_dyn_type; | |||
std::unique_ptr<void, void_deleter> new_ptr(new DecayType(rhs), impl_deleter<DecayType>); | |||
m_ptr.swap(new_ptr); | |||
} | |||
return *this; | |||
} | |||
template <typename T> | |||
ParamVal &ParamVal::operator=(const std::initializer_list<T> &val) { | |||
return this->operator=(std::vector<std::decay_t<T>>(val)); | |||
} | |||
template <typename T> | |||
const T &ParamVal::as(void) const { | |||
return const_cast<ParamVal*>(this)->as<T>(); | |||
} | |||
template <typename T> | |||
T &ParamVal::as(void) { | |||
using DecayType = std::decay_t<T>; | |||
ParamDynType t_dyn_type = get_dyn_type<DecayType>::type; | |||
custom_assert( | |||
t_dyn_type == m_type, "type mismatch, type %s cannot be cast to type %s\n", | |||
type2name[m_type].c_str(), type2name[t_dyn_type].c_str() | |||
); | |||
return TypedRef(T, m_ptr.get()); | |||
} | |||
} |
@@ -0,0 +1,280 @@ | |||
/** | |||
* \file src/custom/include/megbrain/custom/tensor.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include <vector> | |||
#include <string> | |||
#include "utils.h" | |||
#include "accessor.h" | |||
namespace custom { | |||
#define CUSTOM_DATA_ADAPTOR_FRIEND_DECL \ | |||
template <typename BuiltinT, typename CustomT> \ | |||
friend BuiltinT to_builtin(const CustomT &custom); \ | |||
template <typename BuiltinT, typename CustomT> \ | |||
friend CustomT to_custom(const BuiltinT &builtin) | |||
#define CUSTOM_FOR_EACH_DEVICE_TYPE(cb) \ | |||
cb(x86, CPU, "cpux") \ | |||
cb(cuda, CUDA, "gpux") | |||
#define CUSTOM_DEVICE_TYPE_ENUM_DECL(custom_type, builtin_type, builtin_str) custom_type, | |||
class Device { | |||
const void *impl() const; | |||
Device(const void *impl); | |||
CUSTOM_PIMPL_CLS_DECL(Device); | |||
public: | |||
enum class DeviceEnum: uint32_t { | |||
CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_DEVICE_TYPE_ENUM_DECL) | |||
}; | |||
Device(const std::string &device); | |||
Device(const char *device); | |||
Device(DeviceEnum device); | |||
std::string str(void) const; | |||
DeviceEnum enumv(void) const; | |||
static bool is_legal(const std::string &device); | |||
static bool is_legal(DeviceEnum device); | |||
static std::vector<std::string> legal_devices(void); | |||
friend class Tensor; | |||
friend bool operator==(const Device &lhs, const Device &rhs); | |||
CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||
}; | |||
using DeviceEnum = Device::DeviceEnum; | |||
bool operator==(const Device &lhs, const Device &rhs); | |||
class Shape { | |||
const void *impl() const; | |||
Shape(const void *impl); | |||
CUSTOM_PIMPL_CLS_DECL(Shape); | |||
public: | |||
Shape(const std::vector<size_t> &rhs); | |||
Shape(const std::initializer_list<size_t> &rhs); | |||
size_t &operator[](size_t idx); | |||
size_t operator[](size_t idx) const; | |||
void ndim(size_t dim); | |||
size_t ndim(void) const; | |||
friend class Tensor; | |||
friend bool operator==(const Shape &lhs, const Shape &rhs); | |||
CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||
}; | |||
bool operator==(const Shape &lhs, const Shape &rhs); | |||
using float16_t = uint16_t; | |||
using bfloat16_t = uint16_t; | |||
#if MEGDNN_DISABLE_FLOAT16 | |||
#define fp16_wrap(cb, custom_dtype, dnn_dtype, c_dtype) | |||
#else | |||
#define fp16_wrap(cb, custom_dtype, dnn_dtype, c_dtype) cb(custom_dtype, dnn_dtype, c_dtype) | |||
#endif | |||
#define CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(cb) \ | |||
cb(float32, Float32, float) \ | |||
cb(uint8, Uint8, uint8_t) \ | |||
cb(int8, Int8, int8_t) \ | |||
cb(int16, Int16, int16_t) \ | |||
cb(int32, Int32, int32_t) \ | |||
fp16_wrap(cb, float16, Float16, float16_t) \ | |||
fp16_wrap(cb, bfloat16, BFloat16, bfloat16_t) \ | |||
cb(uint16, Uint16, uint16_t) \ | |||
cb(quint8, Quantized8Asymm, uint8_t) \ | |||
cb(qint32, QuantizedS32, int32_t) \ | |||
cb(qint8, QuantizedS8, int8_t) \ | |||
cb(qint16, QuantizedS16, int16_t) | |||
#define CUSTOM_DTYPE_ENUM_DECL(custom_type, builtin_type, ctype) custom_type, | |||
class DType { | |||
const void *impl() const; | |||
DType(const void *impl); | |||
CUSTOM_PIMPL_CLS_DECL(DType); | |||
public: | |||
enum class DTypeEnum: uint32_t { | |||
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_DTYPE_ENUM_DECL) | |||
}; | |||
DType(const std::string &dtype); | |||
DType(const char *dtype); | |||
DType(const std::string &dtype, float scale, uint8_t zero_point = 0); | |||
DType(const char *dtype, float scale, uint8_t zero_point = 0); | |||
DType(DTypeEnum dtype); | |||
DType(DTypeEnum dtype, float scale, uint8_t zero_point = 0); | |||
std::string str(void) const; | |||
DTypeEnum enumv() const; | |||
float scale(void) const; | |||
uint8_t zero_point(void) const; | |||
template<typename T> | |||
bool is_compatible(void) const; | |||
static bool is_legal(const std::string &dtype); | |||
static bool is_legal(const DTypeEnum &dtype); | |||
static std::vector<std::string> legal_dtypes(void); | |||
friend class Tensor; | |||
friend bool operator==(const DType &lhs, const DType &rhs); | |||
CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||
}; | |||
using DTypeEnum = DType::DTypeEnum; | |||
template <DTypeEnum> | |||
struct DTypeTrait; | |||
#define CUSTOM_DEFINE_DTYPE_TRAIT(custom_type, builtin_type, ctype) \ | |||
template <> \ | |||
struct DTypeTrait<DTypeEnum::custom_type> { \ | |||
using type = ctype; \ | |||
}; | |||
#define CUSTOM_CASE_TO_COMPARE_DTYPE(custom_type, builtin_type, ctype) \ | |||
case (DTypeEnum::custom_type): { \ | |||
return std::is_same<DecayT, ctype>::value; \ | |||
} | |||
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_DEFINE_DTYPE_TRAIT) | |||
template<typename T> | |||
bool DType::is_compatible(void) const { | |||
using DecayT = typename std::decay<T>::type; | |||
auto dtype_enum = enumv(); | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
if (dtype_enum == DTypeEnum::float16) { | |||
return sizeof(DecayT) == sizeof(DTypeTrait<DTypeEnum::float16>::type); | |||
} | |||
else if (dtype_enum == DTypeEnum::bfloat16) { | |||
return sizeof(DecayT) == sizeof(DTypeTrait<DTypeEnum::bfloat16>::type); | |||
} | |||
#endif | |||
switch (dtype_enum) { | |||
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_CASE_TO_COMPARE_DTYPE) | |||
default: | |||
return false; | |||
} | |||
} | |||
bool operator==(const DType &lhs, const DType &rhs); | |||
bool operator==(const DType &lhs, const std::string &rhs); | |||
bool operator==(const DType &lhs, const char *rhs); | |||
bool operator==(const std::string &lhs, const DType &rhs); | |||
bool operator==(const char *lhs, const DType &rhs); | |||
class Format { | |||
const void *impl() const; | |||
Format(const void *impl); | |||
CUSTOM_PIMPL_CLS_DECL(Format); | |||
public: | |||
Format(const std::string &format); | |||
Format(const char *format); | |||
std::string str(void) const; | |||
bool is_default(void) const; | |||
friend class Tensor; | |||
CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||
}; | |||
class Tensor { | |||
void *m_tensor; | |||
const void *impl(void) const; | |||
Tensor(const void *impl); | |||
const size_t *shapes_raw(void) const; | |||
const ptrdiff_t *strides_raw(void) const; | |||
public: | |||
Tensor() = delete; | |||
Tensor(const Tensor &rhs); | |||
Tensor &operator=(const Tensor &rhs); | |||
Shape shape(void) const; | |||
DType dtype(void) const; | |||
Format format(void) const; | |||
Device device(void) const; | |||
size_t size(void) const; | |||
std::vector<ptrdiff_t> stride(void) const; | |||
float scale(void) const; | |||
uint8_t zero_point(void) const; | |||
void *data(void); | |||
const void *data(void) const; | |||
template <typename T> | |||
T *data(void); | |||
template <typename T> | |||
const T *data(void) const; | |||
template <typename T, size_t N, | |||
template <typename U> class PtrTraits = DefaultPtrTraits, | |||
typename index_t = int64_t> | |||
const TensorAccessor<T, N, PtrTraits, index_t> accessor() const; | |||
template <typename T, size_t N, | |||
template <typename U> class PtrTraits = DefaultPtrTraits, | |||
typename index_t = int64_t> | |||
TensorAccessor<T, N, PtrTraits, index_t> accessor(); | |||
CUSTOM_DATA_ADAPTOR_FRIEND_DECL; | |||
}; | |||
template <typename T> | |||
T *Tensor::data(void) { | |||
custom_assert(dtype().is_compatible<T>(), | |||
"invalid convert, tensor data type is %s", dtype().str().c_str()); | |||
return reinterpret_cast<T*>(data()); | |||
} | |||
template <typename T> | |||
const T *Tensor::data(void) const { | |||
return const_cast<Tensor*>(this)->data<T>(); | |||
} | |||
template <typename T, size_t N, template <typename U> class PtrTraits, typename index_t> | |||
const TensorAccessor<T, N, PtrTraits, index_t> Tensor::accessor() const { | |||
return const_cast<Tensor*>(this)->accessor<T, N, PtrTraits, index_t>(); | |||
} | |||
template <typename T, size_t N, template <typename U> class PtrTraits, typename index_t> | |||
TensorAccessor<T, N, PtrTraits, index_t> Tensor::accessor() { | |||
custom_assert(N == shape().ndim(), | |||
"cannot get a %lu-d accessor for a tensor with dim %lu", static_cast<unsigned long>(N), static_cast<unsigned long>(shape().ndim())); | |||
custom_assert(N > 0, "cannot get 0-d accessor"); | |||
T *ptr = data<T>(); | |||
return TensorAccessor<T, N, PtrTraits, index_t>(ptr, shapes_raw(), strides_raw()); | |||
} | |||
#undef CUSTOM_DATA_ADAPTOR_FRIEND_DECL | |||
#undef CUSTOM_DEVICE_TYPE_ENUM_DECL | |||
#undef CUSTOM_DTYPE_ENUM_DECL | |||
#undef CUSTOM_DEFINE_DTYPE_TRAIT | |||
#undef CUSTOM_CASE_TO_COMPARE_DTYPE | |||
} // custom |
@@ -0,0 +1,104 @@ | |||
/** | |||
* \file src/custom/include/megbrain/custom/utils.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include <vector> | |||
#include <string> | |||
#include <memory> | |||
#include <cassert> | |||
namespace custom { | |||
void assert_failed_log(const char *file, int line, const char *func, const char *expr, const char *msg_fmt, ...); | |||
#define custom_expect(expr, msg...) \ | |||
if (!(expr)) { \ | |||
assert_failed_log( \ | |||
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg \ | |||
); \ | |||
} | |||
#define custom_assert(expr, msg...) \ | |||
if (!(expr)) { \ | |||
assert_failed_log( \ | |||
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg \ | |||
); \ | |||
} \ | |||
assert((expr)) | |||
class UnImpleWarnLog { | |||
public: | |||
UnImpleWarnLog(const std::string &func, const std::string &attr, | |||
const std::string &val); | |||
}; | |||
using void_deleter = void(*)(void*); | |||
template<typename Impl> | |||
void impl_deleter(void *ptr) { | |||
delete reinterpret_cast<Impl*>(ptr); | |||
} | |||
#define TypedPtr(type, raw_ptr) reinterpret_cast<type*>(raw_ptr) | |||
#define TypedRef(type, raw_ptr) (*reinterpret_cast<type*>(raw_ptr)) | |||
#define CUSTOM_PIMPL_CLS_DECL(Cls) \ | |||
std::unique_ptr<void, void_deleter> m_impl; \ | |||
public: \ | |||
Cls(); \ | |||
Cls(const Cls &rhs); \ | |||
Cls &operator=(const Cls &rhs) | |||
#define CUSTOM_PIMPL_CLS_DEFINE(Cls) \ | |||
Cls::Cls(): m_impl(new Cls##Impl(), impl_deleter<Cls##Impl>) {} \ | |||
\ | |||
Cls::Cls(const Cls &rhs): m_impl(nullptr, impl_deleter<Cls##Impl>) { \ | |||
custom_assert( \ | |||
rhs.m_impl != nullptr, \ | |||
"invalid rhs for the copy constructor of %s", #Cls \ | |||
); \ | |||
m_impl.reset(new Cls##Impl(TypedRef(Cls##Impl, rhs.m_impl.get()))); \ | |||
} \ | |||
\ | |||
Cls &Cls::operator=(const Cls &rhs) { \ | |||
custom_assert( \ | |||
m_impl != nullptr && rhs.m_impl != nullptr, \ | |||
"invalid assignment of %s, lhs or rhs is invalid", #Cls \ | |||
); \ | |||
if (&rhs == this) \ | |||
return *this; \ | |||
\ | |||
TypedRef(Cls##Impl, m_impl.get()) = TypedRef(Cls##Impl, rhs.m_impl.get()); \ | |||
return *this; \ | |||
} | |||
/** | |||
* we define this two function explicitly used for std::unordered_map | |||
* to improve the compatibility with different compiler versions | |||
*/ | |||
template <typename T> | |||
struct EnumHash { | |||
size_t operator()(const T &rhs) const { | |||
return static_cast<size_t>(rhs); | |||
} | |||
}; | |||
template <typename T> | |||
struct EnumCmp { | |||
bool operator()(const T &lhs, const T &rhs) const { | |||
return static_cast<size_t>(lhs) == static_cast<size_t>(rhs); | |||
} | |||
}; | |||
} // custom |
@@ -0,0 +1,96 @@ | |||
/** | |||
* \file src/custom/test/manager.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/custom/manager.h" | |||
#include "megbrain/custom/custom.h" | |||
#include "gtest/gtest.h" | |||
#define MANAGER_TEST_LOG 0 | |||
namespace custom { | |||
TEST(TestOpManager, TestOpManager) { | |||
CustomOpManager *com = CustomOpManager::inst(); | |||
com->insert("Op1", CUSTOM_OP_VERSION); | |||
com->insert("Op2", CUSTOM_OP_VERSION); | |||
std::shared_ptr<CustomOp> ptr = com->find_or_reg("Op3", CUSTOM_OP_VERSION); | |||
ASSERT_TRUE(ptr != nullptr); | |||
std::vector<std::string> op_names = com->op_name_list(); | |||
std::vector<RunTimeId> op_ids = com->op_id_list(); | |||
ASSERT_TRUE(op_names.size() == 3); | |||
ASSERT_TRUE(op_ids.size() == 3); | |||
#if MANAGER_TEST_LOG | |||
for (std::string &name: op_names) { | |||
std::cout << name << std::endl; | |||
} | |||
#endif | |||
for (std::string &name: op_names) { | |||
std::shared_ptr<const CustomOp> op = com->find(name); | |||
ASSERT_TRUE(op != nullptr); | |||
ASSERT_TRUE(op->op_type() == name); | |||
RunTimeId id = com->to_id(name); | |||
ASSERT_TRUE(com->find(id) == op); | |||
} | |||
for (RunTimeId &id: op_ids) { | |||
std::shared_ptr<const CustomOp> op = com->find(id); | |||
ASSERT_TRUE(op != nullptr); | |||
ASSERT_TRUE(op->runtime_id() == id); | |||
std::string name = com->to_name(id); | |||
ASSERT_TRUE(com->find(name) == op); | |||
} | |||
ASSERT_FALSE(com->erase("Op0")); | |||
#if MANAGER_TEST_LOG | |||
for (auto &name: com->op_name_list()) { | |||
std::cout << name << std::endl; | |||
} | |||
#endif | |||
ASSERT_TRUE(com->erase("Op1")); | |||
ASSERT_TRUE(com->erase(com->to_id("Op2"))); | |||
ASSERT_TRUE(com->op_id_list().size() == 1); | |||
ASSERT_TRUE(com->op_name_list().size() == 1); | |||
ASSERT_TRUE(com->op_name_list()[0] == "Op3"); | |||
ptr.reset(); | |||
ASSERT_TRUE(com->erase("Op3")); | |||
} | |||
TEST(TestOpManager, TestOpReg) { | |||
CUSTOM_OP_REG(Op1) | |||
.add_inputs(2) | |||
.add_outputs(3) | |||
.add_input("lhs") | |||
.add_param("param1", 1) | |||
.add_param("param2", 3.45); | |||
CUSTOM_OP_REG(Op2) | |||
.add_input("lhs") | |||
.add_input("rhs") | |||
.add_output("out") | |||
.add_param("param1", "test") | |||
.add_param("param2", true) | |||
.add_param("", "no name"); | |||
(void)_Op1; | |||
(void)_Op2; | |||
#if MANAGER_TEST_LOG | |||
for (const auto &name: CustomOpManager::inst()->op_name_list()) { | |||
std::cout << CustomOpManager::inst()->find(name)->str() << std::endl; | |||
} | |||
#endif | |||
} | |||
} |
@@ -0,0 +1,205 @@ | |||
/** | |||
* \file src/custom/test/op.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/custom/op.h" | |||
#include "megbrain/comp_node.h" | |||
#include "megbrain/tensor.h" | |||
#include "megbrain/custom/data_adaptor.h" | |||
#include "gtest/gtest.h" | |||
#include "megbrain_build_config.h" | |||
#define OP_TEST_LOG 0 | |||
using namespace mgb; | |||
namespace custom { | |||
TEST(TestCustomOp, TestCustomOpInfoSetter) { | |||
CustomOp test("TestOp", CUSTOM_OP_VERSION); | |||
test.set_description("Test Op") | |||
.add_input("lhs", "lhs of test op", {"float32", "int32"}, 2) | |||
.add_inputs(2) | |||
.add_input("rhs", "rhs of test op", {"float32", "int32"}, 2) | |||
.add_outputs(1) | |||
.add_output("out", "out of test op", {"float32", "int32"}, 2) | |||
.add_outputs(3); | |||
ASSERT_TRUE(test.op_type() == "TestOp"); | |||
ASSERT_TRUE(test.op_desc() == "Test Op"); | |||
ASSERT_TRUE(test.input_num() == 4); | |||
ASSERT_TRUE(test.output_num() == 5); | |||
#if OP_TEST_LOG | |||
for (auto input: test.inputs_info()) { | |||
std::cout << input.str() << std::endl; | |||
} | |||
for (auto output: test.outputs_info()) { | |||
std::cout << output.str() << std::endl; | |||
} | |||
#endif | |||
test.add_param("param1", "param1 - float", 1.23f) | |||
.add_param("param2", "param2 - float list", {2.34f, 3.45f}) | |||
.add_param("param3", "param3 - string", "test-string") | |||
.add_param("param4", {"test", "string", "list"}) | |||
.add_param("param5", 1); | |||
#if OP_TEST_LOG | |||
ParamInfo pinfo = test.param_info(); | |||
for (auto kv: pinfo.meta()) { | |||
std::cout << kv.str() << std::endl; | |||
} | |||
#endif | |||
} | |||
void device_infer(const std::vector<Device> &inputs, const Param ¶ms, | |||
std::vector<Device> &outputs) { | |||
(void)inputs; | |||
(void)params; | |||
(void)outputs; | |||
outputs[0] = inputs[1]; | |||
outputs[1] = inputs[0]; | |||
} | |||
void shape_infer(const std::vector<Shape> &inputs, const Param ¶ms, | |||
std::vector<Shape> &outputs) { | |||
(void)inputs; | |||
(void)params; | |||
(void)outputs; | |||
outputs[0] = inputs[1]; | |||
outputs[1] = inputs[0]; | |||
} | |||
void dtype_infer(const std::vector<DType> &inputs, const Param ¶ms, | |||
std::vector<DType> &outputs) { | |||
(void)inputs; | |||
(void)params; | |||
(void)outputs; | |||
outputs[0] = inputs[1]; | |||
outputs[1] = inputs[0]; | |||
} | |||
void format_infer(const std::vector<Format> &inputs, const Param ¶ms, | |||
std::vector<Format> &outputs) { | |||
(void)inputs; | |||
(void)params; | |||
(void)outputs; | |||
outputs[0] = inputs[1]; | |||
outputs[1] = inputs[0]; | |||
} | |||
void cpu_kernel(const std::vector<Tensor> &inputs, const Param ¶ms, | |||
std::vector<Tensor> &outputs) { | |||
(void)inputs; | |||
(void)params; | |||
(void)outputs; | |||
#if OP_TEST_LOG | |||
std::cout << "Checking CPU Forward - " << params["device"].as<std::string>() << std::endl; | |||
#endif | |||
ASSERT_TRUE(params["device"] == "x86"); | |||
} | |||
void gpu_kernel(const std::vector<Tensor> &inputs, const Param ¶ms, | |||
std::vector<Tensor> &outputs) { | |||
(void)inputs; | |||
(void)params; | |||
(void)outputs; | |||
#if OP_TEST_LOG | |||
std::cout << "Checking GPU Forward - " << params["device"].as<std::string>() << std::endl; | |||
#endif | |||
ASSERT_TRUE(params["device"] == "cuda"); | |||
} | |||
TEST(TestCustomOp, TestCustomOpFuncSetter) { | |||
#if MGB_CUDA | |||
CustomOp test("TestOp", CUSTOM_OP_VERSION); | |||
test.set_description("Test Op Forward Backward Union") | |||
.add_input("lhs", "lhs of Test op", {"float32", "int32"}, 2) | |||
.add_input("rhs", "rhs of Test op", {"float32", "int32"}, 2) | |||
.add_output("outl", "outl of Test op", {"float32", "int32"}, 2) | |||
.add_output("outr", "outr of Test op", {"float32", "int32"}, 2) | |||
.add_param("smooth", "smooth", 0.f) | |||
.add_param("device", "using for judge device", "x86"); | |||
std::vector<Device> idevices = {"x86", "cuda"}; | |||
std::vector<Shape> ishapes = {{2, 3}, {3, 4}}; | |||
std::vector<DType> idtypes = {"int32", "float32"}; | |||
std::vector<Format> iformats = {"default", "default"}; | |||
Param param(test.param_info()); | |||
std::vector<Device> odevices = test.infer_output_device(idevices, param); | |||
std::vector<Shape> oshapes = test.infer_output_shape (ishapes, param); | |||
std::vector<DType> odtypes = test.infer_output_dtype (idtypes, param); | |||
std::vector<Format> oformats = test.infer_output_format(iformats, param); | |||
ASSERT_TRUE(odevices.size() == 2); | |||
ASSERT_TRUE(oshapes.size() == 2); | |||
ASSERT_TRUE(odtypes.size() == 2); | |||
ASSERT_TRUE(oformats.size() == 2); | |||
ASSERT_TRUE(odevices[0] == "x86"); | |||
ASSERT_TRUE(odevices[1] == "x86"); | |||
ASSERT_TRUE(oshapes[0] == Shape({2,3})); | |||
ASSERT_TRUE(oshapes[1] == Shape({2,3})); | |||
ASSERT_TRUE(odtypes[0] == "int32"); | |||
ASSERT_TRUE(odtypes[1] == "int32"); | |||
ASSERT_TRUE(iformats[0].is_default()); | |||
ASSERT_TRUE(iformats[1].is_default()); | |||
test.set_device_infer(device_infer) | |||
.set_shape_infer(shape_infer) | |||
.set_dtype_infer(dtype_infer) | |||
.set_format_infer(format_infer); | |||
odevices = test.infer_output_device(idevices, param); | |||
oshapes = test.infer_output_shape (ishapes, param); | |||
odtypes = test.infer_output_dtype (idtypes, param); | |||
oformats = test.infer_output_format(iformats, param); | |||
ASSERT_TRUE(odevices.size() == 2); | |||
ASSERT_TRUE(oshapes.size() == 2); | |||
ASSERT_TRUE(odtypes.size() == 2); | |||
ASSERT_TRUE(oformats.size() == 2); | |||
ASSERT_TRUE(odevices[0] == "cuda"); | |||
ASSERT_TRUE(odevices[1] == "x86"); | |||
ASSERT_TRUE(oshapes[0] == Shape({3,4})); | |||
ASSERT_TRUE(oshapes[1] == Shape({2,3})); | |||
ASSERT_TRUE(odtypes[0] == "float32"); | |||
ASSERT_TRUE(odtypes[1] == "int32"); | |||
ASSERT_TRUE(iformats[0].is_default()); | |||
ASSERT_TRUE(iformats[1].is_default()); | |||
test.set_compute(cpu_kernel); | |||
DeviceTensorND cdev_itensor0(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); | |||
DeviceTensorND cdev_itensor1(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); | |||
DeviceTensorND cdev_otensor0(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); | |||
DeviceTensorND cdev_otensor1(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); | |||
std::vector<Tensor> cinputs = {to_custom_tensor(cdev_itensor0), to_custom_tensor(cdev_itensor1)}; | |||
std::vector<Tensor> coutputs ={to_custom_tensor(cdev_otensor0), to_custom_tensor(cdev_otensor1)}; | |||
param["device"] = "x86"; | |||
test.compute(cinputs, param, coutputs); | |||
test.set_compute("cuda", gpu_kernel); | |||
DeviceTensorND gdev_itensor0(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); | |||
DeviceTensorND gdev_itensor1(CompNode::load("gpux"), {3, 2}, dtype::Float32{}); | |||
DeviceTensorND gdev_otensor0(CompNode::load("gpux"), {3, 2}, dtype::Float32{}); | |||
DeviceTensorND gdev_otensor1(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); | |||
std::vector<Tensor> ginputs = {to_custom_tensor(gdev_itensor0), to_custom_tensor(gdev_itensor1)}; | |||
std::vector<Tensor> goutputs ={to_custom_tensor(gdev_otensor0), to_custom_tensor(gdev_otensor1)}; | |||
param["device"] = "cuda"; | |||
test.compute(ginputs, param, goutputs); | |||
#endif | |||
} | |||
} |
@@ -0,0 +1,208 @@ | |||
/** | |||
* \file src/custom/test/param.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/custom/param.h" | |||
#include "gtest/gtest.h" | |||
#include <iostream> | |||
#define PARAM_TEST_LOG 0 | |||
namespace custom { | |||
#define SchemaDef \ | |||
ParamSchema schema_bool("param_bool", true, "bool"); \ | |||
ParamSchema schema_flt("param_flt", 2.3f, "float"); \ | |||
ParamSchema schema_int("param_int", 4, "int"); \ | |||
ParamSchema schema_str("param_str", "test", "string"); \ | |||
ParamSchema schema_bool_list("param_bl", {true, false, true}, "bool list"); \ | |||
ParamSchema schema_flt_list("param_fl", {1.1f, 2.2f, 3.3f}, "float list"); \ | |||
ParamSchema schema_int_list("param_il", {1, 2, 3}, "int list"); \ | |||
ParamSchema schema_str_list("param_sl", {"test1", "test2", "test3"}, "string list") | |||
#define InfoDef \ | |||
info.meta().emplace_back(schema_bool); \ | |||
info.meta().emplace_back(schema_flt); \ | |||
info.meta().emplace_back(schema_int); \ | |||
info.meta().emplace_back(schema_str); \ | |||
info.meta().emplace_back(schema_bool_list); \ | |||
info.meta().emplace_back(schema_flt_list); \ | |||
info.meta().emplace_back(schema_int_list); \ | |||
info.meta().emplace_back(schema_str_list) | |||
TEST(TestParam, TestParamScheme) { | |||
#if PARAM_TEST_LOG | |||
SchemaDef; | |||
ParamSchema new_schema = schema_int; | |||
std::cout << schema_bool.str() << std::endl; | |||
std::cout << schema_flt.str() << std::endl; | |||
std::cout << schema_int.str() << std::endl; | |||
std::cout << schema_str.str() << std::endl; | |||
std::cout << schema_bool_list.str() << "len: "<< schema_bool_list.default_val().size() << std::endl; | |||
std::cout << schema_flt_list.str() << "len: "<< schema_flt_list.default_val().size() << std::endl; | |||
std::cout << schema_int_list.str() << "len: "<< schema_int_list.default_val().size() << std::endl; | |||
std::cout << schema_str_list.str() << "len: "<< schema_str_list.default_val().size() << std::endl; | |||
std::cout << new_schema.str() << std::endl; | |||
#endif | |||
} | |||
TEST(TestParam, TestParamVal) { | |||
ParamVal pv1 = 1.2f, pv2 = true, pv3 = "test", pv4 = {0, 1, 2}, | |||
pv5 = {true, false, true}; | |||
#if PARAM_TEST_LOG | |||
ParamVal pv6 = {"test1", "test2", "test3"}; | |||
std::cout << pv1.str() << std::endl; | |||
std::cout << pv2.str() << std::endl; | |||
std::cout << pv3.str() << std::endl; | |||
std::cout << pv4.str() << std::endl; | |||
std::cout << pv5.str() << std::endl; | |||
std::cout << pv6.str() << std::endl; | |||
#endif | |||
ParamVal pv_manip = pv1; | |||
ASSERT_TRUE(pv_manip.type() == pv1.type()); | |||
ASSERT_TRUE(pv_manip == pv1); | |||
pv_manip = 1.3; | |||
ASSERT_TRUE(pv_manip.type() != pv1.type()); | |||
ASSERT_TRUE(pv_manip != pv1); | |||
ASSERT_TRUE(pv_manip > pv1); | |||
pv_manip = pv_manip + pv1; | |||
ASSERT_TRUE(pv_manip.type() == ParamDynType::Float64); | |||
ASSERT_TRUE(pv_manip == 1.3 + 1.2f); | |||
pv_manip = 1.3f + 1.2f; | |||
ASSERT_TRUE(pv_manip.type() == pv1.type()); | |||
pv_manip = false; | |||
ASSERT_TRUE(pv_manip.type() == pv2.type()); | |||
ASSERT_TRUE(pv_manip.type() == ParamDynType::Bool); | |||
ASSERT_TRUE(pv_manip != pv2); | |||
pv_manip = "test"; | |||
ASSERT_TRUE(pv_manip.type() == pv3.type()); | |||
ASSERT_TRUE(pv_manip.type() == ParamDynType::String); | |||
ASSERT_TRUE(pv_manip == pv3); | |||
pv_manip = "test1"; | |||
ASSERT_TRUE(pv_manip > pv3); | |||
pv_manip = pv_manip + pv3; | |||
ASSERT_TRUE(pv_manip == "test1test"); | |||
pv_manip = {0, 1, 2}; | |||
ASSERT_TRUE(pv_manip.type() == pv4.type()); | |||
ASSERT_TRUE(pv_manip.type() == ParamDynType::Int32List); | |||
ASSERT_TRUE(pv_manip == pv4); | |||
pv_manip = {3, 2, 1}; | |||
ASSERT_TRUE(pv_manip != pv4); | |||
ASSERT_TRUE(pv_manip > pv4); | |||
pv_manip = {true, false, true}; | |||
ASSERT_TRUE(pv_manip.type() == pv5.type()); | |||
ASSERT_TRUE(pv_manip.type() == ParamDynType::BoolList); | |||
ASSERT_TRUE(pv_manip == pv5); | |||
pv_manip = {false, true, false}; | |||
ASSERT_TRUE(pv_manip != pv5); | |||
} | |||
TEST(TestParam, TestParamInfo) { | |||
ParamInfo info; | |||
info.set_tag("Test"); | |||
#if PARAM_TEST_LOG | |||
uint32_t tag = info.tag(); | |||
std::cout << tag << std::endl; | |||
#endif | |||
SchemaDef; | |||
InfoDef; | |||
ParamInfo new_info1, new_info2; | |||
new_info1.set_meta(info.meta()); | |||
new_info2.meta() = info.meta(); | |||
#if PARAM_TEST_LOG | |||
for (auto ele: new_info1.meta()) { | |||
std::cout << ele.str() << std::endl; | |||
} | |||
for (auto ele: new_info2.meta()) { | |||
std::cout << ele.str() << std::endl; | |||
} | |||
#endif | |||
} | |||
TEST(TestParam, TestParam) { | |||
ParamInfo info; | |||
SchemaDef; | |||
InfoDef; | |||
Param param(info); | |||
#if PARAM_TEST_LOG | |||
std::vector<std::string> names = {"param_bool", "param_flt", "param_int", "param_str", "param_bl", "param_fl", "param_il", "param_sl"}; | |||
for (auto &name: names) { | |||
std::cout << param[name].str() << std::endl;; | |||
} | |||
#endif | |||
ASSERT_TRUE(param["param_bool"] == true); | |||
ASSERT_TRUE(param["param_flt"] == 2.3f); | |||
ASSERT_TRUE(param["param_int"] == 4); | |||
ASSERT_TRUE(param["param_str"] == "test"); | |||
ASSERT_TRUE(param["param_bl"] == ParamVal({true, false, true})); | |||
ASSERT_TRUE(param["param_fl"] == ParamVal({1.1f, 2.2f, 3.3f})); | |||
ASSERT_TRUE(param["param_il"] == ParamVal({1, 2, 3})); | |||
ASSERT_TRUE(param["param_sl"] == ParamVal({"test1", "test2", "test3"})); | |||
param["param_bool"] = false; | |||
param["param_flt"] = 3.4f; | |||
param["param_int"] = 5; | |||
param["param_str"] = "tset"; | |||
param["param_bl"] = {false, true, false, true}; | |||
param["param_fl"] = {7.6f, 6.5f}; | |||
param["param_il"] = {5, 4, 3, 2, 1}; | |||
param["param_sl"] = {"1tset", "2tset", "3tset", "4tset", "5tset"}; | |||
ASSERT_TRUE(param["param_bool"] != true); | |||
ASSERT_TRUE(param["param_flt"] != 2.3f); | |||
ASSERT_TRUE(param["param_int"] != 4); | |||
ASSERT_TRUE(param["param_str"] != "test"); | |||
ASSERT_TRUE(param["param_bl"] != ParamVal({true, false, true})); | |||
ASSERT_TRUE(param["param_fl"] != ParamVal({1.1f, 2.2f, 3.3f})); | |||
ASSERT_TRUE(param["param_il"] != ParamVal({1, 2, 3})); | |||
ASSERT_TRUE(param["param_sl"] != ParamVal({"test1", "test2", "test3"})); | |||
ASSERT_TRUE(param["param_bool"] == false); | |||
ASSERT_TRUE(param["param_flt"] == 3.4f); | |||
ASSERT_TRUE(param["param_int"] == 5); | |||
ASSERT_TRUE(param["param_str"] == "tset"); | |||
ASSERT_TRUE(param["param_bl"] == ParamVal({false, true, false, true})); | |||
ASSERT_TRUE(param["param_fl"] == ParamVal({7.6f, 6.5f})); | |||
ASSERT_TRUE(param["param_il"] == ParamVal({5, 4, 3, 2, 1})); | |||
ASSERT_TRUE(param["param_sl"] == ParamVal({"1tset", "2tset", "3tset", "4tset", "5tset"})); | |||
#if PARAM_TEST_LOG | |||
Param copy_param = param; | |||
for (auto &name: names) { | |||
std::cout << copy_param[name].str() << std::endl; | |||
} | |||
#endif | |||
Param loaded_param(info); | |||
std::string bytes = param.to_bytes(); | |||
loaded_param.from_bytes(bytes); | |||
#if PARAM_TEST_LOG | |||
for (auto &kv: loaded_param.raw()) { | |||
std::cout << kv.first << ":\n" << kv.second.str() << std::endl; | |||
} | |||
#endif | |||
} | |||
} |
@@ -0,0 +1,325 @@ | |||
/** | |||
* \file src/custom/test/tensor.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/custom/tensor.h" | |||
#include "megbrain/custom/data_adaptor.h" | |||
#include "megbrain/comp_node.h" | |||
#include "megbrain/tensor.h" | |||
#include "gtest/gtest.h" | |||
#include "megbrain_build_config.h" | |||
#define TENSOR_TEST_LOG 0 | |||
using namespace mgb; | |||
namespace custom { | |||
TEST(TestDevice, TestDevice) { | |||
#if MGB_CUDA | |||
ASSERT_TRUE(Device::is_legal("x86")); | |||
ASSERT_TRUE(Device::is_legal(DeviceEnum::cuda)); | |||
ASSERT_FALSE(Device::is_legal("cpu")); | |||
Device dev1; | |||
ASSERT_TRUE(dev1.str() == "invalid"); | |||
dev1 = "x86"; | |||
ASSERT_TRUE("x86" == dev1); | |||
Device dev2 = "cuda"; | |||
ASSERT_TRUE(dev2 == "cuda"); | |||
ASSERT_FALSE(dev2 == dev1); | |||
Device dev3 = dev2; | |||
ASSERT_TRUE(dev3 == dev2); | |||
ASSERT_FALSE(dev3 == dev1); | |||
Device dev4 = DeviceEnum::cuda; | |||
ASSERT_TRUE(dev4.enumv() == DeviceEnum::cuda); | |||
#if TENSOR_TEST_LOG | |||
std::cout << dev1.str() << "\n" << dev2.str() << "\n" | |||
<< dev3.str() << "\n" << dev4.str() << std::endl; | |||
#endif | |||
CompNode compnode = to_builtin<CompNode, Device>(dev3); | |||
ASSERT_TRUE(compnode.to_string_logical() == "gpux:0"); | |||
compnode = CompNode::load("cpu0:0"); | |||
Device dev5 = to_custom<CompNode, Device>(compnode); | |||
ASSERT_TRUE(dev5.str() == "x86"); | |||
std::vector<Device> devs1 = {"x86", "cuda", "x86"}; | |||
megdnn::SmallVector<CompNode> compnodes = to_builtin<CompNode, Device>(devs1); | |||
ASSERT_TRUE(compnodes[0].to_string_logical() == "cpux:0"); | |||
ASSERT_TRUE(compnodes[1].to_string_logical() == "gpux:0"); | |||
ASSERT_TRUE(compnodes[2].to_string_logical() == "cpux:0"); | |||
std::vector<Device> devs2 = to_custom<CompNode, Device>(compnodes); | |||
ASSERT_TRUE(devs2[0] == "x86"); | |||
ASSERT_TRUE(devs2[1].str() == "cuda"); | |||
ASSERT_TRUE(devs2[2] == "x86"); | |||
#endif | |||
} | |||
TEST(TestShape, TestShape) { | |||
Shape shape1, shape2; | |||
ASSERT_TRUE(shape1.ndim() == 0); | |||
shape1 = {16, 32, 8, 8}; | |||
shape2 = shape1; | |||
ASSERT_TRUE(shape2.ndim() == 4); | |||
ASSERT_TRUE(shape2[0] == 16); | |||
ASSERT_TRUE(shape2[1] == 32); | |||
ASSERT_TRUE(shape2[2] == 8); | |||
ASSERT_TRUE(shape2[3] == 8); | |||
Shape shape3 = {16, 32, 8, 8}; | |||
const Shape shape4 = shape1; | |||
ASSERT_TRUE(shape3 == shape4); | |||
shape3[0] = 32; | |||
ASSERT_FALSE(shape3 == shape4); | |||
ASSERT_TRUE(shape3[0] == 32); | |||
ASSERT_TRUE(shape4[0] == 16); | |||
Shape shape5 = {2, 3, 4}; | |||
TensorShape bshape1 = to_builtin<TensorShape, Shape>(shape5); | |||
ASSERT_TRUE(bshape1.ndim == 3); | |||
ASSERT_TRUE(bshape1[0] == 2); | |||
ASSERT_TRUE(bshape1[1] == 3); | |||
ASSERT_TRUE(bshape1[2] == 4); | |||
bshape1 = {4, 2, 3}; | |||
Shape shape6 = to_custom<TensorShape, Shape>(bshape1); | |||
ASSERT_TRUE(shape6.ndim() == 3); | |||
ASSERT_TRUE(shape6[0] == 4); | |||
ASSERT_TRUE(shape6[1] == 2); | |||
ASSERT_TRUE(shape6[2] == 3); | |||
Shape shape7; | |||
shape7.ndim(3); | |||
shape7[1] = 4; | |||
ASSERT_TRUE(shape7 == Shape({0, 4, 0})); | |||
std::vector<Shape> shapes1 = {{2, 3, 4}, {6}, {5, 7}}; | |||
megdnn::SmallVector<TensorShape> bshapes = to_builtin<TensorShape, Shape>(shapes1); | |||
ASSERT_TRUE(bshapes[0].total_nr_elems() == 2*3*4); | |||
ASSERT_TRUE(bshapes[1].total_nr_elems() == 6); | |||
ASSERT_TRUE(bshapes[2].total_nr_elems() == 35); | |||
std::vector<Shape> shapes2 = to_custom<TensorShape, Shape>(bshapes); | |||
ASSERT_TRUE(shapes2[0] == Shape({2, 3, 4})); | |||
ASSERT_TRUE(shapes2[1] == Shape({6})); | |||
ASSERT_TRUE(shapes2[2] == Shape({5, 7})); | |||
} | |||
TEST(TestDType, TestDType) { | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
ASSERT_TRUE(DType::is_legal("uint8")); | |||
ASSERT_TRUE(DType::is_legal(DTypeEnum::bfloat16)); | |||
DType dtype1, dtype2; | |||
ASSERT_TRUE(dtype1.str() == "invalid"); | |||
dtype1 = "float32"; | |||
ASSERT_TRUE(dtype1.str() == "float32"); | |||
dtype2 = dtype1; | |||
DType dtype3 = dtype2; | |||
ASSERT_TRUE(dtype3 == dtype1); | |||
ASSERT_TRUE(dtype3 == "float32"); | |||
dtype3 = "int8"; | |||
ASSERT_FALSE("float32" == dtype3.str()); | |||
ASSERT_FALSE(dtype3 == dtype2); | |||
DType dtype4 = DTypeEnum::int8, dtype5 = dtype3; | |||
ASSERT_TRUE(dtype4 == dtype5); | |||
ASSERT_TRUE(dtype4.is_compatible<int8_t>()); | |||
ASSERT_FALSE(dtype4.is_compatible<uint8_t>()); | |||
DType dtype6 = "int32"; | |||
megdnn::DType bdtype1 = to_builtin<megdnn::DType, DType>(dtype6); | |||
ASSERT_TRUE(bdtype1.name() == std::string("Int32")); | |||
bdtype1 = megdnn::DType::from_enum(megdnn::DTypeEnum::BFloat16); | |||
DType dtype7 = to_custom<megdnn::DType, DType>(bdtype1); | |||
ASSERT_TRUE(dtype7.enumv() == DTypeEnum::bfloat16); | |||
std::vector<DType> dtypes1 = {"int8", "uint8", "float16"}; | |||
megdnn::SmallVector<megdnn::DType> bdtypes | |||
= to_builtin<megdnn::DType, DType>(dtypes1); | |||
ASSERT_TRUE(bdtypes[0].name() == std::string("Int8")); | |||
ASSERT_TRUE(bdtypes[1].name() == std::string("Uint8")); | |||
ASSERT_TRUE(bdtypes[2].name() == std::string("Float16")); | |||
std::vector<DType> dtypes2 = to_custom<megdnn::DType, DType>(bdtypes); | |||
ASSERT_TRUE(dtypes2[0] == "int8"); | |||
ASSERT_TRUE(dtypes2[1] == "uint8"); | |||
ASSERT_TRUE(dtypes2[2] == "float16"); | |||
#endif | |||
} | |||
TEST(TestDType, TestDTypeQuantized) { | |||
DType quint8_1("quint8", 3.2, 15); | |||
DType quint8_2("quint8", 3.2, 15); | |||
DType quint8_3("quint8", 3.2, 16); | |||
DType quint8_4("quint8", 3.1, 15); | |||
ASSERT_TRUE(quint8_1 == quint8_2); | |||
ASSERT_FALSE(quint8_1 == quint8_3); | |||
ASSERT_FALSE(quint8_1 == quint8_4); | |||
ASSERT_TRUE(quint8_1.scale() == 3.2f); | |||
ASSERT_TRUE(quint8_1.zero_point() == 15); | |||
DType qint8("qint8", 3.3f); | |||
DType qint16("qint16", 3.4f); | |||
DType qint32("qint32", 3.5f); | |||
ASSERT_TRUE(qint8.scale() == 3.3f); | |||
ASSERT_TRUE(qint16.scale() == 3.4f); | |||
ASSERT_TRUE(qint32.scale() == 3.5f); | |||
ASSERT_TRUE(qint8.enumv() == DTypeEnum::qint8); | |||
ASSERT_TRUE(qint8.str() == "qint8"); | |||
} | |||
TEST(TestFormat, TestFormat) { | |||
Format format1, format2("default"); | |||
ASSERT_TRUE(format1.is_default()); | |||
ASSERT_TRUE(format2.is_default()); | |||
Format format3 = format1; | |||
ASSERT_TRUE(format3.is_default()); | |||
} | |||
TEST(TestTensor, TestTensor) { | |||
CompNode builtin_device = CompNode::load("cpux:0"); | |||
TensorShape builtin_shape = {3, 2, 4}; | |||
megdnn::DType builtin_dtype = dtype::Int32{}; | |||
DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); | |||
Tensor tensor1 = to_custom<DeviceTensorND, Tensor>(dev_tensor); | |||
Tensor tensor2 = to_custom<DeviceTensorND, Tensor>(dev_tensor); | |||
Device device = tensor1.device(); | |||
Shape shape = tensor1.shape(); | |||
DType dtype = tensor1.dtype(); | |||
ASSERT_TRUE(device == "x86"); | |||
ASSERT_TRUE(shape.ndim() == 3); | |||
ASSERT_TRUE(shape[0] == 3); | |||
ASSERT_TRUE(shape[1] == 2); | |||
ASSERT_TRUE(shape[2] == 4); | |||
ASSERT_TRUE(shape == std::vector<size_t>({3, 2, 4})); | |||
ASSERT_TRUE(dtype == "int32"); | |||
int *raw_ptr1 = tensor1.data<int>(); | |||
for (size_t i=0; i<tensor1.size(); i++) | |||
raw_ptr1[i] = i; | |||
int *raw_ptr2 = tensor2.data<int>(); | |||
for (size_t i=0; i<tensor2.size(); i++) | |||
ASSERT_TRUE(raw_ptr2[i] == static_cast<int>(i)); | |||
Tensor tensor3 = tensor2; | |||
int *raw_ptr3 = tensor3.data<int>(); | |||
for (size_t i=0; i<tensor3.size(); i++) | |||
ASSERT_TRUE(raw_ptr3[i] == static_cast<int>(i)); | |||
ASSERT_TRUE(raw_ptr1 == raw_ptr2); | |||
ASSERT_TRUE(raw_ptr1 == raw_ptr3); | |||
for (size_t i=0; i<tensor3.size(); i++) { | |||
raw_ptr3[i] = -static_cast<int>(i); | |||
} | |||
for (size_t i=0; i<tensor1.size(); i++) { | |||
ASSERT_TRUE(raw_ptr1[i] == -static_cast<int>(i)); | |||
} | |||
DeviceTensorND new_dev_tensor = to_builtin<DeviceTensorND, Tensor>(tensor3); | |||
int *builtin_ptr = new_dev_tensor.ptr<int>(); | |||
for (size_t i=0; i<new_dev_tensor.shape().total_nr_elems(); i++) { | |||
ASSERT_TRUE(builtin_ptr[i] == -static_cast<int>(i)); | |||
} | |||
} | |||
TEST(TestTensor, TestTensorQuantized) { | |||
#if MGB_CUDA | |||
CompNode builtin_device = CompNode::load("gpux:0"); | |||
TensorShape builtin_shape = {3, 2, 4}; | |||
megdnn::DType builtin_dtype = dtype::Quantized8Asymm{3.2f, uint8_t(15)}; | |||
DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); | |||
Tensor tensor1 = to_custom<DeviceTensorND, Tensor>(dev_tensor); | |||
Tensor tensor2 = to_custom<DeviceTensorND, Tensor>(dev_tensor); | |||
Device device1 = tensor1.device(), device2 = tensor2.device(); | |||
Shape shape1 = tensor1.shape(), shape2 = tensor2.shape(); | |||
DType dtype1 = tensor1.dtype(), dtype2 = tensor2.dtype(); | |||
ASSERT_TRUE(device1 == "cuda"); | |||
ASSERT_TRUE(shape1.ndim() == 3); | |||
ASSERT_TRUE(shape1[0] == 3); | |||
ASSERT_TRUE(shape1[1] == 2); | |||
ASSERT_TRUE(shape1[2] == 4); | |||
ASSERT_TRUE(shape1 == std::vector<size_t>({3, 2, 4})); | |||
ASSERT_TRUE(dtype1 == "quint8"); | |||
ASSERT_TRUE(dtype1.scale() == 3.2f); | |||
ASSERT_TRUE(dtype1.zero_point() == 15); | |||
ASSERT_TRUE(device1 == device2); | |||
ASSERT_TRUE(shape1 == shape2); | |||
ASSERT_TRUE(dtype1 == dtype2); | |||
#endif | |||
} | |||
TEST(TestTensor, TestTensorAccessorND) { | |||
size_t N = 2, C = 4, H = 6, W = 8; | |||
CompNode builtin_device = CompNode::load("cpux"); | |||
TensorShape builtin_shape = {N, C, H, W}; | |||
megdnn::DType builtin_dtype = dtype::Int32{}; | |||
DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); | |||
int *builtin_ptr = dev_tensor.ptr<int>(); | |||
for (size_t i=0; i<dev_tensor.shape().total_nr_elems(); i++) { | |||
builtin_ptr[i] = i; | |||
} | |||
Tensor tensor = to_custom_tensor(dev_tensor); | |||
auto accessor = tensor.accessor<int32_t, 4>(); | |||
for (size_t n=0; n<N; ++n) { | |||
for (size_t c=0; c<C; ++c) { | |||
for (size_t h=0; h<H; ++h) { | |||
for (size_t w=0; w<W; ++w) { | |||
int32_t idx = n*C*H*W + c*H*W + h*W + w; | |||
ASSERT_TRUE(accessor[n][c][h][w] == idx); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
TEST(TestTensor, TestTensorAccessor1D) { | |||
CompNode builtin_device = CompNode::load("cpux"); | |||
TensorShape builtin_shape = {32}; | |||
megdnn::DType builtin_dtype = dtype::Float32{}; | |||
DeviceTensorND dev_tensor(builtin_device, builtin_shape, builtin_dtype); | |||
float *builtin_ptr = dev_tensor.ptr<float>(); | |||
for (size_t i=0; i<dev_tensor.shape().total_nr_elems(); i++) { | |||
builtin_ptr[i] = i; | |||
} | |||
Tensor tensor = to_custom_tensor(dev_tensor); | |||
auto accessor = tensor.accessor<float, 1>(); | |||
for (size_t n=0; n<32; ++n) { | |||
ASSERT_TRUE(accessor[n] == n); | |||
} | |||
} | |||
} |
@@ -18,7 +18,7 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomOpNode); | |||
void CustomOpNode::infer_output_comp_node(void) { | |||
SmallVector<CompNode> input_comp_nodes(input_num()); | |||
for (int i=0; i<input_num(); ++i) { | |||
for (size_t i=0; i<input_num(); ++i) { | |||
input_comp_nodes[i] = input(i)->comp_node(); | |||
} | |||
@@ -28,7 +28,7 @@ void CustomOpNode::infer_output_comp_node(void) { | |||
) | |||
); | |||
for (int i=0; i<output_num(); ++i) { | |||
for (size_t i=0; i<output_num(); ++i) { | |||
mgb_assert(output_comp_nodes[i] == output_comp_nodes[0], | |||
"only single comp node operator is supported"); | |||
output(i)->comp_node(output_comp_nodes[i]); | |||
@@ -39,7 +39,7 @@ void CustomOpNode::infer_output_comp_node(void) { | |||
void CustomOpNode::infer_output_dtype(void) { | |||
SmallVector<DType> input_dtypes(input_num()); | |||
for (int i=0; i<input_num(); ++i) { | |||
for (size_t i=0; i<input_num(); ++i) { | |||
input_dtypes[i] = input(i)->dtype(); | |||
} | |||
@@ -49,14 +49,14 @@ void CustomOpNode::infer_output_dtype(void) { | |||
) | |||
); | |||
for (int i=0; i<output_num(); ++i) { | |||
for (size_t i=0; i<output_num(); ++i) { | |||
output(i)->dtype(output_dtypes[i]); | |||
} | |||
} | |||
void CustomOpNode::infer_output_format(void) { | |||
SmallVector<TensorFormat> input_formats(input_num()); | |||
for (int i=0; i<input_num(); ++i) { | |||
for (size_t i=0; i<input_num(); ++i) { | |||
input_formats[i] = input(i)->format(); | |||
} | |||
@@ -66,14 +66,14 @@ void CustomOpNode::infer_output_format(void) { | |||
) | |||
); | |||
for (int i=0; i<output_num(); ++i) { | |||
for (size_t i=0; i<output_num(); ++i) { | |||
output(i)->format(output_formats[i]); | |||
} | |||
} | |||
void CustomOpNode::infer_output_shape(void) { | |||
SmallVector<TensorShape> input_shapes(input_num()); | |||
for (int i=0; i<input_num(); ++i) { | |||
for (size_t i=0; i<input_num(); ++i) { | |||
input_shapes[i] = input(i)->shape(); | |||
} | |||
@@ -83,7 +83,7 @@ void CustomOpNode::infer_output_shape(void) { | |||
) | |||
); | |||
for (int i=0; i<output_num(); ++i) { | |||
for (size_t i=0; i<output_num(); ++i) { | |||
output(i)->shape(output_shapes[i]); | |||
} | |||
} | |||
@@ -235,10 +235,10 @@ CustomOpNode::CustomOpNode(const std::shared_ptr<const custom::CustomOp> &op, | |||
const OperatorNodeConfig &config): | |||
OperatorNodeBase(inputs[0]->owner_graph(), config, op->op_type(), inputs), m_op(op), m_param(param) { | |||
mgb_assert(input_num() == inputs.size(), "wrong input tensors list length"); | |||
for (int i=0; i < input_num(); ++i) | |||
for (size_t i=0; i < input_num(); ++i) | |||
add_input({inputs[i]}); | |||
for (int i=0; i<output_num(); ++i) | |||
for (size_t i=0; i<output_num(); ++i) | |||
add_output(output_info(i).name()); | |||
if (!std::is_empty<custom::Param>::value) { | |||
@@ -306,11 +306,11 @@ std::string CustomOpNode::op_desc(void) const { | |||
return m_op->op_desc(); | |||
} | |||
int CustomOpNode::input_num(void) const { | |||
size_t CustomOpNode::input_num(void) const { | |||
return m_op->input_num(); | |||
} | |||
int CustomOpNode::output_num(void) const { | |||
size_t CustomOpNode::output_num(void) const { | |||
return m_op->output_num(); | |||
} | |||
@@ -93,8 +93,8 @@ public: | |||
custom::Param param(void) const; | |||
std::string op_type(void) const; | |||
std::string op_desc(void) const; | |||
int input_num(void) const; | |||
int output_num(void) const; | |||
size_t input_num(void) const; | |||
size_t output_num(void) const; | |||
custom::ArgInfo input_info(size_t idx) const; | |||
custom::ArgInfo output_info(size_t idx) const; | |||
}; | |||
@@ -1,7 +1,7 @@ | |||
include_directories("./src/include") | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") | |||
file(GLOB_RECURSE SOURCES ./*.cpp ../src/core/test/*.cpp ../src/gopt/test/*.cpp ../src/opr/test/*.cpp ../src/plugin/test/*.cpp ../src/serialization/test/*.cpp) | |||
file(GLOB_RECURSE SOURCES ./*.cpp ../src/core/test/*.cpp ../src/gopt/test/*.cpp ../src/opr/test/*.cpp ../src/plugin/test/*.cpp ../src/serialization/test/*.cpp ../src/custom/test/*.cpp) | |||
if(MGE_WITH_JIT) | |||
file(GLOB_RECURSE SOURCES_ ../src/jit/test/*.cpp) | |||
list(APPEND SOURCES ${SOURCES_}) | |||