Skip to content

Commit

Permalink
Merge pull request #1182 from lingbai-kong/imdbfix
Browse files Browse the repository at this point in the history
fix: adjust imdb dataset loader for faster loading speed
  • Loading branch information
Oceania2018 committed Sep 24, 2023
2 parents 3811e4e + 9fb8479 commit 8e02682
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
29 changes: 17 additions & 12 deletions src/TensorFlowNET.Keras/Datasets/Imdb.cs
Expand Up @@ -112,35 +112,39 @@ public class Imdb

if (start_char != null)
{
int[,] new_x_train_array = new int[x_train_array.GetLength(0), x_train_array.GetLength(1) + 1];
for (var i = 0; i < x_train_array.GetLength(0); i++)
var (d1, d2) = (x_train_array.GetLength(0), x_train_array.GetLength(1));
int[,] new_x_train_array = new int[d1, d2 + 1];
for (var i = 0; i < d1; i++)
{
new_x_train_array[i, 0] = (int)start_char;
Array.Copy(x_train_array, i * x_train_array.GetLength(1), new_x_train_array, i * new_x_train_array.GetLength(1) + 1, x_train_array.GetLength(1));
Array.Copy(x_train_array, i * d2, new_x_train_array, i * (d2 + 1) + 1, d2);
}
int[,] new_x_test_array = new int[x_test_array.GetLength(0), x_test_array.GetLength(1) + 1];
for (var i = 0; i < x_test_array.GetLength(0); i++)
(d1, d2) = (x_test_array.GetLength(0), x_test_array.GetLength(1));
int[,] new_x_test_array = new int[d1, d2 + 1];
for (var i = 0; i < d1; i++)
{
new_x_test_array[i, 0] = (int)start_char;
Array.Copy(x_test_array, i * x_test_array.GetLength(1), new_x_test_array, i * new_x_test_array.GetLength(1) + 1, x_test_array.GetLength(1));
Array.Copy(x_test_array, i * d2, new_x_test_array, i * (d2 + 1) + 1, d2);
}
x_train_array = new_x_train_array;
x_test_array = new_x_test_array;
}
else if (index_from != 0)
{
for (var i = 0; i < x_train_array.GetLength(0); i++)
var (d1, d2) = (x_train_array.GetLength(0), x_train_array.GetLength(1));
for (var i = 0; i < d1; i++)
{
for (var j = 0; j < x_train_array.GetLength(1); j++)
for (var j = 0; j < d2; j++)
{
if (x_train_array[i, j] == 0)
break;
x_train_array[i, j] += index_from;
}
}
for (var i = 0; i < x_test_array.GetLength(0); i++)
(d1, d2) = (x_test_array.GetLength(0), x_test_array.GetLength(1));
for (var i = 0; i < d1; i++)
{
for (var j = 0; j < x_test_array.GetLength(1); j++)
for (var j = 0; j < d2; j++)
{
if (x_test_array[i, j] == 0)
break;
Expand Down Expand Up @@ -169,9 +173,10 @@ public class Imdb

if (num_words == null)
{
var (d1, d2) = (xs_array.GetLength(0), xs_array.GetLength(1));
num_words = 0;
for (var i = 0; i < xs_array.GetLength(0); i++)
for (var j = 0; j < xs_array.GetLength(1); j++)
for (var i = 0; i < d1; i++)
for (var j = 0; j < d2; j++)
num_words = max((int)num_words, (int)xs_array[i, j]);
}

Expand Down
8 changes: 5 additions & 3 deletions src/TensorFlowNET.Keras/Utils/data_utils.cs
Expand Up @@ -53,15 +53,17 @@ public static (int[,], long[]) _remove_long_seq(int maxlen, int[,] seq, long[] l
new_seq, new_label: shortened lists for `seq` and `label`.
*/
var nRow = seq.GetLength(0);
var nCol = seq.GetLength(1);
List<int[]> new_seq = new List<int[]>();
List<long> new_label = new List<long>();

for (var i = 0; i < seq.GetLength(0); i++)
for (var i = 0; i < nRow; i++)
{
if (maxlen < seq.GetLength(1) && seq[i, maxlen] != 0)
if (maxlen < nCol && seq[i, maxlen] != 0)
continue;
int[] sentence = new int[maxlen];
for (var j = 0; j < maxlen && j < seq.GetLength(1); j++)
for (var j = 0; j < maxlen && j < nCol; j++)
{
sentence[j] = seq[i, j];
}
Expand Down

0 comments on commit 8e02682

Please sign in to comment.