From b1ab3646f55e87d8eca622a95239c538ac16151c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 31 Aug 2020 17:10:12 +0800 Subject: [PATCH] feat(imperative): add tensor sanity check GitOrigin-RevId: 27243978e38416e4bea79ee875534e44ca83217a --- .../python/megengine/utils/tensor_sanity_check.py | 33 ++++++ imperative/python/src/utils.cpp | 14 +++ imperative/src/impl/ops/cond_take.cpp | 11 ++ imperative/src/impl/tensor_sanity_check.cpp | 130 +++++++++++++++++++++ .../megbrain/imperative/tensor_sanity_check.h | 50 ++++++++ 5 files changed, 238 insertions(+) create mode 100644 imperative/python/megengine/utils/tensor_sanity_check.py create mode 100644 imperative/src/impl/tensor_sanity_check.cpp create mode 100644 imperative/src/include/megbrain/imperative/tensor_sanity_check.h diff --git a/imperative/python/megengine/utils/tensor_sanity_check.py b/imperative/python/megengine/utils/tensor_sanity_check.py new file mode 100644 index 00000000..b77bdd45 --- /dev/null +++ b/imperative/python/megengine/utils/tensor_sanity_check.py @@ -0,0 +1,33 @@ +from ..core._imperative_rt import TensorSanityCheckImpl +from ..core._imperative_rt.imperative import sync + + +class TensorSanityCheck: + r"""An object that checks whether the input tensors of each operator have changed before and after the operation. + + Examples: + + .. testcode:: + from megengine import tensor + from megengine.utils.tensor_sanity_check import TensorSanityCheck + with TensorSanityCheck() as checker: + a = tensor([1, 2]) + b = tensor([3, 4]) + c = a + b + print(c.numpy()) + + .. testoutput:: + [4 6] + """ + + def __init__(self): + self.impl = TensorSanityCheckImpl() + + def __enter__(self): + sync() + self.impl.enable() + return self + + def __exit__(self, val, type, trace): + sync() + self.impl.disable() diff --git a/imperative/python/src/utils.cpp b/imperative/python/src/utils.cpp index e61aa061..0cda89c3 100644 --- a/imperative/python/src/utils.cpp +++ b/imperative/python/src/utils.cpp @@ -23,6 +23,7 @@ #include "megbrain/comp_node.h" #include "megbrain/imperative/blob_manager.h" #include "megbrain/imperative/profiler.h" +#include "megbrain/imperative/tensor_sanity_check.h" #include "megbrain/serialization/helper.h" #if MGB_ENABLE_OPR_MM @@ -225,6 +226,19 @@ void init_utils(py::module m) { }, py::arg("path") = std::optional()); + using mgb::imperative::TensorSanityCheck; + py::class_(m, "TensorSanityCheckImpl") + .def(py::init<>()) + .def("enable", + [](TensorSanityCheck& checker) -> TensorSanityCheck& { + checker.enable(); + return checker; + }) + .def("disable", + [](TensorSanityCheck& checker) { + checker.disable(); + }); + #if MGB_ENABLE_OPR_MM m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"), py::arg("port") = 0); diff --git a/imperative/src/impl/ops/cond_take.cpp b/imperative/src/impl/ops/cond_take.cpp index 5418bc23..74b3fc51 100644 --- a/imperative/src/impl/ops/cond_take.cpp +++ b/imperative/src/impl/ops/cond_take.cpp @@ -110,9 +110,20 @@ SmallVector apply_on_physical_tensor( return out; } +SmallVector infer_output_attrs( + const OpDef& def, + const SmallVector& inputs) { + SmallVector out; + for (size_t i = 0; i < 2; ++ i) { + out.push_back({TensorLayout(), inputs[0]->comp_node()}); + } + return out; +} + OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) .apply_on_var_node(apply_on_var_node) .apply_on_physical_tensor(apply_on_physical_tensor) + .infer_output_attrs(infer_output_attrs) .fallback(); } // namespace diff --git a/imperative/src/impl/tensor_sanity_check.cpp b/imperative/src/impl/tensor_sanity_check.cpp new file mode 100644 index 00000000..4d998da2 --- /dev/null +++ b/imperative/src/impl/tensor_sanity_check.cpp @@ -0,0 +1,130 @@ +/** + * \file src/core/impl/imperative/tensor_sanity_check.cpp + * + * This file is part of MegBrain, a deep learning framework developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + * + */ + +#include "megbrain/imperative/tensor_sanity_check.h" + +#include "./op_trait.h" + +namespace mgb { + +namespace imperative { + + +TensorChecksumCalc::ChecksumResult TensorChecksumCalc::calc(TensorPtr ptr) { + auto&& dt = ptr->dev_tensor(); + if (!dt.layout().total_nr_elems()) { + static ChecksumResult empty_checksum; + return empty_checksum; + } + + auto span = dt.layout().span(); + megdnn::TensorND tensor; + tensor.raw_ptr = dt.raw_ptr() + span.low_byte; + tensor.layout.init_contiguous_stride({span.dist_byte()}); + tensor.layout.dtype = dtype::Byte(); + + DeviceTensorStorage* workspace; + { + MGB_LOCK_GUARD(m_workspace_mtx); + workspace = &m_workspace[std::this_thread::get_id()] + .storage[ptr->comp_node()]; + } + auto comp_node = ptr->comp_node(); + comp_node.activate(); + auto opr = opr::intl::get_megdnn_global_opr(comp_node); + auto workspace_reqsize = opr->get_workspace_in_bytes(tensor.layout); + workspace->comp_node(ptr->comp_node()).ensure_size(workspace_reqsize); + + megdnn::Workspace mwk; + if (workspace_reqsize) + mwk = {workspace->ptr(), workspace_reqsize}; + + return opr->exec(tensor, mwk); +} + + +class TensorSanityCheckImpl { +public: + std::vector>> + hook_list; + std::unordered_map + tensor2chksum; // TODO: may increase device memory overhead + TensorSanityCheckImpl() { + m_calc = std::make_unique(); + } + bool check(TensorPtr p); +private: + std::unique_ptr m_calc; +}; + +bool TensorSanityCheckImpl::check(TensorPtr p) { + auto&& it = tensor2chksum.find(p); + auto&& chksum = m_calc->calc(p); + if (it == tensor2chksum.end()) { + tensor2chksum[p] = chksum; + return true; + } + return it->second == chksum; +} + +void TensorSanityCheck::enable() { + CompNode::sync_all(); + OpTrait::for_each_trait([this](OpTrait& trait) { + auto backup = std::make_unique( + std::move(trait.apply_on_physical_tensor)); + trait.apply_on_physical_tensor = [this, backup = backup.get()] ( + const OpDef& def, const SmallVector& inputs) { + for (auto&& i: inputs) { + if (!m_checker->check(i)) { + mgb_throw(TensorChecksumCalc::Error, + "tensor modified before exec %s", print_op(def).c_str()); + } + } + auto output = (*backup)(def, inputs); + for (auto&& i: output) { + mgb_assert(m_checker->check(i)); + } + for (auto&& i: inputs) { + if (!m_checker->check(i)) { + mgb_throw(TensorChecksumCalc::Error, + "tensor modified after exec %s", print_op(def).c_str()); + } + } + return output; + }; + m_checker->hook_list.push_back({&trait, std::move(backup)}); + }); +} + +void TensorSanityCheck::disable() { + for (auto&& hook : m_checker->hook_list) { + std::get<0>(hook)->apply_on_physical_tensor = + std::move(*std::get<1>(hook)); + } + m_checker->tensor2chksum.clear(); + m_checker->hook_list.clear(); +} + +TensorSanityCheck::TensorSanityCheck() { + m_checker = std::make_unique(); +} + +TensorSanityCheck::~TensorSanityCheck () { +} + +std::string TensorSanityCheck::print_op(const OpDef& def){ + auto* opr_attr = def.try_cast_final(); + if(opr_attr){ + return std::string("OprAttr:") + opr_attr->type; + } + return def.dyn_typeinfo()->name; +} + +} // namespace imperative +} // namespace mgb \ No newline at end of file diff --git a/imperative/src/include/megbrain/imperative/tensor_sanity_check.h b/imperative/src/include/megbrain/imperative/tensor_sanity_check.h new file mode 100644 index 00000000..b6e849d3 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/tensor_sanity_check.h @@ -0,0 +1,50 @@ +/** + * \file src/core/include/megbrain/tensor_sanity_check.h + * + * This file is part of MegBrain, a deep learning framework developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + * + */ + +#include "megbrain/comp_node_env.h" +#include "megbrain/imperative/ops/opr_attr.h" +#include "megbrain/imperative/op_def.h" +#include "megbrain/plugin/var_sanity_check.h" +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megdnn/oprs/general.h" + +namespace mgb { + +namespace imperative { + +class TensorChecksumCalc { +public: + using ChecksumResult = megdnn::opr_result::Checksum; + using Error = VarSanityCheckError; + struct WorkspaceCache { + //! var comp node to workspace + CompNode::UnorderedMap storage; + }; + ThinHashMap m_workspace; + std::mutex m_workspace_mtx; + ChecksumResult calc(TensorPtr ptr); + TensorChecksumCalc() {} +}; + +class TensorSanityCheckImpl; + +class TensorSanityCheck { +public: + TensorSanityCheck(); + ~TensorSanityCheck(); + void enable(); + void disable(); + std::string print_op(const OpDef& def); +private: + std::unique_ptr m_checker; +}; + + +} // namespace imperative +} // namespace mgb \ No newline at end of file