import os
import logging
from tqdm import tqdm
import torch
import json
import random

def set_log(log_dir):
    os.makedirs(log_dir, exist_ok=True)
    final_path = f"{log_dir}/log"
    logger = logging.getLogger()
    logger.setLevel('INFO')
    control = logging.StreamHandler() 
    control.setLevel('INFO')
    fhlr = logging.FileHandler(final_path)
    logger.addHandler(fhlr)
    logger.addHandler(control)
    return logger






def generate_prepare(test_dataLoader, dataset_path, generation_dir):

    #重新整理测试集，也就是test.jsonl，测试集里一条query可能对应多条ref，这里做一个去重，减少测试量
    test_data_list = []
    query_set = set()
    for id, data in enumerate(tqdm(test_dataLoader)):
        query = str(data["query"].tolist())
        if not query in query_set:
            query_set.add(query)
            cuda_data = {k:v.to(torch.device("cuda")) for k,v in data.items()}
            cuda_data["id"] = id
            test_data_list.append(cuda_data)

    context_list = []
    context_pred_refs_dict = {}
    with open(dataset_path + "/test_formatted.jsonl", 'r', encoding='utf8') as reader:
        for line in reader:
            items = json.loads(line.strip())
            context = items['context']
            completion = items['completion']
            context_list.append(context)

            if not context in context_pred_refs_dict:
                context_pred_refs_dict[context] = {}
                context_pred_refs_dict[context]["refs"] = []
                context_pred_refs_dict[context]["pred"] = "[No answer]"
            new_ref = completion.split('<|endoftext|>')[0].split('\n\n')[0].strip()
            context_pred_refs_dict[context]["refs"].append(new_ref)
    os.makedirs(generation_dir, exist_ok=True)

    return test_data_list, context_pred_refs_dict, context_list




def client_choose(R, N, C):
    choosen_client_ids = ["init_round"]
    cur = ["server"]+([0]*N)
    last_r_of_clients = ["init_round"]
    for r in range(1, R+1):
        client_ids = random.sample(range(1, N+1), C)
        choosen_client_ids.append(client_ids)
        last_r_of_clients.append(cur)
        cur = ["server"] + [r if client_id in client_ids else cur[client_id] for client_id in range(1, N+1)]
    return choosen_client_ids, last_r_of_clients
"""
init_round
[11, 10, 20, 6, 17]
[20, 1, 7, 14, 24]
[9, 8, 19, 2, 7]
[6, 9, 24, 8, 25]
[9, 21, 24, 19, 11]
init_round
['server', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
['server', 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0]
['server', 2, 0, 0, 0, 0, 1, 2, 0, 0, 1, 1, 0, 0, 2, 0, 0, 1, 0, 0, 2, 0, 0, 0, 2]
['server', 2, 3, 0, 0, 0, 1, 3, 3, 3, 1, 1, 0, 0, 2, 0, 0, 1, 0, 3, 2, 0, 0, 0, 2]
['server', 2, 3, 0, 0, 0, 4, 3, 4, 4, 1, 1, 0, 0, 2, 0, 0, 1, 0, 3, 2, 0, 0, 0, 4]

"""