import pandas as pd
import os
import sys
from datetime import datetime
import argparse
from argparse import Namespace
import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
from accelerate import Accelerator
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split
import ray
from ray import air, tune, serve, train
from ray.train import Checkpoint
from ray.air import session
from ray.tune import CLIReporter, ExperimentAnalysis
from ray.tune.schedulers import PopulationBasedTraining
from functools import partial

def main(args):
    """
    Main entry point for script.
    """
    
    config = initialize_config(args)
    
    dataset = load_dataset(config)
    
    tokenized_datasets=Tokenize_datasets(config, dataset)
    
    tune_and_train(config, tokenized_datasets)

    save_best_model(config)

def parse_args(args):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model',
                        dest='model_name',
                        default='roberta-large-mnli',
                        help='Name of pretrained transformers model, see: ' +
                        'https://huggingface.co/transformers/pretrained_models.html')
    parser.add_argument('-ft',
                        dest='full_train',
                        default=False,
                        action='store_true',
                        help='Train on full training set')
    parser.add_argument('--fp',
                        dest='in_filepath',
                        default='../resonance_work/data/processed_data/retraining_data/',
                        nargs='+',
                        help='Name of CSV file containing tagged posts.')
    parser.add_argument('--rn',
                        dest='run_name',
                        default='pbt_test',
                        nargs='?',
                        help='name of trial run save directory')
    parser.add_argument('--ptmp',
                        dest='pretrained_modelpath',
                        default='',
                        nargs='?',
                        help='path to model if finetuning already tuned model')
    parser.add_argument('--em',
                        dest='eval_metric',
                        default='accuracy',
                        nargs='?',
                        help='metric on which to optimize')
    parser.add_argument('--n_gpus',
                        dest='n_gpus',
                        default=4,
                        nargs='?',
                        help='How many GPUs to use for parallel training.')
    config = parser.parse_args(args)
    return config

def initialize_config(args):
    config = parse_args(args)
    
    config.save_filepath=f'value_resonance_finetuned_{config.model_name}_{datetime.now().strftime("%H-%M-%S_%d-%m-%Y")}'
    config.checkpoint_dir=f'./modeling_checkpoints/{datetime.now().strftime("%H-%M-%S_%d-%m-%Y")}_run/'
    
    metric_names_dict={'accuracy':'mean_accuracy',
                        'F1':'weighted_F1'}
    
    config.tuning_metric = metric_names_dict[config.eval_metric]

    return config

def load_dataset(config):
    file_dates={}
    if len(config.in_filepath)==1:
        config.in_filepath=config.in_filepath[0]
        if config.in_filepath.endswith('.csv'):
            config.in_filename=f'{config.in_filepath}'
        else:
            for filename in os.listdir(config.in_filepath):
                if (filename.endswith('.csv')):
                    file_dates[datetime.strptime(filename.split('_')[-1][:-4], '%d-%m-%Y')]=filename
            config.in_filename=f'{config.in_filepath}{file_dates[sorted(file_dates,reverse=True)[0]]}'
        model_data=clean_WVC_data(config)
        dataset=split_and_convert_to_TF_dfs(model_data, full_train=config.full_train)
    else:
        training_data=clean_WVC_data(config, config.in_filepath[0])
        testing_data=clean_WVC_data(config, config.in_filepath[1])
        dataset=convert_to_TF_dfs(training_data, testing_data)
    
    return dataset

class CustomStopper(tune.Stopper):
    def __init__(self,metric='mean_accuracy'):
        self.should_stop = False
        self.eval_metric=metric
        if self.eval_metric=='mean_accuracy':
            self.threshold=0.96
        elif self.eval_metric=='weighted_F1':
            self.threshold=0.9

    def __call__(self, trial_id, result):
        max_iter = 100# 5 #if args.smoke_test else 100
        if not self.should_stop and result[self.eval_metric] > self.threshold:
            self.should_stop = True
        return self.should_stop or result["training_iteration"] >= max_iter

    def stop_all(self):
        return self.should_stop

class RVRDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.premise = dataset['premise']
        self.hypothesis = dataset['hypothesis']
        self.label = dataset['label']
        self.idx = dataset['idx']
        self.input_ids = dataset['input_ids']
        self.attention_mask = dataset['attention_mask']

    def __getitem__(self, idx):
        item = {'input_ids':torch.tensor(self.input_ids[idx]),
                'attention_mask':torch.tensor(self.attention_mask[idx]),
                'labels':torch.tensor(self.label[idx])
               }
        return item

    def __len__(self):
        return len(self.label)

#Dataset loading
def clean_WVC_data(config, given_filename=''):
    """
    Clean raw WVC data contributions to Dataframe with idx, premise, hypothesis and label.
    Labels:
        0=Entailment
        1=Neutral
        2=Contradiction
    """
    WVC_contribs=pd.read_csv(config.in_filename)
    if all([col in WVC_contribs.columns for col in ['premise', 'hypothesis','idx', 'label']]):
        merged_narratives=WVC_contribs.loc[:,['idx','premise','hypothesis','label']]
        for col in ['premise','hypothesis']:
            merged_narratives[col]=merged_narratives[col].astype(str)
        if any([label in merged_narratives['label'].unique() for label in ['resonates','neutral','conflicts','entailment','contradiction']]):
            merged_narratives['label']=merged_narratives.label.apply(lambda x: 2 if x in ['resonates','entailment'] else 0 if x in ['conflicts','contradiction'] else 1)
        merged_narratives['label']=merged_narratives['label'].astype(int)
    return merged_narratives

def split_and_convert_to_TF_dfs(model_data,full_train=False):
    if full_train==True:
        train_df = pd.DataFrame({
            "premise": list(model_data.loc[:,'premise'].values),
            "hypothesis": list(model_data.loc[:,'hypothesis'].values),
            "label" : list(model_data.loc[:,'label'].values),
            "idx" : list(model_data.loc[:,'idx'].values)
        })

        train_dataset = datasets.Dataset.from_dict(train_df)
        dataset = datasets.DatasetDict({"train":train_dataset})
    else:
        X_train, X_test, y_train, y_test = train_test_split(
            model_data.loc[:,['idx','premise','hypothesis']].values, model_data.loc[:,['label']].values, test_size=0.2, random_state=42)
        # model_data
        train_df = pd.DataFrame({
            "premise": list(pd.DataFrame(X_train.reshape(-1,3))[1].values),
            "hypothesis": list(pd.DataFrame(X_train.reshape(-1,3))[2].values),
            "label" : list(y_train.reshape(1,-1)[0]),
            "idx" : list(pd.DataFrame(X_train.reshape(-1,3))[0].values)
        })
        
        test_df = pd.DataFrame({
            "premise": list(pd.DataFrame(X_test.reshape(-1,3))[1].values),
            "hypothesis": list(pd.DataFrame(X_test.reshape(-1,3))[2].values),
            "label" : list(y_test.reshape(1,-1)[0]),
            "idx" : list(pd.DataFrame(X_test.reshape(-1,3))[0].values)
        })
        
        train_dataset = datasets.Dataset.from_dict(train_df)
        test_dataset = datasets.Dataset.from_dict(test_df)
        dataset = datasets.DatasetDict({"train":train_dataset,
                                        "test":test_dataset})
    return dataset

def convert_to_TF_dfs(train_df,test_df):
    train_dataset = datasets.Dataset.from_dict(train_df[["premise","hypothesis","label","idx"]])
    test_dataset = datasets.Dataset.from_dict(test_df[["premise","hypothesis","label","idx"]])
    dataset = datasets.DatasetDict({"train":train_dataset,
                                    "test":test_dataset})
    return dataset

def Tokenize_datasets(config, dataset):
    def preprocess_function(examples):
        if sentence2_key is None:
            return tokenizer(examples[sentence1_key], truncation=True)
        return tokenizer(examples[sentence1_key], examples[sentence2_key], padding="max_length",truncation=True)
    
    sentence1_key="premise"
    sentence2_key = "hypothesis"
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenized_datasets = dataset.map(preprocess_function, batched=True)
    return tokenized_datasets

def tune_and_train(config, tokenized_datasets):
    hp_space={
            "alpha": tune.uniform(0.0, 1.0),
            "momentum": tune.uniform(0.001, 1),
            "num_train_epochs": tune.choice([2, 3, 4, 5, 6, 7, 8, 9, 10]),
            "seed": tune.choice([i for i in range(1000)]), 
            "lr": tune.loguniform(1e-4, 1e-1),
            "batch_size": tune.choice([2, 4, 8])
        }
    conf_dict=vars(config)
    conf_dict.update(hp_space)
    config=Namespace(**conf_dict)
    
    scheduler = PopulationBasedTraining(
        time_attr='training_iteration',
        perturbation_interval=5,
        hyperparam_mutations={
            "alpha": tune.uniform(0.0, 1.0),
            "momentum": tune.uniform(0.001, 1),
            "num_train_epochs": tune.choice([2, 3, 4, 5, 6, 7, 8, 9, 10]),
            "seed": tune.choice([i for i in range(1000)]), 
            "lr": tune.loguniform(1e-4, 1e-1),
            "batch_size": tune.choice([2, 4, 8])})
    stopper = CustomStopper(metric=config.tuning_metric)

    reporter = CLIReporter(
        metric_columns=["loss", config.tuning_metric, "training_iteration"])

    tuner = tune.Tuner(
        tune.with_resources(
            partial(train_rvr_pbt, input_data=tokenized_datasets, eval_metric=config.tuning_metric),
            {'cpu':4,'gpu':config.n_gpus}),
            run_config=air.RunConfig(
                name=config.run_name,
                stop=stopper,
                verbose=1,
                progress_reporter=reporter,
                checkpoint_config=air.CheckpointConfig(
                    checkpoint_score_attribute=config.tuning_metric,
                    num_to_keep=4,
                ),
            ),
            tune_config=tune.TuneConfig(
                scheduler=scheduler,
                metric=config.tuning_metric,
                mode="max",
                num_samples=4,
            ),
            param_space={
                "alpha": tune.uniform(0.0, 1.0),
                "momentum": tune.uniform(0.001, 1),
                "num_train_epochs": tune.choice([2, 3, 4, 5, 6, 7, 8, 9, 10]),
                "seed": tune.choice([i for i in range(1000)]), 
                "lr": tune.loguniform(1e-6, 1),
                "batch_size": tune.choice([2, 4, 8])},
    )
    results = tuner.fit()
    print("Best hyperparameters found were: ", results.get_best_result().config)
    
def train_rvr_pbt(config, input_data, eval_metric):
    accelerator = Accelerator()
    step = 0

    model = AutoModelForSequenceClassification.from_pretrained('roberta-large-mnli', num_labels=3)
    optimizer = optim.SGD(
        model.parameters(),
        lr=config.get("lr", 0.01),
        momentum=config.get("momentum", 0.9),
    )

    trainset = RVRDataset(input_data["train"])
    testset = RVRDataset(input_data["test"])

    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
                                trainset, [test_abs, len(trainset) - test_abs],
                                 generator=torch.Generator().manual_seed(42))

    trainloader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=8)
    valloader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=8)
    
    trainloader, valloader, model, optimizer = accelerator.prepare(
         trainloader, valloader, model, optimizer)
    
    # If `session.get_checkpoint()` is not None, then we are resuming from a checkpoint.
    # Load model state and iteration step from checkpoint.
    if session.get_checkpoint():
        print("Loading from checkpoint.")
        loaded_checkpoint = session.get_checkpoint()
        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
            path = os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
            checkpoint = torch.load(path)
            model.load_state_dict(checkpoint["model_state_dict"])
            step = checkpoint["step"]

    while True:
        train(model, optimizer, trainloader, accelerator)
        score = test(model, valloader, eval_metric=eval_metric)
        checkpoint = None
        if step % 5 == 0:
            # Every 5 steps, checkpoint our current state.
            # First get the checkpoint directory from tune.
            # Need to create a directory under current working directory
            # to construct an AIR Checkpoint object from.
            os.makedirs("my_model", exist_ok=True)
            torch.save(
                {
                    "step": step,
                    "model_state_dict": model.state_dict(),
                },
                "my_model/checkpoint.pt",
            )
            checkpoint = Checkpoint.from_directory("my_model")

        step += 1
        session.report({eval_metric: score}, checkpoint=checkpoint)
        
def train(model, optimizer, trainloader, accelerator):
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
    model.to(device)
    model.train()

    correct = 0
    total = 0
    for batch in trainloader:
        data={k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()
        output = model(**data)
        
        loss = output.loss.sum()
        accelerator.backward(loss)

        optimizer.step()


def test(model, data_loader, device=None,eval_metric='mean_accuracy'):
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
            
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in data_loader:
            data={k: v.to(device) for k, v in batch.items()}
            outputs = model(**data)
            _, predicted = torch.max(outputs.logits, 1)
            target=data['labels']
            total += target.size(0)
            correct += (predicted == target).sum().item()
            if eval_metric=='weighted_F1':
                score=f1_score(target.tolist(), predicted.tolist(), average='weighted')
            elif eval_metric == 'mean_accuracy':
                score=correct / total
    return score

def save_best_model(config):
    analysis = ExperimentAnalysis(f"~/ray_results/{config.run_name}/")
    model_path = analysis.get_best_logdir(metric=config.tuning_metric, mode="max")
    
    model = AutoModelForSequenceClassification.from_pretrained('roberta-large-mnli', num_labels=3)
    model.load_state_dict(torch.load(f"{model_path}/my_model/checkpoint.pt")['model_state_dict'])
    model.save_pretrained(config.save_filepath)
    
    print(f"Trained model saved to {config.save_filepath}")
    
if __name__ == "__main__":
    sys.exit(main(sys.argv[1:]))