Source code for deepmd.entrypoints.convert

# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.utils.convert import (
    convert_10_to_21,
    convert_012_to_21,
    convert_12_to_21,
    convert_13_to_21,
    convert_20_to_21,
    convert_pb_to_pbtxt,
    convert_pbtxt_to_pb,
    convert_to_21,
)


[docs]def convert( *, FROM: str, input_model: str, output_model: str, **kwargs, ): if output_model[-6:] == ".pbtxt": if input_model[-6:] != ".pbtxt": convert_pb_to_pbtxt(input_model, output_model) else: raise RuntimeError("input model is already pbtxt") else: if FROM == "auto": convert_to_21(input_model, output_model) elif FROM == "0.12": convert_012_to_21(input_model, output_model) elif FROM == "1.0": convert_10_to_21(input_model, output_model) elif FROM in ["1.1", "1.2"]: # no difference between 1.1 and 1.2 convert_12_to_21(input_model, output_model) elif FROM == "1.3": convert_13_to_21(input_model, output_model) elif FROM == "2.0": convert_20_to_21(input_model, output_model) elif FROM == "pbtxt": convert_pbtxt_to_pb(input_model, output_model) else: raise RuntimeError("unsupported model version " + FROM)