# coding=utf-8
import argparse
import csv
import sys
import glob
import json
import logging
import os
import collections
from transformers import tokenization_bert
import string
import re
from nltk.corpus import stopwords
from nltk import FreqDist
import random
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
import logging
from nltk.stem.porter import PorterStemmer
from sklearn.feature_extraction.text import CountVectorizer
from gensim import corpora
from collections import defaultdict, Counter
from nltk import FreqDist

logger = logging.getLogger(__name__)

csv.field_size_limit(sys.maxsize)

def _read_csv(input_file, quotechar='"'):
    """Reads a tab separated value file."""
    with open(input_file, "r", encoding="utf-8") as f:
        return list(csv.reader(f, dialect='excel', delimiter=',', quotechar=quotechar))
        # return list(csv.reader(f, delimiter="\t", quotechar=quotechar))

def _read_tsv(input_file, quotechar=None):
    """Reads a tab separated value file."""
    with open(input_file, "r", encoding="utf-8") as f:
        return list(csv.reader(f, delimiter="\t", quotechar=quotechar))

def stemtokens(text):#process of reducing each word to its root or base.
    porter = PorterStemmer()
    stemmed = [porter.stem(word) for word in text]
    return stemmed

def _remove_numbers(texts):
    no_number_text=[ x for x in texts if not x.isnumeric()]
    return no_number_text

def _create_lda_examples(lines, output_file):
    """Creates examples for the training and dev sets."""
    f = open(output_file,'a',encoding='utf-8')

    print(len(lines))

    for (i, line) in enumerate(lines):
        # print(i)
        text_a = line[1]  # title
        text_b = line[2] # content
        label = line[0]

        new_text = str(i)+'\t'+label+'\t'+text_a+' '+text_b+'\n'
        f.write(new_text)


def _clean_lda_examples(lines, output_file,output_vocab_file):
    """Creates examples for the training and dev sets."""
    f = open(output_file,'a+',encoding='utf-8')
    v= open(output_vocab_file,'a+',encoding='utf-8')
    # cv = CountVectorizer()
    print('doc lengths:',len(lines))

    corups=[]
    texts = []
    vocab = collections.OrderedDict()
    for (i, line) in enumerate(lines):
        text_id = line[0]
        label = line[1]
        text_content = line[2]
        tokens =[t for t in clean_str(text_content).split(' ') if t not in stopwords.words('english')]
        texts.append(tokens)
        corups+=tokens
        # print(corups)
        # exit()
        # for i in tokens:
        #     if i not in vocab:
        #         vocab[i]=len(vocab)

    #     text =' '.join(tokens)
    #     new_text = str(text_id)+'\t'+label+'\t'+text+' '+'\n'
    #
    #     f.write(new_text)
    # f.close()

    count = Counter(corups)
    print(count.most_common(30000))
    # vocab_dict = corpora.Dictionary(texts)
    # print(vocab_dict.)
    # print(dct.filter_n_most_frequent(1))
    # print(vocab_dict.num_docs)
    # exit()
    # vocab = dict.filter_extremes(no_below=2, no_above=0.5, keep_n=1000)

    # print(len(dict))
    exit()
    for i,j in enumerate(vocab):
        v.write(j+'\n')
    print('vocab_length:',len(vocab))

def _filter_lda_examples(lines, output_file,vocab_file):
    """Creates examples for the training and dev sets."""
    f = open(output_file,'a+',encoding='utf-8')
    # v= open(vocab_file,'a+',encoding='utf-8')
    # cv = CountVectorizer()
    print('doc lengths:',len(lines))
    vocab = load_vocab(vocab_file)
    vocab = [ x[0] for x in vocab.items()]

    for (i, line) in enumerate(lines):
        text_id = line[0]
        label = line[1]
        text_content = line[2]
        tokens =[t for t in clean_str(text_content).split(' ') if t not in stopwords.words('english')]
        tokens = stemtokens(tokens)
        tokens = [tok for tok in tokens if tok in vocab]
        text =' '.join(tokens)

        new_text = str(text_id)+'\t'+label+'\t'+text+' '+'\n'
    #
        f.write(new_text)
    f.close()



def clean_str(string):
    """
    Tokenization/string cleaning for all datasets except for SST.
    Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
    """
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " ", string)
    string = re.sub(r"!", " ", string)
    string = re.sub(r"\(", " ", string)
    string = re.sub(r"\)", " ", string)
    string = re.sub(r"\?", " ", string)
    string = re.sub(r"\s{2,}", " ", string)

    return string.strip().lower()

class TopicInfo(object):
    def __init__(self, vocab_file, do_lower_case=False, max_len=None,
                 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):

        self.vocab = load_vocab(vocab_file)

        self.ids_to_tokens = collections.OrderedDict(
            [(ids, tok) for tok, ids in self.vocab.items()])
        self.basic_tokenizer = tokenization_bert.BasicTokenizer(do_lower_case=do_lower_case,
                                              never_split=never_split)
        self.max_len = max_len if max_len is not None else int(1e12)

    def tokenize(self, text):

        # print(text)
        # exit()
        split_tokens = []
        for token in self.basic_tokenizer.tokenize(text):
            split_tokens.append(token)
            # for sub_token in self.BasicTokenizer.tokenize(self,token):
            #     split_tokens.append(sub_token)
        return split_tokens

    def convert_tokens_to_ids(self, tokens):
        """Converts a sequence of tokens into ids using the vocab."""
        # print(list(self.vocab)[101],'1111111111')
        # print(list(self.vocab)[10188],'2222222')
        # print(list(self.vocab))

        ids = []
        for token in tokens:
            # print(token)
            ids.append(self.vocab[token])
        if len(ids) > self.max_len:
            raise ValueError(
                "Token indices sequence length is longer than the specified maximum "
                " sequence length for this BERT model ({} > {}). Running this"
                " sequence through BERT will result in indexing errors".format(
                    len(ids), self.max_len)
            )
        return ids

    def convert_ids_to_tokens(self, ids):
        """Converts a sequence of ids in wordpiece tokens using the vocab."""
        tokens = []
        for i in ids:
            tokens.append(self.ids_to_tokens[i])
        return tokens


def load_vocab(vocab_file):
    """Loads a vocabulary file into a dictionary."""
    # vocab = collections.OrderedDict()
    # index = 0
    # with open(vocab_file, "r", encoding="utf-8") as reader:
    #     while True:
    #         line = reader.readline()
    #         if int(line.strip().split('\t')[0]) != 0:
    #             vocab["[CLS]"] = len(vocab)
    #             vocab["[SEP]"] = len(vocab)
    #             break
    #         token = line.strip().split('\t')[1]
    #         vocab[token] = index
    #         index += 1
    vocab = collections.OrderedDict()
    index = 0
    with open(vocab_file, "r", encoding="utf-8") as reader:
        while True:
            token = reader.readline()
            if not token:
                break
            token = token.strip()
            vocab[token] = index
            index += 1

    return vocab

def build_vocab(tokens,vocab_output_file):
    f = open(vocab_output_file,'a',encoding='utf-8')
    # print(len(list(set(tokens))))
    for i,j in enumerate(set(tokens)):
        f.write(j+'\n')
    f.close()
def remove_punctuation(text):
    text=''.join([t for t in text if t not in string.punctuation])


def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default='./tests_samples/TOPICS/tc/',
        type=str,
        required=False,
        help="The input data dir. Should contain the data files (or other data files) for the task.",
    )
    parser.add_argument(
        "--output_dir",
        default='./tests_samples/TOPICS/tc/',
        type=str,
        required=False,
        help="The output data dir directory",
    )
    # parser.add_argument(
    #     "--model_type",
    #     default=None,
    #     type=str,
    #     required=True,
    #     help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    # )
    # parser.add_argument(
    #     "--model_name_or_path",
    #     default=None,
    #     type=str,
    #     required=True,
    #     help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
    # )
    # parser.add_argument(
    #     "--task_name",
    #     default=None,
    #     type=str,
    #     required=True,
    #     help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
    # )


    args = parser.parse_args()
    topic_dir =os.path.join(args.data_dir, "dbpedia/topic.no_pun_no_stop.dbp.tsv")
    output_dir = os.path.join(args.output_dir, "dbpedia/topic.cleaned_30000.dbp.tsv")
    vocab_dir = os.path.join(args.output_dir, "dbpedia/dbp.30000.vocab.tsv")


    # topic_dir =os.path.join(args.data_dir, "ag/topic.no_pun_no_stop.ag.tsv")
    # output_dir = os.path.join(args.output_dir, "ag/topic.cleaned_30000.ag.tsv")
    # vocab_dir = os.path.join(args.output_dir, "ag/ag.30000.vocab.tsv")
    # print(stopwords.words('english'))
    # exit()
    # v = open(vocab_dir,'a+',encoding='utf-8')
    # with open(topic_dir,'r',encoding='utf-8') as f:
    #     texts = clean_str(f.read())
    #     texts=texts.lower().split(' ')
    #     texts=stemtokens(_remove_numbers(texts))
    #     count = Counter(texts).most_common(30000)
    #     for i in count:
    #         v.write(i[0]+'\n')
    #     print(len(texts))
    # exit()

    # build the BERT dataset
    # _create_examples(_read_csv(os.path.join(args.data_dir, "train.csv")), os.path.join(args.output_dir, "train.tsv"))
    # _create_examples(_read_csv(os.path.join(args.data_dir, "test.csv")), os.path.join(args.output_dir, "test.tsv"))
    # print('Generated the data for BERT structured')
    # _create_lda_examples(_read_tsv(os.path.join(args.data_dir, "train.tsv"))+_read_tsv(os.path.join(args.data_dir, "test.tsv")),os.path.join(args.output_dir, "topic.tsv"))
    _filter_lda_examples(_read_tsv(topic_dir),output_dir,vocab_dir)

    # build the lda dataset
    print('Generated the data for lda')

if __name__ == "__main__":
    main()