Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: tf__Dirichlet_SOS() missing 1 required positional argument: 't' #9

Open
koo-ec opened this issue Aug 17, 2021 · 2 comments

Comments

@koo-ec
Copy link

koo-ec commented Aug 17, 2021

First of all, thanks for your valuable contribution. The EDL concept is interesting.

I have tried the EDL for a simple classification task like:

import evidential_deep_learning as edl
import tensorflow as tf
import sklearn
import sklearn.datasets

iris = sklearn.datasets.load_iris()
train, test, labels_train, labels_test = sklearn.model_selection.train_test_split(iris.data, iris.target, train_size=0.80)

model = tf.keras.Sequential(
    [
        tf.keras.layers.Dense(4, activation="relu"),
        tf.keras.layers.Dense(64, activation="relu"),
        edl.layers.DenseDirichlet(3), # Evidential distribution!
    ]
)
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3), 
    loss=edl.losses.Dirichlet_SOS # Evidential loss!
)

history = model.fit(train, labels_train, batch_size=1024, epochs=32, verbose=0, validation_split=0.2)

However, I got the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-24-26b7527e8794> in <module>
     19 )
     20 
---> 21 history = model.fit(train, labels_train, batch_size=1024, epochs=32, verbose=0, validation_split=0.2)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1098                 _r=1):
   1099               callbacks.on_train_batch_begin(step)
-> 1100               tmp_logs = self.train_function(iterator)
   1101               if data_handler.should_sync:
   1102                 context.async_wait()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    826     tracing_count = self.experimental_get_tracing_count()
    827     with trace.Trace(self._name) as tm:
--> 828       result = self._call(*args, **kwds)
    829       compiler = "xla" if self._experimental_compile else "nonXla"
    830       new_tracing_count = self.experimental_get_tracing_count()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    869       # This is the first call of __call__, so we have to initialize.
    870       initializers = []
--> 871       self._initialize(args, kwds, add_initializers_to=initializers)
    872     finally:
    873       # At this point we know that the initialization is complete (or less

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    724     self._concrete_stateful_fn = (
    725         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 726             *args, **kwds))
    727 
    728     def invalid_creator_scope(*unused_args, **unused_kwds):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2967       args, kwargs = None, None
   2968     with self._lock:
-> 2969       graph_function, _ = self._maybe_define_function(args, kwargs)
   2970     return graph_function
   2971 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3359 
   3360           self._function_cache.missed.add(call_context_key)
-> 3361           graph_function = self._create_graph_function(args, kwargs)
   3362           self._function_cache.primary[cache_key] = graph_function
   3363 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3204             arg_names=arg_names,
   3205             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3206             capture_by_value=self._capture_by_value),
   3207         self._function_attributes,
   3208         function_spec=self.function_spec,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    988         _, original_func = tf_decorator.unwrap(python_func)
    989 
--> 990       func_outputs = python_func(*func_args, **func_kwargs)
    991 
    992       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    632             xla_context.Exit()
    633         else:
--> 634           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    635         return out
    636 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    975           except Exception as e:  # pylint:disable=broad-except
    976             if hasattr(e, "ag_error_metadata"):
--> 977               raise e.ag_error_metadata.to_exception(e)
    978             else:
    979               raise

TypeError: in user code:

    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:805 train_function  *
        return step_function(self, iterator)

    TypeError: tf__Dirichlet_SOS() missing 1 required positional argument: 't'

I was wondering if you could kindly help me to fix this problem.

@Ali-799
Copy link

Ali-799 commented Aug 19, 2021

Go to Dirichlet_SOS function and remove the remove parameter 't' from def Dirichlet_SOS(y, alpha, t). This isn't used anywhere in the loss function and calling the function without this parameter in model.compile, gives you an error..

@9527-ly
Copy link

9527-ly commented Nov 3, 2021

When I remove the parameter t, I try to run this code. But I found that the running result is always the third type with the greatest probability. Does anyone know how to solve this problem? @Asad-799 @koo-ec @aamini

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants