Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rules list cutoffs are not printed in string representations of GreedyRulesListClassifier #169

Open
davidefiocco opened this issue Mar 15, 2023 · 2 comments

Comments

@davidefiocco
Copy link
Contributor

davidefiocco commented Mar 15, 2023

When training GreedyRulesListClassifier on float features, and the fitted classifer clf is printed, cutoff values are not shown, thus making the interpretation of the model a bit confusing. Here's an example:

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import RocCurveDisplay
import imodels

y = np.random.randint(2, size = 1000)
x1 = y + np.random.rand(1000)*1.2
x2 = y + np.random.rand(1000)*1.2

dataset = pd.DataFrame({"x1": x1, "x2": x2, "y": y})

X_train, X_test, y_train, y_test = train_test_split(dataset[["x1", "x2"]], dataset.y, test_size=0.2, random_state=42)

clf = imodels.GreedyRuleListClassifier(max_depth = 3, criterion = 'gini')
clf.fit(X_train, y_train, feature_names=X_train.columns)
y_pred = clf.predict(X_test)
RocCurveDisplay.from_estimator(clf, X_test, y_test, marker="o")

Trying to render the model with print(clf) yields something along the lines of

> ------------------------------
> Greedy Rule List
> ------------------------------
↓
10.71% risk (800 pts)
	if x1 ==> 99.7% risk (361 pts)
↓
2.0% risk (439 pts)
	if x2 ==> 100.0% risk (39 pts)
↓
0.28% risk (400 pts)
	if x1 ==> 16.3% risk (43 pts)

which I find confusing because x1 and x2 are floats, not booleans.
clf.rules_ are instead

[{'col': 'x1',
  'index_col': 0,
  'cutoff': 1.1447782516479492,
  'val': 0.09429280397022333,
  'flip': False,
  'val_right': 0.9672544080604534,
  'num_pts': 800,
  'num_pts_right': 397},
 {'col': 'x2',
  'index_col': 1,
  'cutoff': 1.2083932757377625,
  'val': 0.01881720430107527,
  'flip': False,
  'val_right': 1.0,
  'num_pts': 403,
  'num_pts_right': 31},
 {'col': 'x1',
  'index_col': 0,
  'cutoff': 1.0007766485214233,
  'val': 0.0,
  'flip': False,
  'val_right': 0.125,
  'num_pts': 372,
  'num_pts_right': 56},
 {'val': 0.0, 'num_pts': 316}] 

and contain a cutoff that is useful for model interpretation. I don't know exactly what would be the desired intended behavior, as at the moment the code starting at

"""
def __str__(self):
# s = ''
# for rule in self.rules_:
# s += f"mean {rule['val'].round(3)} ({rule['num_pts']} pts)\n"
# if 'col' in rule:
# s += f"if {rule['col']} >= {rule['cutoff']} then {rule['val_right'].round(3)} ({rule['num_pts_right']} pts)\n"
# return s
"""
def __str__(self):
'''Print out the list in a nice way
'''
s = '> ------------------------------\n> Greedy Rule List\n> ------------------------------\n'
def red(s):
# return f"\033[91m{s}\033[00m"
return s
def cyan(s):
# return f"\033[96m{s}\033[00m"
return s
def rule_name(rule):
if rule['flip']:
return '~' + rule['col']
return rule['col']
# rule = self.rules_[0]
# s += f"{red((100 * rule['val']).round(3))}% IwI ({rule['num_pts']} pts)\n"
for rule in self.rules_:
s += u'\u2193\n' + f"{cyan((100 * rule['val']).round(2))}% risk ({rule['num_pts']} pts)\n"
# s += f"\t{'Else':>45} => {cyan((100 * rule['val']).round(2)):>6}% IwI ({rule['val'] * rule['num_pts']:.0f}/{rule['num_pts']} pts)\n"
if 'col' in rule:
# prefix = f"if {rule['col']} >= {rule['cutoff']}"
prefix = f"if {rule_name(rule)}"
val = f"{100 * rule['val_right'].round(3)}"
s += f"\t{prefix} ==> {red(val)}% risk ({rule['num_pts_right']} pts)\n"
# rule = self.rules_[-1]
# s += f"{red((100 * rule['val']).round(3))}% IwI ({rule['num_pts']} pts)\n"
return s
contains commented chunks (also with colors, but not used).

@davidefiocco davidefiocco changed the title Rules list cutoffs are not printed in string representations of GreedyRulesList classifier Rules list cutoffs are not printed in string representations of GreedyRulesListClassifier Mar 15, 2023
@davidefiocco
Copy link
Contributor Author

davidefiocco commented Mar 16, 2023

A possible improvement would be reworking the __str__ representation as

    def __str__(self):
        '''Print out the list in a nice way
        '''
        header = '> ------------------------------\n> Greedy Rule List\n> ------------------------------\n'
        footer = '> ------------------------------\n'
        rule_template = '> {condition} => {risk}% risk ({num_pts} pts)\n'

        s = header
        for i in range(len(self.rules_)):

            rule = self.rules_[i]

            condition = 'else'
            risk = (100 * rule['val']).round(2)
            num_pts = rule['num_pts']

            if 'col' in rule:
                predicate = '>=' if not rule['flip'] else '<'
                if i == 0:
                    condition = f"if {rule['col']} {predicate} {rule['cutoff']}"
                else:
                    condition = f"else if {rule['col']} {predicate} {rule['cutoff']}"

                risk = (100 * rule['val_right']).round(2)
                num_pts = rule['num_pts_right']

            s += rule_template.format(
                condition=condition,
                risk=risk,
                num_pts=num_pts
            )

        s += footer
        return s

Which would render rules such as

[{'col': 'x2',
  'index_col': 1,
  'cutoff': 0.1395193189382553,
  'val': 0.04092071611253197,
  'flip': True,
  'val_right': 0.9315403422982885,
  'num_pts': 800,
  'num_pts_right': 409},
 {'col': 'x1',
  'index_col': 0,
  'cutoff': 0.0753365887212567,
  'val': 0.010554089709762533,
  'flip': False,
  'val_right': 1.0,
  'num_pts': 391,
  'num_pts_right': 12},
 {'col': 'x2',
  'index_col': 1,
  'cutoff': 0.19506534934043884,
  'val': 0.0,
  'flip': True,
  'val_right': 0.16666666666666666,
  'num_pts': 379,
  'num_pts_right': 24},
 {'val': 0.0, 'num_pts': 355}]

as

> ------------------------------
> Greedy Rule List
> ------------------------------
> if x2 < 0.1395193189382553 => 93.15% risk (409 pts)
> else if x1 >= 0.0753365887212567 => 100.0% risk (12 pts)
> else if x2 < 0.19506534934043884 => 16.67% risk (24 pts)
> else => 0.0% risk (355 pts)
> ------------------------------

@csinva
Copy link
Owner

csinva commented Mar 17, 2023

Thanks, this is a nice fix!

I'll work on making it so that it displays like this if the feature is continuous-valued and keeps the original behavior for non-continuous features. Probably also worth rounding the cutoff value to ~3 decimal places.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants