Skip to content

Commit

Permalink
Add row IDs to batch reader
Browse files Browse the repository at this point in the history
  • Loading branch information
elharo committed Apr 25, 2024
1 parent ada2eb9 commit 18e37fa
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 11 deletions.
Expand Up @@ -38,5 +38,6 @@ Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation);
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent);
}
Expand Up @@ -505,7 +505,8 @@ public static Optional<ConnectorPageSource> createHivePageSource(
effectivePredicate,
hiveStorageTimeZone,
hiveFileContext,
encryptionInformation);
encryptionInformation,
rowIdPartitionComponent);
if (pageSource.isPresent()) {
HivePageSource hivePageSource = new HivePageSource(
columnMappings,
Expand Down
Expand Up @@ -98,7 +98,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -132,6 +133,7 @@ public Optional<? extends ConnectorPageSource> createPageSource(
.build(),
encryptionInformation,
dwrfEncryptionProvider,
session));
session,
rowIDPartitionComponent));
}
}
Expand Up @@ -24,6 +24,7 @@
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.hive.FileFormatDataSourceStats;
import com.facebook.presto.hive.HiveColumnHandle;
import com.facebook.presto.hive.RowIDCoercer;
import com.facebook.presto.orc.OrcAggregatedMemoryContext;
import com.facebook.presto.orc.OrcBatchRecordReader;
import com.facebook.presto.orc.OrcCorruptionException;
Expand Down Expand Up @@ -71,16 +72,20 @@ public class OrcBatchPageSource

private final List<Boolean> isRowPositionList;

private final RowIDCoercer coercer;

public OrcBatchPageSource(
OrcBatchRecordReader recordReader,
OrcDataSource orcDataSource,
List<HiveColumnHandle> columns,
TypeManager typeManager,
OrcAggregatedMemoryContext systemMemoryContext,
FileFormatDataSourceStats stats,
RuntimeStats runtimeStats)
RuntimeStats runtimeStats,
byte[] rowIDPartitionComponent,
String rowGroupId)
{
this(recordReader, orcDataSource, columns, typeManager, systemMemoryContext, stats, runtimeStats, nCopies(columns.size(), false));
this(recordReader, orcDataSource, columns, typeManager, systemMemoryContext, stats, runtimeStats, nCopies(columns.size(), false), rowIDPartitionComponent, rowGroupId);
}

/**
Expand All @@ -97,7 +102,9 @@ public OrcBatchPageSource(
OrcAggregatedMemoryContext systemMemoryContext,
FileFormatDataSourceStats stats,
RuntimeStats runtimeStats,
List<Boolean> isRowPositionList)
List<Boolean> isRowPositionList,
byte[] rowIDPartitionComponent,
String rowGroupId)
{
this.recordReader = requireNonNull(recordReader, "recordReader is null");
this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null");
Expand All @@ -107,6 +114,7 @@ public OrcBatchPageSource(
this.stats = requireNonNull(stats, "stats is null");
this.runtimeStats = requireNonNull(runtimeStats, "runtimeStats is null");
this.isRowPositionList = requireNonNull(isRowPositionList, "isRowPositionList is null");
this.coercer = new RowIDCoercer(rowIDPartitionComponent, rowGroupId);

this.constantBlocks = new Block[size];
this.hiveColumnIndexes = new int[size];
Expand Down Expand Up @@ -182,6 +190,10 @@ public Page getNextPage()
for (int fieldId = 0; fieldId < blocks.length; fieldId++) {
if (isRowPositionColumn(fieldId)) {
blocks[fieldId] = getRowPosColumnBlock(recordReader.getFilePosition(), batchSize);
} else if (check if this is row id column) {
Block rowNumbers = getRowPosColumnBlock(recordReader.getFilePosition(), batchSize);
Block rowIDs = coercer.apply(rowNumbers);
blocks[fieldId] = rowIDs;
}
else {
if (constantBlocks[fieldId] != null) {
Expand Down Expand Up @@ -260,6 +272,7 @@ private boolean isRowPositionColumn(int column)
return isRowPositionList.get(column);
}

// TODO verify these are row numbers and rename?
private static Block getRowPosColumnBlock(long baseIndex, int size)
{
long[] rowPositions = new long[size];
Expand Down
Expand Up @@ -134,7 +134,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -169,7 +170,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
.build(),
encryptionInformation,
NO_ENCRYPTION,
session));
session,
rowIDPartitionComponent));
}

public static ConnectorPageSource createOrcPageSource(
Expand All @@ -191,7 +193,8 @@ public static ConnectorPageSource createOrcPageSource(
OrcReaderOptions orcReaderOptions,
Optional<EncryptionInformation> encryptionInformation,
DwrfEncryptionProvider dwrfEncryptionProvider,
ConnectorSession session)
ConnectorSession session,
Optional<byte[]> rowIDPartitionComponent)
{
checkArgument(domainCompactionThreshold >= 1, "domainCompactionThreshold must be at least 1");

Expand Down Expand Up @@ -235,14 +238,20 @@ public static ConnectorPageSource createOrcPageSource(
systemMemoryUsage,
INITIAL_BATCH_SIZE);

boolean supplyRowIDs = columns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
checkArgument(!supplyRowIDs || rowIDPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
byte[] partitionID = rowIDPartitionComponent.orElse(new byte[0]);
String rowGroupID = path.getName();
return new OrcBatchPageSource(
recordReader,
reader.getOrcDataSource(),
physicalColumns,
typeManager,
systemMemoryUsage,
stats,
hiveFileContext.getStats());
hiveFileContext.getStats(),
partitionID,
rowGroupID);
}
catch (Exception e) {
try {
Expand Down
Expand Up @@ -285,6 +285,7 @@ public static ConnectorPageSource createOrcPageSource(
boolean supplyRowIDs = selectedColumns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
checkArgument(!supplyRowIDs || rowIDPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
// TODO follow code path for !supplyRowIDs == true following below
// TODO copy this into batch
byte[] partitionID = rowIDPartitionComponent.orElse(new byte[0]);
String rowGroupId = path.getName();

Expand Down
Expand Up @@ -158,6 +158,9 @@ public Supplier<Optional<UpdatablePageSource>> addSplit(ScheduledSplit scheduled

this.split = split;

// TODO if the columns has $row_id then split.getConnectorSplit() should have a row ID partition component
// check that here, clean install, then run locally with a throw to see where the split comes from

Object splitInfo = split.getInfo();
Map<String, String> infoMap = split.getInfoMap();

Expand Down

0 comments on commit 18e37fa

Please sign in to comment.