Skip to content

Commit

Permalink
Paint it black.
Browse files Browse the repository at this point in the history
  • Loading branch information
FlipperPA committed Feb 22, 2022
1 parent ec0f5bf commit e8184d2
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 35 deletions.
13 changes: 10 additions & 3 deletions drf_excel/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,19 @@ def _parse_date(self, value, setting_format, iso_parse_func):
def init_value(self, value):
# Set tzinfo to None on datetime and time types since timezones are not supported in Excel
try:
if isinstance(self.drf_field, DateTimeField) and type(value) != datetime.datetime:
return self._parse_date(value, "DATETIME_FORMAT", parse_datetime).replace(tzinfo=None)
if (
isinstance(self.drf_field, DateTimeField)
and type(value) != datetime.datetime
):
return self._parse_date(
value, "DATETIME_FORMAT", parse_datetime
).replace(tzinfo=None)
elif isinstance(self.drf_field, DateField) and type(value) != datetime.date:
return self._parse_date(value, "DATE_FORMAT", parse_date)
elif isinstance(self.drf_field, TimeField) and type(value) != datetime.time:
return self._parse_date(value, "TIME_FORMAT", parse_time).replace(tzinfo=None)
return self._parse_date(value, "TIME_FORMAT", parse_time).replace(
tzinfo=None
)
except:
return value

Expand Down
114 changes: 87 additions & 27 deletions drf_excel/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from rest_framework.renderers import BaseRenderer
from rest_framework.serializers import Serializer

from drf_excel.fields import XLSXBooleanField, XLSXDateField, XLSXField, XLSXListField, XLSXNumberField
from drf_excel.fields import (
XLSXBooleanField,
XLSXDateField,
XLSXField,
XLSXListField,
XLSXNumberField,
)
from drf_excel.utilities import XLSXStyle, set_cell_style


Expand Down Expand Up @@ -88,11 +94,15 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
if img_addr:
img = Image(img_addr)
self.ws.add_image(img, "A1")
header_style = XLSXStyle(header.get("style")) if header and "style" in header else None
header_style = (
XLSXStyle(header.get("style")) if header and "style" in header else None
)

column_header = get_attribute(drf_view, "column_header", {})
column_header_style = (
XLSXStyle(column_header.get("style")) if column_header and "style" in column_header else None
XLSXStyle(column_header.get("style"))
if column_header and "style" in column_header
else None
)
column_count = 0
row_count = 1
Expand All @@ -117,15 +127,17 @@ def render(self, data, accepted_media_type=None, renderer_context=None):

# Set dict named column_data_styles with headers as keys and style as value. i.e.
# column_data_styles = {
# 'distance': {
# 'fill': {'fill_type': 'solid', 'start_color': 'FFCCFFCC'},
# 'alignment': {'horizontal': 'center', 'vertical': 'center', 'wrapText': True, 'shrink_to_fit': True},
# 'border_side': {'border_style': 'thin', 'color': 'FF000000'},
# 'font': {'name': 'Arial', 'size': 14, 'bold': True, 'color': 'FF000000'},
# 'format': '0.00E+00'
# },
# }
self.column_data_styles = get_attribute(drf_view, "column_data_styles", dict())
# 'distance': {
# 'fill': {'fill_type': 'solid', 'start_color': 'FFCCFFCC'},
# 'alignment': {'horizontal': 'center', 'vertical': 'center', 'wrapText': True, 'shrink_to_fit': True},
# 'border_side': {'border_style': 'thin', 'color': 'FF000000'},
# 'font': {'name': 'Arial', 'size': 14, 'bold': True, 'color': 'FF000000'},
# 'format': '0.00E+00'
# },
# }
self.column_data_styles = get_attribute(
drf_view, "column_data_styles", dict()
)

# Set dict of additional columns. Can be useful when wanting to add columns
# that don't exist in the API response. For example, you could want to
Expand All @@ -147,12 +159,17 @@ def render(self, data, accepted_media_type=None, renderer_context=None):

self.fields_dict = self._serializer_fields(drf_view.get_serializer())

xlsx_header_dict = self._flatten_serializer_keys(drf_view.get_serializer(), use_labels=use_labels)
xlsx_header_dict = self._flatten_serializer_keys(
drf_view.get_serializer(), use_labels=use_labels
)
if self.custom_cols:
custom_header_dict = {
key: self.custom_cols[key].get("label", None) or key for key in self.custom_cols.keys()
key: self.custom_cols[key].get("label", None) or key
for key in self.custom_cols.keys()
}
self.combined_header_dict = dict(list(xlsx_header_dict.items()) + list(custom_header_dict.items()))
self.combined_header_dict = dict(
list(xlsx_header_dict.items()) + list(custom_header_dict.items())
)
else:
self.combined_header_dict = xlsx_header_dict

Expand All @@ -165,7 +182,9 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
else:
column_name_display = column_titles[column_count - 1]

header_cell = self.ws.cell(row=row_count, column=column_count, value=column_name_display)
header_cell = self.ws.cell(
row=row_count, column=column_count, value=column_name_display
)
set_cell_style(header_cell, column_header_style)
self.ws.row_dimensions[row_count].height = column_header.get("height", 45)

Expand Down Expand Up @@ -193,7 +212,9 @@ def render(self, data, accepted_media_type=None, renderer_context=None):

# Make body
body = get_attribute(drf_view, "body", {})
self.body_style = XLSXStyle(body.get("style")) if body and "style" in body else None
self.body_style = (
XLSXStyle(body.get("style")) if body and "style" in body else None
)
if isinstance(results, dict):
self._make_body(body, results, row_count)
elif isinstance(results, list):
Expand All @@ -220,7 +241,14 @@ def _serializer_fields(self, serializer, parent_key="", key_sep="."):
return _fields_dict

def _flatten_serializer_keys(
self, serializer, parent_key="", parent_label="", key_sep=".", list_sep=", ", label_sep=" > ", use_labels=False
self,
serializer,
parent_key="",
parent_label="",
key_sep=".",
list_sep=", ",
label_sep=" > ",
use_labels=False,
):
"""
Iterate through serializer fields recursively when field is a nested serializer.
Expand Down Expand Up @@ -248,13 +276,24 @@ def _get_label(parent_label, label_sep, obj):
if use_labels and getattr(v, "label", None):
_header_dict.update(
self._flatten_serializer_keys(
v, new_key, _get_label(parent_label, label_sep, v), key_sep, list_sep, label_sep, use_labels
v,
new_key,
_get_label(parent_label, label_sep, v),
key_sep,
list_sep,
label_sep,
use_labels,
)
)
else:
_header_dict.update(
self._flatten_serializer_keys(
v, new_key, key_sep=key_sep, list_sep=list_sep, label_sep=label_sep, use_labels=use_labels
v,
new_key,
key_sep=key_sep,
list_sep=list_sep,
label_sep=label_sep,
use_labels=use_labels,
)
)
elif isinstance(v, Field):
Expand Down Expand Up @@ -285,34 +324,55 @@ def _make_body(self, body, row, row_count):
continue
column_count += 1
field = flattened_row.get(header_key)
field.cell(self.ws, row_count, column_count) if field else self.ws.cell(row_count, column_count)
field.cell(self.ws, row_count, column_count) if field else self.ws.cell(
row_count, column_count
)
self.ws.row_dimensions[row_count].height = body.get("height", 40)
if "row_color" in row:
last_letter = get_column_letter(column_count)
cell_range = self.ws["A{}".format(row_count) : "{}{}".format(last_letter, row_count)]
cell_range = self.ws[
"A{}".format(row_count) : "{}{}".format(last_letter, row_count)
]
fill = PatternFill(fill_type="solid", start_color=row["row_color"])
for r in cell_range:
for c in r:
c.fill = fill

def _drf_to_xlsx_field(self, key, value) -> XLSXField:
field = self.fields_dict.get(key)
cell_style = XLSXStyle(self.column_data_styles.get(key)) if key in self.column_data_styles else None
cell_style = (
XLSXStyle(self.column_data_styles.get(key))
if key in self.column_data_styles
else None
)
kwargs = {
"key": key,
"value": value,
"field": field,
"style": self.body_style,
# Basically using formatter of custom col as a custom mapping
"mapping": self.custom_cols.get(key, {}).get("formatter") or self.custom_mappings.get(key),
"mapping": self.custom_cols.get(key, {}).get("formatter")
or self.custom_mappings.get(key),
"cell_style": cell_style,
}
if isinstance(field, BooleanField) or isinstance(field, NullBooleanField):
return XLSXBooleanField(boolean_display=self.boolean_display, **kwargs)
elif isinstance(field, IntegerField) or isinstance(field, FloatField) or isinstance(field, DecimalField):
elif (
isinstance(field, IntegerField)
or isinstance(field, FloatField)
or isinstance(field, DecimalField)
):
return XLSXNumberField(**kwargs)
elif isinstance(field, DateTimeField) or isinstance(field, DateField) or isinstance(field, TimeField):
elif (
isinstance(field, DateTimeField)
or isinstance(field, DateField)
or isinstance(field, TimeField)
):
return XLSXDateField(**kwargs)
elif isinstance(field, ListField) or isinstance(value, Iterable) and not isinstance(value, str):
elif (
isinstance(field, ListField)
or isinstance(value, Iterable)
and not isinstance(value, str)
):
return XLSXListField(list_sep=self.list_sep, **kwargs)
return XLSXField(**kwargs)
24 changes: 19 additions & 5 deletions drf_excel/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,23 @@ def __init__(self, style_dict=None):
if style_dict is None:
style_dict = {}
self.font = Font(**style_dict.get("font")) if "font" in style_dict else None
self.fill = PatternFill(**style_dict.get("fill")) if "fill" in style_dict else None
self.alignment = Alignment(**style_dict.get("alignment")) if "alignment" in style_dict else None
self.fill = (
PatternFill(**style_dict.get("fill")) if "fill" in style_dict else None
)
self.alignment = (
Alignment(**style_dict.get("alignment"))
if "alignment" in style_dict
else None
)
self.number_format = style_dict.get("format", None)
side = Side(**style_dict.get("border_side")) if "border_side" in style_dict else None
self.border = Border(left=side, right=side, top=side, bottom=side) if side else None
side = (
Side(**style_dict.get("border_side"))
if "border_side" in style_dict
else None
)
self.border = (
Border(left=side, right=side, top=side, bottom=side) if side else None
)


def get_setting(key, default=None):
Expand All @@ -49,7 +61,9 @@ def sanitize_value(value):
# prepend ' if value is starting with possible malicious char
if value:
str_value = str(value)
str_value = ILLEGAL_CHARACTERS_RE.sub("", str_value) # remove ILLEGAL_CHARACTERS so it doesn't crash
str_value = ILLEGAL_CHARACTERS_RE.sub(
"", str_value
) # remove ILLEGAL_CHARACTERS so it doesn't crash
return "'" + str_value if str_value.startswith(ESCAPE_CHARS) else str_value
return value

Expand Down

0 comments on commit e8184d2

Please sign in to comment.