import json
import torch
import os
import random
import numpy as np

max_sent_number = 8
prefix = "/home/tiger/nlpcc/MSPM/train.spm"
sent_split = '</s>'
output = "/opt/tiger/sumtest/candidate.jsonl"

def n_grams(tokens, n):
    l = len(tokens)
    return [tuple(tokens[i:i + n]) for i in range(l) if i + n < l]

def cal_overlap_fscore(candidate_tokens: list, summary_tokens: list):
    # candidate_tokens = candidate.split()
    # summary_tokens = summary.split()
    overlap_ratio = []
    for i in range(1, 3):
        summary_ngram = n_grams(summary_tokens, i)
        candidate_ngram = n_grams(candidate_tokens, i)
        overlap = [x for x in candidate_ngram if x in summary_ngram]
        recall = len(overlap) / (len(summary_ngram) + 1e-6)
        precision = len(overlap) / (len(candidate_ngram) + 1e-6)
        f1 = 2 * recall * precision / (recall + precision + 1e-6)
        overlap_ratio.append(f1)
    return sum(overlap_ratio) / float(len(overlap_ratio))

def get_candidate(document_sentences: list, number: int):
    num_sentence = len(document_sentences)
    candidates = []
    for _ in range(number+1): # +1 for positive candidate
        sampled_sent_num = min(random.randrange(1, num_sentence+1), max_sent_number)
        candidate_idx = random.sample([i for i in range(num_sentence)], sampled_sent_num)
        candidate_idx = sorted(candidate_idx, key=lambda x: x, reverse=False)
        candidate = []
        for idx in candidate_idx:
            candidate.append(document_sentences[idx])
        candidates.append({'sentence':candidate})   # 1 means mask

    return candidates

def get_oracle(info: dict):
    doc_sents = info['document_sentences']
    summ = info['reference']
    summary_tokens = summ.split()[1:-1]
    
    # calculate score for each sentence
    sent_scores = []
    for sent in doc_sents:
        score = cal_overlap_fscore(sent.split(), summary_tokens)
        sent_scores.append((sent, score))
    
    sent_scores = sorted(sent_scores, key=lambda x: x[1], reverse=True)
    
    max_score = 0
    oracle = []
    for item in sent_scores:
        oracle.append(item[0])
        score = cal_overlap_fscore(" ".join(oracle).split(), summary_tokens)
        if score > max_score:
            max_score = score
        else:
            return oracle, max_score
    return oracle, max_score

def rank_candidate(candidates: list, target: str):
    summary_tokens = target.split()[1:-1]
    max_score_idx = 0
    max_score = 0.0
    scores = []
    # ignore language and <bos>
    for i in range(len(candidates)):
        candidate_tokens = " ".join(candidates[i]['sentence']).split()
        score = cal_overlap_fscore(candidate_tokens, summary_tokens)
        scores.append(score)
        if score > max_score:
            max_score = score
            max_score_idx = i

    negatives = [{'sent': candidates[i]['sentence'], 'score': round(scores[i], 4)} for i in range(len(candidates)) if i != max_score_idx]    
    return {'sent': candidates[max_score_idx]['sentence'], 'score': round(max_score, 4)}, negatives


if __name__ == "__main__":
    infos = []
    with open(prefix + '.doc', 'r') as fdoc, open(prefix + '.sum') as fsum:
        for (doc, summ) in zip(fdoc, fsum):
            info = {
                "document_sentences": doc.strip().split(sent_split),
                "reference": summ.strip()
            }   
            infos.append(info)
            if len(infos) >= 300:
                break

    for i in range(len(infos)):
        info = infos[i]
        oracle, max_score = get_oracle(info)
        info['oracle'] = {'sent': oracle, 'score': round(max_score, 4)}
        candidates = get_candidate(info['document_sentences'], number=2)
        positive, negatives = rank_candidate(candidates, info['reference'])
        info['positive'] = positive
        info['negatives'] = negatives
        infos[i] = info

    oracle = []
    positive = []
    negatives = [[] for _ in range(2)]
    with open(output, 'w', encoding='utf-8') as fout:
        for info in infos:
            fout.write(json.dumps(info, ensure_ascii=False)+'\n')
            oracle.append(info['oracle']['score'])
            positive.append(info['positive']['score'])
            items = info['negatives']
            items= sorted(items, key=lambda x: x['score'], reverse=True)
            for (i, item) in enumerate(items):
                negatives[i].append(item['score'])
    
    print("oracle: ", np.mean(oracle))
    print("positive: ", np.mean(positive))
    for (i, negative) in enumerate(negatives):
        print("negative {}: {}".format(i, np.mean(negative)))