Source code for deepmd.tf.utils.update_sel

# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
    Type,
)

from deepmd.env import (
    GLOBAL_ENER_FLOAT_PRECISION,
)
from deepmd.tf.env import (
    tf,
)
from deepmd.tf.utils.neighbor_stat import (
    NeighborStat,
)
from deepmd.utils.update_sel import (
    BaseUpdateSel,
)


[docs] class UpdateSel(BaseUpdateSel): @property
[docs] def neighbor_stat(self) -> Type[NeighborStat]: return NeighborStat
[docs] def hook(self, min_nbor_dist, max_nbor_size): # moved from traier.py as duplicated tf.constant( min_nbor_dist, name="train_attr/min_nbor_dist", dtype=GLOBAL_ENER_FLOAT_PRECISION, ) tf.constant(max_nbor_size, name="train_attr/max_nbor_size", dtype=tf.int32)