Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added a simple neural network trained at detecting rainfall with 10mi…
…nute resolution
- Loading branch information
Showing
3 changed files
with
146 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
from numpy.lib.stride_tricks import sliding_window_view | ||
import tensorflow as tf | ||
import pkg_resources | ||
|
||
def get_model_file_path(): | ||
return pkg_resources.resource_filename( | ||
"pycomlink", "/processing/wet_dry/mlp_model_files" | ||
) | ||
|
||
model = tf.keras.models.load_model(str(get_model_file_path() + "/model_mlp.keras")) | ||
|
||
def mlp_wet_dry( | ||
trsl_channel_1, | ||
trsl_channel_2, | ||
threshold=None, # 0.5 is often good, or argmax | ||
): | ||
""" | ||
Wet dry classification using a simple neural network based on channel 1 and channel 2 of a CML | ||
Parameters | ||
---------- | ||
trsl_channel_1 : iterable of float | ||
Time series of received signal level of channel 1 | ||
trsl_channel_2 : iterable of float | ||
Time series of received signal level of channel 2 | ||
threshold : float | ||
Threshold (0 - 1) for setting event as wet or dry. | ||
Returns | ||
------- | ||
iterable of float | ||
Time series of wet/dry probability or (if threshold is provided) | ||
wet dry classification | ||
References | ||
---------- | ||
""" | ||
# Normalization | ||
trsl_channel_1_norm = (trsl_channel_1 - np.nanmean(trsl_channel_1)) / np.nanstd(trsl_channel_1) | ||
trsl_channel_2_norm = (trsl_channel_2 - np.nanmean(trsl_channel_2)) / np.nanstd(trsl_channel_2) | ||
|
||
# add nan to start and end | ||
windowsize = 40 # use two channels | ||
x_start = np.ones([int(windowsize/2), windowsize*2])*np.nan | ||
x_end = np.ones([int(windowsize/2)- 1, windowsize*2])*np.nan | ||
|
||
# sliding window | ||
sliding_window_ch1 = sliding_window_view( | ||
trsl_channel_1_norm, | ||
window_shape = windowsize | ||
) | ||
|
||
sliding_window_ch2 = sliding_window_view( | ||
trsl_channel_2_norm, | ||
window_shape = windowsize | ||
) | ||
|
||
x_fts = np.vstack( | ||
[x_start, np.hstack([sliding_window_ch1, sliding_window_ch2]), x_end] | ||
) | ||
|
||
mlp_pred = np.zeros([x_fts.shape[0], 2])*np.nan | ||
indices = np.argwhere(~np.isnan(x_fts).any(axis = 1)).ravel() | ||
|
||
if indices.size > 0: # everything is nan, mlp_pred is then all nan | ||
mlp_pred_ = model.predict(x_fts[indices], verbose=0) | ||
mlp_pred[indices] = mlp_pred_ | ||
|
||
if threshold == None: | ||
return mlp_pred # | ||
else: | ||
mlp_pred = mlp_pred[:, 1] | ||
mlp_pred[indices] = mlp_pred[indices] > threshold | ||
return mlp_pred |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import unittest | ||
import numpy as np | ||
from pycomlink.processing.wet_dry.mlp import mlp_wet_dry | ||
|
||
class Testmlppred(unittest.TestCase): | ||
def test_mlppred(self): | ||
# generate random array | ||
trsl_channel_1 = np.arange(0, 60 * 8).astype(float) | ||
trsl_channel_2 = np.arange(0, 60 * 8).astype(float) | ||
|
||
trsl_channel_1[310] = np.nan # shorter window than in cnn | ||
|
||
pred_raw = mlp_wet_dry( | ||
trsl_channel_1, | ||
trsl_channel_2, | ||
threshold=None, | ||
)[:, 1] | ||
|
||
pred = mlp_wet_dry( | ||
trsl_channel_1, | ||
trsl_channel_2, | ||
threshold=0.1, # low threshold for testing | ||
) | ||
|
||
# check if length of array is the same | ||
assert len(pred_raw) == 60 * 8 | ||
assert len(pred) == 60 * 8 | ||
|
||
# check if array is as expected | ||
truth_raw = np.array( | ||
[ | ||
0.08784304, | ||
0.08941595, | ||
0.09101421, | ||
0.09263814, | ||
0.09428804, | ||
0.09596423, | ||
0.09766698, | ||
0.09939668, | ||
0.10115347, | ||
0.10293788, | ||
0.10475004, | ||
np.nan, | ||
np.nan, | ||
] | ||
) | ||
truth = np.array( | ||
[ | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
1, | ||
1, | ||
1, | ||
np.nan, | ||
np.nan, | ||
] | ||
) | ||
|
||
np.testing.assert_almost_equal(pred[280:293], truth) | ||
np.testing.assert_almost_equal( | ||
np.round(pred_raw, decimals=7)[280:293], truth_raw | ||
) |