# -*- coding: utf-8 -*-

"""PyTorch Utility Functions"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from efficientnet_pytorch import EfficientNet
import matplotlib.pyplot as plt
import time
import os
import copy
import shutil
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import pandas as pd
from tqdm import tqdm
from decimal import Decimal
from sklearn import metrics
from datetime import datetime
import warnings
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix as compute_confusion_matrix,
    precision_recall_fscore_support,
)
import json

print("PyTorch Version: {}".format(torch.__version__),flush=True)
print("Torchvision Version: {}".format(torchvision.__version__),flush=True)


IMGNET_MEAN = [0.485, 0.456, 0.406]
IMGNET_STD = [0.229, 0.224, 0.225]


class MyDataset(Dataset):
    """Dataset class for wrapping images and target labels read from a file

    Arguments:
        A file path
        Path to image folder
        PIL transforms
    """

    def __init__(self, file_path, sep, root_dir, transform=None, check_imgs=False):

        warnings.filterwarnings("ignore", "Possibly corrupt EXIF data", UserWarning)
        warnings.filterwarnings("ignore", "Corrupt EXIF data", UserWarning)
        warnings.filterwarnings("ignore", "Metadata Warning", UserWarning)
        warnings.filterwarnings("ignore", "Palette images with Transparency", UserWarning)

        self.file_path = file_path
        self.root_dir  = root_dir
        self.transform = transform

    # use the code below if you want to eliminate pandas dependency for file processing
        tmp_X = []
        tmp_y = []
        with open(file_path,'r') as f:
           next(f)
           for line in f:
               if line.strip():
                   line = line.strip().split(sep)
                   tmp_X.append(line[1])
                   tmp_y.append(line[2])
        self.X = tmp_X
        self.y = tmp_y
        print(set(tmp_y))

        # df = pd.read_csv(file_path,sep=sep,header=None,names=['image_path','class_label'],dtype=str)
        # if check_imgs:
        #     chk = df['image_path'].apply(lambda x: os.path.isfile(os.path.join(root_dir,x)))
        #     assert chk.all(), \
        #             "Some images referenced in the CSV file were not found: {}".format(df.loc[~chk,'image_path'])
        # self.X = df['image_path'].tolist()
        # self.y = df['class_label'].tolist()

        print("data: "+str(len(self.X)))
        print("labels: " + str(len(self.y)))
        # self.y = [str(l) for l in df['Label'].tolist()]

        self.classes, self.class_to_idx = self._find_classes()
        self.samples = list(zip(self.X,[self.class_to_idx[i] for i in self.y]))
        self.targets = [s[1] for s in self.samples]


    def __getitem__(self, index):
        path, label = self.samples[index]
        #print("image path: {}".format(os.path.join(self.root_dir,path)))
        f = open(os.path.join(self.root_dir,path),'rb')
        img = Image.open(f)
        if img.mode is not 'RGB':
            img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.samples)

    def _find_classes(self):
        classes_set = set(self.y)
        # classes = [str(c) for c in list(classes_set)]
        classes = list(classes_set)
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        print("class labels: "+str(classes))
        return classes, class_to_idx


def load_transform_data(data_dir="images", img_resize=256, input_size=224, batch_size=32):
    """ Transforms the training and validation sets.
    Source: https://discuss.pytorch.org/t/questions-about-imagefolder/774/6

    Parameters
    ----------
    data_dir : str
        Directory of the training and validations image sets
    batch_size : int (default is 32)
        Batch size

    Returns
    -------
    dict
        Contains the set images for training and validation set
    list
        Contains dataset sizes
    list
        Contains class names
    """
    data_transforms = {
        "train": transforms.Compose([
            transforms.Resize((img_resize,img_resize)),
            transforms.RandomCrop((input_size,input_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMGNET_MEAN, IMGNET_STD)
        ]),
        "val": transforms.Compose([
            transforms.Resize((img_resize,img_resize)),
            transforms.CenterCrop((input_size,input_size)),
            transforms.ToTensor(),
            transforms.Normalize(IMGNET_MEAN, IMGNET_STD)
        ]),
    }
    image_datasets = {
        x: datasets.ImageFolder(
            os.path.join(data_dir, x), data_transforms[x]
        )
        for x in ["train", "val"]
    }
    dataloaders = {
        x: torch.utils.data.DataLoader(
            image_datasets[x],
            batch_size=batch_size,
            shuffle=True,
            num_workers=10,
        )
        for x in ["train", "val"]
    }
    dataset_sizes = {
        x: len(image_datasets[x]) for x in ["train", "val"]
    }
    class_names = image_datasets["train"].classes

    return dataloaders, dataset_sizes, class_names


def load_transform_data_for_test(data_dir="images", img_resize=256, input_size=224, batch_size=32):
    """ Transforms the training and validation sets.
    Source: https://discuss.pytorch.org/t/questions-about-imagefolder/774/6

    Parameters
    ----------
    data_dir : str
        Directory of the training and validations image sets
    batch_size : int (default is 32)
        Batch size

    Returns
    -------
    dict
        Contains the set images for training and validation set
    list
        Contains dataset sizes
    list
        Contains class names
    """
    data_transforms = {
        "train": transforms.Compose([
            transforms.Resize((img_resize,img_resize)),
            transforms.CenterCrop((input_size,input_size)),
            transforms.ToTensor(),
            transforms.Normalize(IMGNET_MEAN, IMGNET_STD)
        ]),
        "val": transforms.Compose([
            transforms.Resize((img_resize,img_resize)),
            transforms.CenterCrop((input_size,input_size)),
            transforms.ToTensor(),
            transforms.Normalize(IMGNET_MEAN, IMGNET_STD)
        ]),
    }
    image_datasets = {
        x: datasets.ImageFolder(
            os.path.join(data_dir, x), data_transforms[x]
        )
        for x in ["train", "val"]
    }
    dataloaders = {
        x: torch.utils.data.DataLoader(
            image_datasets[x],
            batch_size=batch_size,
            shuffle=False, # it is important to keep the order of images fixed while testing and saving predictions
            num_workers=10,
        )
        for x in ["train", "val"]
    }
    dataset_sizes = {
        x: len(image_datasets[x]) for x in ["train", "val"]
    }
    class_names = image_datasets["train"].classes

    return dataloaders, dataset_sizes, class_names


def load_transform_data_from_file(file_path, root_dir, img_resize, input_size, sep='\t', batch_size=32):
    """ Transforms the training and validation sets read from a file.

    Parameters
    ----------
    file_path : str
        Path to the training and validation files (must not include any extension, paths must be relative to the root_dir unless root_dir is null)
    root_dir : str
        Path to the root directory of images
    batch_size : int (default is 32)
        Batch size

    Returns
    -------
    dict
        Contains the set images for training and validation set
    list
        Contains dataset sizes
    list
        Contains class names
    """
    data_transforms = {
        "train": transforms.Compose([
            transforms.Resize((img_resize,img_resize)),
            transforms.RandomCrop((input_size,input_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMGNET_MEAN, IMGNET_STD)
        ]),
        "val": transforms.Compose([
            transforms.Resize((img_resize,img_resize)),
            transforms.CenterCrop((input_size,input_size)),
            transforms.ToTensor(),
            transforms.Normalize(IMGNET_MEAN, IMGNET_STD)
        ]),
    }
    image_datasets = {
        x: MyDataset(
            file_path + "." + x, sep, root_dir, data_transforms[x]
        )
        for x in ["train", "val"]
    }
    dataloaders = {
        x: torch.utils.data.DataLoader(
            image_datasets[x],
            batch_size=batch_size,
            shuffle=True,
            num_workers=10,
        )
        for x in ["train", "val"]
    }
    dataset_sizes = {
        x: len(image_datasets[x]) for x in ["train", "val"]
    }
    class_names = image_datasets["train"].classes

    return dataloaders, dataset_sizes, class_names


def load_transform_data_from_file_for_test(file_path, root_dir, img_resize, input_size, sep='\t', batch_size=32):
    """ Transforms the training and validation sets read from a file.

    Parameters
    ----------
    file_path : str
        Path to the training and validation files (must not include any extension, paths must be relative to the root_dir unless root_dir is null)
    root_dir : str
        Path to the root directory of images
    batch_size : int (default is 32)
        Batch size

    Returns
    -------
    dict
        Contains the set images for training and validation set
    list
        Contains dataset sizes
    list
        Contains class names
    """
    data_transforms = {
        "train": transforms.Compose([
            transforms.Resize((img_resize,img_resize)),
            transforms.CentorCrop((input_size,input_size)),
            transforms.ToTensor(),
            transforms.Normalize(IMGNET_MEAN, IMGNET_STD)
        ]),
        "val": transforms.Compose([
            transforms.Resize((img_resize,img_resize)),
            transforms.CenterCrop((input_size,input_size)),
            transforms.ToTensor(),
            transforms.Normalize(IMGNET_MEAN, IMGNET_STD)
        ]),
    }
    image_datasets = {
        x: MyDataset(
            file_path + "." + x, sep, root_dir, data_transforms[x]
        )
        for x in ["train", "val"]
    }
    dataloaders = {
        x: torch.utils.data.DataLoader(
            image_datasets[x],
            batch_size=batch_size,
            shuffle=False, # it is important to keep the order of images fixed while testing and saving predictions
            num_workers=10,
        )
        for x in ["train", "val"]
    }
    dataset_sizes = {
        x: len(image_datasets[x]) for x in ["train", "val"]
    }
    class_names = image_datasets["train"].classes

    return dataloaders, dataset_sizes, class_names


def imshow(inp, title=None, size=(20, 20)):
    """Imshow for Pytorch tensor.

    Parameters
    ----------
    inp : torch.Tensor
        The tensor of the input image
    title : str (default is None)
        Title of the image
    size : tuple (default is (20, 20))
        Size of image: (width, height)

    """
    plt.figure(figsize=size)
    inp = inp.numpy().transpose((1, 2, 0))
    inp = IMGNET_STD * inp + IMGNET_MEAN
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)


def save_plot(fig_dir, file_name,is_best,losses, accuracies):
    """Save a plot of the training curves for train/val losses and accuracies
    vs. number of training epochs

    Parameters
    ----------
    fig_dir : str
        File to figures directory
    losses : dict
        A dictionary containing phase (train or val) as keys
        and the list of loss values until current epoch
    accuracies : dict
        A dictionary containing phase (train or val) as keys
        and the list of accuracy scores until current epoch
    """
    ohist_val_acc = [h.cpu().numpy() for h in accuracies["val"]]
    ohist_val_loss = [h for h in losses["val"]]
    ohist_train_acc = [h.cpu().numpy() for h in accuracies["train"]]
    ohist_train_loss = [h for h in losses["train"]]

    num_epochs = len(ohist_val_acc)

    from mpl_toolkits.axes_grid1 import host_subplot

    plt.figure()
    host = host_subplot(111)
    par1 = host.twinx()

    host.set_xlabel("epoch")
    host.set_ylabel("loss")
    par1.set_ylabel("accuracy")

    p1, = host.plot(range(1,num_epochs+1), ohist_train_loss, 'b--', label="training loss")
    p2, = host.plot(range(1,num_epochs+1), ohist_val_loss, 'b',  label="validation loss")
    p3, = par1.plot(range(1,num_epochs+1), ohist_train_acc, 'g--', label="training accuracy")
    p4, = par1.plot(range(1,num_epochs+1), ohist_val_acc, 'g', label="validation accuracy")

    host.legend(loc='upper center', bbox_to_anchor=(0.5,1.125), fancybox=True, shadow=True, ncol=4)

    host.axis["left"].label.set_color(p1.get_color())
    par1.axis["right"].label.set_color(p3.get_color())

    plt.draw()
    plt.show()

    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)
    out_file=fig_dir + "/"+file_name+".png"
    print(out_file)
    plt.savefig(out_file,dpi=300,bbox_inches='tight')
    plt.close()
    # if is_best:
    #     shutil.copyfile(os.path.join(out_file), os.path.join(fig_dir, "training.png"))



def save_checkpoint(state, is_best, filename, checkpoint_dir):
    """Saves latest model

    Parameters
    ----------
    state : dict
        State of the model to be saved
    is_best : boolean
        Whether or not current model is the best model
    filename : str
        Name of the file to be saved
    checkpoint_dir : str
        Path to models directory
    """
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    torch.save(state, os.path.join(checkpoint_dir, filename))
    if is_best:
        shutil.copyfile(os.path.join(checkpoint_dir, filename), os.path.join(checkpoint_dir, "best_state.pth"))


def save_best_checkpoint(state, is_best, filename, checkpoint_dir):
    """Saves the best model only

    Parameters
    ----------
    state : dict
        State of the model to be saved
    is_best : boolean
        Whether or not current model is the best model
    filename : str
        Name of the file to be saved
    checkpoint_dir : str
        Path to models directory
    """
    if is_best:
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        torch.save(state, os.path.join(checkpoint_dir, "best_state.pth"))
        # torch.save(state, os.path.join(checkpoint_dir, filename))


def load_checkpoint(best_state_path, model=None, optimizer=None, scheduler=None, epoch=0, evals=(None, None), device="cpu"):
    """Load best model

    Parameters
    ----------

    """
    if os.path.isfile(best_state_path):
        # Load states
        checkpoint = torch.load(best_state_path, map_location=device)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

        # Update settings
        epoch = checkpoint['epoch']
        losses = checkpoint['losses']
        accuracies = checkpoint['accuracies']
        evals = (losses, accuracies)
        print("Loaded checkpoint '{}' (epoch {}) successfully.".format(best_state_path, epoch),flush=True)
        epoch += 1
    else:
        print("No checkpoint found.",flush=True)
    return model, optimizer, epoch, evals

######################################################################
# Model Training and Validation Code
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The ``train_model`` function handles the training and validation of a
# given model. As input, it takes a PyTorch model, a dictionary of
# dataloaders, a loss function, an optimizer, a specified number of epochs
# to train and validate for, and a boolean flag for when the model is an
# Inception model. The *is_inception* flag is used to accomodate the
# *Inception v3* model, as that architecture uses an auxiliary output and
# the overall model loss respects both the auxiliary output and the final
# output, as described
# `here <https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958>`__.
# The function trains for the specified number of epochs and after each
# epoch runs a full validation step. It also keeps track of the best
# performing model (in terms of validation accuracy), and at the end of
# training returns the best performing model. After each epoch, the
# training and validation accuracies are printed.
#
def train_model(model, dataloaders, criterion, optimizer, scheduler, device="cuda", num_epochs=25, curr_epoch=0, curr_evals=(None,None), class_names=None, fig_dir="output/figures", checkpoint_dir="output/models", train_file="train",is_inception=False):
# def train_model(model, dataloaders, criterion, optimizer, scheduler, device="cuda", num_epochs=25, curr_epoch=0, curr_evals=(None,None), class_names=None, fig_dir="output/figures", checkpoint_dir="output/models",is_inception=False):
    """ Trains a model

    Parameters
    ----------
    model
        The pretrained model to be fine-tuned
    dataloaders
        Contains the set images for training and validation set
    criterion
        Loss function, e.g. cross entropy loss
    optimizer
        Optimization algorithm, e.g. SGD, Adam
    scheduler
        Learning rate scheduler
    num_epochs : int (default is 25)
        Number of epochs

    Returns
    -------
    model
        The fine-tuned model

    """

    since = time.time()

    phases = ['train', 'val']
    losses, accuracies = curr_evals
    if not losses:
        losses = {phase: [] for phase in phases}
    if not accuracies:
        accuracies = {phase: [] for phase in phases}

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(curr_epoch, num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1),flush=True)
        print('-' * 10,flush=True)

        # Each epoch has a training and validation phase
        for phase in phases:
            if phase == 'train':
                # scheduler.step()
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for idx, (inputs, labels) in tqdm(enumerate(dataloaders[phase]),total=len(dataloaders[phase]),desc="[{}/{}] {} Iteration".format(epoch,num_epochs-1,phase.upper()),ncols=10):
                inputs = inputs.to(device)
                labels = labels.to(device)
                print(inputs.shape)
                print(labels.shape)
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == 'train':
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            accuracies[phase].append(epoch_acc)
            losses[phase].append(epoch_loss)

            # print progress
            learning_rate = optimizer.param_groups[0]["lr"]
            print(' Loss: {:.4f} Acc: {:.4f} lr: {:.4E}'.format(epoch_loss, epoch_acc, Decimal(learning_rate)),flush=True)

            if phase == 'val':
                if isinstance(model,torch.nn.DataParallel):
                    model_state_dict = model.module.state_dict()
                else:
                    model_state_dict = model.state_dict()
                # check if current model yields better accuracy
                is_best = False
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    is_best = True

                # save states dictionary
                state = {
                    "epoch": epoch,
                    "lr": learning_rate,
                    "state_dict": model_state_dict,
                    "optimizer": optimizer.state_dict(),
                    "losses": losses,
                    "accuracies": accuracies,
                    "class_names": class_names
                }

                # make filename verbose
                # filename = "state_{0:d}_{1:.3f}_{2:.3f}.pth".format(epoch,epoch_loss,epoch_acc)
                # filename = "{}_{0:d}_{1:.3f}_{2:.3f}_{:.4E}_best_state.pth".format(train_file,epoch, epoch_loss, epoch_acc,Decimal(learning_rate))
                # train_file = "train"
                filename = "{}_{:.3f}_best_state.pth".format(train_file, epoch_acc)

                # save model checkpoint
                save_best_checkpoint(state, is_best, filename, checkpoint_dir)

                # save progress plot
                save_plot(fig_dir, train_file, is_best, losses, accuracies)

                if isinstance(scheduler,optim.lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(epoch_loss)
                else:
                    scheduler.step()

        print(flush=True)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60),flush=True)
    print('Best Val Accuracy: {:.4f}'.format(best_acc),flush=True)

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


def test_model(model, dataloader, target_labels=None, target_names=None, device="cuda"):

    model.eval()

    preds_ = []
    labels_ = []

    with torch.no_grad():
        # Iterate over data.
        for idx, (inputs, labels) in tqdm(enumerate(dataloader),total=len(dataloader),desc="Progress",ncols=10):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            preds_.extend(preds.cpu().numpy().tolist())
            labels_.extend(labels.data.cpu().numpy().tolist())

    print('-' * 10,flush=True)
    print('Performance results:',flush=True)
    acc = metrics.accuracy_score(labels_,preds_)
    print('Accuracy: {:.4f}'.format(acc),flush=True)
    print(metrics.classification_report(labels_, preds_, labels=target_labels, target_names=target_names, digits=3),flush=True)
    print(metrics.confusion_matrix(labels_, preds_, labels=target_labels),flush=True)
    print('-' * 10,flush=True)
    print(flush=True)

    acc = metrics.accuracy_score(labels_, preds_)
    precision = metrics.precision_score(labels_, preds_, average="weighted")
    recall = metrics.recall_score(labels_, preds_, average="weighted")
    f1_score = metrics.f1_score(labels_, preds_, average="weighted")

    result = str("{0:.3f}".format(acc)) + "\t" + str(
        "{0:.3f}".format(precision)) + "\t" + str("{0:.3f}".format(recall)) + "\t" + str(
        "{0:.3f}".format(f1_score)) + "\n"
    print(result)




# Function to compute aggregated scores from per-fold evaluations
def compute_aggregate_scores(all_labels, all_predictions, all_classes):
    # Collect all predictions
    # all_labels = []
    # all_predictions = []
    # for labels, predictions in evaluations:
    #     all_labels.extend(labels)
    #     # print("train pred")
    #     # print(predictions)
    #     all_predictions.extend(predictions)
    # print(all_classes)
    # print(all_predictions)
    # print(all_labels)

    aggregated_metrics = {}
    # print(all_classes)
    accuracy = accuracy_score(all_labels, all_predictions)
    confusion_matrix = compute_confusion_matrix(all_labels, all_predictions, labels=all_classes)
    prf_per_class = precision_recall_fscore_support(
            all_labels, all_predictions, labels=all_classes, average=None
        )[:-1]
    prf_micro = precision_recall_fscore_support(
            all_labels, all_predictions, labels=all_classes, average='micro'
        )[:-1]
    prf_macro = precision_recall_fscore_support(
            all_labels, all_predictions, labels=all_classes, average='macro'
        )[:-1]
    prf_weighted = precision_recall_fscore_support(
            all_labels, all_predictions, labels=all_classes, average='weighted'
        )[:-1]
    aggregated_metrics = {
        # "accuracy": accuracy,
        "prf_per_class": prf_per_class,
        "prf_per_class_labels": all_classes,
        "prf_micro": prf_micro,
        "prf_macro": prf_macro,
        "prf_weighted": prf_weighted,
        # "confusion_matrix": confusion_matrix,
    }

    return aggregated_metrics

# Convert a dictionary to JSON-able object, converting all numpy arrays to python
# lists
def convert_to_json(obj):
    converted_obj = {}

    for key, value in obj.items():
        if isinstance(value, dict):
            converted_obj[key] = convert_to_json(value)
        else:
            # print(key, type(value))
            converted_obj[key] = np.array(value).tolist()
            # getattr(value, "tolist", lambda: value)()
            # print(key, type(converted_obj[key]))

    return converted_obj

def test_model2(model, dataloaders, class_names, target_labels=None, target_names=None, device="cuda"):
    print("\nTesting the model on both data splits:",flush=True)

    model.eval()

    phases = ['train', 'val']

    # test both data splits
    all_phases={}
    for phase in phases:
        print("\n\n"+phase+"\n\n")
        preds_ = []
        labels_ = []

        with torch.no_grad():
            # Iterate over data.
            for idx, (inputs, labels) in tqdm(enumerate(dataloaders[phase]),total=len(dataloaders[phase]),desc="{} Iteration".format(phase.upper()),ncols=10):
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)

                preds_.extend(preds.cpu().numpy().tolist())
                labels_.extend(labels.data.cpu().numpy().tolist())

        print('-' * 10,flush=True)
        print('Performance on {} set:'.format(phase.upper()),flush=True)
        acc = metrics.accuracy_score(labels_,preds_)
        class_report = metrics.classification_report(labels_, preds_, labels=target_labels, target_names=target_names, digits=3)
        conf_mat = metrics.confusion_matrix(labels_, preds_, labels=target_labels)
        print('Accuracy: {:.4f}'.format(acc),flush=True)
        print(class_report,flush=True)
        print(conf_mat,flush=True)
        print('-' * 10,flush=True)
        print(flush=True)

        label_y = []
        label_pred = []
        for i in range(len(preds_)):
            label_y.append(class_names[labels_[i]])
            label_pred.append(class_names[preds_[i]])

        acc = metrics.accuracy_score(labels_, preds_)
        precision = metrics.precision_score(labels_, preds_, average="weighted")
        recall = metrics.recall_score(labels_, preds_, average="weighted")
        f1_score = metrics.f1_score(labels_, preds_, average="weighted")

        result = str("{0:.3f}".format(acc)) + "\t" + str(
            "{0:.3f}".format(precision)) + "\t" + str("{0:.3f}".format(recall)) + "\t" + str(
            "{0:.3f}".format(f1_score)) + "\n"

        print(result)
        conf_mat_str = format_conf_mat(label_y, label_pred)
        agr_met = compute_aggregate_scores(label_y, label_pred, target_names)
        phase_object={
            "accuracy":acc,
            "results":result,
            "classification_report":class_report,
            "confusion_matrix":conf_mat,
            "conf_mat_str":conf_mat_str,
            "gold":label_y,
            "pred": label_pred,
            "agr_met":agr_met
        }
        all_phases[phase]=phase_object

        # json.dump(convert_to_json(all_phases), out_json_file)

        # conf_mat_str = format_conf_mat(test_y, test_pred)
        #print(conf_mat_str)
    return all_phases


def format_conf_mat(y_true,y_pred):
    # y_true = np.argmax(y_true, axis=1)
    # y_pred = np.argmax(y_pred, axis=1)


    # y_true = le.inverse_transform(y_true)
    # y_pred = le.inverse_transform(y_pred)

    conf_mat = pd.crosstab(np.array(y_true), np.array(y_pred), rownames=['gold'], colnames=['pred'], margins=True)
    pred_columns = conf_mat.columns.tolist()
    gold_rows = conf_mat.index.tolist()
    conf_mat_str = ""
    header = "Pred\nGold"
    for h in pred_columns:
        header = header + "\t" + str(h)
    conf_mat_str = header + "\n"
    index = 0
    for r_index, row in conf_mat.iterrows():
        row_str = str(gold_rows[index])  # for class label (name)
        index += 1
        for col_item in row:
            row_str = row_str + "\t" + str(col_item)
        conf_mat_str = conf_mat_str + row_str + "\n"
    return conf_mat_str

def test_model_save_results(outfile, sep, model, dataloader, test_images, class_names, target_labels=None, target_names=None, device="cuda"):
    print("\nTesting the model and saving the results:",flush=True)

    model.eval()

    probs_ = []
    preds_ = []
    labels_ = []

    running_corrects = 0
    count = 0
    a = datetime.now().replace(microsecond=0)
    with torch.no_grad():
        # Iterate over data.
        for idx, (inputs, labels) in tqdm(enumerate(dataloader),total=len(dataloader),desc="Test Progress",ncols=10):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            outputs = torch.nn.functional.softmax(outputs,1)
            probs, preds = torch.max(outputs, 1)
            preds_.extend(preds.cpu().numpy().tolist())
            probs_.extend(probs.cpu().numpy().tolist())
            labels_.extend(labels.data.cpu().numpy().tolist())
            running_corrects += torch.sum(preds == labels.data)
            count += len(labels.data)

    b = datetime.now().replace(microsecond=0)
    time_taken=b - a
    print ("\ntime taken for inference:\t{}".format(str(time_taken)))


    overall_acc = running_corrects.double() / count

    print('\nOverall accuracy: {:.4f}'.format(overall_acc),flush=True)

    print('-' * 10,flush=True)
    print('Performance results:',flush=True)
    acc = metrics.accuracy_score(labels_,preds_)
    print('Accuracy: {:.4f}'.format(acc),flush=True)
    acc = metrics.accuracy_score(labels_, preds_)
    class_report = metrics.classification_report(labels_, preds_, labels=target_labels, target_names=target_names,digits=3)
    conf_mat = metrics.confusion_matrix(labels_, preds_, labels=target_labels)
    print(class_report,flush=True)
    print(conf_mat,flush=True)
    print('-' * 10,flush=True)
    print(flush=True)

    acc = metrics.accuracy_score(labels_, preds_)
    precision = metrics.precision_score(labels_, preds_, average="weighted")
    recall = metrics.recall_score(labels_, preds_, average="weighted")
    f1_score = metrics.f1_score(labels_, preds_, average="weighted")

    result = str("{0:.3f}".format(acc)) + "\t" + str(
        "{0:.3f}".format(precision)) + "\t" + str("{0:.3f}".format(recall)) + "\t" + str(
        "{0:.3f}".format(f1_score)) + "\n"
    print(result)

    test_y=[]
    test_pred=[]
    with open(outfile,'w') as f:
        for i in range(len(preds_)):
            f.write('{}{}{}{}{:.4f}\n'.format(test_images.samples[i][0],sep,class_names[preds_[i]],sep,probs_[i]))
            test_y.append(class_names[labels_[i]])
            test_pred.append(class_names[preds_[i]])

    agr_met = compute_aggregate_scores(test_y, test_pred, target_names)
    conf_mat_str = format_conf_mat(test_y, test_pred)
    print(conf_mat_str)
    phase_object = {
        "accuracy": acc,
        "results": result,
        "classification_report": class_report,
        "confusion_matrix": conf_mat,
        "conf_mat_str": conf_mat_str,
        "gold":test_y,
        "pred": test_pred,
        "agr_met": agr_met
    }

    return phase_object

######################################################################
# Set Model Parameters’ .requires_grad attribute
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# This helper function sets the ``.requires_grad`` attribute of the
# parameters in the model to False when we are feature extracting. By
# default, when we load a pretrained model all of the parameters have
# ``.requires_grad=True``, which is fine if we are training from scratch
# or finetuning. However, if we are feature extracting and only want to
# compute gradients for the newly initialized layer then we want all of
# the other parameters to not require gradients. This will make more sense
# later.
#

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False


######################################################################
# Initialize and Reshape the Networks
# -----------------------------------
#
# Now to the most interesting part. Here is where we handle the reshaping
# of each network. Note, this is not an automatic procedure and is unique
# to each model. Recall, the final layer of a CNN model, which is often
# times an FC layer, has the same number of nodes as the number of output
# classes in the dataset. Since all of the models have been pretrained on
# Imagenet, they all have output layers of size 1000, one node for each
# class. The goal here is to reshape the last layer to have the same
# number of inputs as before, AND to have the same number of outputs as
# the number of classes in the dataset. In the following sections we will
# discuss how to alter the architecture of each model individually. But
# first, there is one important detail regarding the difference between
# finetuning and feature-extraction.
#
# When feature extracting, we only want to update the parameters of the
# last layer, or in other words, we only want to update the parameters for
# the layer(s) we are reshaping. Therefore, we do not need to compute the
# gradients of the parameters that we are not changing, so for efficiency
# we set the .requires_grad attribute to False. This is important because
# by default, this attribute is set to True. Then, when we initialize the
# new layer and by default the new parameters have ``.requires_grad=True``
# so only the new layer’s parameters will be updated. When we are
# finetuning we can leave all of the .required_grad’s set to the default
# of True.
#
# Finally, notice that inception_v3 requires the input size to be
# (299,299), whereas all of the other models expect (224,224).
#
# Resnet
# ~~~~~~
#
# Resnet was introduced in the paper `Deep Residual Learning for Image
# Recognition <https://arxiv.org/abs/1512.03385>`__. There are several
# variants of different sizes, including Resnet18, Resnet34, Resnet50,
# Resnet101, and Resnet152, all of which are available from torchvision
# models. Here we use Resnet18, as our dataset is small and only has two
# classes. When we print the model, we see that the last layer is a fully
# connected layer as shown below:
#
# ::
#
#    (fc): Linear(in_features=512, out_features=1000, bias=True)
#
# Thus, we must reinitialize ``model.fc`` to be a Linear layer with 512
# input features and 2 output features with:
#
# ::
#
#    model.fc = nn.Linear(512, num_classes)
#
# Alexnet
# ~~~~~~~
#
# Alexnet was introduced in the paper `ImageNet Classification with Deep
# Convolutional Neural
# Networks <https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf>`__
# and was the first very successful CNN on the ImageNet dataset. When we
# print the model architecture, we see the model output comes from the 6th
# layer of the classifier
#
# ::
#
#    (classifier): Sequential(
#        ...
#        (6): Linear(in_features=4096, out_features=1000, bias=True)
#     )
#
# To use the model with our dataset we reinitialize this layer as
#
# ::
#
#    model.classifier[6] = nn.Linear(4096,num_classes)
#
# VGG
# ~~~
#
# VGG was introduced in the paper `Very Deep Convolutional Networks for
# Large-Scale Image Recognition <https://arxiv.org/pdf/1409.1556.pdf>`__.
# Torchvision offers eight versions of VGG with various lengths and some
# that have batch normalizations layers. Here we use VGG-11 with batch
# normalization. The output layer is similar to Alexnet, i.e.
#
# ::
#
#    (classifier): Sequential(
#        ...
#        (6): Linear(in_features=4096, out_features=1000, bias=True)
#     )
#
# Therefore, we use the same technique to modify the output layer
#
# ::
#
#    model.classifier[6] = nn.Linear(4096,num_classes)
#
# Squeezenet
# ~~~~~~~~~~
#
# The Squeeznet architecture is described in the paper `SqueezeNet:
# AlexNet-level accuracy with 50x fewer parameters and <0.5MB model
# size <https://arxiv.org/abs/1602.07360>`__ and uses a different output
# structure than any of the other models shown here. Torchvision has two
# versions of Squeezenet, we use version 1.0. The output comes from a 1x1
# convolutional layer which is the 1st layer of the classifier:
#
# ::
#
#    (classifier): Sequential(
#        (0): Dropout(p=0.5)
#        (1): Conv2d(512, 1000, kernel_size=(1, 1), stride=(1, 1))
#        (2): ReLU(inplace)
#        (3): AvgPool2d(kernel_size=13, stride=1, padding=0)
#     )
#
# To modify the network, we reinitialize the Conv2d layer to have an
# output feature map of depth 2 as
#
# ::
#
#    model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
#
# Densenet
# ~~~~~~~~
#
# Densenet was introduced in the paper `Densely Connected Convolutional
# Networks <https://arxiv.org/abs/1608.06993>`__. Torchvision has four
# variants of Densenet but here we only use Densenet-121. The output layer
# is a linear layer with 1024 input features:
#
# ::
#
#    (classifier): Linear(in_features=1024, out_features=1000, bias=True)
#
# To reshape the network, we reinitialize the classifier’s linear layer as
#
# ::
#
#    model.classifier = nn.Linear(1024, num_classes)
#
# Inception v3
# ~~~~~~~~~~~~
#
# Finally, Inception v3 was first described in `Rethinking the Inception
# Architecture for Computer
# Vision <https://arxiv.org/pdf/1512.00567v1.pdf>`__. This network is
# unique because it has two output layers when training. The second output
# is known as an auxiliary output and is contained in the AuxLogits part
# of the network. The primary output is a linear layer at the end of the
# network. Note, when testing we only consider the primary output. The
# auxiliary output and primary output of the loaded model are printed as:
#
# ::
#
#    (AuxLogits): InceptionAux(
#        ...
#        (fc): Linear(in_features=768, out_features=1000, bias=True)
#     )
#     ...
#    (fc): Linear(in_features=2048, out_features=1000, bias=True)
#
# To finetune this model we must reshape both layers. This is accomplished
# with the following
#
# ::
#
#    model.AuxLogits.fc = nn.Linear(768, num_classes)
#    model.fc = nn.Linear(2048, num_classes)
#
# Notice, many of the models have similar output structures, but each must
# be handled slightly differently. Also, check out the printed model
# architecture of the reshaped network and make sure the number of output
# features is the same as the number of classes in the dataset.
#

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model = None
    input_size = 0
    img_resize = 0

    if model_name == "resnet18":
        """ Resnet18
        """
        model = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224
        img_resize = 256

    elif model_name == "resnet50":
        """ Resnet50
        """
        model = models.resnet50(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224
        img_resize = 256

    elif model_name == "resnet101":
        """ Resnet101
        """
        model = models.resnet101(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224
        img_resize = 256

    elif model_name == "alexnet":
        """ Alexnet
        """
        model = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        num_ftrs = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224
        img_resize = 256

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        num_ftrs = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224
        img_resize = 256

    elif model_name == "vgg16":
        """ VGG16_bn
        """
        model = models.vgg16_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        num_ftrs = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224
        img_resize = 256

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model.num_classes = num_classes
        input_size = 224
        img_resize = 256

    elif model_name == "densenet":
        """ Densenet
        """
        model = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        num_ftrs = model.classifier.in_features
        model.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224
        img_resize = 256

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        # Handle the auxilary net
        num_ftrs = model.AuxLogits.fc.in_features
        model.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299
        img_resize = 299

    elif model_name == "mobilenet_v2":
        model = models.mobilenet_v2(pretrained=use_pretrained)
        set_parameter_requires_grad(model, feature_extract)
        # Handle the auxilary net
        # num_ftrs = model.AuxLogits.fc.in_features
        # model.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model.last_channel
        model.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 224
        img_resize = 256
    elif model_name == "efficientnet-b1":
        # model = EfficientNet.from_pretrained('efficientnet-b0')
        model = EfficientNet.from_pretrained('efficientnet-b1', num_classes=num_classes)
        input_size = 224
        img_resize = 256
    elif model_name == "efficientnet-b7":
        # model = EfficientNet.from_pretrained('efficientnet-b0')
        model = EfficientNet.from_pretrained('efficientnet-b7', num_classes=num_classes)
        input_size = 224
        img_resize = 256

    else:
        print("Invalid model name, exiting...",flush=True)
        exit()

    return model, img_resize, input_size


def feature_extract(outfile, sep, model, dataloader, test_images, class_names, target_labels=None,
                            target_names=None, device="cuda"):
    print("\nTesting the model and saving the results:", flush=True)

    model.eval()


    out_file = open(outfile, "w")
    header=""
    for i in range(512):
        header=header+"\tF"+str(i)
    out_file.write(header.strip()+"\tclass_label"+"\n")

    with torch.no_grad():
        # Iterate over data.
        for idx, (inputs, labels) in tqdm(enumerate(dataloader), total=len(dataloader), desc="Test Progress", ncols=10):
            inputs = inputs.to(device)
            labels = labels.to(device)

            feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])
            output = feature_extractor(inputs)
            # output = torch.flatten(output)
            labels_=labels.data.cpu().numpy().tolist()
            for v, l in zip(output, labels_):
                vec=torch.flatten(v).tolist()
                # print(vec)
                vec_str=""
                for v1 in vec:
                    vec_str+=str(v1)+"\t"

                out_file.write(vec_str.strip()+"\t"+class_names[l] + "\n")

    out_file.close()
