Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(eda): make comp. and plot API consistent #922

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion dataprep/eda/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def plot(
cfg = Config.from_dict(display, config)

with ProgressBar(minimum=1, disable=not progress):
itmdt = compute(df, col1, col2, col3, cfg=cfg, dtype=dtype)
itmdt = compute(df, col1, col2, col3, config=cfg, dtype=dtype)

to_render = render(itmdt, cfg)

Expand Down
13 changes: 6 additions & 7 deletions dataprep/eda/distribution/compute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def compute(
col2: Optional[Union[str, LatLong]] = None,
col3: Optional[str] = None,
*,
cfg: Union[Config, Dict[str, Any], None] = None,
config: Union[Config, Dict[str, Any], None] = None,
display: Optional[List[str]] = None,
dtype: Optional[DTypeDef] = None,
) -> Intermediate:
Expand All @@ -36,10 +36,10 @@ def compute(
----------
df
DataFrame from which visualizations are generated
cfg: Union[Config, Dict[str, Any], None], default None
config: Union[Config, Dict[str, Any], None], default None
When a user call plot(), the created Config object will be passed to compute().
When a user call compute() directly, if he/she wants to customize the output,
cfg is a dictionary for configuring. If not, cfg is None and
config is a dictionary for configuring. If not, config is None and
default values will be used for parameters.
display: Optional[List[str]], default None
A list containing the names of the visualizations to display. Only exist when
Expand All @@ -60,10 +60,9 @@ def compute(

suppress_warnings()

if isinstance(cfg, dict):
cfg = Config.from_dict(display, cfg)

elif not cfg:
if isinstance(config, dict):
cfg = Config.from_dict(display, config)
else:
cfg = Config()

x, y, z = col1, col2, col3
Expand Down
31 changes: 15 additions & 16 deletions dataprep/eda/distribution/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -2455,43 +2455,42 @@ def render_dt_num_cat(itmdt: Intermediate, cfg: Config) -> Dict[str, Any]:
}


def render(itmdt: Intermediate, cfg: Config) -> Union[LayoutDOM, Dict[str, Any]]:
def render(itmdt: Intermediate, config: Config) -> Union[LayoutDOM, Dict[str, Any]]:
"""
Render a basic plot
Parameters
----------
itmdt
The Intermediate containing results from the compute function.
cfg
config
Config instance
"""
# pylint: disable = too-many-branches

if itmdt.visual_type == "distribution_grid":
visual_elem = render_distribution_grid(itmdt, cfg)
visual_elem = render_distribution_grid(itmdt, config)
elif itmdt.visual_type == "categorical_column":
visual_elem = render_cat(itmdt, cfg)
visual_elem = render_cat(itmdt, config)
elif itmdt.visual_type == "geography_column":
visual_elem = render_geo(itmdt, cfg)
visual_elem = render_geo(itmdt, config)
elif itmdt.visual_type == "numerical_column":
visual_elem = render_num(itmdt, cfg)
visual_elem = render_num(itmdt, config)
elif itmdt.visual_type == "datetime_column":
visual_elem = render_dt(itmdt, cfg)
visual_elem = render_dt(itmdt, config)
elif itmdt.visual_type == "cat_and_num_cols":
visual_elem = render_cat_num(itmdt, cfg)
visual_elem = render_cat_num(itmdt, config)
elif itmdt.visual_type == "geo_and_num_cols":
visual_elem = render_geo_num(itmdt, cfg)
visual_elem = render_geo_num(itmdt, config)
elif itmdt.visual_type == "latlong_and_num_cols":
visual_elem = render_latlong_num(itmdt, cfg)
visual_elem = render_latlong_num(itmdt, config)
elif itmdt.visual_type == "two_num_cols":
visual_elem = render_two_num(itmdt, cfg)
visual_elem = render_two_num(itmdt, config)
elif itmdt.visual_type == "two_cat_cols":
visual_elem = render_two_cat(itmdt, cfg)
visual_elem = render_two_cat(itmdt, config)
elif itmdt.visual_type == "dt_and_num_cols":
visual_elem = render_dt_num(itmdt, cfg)
visual_elem = render_dt_num(itmdt, config)
elif itmdt.visual_type == "dt_and_cat_cols":
visual_elem = render_dt_cat(itmdt, cfg)
visual_elem = render_dt_cat(itmdt, config)
elif itmdt.visual_type == "dt_cat_num_cols":
visual_elem = render_dt_num_cat(itmdt, cfg)
visual_elem = render_dt_num_cat(itmdt, config)

return visual_elem