import json
from collections import Counter
import numpy
from random import shuffle

neg_review_path = "<output-dir>/s1"
pos_review_path = "<output-dir>/s2"
max_pos = 500000
max_neg = 500000
min_len = 0
max_len = 100
max_samples = 1000000000

def check_reviews(businesses = None, only_large_small = True, writeout=True):
    if businesses is not None:
        businesses = set(businesses)
    all_ratings = []
    all_lengths = []
    pos_reviews = []
    neg_reviews = []
    total_num_reviews = 0
    #max_samples = 100000

    with open('<input-dir>/yelp_dataset/yelp_academic_dataset_review.json') as fp:
        for i, l in enumerate(fp):
            if (i % 10000) == 0:
                print(i)

            review = json.loads(l)

            review_text = review["text"]

            # remove escape sequences
            review_text = review_text.replace("\n", " ")
            review_text = review_text.replace("\t", " ")
            review_text = review_text.replace("\r", " ")

            review_len = len(review_text.split(" "))
            rating = review["stars"]
            if review_len < min_len or review_len > max_len:
                continue
            if businesses is not None and review["business_id"] not in businesses:
                continue
            if only_large_small and (rating < 5.0 and rating > 2.0):
                continue

            all_ratings.append(rating)
            total_num_reviews += 1
            all_lengths.append(numpy.log2(review_len))

            if writeout:
                if rating == 5.0:
                    pos_reviews.append(review_text+"\n")
                else:
                    neg_reviews.append(review_text+"\n")
                    
            if i > max_samples:
                break

    shuffle(pos_reviews)
    shuffle(neg_reviews)
    if len(pos_reviews) > max_pos:
        pos_reviews = pos_reviews[:max_pos]
    if len(neg_reviews) > max_neg:
        neg_reviews = neg_reviews[:max_neg]

    pos_file = open(pos_review_path, "w")
    for p in pos_reviews:
        pos_file.write(p)
    pos_file.close()

    neg_file = open(neg_review_path, "w")
    for n in neg_reviews:
        neg_file.write(n)
    neg_file.close()

    counter = Counter(all_ratings)
    bins = numpy.linspace(0, 10, 11)
    digitized = numpy.digitize(numpy.array(all_lengths), bins)
    len_counter = Counter(digitized)

    print(f"Total reviews: {total_num_reviews}")

    print("lens:")
    for i in range(0, 11):

        print(f"{2 ** i} : {float(len_counter[i]) / total_num_reviews}")

    print("stars:")
    for s in [1.0, 2.0, 3.0, 4.0, 5.0]:
        print(s, ":", counter[s], counter[s] / float(total_num_reviews), "%")


def check_businesses(category = None):
    all_businesses_of_category = []
    with open('<input-dir>/yelp_dataset/yelp_academic_dataset_business.json') as fp:

        for i, l in enumerate(fp):

            business = json.loads(l)
            #for k,v in review.items():
            #    print(k)
            if category is None or (business["categories"] is not None and category in business["categories"]):
                all_businesses_of_category.append(business["business_id"])

    return all_businesses_of_category

all_businesses_of_category = check_businesses(category = "Restaurants")
check_reviews(businesses = all_businesses_of_category)
