import os
import json
import glob
import tqdm
import random
import argparse
import logging
from unilm.tokenization_unilm import UniLMAutoTokenizer

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


class TokenAttr(object):
    def __init__(self, token, token_id):
        self.prob_range = 100
        self.prob_x = 10000
        self.prob_counter = [0] * (self.prob_x // self.prob_range + 1)
        self.top_ranges = [1, 5, 10, 20, 50, 100]
        self.top_counters = {key: 0 for key in self.top_ranges}

        self.token = token
        self.token_id = token_id
        self.count = 0
        self.sum_of_prob = 0

        self.ndigits = 4

    def update(self, info):
        for prob, rank in zip(info["probs"], info["rank"]):
            self.sum_of_prob += prob
            self.count += 1
            prob_t = int(prob * self.prob_x) // self.prob_range
            self.prob_counter[prob_t] += 1

            for top_range in self.top_ranges:
                if rank <= top_range:
                    self.top_counters[top_range] += 1

    def get_number(self, x):
        return round(x / self.count, self.ndigits)

    def to_dict(self):
        json_data = {
            "token": self.token,
            "token_id": self.token_id,
            "probs": {
                "%d~%d" % ((self.prob_range * i * 100) // self.prob_x, (self.prob_range * (i + 1) * 100) // self.prob_x):
                    self.get_number(self.prob_counter[i]) for i in range(len(self.prob_counter))
            },
            "ranks": {
                "top-%d" % i: self.get_number(self.top_counters[i]) for i in self.top_ranges
            },
            "mean_probs": self.sum_of_prob / self.count,
            "count": self.count,
        }
        return json_data


def get_args():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--input", type=str, required=True, help="Input files. ")
    parser.add_argument("--output", type=str, required=True, help="Output file for save output. ")
    return parser.parse_args()


def main():
    args = get_args()
    input_files = list(glob.glob(args.input))
    collector = {}
    tokenizer = UniLMAutoTokenizer.from_pretrained('unilm3-base-cased')
    random.shuffle(input_files)
    input_files = input_files[:80]
    for thisfile in tqdm.tqdm(input_files):
        logger.info("Load %s" % thisfile)
        with open(thisfile, mode="r", encoding="utf-8") as reader:
            data = json.loads(reader.read())
        for _key in data:
            key = 1
            if key in collector:
                token_attr = collector[key]
            else:
                token_attr = collector[key] = TokenAttr(
                    token=tokenizer.convert_ids_to_tokens([key])[0],
                    token_id=key,
                )

            token_attr.update(data[_key])

    with open(args.output, mode="w", encoding="utf-8") as writer:
        # for token_id in sorted(collector.keys()):
        #     writer.write(json.dumps(collector[token_id].to_dict(), indent=None))
        #     writer.write('\n')

        all_tokens = []
        for token_id in sorted(collector.keys()):
            all_tokens.append(collector[token_id].to_dict())

        writer.write(json.dumps(all_tokens, indent=2))


if __name__ == '__main__':
    main()
