import random
import os
import numpy as np
import torch
import pickle

from tqdm import tqdm
from rouge_score import rouge_scorer



def seed_everything(seed=42):

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def check_data_pipe(loaders):

    for loader in loaders:
        for idx, batch in enumerate(loader):
            print("*"*50)
            print(batch['text_lang'])
            print(batch['text_inputs']["input_ids"][:,:10])
            print(batch['summary_lang'])
            print(batch['summary_inputs']["input_ids"][:,:10])
            break


def display_losses(mode, losses):

    best_loss = np.min(np.array(losses))
    best_loss_idx = np.argmin(np.array(losses)) + 1
    print("Current {} loss is {:.4f}, best {} loss is {:.4f} achieved at iter {} / {}".format(mode, losses[-1], mode, best_loss, best_loss_idx, len(losses)))


def display_scores(mode, scores):

    for k in scores.keys():
        scores_k = scores[k]
        if "loss" in k:
            best_score_k = np.min(np.array(scores_k))
            best_score_k_idx = np.argmin(np.array(scores_k)) + 1
        else:
            best_score_k = np.max(np.array(scores_k))
            best_score_k_idx = np.argmax(np.array(scores_k)) + 1
        print("Best {} {} is {:.4f} achieved at iter {} / {}".format(mode, k, best_score_k, best_score_k_idx, len(scores_k)))


def compute_r1s(sents):
    scorer = rouge_scorer.RougeScorer(['rouge1'], use_stemmer=False)

    all_r1s = []
    for i in range(len(sents)):
        pruned_sents = sents[:i] + sents[(i + 1):]
        pruned_text = " ".join(pruned_sents)
        scores = scorer.score(pruned_text, sents[i])
        r1 = 100 * scores["rouge1"].fmeasure
        all_r1s.append(r1)
    all_r1s = np.array(all_r1s)

    return all_r1s


def check_scores(dataset):
    all_scores = []
    for i in tqdm(range(len(dataset.scored_summaries))):
        scores = dataset.scored_summaries[i][1]
        max_score = np.max(np.array(scores))
        all_scores.append(max_score)
    m_score = np.mean(all_scores)

    return m_score


def check_training_data(args):
    print("\n", "*"*70, "Checking the data")
    for x in [(args.train_datasets, args.train_data_sizes, args.train_model_names), 
            ([args.val_dataset], [args.val_data_size], [args.model_name]), ([args.test_dataset], [args.test_data_size], [args.model_name])]:
        set, size, models = x
        for i in range(len(set)):
            set_ = set[i]
            size_ = size[i]
            model_ = models[i]
            print("\n", "*"*30, "Set: {}".format(set_))
            for generation_method in args.generation_methods:
                print("\n", "*"*10, "Generation method: {}".format(generation_method))
                gen_scored_summaries = []
                for j in range(len(args.scoring_methods)):
                    scored_summaries_j = []
                    scored_summaries_path_j = args.scored_summaries_path + "{}/{}/{}/{}_scored_summaries_{}_{}_beams_{}.pkl".format(
                        generation_method, args.scoring_methods[j], set_, set_, model_, size_, args.num_beams
                    )
                    with open(scored_summaries_path_j, "rb") as f:
                        scored_summaries_j = pickle.load(f)
                    #print("\nSummaries: {}".format(scored_summaries_path_j), "Length: {}".format(len(scored_summaries_j)))
                    print("\nScoring metric: {}".format(args.scoring_methods[j]))
                    for i in range(3):
                        print(">"*10, scored_summaries_j[0][0][i], scored_summaries_j[0][1][i])






