/** * \file src/opr-mm/test/mock_client.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/opr/group_manager.h" namespace mgb { namespace test { class MockGroupClient final : public opr::GroupClient { public: using RegisterInfo = opr::GroupManager::RegisterInfo; MockGroupClient(const std::string& server_addr = "mock_addr") : m_addr(server_addr) { } ~MockGroupClient() override = default; const std::string& get_addr() const { return m_addr; } RegisterInfo opr_register(const std::string& key, size_t nr_devices, bool is_root, int rank, uint64_t comp_node_hash) override { return m_mgr.opr_register(key, nr_devices, is_root, rank, comp_node_hash); } void bcast_addr(std::string& master_ip, int& port, const std::string& key, uint32_t size, uint32_t rank, uint32_t root) override { return m_mgr.bcast_addr(master_ip, port, key, size, rank, root); } void set_output_shape(const std::string& key, const TensorShape& shape) override { m_mgr.set_output_shape(key, shape); } TensorShape get_output_shape(const std::string& key) override { return m_mgr.get_output_shape(key); } uint32_t group_barrier(uint32_t size, uint32_t rank) override { return m_mgr.group_barrier(size, rank); } private: const std::string m_addr; opr::GroupManager m_mgr; }; } // namespace test } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}