import csv
import torch.utils.data as data
import os
import json
import time
import random
import datetime
from transformers import TrainingArguments, Trainer
from torch.utils.data import TensorDataset, random_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import BertForSequenceClassification, AdamW, BertConfig
from transformers import get_linear_schedule_with_warmup
from collections import Counter
from sklearn import metrics
from datasets import load_dataset, load_metric
from transformers import BertTokenizer, BertForSequenceClassification, BertModel, BertConfig
import torch
import torch.nn.functional as FF
import torch.nn as nn
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification, AutoModel, TrainingArguments, Trainer, DebertaModel
from sklearn.metrics import classification_report
import logging
import subprocess
from SupCsTrainer import SupCsTrainer
import numpy as np
CUDA = "1"
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA
device = torch.device("cuda")
GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]
task = "sst2"
model_name = "bert-base-uncased"
batch_size = 32
actual_task = "mnli" if task == "mnli-mm" else task
#dataset = load_dataset("glue", actual_task)
dataset = load_dataset("csv", data_files={"train": 'train.tsv',
    "validation": 'dev.tsv'},delimiter='\t')
print(dataset['train'][0])
#metric = load_metric('glue', actual_task)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
#tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)
class InputExample(object):
    def __init__(self, guid, text_a, label):
        self.guid = guid
        self.text_a = text_a
        self.label = label
def _read_tsv(input_file):
    with open(input_file, "r", encoding="utf-8") as f:
        reader = csv.reader(f, delimiter="\t")
        lines = []
        for line in reader:
            lines.append(line)
        return lines




sentence1_key="sentence"
sentence2_key = None
def preprocess_function(examples):
    if sentence2_key is None:
        return tokenizer(examples[sentence1_key], truncation=True)

    return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)
    
encoded_dataset = dataset.map(preprocess_function, batched=True)
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if task != "stsb":
        predictions = np.argmax(predictions, axis=1)
    else:
        predictions = predictions[:, 0],
    return metric.compute(predictions=predictions, references=labels)
validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matche" if task == "mnli" else "validation"
train_dataset = encoded_dataset["train"]
test_dataset = encoded_dataset[validation_key]

def run_model( 
              model_name,
              train_dataset,
                  test_dataset, 
                  w_drop_out = [0.0,0.05,0.2],
                  temperature=0.05,
                  lr = 5e-05,
                  bs = 10,
                  epoch = 5,
                  warmup_steps=500,
                  logging_steps = 200,
                  freeze_base=False,
                  evaluation_strategy = 'no',
                  base_flag=False, 
                  num_labels = 2,
                  SupCs=False, 
                  sub_name = "baseline"):
    model = BertModel.from_pretrained("bert-base-uncased")
    print('Base Bert Loaded.')
    args = TrainingArguments(
            output_dir = './results8',
            save_total_limit = 1,
            num_train_epochs=epoch,
            per_device_train_batch_size=bs, 
            per_device_eval_batch_size=64,
            evaluation_strategy = evaluation_strategy,
            logging_steps = logging_steps,
            learning_rate = lr,
            eval_steps = 200,
            warmup_steps=warmup_steps, 
            weight_decay=0.01,              
            logging_dir='./logs8',
        )
    
    if SupCs:
        print('Using SupCs'),
        trainer = SupCsTrainer(
                    model,
                    args,
                    train_dataset=train_dataset,
                    eval_dataset=test_dataset,
                    tokenizer=tokenizer
                )
        trainer.set_views(w_drop_out, temperature)
    logging.basicConfig(level = logging.INFO)
    trainer.train()
    #subprocess.call(["rm", "-r", ./results/])
    trainer.save_model('./' + sub_name)
run_model(   
             model_name,
             train_dataset,
             test_dataset,
             w_drop_out = [0.05],          # drop-out for views\n",
             temperature = 0.05,          # temperature for contrastive loss\n",
             lr = 3e-5,                    # best:3e-5\n",
             sub_name='cs_baseline_dropout0.10.05_8',      # pretrained CS models \n",
             logging_steps = 50,
             warmup_steps = 100,
             base_flag=True,              # if True uses base bert (for contrasitive loss should be True)\n",
             SupCs= True,                 # if True uses Contrastive loss\n",
             bs = 8, #64 10              # This is multiplied by GPUs # best:8\n",
             epoch = 3
           )


