Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add row ID support to batch ORC reader #22615

Merged
merged 1 commit into from May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -38,5 +38,6 @@ Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation);
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent);
}
Expand Up @@ -96,9 +96,11 @@ public HivePageSource(
if (columnMapping.getCoercionFrom().isPresent()) {
coercers[columnIndex] = createCoercer(typeManager, columnMapping.getCoercionFrom().get(), columnMapping.getHiveColumnHandle().getHiveType());
}
else if (isRowIdColumnHandle(columnMapping.getHiveColumnHandle()) && rowIdPartitionComponent.isPresent()) {
else if (isRowIdColumnHandle(columnMapping.getHiveColumnHandle())) {
// If there's no row ID partition component, then path + row numbers will be supplied for $row_id
byte[] component = rowIdPartitionComponent.orElse(new byte[0]);
String rowGroupId = Paths.get(path).getFileName().toString();
coercers[columnIndex] = new RowIDCoercer(rowIdPartitionComponent.get(), rowGroupId);
coercers[columnIndex] = new RowIDCoercer(component, rowGroupId);
}

if (columnMapping.getKind() == PREFILLED) {
Expand Down
Expand Up @@ -183,6 +183,7 @@ public ConnectorPageSource createPageSource(
return createAggregatedPageSource(aggregatedPageSourceFactories, configuration, session, hiveSplit, hiveLayout, selectedColumns, fileContext, encryptionInformation);
}
if (hiveLayout.isPushdownFilterEnabled()) {
Optional<byte[]> rowIDPartitionComponent = hiveSplit.getRowIdPartitionComponent();
Optional<ConnectorPageSource> selectivePageSource = createSelectivePageSource(
selectivePageSourceFactories,
configuration,
Expand Down Expand Up @@ -371,12 +372,14 @@ private static Optional<ConnectorPageSource> createSelectivePageSource(
.orElse(layout.getDomainPredicate());

for (HiveSelectivePageSourceFactory pageSourceFactory : selectivePageSourceFactories) {
List<HiveColumnHandle> columnHandles = toColumnHandles(columnMappings, true);
Optional<byte[]> rowIDPartitionComponent = split.getRowIdPartitionComponent();
Optional<? extends ConnectorPageSource> pageSource = pageSourceFactory.createPageSource(
configuration,
session,
split.getFileSplit(),
split.getStorage(),
toColumnHandles(columnMappings, true),
columnHandles,
prefilledValues,
coercers,
bucketAdaptation,
Expand All @@ -387,7 +390,7 @@ private static Optional<ConnectorPageSource> createSelectivePageSource(
fileContext,
encryptionInformation,
layout.isAppendRowNumberEnabled(),
split.getRowIdPartitionComponent());
rowIDPartitionComponent);
if (pageSource.isPresent()) {
return Optional.of(pageSource.get());
}
Expand Down Expand Up @@ -497,7 +500,8 @@ public static Optional<ConnectorPageSource> createHivePageSource(
effectivePredicate,
hiveStorageTimeZone,
hiveFileContext,
encryptionInformation);
encryptionInformation,
rowIdPartitionComponent);
if (pageSource.isPresent()) {
HivePageSource hivePageSource = new HivePageSource(
columnMappings,
Expand Down
Expand Up @@ -213,6 +213,14 @@ public final class HiveUtil
private static final String USE_RECORD_READER_FROM_INPUT_FORMAT_ANNOTATION = "UseRecordReaderFromInputFormat";
private static final String USE_FILE_SPLITS_FROM_INPUT_FORMAT_ANNOTATION = "UseFileSplitsFromInputFormat";

public static void checkRowIDPartitionComponent(List<HiveColumnHandle> columns, Optional<byte[]> rowIdPartitionComponent)
{
boolean supplyRowIDs = columns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
if (supplyRowIDs) {
checkArgument(rowIdPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
}
}

static {
DateTimeParser[] timestampWithoutTimeZoneParser = {
DateTimeFormat.forPattern("yyyy-M-d").getParser(),
Expand Down
Expand Up @@ -98,7 +98,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -132,6 +133,7 @@ public Optional<? extends ConnectorPageSource> createPageSource(
.build(),
encryptionInformation,
dwrfEncryptionProvider,
session));
session,
rowIDPartitionComponent));
}
}
Expand Up @@ -48,6 +48,7 @@
import java.util.Optional;

import static com.facebook.presto.hive.HiveErrorCode.HIVE_BAD_DATA;
import static com.facebook.presto.hive.HiveUtil.checkRowIDPartitionComponent;
import static com.facebook.presto.hive.orc.OrcSelectivePageSourceFactory.createOrcPageSource;
import static com.facebook.presto.orc.OrcEncoding.DWRF;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -118,6 +119,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
throw new PrestoException(HIVE_BAD_DATA, "ORC file is empty: " + fileSplit.getPath());
}

checkRowIDPartitionComponent(columns, rowIDPartitionComponent);

return Optional.of(createOrcPageSource(
session,
DWRF,
Expand Down
Expand Up @@ -24,6 +24,7 @@
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.hive.FileFormatDataSourceStats;
import com.facebook.presto.hive.HiveColumnHandle;
import com.facebook.presto.hive.RowIDCoercer;
import com.facebook.presto.orc.OrcAggregatedMemoryContext;
import com.facebook.presto.orc.OrcBatchRecordReader;
import com.facebook.presto.orc.OrcCorruptionException;
Expand Down Expand Up @@ -58,6 +59,7 @@ public class OrcBatchPageSource

private final Block[] constantBlocks;
private final int[] hiveColumnIndexes;
private final boolean[] rowIDColumnIndexes;

private int batchId;
private long completedPositions;
Expand All @@ -71,16 +73,20 @@ public class OrcBatchPageSource

private final List<Boolean> isRowPositionList;

private final RowIDCoercer coercer;

public OrcBatchPageSource(
OrcBatchRecordReader recordReader,
OrcDataSource orcDataSource,
List<HiveColumnHandle> columns,
TypeManager typeManager,
OrcAggregatedMemoryContext systemMemoryContext,
FileFormatDataSourceStats stats,
RuntimeStats runtimeStats)
RuntimeStats runtimeStats,
byte[] rowIDPartitionComponent,
String rowGroupId)
{
this(recordReader, orcDataSource, columns, typeManager, systemMemoryContext, stats, runtimeStats, nCopies(columns.size(), false));
this(recordReader, orcDataSource, columns, typeManager, systemMemoryContext, stats, runtimeStats, nCopies(columns.size(), false), rowIDPartitionComponent, rowGroupId);
}

/**
Expand All @@ -97,7 +103,9 @@ public OrcBatchPageSource(
OrcAggregatedMemoryContext systemMemoryContext,
FileFormatDataSourceStats stats,
RuntimeStats runtimeStats,
List<Boolean> isRowPositionList)
List<Boolean> isRowPositionList,
byte[] rowIDPartitionComponent,
String rowGroupId)
{
this.recordReader = requireNonNull(recordReader, "recordReader is null");
this.orcDataSource = requireNonNull(orcDataSource, "orcDataSource is null");
Expand All @@ -107,9 +115,12 @@ public OrcBatchPageSource(
this.stats = requireNonNull(stats, "stats is null");
this.runtimeStats = requireNonNull(runtimeStats, "runtimeStats is null");
this.isRowPositionList = requireNonNull(isRowPositionList, "isRowPositionList is null");
// TODO don't create this if there's no rowID column
this.coercer = new RowIDCoercer(rowIDPartitionComponent, rowGroupId);

this.constantBlocks = new Block[size];
this.hiveColumnIndexes = new int[size];
this.rowIDColumnIndexes = new boolean[size];

ImmutableList.Builder<String> namesBuilder = ImmutableList.builder();
ImmutableList.Builder<Type> typesBuilder = ImmutableList.builder();
Expand All @@ -124,6 +135,7 @@ public OrcBatchPageSource(
typesBuilder.add(type);

hiveColumnIndexes[columnIndex] = column.getHiveColumnIndex();
rowIDColumnIndexes[columnIndex] = HiveColumnHandle.isRowIdColumnHandle(column);

if (!recordReader.isColumnPresent(column.getHiveColumnIndex())) {
constantBlocks[columnIndex] = RunLengthEncodedBlock.create(type, null, MAX_BATCH_SIZE);
Expand Down Expand Up @@ -183,6 +195,11 @@ public Page getNextPage()
if (isRowPositionColumn(fieldId)) {
blocks[fieldId] = getRowPosColumnBlock(recordReader.getFilePosition(), batchSize);
}
else if (isRowIDColumn(fieldId)) {
Block rowNumbers = getRowPosColumnBlock(recordReader.getFilePosition(), batchSize);
Block rowIDs = coercer.apply(rowNumbers);
blocks[fieldId] = rowIDs;
}
else {
if (constantBlocks[fieldId] != null) {
blocks[fieldId] = constantBlocks[fieldId].getRegion(0, batchSize);
Expand Down Expand Up @@ -260,6 +277,12 @@ private boolean isRowPositionColumn(int column)
return isRowPositionList.get(column);
}

private boolean isRowIDColumn(int column)
{
return this.rowIDColumnIndexes[column];
}

// TODO verify these are row numbers and rename?
private static Block getRowPosColumnBlock(long baseIndex, int size)
{
long[] rowPositions = new long[size];
Expand Down
Expand Up @@ -63,6 +63,7 @@
import static com.facebook.presto.hive.HiveCommonSessionProperties.getOrcTinyStripeThreshold;
import static com.facebook.presto.hive.HiveCommonSessionProperties.isOrcBloomFiltersEnabled;
import static com.facebook.presto.hive.HiveCommonSessionProperties.isOrcZstdJniDecompressionEnabled;
import static com.facebook.presto.hive.HiveUtil.checkRowIDPartitionComponent;
import static com.facebook.presto.hive.HiveUtil.getPhysicalHiveColumnHandles;
import static com.facebook.presto.hive.orc.OrcPageSourceFactoryUtils.getOrcDataSource;
import static com.facebook.presto.hive.orc.OrcPageSourceFactoryUtils.getOrcReader;
Expand Down Expand Up @@ -134,7 +135,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!OrcSerde.class.getName().equals(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down Expand Up @@ -169,7 +171,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
.build(),
encryptionInformation,
NO_ENCRYPTION,
session));
session,
rowIDPartitionComponent));
}

public static ConnectorPageSource createOrcPageSource(
Expand All @@ -191,9 +194,11 @@ public static ConnectorPageSource createOrcPageSource(
OrcReaderOptions orcReaderOptions,
Optional<EncryptionInformation> encryptionInformation,
DwrfEncryptionProvider dwrfEncryptionProvider,
ConnectorSession session)
ConnectorSession session,
Optional<byte[]> rowIDPartitionComponent)
{
checkArgument(domainCompactionThreshold >= 1, "domainCompactionThreshold must be at least 1");
checkRowIDPartitionComponent(columns, rowIDPartitionComponent);

OrcDataSource orcDataSource = getOrcDataSource(session, fileSplit, hdfsEnvironment, configuration, hiveFileContext, stats);
Path path = new Path(fileSplit.getPath());
Expand Down Expand Up @@ -235,14 +240,18 @@ public static ConnectorPageSource createOrcPageSource(
systemMemoryUsage,
INITIAL_BATCH_SIZE);

byte[] partitionID = rowIDPartitionComponent.orElse(new byte[0]);
String rowGroupID = path.getName();
return new OrcBatchPageSource(
recordReader,
reader.getOrcDataSource(),
physicalColumns,
typeManager,
systemMemoryUsage,
stats,
hiveFileContext.getStats());
hiveFileContext.getStats(),
partitionID,
rowGroupID);
}
catch (Exception e) {
try {
Expand Down
Expand Up @@ -283,9 +283,9 @@ public static ConnectorPageSource createOrcPageSource(
Path path = new Path(fileSplit.getPath());

boolean supplyRowIDs = selectedColumns.stream().anyMatch(column -> HiveColumnHandle.isRowIdColumnHandle(column));
String rowGroupId = path.getName();
checkArgument(!supplyRowIDs || rowIDPartitionComponent.isPresent(), "rowIDPartitionComponent required when supplying row IDs");
byte[] partitionID = rowIDPartitionComponent.orElse(new byte[0]);
String rowGroupId = path.getName();

DataSize maxMergeDistance = getOrcMaxMergeDistance(session);
DataSize tinyStripeThreshold = getOrcTinyStripeThreshold(session);
Expand Down
Expand Up @@ -72,7 +72,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
if (!PageInputFormat.class.getSimpleName().equals(storage.getStorageFormat().getInputFormat())) {
return Optional.empty();
Expand Down
Expand Up @@ -512,7 +512,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIdPartitionComponent)
{
if (!PARQUET_SERDE_CLASS_NAMES.contains(storage.getStorageFormat().getSerDe())) {
return Optional.empty();
Expand Down
Expand Up @@ -106,7 +106,8 @@ public Optional<? extends ConnectorPageSource> createPageSource(
TupleDomain<HiveColumnHandle> effectivePredicate,
DateTimeZone hiveStorageTimeZone,
HiveFileContext hiveFileContext,
Optional<EncryptionInformation> encryptionInformation)
Optional<EncryptionInformation> encryptionInformation,
Optional<byte[]> rowIDPartitionComponent)
{
RcFileEncoding rcFileEncoding;
if (LazyBinaryColumnarSerDe.class.getName().equals(storage.getStorageFormat().getSerDe())) {
Expand Down