Skip to content

Commit

Permalink
Add row IDs to batch reader
Browse files Browse the repository at this point in the history
  • Loading branch information
elharo committed Apr 29, 2024
1 parent ada2eb9 commit ee5d947
Show file tree
Hide file tree
Showing 12 changed files with 83 additions and 18 deletions.
Expand Up @@ -38,5 +38,6 @@ Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation);
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent);
}
Expand Up @@ -505,7 +505,8 @@ public static Optional<ConnectorPageSource> createHivePageSource(
effectivePredicate,
hiveStorageTimeZone,
hiveFileContext,
encryptionInformation);
encryptionInformation,
rowIdPartitionComponent);
if (pageSource.isPresent()) {
HivePageSource hivePageSource = new HivePageSource(
columnMappings,
Expand Down
Expand Up @@ -98,7 +98,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -132,6 +133,7 @@ public Optional<? extends ConnectorPageSource> createPageSource(
.build(),
encryptionInformation,
dwrfEncryptionProvider,
session));
session,
rowIDPartitionComponent));
}
}
Expand Up @@ -24,6 +24,7 @@
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.hive.FileFormatDataSourceStats;
import com.facebook.presto.hive.HiveColumnHandle;
import com.facebook.presto.hive.RowIDCoercer;
import com.facebook.presto.orc.OrcAggregatedMemoryContext;
import com.facebook.presto.orc.OrcBatchRecordReader;
import com.facebook.presto.orc.OrcCorruptionException;
Expand Down Expand Up @@ -58,6 +59,7 @@ public class OrcBatchPageSource

private final Block[] constantBlocks;
private final int[] hiveColumnIndexes;
private final boolean[] rowIDColumnIndexes;

private int batchId;
private long completedPositions;
Expand All @@ -71,16 +73,20 @@ public class OrcBatchPageSource

private final List<Boolean> isRowPositionList;

private final RowIDCoercer coercer;

public OrcBatchPageSource(
OrcBatchRecordReader recordReader,
OrcDataSource orcDataSource,
List<HiveColumnHandle> columns,
TypeManager typeManager,
OrcAggregatedMemoryContext systemMemoryContext,
FileFormatDataSourceStats stats,
RuntimeStats runtimeStats)
RuntimeStats runtimeStats,
byte[] rowIDPartitionComponent,
String rowGroupId)
{
this(recordReader, orcDataSource, columns, typeManager, systemMemoryContext, stats, runtimeStats, nCopies(columns.size(), false));
this(recordReader, orcDataSource, columns, typeManager, systemMemoryContext, stats, runtimeStats, nCopies(columns.size(), false), rowIDPartitionComponent, rowGroupId);
}

/**
Expand All @@ -97,7 +103,9 @@ public OrcBatchPageSource(
OrcAggregatedMemoryContext systemMemoryContext,
FileFormatDataSourceStats stats,
RuntimeStats runtimeStats,
List<Boolean> isRowPositionList)
List<Boolean> isRowPositionList,
byte[] rowIDPartitionComponent,
String rowGroupId)
{
this.recordReader = requireNonNull(recordReader, "recordReader is null");
this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null");
Expand All @@ -107,9 +115,11 @@ public OrcBatchPageSource(
this.stats = requireNonNull(stats, "stats is null");
this.runtimeStats = requireNonNull(runtimeStats, "runtimeStats is null");
this.isRowPositionList = requireNonNull(isRowPositionList, "isRowPositionList is null");
this.coercer = new RowIDCoercer(rowIDPartitionComponent, rowGroupId);

this.constantBlocks = new Block[size];
this.hiveColumnIndexes = new int[size];
this.rowIDColumnIndexes = new boolean[size];

ImmutableList.Builder<String> namesBuilder = ImmutableList.builder();
ImmutableList.Builder<Type> typesBuilder = ImmutableList.builder();
Expand All @@ -124,6 +134,7 @@ public OrcBatchPageSource(
typesBuilder.add(type);

hiveColumnIndexes[columnIndex] = column.getHiveColumnIndex();
rowIDColumnIndexes[columnIndex] = HiveColumnHandle.isRowIdColumnHandle(column);

if (!recordReader.isColumnPresent(column.getHiveColumnIndex())) {
constantBlocks[columnIndex] = RunLengthEncodedBlock.create(type, null, MAX_BATCH_SIZE);
Expand Down Expand Up @@ -183,6 +194,11 @@ public Page getNextPage()
if (isRowPositionColumn(fieldId)) {
blocks[fieldId] = getRowPosColumnBlock(recordReader.getFilePosition(), batchSize);
}
else if (isRowIDColumn(fieldId)) {
Block rowNumbers = getRowPosColumnBlock(recordReader.getFilePosition(), batchSize);
Block rowIDs = coercer.apply(rowNumbers);
blocks[fieldId] = rowIDs;
}
else {
if (constantBlocks[fieldId] != null) {
blocks[fieldId] = constantBlocks[fieldId].getRegion(0, batchSize);
Expand Down Expand Up @@ -260,6 +276,12 @@ private boolean isRowPositionColumn(int column)
return isRowPositionList.get(column);
}

private boolean isRowIDColumn(int column)
{
return this.rowIDColumnIndexes[column];
}

// TODO verify these are row numbers and rename?
private static Block getRowPosColumnBlock(long baseIndex, int size)
{
long[] rowPositions = new long[size];
Expand Down
Expand Up @@ -134,7 +134,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -169,7 +170,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
.build(),
encryptionInformation,
NO_ENCRYPTION,
session));
session,
rowIDPartitionComponent));
}

public static ConnectorPageSource createOrcPageSource(
Expand All @@ -191,7 +193,8 @@ public static ConnectorPageSource createOrcPageSource(
OrcReaderOptions orcReaderOptions,
Optional<EncryptionInformation> encryptionInformation,
DwrfEncryptionProvider dwrfEncryptionProvider,
ConnectorSession session)
ConnectorSession session,
Optional<byte[]> rowIDPartitionComponent)
{
checkArgument(domainCompactionThreshold >= 1, "domainCompactionThreshold must be at least 1");

Expand Down Expand Up @@ -235,14 +238,20 @@ public static ConnectorPageSource createOrcPageSource(
systemMemoryUsage,
INITIAL_BATCH_SIZE);

boolean supplyRowIDs = columns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
checkArgument(!supplyRowIDs || rowIDPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
byte[] partitionID = rowIDPartitionComponent.orElse(new byte[0]);
String rowGroupID = path.getName();
return new OrcBatchPageSource(
recordReader,
reader.getOrcDataSource(),
physicalColumns,
typeManager,
systemMemoryUsage,
stats,
hiveFileContext.getStats());
hiveFileContext.getStats(),
partitionID,
rowGroupID);
}
catch (Exception e) {
try {
Expand Down
Expand Up @@ -285,6 +285,7 @@ public static ConnectorPageSource createOrcPageSource(
boolean supplyRowIDs = selectedColumns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
checkArgument(!supplyRowIDs || rowIDPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
// TODO follow code path for !supplyRowIDs == true following below
// TODO copy this into batch
byte[] partitionID = rowIDPartitionComponent.orElse(new byte[0]);
String rowGroupId = path.getName();

Expand Down
Expand Up @@ -72,7 +72,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!PageInputFormat.class.getSimpleName().equals(storage.getStorageFormat().getInputFormat())) {
return Optional.empty();
Expand Down
Expand Up @@ -512,7 +512,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent)
{
if (!PARQUET_SERDE_CLASS_NAMES.contains(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down
Expand Up @@ -106,7 +106,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
RcFileEncoding rcFileEncoding;
if (LazyBinaryColumnarSerDe.class.getName().equals(storage.getStorageFormat().getSerDe())) {
Expand Down
Expand Up @@ -497,7 +497,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(Configuration co
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent)
{
return Optional.of(new MockPageSource());
}
Expand Down Expand Up @@ -640,7 +641,18 @@ private static class MockOrcBatchPageSourceFactory
implements HiveBatchPageSourceFactory
{
@Override
public Optional<? extends ConnectorPageSource> createPageSource(Configuration configuration, ConnectorSession session, HiveFileSplit fileSplit, Storage storage, SchemaTableName tableName, Map<String, String> tableParameters, List<HiveColumnHandle> columns, TupleDomain<HiveColumnHandle> effectivePredicate, DateTimeZone hiveStorageTimeZone, HiveFileContext hiveFileContext, Optional<EncryptionInformation> encryptionInformation)
public Optional<? extends ConnectorPageSource> createPageSource(Configuration configuration,
ConnectorSession session,
HiveFileSplit fileSplit,
Storage storage,
SchemaTableName tableName,
Map<String, String> tableParameters,
List<HiveColumnHandle> columns,
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -696,8 +708,18 @@ public Optional<? extends ConnectorPageSource> createPageSource(Configuration co
private static class MockRcBinaryBatchPageSourceFactory
implements HiveBatchPageSourceFactory
{
@Override
public Optional<? extends ConnectorPageSource> createPageSource(Configuration configuration, ConnectorSession session, HiveFileSplit fileSplit, Storage storage, SchemaTableName tableName, Map<String, String> tableParameters, List<HiveColumnHandle> columns, TupleDomain<HiveColumnHandle> effectivePredicate, DateTimeZone hiveStorageTimeZone, HiveFileContext hiveFileContext, Optional<EncryptionInformation> encryptionInformation)
public Optional<? extends ConnectorPageSource> createPageSource(Configuration configuration,
ConnectorSession session,
HiveFileSplit fileSplit,
Storage storage,
SchemaTableName tableName,
Map<String, String> tableParameters,
List<HiveColumnHandle> columns,
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent)
{
if (!storage.getStorageFormat().getSerDe().equals(LazyBinaryColumnarSerDe.class.getName())) {
return Optional.empty();
Expand Down
Expand Up @@ -507,6 +507,7 @@ public static ConnectorPageSource createPageSource(
OptionalLong.of(targetFile.length()),
modificationTime,
false),
Optional.empty(),
Optional.empty())
.get();
}
Expand Down
Expand Up @@ -158,6 +158,9 @@ public Supplier<Optional<UpdatablePageSource>> addSplit(ScheduledSplit scheduled

this.split = split;

// TODO if the columns has $row_id then split.getConnectorSplit() should have a row ID partition component
// check that here, clean install, then run locally with a throw to see where the split comes from

Object splitInfo = split.getInfo();
Map<String, String> infoMap = split.getInfoMap();

Expand Down

0 comments on commit ee5d947

Please sign in to comment.