This repo contains a library for loading and storing TensorFlow records with Apache Spark. The library implements data import from the standard TensorFlow record format ([TFRecords] (https://www.tensorflow.org/how_tos/reading_data/)) into Spark SQL DataFrames, and data export from DataFrames to TensorFlow records.
This is the initial release of the spark-tensorflow-connector
repo.
None.
Build the library using Maven as shown below.
mvn clean install
Run this library in Spark using the --jars
command line option in spark-shell
or spark-submit
. For example:
$SPARK_HOME/bin/spark-shell --jars target/spark-tensorflow-connector-1.0-SNAPSHOT.jar,target/lib/tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar
The following code snippet demonstrates usage.
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
val path = s"$TF_SANDBOX_DIR/test-output.tfr"
val testRows: Array[Row] = Array(
new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
val schema = StructType(List(StructField("id", IntegerType),
StructField("IntegerTypelabel", IntegerType),
StructField("LongTypelabel", LongType),
StructField("FloatTypelabel", FloatType),
StructField("DoubleTypelabel", DoubleType),
StructField("vectorlabel", ArrayType(DoubleType, true)),
StructField("name", StringType)))
val rdd = spark.sparkContext.parallelize(testRows)
//Save DataFrame as TFRecords
val df: DataFrame = spark.createDataFrame(rdd, schema)
df.write.format("tensorflow").save(path)
//Read TFRecords into DataFrame.
//The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
val importedDf1: DataFrame = spark.read.format("tensorflow").load(path)
importedDf1.show()
//Read TFRecords into DataFrame using custom schema
val importedDf2: DataFrame = spark.read.format("tensorflow").schema(schema).load(path)
importedDf2.show()