Skip to content

Commit

Permalink
Create util to instanciate objects and validate arguments (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
reluzita committed Feb 1, 2024
1 parent 489233c commit 69f7cff
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 24 deletions.
27 changes: 4 additions & 23 deletions src/aequitas/flow/methods/preprocessing/label_flipping.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .preprocessing import PreProcessing

from ...utils import create_logger
from ...utils.imports import import_object
from ...utils.imports import instantiate_object

import inspect
import pandas as pd
import math
from typing import Optional, Tuple, Literal, Union, Callable
Expand Down Expand Up @@ -104,26 +103,8 @@ def __init__(

self.bagging_max_samples = bagging_max_samples

if isinstance(bagging_base_estimator, str):
bagging_base_estimator = import_object(bagging_base_estimator)
signature = inspect.signature(bagging_base_estimator)
if (
signature.parameters[list(signature.parameters.keys())[-1]].kind
== inspect.Parameter.VAR_KEYWORD
):
args = (
base_estimator_args # Estimator takes **kwargs, so all args are valid
)
else:
args = {
arg: value
for arg, value in base_estimator_args.items()
if arg in signature.parameters
}
self.bagging_base_estimator = bagging_base_estimator(**args)
self.logger.info(
f"Created base estimator {self.bagging_base_estimator} with params {args}, "
f"discarded args:{list(set(base_estimator_args.keys()) - set(args.keys()))}"
self.bagging_base_estimator = instantiate_object(
bagging_base_estimator, **base_estimator_args
)
self.bagging_n_estimators = bagging_n_estimators

Expand Down Expand Up @@ -317,7 +298,7 @@ def transform(
The transformed input, X, y, and s.
"""
super().transform(X, y, s)

self.logger.info("Transforming data with LabelFlipping.")

if s is None and self.fair_ordering:
Expand Down
33 changes: 32 additions & 1 deletion src/aequitas/flow/utils/imports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
from typing import Callable, Union
import inspect


def import_object(import_path: str) -> Union[object, Callable]:
Expand All @@ -20,7 +21,37 @@ def import_object(import_path: str) -> Union[object, Callable]:
"""
separator_idx = import_path.rindex(".")
module_path = import_path[:separator_idx]
obj_name = import_path[separator_idx + 1 :]
obj_name = import_path[separator_idx + 1:]

module = importlib.import_module(module_path)
return getattr(module, obj_name)


def instantiate_object(class_object: Union[str, Callable], **kwargs) -> object:
"""Instantiates an object by their classpath.
Parameters
----------
class_object : Union[str, Callable]
The classpath of the object to instantiate.
**kwargs
The keyword arguments to pass to the class constructor.
Returns
-------
object
The instantiated object.
"""

class_object = import_object(class_object)
signature = inspect.signature(class_object)
if (
signature.parameters[list(signature.parameters.keys())[-1]].kind
== inspect.Parameter.VAR_KEYWORD
):
args = kwargs
else:
args = {
arg: value for arg, value in kwargs.items() if arg in signature.parameters
}
return class_object(**args)

0 comments on commit 69f7cff

Please sign in to comment.