import sys
import time
from collections import defaultdict
import random
import statistics
import json
import pickle


random.seed(2019)
start_time = time.time()
history_i_cutoff = 10
document_length_cutoff = 30
max_hist_i = -1
data_dir = '../'

# Source domain
data_domain = 'source'  # source domain
print('---Source domain: {}---'.format(data_domain))

items_set = set()
user_item_ts_src = defaultdict(list)
total_feedback = 0
with open(data_dir + data_domain + '_user_item_timestamps' + '.data') as fin:
    for line in fin:
        line = line.strip()
        if not line:
            continue
        line = line.split()
        user = int(line[0])
        n_feed = int(line[1])
        total_feedback += n_feed
        line = line[2:]
        hist = []
        for count in range(n_feed):
            item = int(line[count * 2])
            ts = int(line[count * 2 + 1])
            hist.append([item,ts])
            items_set.add(item)
        hist = sorted(hist, key=lambda x: x[1])  # sort by timestamp from old to recent
        user_item_ts_src[user] = hist
print('#items_set = {}'.format(len(items_set)))  # 1484
print('total_feedback = {}'.format(total_feedback))  # 299830
print('#user_item_ts_src = {}'.format(len(user_item_ts_src)))  # 5967
items_list = list(items_set)
max_item_id_src = max(items_list)
print('max_item_id_src = {}'.format(max_item_id_src))  # 3947
max_user_id_src = max(user_item_ts_src.keys())
print('max_user_id_src = {}'.format(max_user_id_src))  # 6040

hist_length_list = []
train_set = []  # (user, hist_i, i, label)
user_latest_hist_src = defaultdict(list)
for user_id, history in user_item_ts_src.items():
    if len(history) == 1:  # users in training must consume at least two items
        continue
    pos_list = [e[0] for e in history]  # (item, timestamp)
    pos_set = set(pos_list)

    def gen_neg():
        neg = pos_list[0]
        while neg in pos_set:
            neg = random.randint(0, len(items_list)-1)  # [a,b] Note including b!!!
        return neg

    neg_train_list = [gen_neg() for _ in range(len(pos_list))]  # train: pos:neg = 1:1
    for i in range(1, len(pos_list)):  # i=1,2,...,n
        hist = pos_list[:i]  # sliding window
        hist_length_list.append(len(hist))
        if len(hist) > history_i_cutoff:
            hist = hist[len(hist)-history_i_cutoff:]
        train_set.append((user_id, hist, pos_list[i], 1))
        train_set.append((user_id, hist, neg_train_list[i], 0))
    # latest history
    hist = pos_list[:-2]
    if len(hist) > history_i_cutoff:
        hist = hist[len(pos_list)-history_i_cutoff:]
    user_latest_hist_src[user_id] = [hist, pos_list[-1]]

print('#train = {}'.format(len(train_set)))  # 587726
max_hist_length = max(hist_length_list)
mean_hist_length = statistics.mean(hist_length_list)
print('hist(mean,max) = {:.2f}, {}'.format(mean_hist_length, max_hist_length))  # 63.66, 693
print('#user_latest_hist_src = {}'.format(len(user_latest_hist_src)))  # 5882
random.shuffle(train_set)

with open(data_domain + '.source.prnet.pkl', 'wb') as ofile:
    pickle.dump(train_set, ofile, pickle.HIGHEST_PROTOCOL)
    pickle.dump(max_item_id_src, ofile, pickle.HIGHEST_PROTOCOL)
    pickle.dump(max_user_id_src, ofile, pickle.HIGHEST_PROTOCOL)

with open(data_domain + '.source.user_latest_hist.json', 'w') as ofile:
    json.dump(user_latest_hist_src, ofile)

# Target domain
data_domain = 'target'  # target domain
print('---Target domain: {}---'.format(data_domain))

items_set = set()
user_item_ts = defaultdict(list)
total_feedback = 0
with open(data_dir + data_domain + '_user_item_timestamps' + '.data') as fin:
    for line in fin:
        line = line.strip()
        if not line:
            continue
        line = line.split()
        user = int(line[0])
        n_feed = int(line[1])
        total_feedback += n_feed
        line = line[2:]
        hist = []
        for count in range(n_feed):
            item = int(line[count * 2])
            ts = int(line[count * 2 + 1])
            hist.append([item,ts])
            items_set.add(item)
        hist = sorted(hist, key=lambda x: x[1])  # sort by timestamp from old to recent
        user_item_ts[user] = hist
print('#items_set = {}'.format(len(items_set)))  # 2049
print('total_feedback = {}'.format(total_feedback))  # 274115
print('#user_item_ts = {}'.format(len(user_item_ts)))  # 5967
items_list = list(items_set)
max_item_id_tgt = max(items_list)
print('max_item_id_tgt = {}'.format(max_item_id_tgt))  # 3952

users_list = list(user_item_ts.keys())
max_user_id_tgt = max(users_list)
print('max_user_id_tgt = {}'.format(max_user_id_tgt))  # 6040

random.seed(2019)
random.shuffle(users_list)

# users in target training must have at least two items in source domain: jointly training with aligned users
num_users_train = int(0.8 * len(users_list))
users_train_set = set()
for user in users_list:
    if len(user_item_ts_src[user]) != 1:
        users_train_set.add(user)
        if len(users_train_set) == num_users_train:
            break
# users in training must have at least two items
users_train_list = [user for user in users_train_set if len(user_item_ts[user]) != 1]
print('#users_train_list = {}'.format(len(users_train_list)))  # 5315

# users must have at least two items in source domain when test: compute their target rep from SOURCE knowledge
users_test_list = [user for user in users_list if user not in users_train_set and
                   len(user_item_ts[user]) != 1 and len(user_item_ts_src[user]) != 1]
print('#users_test_list = {}'.format(len(users_test_list)))  # 509
print('#users (train+test) = {}'.format(len(users_train_list) + len(users_test_list)))  # 5824

hist_length_list = []
train_set = []  # (user, hist_i, i, label)
user_latest_hist_tgt = defaultdict(list)
for user_id in users_train_list:
    history = user_item_ts[user_id]
    pos_list = [e[0] for e in history]  # (item, timestamp)
    assert len(pos_list) != len(items_list) and len(pos_list) > 1  # users in training must consume at least two items

    pos_set = set(pos_list)

    def gen_neg():
        neg = pos_list[0]
        while neg in pos_set:
            neg = random.randint(0, len(items_list)-1)  # [a,b] Note including b!!!
        return neg

    neg_train_list = [gen_neg() for _ in range(len(pos_list))]  # train: pos:neg = 1:1
    for i in range(1, len(pos_list)):
        hist = pos_list[:i]  # sliding window
        hist_length_list.append(len(hist))
        if len(hist) > history_i_cutoff:
            hist = hist[len(hist)-history_i_cutoff:]
        train_set.append((user_id, hist, pos_list[i], 1))
        train_set.append((user_id, hist, neg_train_list[i], 0))
    # latest history
    hist = pos_list[:-2]
    if len(hist) > history_i_cutoff:
        hist = hist[len(pos_list)-history_i_cutoff:]
    user_latest_hist_tgt[user_id] = [hist, pos_list[-1]]

print('#train = {}'.format(len(train_set)))  # 487654
max_hist_length = max(hist_length_list)
mean_hist_length = statistics.mean(hist_length_list)
print('hist(mean,max) = {:.2f}, {}'.format(mean_hist_length, max_hist_length))  # 54.14, 740
print('#user_latest_hist_tgt = {}'.format(len(user_latest_hist_tgt)))  # 4726

test_set = []  # (user, item, label)
for user_id in users_test_list:
    history = user_item_ts[user_id]
    pos_list = [e[0] for e in history]  # (item, timestamp)
    assert len(pos_list) != len(items_list) and len(pos_list) != 0
    pos_set = set(pos_list)

    def gen_neg():
        neg = pos_list[0]
        while neg in pos_set:
            neg = random.randint(0, len(items_list)-1)  # [a,b] Note including b!!!
        return neg

    # For target domain, no history to compute a user' rep; her source rep is used instead.
    for pos in pos_list:
        test_set.append((user_id, pos, 1, -1))  # 1 positive
        neg99 = [gen_neg() for _ in range(99)]  # 1:99
        for neg in neg99:
            test_set.append((user_id, neg, 0, pos))  # 99 negatives
    # maybe NOT used: since we do NOT evaluate the recommendation on the source domain
    hist = pos_list[:-2]
    if len(hist) > history_i_cutoff:
        hist = hist[len(pos_list)-history_i_cutoff:]
    user_latest_hist_tgt[user_id] = [hist, pos_list[-1]]
print('#test = {}'.format(len(test_set)))  # 2369500
print('#user_latest_hist_tgt = {}'.format(len(user_latest_hist_tgt)))  # 5824

with open(data_domain + '.target.user_latest_hist.json', 'w') as ofile:
    json.dump(user_latest_hist_tgt, ofile)

random.shuffle(train_set)
random.shuffle(test_set)

with open(data_domain + '.target.prnet.pkl', 'wb') as ofile:
    pickle.dump(train_set, ofile, pickle.HIGHEST_PROTOCOL)
    pickle.dump(test_set, ofile, pickle.HIGHEST_PROTOCOL)
    pickle.dump(max_item_id_tgt, ofile, pickle.HIGHEST_PROTOCOL)
    pickle.dump(max_user_id_tgt, ofile, pickle.HIGHEST_PROTOCOL)

print('---{:.2f}s---'.format(time.time() - start_time))  # 18
