Skip to content

Commit

Permalink
Add row IDs to batch reader
Browse files Browse the repository at this point in the history
  • Loading branch information
elharo committed May 2, 2024
1 parent 522e095 commit c7ad8a2
Show file tree
Hide file tree
Showing 16 changed files with 166 additions and 25 deletions.
Expand Up @@ -38,5 +38,6 @@ Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation);
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent);
}
Expand Up @@ -96,9 +96,11 @@ public HivePageSource(
if (columnMapping.getCoercionFrom().isPresent()) {
coercers[columnIndex] = createCoercer(typeManager, columnMapping.getCoercionFrom().get(), columnMapping.getHiveColumnHandle().getHiveType());
}
else if (isRowIdColumnHandle(columnMapping.getHiveColumnHandle()) && rowIdPartitionComponent.isPresent()) {
else if (isRowIdColumnHandle(columnMapping.getHiveColumnHandle())) {
// If there's no row ID partition component, then path + row numbers will be supplied for $row_id
byte[] component = rowIdPartitionComponent.orElse(new byte[0]);
String rowGroupId = Paths.get(path).getFileName().toString();
coercers[columnIndex] = new RowIDCoercer(rowIdPartitionComponent.get(), rowGroupId);
coercers[columnIndex] = new RowIDCoercer(component, rowGroupId);
}

if (columnMapping.getKind() == PREFILLED) {
Expand Down
Expand Up @@ -183,6 +183,7 @@ public ConnectorPageSource createPageSource(
return createAggregatedPageSource(aggregatedPageSourceFactories, configuration, session, hiveSplit, hiveLayout, selectedColumns, fileContext, encryptionInformation);
}
if (hiveLayout.isPushdownFilterEnabled()) {
Optional<byte[]> rowIDPartitionComponent = hiveSplit.getRowIdPartitionComponent();
Optional<ConnectorPageSource> selectivePageSource = createSelectivePageSource(
selectivePageSourceFactories,
configuration,
Expand Down Expand Up @@ -371,12 +372,14 @@ private static Optional<ConnectorPageSource> createSelectivePageSource(
.orElse(layout.getDomainPredicate());

for (HiveSelectivePageSourceFactory pageSourceFactory : selectivePageSourceFactories) {
List<HiveColumnHandle> columnHandles = toColumnHandles(columnMappings, true);
Optional<byte[]> rowIDPartitionComponent = split.getRowIdPartitionComponent();
Optional<? extends ConnectorPageSource> pageSource = pageSourceFactory.createPageSource(
configuration,
session,
split.getFileSplit(),
split.getStorage(),
toColumnHandles(columnMappings, true),
columnHandles,
prefilledValues,
coercers,
bucketAdaptation,
Expand All @@ -387,7 +390,7 @@ private static Optional<ConnectorPageSource> createSelectivePageSource(
fileContext,
encryptionInformation,
layout.isAppendRowNumberEnabled(),
split.getRowIdPartitionComponent());
rowIDPartitionComponent);
if (pageSource.isPresent()) {
return Optional.of(pageSource.get());
}
Expand Down Expand Up @@ -497,7 +500,8 @@ public static Optional<ConnectorPageSource> createHivePageSource(
effectivePredicate,
hiveStorageTimeZone,
hiveFileContext,
encryptionInformation);
encryptionInformation,
rowIdPartitionComponent);
if (pageSource.isPresent()) {
HivePageSource hivePageSource = new HivePageSource(
columnMappings,
Expand Down
Expand Up @@ -213,6 +213,14 @@ public final class HiveUtil
private static final String USE_RECORD_READER_FROM_INPUT_FORMAT_ANNOTATION = "UseRecordReaderFromInputFormat";
private static final String USE_FILE_SPLITS_FROM_INPUT_FORMAT_ANNOTATION = "UseFileSplitsFromInputFormat";

public static void checkRowIDPartitionComponent(List<HiveColumnHandle> columns, Optional<byte[]> rowIdPartitionComponent)
{
boolean supplyRowIDs = columns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
if (supplyRowIDs) {
checkArgument(rowIdPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
}
}

static {
DateTimeParser[] timestampWithoutTimeZoneParser = {
DateTimeFormat.forPattern("yyyy-M-d").getParser(),
Expand Down
Expand Up @@ -98,7 +98,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -132,6 +133,7 @@ public Optional<? extends ConnectorPageSource> createPageSource(
.build(),
encryptionInformation,
dwrfEncryptionProvider,
session));
session,
rowIDPartitionComponent));
}
}
Expand Up @@ -48,6 +48,7 @@
import java.util.Optional;

import static com.facebook.presto.hive.HiveErrorCode.HIVE_BAD_DATA;
import static com.facebook.presto.hive.HiveUtil.checkRowIDPartitionComponent;
import static com.facebook.presto.hive.orc.OrcSelectivePageSourceFactory.createOrcPageSource;
import static com.facebook.presto.orc.OrcEncoding.DWRF;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -118,6 +119,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
throw new PrestoException(HIVE_BAD_DATA, "ORC file is empty: " + fileSplit.getPath());
}

checkRowIDPartitionComponent(columns, rowIDPartitionComponent);

return Optional.of(createOrcPageSource(
session,
DWRF,
Expand Down
Expand Up @@ -24,6 +24,7 @@
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.hive.FileFormatDataSourceStats;
import com.facebook.presto.hive.HiveColumnHandle;
import com.facebook.presto.hive.RowIDCoercer;
import com.facebook.presto.orc.OrcAggregatedMemoryContext;
import com.facebook.presto.orc.OrcBatchRecordReader;
import com.facebook.presto.orc.OrcCorruptionException;
Expand Down Expand Up @@ -58,6 +59,7 @@ public class OrcBatchPageSource

private final Block[] constantBlocks;
private final int[] hiveColumnIndexes;
private final boolean[] rowIDColumnIndexes;

private int batchId;
private long completedPositions;
Expand All @@ -71,16 +73,20 @@ public class OrcBatchPageSource

private final List<Boolean> isRowPositionList;

private final RowIDCoercer coercer;

public OrcBatchPageSource(
OrcBatchRecordReader recordReader,
OrcDataSource orcDataSource,
List<HiveColumnHandle> columns,
TypeManager typeManager,
OrcAggregatedMemoryContext systemMemoryContext,
FileFormatDataSourceStats stats,
RuntimeStats runtimeStats)
RuntimeStats runtimeStats,
byte[] rowIDPartitionComponent,
String rowGroupId)
{
this(recordReader, orcDataSource, columns, typeManager, systemMemoryContext, stats, runtimeStats, nCopies(columns.size(), false));
this(recordReader, orcDataSource, columns, typeManager, systemMemoryContext, stats, runtimeStats, nCopies(columns.size(), false), rowIDPartitionComponent, rowGroupId);
}

/**
Expand All @@ -97,7 +103,9 @@ public OrcBatchPageSource(
OrcAggregatedMemoryContext systemMemoryContext,
FileFormatDataSourceStats stats,
RuntimeStats runtimeStats,
List<Boolean> isRowPositionList)
List<Boolean> isRowPositionList,
byte[] rowIDPartitionComponent,
String rowGroupId)
{
this.recordReader = requireNonNull(recordReader, "recordReader is null");
this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null");
Expand All @@ -107,9 +115,12 @@ public OrcBatchPageSource(
this.stats = requireNonNull(stats, "stats is null");
this.runtimeStats = requireNonNull(runtimeStats, "runtimeStats is null");
this.isRowPositionList = requireNonNull(isRowPositionList, "isRowPositionList is null");
// TODO don't create this if there's no rowID column
this.coercer = new RowIDCoercer(rowIDPartitionComponent, rowGroupId);

this.constantBlocks = new Block[size];
this.hiveColumnIndexes = new int[size];
this.rowIDColumnIndexes = new boolean[size];

ImmutableList.Builder<String> namesBuilder = ImmutableList.builder();
ImmutableList.Builder<Type> typesBuilder = ImmutableList.builder();
Expand All @@ -124,6 +135,7 @@ public OrcBatchPageSource(
typesBuilder.add(type);

hiveColumnIndexes[columnIndex] = column.getHiveColumnIndex();
rowIDColumnIndexes[columnIndex] = HiveColumnHandle.isRowIdColumnHandle(column);

if (!recordReader.isColumnPresent(column.getHiveColumnIndex())) {
constantBlocks[columnIndex] = RunLengthEncodedBlock.create(type, null, MAX_BATCH_SIZE);
Expand Down Expand Up @@ -183,6 +195,11 @@ public Page getNextPage()
if (isRowPositionColumn(fieldId)) {
blocks[fieldId] = getRowPosColumnBlock(recordReader.getFilePosition(), batchSize);
}
else if (isRowIDColumn(fieldId)) {
Block rowNumbers = getRowPosColumnBlock(recordReader.getFilePosition(), batchSize);
Block rowIDs = coercer.apply(rowNumbers);
blocks[fieldId] = rowIDs;
}
else {
if (constantBlocks[fieldId] != null) {
blocks[fieldId] = constantBlocks[fieldId].getRegion(0, batchSize);
Expand Down Expand Up @@ -260,6 +277,12 @@ private boolean isRowPositionColumn(int column)
return isRowPositionList.get(column);
}

private boolean isRowIDColumn(int column)
{
return this.rowIDColumnIndexes[column];
}

// TODO verify these are row numbers and rename?
private static Block getRowPosColumnBlock(long baseIndex, int size)
{
long[] rowPositions = new long[size];
Expand Down
Expand Up @@ -63,6 +63,7 @@
import static com.facebook.presto.hive.HiveCommonSessionProperties.getOrcTinyStripeThreshold;
import static com.facebook.presto.hive.HiveCommonSessionProperties.isOrcBloomFiltersEnabled;
import static com.facebook.presto.hive.HiveCommonSessionProperties.isOrcZstdJniDecompressionEnabled;
import static com.facebook.presto.hive.HiveUtil.checkRowIDPartitionComponent;
import static com.facebook.presto.hive.HiveUtil.getPhysicalHiveColumnHandles;
import static com.facebook.presto.hive.orc.OrcPageSourceFactoryUtils.getOrcDataSource;
import static com.facebook.presto.hive.orc.OrcPageSourceFactoryUtils.getOrcReader;
Expand Down Expand Up @@ -134,7 +135,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -169,7 +171,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
.build(),
encryptionInformation,
NO_ENCRYPTION,
session));
session,
rowIDPartitionComponent));
}

public static ConnectorPageSource createOrcPageSource(
Expand All @@ -191,9 +194,11 @@ public static ConnectorPageSource createOrcPageSource(
OrcReaderOptions orcReaderOptions,
Optional<EncryptionInformation> encryptionInformation,
DwrfEncryptionProvider dwrfEncryptionProvider,
ConnectorSession session)
ConnectorSession session,
Optional<byte[]> rowIDPartitionComponent)
{
checkArgument(domainCompactionThreshold >= 1, "domainCompactionThreshold must be at least 1");
checkRowIDPartitionComponent(columns, rowIDPartitionComponent);

OrcDataSource orcDataSource = getOrcDataSource(session, fileSplit, hdfsEnvironment, configuration, hiveFileContext, stats);
Path path = new Path(fileSplit.getPath());
Expand Down Expand Up @@ -235,14 +240,18 @@ public static ConnectorPageSource createOrcPageSource(
systemMemoryUsage,
INITIAL_BATCH_SIZE);

byte[] partitionID = rowIDPartitionComponent.orElse(new byte[0]);
String rowGroupID = path.getName();
return new OrcBatchPageSource(
recordReader,
reader.getOrcDataSource(),
physicalColumns,
typeManager,
systemMemoryUsage,
stats,
hiveFileContext.getStats());
hiveFileContext.getStats(),
partitionID,
rowGroupID);
}
catch (Exception e) {
try {
Expand Down
Expand Up @@ -283,9 +283,9 @@ public static ConnectorPageSource createOrcPageSource(
Path path = new Path(fileSplit.getPath());

boolean supplyRowIDs = selectedColumns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
String rowGroupId = path.getName();
checkArgument(!supplyRowIDs || rowIDPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
byte[] partitionID = rowIDPartitionComponent.orElse(new byte[0]);
String rowGroupId = path.getName();

DataSize maxMergeDistance = getOrcMaxMergeDistance(session);
DataSize tinyStripeThreshold = getOrcTinyStripeThreshold(session);
Expand Down
Expand Up @@ -72,7 +72,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!PageInputFormat.class.getSimpleName().equals(storage.getStorageFormat().getInputFormat())) {
return Optional.empty();
Expand Down
Expand Up @@ -512,7 +512,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent)
{
if (!PARQUET_SERDE_CLASS_NAMES.contains(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down
Expand Up @@ -106,7 +106,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
RcFileEncoding rcFileEncoding;
if (LazyBinaryColumnarSerDe.class.getName().equals(storage.getStorageFormat().getSerDe())) {
Expand Down

0 comments on commit c7ad8a2

Please sign in to comment.