Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add title argument to plotting function #363

Open
wants to merge 1 commit into
base: plotting-ax
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 5 additions & 4 deletions Python-packages/covidcast-py/covidcast/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def plot(data: pd.DataFrame,
plot_type: str = "choropleth",
combine_megacounties: bool = True,
ax: axes.Axes = None,
title: str = None,
**kwargs: Any) -> axes.Axes:
"""Given the output data frame of :py:func:`covidcast.signal`, plot a choropleth or bubble map.

Expand Down Expand Up @@ -81,8 +82,8 @@ def plot(data: pd.DataFrame,
:param kwargs: Optional keyword arguments passed to ``GeoDataFrame.plot()``.
:param plot_type: Type of plot to create. Either choropleth (default) or bubble map.
:param ax: Optional matplotlib axis to plot on.
:return: Matplotlib axes object.

:param title: Plot title. If not provided, will default to "source: signal, day"
:return: Matplotlib figure object.
"""
if plot_type not in {"choropleth", "bubble"}:
raise ValueError("`plot_type` must be 'choropleth' or 'bubble'.")
Expand All @@ -93,10 +94,10 @@ def plot(data: pd.DataFrame,
day_data = data.loc[data.time_value == pd.to_datetime(day_to_plot), :]
kwargs["vmax"] = kwargs.get("vmax", meta["mean_value"] + 3 * meta["stdev_value"])
kwargs["figsize"] = kwargs.get("figsize", (12.8, 9.6))

ax = _plot_background_states(kwargs["figsize"]) if ax is None else ax
ax.axis("off")
ax.set_title(f"{data_source}: {signal}, {day_to_plot.strftime('%Y-%m-%d')}")
ax.set_title(
f"{data_source}: {signal}, {day_to_plot.strftime('%Y-%m-%d')}" if title is None else title)
if plot_type == "choropleth":
_plot_choro(ax, day_data, combine_megacounties, **kwargs)
else:
Expand Down
9 changes: 7 additions & 2 deletions Python-packages/covidcast-py/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ def _convert_to_array(fig: matplotlib.figure.Figure) -> np.array:
@patch("covidcast.plotting._signal_metadata")
def test_plot(mock_metadata):
mock_metadata.side_effect = [
{"mean_value": 0.5330011, "stdev_value": 0.4683431}, # county metadata
{"mean_value": 0.5330011, "stdev_value": 0.4683431},
{"mean_value": 0.5330011, "stdev_value": 0.4683431},
{"mean_value": 0.5330011, "stdev_value": 0.4683431},
{"mean_value": 0.5304083, "stdev_value": 0.235302},
{"mean_value": 0.5304083, "stdev_value": 0.235302}, # state metadata
{"mean_value": 0.5705364, "stdev_value": 0.4348706}, # msa metadata
{"mean_value": 0.5705364, "stdev_value": 0.4348706},
{"mean_value": 0.5705364, "stdev_value": 0.4348706},
]
Expand Down Expand Up @@ -92,6 +93,10 @@ def test_plot(mock_metadata):
msa_bubble_fig = plt.gcf()
assert np.allclose(_convert_to_array(msa_bubble_fig), expected["msa_bubble"], atol=2, rtol=0)

# test title
ax = plotting.plot(test_msa, title="test")
assert ax.title.get_text() == "test"


def test_get_geo_df():
test_input = pd.DataFrame({"geo_value": ["24510", "31169", "37000"],
Expand Down