from transformers import RobertaTokenizer, RobertaModel
import torch
import csv
import torch.utils.data as data
import os
import json
import time
import random
import datetime
from LogME import LogME
from transformers import TrainingArguments, Trainer
from torch.utils.data import TensorDataset, random_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import BertForSequenceClassification, AdamW, BertConfig
from transformers import get_linear_schedule_with_warmup
from collections import Counter
from sklearn import metrics
from datasets import load_dataset, load_metric
from transformers import BertTokenizer, BertForSequenceClassification, BertModel, BertConfig
import torch
import torch.nn.functional as FF
import torch.nn as nn
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification, AutoModel, TrainingArguments, Trainer, DebertaModel
from sklearn.metrics import classification_report
import logging
import subprocess
from SupCsTrainer import SupCsTrainer
import numpy as np
CUDA = "5"
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA
device = torch.device("cuda")
#tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model_name='roberta-base'
task = "sst2"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
#model = RobertaModel.from_pretrained('roberta-base')
#max_len=0
#model = model.to(device)
def prepare_features(seq_1, zero_pad = False, max_seq_length = 120):
    for seq in seq_1:
        enc_text = tokenizer.encode_plus(seq, add_special_tokens=True, max_length=300)
        if(len(enc_text)>max_len):
            max_len = len(enc_text)
        if zero_pad:
            while len(enc_text['input_ids']) < max_seq_length:
                enc_text['input_ids'].append(0)
                enc_text['token_type_ids'].append(0)
    return enc_text
dataset = load_dataset("csv", data_files={"train": 'train.tsv',
    "validation": 'dev.tsv'},delimiter='\t')
sentence1_key="sentence"
sentence2_key = None
def preprocess_function(examples):
    if sentence2_key is None:
        return tokenizer(examples[sentence1_key], truncation=True)

    return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)
encoded_dataset = dataset.map(preprocess_function, batched=True)
validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matche" if task == "mnli" else "validation"
train_dataset = encoded_dataset["train"]
test_dataset = encoded_dataset[validation_key]
def run_model(
              model_name,
              train_dataset,
                  test_dataset,
                  w_drop_out = [0.0,0.05,0.2],
                  temperature=0.05,
                  lr = 5e-05,
                  bs = 10,
                  epoch = 5,
                  warmup_steps=500,
                  logging_steps = 200,
                  freeze_base=False,
                  evaluation_strategy = 'no',
                  base_flag=False,
                  num_labels = 2,
                  SupCs=False,
                  sub_name = "baseline"):
    #model = BertModel.from_pretrained('./cs_baseline')
    model = RobertaModel.from_pretrained('roberta-base')
    print('Roberta Loaded.')
    model.eval()
    model.to(device)
    args = TrainingArguments(
            output_dir = './results2',
            save_total_limit = 1,
            num_train_epochs=epoch,
            per_device_train_batch_size=bs,
            per_device_eval_batch_size=32,
            evaluation_strategy = evaluation_strategy,
            logging_steps = logging_steps,
            learning_rate = lr,
            eval_steps = 200,
            warmup_steps=warmup_steps,
            weight_decay=0.01,
            logging_dir='./logs',
        )

    if SupCs:
        print('Using SupCs'),
        trainer = SupCsTrainer(
                    model,
                    args,
                    train_dataset=train_dataset,
                    eval_dataset=test_dataset,
                    tokenizer=tokenizer
                )

    #logging.basicConfig(level = logging.INFO)
    feature,labels = trainer.get_feature(model,train_dataset)
    return feature,labels

max_len=0
feature,labels = run_model(
             model_name,
             train_dataset,
             test_dataset,
             w_drop_out = [0.1],          # drop-out for views\n",
             temperature = 0.05,          # temperature for contrastive loss\n",
             lr = 3e-5,                    # best:3e-5\n",
             sub_name='cs_baseline',      # pretrained CS models \n",
             logging_steps = 50,
             warmup_steps = 100,
             base_flag=True,              # if True uses base bert (for contrasitive loss should be True)\n",
             SupCs= True,                 # if True uses Contrastive loss\n",
             bs = 32, #64 10              # This is multiplied by GPUs # best:8\n",
             epoch = 3
           )
score = LogME(feature, labels)
print("score:",score)

