import torch
import os
import pickle
from torch.utils.data import TensorDataset
from torch.utils.data.dataset import random_split
import gzip
import numpy as np
import random




def create_dataset(dir_path, args,type: str):

    with gzip.open(dir_path, 'rb') as f:
        data = pickle.load(f)

    all_input_ids=torch.tensor(data['input_ids'] , dtype=torch.long)
    all_input_mask=torch.tensor(data['attention_mask'] , dtype=torch.long)
    all_segment_ids=torch.tensor(data['token_type_ids'] , dtype=torch.long)
    all_label=torch.tensor(data['labels'] , dtype=torch.long)
    dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)

    if type == 'train':
        data_size = len(data['labels'])
        train_size = int(data_size * 0.8)
        val_size = data_size - train_size
        train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
        if args.num_samples != None:
            gnt=torch.Generator().manual_seed(args.sample_seed)
            train_dataset, val_dataset, _ = random_split(dataset, [args.num_samples, args.num_samples, data_size-2*args.num_samples], generator=gnt)
        return train_dataset, val_dataset
    else:

        return dataset
    
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self,args, verbose=False, delta=0, path='checkpoint_best.bin', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.bin'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = args.early_stop
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self,args, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(args, val_loss, model)

        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(args,val_loss, model)
            self.counter = 0

    def save_checkpoint(self, args, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        model_to_save = model.module if hasattr(model, 'module') else model
        output_model_file = os.path.join(args.output_dir, self.path)
        torch.save(model_to_save.state_dict(), output_model_file)
        self.val_loss_min = val_loss

def accuracy(out, labels):
    outputs = np.argmax(out, axis=1)
    return np.sum(outputs == labels)

def warmup_linear(x, warmup=0.002):
    if x < warmup:
        return x/warmup
    return 1.0 - x