Skip to content

Commit

Permalink
Add the "map" and "fixed" avro types. (#31)
Browse files Browse the repository at this point in the history
* Add the "map" and "fixed" avro types.

* Blackened the incoming fork
  • Loading branch information
wcn00 committed Nov 1, 2021
1 parent a16b63c commit 5c7110e
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 58 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# Changelog

## 0.12.0
### Added
- Support for the "map" and "fixed" avro types.
### Changed
- Alphabetically sort generated imports.
- Gather imports from the same module onto a single line.
### Fixed
- An issue retrieving items from array types within lists of types.

## 0.11.2
### Changed
Expand Down
2 changes: 1 addition & 1 deletion avro_to_python_types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.11.2"
__version__ = "0.12.0"
from .typed_dict_from_schema import (
typed_dict_from_schema_file,
typed_dict_from_schema_string,
Expand Down
2 changes: 2 additions & 0 deletions avro_to_python_types/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
TYPE = "type"
ITEMS = "items"
LIST = "List"
VALUES = "values"
DICT = "Dict"
149 changes: 106 additions & 43 deletions avro_to_python_types/typed_dict_from_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .constants import OPTIONAL
from .generate_typed_dict import GenerateTypedDict
from .schema_mapping import prim_to_type, logical_to_python_type
from collections import defaultdict
from enum import Enum
from fastavro.schema import (
expand_schema,
Expand All @@ -23,13 +24,17 @@
TYPE,
ITEMS,
LIST,
VALUES,
DICT,
)


class AvroSubType(Enum):
ENUM = "enum"
RECORD = "record"
ARRAY = "array"
MAP = "map"
FIXED = "fixed"


def is_nullable(field):
Expand Down Expand Up @@ -121,14 +126,30 @@ def get_array_items(array_type):
"""
if isinstance(array_type, list):
for list_type in list(array_type):
if isinstance(list_type, dict) and ITEMS in array_type:
if isinstance(list_type, dict) and ITEMS in list_type:
return list_type[ITEMS]
elif isinstance(array_type, dict) and ITEMS in array_type:
return array_type[ITEMS]
else:
raise Exception("invalid schema, array type has no items")


def get_map_values(map_type):
"""
Return the item type for the map. Either a reference to a composite obj
or a primative
"""
if isinstance(map_type, list):
for list_type in list(map_type):
if isinstance(list_type, dict) and VALUES in list_type:
return list_type[VALUES]
elif isinstance(map_type, dict) and VALUES in map_type:
return map_type[VALUES]
else:
raise Exception("invalid schema, map type has no values")


def get_logical_type(types):
"""
Logical types can be dates, datetimes, UUIDs etc.
Expand Down Expand Up @@ -185,9 +206,7 @@ def types_for_schema(schema):
tree = ast.Module(body)
body = tree.body

def type_for_schema_record(
record_schema, imports, enums, complex_types, import_flags
):
def type_for_schema_record(record_schema, imports, enums, complex_types):
type_name = "".join(
word[0].upper() + word[1:] for word in record_schema["name"].split(".")
)
Expand All @@ -206,24 +225,24 @@ def type_for_schema_record(
"""
union_field = get_union_type(field[TYPE])
nested = type_for_schema_record(
union_field, imports, enums, complex_types, import_flags
union_field, imports, enums, complex_types
)
body.append(nested.tree)
if is_nullable(field):
our_type.add_optional_element(name, nested.name)
import_flags[OPTIONAL] = True
imports["typing"].add(OPTIONAL)
else:
our_type.add_required_element(name, nested.name)
complex_types.append(nested.name)
elif field_type_is_of_type(field[TYPE], AvroSubType.RECORD.value):
"""nested - This processes an expanded nested type recursively."""
nested = type_for_schema_record(
field[TYPE], imports, enums, complex_types, import_flags
field[TYPE], imports, enums, complex_types
)
body.append(nested.tree)
if is_nullable(field):
our_type.add_optional_element(name, nested.name)
import_flags[OPTIONAL] = True
imports["typing"].add(OPTIONAL)
else:
our_type.add_required_element(name, nested.name)
complex_types.append(nested.name)
Expand All @@ -232,14 +251,11 @@ def type_for_schema_record(
importing packages like date, datetime, uuid and decimal hence the
imports collection"""
logical_type = logical_to_python_type[get_logical_type(field[TYPE])]
imports.append(
"from {} import {}\n".format(
logical_type.split(".")[0], logical_type.split(".")[1]
)
)
module, class_import = logical_type.split(".")
imports[module].add(class_import)
if is_nullable(field):
our_type.add_optional_element(name, logical_type.split(".")[1])
import_flags[OPTIONAL] = True
imports["typing"].add(OPTIONAL)
else:
our_type.add_required_element(name, logical_type.split(".")[1])
elif field_type_is_of_type(field[TYPE], AvroSubType.ENUM.value):
Expand All @@ -248,9 +264,7 @@ def type_for_schema_record(
different schemas will result in that enum being duplicated, but
with a different name. We can revisit that if necessary.
"""
imports.append(
"from {} import {}\n".format(AvroSubType.ENUM.value, ENUM_CLASS)
)
imports[AvroSubType.ENUM.value].add(ENUM_CLASS)
""" The enum class name is composed the same way as the typedict
name is """
enum_class_name = "".join(
Expand All @@ -265,7 +279,7 @@ def type_for_schema_record(
enums[enum_class] = enum_class
if is_nullable(field):
our_type.add_optional_element(name, enum_class_name)
import_flags[OPTIONAL] = True
imports["typing"].add(OPTIONAL)
else:
our_type.add_required_element(name, enum_class_name)
complex_types.append(enum_class_name)
Expand All @@ -277,17 +291,17 @@ def type_for_schema_record(
if field_type_is_of_type(items_type, AvroSubType.RECORD.value):
"""Arrays is for a complex nested type"""
nested = type_for_schema_record(
items_type, imports, enums, complex_types, import_flags
items_type, imports, enums, complex_types
)
body.append(nested.tree)
if is_nullable(field):
our_type.add_optional_element(name, f"List[{nested.name}]")
import_flags[OPTIONAL] = True
imports["typing"].add(OPTIONAL)
else:
our_type.add_required_element(name, f"List[{nested.name}]")
complex_types.append(nested.name)
else:
"""Array is of a prmitive type"""
"""Array is of a primitive type"""
if not items_type in prim_to_type.keys():
items_type_name = "".join(
word[0].upper() + word[1:]
Expand All @@ -302,10 +316,65 @@ def type_for_schema_record(
array_type = prim_to_type[items_type]
if is_nullable(field):
our_type.add_optional_element(name, f"List[{array_type}]")
import_flags[OPTIONAL] = True
imports["typing"].add(OPTIONAL)
else:
our_type.add_required_element(name, f"List[{array_type}]")
import_flags[LIST] = True
imports["typing"].add(LIST)
# map
elif field_type_is_of_type(field[TYPE], AvroSubType.MAP.value):
"""Map types are either primitive or complex."""
values_type = get_map_values(field[TYPE])
if field_type_is_of_type(values_type, AvroSubType.RECORD.value):
"""Map is for a complex nested type"""
nested = type_for_schema_record(
values_type, imports, enums, complex_types
)
body.append(nested.tree)
if is_nullable(field):
"""Avro map keys are always strings."""
our_type.add_optional_element(
name, f"Dict[str, {nested.name}]"
)
imports["typing"].add(OPTIONAL)
else:
our_type.add_required_element(
name, f"Dict[str, {nested.name}]"
)
complex_types.append(nested.name)
else:
"""Map is of a primitive type"""
if not values_type in prim_to_type.keys():
values_type_name = "".join(
word[0].upper() + word[1:]
for word in values_type.split(".")
)
array_type = (
values_type_name
if values_type_name in complex_types
else prim_to_type[values_type]
)
else:
array_type = prim_to_type[values_type]
if is_nullable(field):
our_type.add_optional_element(
name, f"Dict[str, {array_type}]"
)
imports["typing"].add(OPTIONAL)
else:
our_type.add_required_element(
name, f"Dict[str, {array_type}]"
)
imports["typing"].add(DICT)
# fixed
elif field_type_is_of_type(field[TYPE], AvroSubType.FIXED.value):
"""Fixed types are simply represented as bytes. The size
field is ignored because size checking is not possible with
Python typing."""
if is_nullable(field):
our_type.add_optional_element(name, "bytes")
imports["typing"].add(OPTIONAL)
else:
our_type.add_required_element(name, "bytes")
# primitive
else:
"""Ths section process a primitive type or a named complex type."""
Expand All @@ -328,7 +397,7 @@ def type_for_schema_record(
reference_type = prim_to_type[field_type]
if is_nullable(field):
our_type.add_optional_element(name, reference_type)
import_flags[OPTIONAL] = True
imports["typing"].add(OPTIONAL)
else:
our_type.add_required_element(name, reference_type)
except Exception as e:
Expand All @@ -341,31 +410,25 @@ def type_for_schema_record(
)
return our_type

imports = []
imports = defaultdict(set)
imports["typing"].add("TypedDict")
enums = {}
complex_types = []
import_flags = {OPTIONAL: False, LIST: False}
main_type = type_for_schema_record(
schema, imports, enums, complex_types, import_flags
)

additional_types = []
# import the Optional type only if required
if import_flags[OPTIONAL]:
additional_types.append(OPTIONAL)
if import_flags[LIST]:
additional_types.append(LIST)
additional_types.append("TypedDict")
additional_types_as_str = ", ".join(additional_types)
main_type = type_for_schema_record(schema, imports, enums, complex_types)

imports.append(f"from typing import {additional_types_as_str}\n")
import_lines = []
for module, classes in imports.items():
classes = ", ".join(sorted(classes))
import_lines.append(f"from {module} import {classes}")
import_code = "\n".join(sorted(import_lines))

body.append(main_type.tree)
imports = sorted(list(set(imports)))
generated_code = (
"".join(imports) + resolve_enum_str(enums) + ast.unparse(_dedupe_ast(tree).body)

generated_code = "\n".join(
[import_code, resolve_enum_str(enums), ast.unparse(_dedupe_ast(tree).body)]
)
formatted_code = black.format_str(generated_code, mode=black.FileMode())

formatted_code = black.format_str(generated_code, mode=black.Mode())
return formatted_code


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "avro-to-python-types"
version = "0.11.2"
version = "0.12.0"
description = "A library for converting avro schemas to python types."
readme = "README.md"
authors = ["Dan Green-Leipciger"]
Expand Down

0 comments on commit 5c7110e

Please sign in to comment.