Skip to content

Commit

Permalink
Fix regr_r2 result incorrect
Browse files Browse the repository at this point in the history
There are two special cases:
1. m2X() = 0, then result should be NULL
2. m2X() != 0 and m2Y() = 0, then result shoule be 1.
  • Loading branch information
8dukongjian authored and tdcmeehan committed Apr 29, 2024
1 parent aeaa0b7 commit 4bd198a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
Expand Up @@ -196,6 +196,9 @@ public static double getRegressionSyy(RegressionState state)

public static double getRegressionR2(RegressionState state)
{
if (state.getM2X() != 0 && state.getM2Y() == 0) {
return 1.0;
}
return Math.pow(state.getC2(), 2) / (state.getM2X() * state.getM2Y());
}

Expand Down
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.operator.aggregation;

import org.apache.commons.math3.stat.regression.SimpleRegression;
import org.testng.annotations.Test;

import static com.facebook.presto.block.BlockAssertions.createDoublesBlock;
import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -45,6 +46,19 @@ else if (length == 1) {
}
}

@Test
public void testTwoSpecialCase()
{
// when m2x = 0, result is null
Double[] y = new Double[] {1.0, 1.0, 1.0, 1.0, 1.0};
Double[] x = new Double[] {1.0, 1.0, 1.0, 1.0, 1.0};
testAggregation(null, createDoublesBlock(y), createDoublesBlock(x));

// when m2x != 0 and m2y = 0, result is 1.0
x = new Double[] {1.0, 2.0, 3.0, 4.0, 5.0};
testAggregation(1.0, createDoublesBlock(y), createDoublesBlock(x));
}

@Override
protected void testNonTrivialAggregation(Double[] y, Double[] x)
{
Expand Down
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.operator.aggregation;

import org.apache.commons.math3.stat.regression.SimpleRegression;
import org.testng.annotations.Test;

import static com.facebook.presto.block.BlockAssertions.createBlockOfReals;
import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -45,6 +46,19 @@ else if (length == 1) {
}
}

@Test
public void testTwoSpecialCase()
{
// when m2x = 0, result is null
Float[] y = new Float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
Float[] x = new Float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
testAggregation(null, createBlockOfReals(y), createBlockOfReals(x));

// when m2x != 0 and m2y = 0, result is 1.0
x = new Float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
testAggregation(1.0f, createBlockOfReals(y), createBlockOfReals(x));
}

@Override
protected void testNonTrivialAggregation(Float[] y, Float[] x)
{
Expand Down

0 comments on commit 4bd198a

Please sign in to comment.