Skip to content

Commit

Permalink
more gradient optimizer tests
Browse files Browse the repository at this point in the history
fix evaluate

more gradient optimizer tests

fix evaluate

Not needed assert removed# This is a combination of 3 commits.

more gradient optimizer tests

fix evaluate

Not needed assert removed
  • Loading branch information
novikov-alexander committed Nov 13, 2023
1 parent 8cab730 commit 1f3f39e
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 9 deletions.
113 changes: 113 additions & 0 deletions test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Linq;
using Tensorflow;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
Expand Down Expand Up @@ -67,6 +68,51 @@ public void TestBasic()
TestBasic<double>();
}

private void TestMinimizeResourceVariable<T>() where T : struct
{
var dtype = GetTypeForNumericType<T>();

// train.GradientDescentOptimizer is V1 only API.
tf.Graph().as_default();
using (var sess = self.cached_session())
{
var var0 = tf.Variable(new[] { new[] { 1.0, 2.0 } }, dtype: dtype);
var var1 = tf.Variable(new[] { 3.0 }, dtype: dtype);
var x = tf.constant(new[,] { { 4.0f }, { 5.0f } }, dtype: dtype);

var pred = math_ops.matmul(var0, x) + var1;
var loss = pred * pred;
var sgd_op = tf.train.GradientDescentOptimizer(3.0f).minimize(loss);

var global_variables = tf.global_variables_initializer();
sess.run(global_variables);

sess.run(new[] { var0, var1 });
// Fetch params to validate initial values
self.assertAllCloseAccordingToType<T>(new[,] { { 1.0, 2.0 } }, self.evaluate<T[,]>(var0));
self.assertAllCloseAccordingToType(new[] { 3.0 }, self.evaluate<T[]>(var1));
// Run 1 step of sgd
sgd_op.run();
// Validate updated params
var np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0;
var np_grad = 2 * np_pred;
self.assertAllCloseAccordingToType(
new[,] { { 1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0 } },
self.evaluate<T[,]>(var0));
self.assertAllCloseAccordingToType(
new[] { 3.0 - np_grad },
self.evaluate<T[]>(var1));
}
}

[TestMethod]
public void TestMinimizeResourceVariable()
{
//TODO: add np.half
TestMinimizeResourceVariable<float>();
TestMinimizeResourceVariable<double>();
}

private void TestTensorLearningRate<T>() where T : struct
{
var dtype = GetTypeForNumericType<T>();
Expand Down Expand Up @@ -115,5 +161,72 @@ public void TestTensorLearningRate()
TestTensorLearningRate<float>();
TestTensorLearningRate<double>();
}

public void TestGradWrtRef<T>() where T : struct
{
var dtype = GetTypeForNumericType<T>();

var graph = tf.Graph().as_default();
using (var sess = self.cached_session())
{
var opt = tf.train.GradientDescentOptimizer(3.0f);
var values = new[] { 1.0, 3.0 };
var vars_ = values.Select(
v => tf.Variable(new[] { v }, dtype: dtype) as IVariableV1
).ToList();
var grads_and_vars = opt.compute_gradients(tf.add(vars_[0], vars_[1]), vars_);
sess.run(tf.global_variables_initializer());
foreach (var (grad, _) in grads_and_vars)
self.assertAllCloseAccordingToType(new[] { 1.0 }, self.evaluate<T[]>(grad));

}
}

[TestMethod]
public void TestGradWrtRef()
{
TestGradWrtRef<float>();
TestGradWrtRef<double>();
}

public void TestWithGlobalStep<T>() where T : struct
{
var dtype = GetTypeForNumericType<T>();

tf.Graph().as_default();
using (var sess = self.cached_session())
{
var global_step = tf.Variable(0, trainable: false);
var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
var grads_and_vars = new[] {
Tuple.Create(grads0, var0 as IVariableV1),
Tuple.Create(grads1, var1 as IVariableV1)
};
var sgd_op = tf.train.GradientDescentOptimizer(3.0f)
.apply_gradients(grads_and_vars, global_step: global_step);

sess.run(tf.global_variables_initializer());
// Fetch params to validate initial values
self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
// Run 1 step of sgd
sgd_op.run();
// Validate updated params and global_step
self.assertAllCloseAccordingToType(new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, self.evaluate<T[]>(var0));
self.assertAllCloseAccordingToType(new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, self.evaluate<T[]>(var1));
Assert.AreEqual(1, self.evaluate<int>(global_step));
}

}

[TestMethod]
public void TestWithGlobalStep()
{
TestWithGlobalStep<float>();
TestWithGlobalStep<double>();
}
}
}
45 changes: 36 additions & 9 deletions test/Tensorflow.UnitTest/PythonTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ public int Compare(object? x, object? y)
return 1;
}

var a = (double)x;
var b = (double)y;
var a = Convert.ToDouble(x);
var b = Convert.ToDouble(y);

double delta = Math.Abs(a - b);
if (delta < _epsilon)
Expand All @@ -187,6 +187,19 @@ public int Compare(object? x, object? y)
}
}

public void assertAllCloseAccordingToType<T>(
double[,] expected,
T[,] given,
double eps = 1e-6,
float float_eps = 1e-6f)
{
Assert.AreEqual(expected.GetLength(0), given.GetLength(0));
Assert.AreEqual(expected.GetLength(1), given.GetLength(1));

var flattenGiven = given.Cast<T>().ToArray();
assertAllCloseAccordingToType(expected, flattenGiven, eps, float_eps);
}

public void assertAllCloseAccordingToType<T>(
ICollection expected,
ICollection<T> given,
Expand Down Expand Up @@ -267,21 +280,35 @@ public T evaluate<T>(Tensor tensor)
{
var sess = tf.get_default_session();
var ndarray = tensor.eval(sess);
if (typeof(T) == typeof(double)
|| typeof(T) == typeof(float)
|| typeof(T) == typeof(int))

if (typeof(T) == typeof(int))
{
int i = ndarray;
result = i;
}
else if (typeof(T) == typeof(float))
{
float f = ndarray;
result = f;
}
else if (typeof(T) == typeof(double))
{
result = Convert.ChangeType(ndarray, typeof(T));
double d = ndarray;
result = d;
}
else if (typeof(T) == typeof(double[]))
else if (
typeof(T) == typeof(double[])
|| typeof(T) == typeof(double[,]))
{
result = ndarray.ToMultiDimArray<double>();
}
else if (typeof(T) == typeof(float[]))
else if (typeof(T) == typeof(float[])
|| typeof(T) == typeof(float[,]))
{
result = ndarray.ToMultiDimArray<float>();
}
else if (typeof(T) == typeof(int[]))
else if (typeof(T) == typeof(int[])
|| typeof(T) == typeof(int[,]))
{
result = ndarray.ToMultiDimArray<int>();
}
Expand Down

0 comments on commit 1f3f39e

Please sign in to comment.