|
- #include "./graph_rt.h"
-
- #include "./common.h"
- #include "./helper.h"
- #include "./ops.h"
- #include "megbrain/gopt/inference.h"
- #include "megbrain/graph/cg.h"
- #include "megbrain/imperative.h"
- #include "megbrain/imperative/opr_utility.h"
- #include "megbrain/imperative/profiler_plugin.h"
- #include "megbrain/opr/basic_arith.h"
- #include "megbrain/opr/io.h"
- #include "megbrain/opr/utility.h"
- #include "megbrain/plugin/profiler.h"
- #include "megbrain/serialization/serializer.h"
-
- namespace py = pybind11;
-
- using namespace mgb;
- using namespace imperative;
- namespace ser = mgb::serialization;
-
- using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
- using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform;
- using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
- using _SerializationMetadata = mgb::serialization::Metadata;
- using _SerializationFormat = mgb::serialization::GraphDumpFormat;
-
- namespace {
- class _CompGraphProfilerImpl {
- std::shared_ptr<ComputingGraph> m_comp_graph;
- GraphProfiler m_profiler;
-
- public:
- _CompGraphProfilerImpl(std::shared_ptr<ComputingGraph> cg)
- : m_comp_graph{cg}, m_profiler{m_comp_graph.get()} {}
-
- std::string _get_result() {
- auto json = m_profiler.to_json_full(m_comp_graph->current_comp_seq());
- return json->to_string();
- }
- };
-
- struct WeakRendezvousArray : public std::vector<std::weak_ptr<RendezvousBase>>,
- public UserDataContainer::UserData {
- MGB_TYPEINFO_OBJ_DECL;
- };
- MGB_TYPEINFO_OBJ_IMPL(WeakRendezvousArray);
- } // namespace
- #define DEF_READWRITE(name) .def_readwrite(#name, &CURRENT_CLASS::name)
-
- template <typename T>
- auto def_rendezvous(py::object m, const char* name) {
- return py::class_<Rendezvous<T>, std::shared_ptr<Rendezvous<T>>>(m, name)
- .def(py::init([]() { return Rendezvous<T>::make(); }))
- .def("set", [](Rendezvous<T>& r, T v) { r.set(std::move(v)); })
- .def(
- "get", [](Rendezvous<T>& r) { return r.get(); },
- py::call_guard<py::gil_scoped_release>())
- .def("drop", &Rendezvous<T>::drop)
- .def("reset", &Rendezvous<T>::reset)
- .def("set_exception", [](Rendezvous<T>& r, std::string&& message) {
- r.set_exception(std::make_exception_ptr(
- std::runtime_error(std::move(message))));
- });
- }
-
- using TensorAttr = LogicalTensorDesc;
- using HostNDWithEvent = std::pair<HostTensorND, std::shared_ptr<CompNode::Event>>;
-
- std::vector<mgb::cg::VarNode*> _replace_vars(
- const std::vector<mgb::cg::VarNode*>& repl_src,
- const std::vector<mgb::cg::VarNode*>& repl_dst,
- const std::vector<mgb::cg::VarNode*>& vars) {
- mgb::ThinHashMap<SymbolVar, SymbolVar> varmap;
- for (size_t i = 0; i < repl_src.size(); ++i) {
- varmap[SymbolVar(repl_src[i])] = SymbolVar(repl_dst[i]);
- }
- SymbolVarArray symvars(vars.begin(), vars.end());
- auto sym_result = mgb::cg::replace_vars(symvars, varmap);
- std::vector<mgb::cg::VarNode*> result;
- for (auto symvar : sym_result) {
- result.push_back(symvar.node());
- }
- return result;
- }
-
- typedef std::vector<mgb::cg::OperatorNodeBase*> OperatorArray;
- std::vector<mgb::cg::VarNode*> _replace_oprs(
- const OperatorArray& repl_src, const OperatorArray& repl_dst,
- const std::vector<mgb::cg::VarNode*>& vars) {
- mgb::ThinHashMap<mgb::cg::OperatorNodeBase*, mgb::cg::OperatorNodeBase*> oprmap;
- for (size_t i = 0; i < repl_src.size(); ++i) {
- oprmap[repl_src[i]] = repl_dst[i];
- }
- const SymbolVarArray symvars(vars.begin(), vars.end());
- auto sym_result = mgb::cg::replace_oprs(symvars, oprmap);
- std::vector<mgb::cg::VarNode*> result;
- for (auto symvar : sym_result) {
- result.push_back(symvar.node());
- }
- return result;
- }
-
- void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) {
- auto on_opr = [](mgb::cg::OperatorNodeBase* opr) {
- if (opr->node_prop().attribute().priority == 0) {
- opr->node_prop().attribute().priority = opr->id();
- }
- };
- mgb::cg::DepOprIter dep_iter{on_opr};
- for (const auto& var : dest_vars) {
- dep_iter.add(SymbolVar(var));
- }
- }
-
- py::object Py_Varnode = py::none();
-
- void init_graph_rt(py::module m) {
- static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{
- std::make_unique<mgb::OprFootprint>()};
-
- def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous");
-
- def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous");
-
- def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous");
-
- Py_Varnode =
- py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode")
- .def_property_readonly(
- "owner", [](cg::VarNode* v) { return v->owner_opr(); })
- .def_property_readonly(
- "graph", [](cg::VarNode* v) { return v->owner_graph(); })
- .def_property(
- "name", py::overload_cast<>(&VarNode::name, py::const_),
- py::overload_cast<std::string>(&VarNode::name))
- .def_property_readonly(
- "dtype", [](cg::VarNode* v) { return v->dtype(); })
- .def_property_readonly(
- "comp_node", [](cg::VarNode* v) { return v->comp_node(); })
- .def_property_readonly(
- "shape",
- [](cg::VarNode* v) -> const TensorShape* {
- auto&& mgr = v->owner_graph()->static_infer_manager();
- return mgr.infer_shape_fallible(v);
- })
- .def_property_readonly(
- "value",
- [](cg::VarNode* v) -> py::object {
- auto&& mgr = v->owner_graph()->static_infer_manager();
- auto&& type = mgr.get_infer_type(v);
- using InferType = cg::static_infer::InferType;
- if (!(type.value &
- (InferType::CONST | InferType::RT_STATIC))) {
- return py::none();
- }
- auto* val = mgr.infer_value_fallible(v);
- if (!val) {
- return py::none();
- }
- return py::cast(*val).attr("numpy")();
- })
- .def_property_readonly(
- "id", [](cg::VarNode* v) { return (v->id()); })
- .def("__repr__", [](cg::VarNode* v) { return "Var:" + v->name(); });
-
- py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(
- m, "OperatorNode")
- .def_property_readonly(
- "graph",
- [](cg::OperatorNodeBase* opr) { return opr->owner_graph(); })
- .def_property(
- "name",
- py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_),
- py::overload_cast<std::string>(&cg::OperatorNodeBase::name))
- .def_property_readonly(
- "inputs",
- [](cg::OperatorNodeBase* opr) { return to_tuple(opr->input()); })
- .def_property_readonly(
- "outputs",
- [](cg::OperatorNodeBase* opr) {
- return to_tuple(opr->usable_output());
- })
- .def_property_readonly(
- "id", [](cg::OperatorNodeBase* opr) { return opr->id(); })
- .def_property_readonly(
- "params",
- [](cg::OperatorNodeBase* opr) {
- return _imperative_sm_opr_footprint_ptr->calc_footprint(opr)
- .param->to_string();
- })
- .def_property_readonly(
- "type",
- [](cg::OperatorNodeBase* opr) { return opr->dyn_typeinfo()->name; })
- .def("__repr__",
- [](cg::OperatorNodeBase* opr) { return "Opr:" + opr->name(); })
- .def_property(
- "priority",
- [](cg::OperatorNodeBase* opr) {
- return opr->node_prop().attribute().priority;
- },
- [](cg::OperatorNodeBase* opr, int priority) {
- opr->node_prop().attribute().priority = priority;
- });
-
- py::class_<cg::AsyncExecutable>(m, "AsyncExecutable")
- .def("execute", &cg::AsyncExecutable::execute,
- py::call_guard<py::gil_scoped_release>())
- .def("wait", &cg::AsyncExecutable::wait,
- py::call_guard<py::gil_scoped_release>())
- .def("get_prev_exec_time", &cg::AsyncExecutable::get_prev_exec_time,
- py::call_guard<py::gil_scoped_release>())
- .def("_to_json",
- [](cg::AsyncExecutable* exec) {
- py::call_guard<py::gil_scoped_release>();
- // dump currently compiled computing graph for debugging
- return exec->to_json()->to_string();
- })
- // only used for exception handle
- .def_property_readonly(
- "_all_rendezvous",
- [](cg::AsyncExecutable* exec) {
- auto ud =
- exec->owner_graph()
- ->options()
- .user_data.get_user_data<WeakRendezvousArray>();
- std::vector<std::shared_ptr<RendezvousBase>> ret;
- if (ud.second) {
- for (auto&& r : *ud.first[0]) {
- if (auto p = r.lock()) {
- ret.emplace_back(std::move(p));
- }
- }
- }
- return ret;
- })
- .def("get_static_memory_alloc_info",
- &cg::AsyncExecutable::get_static_memory_alloc_info,
- py::call_guard<py::gil_scoped_release>());
-
- auto PyComputingGraph =
- py::class_<cg::ComputingGraph, std::shared_ptr<cg::ComputingGraph>>(
- m, "ComputingGraph")
- .def(py::init(py::overload_cast<>(&cg::ComputingGraph::make)))
- .def("compile",
- [](cg::ComputingGraph& graph,
- const std::vector<cg::VarNode*>& dest_vars) {
- mgb_assert(!dest_vars.empty());
- cg::ComputingGraph::OutputSpec spec;
- for (auto v : dest_vars) {
- spec.emplace_back(v, nullptr);
- }
- return graph.compile(spec);
- })
- .def("enable_weight_preprocess",
- [](cg::ComputingGraph& graph) {
- graph.options().graph_opt.enable_weight_preprocess();
- })
- .def_property_readonly(
- "options",
- py::overload_cast<>(&cg::ComputingGraph::options));
-
- py::class_<_CompGraphProfilerImpl, std::shared_ptr<_CompGraphProfilerImpl>>(
- m, "GraphProfiler")
- .def(py::init([](std::shared_ptr<ComputingGraph> graph) {
- return std::make_shared<_CompGraphProfilerImpl>(graph);
- }))
- .def("get", [](_CompGraphProfilerImpl& profiler) {
- return profiler._get_result();
- });
-
- using interpreter::intl::ProfilerPlugin;
- py::class_<ProfilerPlugin, std::shared_ptr<ProfilerPlugin>>(m, "GraphProfiler2")
- .def(py::init<cg::ComputingGraph*>());
-
- auto GraphOptimizeOptions =
- py::class_<_OptimizeForInferenceOptions>(m, "GraphOptimizeOptions")
- .def(py::init())
- .def("serialize", &_OptimizeForInferenceOptions::serialize)
- .def_static(
- "deserialize", &_OptimizeForInferenceOptions::deserialize)
- .def_readwrite(
- "f16_io_f32_comp",
- &_OptimizeForInferenceOptions::f16_io_f32_comp)
- .def_readwrite(
- "f16_io_comp", &_OptimizeForInferenceOptions::f16_io_comp)
- .def_readwrite(
- "fuse_conv_bias_nonlinearity",
- &_OptimizeForInferenceOptions::fuse_conv_bias_nonlinearity)
- .def_readwrite(
- "fuse_conv_bias_with_z",
- &_OptimizeForInferenceOptions::fuse_conv_bias_with_z)
- .def_readwrite(
- "fuse_preprocess",
- &_OptimizeForInferenceOptions::fuse_preprocess)
- .def_readwrite(
- "layout_transform",
- &_OptimizeForInferenceOptions::layout_transform)
- .def_readwrite(
- "fuse_grain", &_OptimizeForInferenceOptions::fuse_grain);
-
- py::enum_<_LayoutTransform>(GraphOptimizeOptions, "LayoutTransform")
- .value("DEFAULT", _LayoutTransform::DEFAULT)
- .value("NCHW4", _LayoutTransform::NCHW4)
- .value("NHWCD4", _LayoutTransform::NHWCD4)
- .value("NCHW88", _LayoutTransform::NCHW88)
- .value("NCHW44", _LayoutTransform::NCHW44)
- .value("NCHW44_DOT", _LayoutTransform::NCHW44_DOT)
- .value("NCHW32", _LayoutTransform::NCHW32)
- .value("CHWN4", _LayoutTransform::CHWN4)
- .value("NCHW64", _LayoutTransform::NCHW64)
- .export_values();
-
- py::enum_<_SerializationFormat>(m, "SerializationFormat")
- .value("FBS", _SerializationFormat::FLATBUFFERS)
- .value("FBS_V2", _SerializationFormat::FLATBUFFERS_V2)
- .export_values();
-
- m.def("optimize_for_inference",
- [](const VarNodeArray& dest_vars, const _OptimizeForInferenceOptions& opt) {
- SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
- auto res_symvars = mgb::gopt::optimize_for_inference(symvars, opt);
- VarNodeArray vars;
- for (auto& si : res_symvars)
- vars.push_back(si.node());
- return vars;
- });
-
- m.def("modify_opr_algo_strategy_inplace",
- [](const VarNodeArray& dest_vars, const _AlgoStrategy& strategy) {
- mgb::gopt::modify_opr_algo_strategy_inplace(dest_vars, strategy);
- });
-
- m.def("get_info_for_strip", [](const std::vector<VarNode*>& dest_vars) {
- std::unordered_set<const char*> opr_types, dtype_names, elemwise_modes;
- auto on_opr = [&](cg::OperatorNodeBase* opr) {
- if (ser::GraphDumper::should_remove_in_dump(opr))
- return;
- opr_types.insert(opr->dyn_typeinfo()->name);
- for (auto i : opr->output())
- dtype_names.insert(i->dtype().name());
- if (opr->same_type<opr::Elemwise>()) {
- auto mode = opr->cast_final<opr::Elemwise>().param().mode;
- elemwise_modes.insert(
- megdnn::Elemwise::ModeTrait::from_mode(mode).name);
- }
- };
- cg::DepOprIter opr_iter{on_opr};
- for (auto i : dest_vars)
- opr_iter.add(i->owner_opr());
-
- auto to_json = [](const std::unordered_set<const char*>& v) {
- std::vector<std::string> vs(v.begin(), v.end());
- std::sort(vs.begin(), vs.end());
- auto ret = json::Array::make();
- for (auto&& i : vs)
- ret->add(json::String::make(i));
- return ret;
- };
-
- return json::Object::make({
- {"opr_types", to_json(opr_types)},
- {"dtypes", to_json(dtype_names)},
- {"elemwise_modes", to_json(elemwise_modes)},
- })
- ->to_string();
- });
-
- py::class_<_SerializationMetadata>(m, "SerializationMetadata")
- .def(py::init())
- .def_property(
- "user_info",
- [](const _SerializationMetadata& meta) {
- return py::bytes(meta.get_user_info());
- },
- &_SerializationMetadata::set_user_info)
- .def_readonly(
- "optimized_for_inference",
- &_SerializationMetadata::optimized_for_inference)
- .def_property(
- "optimize_options", &_SerializationMetadata::get_optimize_options,
- &_SerializationMetadata::set_optimize_options)
- .def_readwrite("graph_modified", &_SerializationMetadata::graph_modified)
- .def_readwrite("is_valid", &_SerializationMetadata::is_valid);
-
- m.def("dump_graph",
- [](const std::vector<VarNode*>& dest_vars, int keep_var_name,
- bool keep_opr_name, bool keep_param_name, bool keep_opr_priority,
- bool no_change_graph, std::optional<_SerializationMetadata> metadata,
- std::optional<_SerializationFormat> dump_format,
- std::optional<int> model_version, py::list& stat, py::list& inputs,
- py::list& outputs, py::list& params) {
- std::vector<uint8_t> buf;
- ser::GraphDumpFormat format = ser::GraphDumpFormat::FLATBUFFERS_V2;
- int version = 2;
- if (dump_format.has_value()) {
- format = dump_format.value();
- }
- if (model_version.has_value()) {
- version = model_version.value();
- }
- auto dumper = ser::GraphDumper::make(
- ser::OutputFile::make_vector_proxy(&buf), format, version);
- SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
-
- ser::GraphDumper::DumpConfig config{
- keep_var_name, keep_param_name, keep_opr_priority, keep_opr_name};
- config.no_change_graph = no_change_graph;
-
- ser::GraphDumper::DumpResult rst;
- if (metadata)
- rst = dumper->dump(symvars, config, *metadata);
- else
- rst = dumper->dump(symvars, config);
-
- for (auto i : rst.inputs) {
- inputs.append(py::cast(i));
- }
- for (auto i : rst.outputs) {
- outputs.append(py::cast(i));
- }
- for (auto i : rst.params) {
- params.append(py::cast(i));
- }
- auto rst_stat = std::vector{
- rst.nr_opr, rst.tot_bytes, rst.tensor_value_bytes,
- static_cast<size_t>(rst.content_hash)};
- for (auto i : rst_stat) {
- stat.append(py::cast(i));
- }
- return py::bytes(reinterpret_cast<const char*>(&buf[0]), buf.size());
- });
-
- m.def("load_graph",
- [](std::string& buf, py::list& output_var_map, py::list& output_var_list) {
- auto file = ser::InputFile::make_mem_proxy(buf.c_str(), buf.length());
- auto format = ser::GraphLoader::identify_graph_dump_format(*file);
- auto loader = ser::GraphLoader::make(std::move(file), format.val());
- ser::GraphLoader::LoadConfig config;
- auto rst = loader->load(config);
- for (auto i : rst.output_var_map) {
- output_var_map.append(py::make_tuple(i.first, i.second.node()));
- }
- for (auto i : rst.output_var_list) {
- output_var_list.append(i.node());
- }
- std::unordered_map<HostTensorND*, const std::string*> tensor2name;
- for (const auto& pair : rst.tensor_map) {
- tensor2name[pair.second.get()] = &pair.first;
- }
- auto cb = [&tensor2name, graph = rst.graph](cg::OperatorNodeBase* opr) {
- if (!opr->same_type<opr::Host2DeviceCopy>())
- return;
- auto& h2d = opr->cast_final_safe<opr::Host2DeviceCopy>();
- auto it = tensor2name.find(h2d.host_data().get());
- mgb_throw_if(
- it == tensor2name.end(), GraphError,
- "unbound Host2DeviceCopy in loaded graph");
- h2d.output(0)->name(*it->second);
- };
- cg::DepOprIter iter{cb};
- for (const auto& var : rst.output_var_list) {
- iter.add(var);
- }
- auto ret = py::tuple(2);
- ret[0] = py::cast(rst.graph);
- ret[1] = py::cast(rst.metadata);
- return ret;
- });
-
- #define CURRENT_CLASS cg::ComputingGraph::Options
-
- // clang-format off
- auto PyComputingGraphOptions =
- py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options")
- // DEF_READWRITE(opr_attribute)
- DEF_READWRITE(seq_opt)
- DEF_READWRITE(graph_opt)
- DEF_READWRITE(graph_opt_level)
- DEF_READWRITE(log_level)
- DEF_READWRITE(async_exec_level)
- DEF_READWRITE(force_dynamic_alloc)
- DEF_READWRITE(var_sanity_check_first_run)
- DEF_READWRITE(allocate_static_mem_after_graph_compile)
- DEF_READWRITE(fake_next_exec)
- DEF_READWRITE(enable_sublinear_memory_opt)
- DEF_READWRITE(enable_dtr_memory_opt)
- DEF_READWRITE(no_profiling_on_shape_change)
- DEF_READWRITE(enable_var_mem_defragment)
- DEF_READWRITE(enable_grad_var_static_reshape)
- DEF_READWRITE(enable_memory_swap)
- DEF_READWRITE(comp_node_seq_record_level)
- DEF_READWRITE(no_force_inplace)
- DEF_READWRITE(sublinear_mem_config)
- DEF_READWRITE(dtr_config)
- // DEF_READWRITE(eager_evaluation)
- // DEF_READWRITE(imperative_proxy_graph)
- // DEF_READWRITE(extra_vardeps)
- // DEF_READWRITE(user_data)
- ;
- // clang-format on
-
- #undef CURRENT_CLASS
- #define CURRENT_CLASS cg::ComputingGraph::Options::SeqOpt
-
- py::class_<cg::ComputingGraph::Options::SeqOpt>(PyComputingGraphOptions, "SeqOpt")
- DEF_READWRITE(enable_mem_plan_opt) DEF_READWRITE(enable_mem_reuse_alloc)
- DEF_READWRITE(enable_seq_comp_node_opt);
-
- #undef CURRENT_CLASS
- #define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt
-
- auto PyGraphOpt = py::class_<cg::ComputingGraph::Options::GraphOpt>(
- PyComputingGraphOptions, "GraphOpt") DEF_READWRITE(jit)
- DEF_READWRITE(jit_config)
- DEF_READWRITE(tensorrt);
-
- #undef CURRENT_CLASS
- #define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt::JITConfig
-
- py::class_<cg::ComputingGraph::Options::GraphOpt::JITConfig>(
- PyGraphOpt, "JITConfig") DEF_READWRITE(fuse_dimshuffle)
- DEF_READWRITE(fuse_reduce);
-
- #undef CURRENT_CLASS
- #define CURRENT_CLASS cg::ComputingGraph::Options::SublinearMemConfig
-
- py::class_<cg::ComputingGraph::Options::SublinearMemConfig>(
- PyComputingGraphOptions, "SublinearMemConfig") DEF_READWRITE(thresh_nr_try)
- DEF_READWRITE(genetic_nr_iter) DEF_READWRITE(genetic_pool_size)
- DEF_READWRITE(lb_memory_mb) DEF_READWRITE(num_worker);
-
- #undef CURRENT_CLASS
-
- #define CURRENT_CLASS cg::ComputingGraph::Options::DTRConfig
-
- py::class_<cg::ComputingGraph::Options::DTRConfig>(
- PyComputingGraphOptions, "DTRConfig") DEF_READWRITE(eviction_threshold)
- DEF_READWRITE(evictee_minimum_size) DEF_READWRITE(recomp_memory_factor)
- DEF_READWRITE(recomp_time_factor);
-
- #undef CURRENT_CLASS
- auto common = rel_import("common", m, 1);
-
- common.def(
- "invoke_op",
- [](const OpDef& def, const std::vector<cg::VarNode*> inputs,
- cg::ComputingGraph* graph) {
- cg::VarNodeArray vinputs(inputs.begin(), inputs.end());
- return to_tuple(OpDef::apply_on_var_node(def, vinputs));
- },
- py::arg(), py::arg(), py::arg("graph") = py::none());
-
- auto input_callback = [](auto callback, const CompNode& comp_node,
- const DType& dtype, const TensorShape& shape,
- const std::vector<cg::VarNode*>& inputs,
- cg::ComputingGraph* graph, bool use_static_shape) {
- if (!graph) {
- graph = inputs[0]->owner_graph();
- }
- SymbolVarArray sinputs;
- for (auto i : inputs) {
- sinputs.emplace_back(i);
- }
- static_assert(!std::is_reference<decltype(callback)>::value);
- auto soutputs = opr::InputCallback::make(
- *graph, std::move(callback), comp_node, dtype, shape, sinputs,
- use_static_shape);
- std::vector<VarNode*> outputs;
- outputs.reserve(soutputs.size());
- for (auto i : soutputs) {
- outputs.push_back(i.node());
- }
- return outputs;
- };
-
- m.def("make_shared", [](cg::ComputingGraph* graph, const DeviceTensorND& data) {
- return opr::SharedDeviceTensor::make(
- *graph, std::make_shared<DeviceTensorND>(data))
- .node();
- });
-
- m.def(
- "make_const",
- [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype,
- std::optional<std::string> name) {
- if (!cn.valid()) {
- cn = CompNode::load(get_default_device());
- }
- OperatorNodeConfig config(cn);
- if (name) {
- config.name(*name);
- }
- auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
- return opr::ImmutableTensor::make(*graph, hv, config).node();
- },
- py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = py::none());
-
- m.def(
- "make_h2d",
- [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape,
- std::optional<std::string> name) {
- if (!cn.valid()) {
- throw py::type_error("device must be valid");
- }
- if (!dtype.valid()) {
- throw py::type_error("dtype must be valid");
- }
- OperatorNodeConfig config;
- if (name) {
- config.name(*name);
- }
- return opr::Host2DeviceCopy::make(
- graph, std::make_shared<HostTensorND>(cn, shape, dtype),
- config)
- .node();
- },
- py::arg(), py::arg(), py::arg(), py::arg() = py::none(),
- py::arg() = py::none());
-
- m.def("_replace_vars", &_replace_vars, py::arg(), py::arg(), py::arg());
- m.def("_replace_oprs", &_replace_oprs, py::arg(), py::arg(), py::arg());
- m.def("_set_priority_to_id", &_set_priority_to_id, py::arg());
-
- m.def(
- "input_callback",
- [input_callback](
- std::function<DeviceTensorND(void)> callback,
- const CompNode& comp_node, const DType& dtype,
- const TensorShape& shape, const std::vector<cg::VarNode*>& inputs,
- cg::ComputingGraph* graph, bool use_static_shape) {
- return input_callback(
- [f = std::move(callback)]() {
- py::gil_scoped_acquire _;
- return f();
- },
- comp_node, dtype, shape, inputs, graph, use_static_shape);
- },
- py::arg(), py::arg(), py::arg(), py::arg() = py::none(),
- py::arg() = py::tuple(), py::arg("graph") = py::none(),
- py::arg("use_static_shape") = false);
-
- m.def(
- "input_callback",
- [input_callback](
- std::shared_ptr<Rendezvous<DeviceTensorND>> p,
- const CompNode& comp_node, const DType& dtype,
- const TensorShape& shape, const std::vector<cg::VarNode*>& inputs,
- cg::ComputingGraph* graph, bool use_static_shape) {
- auto f = [p]() -> DeviceTensorND { return p->get(); };
- return input_callback(
- std::move(f), comp_node, dtype, shape, inputs, graph,
- use_static_shape);
- },
- py::arg(), py::arg(), py::arg(), py::arg() = py::none(),
- py::arg() = py::tuple(), py::arg("graph") = py::none(),
- py::arg("use_static_shape") = false);
-
- auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs,
- std::shared_ptr<RendezvousBase> r = {},
- bool borrow = false, bool prefer_host_value = false) {
- if (r) {
- mgb_assert(inputs.size());
- auto cg = inputs[0]->owner_graph();
- cg->options()
- .user_data.get_user_data_or_create<WeakRendezvousArray>()
- ->emplace_back(r);
- }
- SymbolVarArray sinputs;
- for (auto i : inputs) {
- sinputs.emplace_back(i);
- }
- static_assert(!std::is_reference<decltype(callback)>::value);
- opr::OutputCallback::Param param{
- std::move(callback), borrow, prefer_host_value};
- auto output = opr::OutputCallback::make(std::move(param), sinputs);
- return output.node();
- };
-
- m.def("output_callback", [output_callback](
- std::function<void(DeviceTensorND)> callback,
- std::vector<cg::VarNode*> inputs) {
- auto f = [f = std::move(callback)](DeviceTensorND dv) {
- auto task = [f = std::move(f), dv = std::move(dv)]() { f(dv); };
- py_task_q.add_task(std::move(task));
- };
- return output_callback(std::move(f), std::move(inputs));
- });
-
- m.def("output_callback", [output_callback](
- std::shared_ptr<Rendezvous<DeviceTensorND>> p,
- std::vector<cg::VarNode*> inputs) {
- auto f = [p](DeviceTensorND dv) { p->set(std::move(dv)); };
- return output_callback(std::move(f), std::move(inputs), p);
- });
-
- m.def("value_output_callback",
- [output_callback](
- std::shared_ptr<Rendezvous<HostNDWithEvent>> p,
- std::vector<cg::VarNode*> inputs) {
- auto f = [p](DeviceTensorND dv) {
- HostNDWithEvent hv_with_event;
- hv_with_event.first.copy_from(dv);
- hv_with_event.second = dv.comp_node().create_event();
- hv_with_event.second->record();
- p->set(std::move(hv_with_event));
- };
- return output_callback(std::move(f), std::move(inputs), p, true, true);
- });
-
- m.def("attr_output_callback", [output_callback](
- std::shared_ptr<Rendezvous<TensorAttr>> p,
- std::vector<cg::VarNode*> inputs) {
- auto f = [p](DeviceTensorND dv) {
- p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()});
- };
- return output_callback(std::move(f), std::move(inputs), p, true);
- });
-
- m.def("virtual_dep", [](std::vector<cg::VarNode*> inputs, std::string device) {
- auto&& graph = inputs[0]->owner_graph();
- VarNodeArray inps(inputs.begin(), inputs.end());
- cg::OperatorNodeConfig config;
- if (device.length() > 0) {
- config.comp_node(CompNode::load(device));
- }
- cg::OperatorNodeBase* opr =
- graph->insert_opr(std::make_unique<mgb::opr::VirtualDep>(inps, config));
- return opr;
- });
- }
|