New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added a simple MLP neural network for wet-dry classification #146
Changes from 1 commit
b7579a9
19d4a30
e1613f3
51380a4
560e069
a2dee2f
b7c99d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
from numpy.lib.stride_tricks import sliding_window_view | ||
import tensorflow as tf | ||
import pkg_resources | ||
|
||
def get_model_file_path(): | ||
return pkg_resources.resource_filename( | ||
"pycomlink", "/processing/wet_dry/mlp_model_files" | ||
) | ||
|
||
model = tf.keras.models.load_model(str(get_model_file_path() + "/model_mlp.keras")) | ||
|
||
def mlp_wet_dry( | ||
trsl_channel_1, | ||
trsl_channel_2, | ||
threshold=None, # 0.5 is often good, or argmax | ||
): | ||
""" | ||
Wet dry classification using a simple neural network based on channel 1 and channel 2 of a CML | ||
|
||
Parameters | ||
---------- | ||
trsl_channel_1 : iterable of float | ||
Time series of received signal level of channel 1 | ||
trsl_channel_2 : iterable of float | ||
Time series of received signal level of channel 2 | ||
threshold : float | ||
Threshold (0 - 1) for setting event as wet or dry. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
update: just saw that |
||
|
||
|
||
Returns | ||
------- | ||
iterable of float | ||
Time series of wet/dry probability or (if threshold is provided) | ||
wet dry classification | ||
|
||
References | ||
---------- | ||
|
||
|
||
""" | ||
# Normalization | ||
trsl_channel_1_norm = (trsl_channel_1 - np.nanmean(trsl_channel_1)) / np.nanstd(trsl_channel_1) | ||
trsl_channel_2_norm = (trsl_channel_2 - np.nanmean(trsl_channel_2)) / np.nanstd(trsl_channel_2) | ||
|
||
# add nan to start and end | ||
windowsize = 40 # use two channels | ||
x_start = np.ones([int(windowsize/2), windowsize*2])*np.nan | ||
x_end = np.ones([int(windowsize/2)- 1, windowsize*2])*np.nan | ||
|
||
# sliding window | ||
sliding_window_ch1 = sliding_window_view( | ||
trsl_channel_1_norm, | ||
window_shape = windowsize | ||
) | ||
|
||
sliding_window_ch2 = sliding_window_view( | ||
trsl_channel_2_norm, | ||
window_shape = windowsize | ||
) | ||
|
||
x_fts = np.vstack( | ||
[x_start, np.hstack([sliding_window_ch1, sliding_window_ch2]), x_end] | ||
) | ||
|
||
mlp_pred = np.zeros([x_fts.shape[0], 2])*np.nan | ||
indices = np.argwhere(~np.isnan(x_fts).any(axis = 1)).ravel() | ||
|
||
if indices.size > 0: # everything is nan, mlp_pred is then all nan | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly, this if-statement is not true if we have an all-NaN in the sample and thus also do not do any prediction. I find the comment missleading, since, if I understand correctly, it explains what happens in the case the if-statement is not true. Can you adjust to make this clearer. |
||
mlp_pred_ = model.predict(x_fts[indices], verbose=0) | ||
mlp_pred[indices] = mlp_pred_ | ||
|
||
if threshold == None: | ||
return mlp_pred # | ||
else: | ||
mlp_pred = mlp_pred[:, 1] | ||
mlp_pred[indices] = mlp_pred[indices] > threshold | ||
return mlp_pred |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import unittest | ||
import numpy as np | ||
from pycomlink.processing.wet_dry.mlp import mlp_wet_dry | ||
|
||
class Testmlppred(unittest.TestCase): | ||
def test_mlppred(self): | ||
# generate random array | ||
trsl_channel_1 = np.arange(0, 60 * 8).astype(float) | ||
trsl_channel_2 = np.arange(0, 60 * 8).astype(float) | ||
|
||
trsl_channel_1[310] = np.nan # shorter window than in cnn | ||
|
||
pred_raw = mlp_wet_dry( | ||
trsl_channel_1, | ||
trsl_channel_2, | ||
threshold=None, | ||
)[:, 1] | ||
|
||
pred = mlp_wet_dry( | ||
trsl_channel_1, | ||
trsl_channel_2, | ||
threshold=0.1, # low threshold for testing | ||
) | ||
|
||
# check if length of array is the same | ||
assert len(pred_raw) == 60 * 8 | ||
assert len(pred) == 60 * 8 | ||
|
||
# check if array is as expected | ||
truth_raw = np.array( | ||
[ | ||
0.08784304, | ||
0.08941595, | ||
0.09101421, | ||
0.09263814, | ||
0.09428804, | ||
0.09596423, | ||
0.09766698, | ||
0.09939668, | ||
0.10115347, | ||
0.10293788, | ||
0.10475004, | ||
np.nan, | ||
np.nan, | ||
] | ||
) | ||
truth = np.array( | ||
[ | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
1, | ||
1, | ||
1, | ||
np.nan, | ||
np.nan, | ||
] | ||
) | ||
|
||
np.testing.assert_almost_equal(pred[280:293], truth) | ||
np.testing.assert_almost_equal( | ||
np.round(pred_raw, decimals=7)[280:293], truth_raw | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you state here some more details or is there a document that you can reference?
E.g. what are the details of the network (MLP, but how many neurons, layers)? What it the sample length, i.e. what is the minimum length of the time series that has to be supplied? Explain if and how the model is applied in a sliding window. How is the NaN handling?
I know that the CNN wet-dry also has very little info in the doc string, but it has the paper with many details. (not saying that we need a paper or somehting similar here...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will provide this somehow, Max gave me this idea of publishing it as a technical note somewhere..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can just be 3-4 lines of text in the doc string. That will be sufficient. But right now the user as absolutely no idea what the function uses. Of course, feel free to write a "technical note" paper any time ;-)