Skip to content

Commit

Permalink
Add row IDs to batch reader
Browse files Browse the repository at this point in the history
  • Loading branch information
elharo committed May 1, 2024
1 parent ada2eb9 commit c1bfc50
Show file tree
Hide file tree
Showing 15 changed files with 123 additions and 28 deletions.
Expand Up @@ -38,5 +38,6 @@ Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation);
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent);
}
Expand Up @@ -73,6 +73,7 @@
import static com.facebook.presto.hive.HiveErrorCode.HIVE_UNSUPPORTED_FORMAT;
import static com.facebook.presto.hive.HivePageSourceProvider.ColumnMapping.toColumnHandles;
import static com.facebook.presto.hive.HiveSessionProperties.isUseRecordPageSourceForCustomSplit;
import static com.facebook.presto.hive.HiveUtil.checkRowIDPartitionComponent;
import static com.facebook.presto.hive.HiveUtil.getPrefilledColumnValue;
import static com.facebook.presto.hive.HiveUtil.parsePartitionValue;
import static com.facebook.presto.hive.HiveUtil.shouldUseRecordReaderFromInputFormat;
Expand Down Expand Up @@ -183,10 +184,8 @@ public ConnectorPageSource createPageSource(
return createAggregatedPageSource(aggregatedPageSourceFactories, configuration, session, hiveSplit, hiveLayout, selectedColumns, fileContext, encryptionInformation);
}
if (hiveLayout.isPushdownFilterEnabled()) {
// TODO from stack trace we come into here so the pushdown filter is enabled
boolean supplyRowIDs = selectedColumns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
Optional<byte[]> rowIDPartitionComponent = hiveSplit.getRowIdPartitionComponent();
checkArgument(!supplyRowIDs || rowIDPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
checkRowIDPartitionComponent(selectedColumns, rowIDPartitionComponent);
Optional<ConnectorPageSource> selectivePageSource = createSelectivePageSource(
selectivePageSourceFactories,
configuration,
Expand Down Expand Up @@ -376,9 +375,8 @@ private static Optional<ConnectorPageSource> createSelectivePageSource(

for (HiveSelectivePageSourceFactory pageSourceFactory : selectivePageSourceFactories) {
List<HiveColumnHandle> columnHandles = toColumnHandles(columnMappings, true);
boolean supplyRowIDs = columnHandles.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
Optional<byte[]> rowIDPartitionComponent = split.getRowIdPartitionComponent();
checkArgument(!supplyRowIDs || rowIDPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
checkRowIDPartitionComponent(columnHandles, rowIDPartitionComponent);
Optional<? extends ConnectorPageSource> pageSource = pageSourceFactory.createPageSource(
configuration,
session,
Expand Down Expand Up @@ -505,7 +503,8 @@ public static Optional<ConnectorPageSource> createHivePageSource(
effectivePredicate,
hiveStorageTimeZone,
hiveFileContext,
encryptionInformation);
encryptionInformation,
rowIdPartitionComponent);
if (pageSource.isPresent()) {
HivePageSource hivePageSource = new HivePageSource(
columnMappings,
Expand Down
Expand Up @@ -213,6 +213,14 @@ public final class HiveUtil
private static final String USE_RECORD_READER_FROM_INPUT_FORMAT_ANNOTATION = "UseRecordReaderFromInputFormat";
private static final String USE_FILE_SPLITS_FROM_INPUT_FORMAT_ANNOTATION = "UseFileSplitsFromInputFormat";

public static void checkRowIDPartitionComponent(List<HiveColumnHandle> columns, Optional<byte[]> rowIdPartitionComponent)
{
boolean supplyRowIDs = columns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
if (supplyRowIDs) {
checkArgument(rowIdPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
}
}

static {
DateTimeParser[] timestampWithoutTimeZoneParser = {
DateTimeFormat.forPattern("yyyy-M-d").getParser(),
Expand Down
Expand Up @@ -98,7 +98,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -132,6 +133,7 @@ public Optional<? extends ConnectorPageSource> createPageSource(
.build(),
encryptionInformation,
dwrfEncryptionProvider,
session));
session,
rowIDPartitionComponent));
}
}
Expand Up @@ -48,9 +48,9 @@
import java.util.Optional;

import static com.facebook.presto.hive.HiveErrorCode.HIVE_BAD_DATA;
import static com.facebook.presto.hive.HiveUtil.checkRowIDPartitionComponent;
import static com.facebook.presto.hive.orc.OrcSelectivePageSourceFactory.createOrcPageSource;
import static com.facebook.presto.orc.OrcEncoding.DWRF;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

public class DwrfSelectivePageSourceFactory
Expand Down Expand Up @@ -119,8 +119,7 @@ public Optional<? extends ConnectorPageSource> createPageSource(
throw new PrestoException(HIVE_BAD_DATA, "ORC file is empty: " + fileSplit.getPath());
}

boolean supplyRowIDs = columns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
checkArgument(!supplyRowIDs || rowIDPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
checkRowIDPartitionComponent(columns, rowIDPartitionComponent);

return Optional.of(createOrcPageSource(
session,
Expand Down
Expand Up @@ -24,6 +24,7 @@
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.hive.FileFormatDataSourceStats;
import com.facebook.presto.hive.HiveColumnHandle;
import com.facebook.presto.hive.RowIDCoercer;
import com.facebook.presto.orc.OrcAggregatedMemoryContext;
import com.facebook.presto.orc.OrcBatchRecordReader;
import com.facebook.presto.orc.OrcCorruptionException;
Expand Down Expand Up @@ -58,6 +59,7 @@ public class OrcBatchPageSource

private final Block[] constantBlocks;
private final int[] hiveColumnIndexes;
private final boolean[] rowIDColumnIndexes;

private int batchId;
private long completedPositions;
Expand All @@ -71,16 +73,20 @@ public class OrcBatchPageSource

private final List<Boolean> isRowPositionList;

private final RowIDCoercer coercer;

public OrcBatchPageSource(
OrcBatchRecordReader recordReader,
OrcDataSource orcDataSource,
List<HiveColumnHandle> columns,
TypeManager typeManager,
OrcAggregatedMemoryContext systemMemoryContext,
FileFormatDataSourceStats stats,
RuntimeStats runtimeStats)
RuntimeStats runtimeStats,
byte[] rowIDPartitionComponent,
String rowGroupId)
{
this(recordReader, orcDataSource, columns, typeManager, systemMemoryContext, stats, runtimeStats, nCopies(columns.size(), false));
this(recordReader, orcDataSource, columns, typeManager, systemMemoryContext, stats, runtimeStats, nCopies(columns.size(), false), rowIDPartitionComponent, rowGroupId);
}

/**
Expand All @@ -97,7 +103,9 @@ public OrcBatchPageSource(
OrcAggregatedMemoryContext systemMemoryContext,
FileFormatDataSourceStats stats,
RuntimeStats runtimeStats,
List<Boolean> isRowPositionList)
List<Boolean> isRowPositionList,
byte[] rowIDPartitionComponent,
String rowGroupId)
{
this.recordReader = requireNonNull(recordReader, "recordReader is null");
this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null");
Expand All @@ -107,9 +115,11 @@ public OrcBatchPageSource(
this.stats = requireNonNull(stats, "stats is null");
this.runtimeStats = requireNonNull(runtimeStats, "runtimeStats is null");
this.isRowPositionList = requireNonNull(isRowPositionList, "isRowPositionList is null");
this.coercer = new RowIDCoercer(rowIDPartitionComponent, rowGroupId);

this.constantBlocks = new Block[size];
this.hiveColumnIndexes = new int[size];
this.rowIDColumnIndexes = new boolean[size];

ImmutableList.Builder<String> namesBuilder = ImmutableList.builder();
ImmutableList.Builder<Type> typesBuilder = ImmutableList.builder();
Expand All @@ -124,6 +134,7 @@ public OrcBatchPageSource(
typesBuilder.add(type);

hiveColumnIndexes[columnIndex] = column.getHiveColumnIndex();
rowIDColumnIndexes[columnIndex] = HiveColumnHandle.isRowIdColumnHandle(column);

if (!recordReader.isColumnPresent(column.getHiveColumnIndex())) {
constantBlocks[columnIndex] = RunLengthEncodedBlock.create(type, null, MAX_BATCH_SIZE);
Expand Down Expand Up @@ -183,6 +194,11 @@ public Page getNextPage()
if (isRowPositionColumn(fieldId)) {
blocks[fieldId] = getRowPosColumnBlock(recordReader.getFilePosition(), batchSize);
}
else if (isRowIDColumn(fieldId)) {
Block rowNumbers = getRowPosColumnBlock(recordReader.getFilePosition(), batchSize);
Block rowIDs = coercer.apply(rowNumbers);
blocks[fieldId] = rowIDs;
}
else {
if (constantBlocks[fieldId] != null) {
blocks[fieldId] = constantBlocks[fieldId].getRegion(0, batchSize);
Expand Down Expand Up @@ -260,6 +276,12 @@ private boolean isRowPositionColumn(int column)
return isRowPositionList.get(column);
}

private boolean isRowIDColumn(int column)
{
return this.rowIDColumnIndexes[column];
}

// TODO verify these are row numbers and rename?
private static Block getRowPosColumnBlock(long baseIndex, int size)
{
long[] rowPositions = new long[size];
Expand Down
Expand Up @@ -63,6 +63,7 @@
import static com.facebook.presto.hive.HiveCommonSessionProperties.getOrcTinyStripeThreshold;
import static com.facebook.presto.hive.HiveCommonSessionProperties.isOrcBloomFiltersEnabled;
import static com.facebook.presto.hive.HiveCommonSessionProperties.isOrcZstdJniDecompressionEnabled;
import static com.facebook.presto.hive.HiveUtil.checkRowIDPartitionComponent;
import static com.facebook.presto.hive.HiveUtil.getPhysicalHiveColumnHandles;
import static com.facebook.presto.hive.orc.OrcPageSourceFactoryUtils.getOrcDataSource;
import static com.facebook.presto.hive.orc.OrcPageSourceFactoryUtils.getOrcReader;
Expand Down Expand Up @@ -134,7 +135,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -169,7 +171,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
.build(),
encryptionInformation,
NO_ENCRYPTION,
session));
session,
rowIDPartitionComponent));
}

public static ConnectorPageSource createOrcPageSource(
Expand All @@ -191,7 +194,8 @@ public static ConnectorPageSource createOrcPageSource(
OrcReaderOptions orcReaderOptions,
Optional<EncryptionInformation> encryptionInformation,
DwrfEncryptionProvider dwrfEncryptionProvider,
ConnectorSession session)
ConnectorSession session,
Optional<byte[]> rowIDPartitionComponent)
{
checkArgument(domainCompactionThreshold >= 1, "domainCompactionThreshold must be at least 1");

Expand Down Expand Up @@ -235,14 +239,20 @@ public static ConnectorPageSource createOrcPageSource(
systemMemoryUsage,
INITIAL_BATCH_SIZE);

checkRowIDPartitionComponent(columns, rowIDPartitionComponent);

byte[] partitionID = rowIDPartitionComponent.orElse(new byte[0]);
String rowGroupID = path.getName();
return new OrcBatchPageSource(
recordReader,
reader.getOrcDataSource(),
physicalColumns,
typeManager,
systemMemoryUsage,
stats,
hiveFileContext.getStats());
hiveFileContext.getStats(),
partitionID,
rowGroupID);
}
catch (Exception e) {
try {
Expand Down
Expand Up @@ -284,7 +284,6 @@ public static ConnectorPageSource createOrcPageSource(

boolean supplyRowIDs = selectedColumns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
checkArgument(!supplyRowIDs || rowIDPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
// TODO follow code path for !supplyRowIDs == true following below
byte[] partitionID = rowIDPartitionComponent.orElse(new byte[0]);
String rowGroupId = path.getName();

Expand Down
Expand Up @@ -72,7 +72,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!PageInputFormat.class.getSimpleName().equals(storage.getStorageFormat().getInputFormat())) {
return Optional.empty();
Expand Down
Expand Up @@ -512,7 +512,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent)
{
if (!PARQUET_SERDE_CLASS_NAMES.contains(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down
Expand Up @@ -106,7 +106,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
RcFileEncoding rcFileEncoding;
if (LazyBinaryColumnarSerDe.class.getName().equals(storage.getStorageFormat().getSerDe())) {
Expand Down
Expand Up @@ -497,7 +497,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(Configuration co
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent)
{
return Optional.of(new MockPageSource());
}
Expand Down Expand Up @@ -640,7 +641,18 @@ private static class MockOrcBatchPageSourceFactory
implements HiveBatchPageSourceFactory
{
@Override
public Optional<? extends ConnectorPageSource> createPageSource(Configuration configuration, ConnectorSession session, HiveFileSplit fileSplit, Storage storage, SchemaTableName tableName, Map<String, String> tableParameters, List<HiveColumnHandle> columns, TupleDomain<HiveColumnHandle> effectivePredicate, DateTimeZone hiveStorageTimeZone, HiveFileContext hiveFileContext, Optional<EncryptionInformation> encryptionInformation)
public Optional<? extends ConnectorPageSource> createPageSource(Configuration configuration,
ConnectorSession session,
HiveFileSplit fileSplit,
Storage storage,
SchemaTableName tableName,
Map<String, String> tableParameters,
List<HiveColumnHandle> columns,
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -696,8 +708,18 @@ public Optional<? extends ConnectorPageSource> createPageSource(Configuration co
private static class MockRcBinaryBatchPageSourceFactory
implements HiveBatchPageSourceFactory
{
@Override
public Optional<? extends ConnectorPageSource> createPageSource(Configuration configuration, ConnectorSession session, HiveFileSplit fileSplit, Storage storage, SchemaTableName tableName, Map<String, String> tableParameters, List<HiveColumnHandle> columns, TupleDomain<HiveColumnHandle> effectivePredicate, DateTimeZone hiveStorageTimeZone, HiveFileContext hiveFileContext, Optional<EncryptionInformation> encryptionInformation)
public Optional<? extends ConnectorPageSource> createPageSource(Configuration configuration,
ConnectorSession session,
HiveFileSplit fileSplit,
Storage storage,
SchemaTableName tableName,
Map<String, String> tableParameters,
List<HiveColumnHandle> columns,
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent)
{
if (!storage.getStorageFormat().getSerDe().equals(LazyBinaryColumnarSerDe.class.getName())) {
return Optional.empty();
Expand Down

0 comments on commit c1bfc50

Please sign in to comment.