Program Listing for File common.h

Return to documentation for file (source/api_cc/include/common.h)

// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include <iostream>
#include <string>
#include <vector>

#include "AtomMap.h"
#include "errors.h"
#include "neighbor_list.h"
#include "version.h"

namespace deepmd {

typedef double ENERGYTYPE;
enum DPBackend { TensorFlow, PyTorch, Paddle, Unknown };

struct NeighborListData {
  std::vector<int> ilist;
  std::vector<std::vector<int>> jlist;
  std::vector<int> numneigh;
  std::vector<int*> firstneigh;

 public:
  void copy_from_nlist(const InputNlist& inlist);
  void shuffle(const std::vector<int>& fwd_map);
  void shuffle(const deepmd::AtomMap& map);
  void shuffle_exclude_empty(const std::vector<int>& fwd_map);
  void make_inlist(InputNlist& inlist);
  void padding();
};

bool model_compatable(std::string& model_version);

template <typename VALUETYPE>
void select_by_type(std::vector<int>& fwd_map,
                    std::vector<int>& bkw_map,
                    int& nghost_real,
                    const std::vector<VALUETYPE>& dcoord_,
                    const std::vector<int>& datype_,
                    const int& nghost,
                    const std::vector<int>& sel_type_);

template <typename VALUETYPE>
void select_real_atoms(std::vector<int>& fwd_map,
                       std::vector<int>& bkw_map,
                       int& nghost_real,
                       const std::vector<VALUETYPE>& dcoord_,
                       const std::vector<int>& datype_,
                       const int& nghost,
                       const int& ntypes);

template <typename VALUETYPE>
void select_real_atoms_coord(std::vector<VALUETYPE>& dcoord,
                             std::vector<int>& datype,
                             std::vector<VALUETYPE>& aparam,
                             int& nghost_real,
                             std::vector<int>& fwd_map,
                             std::vector<int>& bkw_map,
                             int& nall_real,
                             int& nloc_real,
                             const std::vector<VALUETYPE>& dcoord_,
                             const std::vector<int>& datype_,
                             const std::vector<VALUETYPE>& aparam_,
                             const int& nghost,
                             const int& ntypes,
                             const int& nframes,
                             const int& daparam,
                             const int& nall,
                             const bool aparam_nall = false);

template <typename VT>
void select_map(std::vector<VT>& out,
                const std::vector<VT>& in,
                const std::vector<int>& fwd_map,
                const int& stride,
                const int& nframes = 1,
                // nall will not take effect if nframes is 1
                const int& nall1 = 0,
                const int& nall2 = 0);

template <typename VT>
void select_map(typename std::vector<VT>::iterator out,
                const typename std::vector<VT>::const_iterator in,
                const std::vector<int>& fwd_map,
                const int& stride,
                const int& nframes = 1,
                const int& nall1 = 0,
                const int& nall2 = 0);

template <typename VT>
void select_map_inv(std::vector<VT>& out,
                    const std::vector<VT>& in,
                    const std::vector<int>& fwd_map,
                    const int& stride);

template <typename VT>
void select_map_inv(typename std::vector<VT>::iterator out,
                    const typename std::vector<VT>::const_iterator in,
                    const std::vector<int>& fwd_map,
                    const int& stride);

void get_env_nthreads(int& num_intra_nthreads, int& num_inter_nthreads);

void load_op_library();

struct tf_exception : public deepmd::deepmd_exception {
 public:
  tf_exception() : deepmd::deepmd_exception("TensorFlow Error!"){};
  tf_exception(const std::string& msg)
      : deepmd::deepmd_exception(std::string("TensorFlow Error: ") + msg){};
};

std::string name_prefix(const std::string& name_scope);

void read_file_to_string(std::string model, std::string& file_content);

void convert_pbtxt_to_pb(std::string fn_pb_txt, std::string fn_pb);

void print_summary(const std::string& pre);
}  // namespace deepmd