Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to visulaize attentions #675

Open
pratikchhapolika opened this issue Mar 23, 2022 · 0 comments
Open

How to visulaize attentions #675

pratikchhapolika opened this issue Mar 23, 2022 · 0 comments

Comments

@pratikchhapolika
Copy link

pratikchhapolika commented Mar 23, 2022

Here is the code.

import sys
from absl import app
from absl import flags
from absl import logging

from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
# Use the regular GLUE data loaders, because these are very simple already.
from lit_nlp.examples.datasets import glue
from lit_nlp.lib import utils
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types as lit_types
import tensorflow_datasets as tfds
from transformers import BertTokenizer, BertForSequenceClassification
import pandas as pd
import torch
import transformers



df = pd.read_excel("data.xlsx",sheet_name='master_data')
print(df.shape)
df = df[df['train'] == 1]
df = df.head(100)
df = df[['UTTERANCE','label']]
df['label'] = df['label'].astype(int)
print(df.head(2))




def load_tfds(*args, do_sort=True, **kw):
    """Load from TFDS, with optional sorting."""
    # Materialize to NumPy arrays.
    # This also ensures compatibility with TF1.x non-eager mode, which doesn't
    # support direct iteration over a tf.data.Dataset.

    # ds = tfds.load('glue/sst2', split='train', shuffle_files=True,download=True)
    ret = df.values.tolist()
    print(ret)
    # if do_sort:
    #     # Recover original order, as if you loaded from a TSV file.
    #     ret.sort(key=lambda ex: ex['idx'])
    return ret



class SST2Data(lit_dataset.Dataset):
    """Stanford Sentiment Treebank, binary version (SST-2).
    See https://www.tensorflow.org/datasets/catalog/glue#gluesst2.
    """

    LABELS = ['0', '1']

    def __init__(self, data):
        self._examples = []
        for ex in load_tfds(df):
            self._examples.append({
                'sentence': ex[0],
                'label': self.LABELS[ex[1]],
            })

        print(self._examples)

    def spec(self):
        return {
            'sentence': lit_types.TextSegment(),
            'label': lit_types.CategoryLabel(vocab=self.LABELS)
        }



FLAGS = flags.FLAGS

FLAGS.set_default("development_demo", True)

flags.DEFINE_string(
    "model_path",
    "https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz",
    "Path to trained model, in standard transformers format, e.g. as "
    "saved by model.save_pretrained() and tokenizer.save_pretrained()")


def _from_pretrained(cls, *args, **kw):
    """Load a transformers model in PyTorch, with fallback to TF2/Keras weights."""
    try:
        return cls.from_pretrained(*args, **kw)
    except OSError as e:
        logging.warning("Caught OSError loading model: %s", e)
        logging.warning(
            "Re-trying to convert from TensorFlow checkpoint (from_tf=True)")
        return cls.from_pretrained(*args, from_tf=True, **kw)


class SimpleSentimentModel(lit_model.Model):
    """Simple sentiment analysis model."""

    LABELS = ["0", "1"]  # negative, positive

    def __init__(self, model_name_or_path):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        # This is a just a regular PyTorch model.
        self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2,output_hidden_states=True,output_attentions=True)
        self.model.eval()

    ##
    # LIT API implementation
    def max_minibatch_size(self):
        # This tells lit_model.Model.predict() how to batch inputs to
        # predict_minibatch().
        # Alternately, you can just override predict() and handle batching yourself.
        return 32

    def predict_minibatch(self, inputs):
        # Preprocess to ids and masks, and make the input batch.
        encoded_input = self.tokenizer.batch_encode_plus(
            [ex["sentence"] for ex in inputs],
            return_tensors="pt",
            add_special_tokens=True,
            max_length=256,
            padding="longest",
            truncation="longest_first")

        # Check and send to cuda (GPU) if available
        if torch.cuda.is_available():
            self.model.cuda()
            for tensor in encoded_input:
                encoded_input[tensor] = encoded_input[tensor].cuda()
        # Run a forward pass.
        with torch.no_grad():  # remove this if you need gradients.
            out: transformers.modeling_outputs.SequenceClassifierOutput = self.model(**encoded_input)

        # Post-process outputs.
        batched_outputs = {
            "probas": torch.nn.functional.softmax(out.logits, dim=-1),
            "input_ids": encoded_input["input_ids"],
            "ntok": torch.sum(encoded_input["attention_mask"], dim=1),
            "cls_emb": out.hidden_states[-1][:, 0],  # last layer, first token
        }
        # Return as NumPy for further processing.
        detached_outputs = {k: v.cpu().numpy() for k, v in batched_outputs.items()}
        # Unbatch outputs so we get one record per input example.
        for output in utils.unbatch_preds(detached_outputs):
            ntok = output.pop("ntok")
            output["tokens"] = self.tokenizer.convert_ids_to_tokens(
                output.pop("input_ids")[1:ntok - 1])
            yield output

    def input_spec(self) -> lit_types.Spec:
        return {
            "sentence": lit_types.TextSegment(),
            "label": lit_types.CategoryLabel(vocab=self.LABELS, required=False)
        }

    def output_spec(self) -> lit_types.Spec:
        return {
            "tokens": lit_types.Tokens(),
            "probas": lit_types.MulticlassPreds(parent="label", vocab=self.LABELS,
                                                null_idx=0),
            "cls_emb": lit_types.Embeddings()
        }


def get_wsgi_app():
    """Returns a LitApp instance for consumption by gunicorn."""
    FLAGS.set_default("server_type", "external")
    FLAGS.set_default("demo_mode", True)
    # Parse flags without calling app.run(main), to avoid conflict with
    # gunicorn command line flags.
    unused = flags.FLAGS(sys.argv, known_only=True)
    return main(unused)


def main(_):
    # Normally path is a directory; if it's an archive file, download and
    # extract to the transformers cache.
    model_path = FLAGS.model_path
    if model_path.endswith(".tar.gz"):
        model_path = transformers.file_utils.cached_path(
            model_path, extract_compressed_file=True)

    # Load the model we defined above.
    models = {"sst": SimpleSentimentModel(model_path)}
    # Load SST-2 validation set from TFDS.
    datasets = {"sst_dev": SST2Data(df)}

    # Start the LIT server. See server_flags.py for server options.
    lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
    return lit_demo.serve()


if __name__ == "__main__":
    app.run(main)



Screenshot 2022-03-23 at 10 53 00 PM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant