import os
import torch
import json
import re
import random
import h5py
import numpy as np
import copy

random.seed(42)

from torch.utils.data import DataLoader, Dataset

from itertools import cycle

from transformers.models.bert.tokenization_bert import BertTokenizer

from sklearn.metrics.pairwise import cosine_similarity


import logging
logger = logging.getLogger(__name__)


def save_hparams(args, path):
    with open(path, 'w', encoding='utf-8') as f:
        for attr, value in sorted(vars(args).items()):
            f.writelines("{}={}\n".format(attr.upper(), value))

def check_mem(cuda_device):
    devices_info = os.popen('"nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader').read().strip().split("\n")
    total, used = devices_info[int(cuda_device)].split(',')
    return total,used

def check_mem_all():
    devices_info = os.popen('"nvidia-smi" --query-gpu=memory.total,memory.used --format=csv,nounits,noheader').read().strip().split("\n")
    return devices_info

def occupy_mem_new(cuda_device_list, ratio=0.6, num_devices=8):
    import time

    if len(cuda_device_list) == 0 or len(cuda_device_list[0]) == 0:
        while True:
            devices_info = check_mem_all()
            available_devices = []
            occupys = []
            for cuda_device in range(num_devices):
                total, used = devices_info[int(cuda_device)].split(',')
                total = int(total)
                used = int(used)
                occupy = int(total * ratio)
                print("Device-{}: {}/{}/{}".format(cuda_device, total, used, occupy))
                if occupy + used <= total * 0.95:
                    print('Find device-{}!'.format(cuda_device))
                    available_devices.append(cuda_device)
                    occupys.append(occupy)
            if len(available_devices) > 0: # hoooope
                print(available_devices[0])
                os.environ['CUDA_VISIBLE_DEVICES'] = str(available_devices[0])
                try:
                    x = torch.cuda.FloatTensor(256, 1024, occupys[0], device='cuda:0')
                    del x
                except RuntimeError:
                    print("Failed, continue...")
                    time.sleep(2)
                    continue
                break
        input(">>>>>")
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(cuda_device_list)
        for id, cuda_device in enumerate(cuda_device_list):
            while True:
                total, used = check_mem(cuda_device)
                total = int(total)
                used = int(used)
                occupy = int(total * ratio)
                print("Device-{}: {}/{}/{}".format(cuda_device, total, used, occupy))
                if occupy + used <= total * 0.95:
                    print('Find device-{}!'.format(cuda_device))
                    try:
                        x = torch.cuda.FloatTensor(256, 1024, occupy, device='cuda:{}'.format(id))
                        del x
                    except RuntimeError:
                        time.sleep(2)
                        continue
                    break
        # input('>>>>') # todo: del


def get_batch_loader(dataset, collate_fn, batch_size=2, num_workers=0, is_test=True):
    loader = DataLoader(
        dataset, batch_size=batch_size,
        shuffle=(not is_test), num_workers=num_workers, collate_fn=collate_fn
    )
    return loader if is_test else cycle(loader)


class DailyDialogDataset(Dataset):
    def __init__(self, data_path):
        self._context = []
        self._response = []
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                example = json.loads(line)
                self._context.append(" <#Q#> ".join(example['context']))
                self._response.append(example['response'])
        self._n_data = len(self._context)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        context = self._context[i] # "U1 <#Q#> U2 <#Q#> U3"
        response = self._response[i]
        return context, response

    @staticmethod
    def collate_fn(batch):
        context_list = [item[0] for item in batch]
        response_list = [item[1] for item in batch]
        return context_list, response_list

class ConvaiRetDataset(Dataset):
    def __init__(self, data_path):
        self._context = []
        self._knowledge = []
        self._response = []
        self._candidates = []
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                example = json.loads(line)
                self._knowledge.append(example['context'][0])
                self._context.append(" [SEP] ".join(example['context'][1:]))
                self._response.append(example['response'])
                self._candidates.append(" [SEP] ".join(example['candidates']))
        self._n_data = len(self._context)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        knowledge = self._knowledge[i]
        context = self._context[i] # "Knowledge [SEP] U1 [SEP] U2 [SEP] U3"
        response = self._response[i]
        candidates = self._candidates[i] # "R1 [SEP] R2 [SEP] R3"
        return knowledge, context, response, candidates

    @staticmethod
    def collate_fn(batch):
        knowledge_list = [item[0] for item in batch]
        context_list = [item[1] for item in batch]
        response_list = [item[2] for item in batch]
        candidates_list = [item[3] for item in batch]
        return knowledge_list, context_list, response_list, candidates_list



class RedditDataset(Dataset):
    def __init__(self, data_path):
        reader = h5py.File(data_path, 'r')
        self._context = reader['context']
        self._response = reader['response']
        self._n_data = len(self._context)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        context = self._context[i] # "U1 <#Q#> U2 <#Q#> U3"
        response = self._response[i]
        return context, response

    @staticmethod
    def collate_fn(batch):
        context_list = [item[0] for item in batch]
        response_list = [item[1] for item in batch]
        return context_list, response_list


def init_para_frompretrained(m, pm, share_para=False):
    m.wte.weight = pm.wte.weight
    m.wpe.weight = pm.wpe.weight

    for i in range(min(len(m.h), len(pm.h))):
        m.h[i].ln_1.weight = pm.h[i].ln_1.weight if share_para else copy.copy(pm.h[i].ln_1.weight)
        m.h[i].ln_1.bias = pm.h[i].ln_1.bias if share_para else copy.copy(pm.h[i].ln_1.bias)
        m.h[i].attn.c_attn.weight = pm.h[i].attn.c_attn.weight if share_para else copy.copy(pm.h[i].attn.c_attn.weight)
        m.h[i].attn.c_attn.bias = pm.h[i].attn.c_attn.bias if share_para else copy.copy(pm.h[i].attn.c_attn.bias)
        m.h[i].attn.c_proj.weight = pm.h[i].attn.c_proj.weight if share_para else copy.copy(pm.h[i].attn.c_proj.weight)
        m.h[i].attn.c_proj.bias = pm.h[i].attn.c_proj.bias if share_para else copy.copy(pm.h[i].attn.c_proj.bias)
        m.h[i].ln_2.weight = pm.h[i].ln_2.weight if share_para else copy.copy(pm.h[i].ln_2.weight)
        m.h[i].ln_2.bias = pm.h[i].ln_2.bias if share_para else copy.copy(pm.h[i].ln_2.bias)
        m.h[i].mlp.c_fc.weight = pm.h[i].mlp.c_fc.weight if share_para else copy.copy(pm.h[i].mlp.c_fc.weight)
        m.h[i].mlp.c_fc.bias = pm.h[i].mlp.c_fc.bias if share_para else copy.copy(pm.h[i].mlp.c_fc.bias)
        m.h[i].mlp.c_proj.weight = pm.h[i].mlp.c_proj.weight if share_para else copy.copy(pm.h[i].mlp.c_proj.weight)
        m.h[i].mlp.c_proj.bias = pm.h[i].mlp.c_proj.bias if share_para else copy.copy(pm.h[i].mlp.c_proj.bias)

    m.ln_f.weight = pm.ln_f.weight if share_para else copy.copy(pm.ln_f.weight)
    m.ln_f.bias = pm.ln_f.bias if share_para else copy.copy(pm.ln_f.bias)
