Skip to content

Commit

Permalink
[SQL] SPARK-1732 - Support for null primitive values.
Browse files Browse the repository at this point in the history
I also removed a println that I bumped into.

Author: Michael Armbrust <michael@databricks.com>

Closes mesos#658 from marmbrus/nullPrimitives and squashes the following commits:

a3ec4f3 [Michael Armbrust] Remove println.
695606b [Michael Armbrust] Support for null primatives from using scala and java reflection.
  • Loading branch information
marmbrus authored and mateiz committed May 6, 2014
1 parent a2262cd commit 3c64750
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ object ScalaReflection {
case t if t <:< typeOf[Product] =>
val params = t.member("<init>": TermName).asMethod.paramss
StructType(
params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true)))
params.head.map(p =>
StructField(p.name.toString, schemaFor(p.typeSignature), nullable = true)))
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => BinaryType
case t if t <:< typeOf[Array[_]] =>
Expand All @@ -58,6 +59,17 @@ object ScalaReflection {
case t if t <:< typeOf[String] => StringType
case t if t <:< typeOf[Timestamp] => TimestampType
case t if t <:< typeOf[BigDecimal] => DecimalType
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
schemaFor(optType)
case t if t <:< typeOf[java.lang.Integer] => IntegerType
case t if t <:< typeOf[java.lang.Long] => LongType
case t if t <:< typeOf[java.lang.Double] => DoubleType
case t if t <:< typeOf[java.lang.Float] => FloatType
case t if t <:< typeOf[java.lang.Short] => ShortType
case t if t <:< typeOf[java.lang.Byte] => ByteType
case t if t <:< typeOf[java.lang.Boolean] => BooleanType
// TODO: The following datatypes could be marked as non-nullable.
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
case t if t <:< definitions.DoubleTpe => DoubleType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ class JavaSQLContext(sparkContext: JavaSparkContext) {
case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
case c: Class[_] if c == java.lang.Float.TYPE => FloatType
case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType

case c: Class[_] if c == classOf[java.lang.Short] => ShortType
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
case c: Class[_] if c == classOf[java.lang.Long] => LongType
case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
case c: Class[_] if c == classOf[java.lang.Float] => FloatType
case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
}
// TODO: Nullability could be stricter.
AttributeReference(property.getName, dataType, nullable = true)()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow}
/**
* A result row from a SparkSQL query.
*/
class Row(row: ScalaRow) extends Serializable {
class Row(private[spark] val row: ScalaRow) extends Serializable {

/** Returns the number of columns present in this Row. */
def length: Int = row.length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ case class Sort(
@DeveloperApi
object ExistingRdd {
def convertToCatalyst(a: Any): Any = a match {
case o: Option[_] => o.orNull
case s: Seq[Any] => s.map(convertToCatalyst)
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
case other => other
Expand All @@ -180,7 +181,7 @@ object ExistingRdd {
bufferedIterator.map { r =>
var i = 0
while (i < mutableRow.length) {
mutableRow(i) = r.productElement(i)
mutableRow(i) = convertToCatalyst(r.productElement(i))
i += 1
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,24 @@ case class ReflectData(
timestampField: Timestamp,
seqInt: Seq[Int])

case class NullReflectData(
intField: java.lang.Integer,
longField: java.lang.Long,
floatField: java.lang.Float,
doubleField: java.lang.Double,
shortField: java.lang.Short,
byteField: java.lang.Byte,
booleanField: java.lang.Boolean)

case class OptionalReflectData(
intField: Option[Int],
longField: Option[Long],
floatField: Option[Float],
doubleField: Option[Double],
shortField: Option[Short],
byteField: Option[Byte],
booleanField: Option[Boolean])

case class ReflectBinary(data: Array[Byte])

class ScalaReflectionRelationSuite extends FunSuite {
Expand All @@ -48,6 +66,22 @@ class ScalaReflectionRelationSuite extends FunSuite {
assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq)
}

test("query case class RDD with nulls") {
val data = NullReflectData(null, null, null, null, null, null, null)
val rdd = sparkContext.parallelize(data :: Nil)
rdd.registerAsTable("reflectNullData")

assert(sql("SELECT * FROM reflectNullData").collect().head === Seq.fill(7)(null))
}

test("query case class RDD with Nones") {
val data = OptionalReflectData(None, None, None, None, None, None, None)
val rdd = sparkContext.parallelize(data :: Nil)
rdd.registerAsTable("reflectOptionalData")

assert(sql("SELECT * FROM reflectOptionalData").collect().head === Seq.fill(7)(null))
}

// Equality is broken for Arrays, so we test that separately.
test("query binary data") {
val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ class PersonBean extends Serializable {
var age: Int = _
}

class AllTypesBean extends Serializable {
@BeanProperty var stringField: String = _
@BeanProperty var intField: java.lang.Integer = _
@BeanProperty var longField: java.lang.Long = _
@BeanProperty var floatField: java.lang.Float = _
@BeanProperty var doubleField: java.lang.Double = _
@BeanProperty var shortField: java.lang.Short = _
@BeanProperty var byteField: java.lang.Byte = _
@BeanProperty var booleanField: java.lang.Boolean = _
}

class JavaSQLSuite extends FunSuite {
val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext)
val javaSqlCtx = new JavaSQLContext(javaCtx)
Expand All @@ -50,4 +61,54 @@ class JavaSQLSuite extends FunSuite {
schemaRDD.registerAsTable("people")
javaSqlCtx.sql("SELECT * FROM people").collect()
}

test("all types in JavaBeans") {
val bean = new AllTypesBean
bean.setStringField("")
bean.setIntField(0)
bean.setLongField(0)
bean.setFloatField(0.0F)
bean.setDoubleField(0.0)
bean.setShortField(0.toShort)
bean.setByteField(0.toByte)
bean.setBooleanField(false)

val rdd = javaCtx.parallelize(bean :: Nil)
val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
schemaRDD.registerAsTable("allTypes")

assert(
javaSqlCtx.sql(
"""
|SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField,
| booleanField
|FROM allTypes
""".stripMargin).collect.head.row ===
Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false))
}

test("all types null in JavaBeans") {
val bean = new AllTypesBean
bean.setStringField(null)
bean.setIntField(null)
bean.setLongField(null)
bean.setFloatField(null)
bean.setDoubleField(null)
bean.setShortField(null)
bean.setByteField(null)
bean.setBooleanField(null)

val rdd = javaCtx.parallelize(bean :: Nil)
val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
schemaRDD.registerAsTable("allTypes")

assert(
javaSqlCtx.sql(
"""
|SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField,
| booleanField
|FROM allTypes
""".stripMargin).collect.head.row ===
Seq.fill(8)(null))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ import java.nio.ByteBuffer

import org.scalatest.FunSuite

import org.apache.spark.sql.Logging
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer

class ColumnTypeSuite extends FunSuite {
class ColumnTypeSuite extends FunSuite with Logging {
val DEFAULT_BUFFER_SIZE = 512

test("defaultSize") {
Expand Down Expand Up @@ -163,7 +164,7 @@ class ColumnTypeSuite extends FunSuite {

buffer.rewind()
seq.foreach { expected =>
println("buffer = " + buffer + ", expected = " + expected)
logger.info("buffer = " + buffer + ", expected = " + expected)
val extracted = columnType.extract(buffer)
assert(
expected === extracted,
Expand Down

0 comments on commit 3c64750

Please sign in to comment.