Source code for backend.find_tensorflow

# SPDX-License-Identifier: LGPL-3.0-or-later
import os
import site
from functools import (
    lru_cache,
)
from importlib.machinery import (
    FileFinder,
)
from importlib.util import (
    find_spec,
)
from pathlib import (
    Path,
)
from sysconfig import (
    get_path,
)
from typing import (
    List,
    Optional,
    Tuple,
    Union,
)

from packaging.specifiers import (
    SpecifierSet,
)


[docs]@lru_cache() def find_tensorflow() -> Tuple[Optional[str], List[str]]: """Find TensorFlow library. Tries to find TensorFlow in the order of: 1. Environment variable `TENSORFLOW_ROOT` if set 2. The current Python environment. 3. user site packages directory if enabled 4. system site packages directory (purelib) 5. add as a requirement (detect TENSORFLOW_VERSION or the latest) and let pip install it Returns ------- str TensorFlow library path if found. list of str TensorFlow requirement if not found. Empty if found. """ requires = [] tf_spec = None if os.environ.get("CIBUILDWHEEL", "0") == "1" and os.environ.get( "CIBW_BUILD", "" ).endswith("macosx_arm64"): # cibuildwheel cross build site_packages = Path(os.environ.get("RUNNER_TEMP")) / "tensorflow" tf_spec = FileFinder(str(site_packages)).find_spec("tensorflow") if (tf_spec is None or not tf_spec) and os.environ.get( "TENSORFLOW_ROOT" ) is not None: site_packages = Path(os.environ.get("TENSORFLOW_ROOT")).parent.absolute() tf_spec = FileFinder(str(site_packages)).find_spec("tensorflow") # get tensorflow spec # note: isolated build will not work for backend if tf_spec is None or not tf_spec: tf_spec = find_spec("tensorflow") if not tf_spec and site.ENABLE_USER_SITE: # first search TF from user site-packages before global site-packages site_packages = site.getusersitepackages() if site_packages: tf_spec = FileFinder(site_packages).find_spec("tensorflow") if not tf_spec: # purelib gets site-packages path site_packages = get_path("purelib") if site_packages: tf_spec = FileFinder(site_packages).find_spec("tensorflow") # get install dir from spec try: tf_install_dir = tf_spec.submodule_search_locations[0] # type: ignore # AttributeError if ft_spec is None # TypeError if submodule_search_locations are None # IndexError if submodule_search_locations is an empty list except (AttributeError, TypeError, IndexError): if os.environ.get("CIBUILDWHEEL", "0") == "1": cuda_version = os.environ.get("CUDA_VERSION", "12.2") if cuda_version == "" or cuda_version in SpecifierSet(">=12,<13"): # CUDA 12.2 requires.extend( [ "tensorflow-cpu>=2.15.0rc0; platform_machine=='x86_64' and platform_system == 'Linux'", ] ) elif cuda_version in SpecifierSet(">=11,<12"): # CUDA 11.8 requires.extend( [ "tensorflow-cpu>=2.5.0rc0,<2.15; platform_machine=='x86_64' and platform_system == 'Linux'", ] ) else: raise RuntimeError("Unsupported CUDA version") requires.extend(get_tf_requirement()["cpu"]) # setuptools will re-find tensorflow after installing setup_requires tf_install_dir = None return tf_install_dir, requires
[docs]@lru_cache() def get_tf_requirement(tf_version: str = "") -> dict: """Get TensorFlow requirement (CPU) when TF is not installed. If tf_version is not given and the environment variable `TENSORFLOW_VERSION` is set, use it as the requirement. Parameters ---------- tf_version : str, optional TF version Returns ------- dict TensorFlow requirement, including cpu and gpu. """ if tf_version == "": tf_version = os.environ.get("TENSORFLOW_VERSION", "") extra_requires = [] extra_select = {} if not (tf_version == "" or tf_version in SpecifierSet(">=2.12", prereleases=True)): extra_requires.append("protobuf<3.20") if tf_version == "" or tf_version in SpecifierSet(">=1.15", prereleases=True): extra_select["mpi"] = [ "horovod", "mpi4py", ] else: extra_select["mpi"] = [] if tf_version == "": return { "cpu": [ "tensorflow-cpu; platform_machine!='aarch64' and (platform_machine!='arm64' or platform_system != 'Darwin')", "tensorflow; platform_machine=='aarch64' or (platform_machine=='arm64' and platform_system == 'Darwin')", # https://github.com/tensorflow/tensorflow/issues/61830 "tensorflow-cpu<2.15; platform_system=='Windows'", *extra_requires, ], "gpu": [ "tensorflow", "tensorflow-metal; platform_machine=='arm64' and platform_system == 'Darwin'", *extra_requires, ], **extra_select, } elif tf_version in SpecifierSet( "<1.15", prereleases=True ) or tf_version in SpecifierSet(">=2.0,<2.1", prereleases=True): return { "cpu": [ f"tensorflow=={tf_version}", *extra_requires, ], "gpu": [ f"tensorflow-gpu=={tf_version}; platform_machine!='aarch64'", f"tensorflow=={tf_version}; platform_machine=='aarch64'", *extra_requires, ], **extra_select, } else: return { "cpu": [ f"tensorflow-cpu=={tf_version}; platform_machine!='aarch64' and (platform_machine!='arm64' or platform_system != 'Darwin')", f"tensorflow=={tf_version}; platform_machine=='aarch64' or (platform_machine=='arm64' and platform_system == 'Darwin')", *extra_requires, ], "gpu": [ f"tensorflow=={tf_version}", "tensorflow-metal; platform_machine=='arm64' and platform_system == 'Darwin'", *extra_requires, ], **extra_select, }
[docs]@lru_cache() def get_tf_version(tf_path: Union[str, Path]) -> str: """Get TF version from a TF Python library path. Parameters ---------- tf_path : str or Path TF Python library path Returns ------- str version """ if tf_path is None or tf_path == "": return "" version_file = ( Path(tf_path) / "include" / "tensorflow" / "core" / "public" / "version.h" ) major = minor = patch = None with open(version_file) as f: for line in f: if line.startswith("#define TF_MAJOR_VERSION"): major = line.split()[-1] elif line.startswith("#define TF_MINOR_VERSION"): minor = line.split()[-1] elif line.startswith("#define TF_PATCH_VERSION"): patch = line.split()[-1] elif line.startswith("#define TF_VERSION_SUFFIX"): suffix = line.split()[-1].strip('"') if None in (major, minor, patch): raise RuntimeError("Failed to read TF version") return ".".join((major, minor, patch)) + suffix