Program Listing for File DeepTensorTF.h
↰ Return to documentation for file (source/api_cc/include/DeepTensorTF.h
)
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once
#include "DeepTensor.h"
#include "common.h"
#include "commonTF.h"
#include "neighbor_list.h"
namespace deepmd {
class DeepTensorTF : public DeepTensorBase {
public:
DeepTensorTF();
~DeepTensorTF();
DeepTensorTF(const std::string& model,
const int& gpu_rank = 0,
const std::string& name_scope = "");
void init(const std::string& model,
const int& gpu_rank = 0,
const std::string& name_scope = "");
private:
template <typename VALUETYPE>
void compute(std::vector<VALUETYPE>& value,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box);
template <typename VALUETYPE>
void compute(std::vector<VALUETYPE>& value,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const int nghost,
const InputNlist& inlist);
template <typename VALUETYPE>
void compute(std::vector<VALUETYPE>& global_tensor,
std::vector<VALUETYPE>& force,
std::vector<VALUETYPE>& virial,
std::vector<VALUETYPE>& atom_tensor,
std::vector<VALUETYPE>& atom_virial,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box);
template <typename VALUETYPE>
void compute(std::vector<VALUETYPE>& global_tensor,
std::vector<VALUETYPE>& force,
std::vector<VALUETYPE>& virial,
std::vector<VALUETYPE>& atom_tensor,
std::vector<VALUETYPE>& atom_virial,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const int nghost,
const InputNlist& inlist);
public:
double cutoff() const {
assert(inited);
return rcut;
};
int numb_types() const {
assert(inited);
return ntypes;
};
int output_dim() const {
assert(inited);
return odim;
};
const std::vector<int>& sel_types() const {
assert(inited);
return sel_type;
};
void get_type_map(std::string& type_map);
void computew(std::vector<double>& global_tensor,
std::vector<double>& force,
std::vector<double>& virial,
std::vector<double>& atom_tensor,
std::vector<double>& atom_virial,
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box,
const bool request_deriv);
void computew(std::vector<float>& global_tensor,
std::vector<float>& force,
std::vector<float>& virial,
std::vector<float>& atom_tensor,
std::vector<float>& atom_virial,
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box,
const bool request_deriv);
void computew(std::vector<double>& global_tensor,
std::vector<double>& force,
std::vector<double>& virial,
std::vector<double>& atom_tensor,
std::vector<double>& atom_virial,
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box,
const int nghost,
const InputNlist& inlist,
const bool request_deriv);
void computew(std::vector<float>& global_tensor,
std::vector<float>& force,
std::vector<float>& virial,
std::vector<float>& atom_tensor,
std::vector<float>& atom_virial,
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box,
const int nghost,
const InputNlist& inlist,
const bool request_deriv);
private:
tensorflow::Session* session;
std::string name_scope;
int num_intra_nthreads, num_inter_nthreads;
tensorflow::GraphDef* graph_def;
bool inited;
double rcut;
int dtype;
double cell_size;
int ntypes;
std::string model_type;
std::string model_version;
int odim;
std::vector<int> sel_type;
template <class VT>
VT get_scalar(const std::string& name) const;
template <class VT>
void get_vector(std::vector<VT>& vec, const std::string& name) const;
template <typename MODELTYPE, typename VALUETYPE>
void run_model(std::vector<VALUETYPE>& d_tensor_,
tensorflow::Session* session,
const std::vector<std::pair<std::string, tensorflow::Tensor>>&
input_tensors,
const AtomMap& atommap,
const std::vector<int>& sel_fwd,
const int nghost = 0);
template <typename MODELTYPE, typename VALUETYPE>
void run_model(std::vector<VALUETYPE>& dglobal_tensor_,
std::vector<VALUETYPE>& dforce_,
std::vector<VALUETYPE>& dvirial_,
std::vector<VALUETYPE>& datom_tensor_,
std::vector<VALUETYPE>& datom_virial_,
tensorflow::Session* session,
const std::vector<std::pair<std::string, tensorflow::Tensor>>&
input_tensors,
const AtomMap& atommap,
const std::vector<int>& sel_fwd,
const int nghost = 0);
template <typename VALUETYPE>
void compute_inner(std::vector<VALUETYPE>& value,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box);
template <typename VALUETYPE>
void compute_inner(std::vector<VALUETYPE>& value,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const int nghost,
const InputNlist& inlist);
template <typename VALUETYPE>
void compute_inner(std::vector<VALUETYPE>& global_tensor,
std::vector<VALUETYPE>& force,
std::vector<VALUETYPE>& virial,
std::vector<VALUETYPE>& atom_tensor,
std::vector<VALUETYPE>& atom_virial,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box);
template <typename VALUETYPE>
void compute_inner(std::vector<VALUETYPE>& global_tensor,
std::vector<VALUETYPE>& force,
std::vector<VALUETYPE>& virial,
std::vector<VALUETYPE>& atom_tensor,
std::vector<VALUETYPE>& atom_virial,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const int nghost,
const InputNlist& inlist);
};
} // namespace deepmd