Skip to content

Commit

Permalink
Implement AUUC reduce correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
maurever committed May 8, 2024
1 parent 5d5db02 commit e46bf3c
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 56 deletions.
2 changes: 1 addition & 1 deletion h2o-algos/src/test/java/hex/tree/uplift/UpliftDRFTest.java
Expand Up @@ -349,7 +349,7 @@ public void testSupportCVCriteo() {
p._treatment_column = "treatment";
p._response_column = "conversion";
p._seed = 0xDECAF;
p._ntrees = 10;
p._ntrees = 11;
p._score_each_iteration = true;
p._nfolds = 3;
p._auuc_nbins = 50;
Expand Down
131 changes: 76 additions & 55 deletions h2o-core/src/main/java/hex/AUUC.java
Expand Up @@ -368,33 +368,39 @@ public AUUCImpl(double[] thresholds, int nbins, double[] probs) {
*/
public static class AUUCBuilder extends Iced {
final int _nbins;
final double[]_thresholds; // thresholds
final double[] _thresholds; // thresholds
final long[] _treatment; // number of data from treatment group
final long[] _control; // number of data from control group
final long[] _yTreatment; // number of data from treatment group with prediction = 1
final long[] _yControl; // number of data from control group with prediction = 1
final long[] _frequency; // frequency of data in each bin
double[] _probs;
int _n; // number of data
int _nUsed; // number of used bins
int _nbinsUsed; // number of used bins
int _ssx;

public AUUCBuilder(int nbins, double[] thresholds, double[] probs) {
int tlen = thresholds != null ? thresholds.length : 1;
_probs = probs;
_nbins = nbins;
_nUsed = tlen;
_thresholds = thresholds == null ? new double[]{0} : thresholds;
_treatment = new long[tlen];
_control = new long[tlen];
_yTreatment = new long[tlen];
_yControl = new long[tlen];
_frequency = new long[tlen];
_nbinsUsed = thresholds != null ? thresholds.length : 0;
int l = nbins * 2; // maximal possible builder arrays length
_thresholds = new double[l];
if (thresholds != null) {
System.arraycopy(thresholds, 0, _thresholds, 0, thresholds.length);
}
_probs = new double[l];
System.arraycopy(probs, 0, _probs, 0, probs.length);
System.arraycopy(probs, 0, _probs, probs.length-1, probs.length);
_treatment = new long[l];
_control = new long[l];
_yTreatment = new long[l];
_yControl = new long[l];
_frequency = new long[l];
_ssx = -1;
}

public void perRow(double pred, double w, double y, float treatment) {
if (w == 0) {return;}
if (w == 0 || _thresholds == null) {return;}
for(int t = 0; t < _thresholds.length; t++) {
if (pred >= _thresholds[t] && (t == 0 || pred <_thresholds[t-1])) {
_n++;
Expand All @@ -416,20 +422,23 @@ public void perRow(double pred, double w, double y, float treatment) {
}

public void reduce(AUUCBuilder bldr) {
_n += bldr._n;
ArrayUtils.add(_treatment, bldr._treatment);
ArrayUtils.add(_control, bldr._control);
ArrayUtils.add(_yTreatment, bldr._yTreatment);
ArrayUtils.add(_yControl, bldr._yControl);
ArrayUtils.add(_frequency, bldr._frequency);
if(bldr._nbinsUsed == 0) {return;}
if(_nbinsUsed == 0 || _thresholds == bldr._thresholds){
reduceSameOrNullThresholds(bldr);
} else {
reduceDifferentThresholds(bldr);
}
}

public void reduce2(AUUCBuilder bldr) {
// Merge sort the 2 sorted lists into the double-sized arrays. The tail
// half of the double-sized array is unused, but the front half is
// probably a source. Merge into the back.
int x = _n-1;
int y = bldr._n-1;
/**
* Merge sort the 2 sorted lists into the double-sized arrays. The tail
* half of the double-sized array is unused, but the front half is
* probably a source. Merge into the back.
* @param bldr AUUC builder to reduce
*/
public void reduceDifferentThresholds(AUUCBuilder bldr){
int x = _nbinsUsed -1;
int y = bldr._nbinsUsed -1;
while( x+y+1 >= 0 ) {
boolean self_is_larger = y < 0 || (x >= 0 && _thresholds[x] >= bldr._thresholds[y]);
AUUCBuilder b = self_is_larger ? this : bldr;
Expand All @@ -440,16 +449,31 @@ public void reduce2(AUUCBuilder bldr) {
_yTreatment[x+y+1] = b._yTreatment[idx];
_yControl[x+y+1] = b._yControl[idx];
_frequency[x+y+1] = b._frequency[idx];
_probs[x+y+1] = b._probs[idx];
if( self_is_larger ) x--; else y--;
}
_n += bldr._n;
_nbinsUsed += bldr._nbinsUsed;
_ssx = -1;

// Merge elements with least squared-error increase until we get fewer
// than _nBins and no duplicates. May require many merges.
while( _n > _nbins || dups() )
while( _nbinsUsed > _nbins || dups() )
mergeOneBin();
}

public void reduceSameOrNullThresholds(AUUCBuilder bldr){
_n += bldr._n;
if(_nbinsUsed == 0) {
ArrayUtils.add(_thresholds, bldr._thresholds);
_nbinsUsed = bldr._nbinsUsed;
}
ArrayUtils.add(_treatment, bldr._treatment);
ArrayUtils.add(_control, bldr._control);
ArrayUtils.add(_yTreatment, bldr._yTreatment);
ArrayUtils.add(_yControl, bldr._yControl);
ArrayUtils.add(_frequency, bldr._frequency);
}

static double combineCenters(double ths1, double ths0, double probs, long nrows) {
//double center = (ths0 * n0 + ths1 * n1) / (n0 + n1);
Expand All @@ -474,26 +498,22 @@ private void mergeOneBin( ) {
_yTreatment[ssx] += _yTreatment[ssx+1];
_yControl[ssx] += _yControl[ssx+1];
_frequency[ssx] += _frequency[ssx+1];
int n = _n;
int n = _nbinsUsed == 2 ? _nbinsUsed - ssx -1 : _nbinsUsed - ssx -2;
// Slide over to crush the removed bin at index (ssx+1)
System.arraycopy(_thresholds,ssx+2,_thresholds,ssx+1,n-ssx-2);
System.arraycopy(_treatment,ssx+2,_treatment,ssx+1,n-ssx-2);
System.arraycopy(_control,ssx+2,_control,ssx+1,n-ssx-2);
System.arraycopy(_yTreatment,ssx+2,_yTreatment,ssx+1,n-ssx-2);
System.arraycopy(_yControl,ssx+2,_yControl,ssx+1,n-ssx-2);
System.arraycopy(_frequency,ssx+2,_frequency,ssx+1,n-ssx-2);
_n--;
_thresholds[_n] = _treatment[_n] = _control[_n] = _yTreatment[_n] = _yControl[_n] = _frequency[_n] = 0;
System.arraycopy(_thresholds,ssx+2,_thresholds,ssx+1,n);
System.arraycopy(_treatment,ssx+2,_treatment,ssx+1,n);
System.arraycopy(_control,ssx+2,_control,ssx+1,n);
System.arraycopy(_yTreatment,ssx+2,_yTreatment,ssx+1,n);
System.arraycopy(_yControl,ssx+2,_yControl,ssx+1,n);
System.arraycopy(_frequency,ssx+2,_frequency,ssx+1,n);
_nbinsUsed--;
_ssx = -1;
}

// Find the pair of bins that when combined give the smallest increase in
// squared error. Dups never increase squared error.
//
// I tried code for merging bins with keeping the bins balanced in size,
// but this leads to bad errors if the probabilities are sorted. Also
// tried the original: merge bins with the least distance between bin
// centers. Same problem for sorted data.
/**
* Find the pair of bins that when combined give the smallest difference in thresholds
* @return index of the bin where the threshold difference is the smallest
*/
private int findSmallest() {
if( _ssx == -1 ) {
_ssx = findSmallestImpl();
Expand All @@ -503,12 +523,10 @@ private int findSmallest() {
}

private int findSmallestImpl() {
if (_n == 1)
if (_nbinsUsed == 1)
return 0;
// we couldn't find any bins to merge based on SE (the math can be producing Double.Infinity or Double.NaN)
// revert to using a simple distance of the bin centers
int minI = 0;
long n = _n;
long n = _nbinsUsed;
double minDist = _thresholds[1] - _thresholds[0];
for (int i = 1; i < n - 1; i++) {
double dist = _thresholds[i + 1] - _thresholds[i];
Expand All @@ -521,25 +539,27 @@ private int findSmallestImpl() {
}

private boolean dups() {
long n = _n;
long n = _nbinsUsed;
for( int i=0; i<n-1; i++ ) {
double derr = computeDeltaError(_thresholds[i+1],_frequency[i+1],_thresholds[i],_frequency[i]);
if( derr == 0 ) { _ssx = i; return true; }
double derr = computeDeltaError(_thresholds[i + 1], _frequency[i + 1], _thresholds[i], _frequency[i]);
if (derr == 0) {
_ssx = i;
return true;
}
}
return false;
}


/**
* If thresholds vary by less than a float ULP, treat them as the same.
* Parallel equation drawn from:
* http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
* @return delta error from two thresholds
*/
private double computeDeltaError(double ths1, double n1, double ths0, double n0 ) {
// If thresholds vary by less than a float ULP, treat them as the same.
// Some models only output predictions to within float accuracy (so a
// variance here is junk), and also it's not statistically sane to have
// a model which varies predictions by such a tiny change in thresholds.
double delta = (float)ths1-(float)ths0;
if (delta == 0)
return 0;
// Parallel equation drawn from:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
return delta*delta*n0*n1 / (n0+n1);
}

Expand All @@ -555,7 +575,8 @@ private static double computeLinearInterpolation(double ths1, double ths0, doubl

private String toDebugString() {
return "n =" +_n +
"; nBins = " + _nbins +
"; nbins = " + _nbins +
"; nbinsUsed = " + _nbinsUsed +
"; ths = " + Arrays.toString(_thresholds) +
"; treatment = " + Arrays.toString(_treatment) +
"; contribution = " + Arrays.toString(_control) +
Expand Down

0 comments on commit e46bf3c

Please sign in to comment.