Source code for deepmd.pt.utils.finetune

# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from copy import (
    deepcopy,
)

import torch

from deepmd.pt.utils import (
    env,
)

[docs] log = logging.getLogger(__name__)
[docs] def change_finetune_model_params_single( _single_param_target, _model_param_pretrained, from_multitask=False, model_branch="Default", model_branch_from="", ): single_config = deepcopy(_single_param_target) trainable_param = { "descriptor": True, "fitting_net": True, } for net_type in trainable_param: if net_type in single_config: trainable_param[net_type] = single_config[net_type].get("trainable", True) if not from_multitask: old_type_map, new_type_map = ( _model_param_pretrained["type_map"], single_config["type_map"], ) assert set(new_type_map).issubset( old_type_map ), "Only support for smaller type map when finetuning or resuming." single_config = deepcopy(_model_param_pretrained) log.info( f"Change the '{model_branch}' model configurations according to the pretrained one..." ) single_config["new_type_map"] = new_type_map else: model_dict_params = _model_param_pretrained["model_dict"] new_fitting = False if model_branch_from == "": model_branch_chosen = next(iter(model_dict_params.keys())) new_fitting = True single_config["bias_adjust_mode"] = ( "set-by-statistic" # fitting net re-init ) log.warning( "The fitting net will be re-init instead of using that in the pretrained model! " "The bias_adjust_mode will be set-by-statistic!" ) else: model_branch_chosen = model_branch_from assert model_branch_chosen in model_dict_params, ( f"No model branch named '{model_branch_chosen}'! " f"Available ones are {list(model_dict_params.keys())}." ) single_config_chosen = deepcopy(model_dict_params[model_branch_chosen]) old_type_map, new_type_map = ( single_config_chosen["type_map"], single_config["type_map"], ) assert set(new_type_map).issubset( old_type_map ), "Only support for smaller type map when finetuning or resuming." for key_item in ["type_map", "descriptor"]: if key_item in single_config_chosen: single_config[key_item] = single_config_chosen[key_item] if not new_fitting: single_config["fitting_net"] = single_config_chosen["fitting_net"] log.info( f"Change the '{model_branch}' model configurations according to the model branch " f"'{model_branch_chosen}' in the pretrained one..." ) single_config["new_type_map"] = new_type_map single_config["model_branch_chosen"] = model_branch_chosen single_config["new_fitting"] = new_fitting for net_type in trainable_param: if net_type in single_config: single_config[net_type]["trainable"] = trainable_param[net_type] else: single_config[net_type] = {"trainable": trainable_param[net_type]} return single_config
[docs] def change_finetune_model_params(finetune_model, model_config, model_branch=""): """ Load model_params according to the pretrained one. This function modifies the fine-tuning input in different modes as follows: 1. Single-task fine-tuning from a single-task pretrained model: - Updates the model parameters based on the pretrained model. 2. Single-task fine-tuning from a multi-task pretrained model: - Updates the model parameters based on the selected branch in the pretrained model. - The chosen branch can be defined from the command-line or `finetune_head` input parameter. - If not defined, model parameters in the fitting network will be randomly initialized. 3. Multi-task fine-tuning from a single-task pretrained model: - Updates model parameters in each branch based on the single branch ('Default') in the pretrained model. - If `finetune_head` is not set to 'Default', model parameters in the fitting network of the branch will be randomly initialized. 4. Multi-task fine-tuning from a multi-task pretrained model: - Updates model parameters in each branch based on the selected branch in the pretrained model. - The chosen branches can be defined from the `finetune_head` input parameter of each model. - If `finetune_head` is not defined and the model_key is the same as in the pretrained model, it will resume from the model_key branch without fine-tuning. - If `finetune_head` is not defined and a new model_key is used, model parameters in the fitting network of the branch will be randomly initialized. Parameters ---------- finetune_model The pretrained model. model_config The fine-tuning input parameters. model_branch The model branch chosen in command-line mode, only for single-task fine-tuning. Returns ------- model_config: Updated model parameters. finetune_links: Fine-tuning rules in a dict format, with `model_branch`: `model_branch_from` pairs. If `model_key` is not in this dict, it will do just resuming instead of fine-tuning. """ multi_task = "model_dict" in model_config state_dict = torch.load(finetune_model, map_location=env.DEVICE) if "model" in state_dict: state_dict = state_dict["model"] last_model_params = state_dict["_extra_state"]["model_params"] finetune_from_multi_task = "model_dict" in last_model_params finetune_links = {} if not multi_task: # use command-line first if model_branch == "" and "finetune_head" in model_config: model_branch = model_config["finetune_head"] model_config = change_finetune_model_params_single( model_config, last_model_params, from_multitask=finetune_from_multi_task, model_branch="Default", model_branch_from=model_branch, ) finetune_links["Default"] = ( model_config["model_branch_chosen"] if finetune_from_multi_task else "Default" ) else: assert model_branch == "", ( "Multi-task fine-tuning does not support command-line branches chosen!" "Please define the 'finetune_head' in each model params!" ) target_keys = model_config["model_dict"].keys() if not finetune_from_multi_task: pretrained_keys = ["Default"] else: pretrained_keys = last_model_params["model_dict"].keys() for model_key in target_keys: if "finetune_head" in model_config["model_dict"][model_key]: pretrained_key = model_config["model_dict"][model_key]["finetune_head"] assert pretrained_key in pretrained_keys, ( f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model!" f"Available heads are: {list(pretrained_keys)}" ) model_branch_from = pretrained_key finetune_links[model_key] = model_branch_from elif model_key in pretrained_keys: # not do anything if not defined "finetune_head" in heads that exist in the pretrained model # this will just do resuming model_branch_from = model_key else: # if not defined "finetune_head" in new heads, the fitting net will bre randomly initialized model_branch_from = "" finetune_links[model_key] = next(iter(pretrained_keys)) model_config["model_dict"][model_key] = change_finetune_model_params_single( model_config["model_dict"][model_key], last_model_params, from_multitask=finetune_from_multi_task, model_branch=model_key, model_branch_from=model_branch_from, ) return model_config, finetune_links