# Importing the libraries needed
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle

import transformers
from torch.utils.data import Dataset, DataLoader
import torch
import datasets

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, classification_report, roc_auc_score, accuracy_score

from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
from transformers import EarlyStoppingCallback, IntervalStrategy
from custom_models import MultiLabelDataset, MultiLabelBERTClass, SimpleLossCompute, MultiLabelDataset_v2

from utils import compute_metrics
import optuna
import argparse
from collections import Counter

import seaborn as sn
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
# warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")

        loss_fct = torch.nn.BCEWithLogitsLoss()
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

parser = argparse.ArgumentParser()
parser.add_argument('--num_turns', type=int)
parser.add_argument('--num_labels', type=int)
parser.add_argument('--indicator_label', type=str)
parser.add_argument('--text_range', type=str)
parser.add_argument('--lower_level', action='store_true')
args = parser.parse_args()

MAX_LEN = 512
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Current device: %s'%device)

checkpoint = "distilbert-base-uncased"
# checkpoint = "roberta-base"
# checkpoint = "roberta-large"
# checkpoint = "ydshieh/tiny-random-gptj-for-sequence-classification"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
sep_token = ' <SEP> ' if 'distilbert' in checkpoint else ' </s></s> '

config = AutoConfig.from_pretrained(checkpoint)
NUM_TURNS = args.num_turns
NUM_LABELS = args.num_labels
config.num_labels = NUM_LABELS


def multi_label_metrics(predictions, labels, threshold=0.5):
    probs = torch.nn.Sigmoid()(torch.Tensor(predictions))

    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1

    f1_macro_average = f1_score(y_true=labels, y_pred=y_pred, average='macro')
    # roc_auc = roc_auc_score(labels, y_pred, average = 'macro')
    accuracy = accuracy_score(labels, y_pred)

    # metrics = {'f1': f1_macro_average, 'roc_auc': roc_auc, 'accuracy': accuracy}
    metrics = {'f1': f1_macro_average, 'accuracy': accuracy}
    return metrics, labels, y_pred

def compute_metrics_multi(p: transformers.EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    result, _, _ = multi_label_metrics(predictions=preds, labels=p.label_ids)
    return result

def truncate_data(df):
    text_list, label_list = [], []
    for idx, row in df.iterrows():
        all_sents = row['text']
        # all_indicators = row['indicator']
        all_indicators = row[args.indicator_label]

        for _sent, _indicator in zip(all_sents.split('\n'), all_indicators.split('___')):
            curr_sent_idx = all_sents.split('\n').index(_sent)
            if args.text_range == 'counselor' and not _sent.startswith('___Counselor___'):
                continue
            if args.text_range == 'helpseeker' and not _sent.startswith('___HelpSeeker___'):
                continue
            
            if int(_indicator.split(',')[0]) == 0:
                if args.lower_level:
                    continue
                if curr_sent_idx == 0:
                    _contextualized_sent = all_sents.split('\n')[0]
                else:
                    _contextualized_sent = sep_token.join(all_sents.split('\n')[curr_sent_idx-NUM_TURNS:curr_sent_idx+1])
                text_list.append(_contextualized_sent.replace('___HelpSeeker___','HelpSeeker: ').replace('___Counselor___','Counselor: '))
                # label_list.append([1,0,0,0]) # merging solution and resource labels together
                _label = [0]*NUM_LABELS
                _label[0] = 1
                label_list.append(_label)
            else:
                _contextualized_sent = sep_token.join(all_sents.split('\n')[curr_sent_idx-NUM_TURNS:curr_sent_idx+1])
                text_list.append(_contextualized_sent.replace('___HelpSeeker___','HelpSeeker: ').replace('___Counselor___','Counselor: '))
                _indicator_list = [0]*NUM_LABELS
                for elem in _indicator.split(','):
                    _indicator_list[int(elem)-1] = 1
                label_list.append(_indicator_list)

    return pd.DataFrame(data={'text':text_list, 'label':label_list})

def retrieve_data():
    df = pd.read_csv(os.getcwd()+'/data/strategy_indicator.tsv', sep='\t')
    return truncate_data(df)

def data_split_asDataset(df, train_percent, eval_percent, seed):
    df_train, df_eval_test = train_test_split(df, test_size=1-train_percent, random_state=seed)
    df_eval, df_test = train_test_split(df_eval_test, test_size=1-eval_percent, random_state=seed)
    
    print('Train: %d, Eval: %d, Test: %d'%(len(df_train), len(df_eval), len(df_test)))
    return df_train.reset_index(drop=True), df_eval.reset_index(drop=True), df_test.reset_index(drop=True)

"""
Model Training
"""
preds, tgts = [], []

checkpoint_name = checkpoint.split('/')[1] if '/' in checkpoint else checkpoint
model_name = 'multilabel_trainer_lower_'+checkpoint+'_turns'+str(NUM_TURNS)+'_'+args.indicator_label+'_'+args.text_range

for i in range(5):
    df = retrieve_data()
    print('current number of instances with the given label set: %d'%len(df))
    train_percent, eval_percent = 0.8, 0.5
    # train_data, eval_data, test_data = data_split_asDataset(df, train_percent, eval_percent, i)
    train_df, eval_df, test_df = data_split_asDataset(df, train_percent, eval_percent, i)

    training_set = MultiLabelDataset_v2(train_df, tokenizer, MAX_LEN)
    evaluation_set = MultiLabelDataset_v2(eval_df, tokenizer, MAX_LEN)
    testing_set = MultiLabelDataset_v2(test_df, tokenizer, MAX_LEN)

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    # Roberta hyperparameters
    TRAIN_BATCH_SIZE = 16
    VALID_BATCH_SIZE = 16
    LEARNING_RATE = 3.44e-5
    WEIGHT_DECAY = 3.61e-6
    WARMUP_STEPS = 30
    EPOCHS = 10

    steps_per_epoch = int(len(df)*train_percent / TRAIN_BATCH_SIZE)
    EVAL_STEPS = int(steps_per_epoch/2)

    training_args = TrainingArguments(
        output_dir='Anonymized'+str(i),
        per_device_train_batch_size=TRAIN_BATCH_SIZE,
        per_device_eval_batch_size=VALID_BATCH_SIZE,
        weight_decay=WEIGHT_DECAY,
        evaluation_strategy=IntervalStrategy.STEPS,
        # evaluation_strategy='steps',
        do_train=True,
        do_eval=True,
        learning_rate=LEARNING_RATE,
        logging_steps=EVAL_STEPS,
        # save_strategy="no",
        save_steps=EVAL_STEPS,
        eval_steps=EVAL_STEPS,
        warmup_steps=WARMUP_STEPS,
        num_train_epochs=EPOCHS,
        overwrite_output_dir=True,
        metric_for_best_model='f1',
        load_best_model_at_end = True,
        save_total_limit=1,
        fp16=True # for CUDA
    )

    model = AutoModelForSequenceClassification.from_pretrained(checkpoint, config=config).to(device)

    trainer = CustomTrainer(
        model,
        training_args,
        train_dataset=training_set,
        eval_dataset=evaluation_set,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics_multi,
        callbacks = [EarlyStoppingCallback(early_stopping_patience=6)]
    )
    trainer.train()

    """
    Evaluation on Test Set
    """
    results = trainer.predict(testing_set)
    result_preds = results.predictions[0] if isinstance(results.predictions, tuple) else results.predictions
    result_metrics, y_true, y_pred = multi_label_metrics(predictions=result_preds, labels=results.label_ids)

    tgts += list(y_true)
    preds += list(y_pred)
    print()
    print()
    print('[Fold %d] Test f1: %.4f'%(i, result_metrics['f1']))
    print('[Fold %d] Test acc: %.4f'%(i, result_metrics['accuracy']))

# print(tgts, preds)
tgts_int, preds_int = [], []
for elem1, elem2 in zip(tgts, preds):
    _list1, _list2 = [], []
    for item1, item2 in zip(elem1, elem2):
        _list1.append(int(item1))
        _list2.append(int(item2))
    tgts_int.append(_list1)
    preds_int.append(_list2)

# print(tgts_int, preds_int)
# tgts_int = tgts.tolist()
# preds_int = preds.tolist()
overall_f1 = f1_score(y_true=tgts_int, y_pred=preds_int, average='macro')
print()
print()
print('[Overall] Test f1: %.4f'%(overall_f1))

print('All tgt',tgts_int)
print('All pred',preds_int)
if args.num_labels > 6 or args.num_labels==2:
    print(classification_report(tgts_int, preds_int))
else:
    if args.text_range == 'counselor':
        print(classification_report(tgts_int, preds_int, target_names=['No Strategy', 'Framing', 'Feelings', 'Exploration', 'Advice', 'Facilitating']))
    elif args.text_range == 'helpseeker':
        print(classification_report(tgts_int, preds_int, target_names=['No Label', 'Seeking', 'Abuse', 'Identity', 'Response']))


