import numpy as np
import json
import argparse
from sklearn.model_selection import train_test_split
from utils.dataset import GEO, KNN_LM, Text2Data
from utils.model import FC, Embed
import torch
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from tqdm import tqdm
from scipy import stats
import os
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.nist_score import sentence_nist
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer
import pandas as pd
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForMaskedLM
import random
from sklearn.model_selection import KFold
import time
import re

os.environ["TOKENIZERS_PARALLELISM"] = "false" 
os.environ['TORCH_CUDA_ARCH_LIST'] = '8.6'

def quantile_normalization(gene_tensor, num_gene=20000, num_stats=5):
    # gene_tensor's shape: [#paper, 20000, 5]
    n_papers = gene_tensor.shape[0]
    processed_gene_tensor = np.zeros_like(gene_tensor)
    for stats in range(num_stats):
        data = gene_tensor[:, :, stats] # [#paper, 20000]
        sort_index = np.argsort(data) # ascending
        new_data = np.sort(data) # ascending
        avg = np.mean(new_data, axis=0) # [20000]
        new_val = np.tile(avg, (n_papers, 1)) # [#paper, 20000]
        row_idx = np.tile(range(n_papers), (num_gene,1)).T # [#paper, 20000]
        
        # Get inverse sort index
        inverse_sort_index = np.zeros_like(sort_index)
        for i in range(sort_index.shape[0]):
            for j in range(sort_index.shape[1]):
                inverse_sort_index[i,int(sort_index[i,j])] = j
        # print(inverse_sort_index)
        s_idx = inverse_sort_index.reshape(-1) 
        r_idx = row_idx.reshape(-1)
        # print(s_idx,r_idx)
        processed_gene_tensor[:, :, stats] = new_val[r_idx, s_idx].reshape(n_papers, num_gene)
        # print(processed_gene_tensor[:, :, stats])
    return processed_gene_tensor

def load_data(args):
    gene_tensor = np.load(args.data_dir + args.platform + '_imputate.npy') 
    # gene_tensor's shape: [#paper, 20000, 5]
    data = []
    gene_tensor = quantile_normalization(gene_tensor)
    
    # reshape the gene_tensor
    for i in range(gene_tensor.shape[0]):
        if i % 100 == 0:
            print(len(data))
        matrix = gene_tensor[i,:,:]
        vec = matrix.T.reshape(gene_tensor.shape[1]*gene_tensor.shape[2])
        vec = [float(x) for x in vec]
        data.append(vec)
    # data's shape: [#paper, 100000]
    data = np.array(data)
    return data

# Create the dataset
def post_process(data, train_vec, args, k=3):
    new_data = []
    # Loading the path_list  of paper
    vec_in_train_vec = [train_vec[i][0] for i in range(len(train_vec))]
    vec_in_data = [data[i][0] for i in range(len(data))]
    # c = cosine_similarity(vec_in_data,vec_in_train_vec)
    e = euclidean_distances(vec_in_data,vec_in_train_vec)
    idx_c = np.argsort(e)
    for ct, sample in enumerate(data):
        text = sample[1]
        # if 'The purpose of this experiment was to identify oestrogen regulated genes' in text:
        #     print(text)
        text_arr = []
        for idx_k in range(k):
            t_idx = idx_c[ct][idx_k+1]
            new_text = train_vec[t_idx][1]
            # t = data[]
            text_arr.append(new_text)
        new_data.append((text,text_arr))
    return new_data

def create_text_vec(gene_tensor, args):
    device_0 = torch.device(args.device)
    model_dict = {
        'specter': 'allenai/specter',
        'scibert':'allenai/scibert_scivocab_cased',
        'sentbert':'deepset/sentence_bert',
        'biobert':'dmis-lab/biobert-base-cased-v1.2',
        'bert': 'bert-base-cased',
    }
    model_name = model_dict[args.model]
    device_0 = torch.device(args.device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    encoder = AutoModel.from_pretrained(model_name)
    encoder.to(device_0)
    # Loading the path_list of paper
    with open(args.data_dir + args.platform + '_text.json', 'r') as f:
        text_dict = json.load(f)
    summary_list = text_dict['summary']
    title_list = text_dict['title']
    data = []
    # summary_set = set()
    for i in range(gene_tensor.shape[0]):
        if i % 100 == 0:
            print(len(data))
        vec = gene_tensor[i,:]
        summary = summary_list[i]
        if len(summary.split(' ')) > 64:
            summary = summary.split('.')[0]
        summary = re.sub(u"\\(.*?\\)|\\{.*?}|\\[.*?]", "", summary)
        title = title_list[i]
        title = re.sub(u"\\(.*?\\)|\\{.*?}|\\[.*?]", "", title)
        if summary == 'This SuperSeries is composed of the SubSeries listed below.':
            continue
        text = title + tokenizer.sep_token + summary
        inputs = tokenizer([text], padding='max_length', truncation=True, return_tensors="pt", max_length=args.max_length)
        input_ids = inputs['input_ids'].to(device_0)
        mask = inputs['attention_mask'].to(device_0)
        result = encoder(input_ids, attention_mask=mask)
        embeddings = result.last_hidden_state[:, 0, :]
        embeddings = torch.squeeze(embeddings, 0)
        embeddings = embeddings.detach().cpu().numpy()
        data.append((vec, title, summary, embeddings))
    print(len(data))
    return data

def create_data(args):
    print('Creating the data')
    data = load_data(args)
    raw_data = create_text_vec(data, args)
    raw_data = np.array(raw_data, dtype=object)
    np.save(args.data_dir + args.platform + '_raw_data.npy', raw_data)

def bleu_filter(args):
    raw_data = np.load(args.data_dir + args.platform + '_raw_data.npy', allow_pickle=True)
    new_raw_data = []
    for i in range(len(raw_data)):
        if i % 20 == 0:
            print(i, len(new_raw_data))
        vec = raw_data[i][0]
        title = raw_data[i][1]
        summary = raw_data[i][2]
        embedding = raw_data[i][3]
        bleu_list = []
        for j in range(len(raw_data)):
            if i == j:
                continue
            else:
                sub_summary = raw_data[j][2]
                bleu = sentence_bleu([sub_summary.split(' ')], summary.split(' '), weights=(1,0,0,0))
                bleu_list.append(bleu)
        bleu_list = np.array(bleu_list)
        # quartiles = np.percentile(bleu_list, [25,50,75])
        # five_numbers = [np.min(bleu_list)] + list(quartiles) + [np.max(bleu_list)]
        # print(five_numbers)
        if np.median(bleu_list) > 0.09:
            new_raw_data.append((vec, title, summary, embedding))
    new_raw_data = np.array(new_raw_data, dtype=object)
    np.save(args.data_dir + args.platform + '_bleu.npy', new_raw_data)
    new_raw_data = []
    raw_data = np.load(args.data_dir + args.platform + '_bleu.npy', allow_pickle=True)
    print(len(raw_data))
    for i in range(len(raw_data)):
        if i % 20 == 0:
            print(i, len(new_raw_data))
        vec = raw_data[i][0]
        title = raw_data[i][1]
        summary = raw_data[i][2]
        embedding = raw_data[i][3]
        bleu_list = []
        for j in range(len(raw_data)):
            if i == j:
                continue
            else:
                sub_summary = raw_data[j][2]
                bleu = sentence_bleu([sub_summary.split(' ')], summary.split(' '), weights=(1,0,0,0))
                bleu_list.append(bleu)
        bleu_list = np.array(bleu_list)
        if np.median(bleu_list) > 0.13:
            new_raw_data.append((vec, title, summary, embedding))
    new_raw_data = np.array(new_raw_data, dtype=object)
    np.save(args.data_dir + args.platform + '_bleu_0.13.npy', new_raw_data)

def data2text(args):
    device_0 = torch.device(args.device)
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    raw_data = np.load(args.data_dir + args.platform + '_bleu_0.13.npy', allow_pickle=True)
    new_raw_data = []
    for sample in raw_data:
        new_raw_data.append((sample[0], sample[2]))
    raw_data = np.array(new_raw_data, dtype=object)
    print(len(raw_data))
    tokenizer = AutoTokenizer.from_pretrained("t5-base")
    K = 4
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    # Split the dataset first to get the vec space in train_vec
    bleu_1_list = [0] * 10
    bleu_2_list = [0] * 10
    bleu_3_list = [0] * 10
    bleu_4_list = [0] * 10
    rouge_1_list = [0] * 10
    rouge_2_list = [0] * 10
    rouge_L_list = [0] * 10
    meteor_list = [0] * 10
    nist_list = [0] * 10
    for train_index, test_index in kf.split(raw_data):
        print('Fold Starting')
        train_vec = raw_data[train_index]
        # Find the nearest vec to get the source
        data = post_process(raw_data, train_vec, args, k=K) 
        data = np.array(data, dtype=object)
        print('Loading the model')
        model = AutoModelWithLMHead.from_pretrained("t5-base")
        model.to(device_0)
        optimizer = torch.optim.Adam(model.parameters())
        train = data[train_index]
        test = data[test_index]
        train_set = KNN_LM(tokenizer, train)
        test_set = KNN_LM(tokenizer,test)
        train_loader = DataLoader(train_set, batch_size=16, drop_last=True, shuffle=True, num_workers=1)
        test_loader = DataLoader(test_set, batch_size=16, drop_last=False, shuffle=True, num_workers=1)
        epoch = 10
        print('Training')
        for i in range(epoch):
            total_loss = 0
            for batch in tqdm(train_loader):
                source = batch['source']
                target = batch['target']
                input_embs = 0
                for k in range(K):
                    source_token = tokenizer.batch_encode_plus(list(source[k]), max_length=args.max_length, add_special_tokens=True, padding='max_length',truncation=True, return_tensors='pt') 
                    input_ids = source_token['input_ids']
                    input_emb = model.get_input_embeddings()(input_ids.to(device_0))
                    input_embs += input_emb
                input_emb = input_embs / K
                target_token = tokenizer.batch_encode_plus(target, max_length=args.max_length, add_special_tokens=True, padding='max_length',truncation=True, return_tensors='pt') 
                y_ids = target_token['input_ids']
                y_attention_mask = target_token['attention_mask']
                lm_labels = y_ids.clone()
                lm_labels[y_ids[:,:] == tokenizer.pad_token_id] = -100
                optimizer.zero_grad()
                encoder_outputs = model.get_encoder()(inputs_embeds = input_emb.to(device_0))
                outputs = model(
                    inputs_embeds = input_emb.to(device_0),
                    decoder_attention_mask=y_attention_mask.to(device_0),
                    labels=lm_labels.to(device_0)
                )
                loss = outputs.loss
                loss.backward()
                optimizer.step()
                total_loss += loss
            print(total_loss/len(train_loader))
            torch.save(model, 'model/t5_{0}'.format(i))
        
        with torch.no_grad():
            print('Short Inferencing')
            max_length = 0
            for i in range(epoch):
                count = 0
                total_bleu_1 = 0
                total_bleu_2 = 0
                total_bleu_3 = 0
                total_bleu_4 = 0
                total_rouge_1 = 0
                total_rouge_2 = 0
                total_rouge_L = 0
                total_meteor = 0
                total_nist = 0
                model = torch.load('model/t5_{0}'.format(i))
                model.to(device_0)
                for batch in tqdm(test_loader):
                    source = batch['source']
                    target = batch['target']
                    max_length = 64
                    input_embs = 0
                    for k in range(K):
                        source_token = tokenizer.batch_encode_plus(list(source[k]), max_length=args.max_length, add_special_tokens=True, padding='max_length',truncation=True, return_tensors='pt') 
                        input_ids = source_token['input_ids']
                        input_emb = model.get_input_embeddings()(input_ids.to(device_0))
                        input_embs += input_emb
                    input_emb = input_embs / K
                    target_token = tokenizer.batch_encode_plus(target, max_length=args.max_length, add_special_tokens=True, padding='max_length',truncation=True, return_tensors='pt') 
                    encoder_outputs = model.get_encoder()(inputs_embeds = input_emb.to(device_0))
                    generated_ids = model.generate(
                        encoder_outputs = encoder_outputs,
                        max_length=max_length, 
                        num_beams=2,
                        repetition_penalty=2.5, 
                        length_penalty=1.0, 
                        early_stopping=True
                    )
                    preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
                    for j in range(len(preds)):
                        pred = preds[j].split(' ')
                        gt = target[j].split(' ') 
                        bleu_1 = sentence_bleu([gt], pred, weights=(1,0,0,0))
                        bleu_2 = sentence_bleu([gt], pred, weights=(0,1,0,0))
                        bleu_3 = sentence_bleu([gt], pred, weights=(0,0,1,0))
                        bleu_4 = sentence_bleu([gt], pred, weights=(0,0,0,1))
                        score = scorer.score(target[j], preds[j])
                        rouge_1 = score['rouge1'].recall
                        rouge_2 = score['rouge2'].recall
                        rouge_L = score['rougeL'].recall
                        meteor = meteor_score([target[j]], preds[j])
                        nist = sentence_nist([gt], pred)
                        total_bleu_1 += bleu_1
                        total_bleu_2 += bleu_2
                        total_bleu_3 += bleu_3
                        total_bleu_4 += bleu_4
                        total_rouge_1 += rouge_1
                        total_rouge_2 += rouge_2
                        total_rouge_L += rouge_L
                        total_meteor += meteor
                        total_nist += nist
                        count += 1
                bleu_1_list[i] += total_bleu_1/count
                bleu_2_list[i] += total_bleu_2/count
                bleu_3_list[i] += total_bleu_3/count
                bleu_4_list[i] += total_bleu_4/count
                rouge_1_list[i] += total_rouge_1/count
                rouge_2_list[i] += total_rouge_2/count
                rouge_L_list[i] += total_rouge_L/count
                meteor_list[i] += total_meteor/count
                nist_list[i] += total_nist/count
                torch.cuda.empty_cache()
    bleu_1_list = np.array(bleu_1_list)
    bleu_2_list = np.array(bleu_2_list)
    bleu_3_list = np.array(bleu_3_list)
    bleu_4_list = np.array(bleu_4_list)
    rouge_1_list = np.array(rouge_1_list)
    rouge_2_list = np.array(rouge_2_list)
    rouge_L_list = np.array(rouge_L_list)
    meteor_list = np.array(meteor_list)
    nist_list = np.array(nist_list)
    print('bleu_1: ', bleu_1_list/5)
    print('bleu_2: ', bleu_2_list/5)
    print('bleu_3: ', bleu_3_list/5)
    print('bleu_4: ', bleu_4_list/5)
    print('rouge_1: ', rouge_1_list/5)
    print('rouge_2: ', rouge_2_list/5)
    print('rouge_L: ', rouge_L_list/5)
    print('meteor: ', meteor_list/5)
    print('nist: ', nist_list/5)

def text2data(args):
    raw_data = np.load(args.data_dir + args.platform + '_bleu_0.13.npy', allow_pickle=True)  
    new_raw_data = []
    for sample in raw_data:
        vec = sample[0]
        embedding = sample[3]
        new_raw_data.append((vec,embedding))
    raw_data = np.array(new_raw_data, dtype=object)
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    corr_list = [0] * 20
    for train_index, test_index in kf.split(raw_data):
        model = FC(768,100000)
        model.to(device_0)
        criteria = nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters())
        train = raw_data[train_index]
        test = raw_data[test_index]
        train_set = Text2Data(train)
        test_set = Text2Data(test)
        train_loader = DataLoader(train_set, batch_size=16, drop_last=True, shuffle=True, num_workers=1)
        test_loader = DataLoader(test_set, batch_size=16, drop_last=False, shuffle=True, num_workers=1)
        epoch = 20
        for i in range(epoch):
            total_loss = 0
            for batch in tqdm(train_loader):
                embed = batch['emb'].to(device_0)
                vec = batch['vec'].float()
                label = vec.to(device_0)
                output = model(embed)
                loss = torch.sqrt(criteria(output, label))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss
            avg_loss = total_loss / len(train_loader)
            print('Epoch: ' + str(i) + ' with avg_loss ' + str(avg_loss))
            torch.save(model, 'model/text2data_{0}'.format(i))
        with torch.no_grad():
            for i in range(epoch):
                total_corr = 0
                count = 0
                model = torch.load('model/text2data_{0}'.format(i))
                model.to(device_0)
                for batch in tqdm(test_loader):
                    embed = batch['emb'].to(device_0)
                    vec = batch['vec'].float()
                    label = vec.to(device_0)
                    output = model(embed)
                    loss = torch.sqrt(criteria(output, label))
                    total_corr += loss.item()
                    count += 1
                print(total_corr/count)
                corr_list[i] += total_corr/count
    corr_list = np.array(corr_list)
    print(args.platform, min(corr_list/5), args.model)

def text2data_baseline(args):
    raw_data = np.load(args.data_dir + args.platform + '_bleu_0.13.npy', allow_pickle=True)  
    print(len(raw_data))
    new_raw_data = []
    for sample in raw_data:
        vec = sample[0]
        embedding = sample[3]
        new_raw_data.append((vec,embedding))
    raw_data = np.array(new_raw_data, dtype=object)
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    corr_list = [0]
    for train_index, test_index in kf.split(raw_data):
        train = raw_data[train_index]
        test = raw_data[test_index]
        train_vec = []
        for sample in train:
            vec = sample[0]
            train_vec.append(vec)
        train_vec = np.array(train_vec)
        train_vec = np.mean(train_vec, axis=0)
        test_set = Text2Data(test)
        test_loader = DataLoader(test_set, batch_size=16, drop_last=False, shuffle=True, num_workers=1)
        epoch = 1
        with torch.no_grad():
            for i in range(epoch):
                total_corr = 0
                count = 0
                for batch in tqdm(test_loader):
                    vec = batch['vec'].float()
                    label = vec.cpu().numpy()
                    for j in range(label.shape[0]):
                        corr = np.sqrt(np.mean((train_vec[j] - label[j])**2))
                        total_corr += corr
                        count += 1
                print(total_corr/count)
                corr_list[i] += total_corr/count
    corr_list = np.array(corr_list)
    print(corr_list/5)

if __name__ == '__main__':
    torch.multiprocessing.set_sharing_strategy('file_system')
    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", default= '', type=str)
    parser.add_argument("--device", type=str, default='cuda:1', help='device') 
    parser.add_argument("--emb_dim", type=int, default=768, help='specter dim') 
    parser.add_argument("--max_length", type=int, default=64, help='max_length')  
    parser.add_argument("--platform", type=str, default='GPL570', help='platform')
    parser.add_argument("--model", type=str, default='specter', help='model of text2data')
    parser.add_argument("--batch_size", type=int, default=16)
    args = parser.parse_args()
    device_0 = torch.device(args.device)
    # create_data(args)
    # bleu_filter(args)
    # correlation(args)
    # data2text(args)
    # text2data(args)