-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
CSVMixedDataTypesLocal.java
197 lines (161 loc) · 9.74 KB
/
CSVMixedDataTypesLocal.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
/*******************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.datapipelineexamples.transform.basic;
import org.apache.commons.io.FileUtils;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.writer.RecordWriter;
import org.datavec.api.records.writer.impl.csv.CSVRecordWriter;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.split.partition.Partitioner;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.condition.ConditionOp;
import org.datavec.api.transform.condition.column.CategoricalColumnCondition;
import org.datavec.api.transform.condition.column.DoubleColumnCondition;
import org.datavec.api.transform.filter.ConditionFilter;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.time.DeriveColumnsFromTimeTransform;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.datapipelineexamples.utils.DownloaderUtility;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import java.io.File;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
/**
* Basic DataVec example for preprocessing operations on some simple CSV data. If you just want to load CSV data
* and pass it on for learning take a look at {@see org.deeplearning4j.examples.dataExample.CSVExample}.
*
* The premise here is that some data regarding transactions is available in CSV format, and we want to do some
* operations on this data, including:
* 1. Removing some unnecessary columns
* 2. Filtering examples to keep only examples with values "USA" or "CAN" for the "MerchantCountryCode" column
* 3. Replacing some invalid values in the "TransactionAmountUSD" column
* 4. Parsing the date string, and extracting the hour of day from it to create a new "HourOfDay" column
*
* @author Alex Black
*/
public class CSVMixedDataTypesLocal {
public static String dataLocalPath;
public static void main(String[] args) throws Exception {
dataLocalPath = DownloaderUtility.BASICDATAVECEXAMPLE.Download();
//=====================================================================
// Step 1: Define the input data schema
//=====================================================================
//Let's define the schema of the data that we want to import
//The order in which columns are defined here should match the order in which they appear in the input data
Schema inputDataSchema = new Schema.Builder()
//We can define a single column
.addColumnString("DateTimeString")
//Or for convenience define multiple columns of the same type
.addColumnsString("CustomerID", "MerchantID")
//We can define different column types for different types of data:
.addColumnInteger("NumItemsInTransaction")
.addColumnCategorical("MerchantCountryCode", Arrays.asList("USA","CAN","FR","MX"))
//Some columns have restrictions on the allowable values, that we consider valid:
.addColumnDouble("TransactionAmountUSD",0.0,null,false,false) //$0.0 or more, no maximum limit, no NaN and no Infinite values
.addColumnCategorical("FraudLabel", Arrays.asList("Fraud","Legit"))
.build();
//Print out the schema:
System.out.println("Input data schema details:");
System.out.println(inputDataSchema);
System.out.println("\n\nOther information obtainable from schema:");
System.out.println("Number of columns: " + inputDataSchema.numColumns());
System.out.println("Column names: " + inputDataSchema.getColumnNames());
System.out.println("Column types: " + inputDataSchema.getColumnTypes());
//=====================================================================
// Step 2: Define the operations we want to do
//=====================================================================
//Lets define some operations to execute on the data...
//We do this by defining a TransformProcess
//At each step, we identify column by the name we gave them in the input data schema, above
TransformProcess tp = new TransformProcess.Builder(inputDataSchema)
//Let's remove some column we don't need
.removeColumns("CustomerID","MerchantID")
//Now, suppose we only want to analyze transactions involving merchants in USA or Canada. Let's filter out
// everything except for those countries.
//Here, we are applying a conditional filter. We remove all of the examples that match the condition
// The condition is "MerchantCountryCode" isn't one of {"USA", "CAN"}
.filter(new ConditionFilter(
new CategoricalColumnCondition("MerchantCountryCode", ConditionOp.NotInSet, new HashSet<>(Arrays.asList("USA","CAN")))))
//Let's suppose our data source isn't perfect, and we have some invalid data: negative dollar amounts that we want to replace with 0.0
//For positive dollar amounts, we don't want to modify those values
//Use the ConditionalReplaceValueTransform on the "TransactionAmountUSD" column:
.conditionalReplaceValueTransform(
"TransactionAmountUSD", //Column to operate on
new DoubleWritable(0.0), //New value to use, when the condition is satisfied
new DoubleColumnCondition("TransactionAmountUSD",ConditionOp.LessThan, 0.0)) //Condition: amount < 0.0
//Finally, let's suppose we want to parse our date/time column in a format like "2016/01/01 17:50.000"
//We use JodaTime internally, so formats can be specified as follows: http://www.joda.org/joda-time/apidocs/org/joda/time/format/DateTimeFormat.html
.stringToTimeTransform("DateTimeString","YYYY-MM-DD HH:mm:ss.SSS", DateTimeZone.UTC)
//However, our time column ("DateTimeString") isn't a String anymore. So let's rename it to something better:
.renameColumn("DateTimeString", "DateTime")
//At this point, we have our date/time format stored internally as a long value (Unix/Epoch format): milliseconds since 00:00.000 01/01/1970
//Suppose we only care about the hour of the day. Let's derive a new column for that, from the DateTime column
.transform(new DeriveColumnsFromTimeTransform.Builder("DateTime")
.addIntegerDerivedColumn("HourOfDay", DateTimeFieldType.hourOfDay())
.build())
//We no longer need our "DateTime" column, as we've extracted what we need from it. So let's remove it
.removeColumns("DateTime")
//We've finished with the sequence of operations we want to do: let's create the final TransformProcess object
.build();
//After executing all of these operations, we have a new and different schema:
Schema outputSchema = tp.getFinalSchema();
System.out.println("\n\n\nSchema after transforming data:");
System.out.println(outputSchema);
//=====================================================================
// Step 3: Load our data and execute the operations locally
//=====================================================================
//Define input and output paths:
File inputFile = new File(dataLocalPath, "exampledata.csv");
File outputFile = new File("BasicDataVecExampleLocalOut.csv");
if(outputFile.exists()){
outputFile.delete();
}
outputFile.createNewFile();
//Define input reader and output writer:
RecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(inputFile));
RecordWriter rw = new CSVRecordWriter();
Partitioner p = new NumberOfRecordsPartitioner();
rw.initialize(new FileSplit(outputFile), p);
//Process the data:
List<List<Writable>> originalData = new ArrayList<>();
while(rr.hasNext()){
originalData.add(rr.next());
}
List<List<Writable>> processedData = LocalTransformExecutor.execute(originalData, tp);
rw.writeBatch(processedData);
rw.close();
//Print before + after:
System.out.println("\n\n---- Original Data File ----");
String originalFileContents = FileUtils.readFileToString(inputFile, Charset.defaultCharset());
System.out.println(originalFileContents);
System.out.println("\n\n---- Processed Data File ----");
String fileContents = FileUtils.readFileToString(outputFile, Charset.defaultCharset());
System.out.println(fileContents);
System.out.println("\n\nDONE");
}
}