Skip to content

Commit 8135110

Browse files
hawkinsptensorflower-gardener
authored andcommitted
[XLA] Switch Python bindings to use pybind11 instead of SWIG.
* Simplifies the Python code by taking advantage of pybind11's greater flexibility in exposing a Pythonic API. * Pybind11 is also better able to handle features of modern C++. SWIG is incapable of parsing many common C++ idioms in XLA's interface, which forced us to create intermediate wrapper classes with SWIG-friendly APIs that do little other than call other XLA APIs (e.g., ComputationBuilder). In many cases, we can now elide these. In combination with the preceding point, the code becomes significantly shorter and more readable. * Pybind11's numpy integration makes it easy to construct xla Literals that alias numpy arrays. This reduces the number of copies when transferring a numpy array to, say, a GPU to 1 instead of 2; previously a NumPy<->Literal conversion required a copy. PiperOrigin-RevId: 241067931
1 parent 64c2d99 commit 8135110

22 files changed

+1878
-4395
lines changed

tensorflow/compiler/xla/python/BUILD

Lines changed: 49 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,16 @@ licenses(["notice"]) # Apache 2.0
22

33
package(default_visibility = ["//tensorflow:internal"])
44

5-
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
65
load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library")
76
load("//tensorflow/compiler/xla:xla.bzl", "xla_python_default_plugins")
7+
load("//tensorflow:tensorflow.bzl", "tf_pybind_extension")
88

99
py_library(
1010
name = "xla_client",
1111
srcs = ["xla_client.py"],
1212
srcs_version = "PY2AND3",
1313
visibility = ["//visibility:public"],
14-
deps = [
15-
":pywrap_xla",
16-
],
14+
deps = [":xla_extension"],
1715
)
1816

1917
pyx_library(
@@ -37,31 +35,59 @@ py_test(
3735
)
3836

3937
cc_library(
40-
name = "numpy_bridge",
41-
srcs = ["numpy_bridge.cc"],
42-
hdrs = ["numpy_bridge.h"],
38+
name = "types",
39+
srcs = ["types.cc"],
40+
hdrs = ["types.h"],
41+
copts = [
42+
"-fexceptions",
43+
"-fno-strict-aliasing",
44+
"-Wno-c++98-c++11-compat",
45+
],
46+
features = ["-use_header_modules"],
4347
deps = [
4448
"//tensorflow/compiler/xla:literal",
45-
"//tensorflow/compiler/xla:literal_util",
4649
"//tensorflow/compiler/xla:shape_util",
50+
"//tensorflow/compiler/xla:status",
51+
"//tensorflow/compiler/xla:status_macros",
52+
"//tensorflow/compiler/xla:statusor",
53+
"//tensorflow/compiler/xla:types",
4754
"//tensorflow/compiler/xla:xla_data_proto",
55+
"//tensorflow/compiler/xla/service:device_memory_allocator",
4856
"//tensorflow/core:lib",
49-
"//tensorflow/python:numpy_lib",
50-
"@com_google_absl//absl/strings",
51-
"@com_google_absl//absl/strings:str_format",
52-
"@com_google_absl//absl/types:span",
57+
"@com_google_absl//absl/container:flat_hash_map",
58+
"@com_google_absl//absl/types:optional",
59+
"@pybind11",
5360
],
5461
)
5562

56-
cc_library(
57-
name = "local_computation_builder",
58-
srcs = ["local_computation_builder.cc"],
59-
hdrs = ["local_computation_builder.h"],
63+
tf_pybind_extension(
64+
name = "xla_extension",
65+
srcs = [
66+
"local_client.cc",
67+
"local_client.h",
68+
"xla.cc",
69+
],
70+
copts = [
71+
"-fexceptions",
72+
"-fno-strict-aliasing",
73+
"-Wno-c++98-c++11-compat",
74+
],
75+
features = ["-use_header_modules"],
76+
module_name = "xla_extension",
6077
deps = [
78+
":types",
79+
"@com_google_absl//absl/memory",
80+
"@com_google_absl//absl/strings",
81+
"@com_google_absl//absl/types:optional",
82+
"@com_google_absl//absl/types:span",
83+
"@pybind11",
84+
"//third_party/python_runtime:headers", # buildcleaner: keep
6185
"//tensorflow/compiler/xla:executable_run_options",
6286
"//tensorflow/compiler/xla:literal",
6387
"//tensorflow/compiler/xla:literal_util",
6488
"//tensorflow/compiler/xla:shape_util",
89+
"//tensorflow/compiler/xla:status",
90+
"//tensorflow/compiler/xla:statusor",
6591
"//tensorflow/compiler/xla:util",
6692
"//tensorflow/compiler/xla:xla_data_proto",
6793
"//tensorflow/compiler/xla/client:client_library",
@@ -77,84 +103,14 @@ cc_library(
77103
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
78104
"//tensorflow/compiler/xla/service:platform_util",
79105
"//tensorflow/compiler/xla/service:shaped_buffer",
106+
"//tensorflow/compiler/xla/service:cpu_plugin",
80107
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
81108
"//tensorflow/core:lib",
82-
"//third_party/python_runtime:headers", # buildcleaner: keep
83-
"@com_google_absl//absl/memory",
84-
"@com_google_absl//absl/types:span",
85-
],
86-
)
87-
88-
cc_library(
89-
name = "xrt",
90-
srcs = ["xrt.cc"],
91-
hdrs = ["xrt.h"],
92-
deps = [
93-
"//tensorflow/cc:cc_ops",
94-
"//tensorflow/cc:client_session",
95-
"//tensorflow/cc:ops",
96-
"//tensorflow/cc:scope",
97-
"//tensorflow/compiler/xla:literal",
98-
"//tensorflow/compiler/xla:literal_util",
99-
"//tensorflow/compiler/xla:shape_util",
100-
"//tensorflow/compiler/xla:util",
101-
"//tensorflow/compiler/xla:xla_data_proto",
102-
"//tensorflow/compiler/xla/service:hlo_proto",
103-
"//tensorflow/compiler/xla/service:platform_util",
104-
"//tensorflow/compiler/xrt:xrt_proto",
105-
"//tensorflow/compiler/xrt/cc:xrt_ops",
106-
"//tensorflow/core:framework",
107-
"//tensorflow/core:lib",
108-
"@com_google_absl//absl/memory",
109-
"@com_google_absl//absl/types:span",
110-
],
111-
)
112-
113-
tf_py_wrap_cc(
114-
name = "pywrap_xla",
115-
srcs = [
116-
"xla.i",
117-
],
118-
swig_includes = [
119-
"local_computation_builder.i",
120-
"xla_data.i",
121-
"//tensorflow/python:platform/base.i",
122-
],
123-
version_script = select({
124-
"//tensorflow:macos": "pywrap_xla_exported_symbols.lds",
125-
"//tensorflow:windows": None,
126-
"//conditions:default": "pywrap_xla_version_script.lds",
127-
}),
128-
deps = [
129-
":local_computation_builder",
130-
":numpy_bridge",
131-
"//tensorflow/compiler/xla:literal",
132-
"//tensorflow/compiler/xla:shape_util",
133-
"//tensorflow/compiler/xla:xla_data_proto",
134-
"//tensorflow/compiler/xla/service:cpu_plugin",
109+
# Do NOT remove this dependency. The XLA Python extension must not
110+
# depend on any part of TensorFlow at runtime, **including**
111+
# libtensorflow_framework.so. The XLA module is deployed self-contained
112+
# without any TF dependencies as "jaxlib" on Pypi, and "jaxlib" does
113+
# not require Tensorflow.
114+
"//tensorflow/core:lib_internal_impl", # buildcleaner: keep
135115
] + xla_python_default_plugins(),
136116
)
137-
138-
tf_py_wrap_cc(
139-
name = "pywrap_xrt",
140-
srcs = [
141-
"xrt.i",
142-
],
143-
swig_includes = [
144-
"xla_data.i",
145-
"//tensorflow/python:platform/base.i",
146-
],
147-
version_script = select({
148-
"//tensorflow:macos": "pywrap_xla_exported_symbols.lds",
149-
"//tensorflow:windows": None,
150-
"//conditions:default": "pywrap_xla_version_script.lds",
151-
}),
152-
visibility = ["//visibility:public"],
153-
deps = [
154-
":numpy_bridge",
155-
":xrt",
156-
"//tensorflow/compiler/xla:literal",
157-
"//tensorflow/compiler/xla:shape_util",
158-
"//tensorflow/compiler/xla:xla_data_proto",
159-
],
160-
)

0 commit comments

Comments
 (0)