'''Explore representations

Commands to generate the results in the paper, feel free to adapt
:-). The tokenization here differs slightly vs. the tokenization in
the prediction experiments because we needed to maintain subtoken
identity for identifiability/self-attention distance.

BERT

for test in {0,1}; do CUDA_VISIBLE_DEVICES=0 python
explore_representations.py --model bert-base-uncased --lowercase 1
--test $test; done;

for test in {0,1}; do CUDA_VISIBLE_DEVICES=0 python
explore_representations.py --model bert-base-cased --lowercase 0
--test $test; done;

RoBERTa

for case in {0,1}; do for test in {0,1}; do CUDA_VISIBLE_DEVICES=0
python explore_representations.py --model roberta-base --lowercase
$case --test $test; done; done;

'''
import argparse
import torch
import numpy as np
import scipy.stats
import copy
import time
import json

from transformers import AutoTokenizer, AutoModel

import sklearn.metrics
import scipy.optimize
import itertools
import collections

from nltk.tokenize import TweetTokenizer
global TOKENIZER
TOKENIZER = TweetTokenizer()

import matplotlib.pyplot as plt
import tqdm
import seaborn as sns

import scipy.special

sns.set()


import sklearn.preprocessing

def tweet_tokenizer(cap_in, lower=False):
    if lower:
        return [x.lower() for x in TOKENIZER.tokenize(cap_in)]
    else:
        return TOKENIZER.tokenize(cap_in)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--corpus',
        default='representation_corpus.txt',
        type=str)
    parser.add_argument(
        '--model',
        default='bert-base-cased')
    parser.add_argument(
        '--compute_token_identifiability',
        type=int,
        default=1)
    parser.add_argument(
        '--compute_deshuffled_attention',
        type=int,
        default=1)
    parser.add_argument(
        '--test',
        type=int,
        default=0,
        help='if test, run the code without shuffling, and check some additional assertions.'
        ' ACL RESPONSE: this is the baseline that people asked for.')
    parser.add_argument(
        '--sent_limit',
        type=int,
        default=-1,
        help='if > 1, then compute only over the first this-many sentences.')
    parser.add_argument(
        '--lowercase',
        type=int,
        default=0,
        help='should we lowercase everything?')
    args = parser.parse_args()
    if '-cased' in args.model and args.lowercase:
        print('if you want lower case, use a lowercase model.')
        print(args.model, args.lowercase)
        quit()

    return args


def plot_self_attention(self_attention, toks, filename):
    '''
    if args.plot:
        plot_self_attention(
                original_a,
                sample_sentence_toks_text,
                'layer={}-head={}-original_attention.pdf'.format(layer_idx, head_idx))
        plot_self_attention(
                shuffled_a,
                shuff_sentence_toks_text,
                'layer={}-head={}-shuffled_attention.pdf'.format(layer_idx, head_idx))
        plot_self_attention(
                deshuffled_a,
                sample_sentence_toks_text,
                'layer={}-head={}-jsd2random={:.4f}-deshuffled_attention.pdf'.format(layer_idx, head_idx, smallest_jsd/smallest_jsd_random))
    '''

    plt.imshow(self_attention)
    plt.xticks(range(len(toks)), toks, rotation=45)
    plt.yticks(range(len(toks)), toks)
    plt.tight_layout()
    plt.savefig(filename)
    plt.cla()
    plt.clf()


def check_tokens_same_id(tok1, tok2, model_name):
    if 'roberta' in model_name:
        tok1 = tok1.replace('Ġ', '')
        tok2 = tok2.replace('Ġ', '')
    if tok1 == tok2: return True
    return False


def get_all_valid_permutations(toks1, toks2, model_name, exact_cutoff=32):
    if 'roberta' in model_name:
        toks1 = np.array([t.replace('Ġ', '') for t in toks1])
        toks2 = np.array([t.replace('Ġ', '') for t in toks2])

    assert len(toks1) == len(toks2), (toks1, toks2)

    # compute the expected number of valid permutations
    bow = collections.Counter([t.replace('Ġ', '') for t in toks1])
    expected = np.prod([np.math.factorial(v) for v in bow.values()], dtype=np.float32)

    alignment = np.expand_dims(toks1, 1) == np.expand_dims(toks2, 0)

    if expected > exact_cutoff:
        m_rng = np.random.RandomState(1)
        # randomly sample
        perms = collections.defaultdict(list)
        for idx in range(len(alignment)):
            locs = list(np.where(alignment[idx])[0])
            perms[frozenset(locs)].append(idx)

        pairings = {tuple(sorted(v)): list(sorted(k)) for k, v in perms.items()}

        valid_idxs = []
        for sample_idx in range(exact_cutoff):
            pairings_copy = copy.deepcopy(pairings)

            for v in pairings_copy.values():
                m_rng.shuffle(v)

            assignment = []
            for k, v in pairings_copy.items():
                assignment.extend(zip(k,v))
            assignment.sort(key=lambda x: x[0])
            valid_idxs.append(np.array([x[1] for x in assignment]))

        valid_idxs = np.vstack(valid_idxs)
        assert valid_idxs.shape[0] == exact_cutoff

        for p in valid_idxs:
            cand1 = ' '.join([toks2[idx] for idx in p])
            cand2 = ' '.join(toks1)
            assert cand1 == cand2, (cand1, cand2)

        return valid_idxs


    perms = collections.defaultdict(list)
    for idx in range(len(alignment)):
        locs = list(np.where(alignment[idx])[0])
        perms[frozenset(locs)].append(idx)

    flatten = lambda l: [item for sublist in l for item in sublist]

    perms = list({tuple(sorted(v)): list(itertools.permutations(k)) for k, v in perms.items()}.items())
    key_order = [p[0] for p in perms]
    valid_idxs = []
    for poss in itertools.product(*[p[1] for p in perms]):
        valid_idxs.append([x[1] for x in sorted(zip(flatten(key_order), flatten(poss)), key=lambda x: x[0])])
    valid_idxs = np.array(valid_idxs)
    unique = np.unique(valid_idxs, axis=0)

    assert valid_idxs.shape == unique.shape
    assert valid_idxs.shape[0] == expected

    valid_idxs = unique

    for p in valid_idxs:
        cand1 = ' '.join([toks2[idx] for idx in p])
        cand2 = ' '.join(toks1)
        assert cand1 == cand2, (cand1, cand2)

    return valid_idxs


def my_jsd(p, q, base=None):
    p = np.asarray(p)
    q = np.asarray(q)

    p = p / np.sum(p, axis=0)
    q = q / np.sum(q, axis=0)

    m = (p + q) / 2.0

    left = scipy.special.rel_entr(p, m)
    right = scipy.special.rel_entr(q, m)

    js = np.sum(left, axis=0) + np.sum(right, axis=0)
    if js < 0:
        print('numerical warning: {} is less than zero!'.format(js))
        js = 0.0
    if base is not None:
        js /= np.log(base)
    return np.sqrt(js / 2.0)


def my_vectorized_jsd(p, q, base=None, eps=1e-7):
    p = np.asarray(p)
    q = np.asarray(q)

    p += np.random.uniform(low=0.0, high=eps, size=p.shape)
    q += np.random.uniform(low=0.0, high=eps, size=q.shape)

    p = p / p.sum(axis=1)[:, np.newaxis]
    q = q / q.sum(axis=1)[:, np.newaxis]

    m = (p + q) / 2.0

    left = scipy.special.rel_entr(p, m)
    right = scipy.special.rel_entr(q, m)

    js = np.sum(left, axis=1) + np.sum(right, axis=1)
    js = np.clip(js, 0, None)

    if base is not None:
        js /= np.log(base)

    return np.sqrt(js / 2.0)


def row_jsd(att1, att2):
    return my_vectorized_jsd(att1, att2)


def get_token_identifiability(sent, tokenizer, model, args, n_shuff=32, test=False):
    np.random.seed(1)
    sent_toks = tweet_tokenizer(sent, lower=args.lowercase)
    sent = ' '.join(sent_toks)
    shuff_versions = []
    for idx in range(n_shuff):
        shuff_sentence = sent_toks.copy()
        if not test: # test doesn't shuffle
            np.random.shuffle(shuff_sentence)
        shuff_versions.append(' '.join(shuff_sentence))

    sent_toks_model = tokenizer.encode(sent)
    shuff_toks_model = [tokenizer.encode(s) for s in shuff_versions]

    sent_toks_text = tokenizer.convert_ids_to_tokens(sent_toks_model)
    shuff_toks_text = [tokenizer.convert_ids_to_tokens(s) for s in shuff_toks_model]

    for cur_shuff_toks in shuff_toks_text:
        assert collections.Counter(sent_toks_text) == collections.Counter(cur_shuff_toks)

    input_ids = torch.tensor([sent_toks_model] + shuff_toks_model).to(args.device)

    output = model(input_ids)
    last_hidden, pooler, hiddens, attentions = output['last_hidden_state'], output['pooler_output'], output['hidden_states'], output['attentions']

    token_id_acc = collections.defaultdict(list)
    token_id_acc_random = collections.defaultdict(list)

    for layer_idx, h in enumerate(hiddens):
        h = h.cpu().detach().numpy()
        for sample_idx in range(len(h)-1):
            pairwise = sklearn.metrics.pairwise_distances(h[0], h[sample_idx+1], metric='cosine')
            row, col = scipy.optimize.linear_sum_assignment(pairwise)
            if test:
                assert np.all(row == col)
            tok_corr = []
            for r, c in zip(row, col):
                tok_corr.append(check_tokens_same_id(sent_toks_text[r], shuff_toks_text[sample_idx][c], args.model))
            token_id_acc[layer_idx].append(np.mean(tok_corr))

            if test:
                assert np.mean(tok_corr) == 1.0

            pairwise_random = np.random.random(size=pairwise.shape)
            row_rand, col_rand = scipy.optimize.linear_sum_assignment(pairwise_random)
            tok_corr_random = []
            for r, c in zip(row_rand, col_rand):
                tok_corr_random.append(check_tokens_same_id(sent_toks_text[r], shuff_toks_text[sample_idx][c], args.model))

            token_id_acc_random[layer_idx].append(np.mean(tok_corr_random))

    return token_id_acc, token_id_acc_random


def get_deshuffled_attention(sent, tokenizer, model, args, n_shuff=20, exact_cutoff=32):
    print(sent)
    np.random.seed(1)
    sent_toks = tweet_tokenizer(sent, lower=args.lowercase)
    sent = ' '.join(sent_toks)
    shuff_versions = []
    for idx in range(n_shuff):
        shuff_sentence = sent_toks.copy()
        if not args.test:
            np.random.shuffle(shuff_sentence)
        shuff_versions.append(' '.join(shuff_sentence))

    sent_toks_model = tokenizer.encode(sent)
    shuff_toks_model = [tokenizer.encode(s) for s in shuff_versions]

    sent_toks_text = tokenizer.convert_ids_to_tokens(sent_toks_model)
    shuff_toks_text = [tokenizer.convert_ids_to_tokens(s) for s in shuff_toks_model]

    for cur_shuff_toks in shuff_toks_text:
        assert collections.Counter(sent_toks_text) == collections.Counter(cur_shuff_toks)

    timea = time.time()
    poss_perms = [get_all_valid_permutations(sent_toks_text, s, args.model, exact_cutoff=exact_cutoff) for s in shuff_toks_text]
    print('poss perms: {}'.format(list(map(len, poss_perms))))

    input_ids = torch.tensor([sent_toks_model] + shuff_toks_model).to(args.device)

    output = model(input_ids)
    last_hidden, pooler, hiddens, attentions = output['last_hidden_state'], output['pooler_output'], output['hidden_states'], output['attentions']

    layer2head2jsds = collections.defaultdict(lambda : collections.defaultdict(list))
    layer2head2random_jsds = collections.defaultdict(lambda : collections.defaultdict(list))

    for layer_idx, a in enumerate(attentions) if len(poss_perms[0]) == 1 else tqdm.tqdm(enumerate(attentions), total=len(attentions)):
        a = a.cpu().detach().numpy()
        for sample_idx in range(len(a)-1):
            for head_idx in range(a.shape[1]):
                cur_head_a = a[:, head_idx, :, :]
                original_a = cur_head_a[0]
                shuffled_a = cur_head_a[sample_idx + 1]

                smallest_jsd, smallest_jsd_perm = np.inf, None
                smallest_jsd_random, smallest_jsd_random_perm = np.inf, None
                for p in poss_perms[sample_idx]:
                    p_random = np.random.permutation(len(p))

                    candidate_a = shuffled_a[:, p][p, :]
                    candidate_random_a = shuffled_a[:, p_random][p_random, :]

                    jsd = row_jsd(original_a, candidate_a)
                    if np.mean(jsd) < smallest_jsd:
                        smallest_jsd = np.mean(jsd)
                        smallest_jsd_perm = p

                    jsd_random = row_jsd(original_a, candidate_random_a)
                    if np.mean(jsd_random) < smallest_jsd_random:
                        smallest_jsd_random = np.mean(jsd_random)
                        smallest_jsd_random_perm = p_random

                layer2head2jsds[layer_idx][head_idx].append(smallest_jsd)
                layer2head2random_jsds[layer_idx][head_idx].append(smallest_jsd_random)

    return layer2head2jsds, layer2head2random_jsds


def main():
    args = parse_args()

    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    np.random.seed(1)

    if 'roberta' in args.model:
        tokenizer = AutoTokenizer.from_pretrained(args.model, add_prefix_space=True)
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.model)

    model = AutoModel.from_pretrained(args.model, output_hidden_states=True, output_attentions=True, return_dict=True)
    model.eval()
    model = model.to(args.device)

    with open(args.corpus) as f:
        sents = [x.strip() for x in f.readlines()]

    np.random.shuffle(sents)

    if args.sent_limit > 0:
        sents = sents[ : args.sent_limit]

    if args.compute_token_identifiability:
        token_id_accs, token_id_accs_random = [], []
        for s in tqdm.tqdm(sents):
            acc, acc_random = get_token_identifiability(s, tokenizer, model, args, test=args.test, n_shuff=32)
            token_id_accs.append(acc)
            token_id_accs_random.append(acc_random)

        n_layers = len(token_id_accs[0])
        xs = list(range(n_layers))
        ys = []
        yerrs = []
        ydists = []
        for idx in range(n_layers):
            acc_ratio = []
            for sent_idx in range(len(token_id_accs)):
                cur_acc = token_id_accs[sent_idx][idx]
                cur_random_acc = token_id_accs_random[sent_idx][idx]
                acc_ratio.append(np.mean(cur_acc) / np.mean(cur_random_acc))
            ys.append(np.mean(acc_ratio))
            yerrs.append(scipy.stats.sem(acc_ratio))
            ydists.append(acc_ratio)
        plt.plot(xs, ys)
        plt.xlabel('Layer index')
        plt.ylabel('Acc / Random Acc')
        plt.savefig('acc_ratio.pdf')

        with open('token_identifiability_statistics_{}_limit_{}_test_{}_lowercase_{}.json'.format(args.model, args.sent_limit, args.test, args.lowercase), 'w') as f:
            f.write(json.dumps({'xs': xs, 'ydists': ydists}))

        plt.cla()
        plt.clf()

    if args.compute_deshuffled_attention:

        jsds, jsds_random = [], []
        for s in tqdm.tqdm(sents):
            # layer2head2jsds
            jsd, jsd_random = get_deshuffled_attention(s, tokenizer, model, args, n_shuff=32, exact_cutoff=16)
            jsds.append(jsd)
            jsds_random.append(jsd_random)

        n_layers = len(jsds[0])
        n_heads = len(jsds[0][0])

        xs = list(range(n_layers))
        ys = []
        ydists = []

        # per head results
        layer2means = collections.defaultdict(list)

        for l_idx in range(n_layers):
            jsd_ratio = []
            for h_idx in range(n_heads):
                cur_head_res = []
                for sent_idx in range(len(jsds)):
                    cur_jsd = jsds[sent_idx][l_idx][h_idx]
                    cur_random_jsd = jsds_random[sent_idx][l_idx][h_idx]
                    jsd_ratio.append(float(np.mean(cur_jsd) / np.mean(cur_random_jsd)))
                    cur_head_res.append(float(np.mean(cur_jsd) / np.mean(cur_random_jsd)))
                layer2means[l_idx].append(float(np.mean(cur_head_res)))
            ys.append(float(np.mean(jsd_ratio)))
            ydists.append(jsd_ratio)

        plt.plot(xs, ys)

        per_head_xs, per_head_ys = [], []
        for l, head_res in layer2means.items():
            for h in head_res:
                per_head_xs.append(l)
                per_head_ys.append(h)

        plt.scatter(per_head_xs, per_head_ys)

        plt.xlabel('Layer index')
        plt.ylabel('JSD / Random JSD')
        plt.savefig('JSD_ratio.pdf')

        with open('deshuffled_attn_statistics_{}_limit_{}_test_{}_lowercase_{}.json'.format(args.model, args.sent_limit, args.test, args.lowercase), 'w') as f:
            f.write(json.dumps({'xs': xs, 'ydists': ydists, 'per_head_xs': per_head_xs, 'per_head_ys': per_head_ys}))


if __name__ == '__main__':
    main()
