Skip to content

Commit

Permalink
Merge pull request #186 from dssg/rework-label-flipping-inference
Browse files Browse the repository at this point in the history
Update label flipping method
  • Loading branch information
reluzita committed Mar 25, 2024
2 parents 270da37 + 0515a2e commit a61653f
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions src/aequitas/flow/methods/preprocessing/label_flipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _score_instances(self, X: pd.DataFrame, y: pd.Series) -> pd.Series:

def _calculate_group_flips(self, y: pd.Series, s: pd.Series):
prevalence = y.mean()
group_prevalences = y.groupby(s).mean()
group_prevalences = y.groupby(s, observed=True).mean()

min_prevalence = prevalence - self.disparity_target * prevalence
max_prevalence = prevalence + self.disparity_target * prevalence
Expand Down Expand Up @@ -259,11 +259,11 @@ def _label_flipping(self, y: pd.Series, s: Optional[pd.Series], scores: pd.Serie
y_flipped : pd.Series
The transformed label vector.
"""
y_flipped = y.reindex(
y_flipped = y.loc[
scores.sort_values(
ascending=(self.ordering_method == "ensemble_margin")
).index
)
]
n_flip = int(self.max_flip_rate * len(y))

if self.fair_ordering:
Expand All @@ -273,26 +273,30 @@ def _label_flipping(self, y: pd.Series, s: Optional[pd.Series], scores: pd.Serie
if self.ordering_method == "residuals"
else y_flipped.loc[scores <= 0].index
)
flip_count = 0

for i in flip_index:
if abs(scores.loc[i]) < self.score_threshold:
break

if (group_flips[s.loc[i]] > 0 and y.loc[i] == 0) or (
group_flips[s.loc[i]] < 0 and y.loc[i] == 1
):
y_flipped.loc[i] = 1 - y.loc[i]
flip_count += 1
if group_flips[s.loc[i]] > 0:
group_flips[s.loc[i]] -= 1
else:
group_flips[s.loc[i]] += 1

if flip_count == n_flip:
break

self.logger.info(f"Flipped {flip_count} instances.")
to_flip = pd.Series(index=flip_index).fillna(False).loc[y_flipped.index]
for group, flips in group_flips.items():
if flips > 0:
labels = y == 0
elif flips < 0:
labels = y == 1
else:
labels = y == -1 # To keep everything false.
group_instances = s == group
intersection = (group_instances & labels).loc[y_flipped.index]
# find and keep first "flips" instances to flip as true, rest as false
true_indices = intersection[intersection].index
if len(true_indices) > abs(flips):
intersection[true_indices[abs(flips) :]] = False
to_flip = to_flip | intersection
# Check if we are flipping more than n_flip
true_indices = to_flip[to_flip].index
if to_flip.sum() > n_flip:
to_flip[true_indices[n_flip:]] = False
y_flipped[to_flip] = 1 - y_flipped[to_flip]

differences = (y_flipped.loc[y.index] != y).sum()
# Check the indexes of instances that were flipped
self.logger.info(f"Flipped {differences} instances.")

else:
n_above_threshold = scores.loc[abs(scores) >= self.score_threshold].shape[0]
Expand All @@ -302,7 +306,7 @@ def _label_flipping(self, y: pd.Series, s: Optional[pd.Series], scores: pd.Serie

self.logger.info(f"Flipped {n_flip} instances.")

return y_flipped.reindex(y.index)
return y_flipped.loc[y.index]

def transform(
self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]
Expand Down

0 comments on commit a61653f

Please sign in to comment.