import os
import math
import torch
import pickle
import random
import json
import yaml
import h5py
import numpy as np
import re
import operator
import functools

def load_files(path):
    if path.rsplit('.', 2)[-1] == 'json':
        with open(path, 'r') as f:
            data = json.load(f)
    elif path.rsplit('.', 2)[-1] in ['pkl', 'pickle']:
        with open(path, 'rb') as f:
            data = pickle.load(f)
    elif path.rsplit('.', 2)[-1] == 'yaml':
        with open(path, 'r') as f:
            try:
                data = yaml.safe_load(f)
            except yaml.YAMLError as exc:
                print(exc)
    elif path.rsplit('.', 2)[-1] == 'hdf5':
        data = h5py.File(path, "r")
    elif path.rsplit('.', 2)[-1] == 'npz':
        data = np.load(path)

    return data

def save_pickle(data, path):
    with open(path, 'wb') as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

def save_json(data, path):
    with open(path, 'w') as f:
        json.dump(data, f)

def clean_str(string, lower = True):
    string = re.sub(r"[^A-Za-z0-9,!\']", " ", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'m", " \'m", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " , ", string)
    return string.strip().lower() if lower else string.strip()

def fix_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

def print_model(model, logger=None):
    for name, params in model.named_parameters():
        print(name, params.size())
    
    total_params = sum(p.numel() for p in model.parameters())
    print (total_params)
    
    nParams = 0
    for w in model.parameters():
        nParams += functools.reduce(operator.mul, w.size(), 1)
    if logger is not None:
        logger.info(model)
        logger.info('nParams=\t' + str(nParams))
    else:
        print (model)
        print ('nParams=\t' + str(nParams))

def make_one_hot(labels, num_labels):
    onehot = torch.eye(num_labels)
    out = onehot[labels.long()]
    return out.type(torch.cuda.FloatTensor)

def glorot(tensor):
    if tensor is not None:
        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
        tensor.data.uniform_(-stdv, stdv)
        
def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)