-
Notifications
You must be signed in to change notification settings - Fork 26
/
train
executable file
路100 lines (80 loc) 路 3.79 KB
/
train
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python
# A sample training component that trains a simple scikit-learn decision tree model.
# This implementation works in File mode and makes no assumptions about the input file names.
# Input is specified as CSV with a data point in each row and the labels in the first column.
from __future__ import print_function
import json
import os
import pickle
import sys
import traceback
import neptune
import neptune.integrations.sklearn as npt_utils
import pandas as pd
from sklearn import tree
# These are the paths to where SageMaker mounts interesting things in your container.
prefix = "/opt/ml/"
input_path = prefix + "input/data"
output_path = os.path.join(prefix, "output")
model_path = os.path.join(prefix, "model")
param_path = os.path.join(prefix, "input/config/hyperparameters.json")
# This algorithm has a single channel of input data called 'training'. Since we run in
# File mode, the input files are copied to the directory specified here.
channel_name = "training"
training_path = os.path.join(input_path, channel_name)
# The function to execute the training.
def train():
run = neptune.init_run(
tags=["sagemaker"],
source_files=[sys.argv[0]],
)
print("Starting the training.")
try:
# Read in any hyperparameters that the user passed with the training job
with open(param_path, "r") as tc:
trainingParams = json.load(tc)
# Take the set of files and read them all into a single pandas dataframe
input_files = [os.path.join(training_path, file) for file in os.listdir(training_path)]
if len(input_files) == 0:
raise ValueError(
(
"There are no files in {}.\n"
+ "This usually indicates that the channel ({}) was incorrectly specified,\n"
+ "the data specification in S3 was incorrectly specified or the role specified\n"
+ "does not have permission to access the data."
).format(training_path, channel_name)
)
raw_data = [pd.read_csv(file, header=None) for file in input_files if file.endswith(".csv")]
train_data = pd.concat(raw_data)
# labels are in the first column
train_y = train_data.iloc[:, 0]
train_X = train_data.iloc[:, 1:]
# Here we only support a single hyperparameter. Note that hyperparameters are always passed in as
# strings, so we need to do any necessary conversions.
max_leaf_nodes = trainingParams.get("max_leaf_nodes", None)
if max_leaf_nodes is not None:
max_leaf_nodes = int(max_leaf_nodes)
# Now use scikit-learn's decision tree classifier to train the model.
clf = tree.DecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes)
clf = clf.fit(train_X, train_y)
run["cls_summary"] = npt_utils.create_classifier_summary(
clf, train_X, train_X, train_y, train_y
)
# save the model
with open(os.path.join(model_path, "decision-tree-model.pkl"), "wb") as out:
pickle.dump(clf, out)
print("Training complete.")
except Exception as e:
# Write out an error file. This will be returned as the failureReason in the
# DescribeTrainingJob result.
trc = traceback.format_exc()
with open(os.path.join(output_path, "failure"), "w") as s:
s.write("Exception during training: " + str(e) + "\n" + trc)
# Printing this causes the exception to be in the training job logs, as well.
print("Exception during training: " + str(e) + "\n" + trc, file=sys.stderr)
# A non-zero exit code causes the training job to be marked as Failed.
sys.exit(255)
if __name__ == "__main__":
train()
# A zero exit code causes the job to be marked a Succeeded.
sys.exit(0)