import pandas as pd
import os
import numpy as np
import random
import math
import pickle
from collections import Counter
from sklearn.model_selection import train_test_split
from config import basic_opt as opt
from matplotlib import pyplot as plt
import pickle
from utils.utils import save_output
import logging as log

def read_df(data_dir, lang):
    "load data frame, remove duplicates"
    df = pd.read_csv(open(os.path.join(data_dir, str(lang)+'.csv'), 'r'), error_bad_lines=False)
    df.drop_duplicates(subset='preverb', inplace=True)
    return df

def apply_inplace(df, field, func):
    return pd.concat([df.drop(field, axis=1), df[field].apply(func)], axis=1)

def filtering(df, lang_mode, target_field, NUM_VERBS=50, MIN_LEN=10, MAX_LEN=50, NUM_SOURCE =100000):
    "obtain samples end with 50 most frequent verbs and build dictionary, target_field: lemma, inflected, tense(japanese)"

    verb_freqs = df[target_field].value_counts()[:NUM_VERBS].to_dict() #count verb freqs
    df = df[df[target_field].isin(list(verb_freqs.keys()))] # extract sentences with verbs in range

    # replace numbers with 0
    df = apply_inplace(df, 'preverb', lambda x: [w if not w.isdigit() else '0' for w in x.split()])
    if 'postag' in df.columns:
        df = apply_inplace(df, 'postag', lambda x:[w for w in str(x).split()])
    df['lengths'] = df['preverb'].apply(lambda x: len(x))

    # obtain samples in length range
    filtered_df = df[df['lengths'].isin(range(MIN_LEN, MAX_LEN + 1))]

    ## build vocabulary
    ### build target dict
    t_word2index = {v: i for i, v in enumerate(verb_freqs.keys())}

    ### build soruce dict
    if lang_mode == 'char':
        source_freqs = Counter()
        for sent in filtered_df.preverb.tolist():
            for w in "".join(sent):
                source_freqs[str(w)] += 1
        s_word2index = {v: i + 1 for i, (v, c) in enumerate(list(source_freqs.items()))}
        s_word2index['<PAD>'] = 0

        data_dict = {'source': s_word2index,
                     'target': t_word2index}

    elif lang_mode == 'word':
        data_dict = dict()
        # build word_dict and pos_dict for word-level language model
        source_freqs = filtered_df.preverb.str.join(" ").str.split(pat=" ", expand=True).stack().value_counts().to_dict()
        s_word2index = {v: i + 2 for i, (v, c) in enumerate(list(source_freqs.items())[:NUM_SOURCE])}
        s_word2index['<PAD>'] = 0
        s_word2index['<UNK>'] = 1

        if 'postag' in filtered_df.columns:
            pos_freqs = filtered_df.postag.str.join(" ").str.split(pat=" ", expand=True).stack().value_counts().to_dict()
            pos2index = {v: i + 1 for i, (v, c) in enumerate(list(pos_freqs.items()))}
            pos2index['<PAD>'] = 0
            data_dict['pos']=pos2index

        data_dict['source'] = s_word2index
        data_dict['target'] = t_word2index

    return filtered_df, data_dict

def grouping_by_verb(pairs, balance, proportion=0):
    """
    group samples by verb, if proportion, then a fixed proportion of samples are extracted from different verb classes,
    if balance, then extract sentences from all verb classes of number equal to the number of sentences that
    the smallest verb class contains.
    :param pairs: data tuples(sentence, pos, target)
    :param balance:
    :param proportion:
    :return: balanced, or down-sampled pairs
    """
    output_pairs = []
    label_indexes = {}
    for idx,pair in enumerate(pairs):
        if pair[-1] in label_indexes.keys():
            label_indexes[pair[-1]].append(idx)
        else:
            label_indexes[pair[-1]] = [idx]

    if proportion > 0:
        for l in label_indexes.keys():
            num_samples = round(len(label_indexes[l])*proportion)
            label_indexes[l] =  label_indexes[l][:num_samples]

    if balance:
        min_count = min([len(label_indexes[l]) for l in label_indexes.keys()])
        for l in label_indexes.keys():
            label_indexes[l] = label_indexes[l][:min_count]

    output_pairs.extend(pairs[i] for k in label_indexes.keys() for i in label_indexes[k])
    # plt.bar(list(label_count.keys()), label_count.values(), color='g')
    # plt.show()
    return output_pairs

def obtain_subsentences(pairs, data_dict, lang_mode,
                        portion=[0.3, 0.5, 0.7, 0.9, 1.0],
                        group_by_length=False,
                        reverse = False):
    """
    split sentences into subsentences according to given portion
    :param pairs:  can be (sentences, id, target) or (sentence, postag, id, target) tuple
    :param data_dict: has field 'source', 'target', and 'postag' for german data
    :param lang_mode: 'word' or 'char'
    :param portion: 'subsentence length'
    :param group_by_length: if True, then returned subsentences are grouped by length and stored in dict with key=p
    :return: list of (subsent,target) tuples or dict of (subsent, target) tuple grouped by length,
            if .csv file has postag filed, then tuple is (subsent, pos, target)
    """
    data = {}
    count = 0
    if reverse:
        portion.reverse()
    for p in portion:
        data[p] = []
        for id,pair in enumerate(pairs):
            has_postag = True if len(pair) > 2 and lang_mode=='word' else False
            index = math.ceil(p*len(pair[0]))

            if lang_mode == 'char':
                text = "".join(pair[0][:int(index)])  # character model
            elif lang_mode == 'word':
                if  len(pair)>2:
                    assert len(pair[0]) == len(pair[1])
                text = pair[0][:int(index)]

            if 'pos' in data_dict.keys():
                if pair[-1] in data_dict['target'].keys():
                    data[p].append(([data_dict['source'][w] if w in data_dict['source'].keys() else 1 for w in text],
                                    [data_dict['pos'][w] for w in pair[1][:int(index)]],
                                    id,
                                    data_dict['target'][pair[-1]]))
                else:
                    continue

            else:
                if pair[-1] in data_dict['target'].keys():
                    data[p].append(([data_dict['source'][w] if w in data_dict['source'].keys() else 1 for w in text],
                                    id,
                                    data_dict['target'][pair[-1]]))
                else:
                    continue
            count += 1

    if group_by_length==True:
        return data, count
    else:
        out = []
        for d in data.keys():
            random.shuffle(data[d])
            out.extend(data[d])
        return out, len(out)

def get_len(data):
    lens =  [len(d[0]) for d in data]
    return int(max(lens)), int(min(lens)), int(sum(lens)/len(lens))

def filtering_test(df, target_field, NUM_VERBS=50, MIN_LEN=10, MAX_LEN=50, NUM_SOURCE =100000):
    "obtain samples end with 50 most frequent verbs and build dictionary, target_field: lemma, inflected, tense(japanese)"

    verb_freqs = df[target_field].value_counts()[:NUM_VERBS].to_dict() #count verb freqs
    df = df[df[target_field].isin(list(verb_freqs.keys()))] # extract sentences with verbs in range

    # replace numbers with 0
    df = apply_inplace(df, 'preverb', lambda x: [w if not w.isdigit() else '0' for w in x.split()])
    if 'postag' in df.columns:
        df = apply_inplace(df, 'postag', lambda x:[w for w in str(x).split()])
    df['lengths'] = df['preverb'].apply(lambda x: len(x))

    # obtain samples in length range
    filtered_df = df[df['lengths'].isin(range(MIN_LEN, MAX_LEN + 1))]
    return filtered_df

def load_test_data(opt, dictionary):
    df = read_df(opt.data_dir, opt.input_test_file)
    df.drop(columns=[c for c in df.columns.tolist() if c.startswith('Unnamed')])
    filtered_df = filtering_test(df, opt.target_field, NUM_VERBS= opt.num_verbs, NUM_SOURCE=opt.source_vocab_size)

    if 'postag' in filtered_df.columns:
        data = [(item.preverb, item.postag, item[opt.target_field]) for index, item in
                filtered_df.iterrows()]  # de(871,536)
    else:
        data = [(item.preverb, item[opt.target_field]) for index, item in filtered_df.iterrows()]

    test_df = pd.DataFrame({'preverbs': [t[0] for t in data],
                            'targets': [t[-1] for t in data]},
                           columns=['preverbs', 'targets'])

    if not os.path.exists(opt.out_dir):
        os.makedirs(opt.out_dir)
    test_df.to_csv(os.path.join(opt.out_dir, 'test_df.csv'))
    # split test samples into subsentences
    test_data, count_test = obtain_subsentences(data, dictionary, opt.lang_mode, opt.portion, group_by_length=True)

    return test_data


def load_data(opt):

    # read data pairs and build vocabulary
    df = read_df(opt.data_dir, opt.lang)

    # filtering by verb vocab
    filtered_df, dictionary = filtering(df, opt.lang_mode, opt.target_field, NUM_VERBS=opt.num_verbs, NUM_SOURCE=opt.source_vocab_size)

    # update dictionary by adding pretrained_embedding
    if opt.use_pretrained_embed:
        import pickle
        embed_dict = pickle.load(open(os.path.join(opt.data_dir, 'embed_dict.pickle'), 'rb'))
        dictionary['embed'] = embed_dict

    if 'postag' in filtered_df.columns:
         data  = [(item.preverb, item.postag, item[opt.target_field]) for index, item in filtered_df.iterrows()]  # de(871,536)
    else:
        data = [(item.preverb, item[opt.target_field]) for index, item in filtered_df.iterrows()]


    log.info("total number of sentences:%d"%(len(data)))
    ### balance class or get partial data
    if opt.balance_flag or opt.split_p>0:
        rest, test = train_test_split(data, test_size=0.1, random_state=42)  # get fixed test set
        grouped = grouping_by_verb(rest,balance=opt.balance_flag, proportion=opt.split_p)
        train, eval = train_test_split(grouped, test_size=0.1, random_state=42)
    else:
        # separate data_processing into train, test and eval (before splitting into subsentences)
        rest, test = train_test_split(data, test_size=0.2, random_state=42)  # get fixed test set
        train, eval = train_test_split(rest, test_size=0.2, random_state=42)  # get fixed test set
    log.info("complete sentences in set: train(%d), eval(%d), test(%d)" % (len(train), len(eval), len(test)))


    log.info('length info(train): max(%d), min(%d),avg(%d)'%(get_len(train)))
    log.info('length info(eval): max(%d), min(%d),avg(%d)' % (get_len(eval)))
    log.info('length info(test): max(%d), min(%d),avg(%d)' % (get_len(test)))

    # obtain subsentences and built tensor
    if opt.fix_portion > 0:
        train_data, count_train = obtain_subsentences(train, dictionary,opt.lang_mode, opt.portion, group_by_length=True)
        eval_data, count_eval = obtain_subsentences(eval, dictionary, opt.lang_mode, opt.portion, group_by_length=True)
        save_output(opt.data_dir, 'eval', eval_data)  # save length-splitted eval data
    else:
        train_data, count_train = obtain_subsentences(train, dictionary, opt.lang_mode, opt.portion,
                                                      group_by_length=False,
                                                      reverse=opt.reverse)
        eval_data, count_eval = obtain_subsentences(eval, dictionary, opt.lang_mode, opt.portion, group_by_length=False)
        eval_data_copy, count_eval = obtain_subsentences(eval, dictionary, opt.lang_mode, opt.portion, group_by_length=True)
        save_output(opt.data_dir, 'eval', eval_data_copy) # save length-splitted eval data
    # always group test data by length
    # save test original sentences before convert to indexes
    test_df = pd.DataFrame({'preverbs':[t[0] for t in test],
                            'targets':[t[-1] for t in test]},
                           columns=['preverbs','targets'])

    if not os.path.exists(opt.out_dir):
        os.makedirs(opt.out_dir)
    test_df.to_csv(os.path.join(opt.out_dir,'test_df.csv'))
    # split test samples into subsentences
    test_data, count_test = obtain_subsentences(test, dictionary, opt.lang_mode, opt.portion, group_by_length=True)

    # test_data, count = obtain_test_samples(test, dictionary, lang_mode)

    log.info("Total number of samples: %d train, %d dev, %d test" % (count_train, count_eval, count_test))
    # log.info("Total number of samples: %d training data, %d dev, %d test" % (len(train_data), len(eval_data), len(test_data)))
    return train_data, eval_data, test_data, dictionary

