# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
)
import torch
from deepmd.pt.model.task.dipole import (
DipoleFittingNet,
)
from .dp_atomic_model import (
DPAtomicModel,
)
[docs]
class DPDipoleAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
assert isinstance(fitting, DipoleFittingNet)
super().__init__(descriptor, fitting, type_map, **kwargs)
[docs]
def apply_out_stat(
self,
ret: Dict[str, torch.Tensor],
atype: torch.Tensor,
):
# dipole not applying bias
return ret