#!/usr/bin/python3
#-*- coding:utf-8 -*-
############################
#File Name: data.py
#Created Time:
############################
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import utils.utils

# define my data set

class MyData(Dataset):
    def __init__(self,tfidfs,labels):
        super(MyData,self).__init__()
        self.tfidfs = tfidfs
        self.labels = labels.A
        #print(type(tfidfs))
        #print(type(labels))

    def __getitem__(self,index):
        tfidf, label = self.tfidfs[index],self.labels[index]
        #print(tfidf,label)
        return torch.from_numpy(tfidf), torch.from_numpy(label)

    def __len__(self):
        return len(self.tfidfs)

class DataPrefetcher():
    def __init__(self,loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.preload()

    def preload(self):
        try:
            self.next_tfidf, self.next_label = next(self.loader)
        except StopIteration:
            self.next_tfidf = None
            self.next_label = None
            return

        with torch.cuda.stream(self.stream):
            self.next_tfidf = self.next_tfidf.cuda(non_blocking=True)
            self.next_label = self.next_label.cuda(non_blocking=True)

            self.next_tfidf = self.next_tfidf.float()
            self.next_label = self.next_label.float()

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        tfidf, label = self.next_tfidf, self.next_label
        self.preload()
        return tfidf, label

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self,val,n=1):
        self.val = val
        self.sum += val*n
        self.count += n
        self.avg = self.sum /self.count
