from transformers import *
from setproctitle import setproctitle
from nltk.stem import PorterStemmer
from tqdm import tqdm
import torch
import random
import numpy as np
import re
import operator
import argparse
import json
stemmer = PorterStemmer()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
setproctitle("BART_Korean_CommonGEN")
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(42)

test_index = list()
index_list = list()
count = 0
past_num = 0
#with open('test_index.txt', 'r') as t:
#    for number in t: # 1, 2, 3, 4 / 1, 2, 3, 4, 5
#        test_index.append(int(number))
#        count += 1
#        number = int(number.strip())
#        if past_num != number:
#            index_list.append(count-1)
#            count = 1
#            past_num += 1

gen_list = []
def main(args):
    tokenizer = PreTrainedTokenizerFast.from_pretrained('hyunwoongko/kobart')
    model = AutoModelForSeq2SeqLM.from_pretrained(args.checkpoint_path)
    # encode input context
    model = model.to(device)
    #with open('commongen_data/commongen_test2.json', 'r', encoding='utf-8') as f:
    with open(args.test_file, 'r', encoding='utf-8') as f:
        tmp = -1
        count = 0
        json_data = f.readlines()
        for d in tqdm(json_data):
            attention_mask = []
            d = json.loads(d)
            concept_set = d['concept-set']
            a = concept_set
            input_ids = tokenizer(a, return_tensors="pt", max_length=64, truncation=True).input_ids
            input_ids = input_ids.to(device)
            outputs = model.generate(input_ids=input_ids, num_beams=args.beam_size, num_return_sequences= 10,
                                         no_repeat_ngram_size=args.ngram, max_length = args.max_len, min_length= args.min_len, repetition_penalty=args.repeat)

            concept_set = concept_set.split('#')
            new_outputs = []
            candidate_sentences = {}
            for i in range(10):
                now_count = 0
                candidate_sentence = outputs[i]
                candidate_text = tokenizer.decode(candidate_sentence, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                try:
                    p = re.compile('다.')
                    m = p.search(candidate_text)
                    candidate_text = candidate_text[:m.end()]
                    #print(candidate_text)
                except:
                    pass

                new_outputs.append(candidate_text)

                #candidate_answer_text = candidate_text.split(' =  ')[1]
                for concept in concept_set:
                    if re.search(stemmer.stem(concept), candidate_text):
                        now_count += 1

                candidate_sentences[i] = now_count
            c_dict = sorted(candidate_sentences.items(), key=operator.itemgetter(1), reverse=True)[:1]
            print(c_dict)


            for generated_sequence_idx, (k, v) in enumerate(c_dict):
                print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1))
                    # generated_sequence = generated_sequence.tolist()

                    # if generated_sequence_idx == 0:
                    #    p = k

                generated_sequence = new_outputs[int(k)]
                    # print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
                #result = tokenizer.decode(generated_sequence, skip_special_tokens=True, clean_up_tokenization_spaces=True)

                print(generated_sequence)
                gen_list.append(generated_sequence)
            count += 1



    with open(args.output_dir, 'w', encoding='utf-8') as f:
        for line in gen_list:
            f.write(line)
            f.write('\n')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--checkpoint_path", type=str)
    parser.add_argument("--beam_size", type=int)
    parser.add_argument("--output_dir", type=str)
    parser.add_argument("--max_len", type=int)
    parser.add_argument("--model_name", type=str)
    parser.add_argument("--min_len", type=int)
    parser.add_argument("--ngram", type= int, default=None)
    parser.add_argument("--repeat", type=float, default=None)
    parser.add_argument("--test_file", type=str, default=None)
    args = parser.parse_args()

    main(args)
