Program Listing for File commonTF.h

Return to documentation for file (source/api_cc/include/commonTF.h)

// SPDX-License-Identifier: LGPL-3.0-or-later
#include <string>
#include <vector>

#ifdef TF_PRIVATE
#include "tf_private.h"
#else
#include "tf_public.h"
#endif

namespace deepmd {
void check_status(const tensorflow::Status& status);

template <typename VT>
VT session_get_scalar(tensorflow::Session* session,
                      const std::string name,
                      const std::string scope = "");

template <typename VT>
void session_get_vector(std::vector<VT>& o_vec,
                        tensorflow::Session* session,
                        const std::string name_,
                        const std::string scope = "");

int session_get_dtype(tensorflow::Session* session,
                      const std::string name,
                      const std::string scope = "");

template <typename MODELTYPE, typename VALUETYPE>
int session_input_tensors(
    std::vector<std::pair<std::string, tensorflow::Tensor>>& input_tensors,
    const std::vector<VALUETYPE>& dcoord_,
    const int& ntypes,
    const std::vector<int>& datype_,
    const std::vector<VALUETYPE>& dbox,
    const double& cell_size,
    const std::vector<VALUETYPE>& fparam_,
    const std::vector<VALUETYPE>& aparam_,
    const deepmd::AtomMap& atommap,
    const std::string scope = "",
    const bool aparam_nall = false);

template <typename MODELTYPE, typename VALUETYPE>
int session_input_tensors(
    std::vector<std::pair<std::string, tensorflow::Tensor>>& input_tensors,
    const std::vector<VALUETYPE>& dcoord_,
    const int& ntypes,
    const std::vector<int>& datype_,
    const std::vector<VALUETYPE>& dbox,
    InputNlist& dlist,
    const std::vector<VALUETYPE>& fparam_,
    const std::vector<VALUETYPE>& aparam_,
    const deepmd::AtomMap& atommap,
    const int nghost,
    const int ago,
    const std::string scope = "",
    const bool aparam_nall = false);

template <typename MODELTYPE, typename VALUETYPE>
int session_input_tensors_mixed_type(
    std::vector<std::pair<std::string, tensorflow::Tensor>>& input_tensors,
    const int& nframes,
    const std::vector<VALUETYPE>& dcoord_,
    const int& ntypes,
    const std::vector<int>& datype_,
    const std::vector<VALUETYPE>& dbox,
    const double& cell_size,
    const std::vector<VALUETYPE>& fparam_,
    const std::vector<VALUETYPE>& aparam_,
    const deepmd::AtomMap& atommap,
    const std::string scope = "",
    const bool aparam_nall = false);

}  // namespace deepmd