import jsonlines
import zstandard
import tarfile
import string
import re
import _pickle as cPickle
import random
from nltk import word_tokenize, pos_tag
from os import path, listdir

def exceeds_tokenizer_max_lengths(line, tokenizers):
    for tokenizer in tokenizers:
        encoded_line = tokenizer.encode(line)
        if len(encoded_line) > tokenizer.model_max_length:
            return True
    return False

def read_targz(targz_file):
    with tarfile.open(targz_file, 'r:gz') as tar:
        tar.extractall('D:\\datasets\\')

def read_tarxz(tarxz_file):
    with tarfile.open(tarxz_file, 'r:xz') as tar:
        tar.extractall('D:\\datasets\\openwebtext')

#Reads in Reddit corpus
def index_contexts(zstd_file, save_dir, contexts_per_file):

    removal_tokens = ['&gt;', '&amp;', '&lt;', '&quot;', '&apos;']

    with open(zstd_file, 'rb') as zstd_reader:

        decompressor = zstandard.ZstdDecompressor()
        stream_reader = decompressor.stream_reader(zstd_reader)
        text_stream = zstandard.io.TextIOWrapper(stream_reader, encoding='utf-8')
        json_reader = jsonlines.Reader(text_stream)

        context_list = []
        term_dict = {}
        file_num = 0

        for line in json_reader:

            try:
            
                context = line['body']

                if context == '[deleted]' or context == '[removed]' or 'post has been removed' in context or 'this post violates' in context or 'contact the moderators' in context:
                    continue
            
                for token in removal_tokens:
                    context = context.replace(token, ' ')

                context = re.sub(r'(\S+)?http\S+', ' ', context)
                context = re.sub(r'\s+', ' ', context)

                if len(context.split(' ')) < 5:
                    continue
                
                context_list.append(context)

                tokenized_context = word_tokenize(context)
                
                for word in tokenized_context:
                    if word not in string.punctuation:
                        if word not in term_dict:
                            term_dict[word] = [len(context_list) - 1]
                        else:
                            term_dict[word].append(len(context_list) - 1)

            except:
                continue

            if len(context_list) == contexts_per_file:
                with open(path.join(save_dir, f'Reddit_Contexts_2019-12_{file_num}.pkl'), 'wb') as pickle_file:
                    cPickle.dump((context_list, term_dict), pickle_file)
                
                print(f'Saved zstd pickle {file_num} to {save_dir}')

                file_num += 1
                context_list = []
                term_dict = {}
        
        if len(context_list) > 0:
            with open(path.join(save_dir, f'{zstd_file[:-4]}_{file_num}.pkl'), 'wb') as pickle_file:
                cPickle.dump((context_list, term_dict), pickle_file)

#Gathers contexts for a word
def gather_term_contexts(term, num_contexts, tokenizers, load_dir, target_pos = None, ordered_dict = None, random_load = True, detect_names = True):
    
    term_contexts = []
    
    if ordered_dict is not None:
        ordered_term_tups = ordered_dict[term]
        context_dir = [i[0] for i in ordered_term_tups]
    else:
        context_dir = listdir(load_dir)

    if random_load == True:
        random.shuffle(context_dir)
    
    pkl_idx = 0

    while(len(term_contexts)) < num_contexts and pkl_idx < len(context_dir):

        with open(path.join(load_dir, context_dir[pkl_idx]), 'rb') as pickle_file:
            context_tup = cPickle.load(pickle_file)

        
        context_list = context_tup[0]
        context_dict = context_tup[1]

        potential_term_contexts = []

        if term in context_dict:
            for idx in context_dict[term]:
                potential_term_contexts.append(context_list[idx])
        
        random.shuffle(potential_term_contexts)
        
        for context in potential_term_contexts:
            
            tokenized_line = word_tokenize(context)
            
            if len(tokenized_line) < 7:
                idx += 1
                continue
            
            if tokenized_line.count(term) != 1:
                idx += 1
                continue
            
            if exceeds_tokenizer_max_lengths(context, tokenizers):
                idx += 1
                continue

            term_position = tokenized_line.index(term)

            if term_position - 1 >= 0:
                preceding_term = tokenized_line[term_position - 1]
                if preceding_term == 'not' or preceding_term == 'no' or preceding_term[-3:] == 'n\'t':
                    idx += 1
                    continue

            if detect_names:
                if term_position - 1 >= 0:
                    preceding_term = tokenized_line[term_position - 1]
                    if preceding_term[0].isupper():
                        idx += 1
                        continue
                if term_position + 1 <= len(tokenized_line) - 1:
                    following_term = tokenized_line[term_position + 1]
                    if following_term[0].isupper():
                        idx += 1
                        continue

            if target_pos is not None:
                pos_line = pos_tag(context)
                if pos_line[term_position][1] != target_pos:
                    idx += 1
                    continue
            
            term_contexts.append(context)
            idx += 1

            if len(term_contexts) == num_contexts:         
                term_contexts = list(set(term_contexts))
                if len(term_contexts) == num_contexts:
                    return term_contexts
        
        pkl_idx += 1

    return term_contexts

def gather_context_term_list(term_list, num_contexts, tokenizers, load_dir, save_dir, target_pos = None, random_load = True, detect_names = True):
    
    term_context_list = [['position_0'] for i in term_list]
    context_dir = listdir(load_dir)

    if random_load == True:
        random.shuffle(context_dir)
    
    for context_file in context_dir:

        with open(path.join(load_dir, context_file), 'rb') as pickle_file:
            context_tup = cPickle.load(pickle_file)
        
        pop_list = []
        
        context_list = context_tup[0]
        context_dict = context_tup[1]

        for term_idx, term in enumerate(term_list):
            
            potential_term_contexts = []

            if term in context_dict:
                for idx in context_dict[term]:
                    potential_term_contexts.append(context_list[idx])
            
                random.shuffle(potential_term_contexts)
                
                for context in potential_term_contexts:
                    
                    tokenized_line = word_tokenize(context)
                    
                    if len(tokenized_line) < 7:
                        continue
                    
                    if tokenized_line.count(term) != 1:
                        continue
                    
                    if exceeds_tokenizer_max_lengths(context, tokenizers):
                        continue

                    term_position = tokenized_line.index(term)

                    if term_position - 1 >= 0:
                        preceding_term = tokenized_line[term_position - 1]
                        if preceding_term == 'not' or preceding_term == 'no' or preceding_term[-3:] == 'n\'t':
                            continue

                    if detect_names:
                        if term_position - 1 >= 0:
                            preceding_term = tokenized_line[term_position - 1]
                            if preceding_term[0].isupper():
                                continue
                        if term_position + 1 <= len(tokenized_line) - 1:
                            following_term = tokenized_line[term_position + 1]
                            if following_term[0].isupper():
                                continue

                    if target_pos is not None:
                        pos_line = pos_tag(context)
                        if pos_line[term_position][1] != target_pos:
                            continue
                    
                    term_context_list[term_idx].append(context)

                    if len(term_context_list[term_idx]) == num_contexts:

                        term_contexts = term_context_list[term_idx][1:]

                        with open(path.join(save_dir, f'{term}_contexts.pkl'), 'wb') as pickle_writer:
                            cPickle.dump(term_contexts, pickle_writer)
                        
                        print(f'Collected {num_contexts} contexts for term {term} and wrote to {save_dir}')
                        pop_list.append(term_idx)

                        break

        term_context_list = [term_context_list[i] for i in range(len(term_context_list)) if i not in pop_list]
        term_list = [term_list[i] for i in range(len(term_list)) if i not in pop_list]

        if not term_list:
            print(f'Finished collecting {num_contexts} contexts for full term list')
            return

    for term_idx, term in enumerate(term_list):

        if len(term_context_list[term_idx]) > 1:
            term_contexts = term_context_list[term_idx][1:]

            with open(path.join(save_dir, f'{term}_contexts.pkl'), 'wb') as pickle_writer:
                cPickle.dump(term_contexts, pickle_writer)

            print(f'Collected {len(term_context_list[term_idx])} contexts for {term} and saved to {save_dir}')

        else:
            print(f'Found no contexts for {term}')

    print('Finished collecting contexts')
    return