# SPDX-License-Identifier: LGPL-3.0-or-later
from copy import (
deepcopy,
)
from deepmd.pt.model.descriptor import (
BaseDescriptor,
)
from deepmd.pt.model.task import (
BaseFitting,
)
[docs]
def preprocess_shared_params(model_config):
"""Preprocess the model params for multitask model, and generate the links dict for further sharing.
Args:
model_config: Model params of multitask model.
Returns
-------
model_config: Preprocessed model params of multitask model.
Those string names are replaced with real params in `shared_dict` of model params.
shared_links: Dict of link infos for further sharing.
Each item, whose key must be in `shared_dict`, is a dict with following keys:
- "type": The real class type of this item.
- "links": List of shared settings, each sub-item is a dict with following keys:
- "model_key": Model key in the `model_dict` to share this item.
- "shared_type": Type of this shard item.
- "shared_level": Shared level (int) of this item in this model.
Lower for more params to share, 0 means to share all params in this item.
This list are sorted by "shared_level".
For example, if one has `model_config` like this:
"model": {
"shared_dict": {
"my_type_map": ["foo", "bar"],
"my_des1": {
"type": "se_e2_a",
"neuron": [10, 20, 40]
},
},
"model_dict": {
"model_1": {
"type_map": "my_type_map",
"descriptor": "my_des1",
"fitting_net": {
"neuron": [100, 100, 100]
}
},
"model_2": {
"type_map": "my_type_map",
"descriptor": "my_des1",
"fitting_net": {
"neuron": [100, 100, 100]
}
}
"model_3": {
"type_map": "my_type_map",
"descriptor": "my_des1:1",
"fitting_net": {
"neuron": [100, 100, 100]
}
}
}
}
The above config will init three model branches named `model_1` and `model_2` and `model_3`,
in which:
- `model_2` and `model_3` will have the same `type_map` as that in `model_1`.
- `model_2` will share all the parameters of `descriptor` with `model_1`,
while `model_3` will share part of parameters of `descriptor` with `model_1`
on human-defined share-level `1` (default is `0`, meaning share all the parameters).
- `model_1`, `model_2` and `model_3` have three different `fitting_net`s.
The returned `model_config` will automatically fulfill the input `model_config` as if there's no sharing,
and the `shared_links` will keep all the sharing information with looking:
{
'my_des1': {
'type': 'DescrptSeA',
'links': [
{'model_key': 'model_1',
'shared_type': 'descriptor',
'shared_level': 0},
{'model_key': 'model_2',
'shared_type': 'descriptor',
'shared_level': 0},
{'model_key': 'model_3',
'shared_type': 'descriptor',
'shared_level': 1}
]
}
}
"""
assert "model_dict" in model_config, "only multi-task model can use this method!"
supported_types = ["type_map", "descriptor", "fitting_net"]
shared_dict = model_config.get("shared_dict", {})
shared_links = {}
type_map_keys = []
def replace_one_item(params_dict, key_type, key_in_dict, suffix="", index=None):
shared_type = key_type
shared_key = key_in_dict
shared_level = 0
if ":" in key_in_dict:
shared_key = key_in_dict.split(":")[0]
shared_level = int(key_in_dict.split(":")[1])
assert (
shared_key in shared_dict
), f"Appointed {shared_type} {shared_key} are not in the shared_dict! Please check the input params."
if index is None:
params_dict[shared_type] = deepcopy(shared_dict[shared_key])
else:
params_dict[index] = deepcopy(shared_dict[shared_key])
if shared_type == "type_map":
if key_in_dict not in type_map_keys:
type_map_keys.append(key_in_dict)
else:
if shared_key not in shared_links:
class_name = get_class_name(shared_type, shared_dict[shared_key])
shared_links[shared_key] = {"type": class_name, "links": []}
link_item = {
"model_key": model_key,
"shared_type": shared_type + suffix,
"shared_level": shared_level,
}
shared_links[shared_key]["links"].append(link_item)
for model_key in model_config["model_dict"]:
model_params_item = model_config["model_dict"][model_key]
for item_key in model_params_item:
if item_key in supported_types:
item_params = model_params_item[item_key]
if isinstance(item_params, str):
replace_one_item(model_params_item, item_key, item_params)
elif item_params.get("type", "") == "hybrid":
for ii, hybrid_item in enumerate(item_params["list"]):
if isinstance(hybrid_item, str):
replace_one_item(
model_params_item[item_key]["list"],
item_key,
hybrid_item,
suffix=f"_hybrid_{ii}",
index=ii,
)
for shared_key in shared_links:
shared_links[shared_key]["links"] = sorted(
shared_links[shared_key]["links"],
key=lambda x: x["shared_level"]
- ("spin" in model_config["model_dict"][x["model_key"]]) * 100,
)
# little trick to make spin models in the front to be the base models,
# because its type embeddings are more general.
assert len(type_map_keys) == 1, "Multitask model must have only one type_map!"
return model_config, shared_links
[docs]
def get_class_name(item_key, item_params):
if item_key == "descriptor":
return BaseDescriptor.get_class_by_type(item_params.get("type", "se_e2_a"))
elif item_key == "fitting_net":
return BaseFitting.get_class_by_type(item_params.get("type", "ener"))
else:
raise RuntimeError(f"Unknown class_name type {item_key}")