import argparse

import torch

from data import PTBLoader
from hinton import plot
from tree_utiles import *
from utils import evalb, generate_ground_truth, generate_idx


def model_load(fn):
    global model
    with open(fn, 'rb') as f:
        if args.cuda:
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
        model, _, _ = torch.load(f, map_location=device)


def mean(x):
    return sum(x) / len(x)


@torch.no_grad()
def test(model, corpus, cuda, prt=True):
    model.eval()
    if len(corpus.test) == 4:
        sentence_list, label_list, sens_list, tree_list = corpus.test
    else:
        sentence_list, label_list = corpus.test
        tree_list = None
        sens_list = None
    vocab = corpus.dictionary

    prec_list = []
    reca_list = []
    f1_list = []

    pred_tree_list = []
    targ_tree_list = []

    nsens = 0
    total_structure_acc = 0
    total_length = 0
    for idx, sen in enumerate(sentence_list):
        data = torch.LongTensor([sen])
        structure_first = torch.LongTensor([label_list[idx]])
        if cuda:
            data = data.cuda()
            structure_first = structure_first.cuda()

        hidden, prev_structure = model.init_hidden(1)
        output, probs, hidden = model(data, hidden)

        structure_p, output_p, p, q = probs
        structure = structure_p.max(dim=-1)[1]
        prev_structure = structure
        structure_idx = generate_idx(prev_structure, structure, structure_first, model.nslot)

        total_structure_acc += (p.max(dim=-1)[1] == structure_idx).float().sum()
        total_length += len(sen)

        words = []
        for loc, index in enumerate(sen):
            words.append(vocab.idx2word[index])

        depth = structure.squeeze().cpu().numpy().tolist()
        words = words[1:]
        depth = depth[1:]
        assert len(words) == len(depth)
        parse_tree = build_tree(depth, words)

        if tree_list is None:
            true_depth = generate_ground_truth(prev_structure, structure_first, model.nslot)
            true_depth = true_depth.squeeze().cpu().numpy().tolist()
            true_depth = true_depth[1:]
            assert len(words) == len(true_depth)
            gold_tree = build_tree(true_depth, words)
        else:
            gold_tree = tree_list[idx]
            parse_tree = build_tree(depth, sens_list[idx][1:])

        pred_tree_list.append(parse_tree)
        targ_tree_list.append(gold_tree)

        nsens += 1
        if prt and nsens % 100 == 0:
            tokens = data[-1].cpu().numpy()
            indexes = structure_idx[-1].cpu().numpy()
            p0 = p[-1].cpu().numpy()
            for token_id, idx, dist in zip(tokens, indexes, p0):
                print('%15s\t%s\t%2d\t%2d\t%s' % (vocab.idx2word[token_id], (idx == numpy.argmax(dist)),
                                                  idx, numpy.argmax(dist), plot(dist, max_val=1.)))
            print(parse_tree)
            print(gold_tree)
            print()

        model_out, _ = get_brackets(parse_tree)
        std_out, _ = get_brackets(gold_tree)
        overlap = model_out.intersection(std_out)

        prec = float(len(overlap)) / (len(model_out) + 1e-8)
        reca = float(len(overlap)) / (len(std_out) + 1e-8)
        if len(std_out) == 0:
            reca = 1.
            if len(model_out) == 0:
                prec = 1.
        f1 = 2 * prec * reca / (prec + reca + 1e-8)
        prec_list.append(prec)
        reca_list.append(reca)
        f1_list.append(f1)

    prec_list, reca_list, f1_list \
        = numpy.array(prec_list).reshape((-1, 1)), numpy.array(reca_list).reshape((-1, 1)), numpy.array(
        f1_list).reshape((-1, 1))

    print('-' * 80)
    numpy.set_printoptions(precision=4)
    print('Mean Prec:', prec_list.mean(axis=0),
          ', Mean Reca:', reca_list.mean(axis=0),
          ', Mean F1:', f1_list.mean(axis=0))
    print('Number of sentence: %i' % nsens)
    print('1 look-ahead acc:', total_structure_acc / total_length)

    evalb(pred_tree_list, targ_tree_list)

    return f1_list.mean(axis=0)


if __name__ == '__main__':
    marks = [' ', '-', '=']

    numpy.set_printoptions(precision=2, suppress=True, linewidth=5000)

    parser = argparse.ArgumentParser(description='PyTorch PTB Language Model')

    # Model parameters.
    parser.add_argument('--data', type=str, default='data/rnng/en_ptb-ud',
                        help='location of the data corpus')
    parser.add_argument('--save', type=str, default='rnng.pt',
                        help='model checkpoint to use')
    parser.add_argument('--seed', type=int, default=1111,
                        help='random seed')
    parser.add_argument('--cuda', action='store_true',
                        help='use CUDA')
    parser.add_argument('--print', action='store_true',
                        help='use CUDA')
    args = parser.parse_args()

    print('Args:', args)

    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)

    # Load model
    model_load(args.save)

    # Load data
    corpus = PTBLoader(data_path=args.data)

    print('Testing...')
    test(model, corpus, args.cuda, args.print)
    print('Done\n')
