Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement large array constructing through chunk concatenation #22622

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Expand Up @@ -99,6 +99,7 @@
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;

import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -1006,6 +1007,20 @@ protected RowExpression visitArrayConstructor(ArrayConstructor node, Context con
List<Type> argumentTypes = arguments.stream()
.map(RowExpression::getType)
.collect(toImmutableList());

if (arguments.size() > 200) {
List<RowExpression> concatenatedArguments = new ArrayList<>();
for (int i = 0; i < arguments.size(); i += 200) {
int end = Math.min(i + 200, arguments.size());
List<RowExpression> chunk = arguments.subList(i, end);
RowExpression chunkArray = call("ARRAY", functionResolution.arrayConstructor(argumentTypes.subList(i, end)), getType(node), chunk);
concatenatedArguments.add(chunkArray);
}
List<Type> concatenatedArgumentTypes = concatenatedArguments.stream()
.map(RowExpression::getType)
.collect(toImmutableList());
return call("concat", functionAndTypeResolver.lookupFunction("concat", fromTypes(concatenatedArgumentTypes)), getType(node), concatenatedArguments);
}
return call("ARRAY", functionResolution.arrayConstructor(argumentTypes), getType(node), arguments);
}

Expand Down
Expand Up @@ -27,10 +27,8 @@ public class TestArrayFunctions
@Test
public void testArrayConstructor()
{
tryEvaluateWithAll("array[" + Joiner.on(", ").join(nCopies(254, "rand()")) + "]", new ArrayType(DOUBLE));
assertNotSupported(
"array[" + Joiner.on(", ").join(nCopies(255, "rand()")) + "]",
"Too many arguments for array constructor");
tryEvaluateWithAll("array[" + Joiner.on(", ").join(nCopies(200, "rand()")) + "]", new ArrayType(DOUBLE));
tryEvaluateWithAll("array[" + Joiner.on(", ").join(nCopies(1000, "rand()")) + "]", new ArrayType(DOUBLE));
}

@Test
Expand Down
Expand Up @@ -35,6 +35,7 @@
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
Expand All @@ -58,7 +59,11 @@
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static com.facebook.airlift.configuration.ConfigBinder.configBinder;
import static com.facebook.airlift.json.JsonBinder.jsonBinder;
Expand Down Expand Up @@ -146,6 +151,33 @@ public void testArrayLiteral()
assertEquals(block.getInt(2), 3);
}

@Test
public void testLargeArraySplitting()
{
// Test array with more than 200 elements
int numElements = 900;
String sql = "ARRAY " + IntStream.rangeClosed(1, numElements)
.mapToObj(Integer::toString)
.collect(Collectors.joining(", ", "[", "]"));
List<String> sqlParts = new ArrayList<>();
for (int i = 0; i < (numElements - 1) / 200 + 1; i++) {
sqlParts.add("ARRAY " + IntStream.rangeClosed(1 + 200 * i, Math.min(200 * (i + 1), numElements))
.mapToObj(Integer::toString)
.collect(Collectors.joining(", ", "[", "]")));
}
RowExpression roundTripExpression = getRoundTrip(sql, false);
List<RowExpression> rowExpressionParts = new ArrayList<>();
for (int i = 0; i < (numElements - 1) / 200 + 1; i++) {
rowExpressionParts.add(getRoundTrip(sqlParts.get(i), false));
}

assertTrue(roundTripExpression instanceof CallExpression);
CallExpression callExpression = (CallExpression) roundTripExpression;
assertEquals(callExpression.getDisplayName(), "concat");
List<RowExpression> arguments = callExpression.getArguments();
assertEquals(arguments, rowExpressionParts);
}

@Test
public void testArrayGet()
{
Expand Down