Program Listing for File DataModifier.h

Return to documentation for file (source/api_cc/include/DataModifier.h)

// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include <memory>

#include "common.h"

namespace deepmd {
class DipoleChargeModifierBase {
 public:
  DipoleChargeModifierBase(){};
  DipoleChargeModifierBase(const std::string& model,
                           const int& gpu_rank = 0,
                           const std::string& name_scope = "");
  virtual ~DipoleChargeModifierBase(){};
  virtual void init(const std::string& model,
                    const int& gpu_rank = 0,
                    const std::string& name_scope = "") = 0;
  virtual void computew(std::vector<double>& dfcorr_,
                        std::vector<double>& dvcorr_,
                        const std::vector<double>& dcoord_,
                        const std::vector<int>& datype_,
                        const std::vector<double>& dbox,
                        const std::vector<std::pair<int, int>>& pairs,
                        const std::vector<double>& delef_,
                        const int nghost,
                        const InputNlist& lmp_list) = 0;
  virtual void computew(std::vector<float>& dfcorr_,
                        std::vector<float>& dvcorr_,
                        const std::vector<float>& dcoord_,
                        const std::vector<int>& datype_,
                        const std::vector<float>& dbox,
                        const std::vector<std::pair<int, int>>& pairs,
                        const std::vector<float>& delef_,
                        const int nghost,
                        const InputNlist& lmp_list) = 0;
  virtual double cutoff() const = 0;
  virtual int numb_types() const = 0;
  virtual std::vector<int> sel_types() const = 0;
};

class DipoleChargeModifier {
 public:
  DipoleChargeModifier();
  DipoleChargeModifier(const std::string& model,
                       const int& gpu_rank = 0,
                       const std::string& name_scope = "");
  ~DipoleChargeModifier();
  void init(const std::string& model,
            const int& gpu_rank = 0,
            const std::string& name_scope = "");
  void print_summary(const std::string& pre) const;

  template <typename VALUETYPE>
  void compute(std::vector<VALUETYPE>& dfcorr_,
               std::vector<VALUETYPE>& dvcorr_,
               const std::vector<VALUETYPE>& dcoord_,
               const std::vector<int>& datype_,
               const std::vector<VALUETYPE>& dbox,
               const std::vector<std::pair<int, int>>& pairs,
               const std::vector<VALUETYPE>& delef_,
               const int nghost,
               const InputNlist& lmp_list);
  double cutoff() const;
  int numb_types() const;
  std::vector<int> sel_types() const;

 private:
  bool inited;
  std::shared_ptr<deepmd::DipoleChargeModifierBase> dcm;
};
}  // namespace deepmd