import json
from collections import defaultdict

def length():
    dataset_len_distri = defaultdict(int)
    json_file = "data/kp20k/train.json"
    with open(json_file, 'r') as f:
        for line in f:
            data = json.loads(line)
            '''{
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "target_mask": target_mask,
                "candidates_number": len(candidates_list)
            }'''
            # 323 T, 383 F
            len_i = 0
            for token in data['labels']:
                if token==323:
                    len_i+=1
            # len_i = len(data['input_ids'])
            dataset_len_distri[len_i] += 1
        
    interval = 1
    interval_distri = defaultdict(int)
    for len_i in dataset_len_distri.keys():
        interval_distri[len_i // interval] += dataset_len_distri[len_i]
    print(interval_distri)

def train_present_num():
    data_file = "data/inspec/train_trg.txt"
    num_set = {}
    with open(data_file) as f:
        data = f.readlines()
        for line in data:
            present, _ = line.strip().split("<peos>")
            present = present.strip(';').split(';')
            num = len(present)
            if num in num_set:
                num_set[num] += 1
            else:
                num_set[num] = 1
    
    probs = [0 for _ in range(10)]
    for key in num_set:
        if int(key) < 10:
            probs[int(key) - 1] += num_set[key]
        else:
            probs[9] += num_set[key]
    total = sum(probs)
    probs = [p / total for p in probs]
    print(probs)

def get_all_index(nums, masks, target):
    res = []
    begin_pos = 0
    for i, (num, mask) in enumerate(zip(nums, masks)):
        if num == target and mask==1:
            res.append(i)
        if begin_pos==0 and mask==1:
            begin_pos = i
    return res, begin_pos

def pos_distribution():
    distribution = [0 for _ in range(10)]
    json_file = "data/kp20k/train.json"
    with open(json_file, 'r') as f:
        for line in f:
            data = json.loads(line)
            '''{
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "target_mask": target_mask,
                "candidates_number": len(candidates_list)
            }'''
            # 323 T, 383 F
            # find the position of T
            T_positions, begin_pos = get_all_index(data['labels'], data['target_mask'], 323)
            len_scale = sum(data['target_mask'])
            for p in T_positions:
                percentage = (p-begin_pos)/len_scale
                distribution[int(percentage//0.1)] += 1
    num_scale = sum(distribution)
    distribution = [num/num_scale for num in distribution]
    print(distribution)    


if __name__ == '__main__':
    # pos_distribution()
    data = [0.4229844477799153, 0.16823175415230887, 0.1053576587813317, 0.06735513603788343, 0.06687102297951504, 0.049422517246968126, 0.046048524054147374, 0.0352049552882481, 0.02958304265409523, 0.00894094102558682]

    import matplotlib.pyplot as plt
    import numpy as np
    x = np.arange(1, 11)
    plt.bar(x, data)
    plt.savefig("pos_distribution.png")