Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prediction problem #37

Open
Alek-dr opened this issue Jan 27, 2019 · 3 comments
Open

Prediction problem #37

Alek-dr opened this issue Jan 27, 2019 · 3 comments

Comments

@Alek-dr
Copy link

Alek-dr commented Jan 27, 2019

Amazing work!
But, there are one strange thing that I can't figure out. I'm trying to train my own dataset - almost mnist, but slightly extended to 21 symbol. Records exactly like mnist, loader too, and so on, but, one random class never predicted (for example 7). I have no idea why. Outputs contains 21 labels, but 7'th label always has very very small values.
tydvdtwgtog

@naturomics
Copy link
Owner

I didn't encounter such problem. Could you tell me these:

  1. except for the data reading pipeline, did you modify the code and which part;
  2. Is the number of training samples for each class balanced?

@Alek-dr
Copy link
Author

Alek-dr commented Jan 28, 2019

Dataset is balanced
but I wrote my own test code

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.framework.errors_impl import OutOfRangeError
from sklearn.metrics import accuracy_score, confusion_matrix
from seaborn import heatmap
import capslayer as cl
import os
import cv2

chars_map = {
    1: '1',
    2: '2',
    3: '3',
    4: '4',
    5: '5',
    6: '6',
    7: '7',
    8: '8',
    9: '9',
    10: '0',
    11: "A",
    12: "B",
    13: "C",
    14: "E",
    15: "H",
    16: "K",
    17: "M",
    18: "P",
    19: "T",
    20: "X",
    21: "Y"
}

WIDTH = 20
HEIGHT = 30

def elem_conv(elem):
    image = elem['images']
    image = np.reshape(image, newshape=(HEIGHT,WIDTH))
    label = elem['labels']
    return image, label

def parse_fn(serialized_example):
    features = tf.parse_single_example(serialized_example,
                                       features={'image': tf.FixedLenFeature([], tf.string),
                                                 'label': tf.FixedLenFeature([], tf.int64),
                                                 'height': tf.FixedLenFeature([], tf.int64),
                                                 'width': tf.FixedLenFeature([], tf.int64),
                                                 'depth': tf.FixedLenFeature([], tf.int64)})
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    depth = tf.cast(features['depth'], tf.int32)
    image = tf.decode_raw(features['image'], tf.float32)
    image = tf.reshape(image, shape=[height * width * depth])
    image.set_shape([HEIGHT * WIDTH * 1])
    image = tf.cast(image, tf.float32) * (1. / 255)
    label = tf.cast(features['label'], tf.int32)
    features = {'images': image, 'labels': label}
    return (features)

def get_model():
    # Vector CapsNet
    num_label = 21
    in_images = tf.placeholder(tf.float32, [None, HEIGHT*WIDTH])

    with tf.variable_scope('Conv1_layer'):
        # Conv1, return with shape [batch_size, 20, 20, 256]
        inputs = tf.reshape(in_images, shape=[-1, HEIGHT, WIDTH, 1])
        conv1 = tf.layers.conv2d(inputs,
                                 filters=256,
                                 kernel_size=9,
                                 strides=1,
                                 padding='VALID',
                                 activation=tf.nn.relu)

    with tf.variable_scope('PrimaryCaps_layer'):
        primaryCaps, activation = cl.layers.primaryCaps(conv1,
                                                        filters=32,
                                                        kernel_size=9,
                                                        strides=2,
                                                        out_caps_dims=[8, 1],
                                                        method="norm")

    with tf.variable_scope('DigitCaps_layer'):
        routing_method = "EMRouting"
        num_inputs = np.prod(cl.shape(primaryCaps)[1:4])
        primaryCaps = tf.reshape(primaryCaps, shape=[-1, num_inputs, 8, 1])
        activation = tf.reshape(activation, shape=[-1, num_inputs])
        poses, probs = cl.layers.dense(primaryCaps,
                                                 activation,
                                                 num_outputs=num_label,
                                                 out_caps_dims=[16, 1],
                                                 routing_method=routing_method)

        # Decoder structure
        # Reconstructe the inputs with 3 FC layers
    with tf.variable_scope('Decoder'):
        logits_idx = tf.to_int32(tf.argmax(cl.softmax(probs, axis=1), axis=1))
        labels = tf.one_hot(logits_idx, depth=num_label, axis=-1, dtype=tf.float32)
        labels_one_hoted = tf.reshape(labels, (-1, num_label, 1, 1))
        masked_caps = tf.multiply(poses, labels_one_hoted)
        num_inputs = np.prod(masked_caps.get_shape().as_list()[1:])
        active_caps = tf.reshape(masked_caps, shape=(-1, num_inputs))
        fc1 = tf.layers.dense(active_caps, units=512, activation=tf.nn.relu)
        fc2 = tf.layers.dense(fc1, units=1024, activation=tf.nn.relu)
        num_outputs = HEIGHT * WIDTH * 1
        recon_imgs = tf.layers.dense(fc2,
                                          units=num_outputs,
                                          activation=tf.sigmoid)
        recon_imgs = tf.reshape(recon_imgs, shape=[-1, HEIGHT, WIDTH, 1])

    return in_images, recon_imgs, probs

def show_reconstruct(original, reconstruct, true_lbl, pred_lbl, lbl=None):
    if lbl is not None:
        ind = np.where(true_lbl==lbl)[0][0]
    else:
        ind = 0
    original_image = original[ind]
    reconstruct_image = reconstruct[0]
    original_image = np.reshape(original_image, newshape=(HEIGHT,WIDTH))
    true_lbl = chars_map[true_lbl[ind]]
    pred_lbl = chars_map[pred_lbl[ind]]
    title = true_lbl + 20*' ' + pred_lbl
    res = np.hstack((original_image, reconstruct_image))
    plt.imshow(res, cmap='gray')
    plt.title(title)
    plt.show()

def test(records):
    """
    param records: list of .record files
    """
    batch_size = 128
    dataset = tf.data.TFRecordDataset(records)
    dataset = dataset.map(parse_fn).batch(batch_size).repeat(1).shuffle(buffer_size=5000, seed=3)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    inputs, recon_imgs, labels_one_hoted = get_model()

    saver = tf.train.Saver()

    true_labels, predicted = [], []
    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(os.path.dirname('../models/models/results/logdir/model.ckpt-6600'))
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        while True:
            try:
                elem = sess.run(next_element)

                raw_images = elem['images']
                true_lbls = elem['labels'] + 1

                reconstructed, pred_lbls = sess.run([recon_imgs,labels_one_hoted], feed_dict={inputs : raw_images})

                reconstructed = np.squeeze(reconstructed)
                pred_lbls = np.argmax(pred_lbls, axis=1) + 1

                #show_reconstruct(raw_images, reconstructed, true_lbls, pred_lbls, lbl=9)

                predicted.extend(pred_lbls)
                true_labels.extend(true_lbls)

            except OutOfRangeError as ex:
                break

    labels_id = np.arange(1, 11, 1).astype(np.int16)
    labels = [chars_map[lbl] for lbl in labels_id]
    conf_matr = confusion_matrix(true_labels, predicted)
    ax = heatmap(conf_matr, annot=True, fmt='d')
    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)
    plt.show()

if __name__ == '__main__':
    records = ['data/symbols/eval_symbols.tfrecord']
    test(records=records)

Checkpoint 6600 is not matter, I trained 50000 steps, and problem was the same

@naturomics
Copy link
Owner

The code looks fine. I'm not sure what the problem is. What I can guess is that capsnet might have bias problem. But before making this conclusion, I suggest:

  1. visualize the input image and print its corresponding label (both for training and validation set) to make sure the dataset is right;
  2. then remove one or more class from dataset (not the 7th class), and train the model from scratch and test it again. To see if the 7th or any others class were never predicted.

Of course it might be an implementation problem, I will check my code again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants