# Importing the libraries needed
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle

import transformers
from torch.utils.data import Dataset, DataLoader
import torch
import datasets

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, classification_report

from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
from transformers import EarlyStoppingCallback, IntervalStrategy
from custom_models import MultiLabelDataset, MultiLabelBERTClass, SimpleLossCompute

from collections import Counter

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
# warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

MAX_LEN = 512
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Current device: %s'%device)

checkpoint = "roberta-base"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
sep_token = ' <SEP> ' if 'distilbert' in checkpoint else ' </s></s> '

def process_data(df):
    id_list, text_list, label_list = [], [], []
    for idx, row in df.iterrows():
        all_sent_list = row['text'].split('\n')
        inference_tgt = [elem.replace('___Counselor___','Counselor: ') for elem in all_sent_list if elem.startswith('___Counselor')]
        fake_labels = [[0,0,0,0] for i in range(len(inference_tgt))]
        id_list += [row['id']]*len(inference_tgt)

        text_list += inference_tgt
        label_list += fake_labels

    return pd.DataFrame(data={'id':id_list, 'text':text_list, 'label':label_list})

def retrieve_data():
    # df = pd.read_csv(os.getcwd()+'/data/additional_transcripts.tsv', sep='\t')
    df_transcripts = pd.read_csv(os.getcwd()+'/data/extended_transcripts_allOutcome.tsv', sep='\t')
    return process_data(df_transcripts)

def run_inference(eval_loader, model, criterion=torch.nn.BCEWithLogitsLoss()):
    all_preds = []
    for j, data in enumerate(eval_loader, 0):
        model.eval()
        outputs = model(data['ids'].to(device), data['mask'].to(device))
        outputs_onehot = np.array(torch.sigmoid(outputs).cpu().detach().numpy()) >= 0.5
        all_preds += outputs_onehot.tolist() 

    return all_preds

"""
Inference
"""

MAX_LEN = 512
VALID_BATCH_SIZE = 16

df = retrieve_data()
testing_set = MultiLabelDataset(df, tokenizer, MAX_LEN)
testing_loader = DataLoader(testing_set, batch_size=VALID_BATCH_SIZE)

# model_name = 'multilabel_roberta-base_turns0_fold0'
model_name = 'multilabel_roberta-base_turns0_indicator_counselor_fold1'
model_path = 'ANONYMIZED'+model_name

best_step_model = MultiLabelBERTClass(checkpoint=checkpoint).to(device)
best_step_num = ''
for _file in os.listdir(model_path):
    filename = os.fsdecode(_file)
    if filename.endswith('.pt'):
        best_step_num = filename.split('/')[-1].split('.pt')[0].split('_')[-1]
best_step_model.load_state_dict(torch.load(model_path+'/state_dict_'+str(best_step_num)+'.pt'))

test_preds = run_inference(testing_loader, best_step_model)
assert len(test_preds) == len(df)

inferred_labels = []
for elem in test_preds:
    _label_list = []
    for i, item in enumerate(elem):
        if item and i != 0:
            _label_list.append(str(i))
    if len(_label_list) == 0:
        _label_list.append('0')
    inferred_labels.append(','.join(_label_list))

df['label'] = inferred_labels

# df_transcripts = pd.read_csv(os.getcwd()+'/data/additional_transcripts.tsv', sep='\t')
df_transcripts = pd.read_csv(os.getcwd()+'/data/extended_transcripts_allOutcome.tsv', sep='\t')
indicators = []
for _id, _text in zip(df_transcripts['id'].tolist(), df_transcripts['text'].tolist()):
    minidf = df[df['id']==_id]
    all_sents = [elem.replace('___Counselor___','Counselor: ') for elem in _text.split('\n')]
    all_labels = ['0']*len(all_sents)
    for idx, row in minidf.iterrows():
        curr_sent_idx = all_sents.index(row['text'])
        curr_sent_label = row['label']
        if curr_sent_label == '0':
            continue
        else:
            all_labels[curr_sent_idx] = curr_sent_label
    indicators.append('___'.join(all_labels))
df_transcripts['indicator'] = indicators
# df_transcripts.to_csv(os.getcwd()+'/data/additional_strategy_indicator.tsv', sep='\t', index=False)
df_transcripts.to_csv(os.getcwd()+'/data/extended_transcripts_allOutcome_strategy_indicator.tsv', sep='\t', index=False)

