import torch
import numpy as np
import os
from transformers import BertModel
from utils.path_utils import get_output_path
import random
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from hfdataset import HFDataset
from head_layer import HeadLayer
from trainers import HeadLayerTrainer
from evaluators import ClassifierEvaluator
from tqdm import tqdm
import time
import pickle
from utils.logging_utils import get_logger
from utils.load_model_utils import load_model
from utils.model_utils import create_base_model, get_pooled_output
from config import PREDICTORS_BASE_DIR, PREDICTOR_EVAL_BASE_DIR, TRANSFORMATION_NETS_FILEPATH, GPU_DEVICE,\
    PREDICTOR_OPTIONAL_LAYER_DIMS, TRANSFORMATION_NETS_OPTIONAL_LAYER_DIMS

# train_logger = get_logger(__name__, os.path.join(LOGGING_PREDICTORS_DIR, 'training.log'))
# validate_logger = get_logger(__name__, os.path.join(LOGGING_PREDICTORS_DIR, 'validation.log'))
train_logger = get_logger(__name__, 'predictor_training.log')
validate_logger = get_logger(__name__, 'predictor_validation.log')


def train_predictors(source_dataset_names,
                     target_dataset_name,
                     num_epochs=20,
                     batch_size=32,
                     weight_decay=0.01,
                     classifier_optional_layer_dims=PREDICTOR_OPTIONAL_LAYER_DIMS,
                     transformation_nets_optional_layer_dims=TRANSFORMATION_NETS_OPTIONAL_LAYER_DIMS,
                     num_train_samples=None,
                     num_source_samples=None,
                     overwrite=False,
                     seed=42):
    if isinstance(source_dataset_names, str):
        source_dataset_names = [source_dataset_names]

    classifier_filepath = get_output_path(PREDICTORS_BASE_DIR,
                                          optional_layers=classifier_optional_layer_dims,
                                          # num_epochs=num_epochs,
                                          num_train_samples=num_train_samples,
                                          num_source_samples=num_source_samples,
                                          seed=seed,
                                          target_name=target_dataset_name)

    standard_classifier_filepath = os.path.join(classifier_filepath, f'{target_dataset_name}.pt')
    source_classifier_filepaths = {
        source_dataset_name: os.path.join(classifier_filepath, f'{source_dataset_name}.pt')
        for source_dataset_name in source_dataset_names}

    # Check if all the classifiers were found on disk
    if not overwrite:
        source_dataset_names = [name for name in source_dataset_names
                                if not os.path.isfile(source_classifier_filepaths[name])]
        if os.path.isfile(standard_classifier_filepath) and len(source_dataset_names) == 0:
            train_logger.info('Standard classifier and all transformed classifiers were found on disk.')
            return

    dataset = HFDataset(target_dataset_name, split='train', max_num_examples=num_train_samples, seed=seed)
    label_dim = dataset.label_dim

    base_model = create_base_model()
    device = torch.device(GPU_DEVICE) if torch.cuda.is_available() else torch.device("cpu")
    base_model.to(device)

    train_sampler = RandomSampler(dataset)
    train_dataloader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size, collate_fn=dataset.collate_fn)

    total_steps = len(train_dataloader) * num_epochs

    standard_classifier = HeadLayer(output_dim=label_dim, optional_layer_dims=classifier_optional_layer_dims)
    standard_trainer = HeadLayerTrainer(model=standard_classifier,
                                        output_dim=label_dim,
                                        num_train_steps=total_steps,
                                        weight_decay=weight_decay)
    transferred_trainers = {}
    for dataset_name in source_dataset_names:
        transferred_classifier = HeadLayer(output_dim=label_dim,
                                           optional_layer_dims=classifier_optional_layer_dims)
        # transformation_net = torch.load(os.path.join(TRANSFORMATION_NETS_FILEPATH, f'{dataset_name}.pt'),
        #                                 map_location=device)
        # transformation_net = load_model(os.path.join(TRANSFORMATION_NETS_FILEPATH, f'{dataset_name}.pt'),
        #                                 model_type='transformation_network',
        #                                 device=device)
        transformation_net_filepath = os.path.join(get_output_path(TRANSFORMATION_NETS_FILEPATH,
                                                                   num_train_samples=num_source_samples,
                                                                   optional_layers=transformation_nets_optional_layer_dims),
                                                   f'{dataset_name}.pt')
        transformation_net = load_model(transformation_net_filepath,
                                        model_type='transformation_network',
                                        device=device)
        transferred_trainers[dataset_name] = HeadLayerTrainer(model=transferred_classifier,
                                                              output_dim=label_dim,
                                                              num_train_steps=total_steps,
                                                              transformation_nets=transformation_net,
                                                              weight_decay=weight_decay)

    # TODO: Maybe function to set seeds
    seed_val = 42

    random.seed(seed_val)
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)

    train_logger.info(f'Starting predictor training with {len(source_dataset_names)} sources for target'
                      f'{target_dataset_name}.'
                      f'\nSettings: num_epochs={num_epochs}, num_train_samples={num_train_samples},'
                      f'seed={seed}, optional_layer_dims={classifier_optional_layer_dims}, weight_decay={weight_decay}')

    start_time = time.time()
    with tqdm(range(0, num_epochs), unit='epoch') as pbar:
        for _ in pbar:

            standard_trainer.reset_loss()
            for transferred_trainer in transferred_trainers.values():
                transferred_trainer.reset_loss()

            for step, batch in enumerate(train_dataloader):
                batch = tuple(t.to(device) for t in batch)
                b_input_ids, b_input_mask, b_labels = batch

                # TODO: Use no_grad
                with torch.no_grad():
                    # standard_embeddings = bert(b_input_ids, attention_mask=b_input_mask)[1].detach()
                    standard_embeddings = get_pooled_output(base_model, b_input_ids, b_input_mask)
                # standard_embeddings = base_model(b_input_ids, attention_mask=b_input_mask)[1]#.detach()

                base_model.zero_grad()
                standard_trainer.train_step(standard_embeddings, b_labels)

                for transferred_trainer in transferred_trainers.values():
                    base_model.zero_grad()
                    transferred_trainer.train_step(standard_embeddings, b_labels)

            avg_train_loss = np.mean([standard_trainer.total_loss]
                                     + [trainer.total_loss for trainer in transferred_trainers.values()]) / len(train_dataloader)
            pbar.set_postfix(avg_train_loss=avg_train_loss)

    end_time = time.time()
    train_logger.info(f'Training took {time.strftime("%H:%M:%S", time.gmtime(end_time - start_time))}')

    os.makedirs(classifier_filepath, exist_ok=True)

    torch.save(standard_trainer.model.state_dict(), standard_classifier_filepath)
    for source_dataset_name, transferred_trainer in transferred_trainers.items():
        torch.save(transferred_trainer.model.state_dict(), source_classifier_filepaths[source_dataset_name])


def validate_predictors(source_dataset_names,
                        target_dataset_name,
                        batch_size=32,
                        num_epochs=20,
                        num_train_samples=None,
                        num_source_samples=None,
                        classifier_optional_layer_dims=PREDICTOR_OPTIONAL_LAYER_DIMS,
                        transformation_nets_optional_layer_dims=TRANSFORMATION_NETS_OPTIONAL_LAYER_DIMS,
                        seed=42):
    results_dir = get_output_path(PREDICTOR_EVAL_BASE_DIR,
                                  optional_layers=classifier_optional_layer_dims,
                                  # num_epochs=num_epochs,
                                  num_train_samples=num_train_samples,
                                  num_source_samples=num_source_samples,
                                  seed=seed,
                                  target_name=target_dataset_name)
    os.makedirs(results_dir, exist_ok=True)

    if isinstance(source_dataset_names, str):
        source_dataset_names = [source_dataset_names]

    classifier_filepath = get_output_path(PREDICTORS_BASE_DIR,
                                          optional_layers=classifier_optional_layer_dims,
                                          # num_epochs=num_epochs,
                                          num_train_samples=num_train_samples,
                                          num_source_samples=num_source_samples,
                                          seed=seed,
                                          target_name=target_dataset_name)

    dataset = HFDataset(target_dataset_name, split='validation')

    base_model = create_base_model()
    device = torch.device(GPU_DEVICE) if torch.cuda.is_available() else torch.device("cpu")
    base_model.to(device)

    # standard_classifier = torch.load(os.path.join(classifier_filepath, f'{target_dataset_name}.pt'),
    #                                  map_location=device)
    standard_classifier = load_model(os.path.join(classifier_filepath, f'{target_dataset_name}.pt'),
                                     model_type='head_layer',
                                     device=device)
    standard_evaluator = ClassifierEvaluator(model=standard_classifier, metric=dataset.metric)
    transferred_evaluators = {}
    for dataset_name in source_dataset_names:
        # transformation_net = torch.load(os.path.join(TRANSFORMATION_NETS_FILEPATH, f'{dataset_name}.pt'),
        #                                 map_location=device)
        # transferred_classifier = torch.load(os.path.join(classifier_filepath, f'{dataset_name}.pt'),
        #                                     map_location=device)
        transformation_net_filepath = os.path.join(get_output_path(TRANSFORMATION_NETS_FILEPATH,
                                                                   num_train_samples=num_source_samples,
                                                                   optional_layers=transformation_nets_optional_layer_dims),
                                                   f'{dataset_name}.pt')
        transformation_net = load_model(transformation_net_filepath,
                                        model_type='transformation_network',
                                        device=device)

        transferred_classifier = load_model(os.path.join(classifier_filepath, f'{dataset_name}.pt'),
                                            model_type='head_layer',
                                            device=device)

        transferred_evaluators[dataset_name] = ClassifierEvaluator(model=transferred_classifier,
                                                                   metric=dataset.metric,
                                                                   transformation_net=transformation_net)

    base_model.eval()
    validation_sampler = SequentialSampler(dataset)
    validation_dataloader = DataLoader(dataset, sampler=validation_sampler, batch_size=batch_size,
                                       collate_fn=dataset.collate_fn)

    print("")
    print("Running Validation...")

    for step, batch in enumerate(validation_dataloader):
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        with torch.no_grad():
            # standard_embeddings = base_model(b_input_ids, attention_mask=b_input_mask)[1]
            standard_embeddings = get_pooled_output(base_model, b_input_ids, b_input_mask)

        standard_evaluator.evaluate_step(standard_embeddings, b_labels)

        for transferred_evaluator in transferred_evaluators.values():
            transferred_evaluator.evaluate_step(standard_embeddings, b_labels)

    # print(f'Standard_{num_train_samples_str} Classifier results: {standard_evaluator.evaluation_results["metric"]}')
    with open(os.path.join(results_dir, f'{target_dataset_name}.pkl'), 'wb') as f:
        pickle.dump(standard_evaluator.evaluation_results, f)

    for source_dataset_name, transferred_evaluator in transferred_evaluators.items():
        # print(f'{source_dataset_name} Classifier results: {transferred_evaluator.evaluation_results["metric"]}')
        with open(os.path.join(results_dir, f'{source_dataset_name}.pkl'), 'wb') as f:
            pickle.dump(transferred_evaluator.evaluation_results, f)
