#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import json
import random
import re
import time
import os

import numpy
import torch
from torch import nn

from data import PTBLoader
from helpers import Dictionary, Dictionary_wordpiece

exclude_suite_re = re.compile(r"^fgd-embed[34]|^gardenpath|^nn-nv")

parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='data/dependency/UD_English-PTB/en_ptb-ud',
                    help='location of the data corpus')
parser.add_argument('--seed', type=int, default=141,
                    help='random seed')
parser.add_argument('--cuda', action='store_true',
                    help='use CUDA')
randomhash = ''.join(str(time.time()).split('.'))
parser.add_argument('--save', type=str, default=randomhash + '.pt',
                    help='path to save the final model')
parser.add_argument('--device', type=int, default=0, help='select GPU')

args = parser.parse_args()
args.tied = True

# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
random.seed(args.seed)
numpy.random.seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.set_device(args.device)
    if not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")
    else:
        torch.cuda.manual_seed(args.seed)


###############################################################################
# Load data
###############################################################################


def model_load(fn):
    global model
    with open(fn, 'rb') as f:
        if args.cuda:
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
        model, _, _ = torch.load(f, map_location=device)

print("loading penn data ...")
corpus = PTBLoader(data_path=args.data)
vocab = corpus.dictionary

args.vocab_size = len(vocab)

print("done loading, vocabulary size: {}".format(args.vocab_size))

###############################################################################
# Build the model
###############################################################################

criterion = nn.CrossEntropyLoss(reduction='none')

model_load(args.save)
model.eval()

test_data_path = 'data/test_suites/json'
test_set_files = os.listdir(test_data_path)

result_list = []

with open('results/syntaxgym_results_eSOM_bllip-sm.csv', 'w') as fout:
    fout.write('model,suite,item,correct\n')

    for test_file in test_set_files:
        test_file_path = os.path.join(test_data_path, test_file)
        test_data = json.load(open(test_file_path, 'r'))
        print(test_data['meta']['name'])
        test_name = test_data['meta']['name']
        metric = test_data['meta']['metric']

        for item_id, item in enumerate(test_data['items']):
            # print(item)
            for condition in item['conditions']:
                print(condition['condition_name'])
                condition_name = condition['condition_name']
                phrase_list = []
                hidden, prev_structure = model.init_hidden(1)
                if isinstance(vocab, Dictionary):
                    token_list = 'A meaningless sentence . </s>'.split(' ')
                    # token_list = ['</s>']
                    assert '</s>' in vocab
                    token_id_list = [vocab.get_idx(token, loc) for loc, token in enumerate(token_list)]
                elif isinstance(vocab, Dictionary_wordpiece):
                    token_id_list = vocab['A meaningless sentence . </s>']
                data = torch.tensor(token_id_list, dtype=torch.long, device=hidden[0].device)
                output, _, hidden = model(data[None, :], hidden)
                output = output[-1, :].unsqueeze(0)

                loc = 0
                for region in condition['regions']:
                    phrase = region['content']
                    token_list = phrase.split(' ')
                    nll_sum = 0
                    for token in token_list:
                        if len(token) > 0:
                            if isinstance(vocab, Dictionary):
                                token_id_list = [vocab.get_idx(token, loc)]
                                loc += 1
                                print(vocab.idx2word[token_id_list[0]], end=' ')
                            elif isinstance(vocab, Dictionary_wordpiece):
                                phrase_list.append(token)
                                token_id_list = vocab[token]
                                for token_id in token_id_list:
                                    print(vocab.idx2word(token_id)[0], end='')


                            for token_id in token_id_list:
                                data = torch.tensor([token_id], dtype=torch.long, device=hidden[0].device)
                                nll_sum += criterion(output, data).sum().data.cpu().numpy().tolist()
                                output, _, hidden = model(data[None, :], hidden)

                    region['metric_value']['sum'] = nll_sum
                print()
                # for token_id in vocab[' '.join(phrase_list)]:
                #     print(vocab.idx2word(token_id)[0], end='')
                # print()

            command = test_data['meta']['string_predictions'][0]
            print(command)
            m = re.search(r'\((\d)\;\%([a-zA-Z0-9_\-]+)\%\)', command)
            while m is not None:
                # print(m.group(0))
                region_number = int(m.group(1))
                condition_name = m.group(2)
                # print(condition_name)
                for condition in item['conditions']:
                    if condition_name == condition['condition_name']:
                        score = condition['regions'][region_number - 1]['metric_value'][metric]
                        assert condition['regions'][region_number - 1]['region_number'] == region_number
                # print(score)
                command = re.sub(m.group(0), str(score), command)
                m = re.search(r'\((\d)\;\%([a-zA-Z0-9_\-]+)\%\)', command)
            command = re.sub(r'\[', r'(', command)
            command = re.sub(r'\]', r')', command)
            # command = re.sub(r'=', r'==', command)
            print(command)
            result = eval(command)
            print(result)
            if exclude_suite_re.match(test_name) is None:
                result_list.append(result)
            fout.write('eSOM_bllip-sm_141,%s,%d,%s\n' % (test_name, item_id, result))
            print('Results: %d/%d is correct' % (result_list.count(True), len(result_list)))

            print()
