Skip to content

Commit

Permalink
Improve KLL Sketch perf for non-grouped queries
Browse files Browse the repository at this point in the history
The current method for calculating memory usage has a hidden cost.
Within getEstimatedKllInMemorySize we call getSerializedSizeBytes.
The code for the serialized bytes size actually serializes the entire
internal state to a byte array first before returning the length. This
is expensive and should be avoided.

I am working on a PR to the upstream library to add a less-costly method
but until released, I would like to fix this as non-grouped execution
doesn't need the memory accounting for every sketch input.
  • Loading branch information
ZacBlanco committed Apr 26, 2024
1 parent b4af7c4 commit 2e08340
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 18 deletions.
Expand Up @@ -34,6 +34,7 @@

import java.util.Comparator;
import java.util.Map;
import java.util.function.Supplier;

import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS;
Expand All @@ -60,7 +61,7 @@ public interface KllSketchAggregationState
@Nullable
<T> KllItemsSketch<T> getSketch();

void addMemoryUsage(long value);
void addMemoryUsage(Supplier<Long> usage);

Type getType();

Expand Down Expand Up @@ -115,7 +116,7 @@ public <T> void setSketch(KllItemsSketch<T> sketch)
}

@Override
public void addMemoryUsage(long value)
public void addMemoryUsage(Supplier<Long> usage)
{
// noop
}
Expand Down Expand Up @@ -161,9 +162,9 @@ public <T> KllItemsSketch<T> getSketch()
}

@Override
public void addMemoryUsage(long value)
public void addMemoryUsage(Supplier<Long> usage)
{
accumulatedSizeInBytes += value;
accumulatedSizeInBytes += usage.get();
}

@Override
Expand Down
Expand Up @@ -65,8 +65,8 @@ public void deserialize(Block block, int index, KllSketchAggregationState state)
KllSketchAggregationState.SketchParameters parameters = KllSketchAggregationState.getSketchParameters(type);
// use heapify over wrap in order to get a writable sketch for updates and merges
KllItemsSketch sketch = KllItemsSketch.heapify(memory, parameters.getComparator(), parameters.getSerde());
state.addMemoryUsage(-getEstimatedKllInMemorySize(state.getSketch(), type.getJavaType()));
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(state.getSketch(), type.getJavaType()));
state.setSketch(sketch);
state.addMemoryUsage(getEstimatedKllInMemorySize(state.getSketch(), type.getJavaType()));
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(state.getSketch(), type.getJavaType()));
}
}
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.operator.aggregation.sketch.kll;

import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
Expand Down Expand Up @@ -54,9 +55,9 @@ public static void input(@AggregationState KllSketchAggregationState state, @Sql
{
initializeSketch(state, () -> Long::compareTo, ArrayOfLongsSerDe::new, k);
KllItemsSketch<Long> sketch = state.getSketch();
state.addMemoryUsage(-getEstimatedKllInMemorySize(sketch, long.class));
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, long.class));
state.getSketch().update(value);
state.addMemoryUsage(getEstimatedKllInMemorySize(sketch, long.class));
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, long.class));
}

@InputFunction
Expand All @@ -65,9 +66,9 @@ public static void input(@AggregationState KllSketchAggregationState state, @Sql
{
initializeSketch(state, () -> Double::compareTo, ArrayOfDoublesSerDe::new, k);
KllItemsSketch<Double> sketch = state.getSketch();
state.addMemoryUsage(-getEstimatedKllInMemorySize(sketch, double.class));
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, double.class));
state.getSketch().update(value);
state.addMemoryUsage(getEstimatedKllInMemorySize(sketch, double.class));
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, double.class));
}

@InputFunction
Expand All @@ -76,9 +77,9 @@ public static void input(@AggregationState KllSketchAggregationState state, @Sql
{
initializeSketch(state, () -> String::compareTo, ArrayOfStringsSerDe::new, k);
KllItemsSketch sketch = state.getSketch();
state.addMemoryUsage(-getEstimatedKllInMemorySize(sketch, Slice.class));
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, Slice.class));
state.getSketch().update(value.toStringUtf8());
state.addMemoryUsage(getEstimatedKllInMemorySize(sketch, Slice.class));
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, Slice.class));
}

@InputFunction
Expand All @@ -87,22 +88,22 @@ public static void input(@AggregationState KllSketchAggregationState state, @Sql
{
initializeSketch(state, () -> Boolean::compareTo, ArrayOfBooleansSerDe::new, k);
KllItemsSketch<Boolean> sketch = state.getSketch();
state.addMemoryUsage(-getEstimatedKllInMemorySize(sketch, boolean.class));
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(sketch, boolean.class));
state.getSketch().update(value);
state.addMemoryUsage(getEstimatedKllInMemorySize(sketch, boolean.class));
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, boolean.class));
}

@CombineFunction
public static void combine(@AggregationState KllSketchAggregationState state, @AggregationState KllSketchAggregationState otherState)
{
if (state.getSketch() != null && otherState.getSketch() != null) {
state.addMemoryUsage(-getEstimatedKllInMemorySize(state.getSketch(), state.getType().getJavaType()));
state.addMemoryUsage(() -> -getEstimatedKllInMemorySize(state.getSketch(), state.getType().getJavaType()));
state.getSketch().merge(otherState.getSketch());
state.addMemoryUsage(getEstimatedKllInMemorySize(state.getSketch(), state.getType().getJavaType()));
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(state.getSketch(), state.getType().getJavaType()));
}
else if (state.getSketch() == null) {
state.setSketch(otherState.getSketch());
state.addMemoryUsage(getEstimatedKllInMemorySize(otherState.getSketch(), state.getType().getJavaType()));
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(otherState.getSketch(), state.getType().getJavaType()));
}
}

Expand All @@ -125,7 +126,7 @@ private static <T> void initializeSketch(KllSketchAggregationState state, Suppli
if (state.getSketch() == null) {
KllItemsSketch<T> sketch = KllItemsSketch.newHeapInstance((int) k, comparator.get(), serdeSupplier.get());
state.setSketch(sketch);
state.addMemoryUsage(getEstimatedKllInMemorySize(sketch, state.getType().getJavaType()));
state.addMemoryUsage(() -> getEstimatedKllInMemorySize(sketch, state.getType().getJavaType()));
}
}
}

0 comments on commit 2e08340

Please sign in to comment.