Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the self.model.predict crash #9

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions sources/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import keras.backend.tensorflow_backend as backend
from keras.optimizers import Adam
from keras.models import load_model, Model
from keras.backend import set_session
sys.stdin = stdin
sys.stderr = stderr

Expand All @@ -36,6 +37,8 @@ def __init__(self, model_path=False, id=None):
# Set to show an output from Conv2D layer
self.show_conv_cam = (id + 1) in settings.CONV_CAM_AGENTS

self.sess = tf.Session()

# Main model (agent does not use target model)
self.model = self.create_model(prediction=True)

Expand All @@ -48,6 +51,8 @@ def __init__(self, model_path=False, id=None):
# Load or create a new model (loading a model is being used only when playing or by trainer class that inherits from agent)
def create_model(self, prediction=False):

set_session(self.sess)

# If there is a patht to the model set, load model
if self.model_path:
model = load_model(self.model_path)
Expand Down
4 changes: 4 additions & 0 deletions sources/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)
import keras.backend.tensorflow_backend as backend
from keras.backend import set_session
sys.stdin = stdin
sys.stderr = stderr

Expand All @@ -29,6 +30,8 @@
class ARTDQNTrainer(ARTDQNAgent):
def __init__(self, model_path):

self.sess = tf.Session()

# If model path is beiong passed in - use it instead of creating a new one
self.model_path = model_path
self.model = self.create_model()
Expand Down Expand Up @@ -107,6 +110,7 @@ def train(self):
current_states.append((np.array([[transition[0][1]] for transition in minibatch]) - 50) / 50)
# We need to use previously saved graph here as this is going to be called from separate thread
with self.graph.as_default():
set_session(self.sess)
current_qs_list = self.model.predict(current_states, settings.PREDICTION_BATCH_SIZE)

# Get future states from minibatch, then query NN model for Q values
Expand Down