__author__ = 'thomas'
from argparse import ArgumentParser
import csv
import logging
import os
import sys

from sklearn.externals import joblib
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer
from transformers import BertModel
import numpy as np

from telicity.util import path_utils
from telicity.util import wittgenstein

PACKAGE_PATH = os.path.dirname(__file__)

parser = ArgumentParser()
parser.add_argument('-i', '--input-file', type=str, help='input file')
parser.add_argument('-i2', '--input-file-2', type=str, help='input file 2')
parser.add_argument('-ip', '--input-path', type=str, help='input path')
parser.add_argument('-ip2', '--input-path-2', type=str, help='input path 2')
parser.add_argument('-op', '--output-path', type=str, help='output path')
parser.add_argument('-op2', '--output-path-2', type=str, help='output path 2')
parser.add_argument('-op3', '--output-path-3', type=str, help='output path 3')
parser.add_argument('-o', '--output-file', type=str, help='output file')
parser.add_argument('-o2', '--output-file-2', type=str, help='output file 2')
parser.add_argument('-a', '--action', type=str, help='action to be executed', required=True)
parser.add_argument('-ef', '--experiment-file', type=str, help='experiment file to use')
parser.add_argument('-f', '--force-rewrite', action='store_true', help='force rewrite if exists')
parser.add_argument('-ll', '--log-level', type=str, default='INFO', help='logging level', choices=['CRITICAL', 'FATAL',
                                                                                                   'ERROR', 'WARNING',
                                                                                                   'WARN', 'INFO',
                                                                                                   'DEBUG'])
parser.add_argument('-xt', '--max-tokens', type=int, default=20, help='maximum number of tokens in sentence.')
parser.add_argument('-mt', '--min-tokens', type=int, default=4, help='minimum number of tokens in sentence.')
parser.add_argument('-ns', '--num-samples', type=int, default=50, help='number of samples from corpus.')
parser.add_argument('-rs', '--random-seed', type=int, default=29306, help='random seed')
parser.add_argument('-lp', '--log-path', type=str, help='path for file log')
parser.add_argument('-enc', '--encoding', type=str, default='utf-8', help='Encoding for processing embeddings.')
parser.add_argument('-sfl', '--skip-first-line', action='store_true', default=True,
                    help='Skip header line when processing embeddings.')
parser.add_argument('-ed', '--expected-dim', type=int, default=300, help='Expected dimension of fastText embeddings.')


def prepare_mbert_experiments_monolingual(input_file, data_load_function, data_idx, target_idx, skip_header, drop_target,
                                          out_file_numpy, out_file_labels, bert_model, out_file_le):
    logging.info(f'Loading {input_file}...')

    load_fn = wittgenstein.create_function(data_load_function)
    data, targets = load_fn(input_file, data_idx, target_idx, skip_header=skip_header, drop_target=drop_target,
                            tokenize=False)

    # Labels (Manual encoding to ensure cross-lingual compatibility)
    labels = []
    for target in targets:
        if target in {'state', 'Stative', 's', 'S', 'stative'}:
            labels.append(0)
        else:
            labels.append(1)
    y = np.array(labels)

    # Data
    tokenizer = BertTokenizer.from_pretrained(bert_model)
    model = BertModel.from_pretrained(bert_model)

    encoded_sents = []
    for sent in data:
        enc = tokenizer(sent, return_tensors='pt')
        output = model(**enc)
        encoded_sents.append(output[1].detach().numpy().reshape((-1,)))

    X = np.array(encoded_sents)
    logging.info(f'X.shape={X.shape}; y.shape={y.shape}')

    np.savetxt(out_file_numpy, X)
    np.savetxt(out_file_labels, y)
    #joblib.dump(le, out_file_le)
    logging.info(f'Saved {out_file_numpy} and {out_file_labels} to disk!')


def prepare_mbert_experiments_crosslingual(train_data_path, train_data_files, train_label_files, test_data_file,
                                           test_label_file, train_data_out_file, test_label_out_file, output_path):
    logging.info(f'Processing {train_data_files} from {train_data_path}...')
    X_all = np.loadtxt(os.path.join(train_data_path, train_data_files[0]))
    y_all = np.loadtxt(os.path.join(train_data_path, train_label_files[0]))

    X_test = np.loadtxt(os.path.join(train_data_path, test_data_file))
    y_test = np.loadtxt(os.path.join(train_data_path, test_label_file))

    for train_file, label_file in zip(train_data_files[1:], train_label_files[1:]):
        X = np.loadtxt(os.path.join(train_data_path, train_file))
        y = np.loadtxt(os.path.join(train_data_path, label_file))

        X_all = np.vstack((X_all, X))
        y_all = np.concatenate((y_all, y))

    train_out_file = os.path.join(output_path, train_data_out_file)
    train_labels_out_file = os.path.join(output_path, train_label_out_file)
    train_out_path = os.path.split(train_out_file)[0]
    test_data_out_file = os.path.split(test_data_file)[1]
    test_label_out_file = os.path.split(test_label_file)[1]

    if not os.path.exists(train_out_path):
        os.makedirs(train_out_path)

    np.savetxt(train_out_file, X_all)
    np.savetxt(train_labels_out_file, y_all)
    np.savetxt(os.path.join(train_out_path, test_data_out_file), X_test)
    np.savetxt(os.path.join(train_out_path, test_label_out_file), y_test)


if __name__ == '__main__':
    args = parser.parse_args()

    log_formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)s - %(message)s', datefmt='[%d/%m/%Y %H:%M:%S %p]')
    root_logger = logging.getLogger()
    root_logger.setLevel(getattr(logging, args.log_level))

    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(log_formatter)
    root_logger.addHandler(console_handler)

    if args.log_path is not None:
        timestamped_foldername = path_utils.timestamped_foldername()
        log_path = os.path.join(args.log_path, timestamped_foldername)

        if not os.path.exists(log_path):
            os.makedirs(log_path)

        file_handler = logging.FileHandler(os.path.join(log_path, f'experiments_preprocess_{args.action}.log'))
        file_handler.setFormatter(log_formatter)
        root_logger.addHandler(file_handler)

    if args.output_path is not None and not os.path.exists(args.output_path):
        os.makedirs(args.output_path)

    # Load experiment id file
    if args.experiment_file is not None:
        with open(os.path.join(PACKAGE_PATH, 'resources', 'preprocessing', args.experiment_file), 'r') as csv_file:
            csv_reader = csv.reader(csv_file)
            experiments = []

            for line in csv_reader:
                experiments.append(line)

    if args.action == 'prepare_mbert_experiments_monolingual':
        for input_file, data_load_function, data_idx, target_idx, skip_header, drop_target, out_file_numpy, \
            out_file_labels, bert_model, out_file_le in experiments:

            prepare_mbert_experiments_monolingual(input_file=os.path.join(args.input_path, input_file),
                                                  data_load_function=data_load_function, data_idx=int(data_idx),
                                                  target_idx=int(target_idx), skip_header=skip_header=='True',
                                                  out_file_numpy=os.path.join(args.output_path, out_file_numpy),
                                                  out_file_labels=os.path.join(args.output_path, out_file_labels),
                                                  bert_model=bert_model, drop_target=drop_target,
                                                  out_file_le=os.path.join(args.output_path, out_file_le))
    elif args.action == 'prepare_mbert_experiments_crosslingual':
        for train_data_files, test_data_file, train_label_files, test_label_file, train_data_out_file, \
            train_label_out_file in experiments:

            prepare_mbert_experiments_crosslingual(train_data_path=args.input_path, train_data_files=train_data_files.split('---'),
                                                   train_label_files=train_label_files.split('---'), test_data_file=test_data_file,
                                                   test_label_file=test_label_file, train_data_out_file=train_data_out_file,
                                                   test_label_out_file=train_label_out_file, output_path=args.output_path)
    else:
        raise ValueError(f'Action "{args.action}" not supported!')