# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import pickle from ..utils.max_recursion_limit import max_recursion_limit def save(obj, f, pickle_module=pickle, pickle_protocol=pickle.HIGHEST_PROTOCOL): r"""Save an object to disk file. :type obj: object :param obj: object to save. Only ``module`` or ``state_dict`` are allowed. :type f: text file object :param f: a string of file name or a text file object to which ``obj`` is saved to. :type pickle_module: :param pickle_module: Default: ``pickle``. :type pickle_protocol: :param pickle_protocol: Default: ``pickle.HIGHEST_PROTOCOL``. """ if isinstance(f, str): with open(f, "wb") as fout: save( obj, fout, pickle_module=pickle_module, pickle_protocol=pickle_protocol ) return with max_recursion_limit(): assert hasattr(f, "write"), "{} does not support write".format(f) pickle_module.dump(obj, f, pickle_protocol) def load(f, pickle_module=pickle): r"""Load an object saved with save() from a file. :type f: text file object :param f: a string of file name or a text file object from which to load. :type pickle_module: :param pickle_module: Default: ``pickle``. """ if isinstance(f, str): with open(f, "rb") as fin: return load(fin, pickle_module=pickle_module) return pickle_module.load(f)