Skip to content

joyeshmishra/spark-tensorflow-connector

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

spark-tensorflow-connector

This repo contains a library for loading and storing TensorFlow records with Apache Spark. The library implements data import from the standard TensorFlow record format ([TFRecords] (https://www.tensorflow.org/how_tos/reading_data/)) into Spark SQL DataFrames, and data export from DataFrames to TensorFlow records.

What's new

This is the initial release of the spark-tensorflow-connector repo.

Known issues

None.

Prerequisites

  1. Apache Spark 2.0 (or later)

  2. Apache Maven

Building the library

Build the library using Maven as shown below.

mvn clean install

Using Spark Shell

Run this library in Spark using the --jars command line option in spark-shell or spark-submit. For example:

$SPARK_HOME/bin/spark-shell --jars target/spark-tensorflow-connector-1.0-SNAPSHOT.jar,target/lib/tensorflow-hadoop-1.0-01232017-SNAPSHOT-shaded-protobuf.jar

The following code snippet demonstrates usage.

import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

val path = s"$TF_SANDBOX_DIR/test-output.tfr"
val testRows: Array[Row] = Array(
new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
val schema = StructType(List(StructField("id", IntegerType), 
                             StructField("IntegerTypelabel", IntegerType), 
                             StructField("LongTypelabel", LongType), 
                             StructField("FloatTypelabel", FloatType), 
                             StructField("DoubleTypelabel", DoubleType), 
                             StructField("vectorlabel", ArrayType(DoubleType, true)), 
                             StructField("name", StringType)))
                             
val rdd = spark.sparkContext.parallelize(testRows)

//Save DataFrame as TFRecords
val df: DataFrame = spark.createDataFrame(rdd, schema)
df.write.format("tensorflow").save(path)

//Read TFRecords into DataFrame.
//The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
val importedDf1: DataFrame = spark.read.format("tensorflow").load(path)
importedDf1.show()

//Read TFRecords into DataFrame using custom schema
val importedDf2: DataFrame = spark.read.format("tensorflow").schema(schema).load(path)
importedDf2.show()

Releases

No releases published

Packages

No packages published

Languages

  • Scala 100.0%