Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented where that have individual clause for each row. #1053

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,12 @@ class RDDSpec extends SparkCassandraITFlatSpecBase {
results should have length keys.count(_ >= 5)
}

it should " support functional where clauses" in {
val someCass = sc.parallelize(keys).map(x => new KVRow(x)).joinWithCassandraTable(ks, tableName).where("group = ?", (k : KVRow) => Seq(k.key * 100))
val results = someCass.collect.map(_._2)
results should have length keys.size
}

it should " throw an exception if using a where on a column that is specified by the join" in {
val exc = intercept[IllegalArgumentException] {
val someCass = sc.parallelize(keys).map(x => (x, x * 100L))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ public <R> CassandraJavaPairRDD<T, R> joinWithCassandraTable(
Option<ClusteringOrder> clusteringOrder = Option.empty();
Option<Object> limit = Option.empty();
CqlWhereClause whereClause = CqlWhereClause.empty();
FCqlWhereClause<T> fwhereClause = FCqlWhereClause.empty();
ReadConf readConf = ReadConf.fromSparkConf(rdd.conf());

CassandraJoinRDD<T, R> joinRDD = new CassandraJoinRDD<>(
Expand All @@ -113,6 +114,7 @@ public <R> CassandraJavaPairRDD<T, R> joinWithCassandraTable(
selectedColumns,
joinColumns,
whereClause,
fwhereClause,
limit,
clusteringOrder,
readConf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ private[rdd] trait AbstractCassandraJoin[L, R] {

val left: RDD[L]
val joinColumns: ColumnSelector
val fwhere : FCqlWhereClause[L]
val manualRowWriter: Option[RowWriter[L]]
implicit val rowWriterFactory: RowWriterFactory[L]

Expand Down Expand Up @@ -99,7 +100,7 @@ private[rdd] trait AbstractCassandraJoin[L, R] {
//We need to make sure we get selectedColumnRefs before serialization so that our RowReader is
//built
lazy val singleKeyCqlQuery: (String) = {
val whereClauses = where.predicates.flatMap(CqlWhereParser.parse)
val whereClauses = where.predicates.flatMap(CqlWhereParser.parse) ++ fwhere.predicates.flatMap(CqlWhereParser.parse)
val joinColumns = joinColumnNames.map(_.columnName)
val joinColumnPredicates = whereClauses.collect {
case EqPredicate(c, _) if joinColumns.contains(c) => c
Expand All @@ -121,7 +122,7 @@ private[rdd] trait AbstractCassandraJoin[L, R] {
val joinWhere = joinColumnNames.map(_.columnName).map(name => s"${quote(name)} = :$name")
val limitClause = limit.map(limit => s"LIMIT $limit").getOrElse("")
val orderBy = clusteringOrder.map(_.toCql(tableDef)).getOrElse("")
val filter = (where.predicates ++ joinWhere).mkString(" AND ")
val filter = (where.predicates ++ fwhere.predicates ++ joinWhere).mkString(" AND ")
val quotedKeyspaceName = quote(keyspaceName)
val quotedTableName = quote(tableName)
val query =
Expand All @@ -135,7 +136,7 @@ private[rdd] trait AbstractCassandraJoin[L, R] {
private def boundStatementBuilder(session: Session): BoundStatementBuilder[L] = {
val protocolVersion = session.getCluster.getConfiguration.getProtocolOptions.getProtocolVersion
val stmt = session.prepare(singleKeyCqlQuery).setConsistencyLevel(consistencyLevel)
new BoundStatementBuilder[L](rowWriter, stmt, where.values, protocolVersion = protocolVersion)
new BoundStatementBuilder[L](rowWriter, stmt, where.values, fwhere, protocolVersion = protocolVersion)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ import org.apache.spark.rdd.RDD

import scala.reflect.ClassTag


case class FCqlWhereClause[L](predicates: Seq[String], values: L => Seq[Any]) {
def apply(v1: L): CqlWhereClause = CqlWhereClause(predicates,values(v1))
def and(other: FCqlWhereClause[L]) = FCqlWhereClause(predicates ++ other.predicates, (l: L) => values(l) ++ other.values(l))
}
object FCqlWhereClause{
def empty[L] : FCqlWhereClause[L] = FCqlWhereClause[L](Nil,(l: L) => Nil)
}


/**
* An [[org.apache.spark.rdd.RDD RDD]] that will do a selecting join between `left` RDD and the specified
* Cassandra Table This will perform individual selects to retrieve the rows from Cassandra and will take
Expand All @@ -27,6 +37,7 @@ class CassandraJoinRDD[L, R] private[connector](
val columnNames: ColumnSelector = AllColumns,
val joinColumns: ColumnSelector = PartitionKeyColumns,
val where: CqlWhereClause = CqlWhereClause.empty,
val fwhere : FCqlWhereClause[L] = FCqlWhereClause.empty[L],
val limit: Option[Long] = None,
val clusteringOrder: Option[ClusteringOrder] = None,
val readConf: ReadConf = ReadConf(),
Expand All @@ -50,7 +61,7 @@ class CassandraJoinRDD[L, R] private[connector](
case None => rowReaderFactory.rowReader(tableDef, columnNames.selectFrom(tableDef))
}

override protected def copy(
protected def copy(
columnNames: ColumnSelector = columnNames,
where: CqlWhereClause = where,
limit: Option[Long] = limit,
Expand All @@ -67,12 +78,36 @@ class CassandraJoinRDD[L, R] private[connector](
columnNames = columnNames,
joinColumns = joinColumns,
where = where,
fwhere = fwhere,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf
)
}

// I was not able to do a proper copy because of the inheritance.
def setFWhere(
fwhere : FCqlWhereClause[L]
): Self = {

new CassandraJoinRDD[L, R](
left = left,
keyspaceName = keyspaceName,
tableName = tableName,
connector = connector,
columnNames = columnNames,
joinColumns = joinColumns,
where = where,
fwhere = fwhere,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf
)
}

def where(f : FCqlWhereClause[L]) : Self = setFWhere(fwhere = fwhere and f)
def where(clause : String, f : L => Seq[Any]) : Self = where(FCqlWhereClause(Seq(clause),f))

override def cassandraCount(): Long = {
columnNames match {
case SomeColumns(_) =>
Expand All @@ -89,6 +124,7 @@ class CassandraJoinRDD[L, R] private[connector](
columnNames = SomeColumns(RowCountRef),
joinColumns = joinColumns,
where = where,
fwhere = fwhere,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf
Expand All @@ -106,13 +142,14 @@ class CassandraJoinRDD[L, R] private[connector](
columnNames = columnNames,
joinColumns = joinColumns,
where = where,
fwhere = fwhere,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf
)
}

private[rdd] def fetchIterator(
override private[rdd] def fetchIterator(
session: Session,
bsb: BoundStatementBuilder[L],
leftIterator: Iterator[L]
Expand All @@ -121,7 +158,6 @@ class CassandraJoinRDD[L, R] private[connector](
val rateLimiter = new RateLimiter(
readConf.throughputJoinQueryPerSec, readConf.throughputJoinQueryPerSec
)

def pairWithRight(left: L): SettableFuture[Iterator[(L, R)]] = {
val resultFuture = SettableFuture.create[Iterator[(L, R)]]
val leftSide = Iterator.continually(left)
Expand All @@ -141,6 +177,7 @@ class CassandraJoinRDD[L, R] private[connector](
resultFuture
}
val queryFutures = leftIterator.map(left => {

rateLimiter.maybeSleep(1)
pairWithRight(left)
}).toList
Expand All @@ -162,6 +199,7 @@ class CassandraJoinRDD[L, R] private[connector](
columnNames,
joinColumns,
where,
fwhere,
limit,
clusteringOrder,
readConf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class CassandraLeftJoinRDD[L, R] private[connector](
val columnNames: ColumnSelector = AllColumns,
val joinColumns: ColumnSelector = PartitionKeyColumns,
val where: CqlWhereClause = CqlWhereClause.empty,
val fwhere : FCqlWhereClause[L] = FCqlWhereClause.empty[L],
val limit: Option[Long] = None,
val clusteringOrder: Option[ClusteringOrder] = None,
val readConf: ReadConf = ReadConf(),
Expand Down Expand Up @@ -67,12 +68,36 @@ class CassandraLeftJoinRDD[L, R] private[connector](
columnNames = columnNames,
joinColumns = joinColumns,
where = where,
fwhere = fwhere,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf
)
}

// I was not able to do a proper copy because of the inheritance.
def setFWhere(
fwhere : FCqlWhereClause[L]
): Self = {

new CassandraLeftJoinRDD[L, R](
left = left,
keyspaceName = keyspaceName,
tableName = tableName,
connector = connector,
columnNames = columnNames,
joinColumns = joinColumns,
where = where,
fwhere = fwhere,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf
)
}

def where(f : FCqlWhereClause[L]) : Self = setFWhere(fwhere = fwhere and f)
def where(clause : String, f : L => Seq[Any]) : Self = where(FCqlWhereClause(Seq(clause),f))

override def cassandraCount(): Long = {
columnNames match {
case SomeColumns(_) =>
Expand All @@ -89,6 +114,7 @@ class CassandraLeftJoinRDD[L, R] private[connector](
columnNames = SomeColumns(RowCountRef),
joinColumns = joinColumns,
where = where,
fwhere = fwhere,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf
Expand All @@ -106,6 +132,7 @@ class CassandraLeftJoinRDD[L, R] private[connector](
columnNames = columnNames,
joinColumns = joinColumns,
where = where,
fwhere = fwhere,
limit = limit,
clusteringOrder = clusteringOrder,
readConf = readConf
Expand All @@ -127,6 +154,7 @@ class CassandraLeftJoinRDD[L, R] private[connector](
columnNames,
joinColumns,
where,
fwhere,
limit,
clusteringOrder,
readConf,
Expand All @@ -144,7 +172,6 @@ class CassandraLeftJoinRDD[L, R] private[connector](
val rateLimiter = new RateLimiter(
readConf.throughputJoinQueryPerSec, readConf.throughputJoinQueryPerSec
)

def pairWithRight(left: L): SettableFuture[Iterator[(L, Option[R])]] = {
val resultFuture = SettableFuture.create[Iterator[(L, Option[R])]]
val leftSide = Iterator.continually(left)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.datastax.spark.connector.writer

import com.datastax.driver.core._
import com.datastax.spark.connector.rdd.FCqlWhereClause
import com.datastax.spark.connector.types.{ColumnType, Unset}
import com.datastax.spark.connector.util.{CodecRegistryUtil, Logging}

Expand All @@ -13,6 +14,7 @@ private[connector] class BoundStatementBuilder[T](
val rowWriter: RowWriter[T],
val preparedStmt: PreparedStatement,
val prefixVals: Seq[Any] = Seq.empty,
val dependentValues : FCqlWhereClause[T] = FCqlWhereClause.empty[T],
val ignoreNulls: Boolean = false,
val protocolVersion: ProtocolVersion) extends Logging {

Expand Down Expand Up @@ -91,11 +93,21 @@ private[connector] class BoundStatementBuilder[T](
prefixConverter = ColumnType.converterToCassandra(prefixType)
} yield prefixConverter.convert(prefixVal)

private def variablesConverted(row : T): Seq[AnyRef] = {
val values = dependentValues.values(row)
for {
index <- 0 until values.length
value = values(index)
valueType = preparedStmt.getVariables.getType(prefixVals.length + index)
valueConverter = ColumnType.converterToCassandra(valueType)
} yield valueConverter.convert(value)
}

/** Creates `BoundStatement` from the given data item */
def bind(row: T): RichBoundStatement = {
val boundStatement = new RichBoundStatement(preparedStmt)
boundStatement.bind(prefixConverted: _*)

val variables = prefixConverted ++ variablesConverted(row)
boundStatement.bind(variables: _*)
rowWriter.readColumnValues(row, buffer)
var bytesCount = 0
for (i <- 0 until columnNames.size) {
Expand Down