# coding=latin-1
import util201217 as ut
import step210125 as st
import argparse, os, re, sys, time
from unidecode import unidecode
import demoji
import datetime
import numpy as np
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, normalize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LinearRegression, Lasso
from sklearn.metrics import accuracy_score, f1_score, log_loss, classification_report
from collections import defaultdict, Counter
from scipy import special
from scipy.sparse import csr_matrix, save_npz, load_npz
import pandas as pd
import random
import torch
import torch.nn as nn
from torch import optim
###################################################################################################
parser = argparse.ArgumentParser()
# inputs
parser.add_argument("-path_df_y_trn",            type=str,   default='0_source/train_features_labels_WASSA2021/goldstandard.tsv')
parser.add_argument("-path_df_emp_x_trn",        type=str,   default='0_source/train_features_labels_WASSA2021/messages_train_ready_for_WS.tsv')
parser.add_argument("-path_df_emo_x_trn",        type=str,   default='0_source/train_features_labels_WASSA2021/messages_train_sentencized_automatic_emotion_tags.tsv')
parser.add_argument("-path_df_y_dev",            type=str,   default='0_source/dev_features_labels_WASSA2021/goldstandard.tsv')
parser.add_argument("-path_df_emp_x_dev",        type=str,   default='0_source/dev_features_labels_WASSA2021/messages_dev_features_ready_for_WS.tsv')
parser.add_argument("-path_df_emo_x_dev",        type=str,   default='0_source/dev_features_labels_WASSA2021/messages_dev_sentencized_automatic_emotion_tags.tsv')
parser.add_argument("-path_df_y_tst",            type=str,   default='0_source/test_features_labels_EMO_WASSA2021/gold_standard_test_EMO.tsv')
parser.add_argument("-path_df_emp_x_tst",        type=str,   default='0_source/test_features_labels_EMO_WASSA2021/messages_test_features_ready_for_WS.tsv')
parser.add_argument("-path_df_emo_x_tst",        type=str,   default='0_source/test_features_labels_EMO_WASSA2021/messages_test_sentencized_automatic_emotion_tags.tsv')
parser.add_argument("-device",      type=str,   default='cuda:1')
parser.add_argument("-bertlang",    type=str,   default='enlarge')
parser.add_argument("-bertlayer",   type=int,   default=24)
parser.add_argument("-min_freq",    type=int,   default=0)
parser.add_argument("-no_below",    type=float, default=.0005, help='min word freq')
parser.add_argument("-no_above",    type=float, default=.9,    help='max % of docs where the word is found')
parser.add_argument("-pad_rate",    type=float, default=99.9,  help='rate pad wrt maxlen, per fasttext')
parser.add_argument("-max_context", type=int,  default=2)
# parser.add_argument("-pad_size", type=int,  default = 510,  help='per bert')
args = parser.parse_args()
sys.stdout = sys.stderr = log = ut.log(__file__, f"")
os.system(f"cp {__file__} {ut.__file__} {log.pathtime}")
startime = ut.start()
align_size = ut.print_args(args)
print(f"{'dirout':.<{align_size}} {log.pathtime}")
###################################################################################################
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print(f"{'GPU in use':.<{align_size}} {device}\n{'#'*80}" if torch.cuda.is_available() else f"No GPU available, using the CPU.\n{'#'*80}")
###################################################################################################
df_y_trn = pd.read_csv(args.path_df_y_trn, sep="\t", header=None)
print(df_y_trn)

df_y_dev = pd.read_csv(args.path_df_y_dev, sep="\t", header=None)
print(df_y_dev)

df_y_tst = pd.read_csv(args.path_df_y_tst, sep="\t", header=None)
print(df_y_tst)

labencoder = LabelEncoder()
emo_trn = labencoder.fit_transform(df_y_trn[2].values)
emo_dev = labencoder.fit_transform(df_y_dev[2].values)
emo_tst = labencoder.fit_transform(df_y_tst[0].values)
id2year = {i: str(l) for i, l in enumerate(labencoder.classes_)}
print(id2year)
np.save(f"{log.pathtime}y_emo_trn", emo_trn)
np.save(f"{log.pathtime}y_emo_dev", emo_dev)
np.save(f"{log.pathtime}y_emo_tst", emo_tst)
np.save(f"{log.pathtime}y_emp_trn", df_y_trn[0].values)
np.save(f"{log.pathtime}y_emp_dev", df_y_dev[0].values)
# np.save(f"{log.pathtime}y_emp_tst", df_y_tst[0].values)
np.save(f"{log.pathtime}y_dis_trn", df_y_trn[1].values)
np.save(f"{log.pathtime}y_dis_dev", df_y_dev[1].values)
# np.save(f"{log.pathtime}y_dis_tst", df_y_tst[1].values)


df_emp_x_trn = pd.read_csv(args.path_df_emp_x_trn, sep="\t", header=0)
df_emo_x_trn = pd.read_csv(args.path_df_emo_x_trn, sep="\t", header=0)

df_emp_x_dev = pd.read_csv(args.path_df_emp_x_dev, sep="\t", header=0)
df_emo_x_dev = pd.read_csv(args.path_df_emo_x_dev, sep="\t", header=0)

df_emp_x_tst = pd.read_csv(args.path_df_emp_x_tst, sep="\t", header=0)
df_emo_x_tst = pd.read_csv(args.path_df_emo_x_tst, sep="\t", header=0)

enc = OneHotEncoder(handle_unknown='ignore')
enc.fit(df_emp_x_trn.gender.to_numpy().reshape(-1, 1))
x_gender_trn = enc.transform(df_emp_x_trn.gender.to_numpy().reshape(-1, 1))
x_gender_dev = enc.transform(df_emp_x_dev.gender.to_numpy().reshape(-1, 1))
x_gender_tst = enc.transform(df_emp_x_tst.gender.to_numpy().reshape(-1, 1))
# save_npz(f"{log.pathtime}x_gender_trn", x_gender_trn)
# save_npz(f"{log.pathtime}x_gender_dev", x_gender_dev)
# save_npz(f"{log.pathtime}x_gender_tst", x_gender_tst)
np.save(f"{log.pathtime}x_gender_trn", x_gender_trn.toarray())
np.save(f"{log.pathtime}x_gender_dev", x_gender_dev.toarray())
np.save(f"{log.pathtime}x_gender_tst", x_gender_tst.toarray())

# print(x_gender_trn.toarray(), x_gender_trn.toarray().shape)

x_income_trn = normalize(df_emp_x_trn.income.to_numpy().reshape(-1, 1), axis=0)
x_income_dev = normalize(df_emp_x_dev.income.to_numpy().reshape(-1, 1), axis=0)
x_income_tst = normalize(df_emp_x_tst.income.to_numpy().reshape(-1, 1), axis=0)

# print(x_income_trn[:7], x_income_trn.shape, min(x_income_trn), max(x_income_trn), np.mean(x_income_trn), Counter(df_emp_x_trn.income))
np.save(f"{log.pathtime}x_income_trn", x_income_trn)
np.save(f"{log.pathtime}x_income_dev", x_income_dev)
np.save(f"{log.pathtime}x_income_tst", x_income_tst)


x_pers_trn = normalize(df_emp_x_trn.iloc[:, 14:19].to_numpy(), axis=0)
x_pers_dev = normalize(df_emp_x_dev.iloc[:, 9:14].to_numpy(), axis=0)
x_pers_tst = normalize(df_emp_x_tst.iloc[:, 9:14].to_numpy(), axis=0)

print(x_pers_trn, df_emp_x_trn.personality_conscientiousness, x_pers_trn.shape)
np.save(f"{log.pathtime}x_pers_trn", x_pers_trn)
np.save(f"{log.pathtime}x_pers_dev", x_pers_dev)
np.save(f"{log.pathtime}x_pers_tst", x_pers_tst)

x_iri_trn = normalize(df_emp_x_trn.iloc[:, 19:].to_numpy(), axis=0)
x_iri_dev = normalize(df_emp_x_dev.iloc[:, 14:].to_numpy(), axis=0)
x_iri_tst = normalize(df_emp_x_tst.iloc[:, 14:].to_numpy(), axis=0)

print(x_iri_trn, df_emp_x_trn.iri_perspective_taking, x_iri_trn.shape)
np.save(f"{log.pathtime}x_iri_trn", x_iri_trn)
np.save(f"{log.pathtime}x_iri_dev", x_iri_dev)
np.save(f"{log.pathtime}x_iri_tst", x_iri_tst)


proc = st.Processing(log.pathtime, device)
lens = [len(row.split()) for row in df_emp_x_trn.essay.values]
# print(max(lens), min(lens), np.percentile(lens, 99.9), np.percentile(lens, 100), np.percentile(lens, 99))
pad_size = int(round(np.percentile(lens, args.pad_rate), -1))

proc.bert_lookup(lang=args.bertlang, hiddenlayer=args.bertlayer, X1=df_emp_x_trn.essay.values, padsize=pad_size, matrix=True, splits=(None, None), name='emp_trn')
proc.bert_lookup(lang=args.bertlang, hiddenlayer=args.bertlayer, X1=df_emp_x_dev.essay.values, padsize=pad_size, matrix=True, splits=(None, None), name='emp_dev')
proc.bert_lookup(lang=args.bertlang, hiddenlayer=args.bertlayer, X1=df_emp_x_tst.essay.values, padsize=pad_size, matrix=True, splits=(None, None), name='emp_tst')


# ut.end(startime)