from pathlib import Path
import tqdm
import torch
import ipdb
import pandas as pd
from transformers import AutoTokenizer, AutoModelForMaskedLM
from datasets import load_dataset
import ipdb
import numpy as np

from sentence_transformers import SentenceTransformer

sentence_model = SentenceTransformer('paraphrase-mpnet-base-v2').to('cuda:1')

from sklearn.model_selection import train_test_split
seed = 1

## Reading all the datasets ##
Psy_df = pd.read_csv("./data/youtube/Youtube01-Psy.csv")
KatyPerry_df = pd.read_csv("./data/youtube/Youtube02-KatyPerry.csv")
LMFAO_df = pd.read_csv("./data/youtube/Youtube03-LMFAO.csv")
Eminem_df = pd.read_csv("./data/youtube/Youtube04-Eminem.csv")
Shakira_df = pd.read_csv("./data/youtube/Youtube05-Shakira.csv")

train_ratio = 0.8
validation_ratio = 0.1
test_ratio = 0.1

dataset = {}

for fname in ["Youtube01-Psy.csv", "Youtube02-KatyPerry.csv", "Youtube03-LMFAO.csv", "Youtube04-Eminem.csv", "Youtube05-Shakira.csv"]:
    df = pd.read_csv(Path('./data/youtube') / fname)
    data_subset = []
    for index, row in df.iterrows():
        text = row['CONTENT']
        label = row['CLASS'] # 1 is spam 0 is not spam
        data_subset.append((text, label))

    n_spam = len([item for item in data_subset if item[1] == 1])
    n_not_spam = len([item for item in data_subset if item[1] == 0])
    print(f"spam: {n_spam}, not spam: {n_not_spam}")
    
    X = [item[0] for item in data_subset]
    y = [item[1] for item in data_subset]
    x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=1 - train_ratio, random_state=seed)
    x_val, x_test, y_val, y_test = train_test_split(x_test, y_test, test_size=test_ratio/(test_ratio + validation_ratio), random_state=seed)

    dataset[fname] = {'x_train': x_train,
                      'x_val': x_val,
                      'x_test': x_test,
                      'y_train': y_train,
                      'y_val': y_val,
                      'y_test': y_test,
                     }

# aggregate
train_x = []
train_y = []
val_x = []
val_y = []
test_x = []
test_y = []
for split in ['train', 'val', 'test']:
    for key in dataset.keys():
        x = dataset[key][f"x_{split}"]
        y = dataset[key][f"y_{split}"]
        eval(f"{split}_x").extend(x)
        eval(f"{split}_y").extend(y)

 
    # inference
    svector_savefile = f"./data/youtube/{split}.csv"
    xy_savefile = f"./data/youtube/{split}_xy.txt"
    svector_file = open(svector_savefile, 'w')
    xy_file = open(xy_savefile, 'w')

    assert len(eval(f"{split}_x")) == len(eval(f"{split}_y"))
    data_items = list(zip(eval(f"{split}_x"), eval(f"{split}_y")))
    for item in tqdm.tqdm(data_items):
        text = item[0].replace('\n', ' ')
        sentence_embedding = sentence_model.encode(text)
        semb_str = ','.join([str(w) for w in sentence_embedding.tolist()]) 
        svector_file.write(semb_str + '\n')
        xy_file.write(str(item[1]) + '\t' + text + '\n')
        #ipdb.set_trace()
#ipdb.set_trace()

"""



svector_savefile = './data_youtube/sentence_vectors.csv'
xy_savefile = './data_youtube/xy.txt'

vector_file = open(vector_savefile, 'w')
svector_file = open(svector_savefile, 'w')
token_file = open(token_savefile, 'w')
xy_file = open(xy_savefile, 'w')

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased").to('cuda:0')
dataset = load_dataset("tweets_hate_speech_detection")



for item in tqdm.tqdm(dataset['train']):
    #lengths.append(len(tok.tokenize(item['tweet'])))

    # Sentence
    sentence_embedding = sentence_model.encode(item['tweet'])
    semb_str = ','.join([str(w) for w in sentence_embedding.tolist()])
    svector_file.write(semb_str + '\n')
    #ipdb.set_trace()

    xy_file.write(str(item['label']) + '\t' + item['tweet'] + '\n')
    #ipdb.set_trace()

vector_file.close()
svector_file.close()
token_file.close()
xy_file.close()
"""
