import torch
from torch import nn
import sys
from model import *
# from utils import *
import torch.optim as optim
import numpy as np
import time
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
import pickle
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm


def multiclass_acc(preds, truths):
    """
    Compute the multiclass accuracy w.r.t. groundtruth
    :param preds: Float array representing the predictions, dimension (N,)
    :param truths: Float/int array representing the groundtruth classes, dimension (N,)
    :return: Classification accuracy
    """
    return np.sum(np.round(preds) == np.round(truths)) / float(len(truths))


def weighted_accuracy(test_preds_emo, test_truth_emo):
    true_label = (test_truth_emo > 0)
    predicted_label = (test_preds_emo > 0)
    tp = float(np.sum((true_label == 1) & (predicted_label == 1)))
    tn = float(np.sum((true_label == 0) & (predicted_label == 0)))
    p = float(np.sum(true_label == 1))
    n = float(np.sum(true_label == 0))

    return (tp * (n / p) + tn) / (2 * n)


def eval_mosei_senti(results, truths, exclude_zero=False):
    test_preds = results.view(-1).cpu().detach().numpy()
    test_truth = truths.view(-1).cpu().detach().numpy()

    non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0 or (not exclude_zero)])

    test_preds_a7 = np.clip(test_preds, a_min=-3., a_max=3.)
    test_truth_a7 = np.clip(test_truth, a_min=-3., a_max=3.)
    test_preds_a5 = np.clip(test_preds, a_min=-2., a_max=2.)
    test_truth_a5 = np.clip(test_truth, a_min=-2., a_max=2.)

    mae = np.mean(np.absolute(test_preds - test_truth))  # Average L1 distance between preds and truths
    corr = np.corrcoef(test_preds, test_truth)[0][1]
    mult_a7 = multiclass_acc(test_preds_a7, test_truth_a7)
    mult_a5 = multiclass_acc(test_preds_a5, test_truth_a5)
    f_score = f1_score((test_preds[non_zeros] > 0), (test_truth[non_zeros] > 0), average='weighted')
    binary_truth = (test_truth[non_zeros] > 0)
    binary_preds = (test_preds[non_zeros] > 0)

    t=Texttable()
    
    t.add_rows([['Metric','value'],['MAE',mae],['Correlation coefficient',corr],['mult_7',mult_a7],['mult_5',mult_a5],['f1 score',f_score],['accuracy',accuracy_score(binary_truth, binary_preds)]])

    print("MAE: ", mae)
    print("Correlation Coefficient: ", corr)
    print("mult_acc_7: ", mult_a7)
    print("mult_acc_5: ", mult_a5)
    print("F1 score: ", f_score)
    print("Accuracy: ", accuracy_score(binary_truth, binary_preds))

    print("-" * 50)
    return t


def scores(results, truths):
    emos = ["Neutral", "Happy", "Sad", "Angry"]
    # print("truth:",len(truths))
    # print("result",len(results))
    preds = results.view(-1, 4, 2).cpu().detach().numpy()
    
    label = truths.view(-1, 4).cpu().detach().numpy()
    vals={}
    t=Texttable()
    
    for emo_ind in range(4):
        print(f"{emos[emo_ind]}: ")
        test_preds_i = np.argmax(preds[:, emo_ind], axis=1)
        
        test_truth_i = label[:, emo_ind]
        
        f1 = f1_score(test_truth_i, test_preds_i, average='weighted')
        acc = accuracy_score(test_truth_i, test_preds_i)
        vals[emo_ind]={'emotion':emos[emo_ind],'f1':f1,'acc':acc}
        t.add_rows([['emotion','f1','acc'],[emos[emo_ind],f1,acc]])
        # t.add_rows()
        print("  - F1 Score: ", f1)
        print("  - Accuracy: ", acc)
    
    
    return t
    
# dataset='iemocap'


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def train_model(hparams,train_loader, test_loader, valid_loader, length_test,length_valid,length_train):
    # model=build_model(model_size=128,num_heads=4,num_blocks=8, mask=False,device="cuda", dataset='iemocap')
    #### best model
    # model_size=128, tnum_heads=4,num_blocks=8, device="cuda", dataset='iemocap'
    
    
    lr=hparams.get('lr',0.001)
    model = model1(hparams).to(device)

    print("parameters=",count_parameters(model))
    datasetName=hparams.get('dataset','iemocap')
    optimizer = optim.AdamW(model.parameters(), lr=lr,amsgrad=True)#,weight_decay=0.2)
    if datasetName == 'iemocap':
        criterion = nn.CrossEntropyLoss()
    else:

        criterion = nn.L1Loss()

    scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.1, verbose=True)
    # scheduler=torch.optim.lr_scheduler.StepLR(optimizer,30,0.1)
    # print(optimizer)
    def train(model, optimizer, criterion,batch_size):
        epoch_loss = 0
        model.train()
        num_batches = length_train // batch_size
        proc_loss, proc_size = 0, 0
        start_time = time.time()
        correct = 0
       
        avg_loss=0
        for i, (batch_x, batch_y) in enumerate(train_loader):
            sample_ind, text, audio, vision = batch_x
            
            labels = batch_y.squeeze(-1)
            model.zero_grad()
            batch_size = text.size(0)
            batch_chunk = 1
            combined_loss = 0
            
            text=text.to(device)
            audio=audio.to(device)
            vision=vision.to(device)
            preds = model(text, audio, vision).to(device)
            if datasetName == 'iemocap':
                
                preds = preds.view(-1, 2)
                
                labels = labels.view(-1).to(device)
                

            labels=labels.to(device)
            raw_loss = criterion(preds, labels)
            combined_loss = raw_loss
            combined_loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            proc_loss += raw_loss.item() * batch_size
            proc_size += batch_size
            epoch_loss += combined_loss.item() * batch_size
            
            _, predicted = torch.max(preds, 1)
            
            avg_loss = proc_loss / proc_size
            elapsed_time = time.time() - start_time

            if i % num_batches == 0 and i > 0:
                avg_loss = proc_loss / proc_size
                elapsed_time = time.time() - start_time
                print('Epoch {:2d} | Batch {:3d}/{:3d} | Time/Batch(ms) {:5.2f} | Train Loss {:5.4f}'.
                      format(epoch, i, num_batches, elapsed_time * 1000 / 20, avg_loss))
                proc_loss, proc_size = 0, 0
                start_time = time.time()



            # return epoch_loss / length_train

    def evaluate(model, criterion, test=False):
        model.eval()
        loader = test_loader if test else valid_loader
        total_loss = 0.0

        results = []
        truths = []

        with torch.no_grad():
            for i_batch, (batch_X, batch_Y) in enumerate(loader):
                sample_ind, text, audio, vision = batch_X
                eval_attr = batch_Y.squeeze(dim=-1)
                # eval_attr = batch_Y.squeeze(dim=-1)  # if num of labels is 1
                text=text.to(device)
                audio=audio.to(device)
                vision=vision.to(device)

                
                preds = model(text, audio, vision).to(device)
                if datasetName == 'iemocap':
                    eval_attr = eval_attr.long()
                    # eval_attr = torch.argmax(eval_attr, dim=-1)
                    eval_attr = eval_attr.view(-1).to(device)
                    preds = preds.view(-1, 2)
                eval_attr=eval_attr.to(device)
                batch_size = text.size(0)

                total_loss += criterion(preds, eval_attr).item() * batch_size

                # Collect the results into dictionary
                results.append(preds)
                truths.append(eval_attr)
        avg_loss = total_loss/(length_test if test else length_valid)
        results = torch.cat(results)
        truths = torch.cat(truths)
        return avg_loss, results, truths
    best_valid=1e8
    epochs=hparams.get('epochs',5)
    n_no_improve=0
    for epoch in range(1,epochs+1):
        start=time.time()
        train(model, optimizer, criterion,batch_size)

        val_loss, _, _ = evaluate(model, criterion, test=False)
        test_loss, _, _ = evaluate(model, criterion, test=True)
        # print('\ntrain loss for epoch:', epoch, '|', lo)
        print("time taken:",time.time()-start,"| valid loss for epoch:",epoch,"|",val_loss,"test loss ",test_loss,'\n')
        scheduler.step(val_loss)
        n_no_improve+=1
        if val_loss<best_valid:
          n_no_improve=0
          torch.save(model,'model_best_perform.pt')
          best_valid=val_loss
          print("model saved at epoch:",epoch,'\n')
        if n_no_improve==10:
          print('early stopping triggered')
          break
    model=torch.load('model_best_perform.pt')
    lo, results, truths = evaluate(model, criterion, test=True)
    print("\nfinal test loss:",lo)
    
    if datasetName == 'iemocap':
        t=scores(results, truths)
       
    else:

        t=eval_mosei_senti(results, truths, True)
    print(t.draw())

    # with open('ablation_hybrid2.txt','a') as fi:
    #   fi.write('\n')
    #   fi.write(str(hparams.get('model_type')))
    #   fi.write('\n')
    #   fi.write(str(hparams.get('dataset')))
    #   fi.write('\n')
    #   fi.write('Text '+str(hparams.get('T',True))+' ')
    #   fi.write('Audio '+str(hparams.get('A',True))+' ')
    #   fi.write('Video '+str(hparams.get('V',True))+' ')
    #   fi.write('\n')
    #   fi.write(t.draw())
    #   fi.write('\n')