__author__ = 'Luchen'
import RBO
import json


similar_word_list = '../data/common_word_similar_list'
common_word_list_path = '../data/CommonAllVocabulary_150406.txt'
common_word_list = [item.strip() for item in open(common_word_list_path).readlines()]


def get_topN_word_lists(N):
    with open(common_word_list_path) as reader:
        return [next(reader).strip() for x in xrange(N)]

#########
# original RBO computation
def get_rbo(wiki, tweet, top_K, round_num=None):
    wiki = wiki[:top_K]
    tweet = tweet[:top_K]
    if not round_num:
        score = RBO.score(wiki, tweet)
    else:
        score = round(RBO.score(wiki, tweet), round_num)
    return score


##########
#  only consider top N terms in common set
# common_set = [term.strip() for term in open('../data/top_common_set_1000').readlines()]
# def get_rbo(wiki, tweet, top_K):
#     wiki_list = [term for term in wiki if term in common_set][:top_K]
#     tweet_list = [term for term in tweet if term in common_set][:top_K]
#     return RBO.score(wiki_list, tweet_list)


def get_topK_rbo_value(top_N, top_K, remove_single_alpha_word=False, round_num=3):
    rbo_list = []
    fp = open(similar_word_list)
    for i, line in enumerate(fp):
        data = json.loads(line)
        wiki_list = [item['term'] for item in data['wiki_list']]
        tweet_list = [item['term'] for item in data['tweet_list']]
        score = get_rbo(wiki_list, tweet_list, top_K)
        term = data['term']
        if remove_single_alpha_word:
            if not len(term) <= 1:
                rbo_list.append({'word': term, 'score': score})
        else:
            rbo_list.append({'word': term, 'score': score})
        if i == top_N:
            break
    fp.close()
    return rbo_list



def split_lines(line_list, NUM_PER_LINE):
    line_count = int((len(line_list) + NUM_PER_LINE -1)/NUM_PER_LINE)
    splited = []
    for i in range(line_count-1):
        splited.append(line_list[NUM_PER_LINE*i:NUM_PER_LINE*(i+1)])
    splited.append(line_list[NUM_PER_LINE*(line_count-1):len(line_list)])
    return splited


def sort_topK_by_score(line_list):
    return sorted(line_list, key=lambda item:item['score'], reverse=True)


def get_similar_lists_by_term(term, topK):
    fp = open(similar_word_list)
    data = {}
    for i, line in enumerate(fp):
        data = json.loads(line)
        line_term = data['term']
        if term == line_term:
            break
    fp.close()
    wiki_list = [item['term'] for item in data['wiki_list']]
    tweet_list = [item['term'] for item in data['tweet_list']]
    score = get_rbo(wiki_list, tweet_list, topK)
    show_list = []
    for i in range(50):
        wiki_term = data['wiki_list'][i]
        tweet_term = data['tweet_list'][i]
        show_list.append({'wiki_term': wiki_term['term'],
                          'wiki_score': wiki_term['cosine'],
                          'wiki_tag': wiki_term['common'],
                          'wiki_rank': wiki_term['rank'],
                          'tweet_term': tweet_term['term'],
                          'tweet_score': tweet_term['cosine'],
                          'tweet_tag': tweet_term['common'],
                          'tweet_rank': tweet_term['rank']})
    data['list'] = show_list
    return data, score


def read_in_median_distance(method, top_N=None):
    rbo_dict = {line.strip().split('\t')[0]: float(line.strip().split('\t')[1]) for line in open('top_10000_rbo_value').readlines()}
    word_list = []
    path = '../data/' + method + '_distance_to_median_withnegative'
    fp = open(path)
    if top_N:
        for i, line in enumerate(fp):
            if i == top_N:
                break
            line_split = line.strip().split('\t')
            term = line_split[0]
            distance = float(line_split[1])
            rank_in_tweet = int(line_split[2])
            try:
                rbo = round(rbo_dict[term], 3)
            except KeyError:
                rbo = '??'
            word_list.append({'term': term, 'distance': distance, 'rank': rank_in_tweet, 'rbo': rbo})
    else:
        for i, line in enumerate(fp):
            line_split = line.strip().split('\t')
            term = line_split[0]
            distance = float(line_split[1])
            rank_in_tweet = int(line_split[2])
            try:
                rbo = round(rbo_dict[term], 3)
            except KeyError:
                rbo = '??'
            word_list.append({'term': term, 'distance': distance, 'rank': rank_in_tweet, 'rbo': rbo})
    fp.close()
    return word_list


def sort_topK_by_rank(line_list):
    return sorted(line_list, key=lambda item:item['rank'])


def sort_by_distance_in_topK(list_sorted_by_rank, top_K=5000):
    selected = []
    for term in list_sorted_by_rank:
        if term['rank'] < top_K:
            selected.append(term)
    sorted_top = sorted(selected, key=lambda item:item['distance'], reverse=True)
    #print sorted_top[0]
    # with open('t2w_top5000_sorted_by_adjusted_distance', 'w') as out:
    #     out.write('\n'.join([item['term'] + '\t' + str(item['distance']) + '\t' + str(item['rank']) for item in sorted_top]))
    # out.close()
    return sorted_top

# rbo = get_topK_rbo_value(10000, 50)
# with open('top_10000_rbo_value', 'w') as out:
#
#     out.write('\n'.join([item['word'] + '\t' + str(item['score']) for item in rbo]))
# out.close()