Source code for deepmd.pt.model.model.dp_model

# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.pt.model.descriptor.base_descriptor import (
    BaseDescriptor,
)


[docs] class DPModelCommon: """A base class to implement common methods for all the Models.""" @classmethod
[docs] def update_sel(cls, global_jdata: dict, local_jdata: dict): """Update the selection and perform neighbor statistics. Parameters ---------- global_jdata : dict The global data, containing the training section local_jdata : dict The local data refer to the current class """ local_jdata_cpy = local_jdata.copy() local_jdata_cpy["descriptor"] = BaseDescriptor.update_sel( global_jdata, local_jdata["descriptor"] ) return local_jdata_cpy
[docs] def get_fitting_net(self): """Get the fitting network.""" return self.atomic_model.fitting_net
[docs] def get_descriptor(self): """Get the descriptor.""" return self.atomic_model.descriptor