import argparse
import os
import json
import random
import jieba
import math

from multiprocessing import Pool
from tqdm import tqdm
from transformers import AutoTokenizer
from unsupervised.Trie import HatTrie
from unsupervised.utils import compute_prob_from_statistics


parser = argparse.ArgumentParser()
parser.add_argument("--file", type=str, default="xxx")
parser.add_argument("--save_to", type=str, default="xxx")
parser.add_argument("--column", type=int, help="take specified column")
parser.add_argument("--separator", type=str, help="the separator of file")
parser.add_argument("--tokenizer_name", type=str, default="xxx")
parser.add_argument("--max_len", type=int, default=100)
parser.add_argument("--chunksize", type=int, default=500)
parser.add_argument("--short_sentence_prob", type=float, default=0.1)
args = parser.parse_args()


tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
file_name = os.path.split(args.file)[1]
target_length = args.max_len - tokenizer.num_special_tokens_to_add(pair=False)
trie, _ = HatTrie.load('trie path')


def is_all_chinese(strs):
    for _char in strs:
        if not '\u4e00' <= _char <= '\u9fa5':
            return False
    return True


def sigmoid(x):
    sig = 1 / (1 + math.exp(-x))
    return sig


# def encode_one_line(text):
#     if args.column is not None:
#         text = text.split(args.separator)[args.column]
#     else:
#         text = text.strip().split("$-$")[0]
#     sentence = text
#     tokenized = tokenizer(sentence, 
#                     add_special_tokens=False, 
#                     truncation=False, 
#                     return_offsets_mapping=True,
#                     return_attention_mask=False,
#                     return_token_type_ids=False)

#     ids, offsets = tokenized['input_ids'], tokenized['offset_mapping']
#     offsets_mapping = {}
#     for i in range(len(offsets)):
#         offsets_mapping[offsets[i][0]] = i
#     tokens = jieba.tokenize(sentence)
#     segs = [0] * len(ids)
#     for token in tokens:
#         if token[1] in offsets_mapping:
#             segs[offsets_mapping[token[1]]] = 1
    
#     return {'ids': ids, 'segs': segs}


def encode_one_line(text):
    if args.column is not None:
        text = text.split(args.separator)[args.column]
    else:
        text = text.strip().split("$-$")[0]
    sentence = text
    tokenized = tokenizer(sentence, 
                    add_special_tokens=False, 
                    truncation=False, 
                    return_offsets_mapping=True,
                    return_attention_mask=False,
                    return_token_type_ids=False)
    ids, offsets = tokenized['input_ids'], tokenized['offset_mapping']
    # probs, _ = compute_prob_from_statistics(text, trie)
    # for i, offset in enumerate(offsets):
    #     segs[i] = (sigmoid(probs[offset[0]][0]), sigmoid(probs[offset[0]][1]))

    offsets = [offset[0] for offset in offsets]
    res_dic = {'ids': ids, 'meta': text, 'offsets': offsets}

    # ## If use whole word mask, use the below code.
    # tokens = tokenizer.convert_ids_to_tokens(ids)
    # jieba_res = jieba.tokenize(sentence)
    # offsets_mapping = {}
    # for i in range(len(offsets)):
    #     offsets_mapping[offsets[i][0]] = i
    # for token in jieba_res:
    #     if is_all_chinese(token[0]):
    #         for i in range(token[1] + 1, token[2]):
    #             tokens[offsets_mapping[i]] = '##' + tokens[offsets_mapping[i]]
    # res_dic['tokens'] = tokens

    return res_dic


# with open(args.file, 'r') as corpus_file:
#     lines = corpus_file.readlines()

with open(os.path.join(args.save_to, file_name+".json"), "w") as tokenized_file, open(args.file, 'r') as corpus_file:
    with Pool() as p:
        all_blocks = p.imap_unordered(encode_one_line, tqdm(corpus_file), chunksize=args.chunksize)
        for block in all_blocks:
            # tokenized_file.write(json.dumps({'text': block['ids'], 'meta': block['meta'], 'tokens': block['tokens'], 'offsets': block['offsets']}, ensure_ascii=False) + "\n")
            tokenized_file.write(json.dumps({'text': block['ids'], 'meta': block['meta'], 'offsets': block['offsets']}, ensure_ascii=False) + "\n")
            tokenized_file.flush() 