import os
import sys
import math
import json
import random
import argparse

from scipy.special import comb
from itertools import combinations

sys.path.append("..")
from utils.logConfig import Log
from utils.makeLabel import rouge_1_recall, rouge_2_recall

OPTION_NUM = 4
MAX_DOC = 50
MAX_SUM = 3
MAX_QUES = 3
MAX_TRY = 10

random.seed(666)

def readPair(prefix):
    doc_file = prefix + '.doc'
    sum_file = prefix + '.sum'
    examples = []
    with open(doc_file) as doc, open(sum_file) as summ:
        for d, s in zip(doc, summ):
            examples.append({"doc": d.replace('\r', "").strip().split("<q>"), 
                            "summ": s.replace('\r', "").strip().replace("<q>", "\n")})
    return examples

def filter_low_r2(doc, summ, language):
    sents = []
    for sent in doc:
        score = rouge_2_recall(sent, summ, language=language)
        if score > 0.0:
            sents.append(sent)
    return sents


def getOption(example, language, numq=MAX_QUES):
    samples = []
    sents, summ = example["doc"], example["summ"]
    sents = filter_low_r2(sents, summ, language)

    doc_len = min(len(sents), MAX_DOC)
    opt_len = min(min(len(summ.split("\n")), doc_len), MAX_SUM)
    # print(doc_len, opt_len)
    cand_comb_idx = list(combinations(range(doc_len), opt_len))
    if len(cand_comb_idx) < OPTION_NUM:
        return [] 
    # print(cand_comb_idx)
    
    max_combs = comb(len(cand_comb_idx), OPTION_NUM)
    ques_opts = []
    
    cnt = 0
    while len(ques_opts) < min(numq, max_combs):    # MAX_QUES=5
        if cnt > MAX_TRY :
            logger.info("Limited by MAX_TRY=%d. The question number is %d", MAX_TRY, len(ques_opts))
            break
        
        current = random.sample(cand_comb_idx, OPTION_NUM)
        if current in ques_opts:
            continue
        
        candidate, score = [], []
        for opt in current:                        # OPTION_NUM = 4
            cand = "\n".join([sents[i] for i in sorted(opt)])
            sc = rouge_1_recall(hyps=cand, refer=summ, language=language)
            candidate.append(cand)
            score.append(sc)
            
        # TODO: some filters
        if max(score) > 0.1:
            ques_opts.append(current)
            samples.append({"candidate": candidate, "idx": current, "score": score, "label": score.index(max(score))})
        else:
            max_combs -= 1
        cnt += 1    

    # print(summ)
    # print(samples)
    return samples


def saveOptions(name, examples, samples):
    outf_context = open(name + ".input0", "w")
    outf_label = open(name + ".label", "w")
    qa_files = [open(name + ".input" + str(i + 1), "w") for i in range(4)]
    for ex, sample in zip(examples, samples):
        if sample != []:
            for ques in sample:  # [0, MAX_CAND]
                outf_context.write(" </s> ".join(ex["doc"]) + "\n")
                outf_label.write(str(ques["label"]) + '\n')
                for j in range(len(ques["candidate"])):   # [0, OPTION_NUM]
                    qa_files[j].write(ques["candidate"][j].replace("\n", " </s> ") + '\n')
    
    for f in qa_files:
        f.close()
    outf_label.close()
    outf_context.close()



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', type=str, default="input", help='input')
    parser.add_argument('-o', type=str, default="output", help='output dir')
    parser.add_argument('-t', type=str, default="dev", help='set type')
    parser.add_argument('-l', type=str, default="en", help='language')
    args = parser.parse_args()

    logger = Log.getLogger(os.path.basename(sys.argv[0]))
    logger.info(args)

    if not os.path.exists(args.o):
        os.makedirs(args.o)

    for set_type in args.t.split(','):
        input_prefix = os.path.join(args.i, "{}.{}".format(set_type, args.l))
        examples = readPair(input_prefix)
        sample_file = open(os.path.join(args.o, "{}.samples.jsonl".format(set_type)), "w")
        sample_list = []
        for ex in examples:
            if set_type == "dev":
                samples = getOption(ex, args.l, numq=1)
            else:
                samples = getOption(ex, args.l)
            sample_list.append(samples)
            sample_file.write(json.dumps(samples, ensure_ascii=False) + "\n")
            sample_file.flush()

            if not len(sample_list) % 100:
                logger.info("Processd %d example in %s", len(sample_list), set_type)
        
        
        logger.info("Save to %s", os.path.join(args.o, set_type))
        saveOptions(name=os.path.join(args.o, set_type), examples=examples, samples=sample_list)
        logger.info("Finish %s", input_prefix)
        
        

    



