# Importing the libraries needed
import os
import time
import pandas as pd
import numpy as np
import shutil

import transformers
from torch.utils.data import Dataset, DataLoader
import torch
import datasets

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, classification_report

from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
from transformers import EarlyStoppingCallback, IntervalStrategy

from utils import compute_metrics
import optuna
import argparse
from collections import Counter

from load_data import convert_data, combine_codes
from custom_models import DatasetsForFeaturedBERTclf, FeaturedBERTclf

parser = argparse.ArgumentParser()
parser.add_argument('--num_turns', type=int)
parser.add_argument('--output_label', type=str, default='positive')
parser.add_argument('--binarize_label', action='store_true')
parser.add_argument('--conv', action='store_true')
parser.add_argument('--code', action='store_true')
parser.add_argument('--feature', action='store_true')
parser.add_argument('--summary', action='store_true')
parser.add_argument('--objective', action='store_true')
args = parser.parse_args()

NUM_TURNS = args.num_turns
OUTPUT_LABEL = args.output_label
IS_BINARY = args.binarize_label
USE_CONV = args.conv
USE_CODE = args.code
USE_FEATURE = args.feature
USE_SUMMARY = args.summary
USE_OBJECTIVE = args.objective

MAX_LEN = 512
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Current device: %s'%device)

# checkpoint = "roberta-base"
checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
SEP_TOKEN = ' <SEP> ' if 'distilbert' in checkpoint else ' </s></s> '

config = AutoConfig.from_pretrained(checkpoint)

def retrieve_data():
    """
    input types
    - conv: baseline
        - v. + code: Does utterance-level code help?
            - v. + LLM_feat: Can utterance-level and session-level features work together?
            - v. + summary: Can summary better work together with utterance-level code?
                - v. + LLM_feat: Can everything combined together give the best result?
        - v. + LLM_feat: Is session-level feature more helpful than utterance-level code?
        - v. + summary: Is summary more helpful than utterance-level code?
            - v. + LLM_feat: Is it beneficial to have both summary and session-level feature?
            
    - LLM_feat: Can LLMs generate guided features that can better characterize conversation?
    - summary: Can LLMs generate summaries that can better characterize conversation?
        - v. + LLM_feat: Does session-level feature STILL help?
    """
    df_transcripts = pd.read_csv(os.getcwd()+'/data/extended_transcripts_allOutcome.tsv', sep='\t')
    df_codes_annotated = pd.read_csv(os.getcwd()+'/data/strategy_indicator.tsv', sep='\t')
    df_codes_labeled = pd.read_csv(os.getcwd()+'/data/extended_transcripts_allOutcome_strategy_indicator.tsv', sep='\t')
    df_codes = combine_codes(df_codes_annotated, df_codes_labeled)
    df_features = pd.read_csv(os.getcwd()+'/data/chatgpt_convToFeature.tsv', sep='\t')
    df_summary = pd.read_csv(os.getcwd()+'/data/chatgpt_convToSummary.tsv', sep='\t')
    df_ObjSummary = pd.read_csv(os.getcwd()+'/data/chatgpt_convToObjectiveSummary.tsv', sep='\t')

    df = convert_data(df_transcripts, df_codes, df_features, df_summary, df_ObjSummary, IS_BINARY, NUM_TURNS, OUTPUT_LABEL, SEP_TOKEN, USE_CONV, USE_CODE, USE_FEATURE, USE_SUMMARY, USE_OBJECTIVE)
    config.num_labels = 2 if IS_BINARY else 3

    return df

def data_split(df, train_percent, eval_percent, seed):
    all_ids = sorted(df['id'].tolist())
    _train, _eval_test = train_test_split(all_ids, test_size=1-train_percent, random_state=seed)
    _eval, _test = train_test_split(_eval_test, test_size=1-eval_percent, random_state=seed)

    df_train = df[df['id'].isin(_train)]
    df_eval = df[df['id'].isin(_eval)]
    df_test = df[df['id'].isin(_test)]
    
    print('Train: %d, Eval: %d, Test: %d'%(len(df_train), len(df_eval), len(df_test)))
    try:
        for _split, _df in zip(['Train', 'Eval', 'Test'], [df_train, df_eval, df_test]):
            _cnt = Counter()
            _cnt.update(_df['label'].tolist())
            print(_split, _cnt)
        return df_train, df_eval, df_test
    except:
        return df_train, df_eval, df_test

def tokenize_function(data):
    return tokenizer(data["text"], None, add_special_tokens=True, 
            max_length=MAX_LEN, padding=True, return_token_type_ids=True, truncation=True)

def convert_df_to_dict(train_df, eval_df, test_df):
    train_dict, eval_dict, test_dict = {'text':[], 'label':[]}, {'text':[], 'label':[]}, {'text':[], 'label':[]}
    for idx, row in train_df.iterrows():
        train_dict['text'].append(str(row['text']))
        train_dict['label'].append(int(row['label']))
    for idx, row in eval_df.iterrows():
        eval_dict['text'].append(str(row['text']))
        eval_dict['label'].append(int(row['label']))
    for idx, row in test_df.iterrows():
        test_dict['text'].append(str(row['text']))
        test_dict['label'].append(int(row['label']))
    
    return train_dict, eval_dict, test_dict

def convert_to_Dataset(df_train, df_eval, df_test):
    train_dict, eval_dict, test_dict = convert_df_to_dict(df_train, df_eval, df_test)

    train_data = datasets.Dataset.from_dict(train_dict, features=datasets.Features({
                    'text': datasets.Value(id=None, dtype='string'), 
                    'label': datasets.ClassLabel(num_classes=config.num_labels, names=list(range(config.num_labels)), 
                    names_file=None, id=None)}))
    eval_data = datasets.Dataset.from_dict(eval_dict, features=datasets.Features({
                    'text': datasets.Value(id=None, dtype='string'), 
                    'label': datasets.ClassLabel(num_classes=config.num_labels, names=list(range(config.num_labels)), 
                    names_file=None, id=None)}))
    test_data = datasets.Dataset.from_dict(test_dict, features=datasets.Features({
                    'text': datasets.Value(id=None, dtype='string'), 
                    'label': datasets.ClassLabel(num_classes=config.num_labels, names=list(range(config.num_labels)), 
                    names_file=None, id=None)}))

    return train_data.map(tokenize_function, batched=True), eval_data.map(tokenize_function, batched=True), test_data.map(tokenize_function, batched=True)

def generate_summary(batch):
    with torch.no_grad():
        inputs = tokenizer(batch["text"], None, add_special_tokens=True, 
                max_length=MAX_LEN, padding=True, return_token_type_ids=True, truncation=True, return_tensors="pt")
        input_ids = inputs.input_ids.to(device)
        attention_mask = inputs.attention_mask.to(device)

        output = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = output.logits.cpu()
        preds = logits.argmax(dim=1).tolist()

        batch["pred"] = preds

    return batch

"""
Model Training
"""
preds, tgts = [], []

checkpoint_name = checkpoint.split('-')[0]
binary_indicator_text = 'binary' if IS_BINARY else 'multi'
input_stream_text = ''
if USE_CONV:
    input_stream_text += 'Conv'
if USE_CODE:
    input_stream_text += 'Code'
if USE_FEATURE:
    input_stream_text += 'Feature'
if USE_SUMMARY:
    input_stream_text += 'Summary'
if USE_OBJECTIVE:
    input_stream_text += 'Objective'

model_name = f"PLMClf_{input_stream_text}_{checkpoint_name}_{OUTPUT_LABEL}_{binary_indicator_text}"
if USE_CONV and (USE_FEATURE or USE_SUMMARY):
    IS_CUSTOM_MODEL = True
else:
    IS_CUSTOM_MODEL = False

for i in range(10):
    df = retrieve_data()
    train_percent, eval_percent = 0.6, 0.5
    df_train, df_eval, df_test = data_split(df, train_percent, eval_percent, i)
    if IS_CUSTOM_MODEL:
        training_set = DatasetsForFeaturedBERTclf(df_train.reset_index(drop=True), tokenizer, MAX_LEN)
        evaluation_set = DatasetsForFeaturedBERTclf(df_eval.reset_index(drop=True), tokenizer, MAX_LEN)
        testing_set = DatasetsForFeaturedBERTclf(df_test.reset_index(drop=True), tokenizer, MAX_LEN)
    else:
        training_set, evaluation_set, testing_set = convert_to_Dataset(df_train, df_eval, df_test)

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    
    TRAIN_BATCH_SIZE = 16
    VALID_BATCH_SIZE = 16
    EPOCHS = 10
    LEARNING_RATE = 3.44e-5
    WEIGHT_DECAY = 3.61e-6
    WARMUP_STEPS = 30

    steps_per_epoch = int(len(df)*train_percent / TRAIN_BATCH_SIZE)
    EVAL_STEPS = int(steps_per_epoch/2)

    training_args = TrainingArguments(
        output_dir='ANONYMIZED'+str(i),
        per_device_train_batch_size=TRAIN_BATCH_SIZE,
        per_device_eval_batch_size=VALID_BATCH_SIZE,
        weight_decay=WEIGHT_DECAY,
        evaluation_strategy=IntervalStrategy.STEPS,
        do_train=True,
        do_eval=True,
        learning_rate=LEARNING_RATE,
        logging_steps=EVAL_STEPS,
        # save_strategy="no",
        save_steps=EVAL_STEPS,
        eval_steps=EVAL_STEPS,
        warmup_steps=WARMUP_STEPS,
        num_train_epochs=EPOCHS,
        overwrite_output_dir=True,
        metric_for_best_model='f1',
        load_best_model_at_end = True,
        save_total_limit=2,
        fp16=True # for CUDA
    )

    if IS_CUSTOM_MODEL:
        model = FeaturedBERTclf(checkpoint=checkpoint, config=config).to(device)
    else:
        model = AutoModelForSequenceClassification.from_pretrained(checkpoint, config=config).to(device)

    trainer = Trainer(
        model,
        training_args,
        train_dataset=training_set,
        eval_dataset=evaluation_set,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        # callbacks = [EarlyStoppingCallback(early_stopping_patience=6)]
    )
    trainer.train()

    print("Done training")

    """
    Evaluation on Test Set
    """
    results = trainer.predict(testing_set)
    result_preds = results.predictions[0] if isinstance(results.predictions, tuple) else results.predictions

    # Save training/evaluation set logits for ensemble
    tr_results = trainer.predict(training_set)
    tr_preds = tr_results.predictions[0] if isinstance(tr_results.predictions, tuple) else tr_results.predictions

    ev_results = trainer.predict(evaluation_set)
    ev_preds = ev_results.predictions[0] if isinstance(ev_results.predictions, tuple) else ev_results.predictions

    logit_path = f'ANONYMIZED'
    if not os.path.exists(logit_path):
        os.makedirs(logit_path)
    for _split, _pred_results in zip(['train', 'eval', 'test'], [tr_preds, ev_preds, result_preds]):
        with open(f'{logit_path}/{_split}_{i}.npy', 'wb') as f:
            np.save(f, _pred_results)

    y_pred = np.argmax(result_preds, axis=-1)
    y_true = np.argmax(results.label_ids, axis=-1) if IS_CUSTOM_MODEL else results.label_ids

    tgts += y_true.tolist()
    preds += y_pred.tolist()

    print()
    print(y_true)
    print(y_pred)
    print('[Fold %d] Test f1: %.6f'%(i, f1_score(y_true=y_true.tolist(), y_pred=y_pred.tolist(), average='macro')))

overall_f1 = f1_score(y_true=tgts, y_pred=preds, average='macro')
print()
print()
print('[Overall] Test f1: %.6f'%(overall_f1))

if IS_BINARY:
    print(classification_report(tgts, preds, target_names=['Negative', 'Neut and Posi']))
else:
    print(classification_report(tgts, preds, target_names=['Not at all', 'A little', 'A lot']))