Program Listing for File DeepPotPT.h

Return to documentation for file (source/api_cc/include/DeepPotPT.h)

// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include <torch/script.h>
#include <torch/torch.h>

#include "DeepPot.h"

namespace deepmd {
class DeepPotPT : public DeepPotBase {
 public:
  DeepPotPT();
  ~DeepPotPT();
  DeepPotPT(const std::string& model,
            const int& gpu_rank = 0,
            const std::string& file_content = "");
  void init(const std::string& model,
            const int& gpu_rank = 0,
            const std::string& file_content = "");

 private:
  template <typename VALUETYPE, typename ENERGYVTYPE>
  void compute(ENERGYVTYPE& ener,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               std::vector<VALUETYPE>& atom_energy,
               std::vector<VALUETYPE>& atom_virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
               const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE, typename ENERGYVTYPE>
  void compute(ENERGYVTYPE& ener,
               std::vector<VALUETYPE>& force,
               std::vector<VALUETYPE>& virial,
               std::vector<VALUETYPE>& atom_energy,
               std::vector<VALUETYPE>& atom_virial,
               const std::vector<VALUETYPE>& coord,
               const std::vector<int>& atype,
               const std::vector<VALUETYPE>& box,
               const int nghost,
               const InputNlist& lmp_list,
               const int& ago,
               const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
               const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE, typename ENERGYVTYPE>
  void compute_mixed_type(
      ENERGYVTYPE& ener,
      std::vector<VALUETYPE>& force,
      std::vector<VALUETYPE>& virial,
      const int& nframes,
      const std::vector<VALUETYPE>& coord,
      const std::vector<int>& atype,
      const std::vector<VALUETYPE>& box,
      const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
      const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
  template <typename VALUETYPE, typename ENERGYVTYPE>
  void compute_mixed_type(
      ENERGYVTYPE& ener,
      std::vector<VALUETYPE>& force,
      std::vector<VALUETYPE>& virial,
      std::vector<VALUETYPE>& atom_energy,
      std::vector<VALUETYPE>& atom_virial,
      const int& nframes,
      const std::vector<VALUETYPE>& coord,
      const std::vector<int>& atype,
      const std::vector<VALUETYPE>& box,
      const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
      const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());

 public:
  double cutoff() const {
    assert(inited);
    return rcut;
  };
  int numb_types() const {
    assert(inited);
    return ntypes;
  };
  int numb_types_spin() const {
    assert(inited);
    return ntypes_spin;
  };
  int dim_fparam() const {
    assert(inited);
    return dfparam;
  };
  int dim_aparam() const {
    assert(inited);
    return daparam;
  };
  void get_type_map(std::string& type_map);

  bool is_aparam_nall() const {
    assert(inited);
    return aparam_nall;
  };

  // forward to template class
  void computew(std::vector<double>& ener,
                std::vector<double>& force,
                std::vector<double>& virial,
                std::vector<double>& atom_energy,
                std::vector<double>& atom_virial,
                const std::vector<double>& coord,
                const std::vector<int>& atype,
                const std::vector<double>& box,
                const std::vector<double>& fparam = std::vector<double>(),
                const std::vector<double>& aparam = std::vector<double>());
  void computew(std::vector<double>& ener,
                std::vector<float>& force,
                std::vector<float>& virial,
                std::vector<float>& atom_energy,
                std::vector<float>& atom_virial,
                const std::vector<float>& coord,
                const std::vector<int>& atype,
                const std::vector<float>& box,
                const std::vector<float>& fparam = std::vector<float>(),
                const std::vector<float>& aparam = std::vector<float>());
  void computew(std::vector<double>& ener,
                std::vector<double>& force,
                std::vector<double>& virial,
                std::vector<double>& atom_energy,
                std::vector<double>& atom_virial,
                const std::vector<double>& coord,
                const std::vector<int>& atype,
                const std::vector<double>& box,
                const int nghost,
                const InputNlist& inlist,
                const int& ago,
                const std::vector<double>& fparam = std::vector<double>(),
                const std::vector<double>& aparam = std::vector<double>());
  void computew(std::vector<double>& ener,
                std::vector<float>& force,
                std::vector<float>& virial,
                std::vector<float>& atom_energy,
                std::vector<float>& atom_virial,
                const std::vector<float>& coord,
                const std::vector<int>& atype,
                const std::vector<float>& box,
                const int nghost,
                const InputNlist& inlist,
                const int& ago,
                const std::vector<float>& fparam = std::vector<float>(),
                const std::vector<float>& aparam = std::vector<float>());
  void computew_mixed_type(
      std::vector<double>& ener,
      std::vector<double>& force,
      std::vector<double>& virial,
      std::vector<double>& atom_energy,
      std::vector<double>& atom_virial,
      const int& nframes,
      const std::vector<double>& coord,
      const std::vector<int>& atype,
      const std::vector<double>& box,
      const std::vector<double>& fparam = std::vector<double>(),
      const std::vector<double>& aparam = std::vector<double>());
  void computew_mixed_type(
      std::vector<double>& ener,
      std::vector<float>& force,
      std::vector<float>& virial,
      std::vector<float>& atom_energy,
      std::vector<float>& atom_virial,
      const int& nframes,
      const std::vector<float>& coord,
      const std::vector<int>& atype,
      const std::vector<float>& box,
      const std::vector<float>& fparam = std::vector<float>(),
      const std::vector<float>& aparam = std::vector<float>());

 private:
  int num_intra_nthreads, num_inter_nthreads;
  bool inited;
  int ntypes;
  int ntypes_spin;
  int dfparam;
  int daparam;
  bool aparam_nall;
  // copy neighbor list info from host
  torch::jit::script::Module module;
  double rcut;
  NeighborListData nlist_data;
  int max_num_neighbors;
  int gpu_id;
  int do_message_passing;  // 1:dpa2 model 0:others
  bool gpu_enabled;
  at::Tensor firstneigh_tensor;
  torch::Dict<std::string, torch::Tensor> comm_dict;
};

}  // namespace deepmd