import torch
import random
import numpy as np
import re
import operator
import argparse
from tqdm import tqdm
from nltk.stem import PorterStemmer
from setproctitle import setproctitle
from transformers import *
from itertools import permutations, combinations
stemmer = PorterStemmer()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
setproctitle("GPT2_Korean_CommonGEN")
ATTR_TO_SPECIAL_TOKEN = ['[SOS] ', ' [EOS]', ' = ', ' [SEP] ']

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

def main(args):
    #total_tokens = []
    checkpoint = args.checkpoint_path
    if args.model_name == 'openai-gpt':
        tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name)  # Tokenizer download
        model = OpenAIGPTLMHeadModel.from_pretrained(args.model_name)  # model and vocab download
    elif args.model_name == 'kogpt2':
        tokenizer = PreTrainedTokenizerFast.from_pretrained('skt/kogpt2-base-v2')  # Tokenizer download
        model = GPT2LMHeadModel.from_pretrained('skt/kogpt2-base-v2')  # model and vocab download
    elif args.model_name == 'gpt2-medium':
        tokenizer = GPT2Tokenizer.from_pretrained(args.model_name)  # Tokenizer download
        model = GPT2LMHeadModel.from_pretrained(args.model_name)  # model and vocab download
    elif args.model_name == 'gpt2-large':
        tokenizer = GPT2Tokenizer.from_pretrained(args.model_name)  # Tokenizer download
        model = GPT2LMHeadModel.from_pretrained(args.model_name)  # model and vocab download
    elif args.model_name == 'xlnet-base-cased':
        tokenizer = XLNetTokenizer.from_pretrained(args.model_name)  # Tokenizer download
        model = XLNetLMHeadModel.from_pretrained(args.model_name)  # model and vocab download
    elif args.model_name == 'xlm-mlm-en-2048':
        tokenizer = XLMTokenizer.from_pretrained(args.model_name)  # Tokenizer download
        model = XLMWithLMHeadModel.from_pretrained(args.model_name)  # model and vocab download
    else:
        tokenizer = TransfoXLTokenizer.from_pretrained(args.model_name)  # Tokenizer download
        model = TransfoXLLMHeadModel.from_pretrained(args.model_name)  # model and vocab download

    num_added_tokens = tokenizer.add_special_tokens({'additional_special_tokens': ATTR_TO_SPECIAL_TOKEN})

    model.resize_token_embeddings(len(tokenizer)+1)
    model.load_state_dict(torch.load(checkpoint), strict=False)
    model.to(device)

    generated_sequences = []
    stop_token = ' [EOS]'
    prompt_text = '[SOS] '
    all_concepts = []
    missing_concepts = []

    with open(args.test_file, 'r', encoding='utf-8') as f:
        for line in tqdm(f):
            #line = re.sub('[-=+,#/\?:^$.@*\"※~&%ㆍ!』\\‘|\(\)\[\]\<\>`\'…》]', '', line)
            line = line.replace('[SOS] ', '')
            #print(line)
            concept_sep = line.replace(' = ', '').strip()
            line = line.split(' = ')[0]
            #tokens = '[SOS] ' + line + ' = '
            tokens = ['[SOS] '] + tokenizer.tokenize(line) + [' = ']
            total_tokens = tokenizer.convert_tokens_to_ids(tokens)
            input_ids = torch.tensor(total_tokens).unsqueeze(0).to(device)
            #input_ids = torch.tensor(tokenizer.encode(tokens)).unsqueeze(0).to(device)
            all_concepts.append(concept_sep)
            concept_set = concept_sep.split(' ')
            candidate_sentences = {}
            try:
                output_sequences = model.generate(input_ids=input_ids, num_beams= args.beam_size, no_repeat_ngram_size=args.ngram, do_sample= True, num_return_sequences=args.num_sentences, max_length=len(input_ids)+args.max_len)  # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
            except:
                output_sequences = model.generate(input_ids=input_ids, num_beams= args.beam_size, no_repeat_ngram_size=args.ngram, do_sample= True, num_return_sequences=args.num_sentences, max_length=len(input_ids)+args.max_len+100)  # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
            if len(output_sequences.shape) > 2:
                output_sequences.squeeze_()
            new_outputs = []
            for i in range(args.num_sentences):
                now_count = 0
                candidate_sentence = output_sequences[i]
                candidate_text = tokenizer.decode(candidate_sentence, clean_up_tokenization_spaces=True)
                try:
                    candidate_answer_text = candidate_text.split(' = ')[1]
                except:
                    candidate_answer_text = candidate_text

                candidate_answer_text = candidate_answer_text.replace('[SEP]', ' ')
                candidate_answer_text = re.sub(' +', ' ', candidate_answer_text)[1:]
                #candidate_answer_text = candidate_answer_text.replace('   ', '')

                try:
                    p = re.compile('다.')
                    m = p.search(candidate_answer_text)
                    candidate_answer_text = candidate_answer_text[:m.end()]
                    #print(candidate_answer_text)

                except:
                    pass

                new_outputs.append(candidate_answer_text)
                for concept in concept_set:
                    if re.search(stemmer.stem(concept), candidate_answer_text):
                        now_count += 1

                candidate_sentences[i] = now_count
            c_dict = sorted(candidate_sentences.items(), key=operator.itemgetter(1), reverse= True)[:1]
            print(c_dict)

            #for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
            for generated_sequence_idx, (k, v) in enumerate(c_dict):

                generated_sequence = new_outputs[int(k)]
                print(generated_sequence)
                generated_sequences.append(generated_sequence)

    with open(args.output_dir, 'w', encoding= 'utf-8') as f:
        for line in generated_sequences:
            f.write(line)
            f.write('\n')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--checkpoint_path", type=str)
    parser.add_argument("--beam_size", type=int)
    parser.add_argument("--num_sentences", type=int)
    parser.add_argument("--ngram", type=int)
    parser.add_argument("--output_dir", type=str)
    parser.add_argument("--max_len", type=int)
    parser.add_argument("--model_name", type=str)
    parser.add_argument("--test_file",type=str)


    args = parser.parse_args()

    main(args)
