From 9fa3d27c4bdd534eaff88ea2c4a7119e3174dbbf Mon Sep 17 00:00:00 2001 From: Brian Lee Date: Wed, 6 Mar 2019 18:37:28 -0800 Subject: [PATCH] Add unicode string support to pyct pretty printer. PiperOrigin-RevId: 237161597 --- tensorflow/python/autograph/pyct/pretty_printer.py | 9 +++++++-- tensorflow/python/autograph/pyct/pretty_printer_test.py | 9 +++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/autograph/pyct/pretty_printer.py b/tensorflow/python/autograph/pyct/pretty_printer.py index a92017f4142f67..9a4f509ec36bd5 100644 --- a/tensorflow/python/autograph/pyct/pretty_printer.py +++ b/tensorflow/python/autograph/pyct/pretty_printer.py @@ -18,7 +18,9 @@ from __future__ import division from __future__ import print_function + import gast +import six import termcolor @@ -106,9 +108,12 @@ def generic_visit(self, node, name=None): self._print('%s%s=()' % (self._indent(), self._field(f))) elif isinstance(v, gast.AST): self.generic_visit(v, f) - elif isinstance(v, str): + elif isinstance(v, six.binary_type): + self._print('%s%s=%s' % (self._indent(), self._field(f), + self._value('b"%s"' % v))) + elif isinstance(v, six.text_type): self._print('%s%s=%s' % (self._indent(), self._field(f), - self._value('"%s"' % v))) + self._value('u"%s"' % v))) else: self._print('%s%s=%s' % (self._indent(), self._field(f), self._value(v))) diff --git a/tensorflow/python/autograph/pyct/pretty_printer_test.py b/tensorflow/python/autograph/pyct/pretty_printer_test.py index 1c76744547f584..26d70f2e6006fe 100644 --- a/tensorflow/python/autograph/pyct/pretty_printer_test.py +++ b/tensorflow/python/autograph/pyct/pretty_printer_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import ast +import textwrap from tensorflow.python.autograph.pyct import pretty_printer from tensorflow.python.platform import test @@ -26,6 +27,14 @@ class PrettyPrinterTest(test.TestCase): + def test_unicode_bytes(self): + source = textwrap.dedent(''' + def f(): + return b'b', u'u', 'depends_py2_py3' + ''') + node = ast.parse(source) + self.assertIsNotNone(pretty_printer.fmt(node)) + def test_format(self): node = ast.FunctionDef( name='f',