import numpy as np
import math
import sys
import random


class DataInputTrainTarget:  # (user, hist_i, i, label)
    def __init__(self, train_set_tgt, batch_size):
        self.batch_size = batch_size
        self.data = train_set_tgt
        self.num_batches = math.ceil(len(self.data) / self.batch_size)  # num of batches
        self.iter = 0

    def __iter__(self):
        return self

    def __next__(self):  # gen a batch
        if self.iter == self.num_batches:
            raise StopIteration  # return vs. StopIteration
        ts = self.data[self.iter*self.batch_size: min((self.iter+1)*self.batch_size, len(self.data))]
        self.iter += 1

        users, items, ys = [], [], []
        sl = []  # sl (sliding_len): length of hist_i
        # read for the first time...
        for t in ts:  # user_id, hist, pos_list[i]/neg_list[i], 1/0
            user = t[0]
            users.append(user)
            history = t[1]
            sl.append(len(history))  # sl = len(hist_i)
            item = t[2]
            items.append(item)
            label = t[3]
            ys.append(label)
        max_s1 = max(sl)
        hist_i = np.zeros([len(ts), max_s1], np.int64)
        # read for the second time...
        for index, t in enumerate(ts):
            t_hist = t[1]
            for le in range(len(t_hist)):
                hist_i[index][le] = t_hist[le]

        return self.iter, (users, hist_i,items,ys,sl)

    def next(self):
        return self.__next__()


class DataInputTrainSource:  # (user, hist_i, i, label)
    def __init__(self, train_set_src, config):
        self.data = train_set_src
        self.batch_size = config['train_batch_size']
        self.num_batches = math.ceil(len(self.data) / self.batch_size)
        self.iter = 0

    def __iter__(self):
        return self

    def __next__(self):  # gen a batch
        if self.iter == self.num_batches:
            raise StopIteration  # return vs. StopIteration
        ts = self.data[self.iter*self.batch_size: min((self.iter+1)*self.batch_size, len(self.data))]
        self.iter += 1

        users, items, ys = [], [], []
        sl = []  # sl (sliding_len): length of hist_i
        # read for the first time...
        for index, t in enumerate(ts):  # user_id, hist, pos_list[i]/neg_list[i], 1/0
            user = t[0]
            users.append(user)
            history = t[1]
            sl.append(len(history))  # sl = len(hist_i)
            item = t[2]
            items.append(item)
            label = t[3]
            ys.append(label)
        max_s1 = max(sl)
        hist_i = np.zeros([len(ts), max_s1], np.int64)
        # read for the second time...
        for index, t in enumerate(ts):
            t_hist = t[1]
            for le in range(len(t_hist)):
                hist_i[index][le] = t_hist[le]
        return self.iter, (users, hist_i,items,ys,sl)

    def next(self):
        return self.__next__()


class DataInputTrainAttackerReal:  # (user, hist_i, i, label, gender,age,occupation)
    def __init__(self, train_set_attack_real, config):
        self.data = train_set_attack_real
        self.num_classes_gender = config['num_classes_gender']
        self.num_classes_age = config['num_classes_age']
        self.num_classes_occupation = config['num_classes_occupation']
        self.batch_size = config['train_attack_batch_size']
        self.num_batches = math.ceil(len(self.data) / self.batch_size)  # num of batches
        self.iter = 0

    def __iter__(self):
        return self

    def __next__(self):  # gen a batch
        if self.iter == self.num_batches:
            raise StopIteration  # return vs. StopIteration
        ts = self.data[self.iter*self.batch_size: min((self.iter+1)*self.batch_size, len(self.data))]
        self.iter += 1

        users, items, ys = [], [], []
        sl = []  # sl (sliding_len): length of hist_i
        # read for the first time...
        for index, t in enumerate(ts):  # user_id, hist, pos_list[i]/neg_list[i], 1/0
            user = t[0]
            users.append(user)
            history = t[1]
            sl.append(len(history))  # sl = len(hist_i)
            item = t[2]
            items.append(item)
            label = t[3]  # always 1
            ys.append(label)  # dummy, NOT used
        max_s1 = max(sl)
        hist_i = np.zeros([len(ts), max_s1], np.int64)
        y_gender = np.zeros([len(ts), self.num_classes_gender])
        y_age = np.zeros([len(ts), self.num_classes_age])
        y_occupation = np.zeros([len(ts), self.num_classes_occupation])
        # read for the second time...
        for index, t in enumerate(ts):
            t_hist = t[1]
            for le in range(len(t_hist)):
                hist_i[index][le] = t_hist[le]
            gender = t[4]  # 0:female, 1:male
            y_gender[index][gender] = 1
            age = t[5]  # 0,1,2
            y_age[index][age] = 1
            occupation = t[6]  # 0~20
            y_occupation[index][occupation] = 1
        return self.iter, (users, hist_i,items,ys,sl, y_gender, y_age, y_occupation)

    def next(self):
        return self.__next__()


class DataInputTrainAttackerFakeSource:  # (user, hist_i, i, label, gender,age,occupation)
    def __init__(self, train_set_attack_fake_source, config):
        self.data = train_set_attack_fake_source
        self.num_classes_gender = config['num_classes_gender']
        self.num_classes_age = config['num_classes_age']
        self.num_classes_occupation = config['num_classes_occupation']
        self.batch_size = config['train_attack_batch_size']
        self.num_batches = math.ceil(len(self.data) / self.batch_size)  # num of batches
        self.iter = 0

    def __iter__(self):
        return self

    def __next__(self):  # gen a batch
        if self.iter == self.num_batches:
            raise StopIteration  # return vs. StopIteration
        ts = self.data[self.iter*self.batch_size: min((self.iter+1)*self.batch_size, len(self.data))]
        self.iter += 1

        users, items, ys = [], [], []
        sl = []
        # read for the first time...
        for index, t in enumerate(ts):  # user_id, hist, pos_list[i]/neg_list[i], 1/0
            user = t[0]
            users.append(user)
            history = t[1]
            sl.append(len(history))  # sl = len(hist_i)
            item = t[2]
            items.append(item)
            label = t[3]
            ys.append(label)
        max_s1 = max(sl)
        hist_i = np.zeros([len(ts), max_s1], np.int64)
        y_gender = np.zeros([len(ts), self.num_classes_gender])
        y_age = np.zeros([len(ts), self.num_classes_age])
        y_occupation = np.zeros([len(ts), self.num_classes_occupation])
        # read for the second time...
        for index, t in enumerate(ts):
            t_hist = t[1]
            for le in range(len(t_hist)):
                hist_i[index][le] = t_hist[le]
            gender = t[4]  # 0:female, 1:male
            y_gender[index][gender] = 1
            age = t[5]  # 0,1,2
            y_age[index][age] = 1
            occupation = t[6]  # 0~20
            y_occupation[index][occupation] = 1
        return self.iter, (users, hist_i,items,ys,sl, y_gender, y_age, y_occupation)

    def next(self):
        return self.__next__()


class DataInputTest:  # u,i,j
    def __init__(self, test_set_tgt, user_latest_hist_src, batch_size):
        self.batch_size = batch_size
        self.data = test_set_tgt
        self.user_latest_hist_src = user_latest_hist_src
        self.num_batches = math.ceil(len(self.data) / self.batch_size)  # num of batches
        self.iter = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.iter == self.num_batches:
            raise StopIteration  # return vs. StopIteration
        ts = self.data[self.iter*self.batch_size: min((self.iter+1)*self.batch_size, len(self.data))]
        self.iter += 1

        # compute users' source rep
        users, items = [], []
        sl = []
        items_tgt = []
        j_parent = []  # parent of a positive is -1; parent of a negative is its corresponding positive
        # read for the first time...
        for t in ts:  # (user_id, pos, neg)
            # compute users' source rep
            user = t[0]
            users.append(user)
            history = self.user_latest_hist_src[user]
            hist = history[0]
            item = history[1]
            sl.append(len(hist))  # sl = len(hist_i)
            items.append(item)
            # target item
            j_tgt = t[1]
            # label = t[2]
            j_parent.append(t[3])
            items_tgt.append(j_tgt)
        max_s1 = max(sl)
        hist_i = np.zeros([len(ts), max_s1], np.int64)
        # read for the second time...
        for index, t in enumerate(ts):
            user = t[0]
            t_hist, _ = self.user_latest_hist_src[user]
            for le in range(len(t_hist)):
                hist_i[index][le] = t_hist[le]

        return self.iter, (users,hist_i,items,sl, items_tgt, j_parent)

    def next(self):
        return self.__next__()


class DataInputTransferTarget:  # train_set_transfer_tgt: [user,hist_i,item,label, hist_i_src,i_src]
    def __init__(self, train_set_transfer_tgt, batch_size):
        self.batch_size = batch_size
        self.data = train_set_transfer_tgt
        self.num_batches = math.ceil(len(self.data) / self.batch_size)  # num of batches
        self.iter = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.iter == self.num_batches:
            raise StopIteration  # return vs. StopIteration
        ts = self.data[self.iter*self.batch_size: min((self.iter+1)*self.batch_size, len(self.data))]
        self.iter += 1

        users = []  # must already be aligned
        # Target domain
        items_tgt = []
        sl_tgt = []
        ys_tgt = []
        # Source domain
        items_src = []
        sl_src = []
        ys_src = []  # dummy, NOT used
        # read for the first time...
        for t in ts:  # [user,hist_i,item,label, hist_i_src,i_src]
            user = t[0]
            users.append(user)
            # tgt
            history_tgt = t[1]
            sl_tgt.append(len(history_tgt))
            item_tgt = t[2]
            items_tgt.append(item_tgt)
            label = t[3]
            ys_tgt.append(label)
            # src
            history_src = t[4]
            sl_src.append(len(history_src))
            item_src = t[5]
            items_src.append(item_src)
            ys_src.append(1)  # dummy, NOT used
        # Target domain
        max_s1_tgt = max(sl_tgt)
        hist_i_tgt = np.zeros([len(ts), max_s1_tgt], np.int64)
        # Source domain
        max_s1_src = max(sl_src)
        hist_i_src = np.zeros([len(ts), max_s1_src], np.int64)
        # read for the second time...
        for index, merged_t in enumerate(ts):
            tgt_hist = merged_t[1]
            for le in range(len(tgt_hist)):
                hist_i_tgt[index][le] = tgt_hist[le]
            src_hist = merged_t[4]
            for le in range(len(src_hist)):
                hist_i_src[index][le] = src_hist[le]

        return self.iter, (users, hist_i_src,items_src,sl_src,ys_src, hist_i_tgt,items_tgt,sl_tgt,ys_tgt)

    def next(self):
        return self.__next__()


class DataInputTransferSource:  # train_set_transfer_tgt: [user,hist_i,item,label, hist_i_tgt,i_tgt]
    def __init__(self, train_set_transfer_src, batch_size):
        self.batch_size = batch_size
        self.data = train_set_transfer_src
        self.num_batches = math.ceil(len(self.data) / self.batch_size)  # num of batches
        self.iter = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.iter == self.num_batches:
            raise StopIteration  # return vs. StopIteration
        ts = self.data[self.iter*self.batch_size: min((self.iter+1)*self.batch_size, len(self.data))]
        self.iter += 1

        users = []  # must be aligned
        # Source domain
        items_src = []
        sl_src = []
        ys_src = []
        # Target domain
        items_tgt = []
        sl_tgt = []
        ys_tgt = []  # dummy, NOT used
        # read for the first time...
        for t in ts:  # [user,hist_i,item,label, hist_i_tgt,i_tgt]
            user = t[0]
            users.append(user)
            # src
            history_src = t[1]
            sl_src.append(len(history_src))
            item_src = t[2]
            items_src.append(item_src)
            label = t[3]
            ys_src.append(label)
            # tgt
            history_tgt = t[4]
            sl_tgt.append(len(history_tgt))
            item_tgt = t[5]
            items_tgt.append(item_tgt)
            ys_tgt.append(1)  # dummy, NOT used
        # Source domain
        max_s1_src = max(sl_src)
        hist_i_src = np.zeros([len(ts), max_s1_src], np.int64)
        # Target domain
        max_s1_tgt = max(sl_tgt)
        hist_i_tgt = np.zeros([len(ts), max_s1_tgt], np.int64)
        # read for the second time...
        for index, merged_t in enumerate(ts):
            src_hist = merged_t[1]
            for le in range(len(src_hist)):
                hist_i_src[index][le] = src_hist[le]
            tgt_hist = merged_t[4]
            for le in range(len(tgt_hist)):
                hist_i_tgt[index][le] = tgt_hist[le]

        return self.iter, (users, hist_i_src,items_src,sl_src,ys_src, hist_i_tgt,items_tgt,sl_tgt,ys_tgt)

    def next(self):
        return self.__next__()


class DataInputTestAttack:
    def __init__(self, test_set_users_attack, user_latest_hist_src, user_private_map, config):
        self.data = test_set_users_attack
        self.user_latest_hist_src = user_latest_hist_src
        self.user_gender_map = user_private_map['gender']
        self.user_age_map = user_private_map['age']
        self.user_occupation_map = user_private_map['occupation']
        self.num_classes_gender = config['num_classes_gender']
        self.num_classes_age = config['num_classes_age']
        self.num_classes_occupation = config['num_classes_occupation']
        self.batch_size = config['test_batch_size_attack']
        self.num_batches = math.ceil(len(self.data) / self.batch_size)
        self.iter = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.iter == self.num_batches:
            raise StopIteration  # return vs. StopIteration
        ts = self.data[self.iter*self.batch_size: min((self.iter+1)*self.batch_size, len(self.data))]
        self.iter += 1

        # compute users' source rep
        users, items = [], []
        ys_dummy = []
        sl = []
        # read for the first time...
        for index, user in enumerate(ts):
            # compute users' source rep
            users.append(user)
            history = self.user_latest_hist_src[user]
            hist = history[0]
            item = history[1]
            sl.append(len(hist))  # sl = len(hist_i)
            items.append(item)
            ys_dummy.append(1)
        max_s1 = max(sl)
        hist_i = np.zeros([len(ts), max_s1], np.int64)
        y_gender = np.zeros([len(ts), self.num_classes_gender])
        y_age = np.zeros([len(ts), self.num_classes_age])
        y_occupation = np.zeros([len(ts), self.num_classes_occupation])
        # read for the second time...
        for index, user in enumerate(ts):
            t_hist, _ = self.user_latest_hist_src[user]
            for le in range(len(t_hist)):
                hist_i[index][le] = t_hist[le]
            gender = self.user_gender_map[user]
            y_gender[index][gender] = 1
            age = self.user_age_map[user]
            y_age[index][age] = 1
            occupation = self.user_occupation_map[user]
            y_occupation[index][occupation] = 1
        return self.iter, (users,hist_i,items, ys_dummy, sl, y_gender, y_age, y_occupation)

    def next(self):
        return self.__next__()
