import torch
import load_data
import transformers
import model
import pickle
import torch.nn as F 
import time
import os
import sys
import datetime
import torch.nn as nn
from transformers import AdamW
from torch.optim.swa_utils import SWALR
from sklearn.metrics import f1_score
from transformers import XLNetModel, XLNetTokenizer
from tqdm import tqdm


class TrainModel():

    def save_model(self, step, accuracy):
        
        model_path = os.path.join("saved_models", "{}_seed-{}_bs-{}_lr-{}_step-{}_acc-{}_type-{}.cont".format(self.desc, self.seed, self.batch_size, self.learning_rate, step, accuracy, self.model_type))

        if torch.cuda.device_count() > 1:
             torch.save(self.xlnet_model.module.state_dict(), model_path)
        else:
             torch.save(self.xlnet_model.state_dict(), model_path)
            


    def __init__(self, model_type, learning_rate, batch_size, negs, margin, eval_interval, train_file, dev_file, test_file, desc, seed, datatype):
        self.batch_size = batch_size
        self.model_type = model_type
        self.learning_rate = learning_rate
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.train_file = train_file
        self.dev_file = dev_file
        self.test_file = test_file
        self.negs = negs
        self.margin = margin
        self.desc = desc
        self.seed = seed
        self.datatype = datatype
        self.bestacc = 0.0
        

        self.xlnet_model = model.MIM(self.model_type, self.negs, self.batch_size, self.margin, self.device)
        try:
            if sys.argv[5]:
                pretrained = torch.load(sys.argv[5])
                self.xlnet_model.load_state_dict(pretrained, strict=False)
                print("Loaded pretrained model")
        except:
            pass
        if torch.cuda.device_count() > 1:
          print("Let's use", torch.cuda.device_count(), "GPUs!")
          self.xlnet_model = nn.DataParallel(self.xlnet_model)
        self.xlnet_model = self.xlnet_model.to(self.device)
        
        
        self.optimizer=AdamW(self.xlnet_model.parameters(), lr=self.learning_rate)
        self.scheduler = SWALR(self.optimizer, anneal_strategy="linear", anneal_epochs=10, swa_lr=1e-6)
        self.total_loss = 0.0

        self.eval_interval = eval_interval


    def train_xlnet_model(self):    
        train_data = load_data.LoadConnData(self.train_file, self.batch_size, self.model_type, self.device, self.datatype, self.negs)
        train_loader = train_data.data_loader()
        start = time.time()
        self.xlnet_model.train()

        if 1:
            for step, data in enumerate(train_loader):
                
                self.optimizer.zero_grad()

                try:
                    pos_input, neg_input = data
                except Error as e:
                    print(e)
                    continue


                pos_score, neg_scores = self.xlnet_model(pos_input, neg_input)
                if torch.cuda.device_count() > 1:
                    loss = self.xlnet_model.module.contrastiveLoss(pos_score, neg_scores)
                else:
                    loss = self.xlnet_model.contrastiveLoss(pos_score, neg_scores)
            
                loss.backward()
                self.optimizer.step()

                self.total_loss += loss.item()

                if step%self.eval_interval == 0 and step > 0:
                    self.eval_model(self.dev_file, step, start)
                    self.scheduler.step()
                    
            
                if step%1000 == 0:
                    end = time.time()
                    full_time = time.asctime(time.localtime(end))
                    print("LOG Time: {} Elapsed: {} Steps: {} Loss: {}".format(full_time, end-start, step, loss.item()))

    
            self.eval_model(self.test_file, step, start)

    def eval_model(self, data_file, step, start):
        batch_size = 1
        self.xlnet_model.eval()
        test_data = load_data.LoadConnData(data_file, self.batch_size, self.model_type, self.device, self.datatype, self.negs)
        test_loader = test_data.data_loader()

        correct = 0.0
        total = 0.0

        with torch.no_grad():
            for data in test_loader:
                try:
                    pos_input, neg_inputs = data
                except Error as e:
                    print(e)
                    continue

                pos_score, neg_scores = self.xlnet_model(pos_input, neg_inputs)
                max_neg_score = torch.max(neg_scores, -1).values
                
                if pos_score > max_neg_score:
                    correct += 1.0
                total += 1.0

        self.xlnet_model.train()
        end = time.time()
        full_time = time.asctime(time.localtime(end))
        acc = correct/total
        if data_file == self.test_file:
            print(self.desc, self.seed)
            print("TEST EVAL Time: {} Elapsed: {} Steps: {} Acc: {}".format(full_time, end-start, step,  acc))
            if step > 0:
                self.bestacc = acc
                self.save_model(step, acc)

                
        return
            
                



train_file = sys.argv[1] 
dev_file = sys.argv[2]
test_file = sys.argv[3] 

batch_size = 1 
negs = 5 #number of negatives for each positive
model_type = 'base' 
eval_interval = 5000 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seed = 100
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
start = time.time()

learning_rate = 5e-6 # start LR; see Line #64 for anneal rate
margin = 0.1 #contrastive loss margin
desc = sys.argv[4] #model file description
datatype = 'pair'

Trainer = TrainModel(model_type, learning_rate, batch_size, negs, margin, eval_interval, train_file, dev_file, test_file, desc, seed, datatype)
Trainer.train_xlnet_model()
