import json
import pickle
import gzip
from glob import glob
from tqdm import tqdm
from datetime import datetime, timedelta
import pandas as pd
import numpy as np
import re
import string

from datasets import Dataset
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, AutoModelForTokenClassification


def pull_data_for_dates(
        path_pattern,
        start_date,
        end_date,
        type="all_arts",
        fps_only=False,
        editorials_only=False, 
        sort_bylines=True,
        remove_weather=True,
        remove_merged=False,
        dailies_only=False,
        vietnam_only=False
):

    # List of dates
    start = datetime.strptime(start_date, "%b-%d-%Y")
    end = datetime.strptime(end_date, "%b-%d-%Y")

    year_list = []
    date_list = []

    delta = end - start
    for i in range(delta.days + 1):
        date = start + timedelta(days=i)
        date_list.append(date.strftime("%b-%d-%Y"))
        if type == "all_arts":
            year_list.append(str(date.year) + "_")
        else:
            year_list.append(str(date.year))

    year_list = list(set(year_list))

    # Pull data from those dates
    print(f"\n Pulling all articles from {start_date} to {end_date} ...")

    corpus_dict = {}

    if type == "topic":

        for path in glob(path_pattern):

            if any(year in path for year in year_list):

                with open(path) as f:
                    data = json.load(f)

                    for art in tqdm(data):
                        if any(date in art["id"] for date in date_list):
                            if fps_only:
                                if "-p-1.jpg" in art['id']:
                                    corpus_dict[art['id']] = art
                            else:
                                corpus_dict[art['id']] = art

    else:
        for path_base in glob(path_pattern):

            if any(year in path_base for year in year_list):

                for path in tqdm(glob(path_base + '/**/ocr_text.json')):

                    with open(path) as f:
                        data = json.load(f)

                        for scan in list(data.keys()):
                            if any(date in scan for date in date_list):
                                for art in data[scan]:
                                    if fps_only:
                                        if "-p-1.jpg" in art['id']:
                                            corpus_dict[art['id']] = art
                                    else:
                                        corpus_dict[art['id']] = art

    print(f"{len(corpus_dict)} articles in corpus")

    if dailies_only:

        print("Subsetting to dailies ...")

        daily_corpus = {}

        dailies = pd.read_csv('/mnt/data01/wire_clusters/model_estimates.csv')

        year_dict = {}
        for year in year_list:
            year_str = year[:-1]
            year_dailies = dailies.loc[dailies['year'] == int(year_str)]
            papers = list(year_dailies['paper'])
            year_dict[year_str] = papers

        for art_id in tqdm(list(corpus_dict.keys())):
            year = art_id.split("-")[-3]
            paper = "-".join(art_id.split("-")[1:-5])
            if paper in year_dict[year]:
                daily_corpus[art_id] = corpus_dict[art_id]

        corpus_dict = daily_corpus

    print(f"{len(corpus_dict)} articles in corpus")

    if vietnam_only:
        corpus_dict = remove_non_vietnam(corpus_dict)

    print(f"{len(corpus_dict)} articles in corpus after non-vietnam removed")

    if editorials_only:
        corpus_dict = remove_non_editorials(corpus_dict)

    if remove_weather:
        corpus_dict = remove_weather_articles(corpus_dict)

    print(f"{len(corpus_dict)} articles in corpus after weather removed")

    if remove_merged:
        updated_corpus_dict = {}
        if "1955_" in year_list:
            with open('/mnt/data01/wire_clusters/fa_v_bb_labs/one_day_sample_labels/updated_good.json') as f:
            # with open('/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/wire_clusters/updated_good.json') as f:
                good = json.load(f)
                for art_id in list(corpus_dict.keys()):
                    if art_id in good:
                        updated_corpus_dict[art_id] = corpus_dict[art_id]
            # with open('/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/wire_clusters/news_summaries.json') as f:
            #     ns = json.load(f)
            #     for art_id in list(corpus_dict.keys()):
            #         if art_id not in ns:
            #             updated_corpus_dict[art_id] = corpus_dict[art_id]
        elif "1974_" in year_list:
            # with open('/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/wire_clusters/1974_updated_bads.json') as f:
            with open('/mnt/data01/wire_clusters/fa_v_bb_labs/one_day_sample_labels/1974_updated_bads.json') as f:

                bad = json.load(f)
                for art_id in list(corpus_dict.keys()):
                    if art_id not in bad:
                        updated_corpus_dict[art_id] = corpus_dict[art_id]
        elif "1930_" in year_list:
            # with open('/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/wire_clusters/1930_updated_bads.json') as f:
            with open('/mnt/data01/wire_clusters/fa_v_bb_labs/one_day_sample_labels/1930_updated_bads.json') as f:
                bad = json.load(f)
                for art_id in list(corpus_dict.keys()):
                    if art_id not in bad:
                        updated_corpus_dict[art_id] = corpus_dict[art_id]

        corpus_dict = updated_corpus_dict

    if sort_bylines:
        corpus_dict = update_bylines(corpus_dict)

    print(f"{len(corpus_dict)} articles in corpus")

    return corpus_dict


def pull_data_for_year(
        path_pattern,
        remove_weather=True
):

    print("\n Pulling all articles ...")

    dict_of_corpus_dicts = {}

    for path in tqdm(glob(path_pattern)):

        with open(path) as f:
            data = json.load(f)

            for scan in list(data.keys()):

                date = "-".join(scan.split("-")[-5:-2])
                if date not in dict_of_corpus_dicts:
                    dict_of_corpus_dicts[date] = {}

                for art in data[scan]:
                    dict_of_corpus_dicts[date][art['id']] = art

    print(f"{len(dict_of_corpus_dicts)} days in corpus")

    if remove_weather:
        for date in tqdm(dict_of_corpus_dicts):
            dict_of_corpus_dicts[date] = remove_weather_articles(dict_of_corpus_dicts[date])

    return dict_of_corpus_dicts


def remove_non_editorials(corpus_dict):

    print("Removing non editorials ...")

    # Instantiate tokenizer
    tokenizer = AutoTokenizer.from_pretrained('roberta-large')

    # Featurize data
    sep = find_sep_token(tokenizer)

    corpus = []
    for art_id in list(corpus_dict.keys()):
        corpus.append(featurize_text(
            headline=corpus_dict[art_id]['headline'],
            byline=corpus_dict[art_id]['byline'],
            text=corpus_dict[art_id]['article'],
            sep=sep))

    dataset = Dataset.from_dict({'corpus': corpus})

    # Tokenize datasets
    def tokenize_function(dataset):
        return tokenizer(dataset['corpus'], padding="max_length", truncation=True)

    tokenized_dataset = dataset.map(tokenize_function, batched=True)

    # Predict
    model = AutoModelForSequenceClassification.from_pretrained(
        '/mnt/data01/editorials/trained_models/2022-08-22_11-42-25/checkpoint-270',
        num_labels=2)

    inference_args = TrainingArguments(output_dir="save", per_device_eval_batch_size=512)

    trainer = Trainer(model=model, args=inference_args)

    preds = trainer.predict(tokenized_dataset)

    # Subset to positives only
    predictions = np.argmax(preds.predictions, axis=-1)

    corpus_editorials = {}
    for i, art_id in enumerate(list(corpus_dict.keys())):
        if predictions[i] == 1:
            corpus_editorials[art_id] = corpus_dict[art_id]

    return corpus_editorials


def remove_non_vietnam(corpus_dict):

    print("Removing non vietnam ...")

    # Instantiate tokenizer
    tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')   

    # Featurize data
    sep = find_sep_token(tokenizer)

    corpus = []
    for art_id in list(corpus_dict.keys()):
        corpus.append(featurize_text(
            headline=corpus_dict[art_id]['headline'],
            byline=corpus_dict[art_id]['byline'],
            text=corpus_dict[art_id]['article'],
            sep=sep))

    dataset = Dataset.from_dict({'corpus': corpus})

    # Tokenize datasets
    def tokenize_function(dataset):
        return tokenizer(dataset['corpus'], padding="max_length", truncation=True)

    tokenized_dataset = dataset.map(tokenize_function, batched=True)

    # Predict
    model = AutoModelForSequenceClassification.from_pretrained(
        '/mnt/data01/topic/finetuning/trained_models/Vietnam/2022-07-25_17-28-39/checkpoint-320',
        num_labels=2)

    inference_args = TrainingArguments(output_dir="save", per_device_eval_batch_size=512)

    trainer = Trainer(model=model, args=inference_args)

    preds = trainer.predict(tokenized_dataset)

    # Subset to positives only
    predictions = np.argmax(preds.predictions, axis=-1)

    corpus_editorials = {}
    for i, art_id in enumerate(list(corpus_dict.keys())):
        if predictions[i] == 1:
            corpus_editorials[art_id] = corpus_dict[art_id]

    return corpus_editorials


def remove_weather_articles(corpus_dict):

    # Todo: eventually to put in intial text cleaning (after faro) and remove from here

    print("Removing weather articles ...")

    # Instantiate tokenizer
    tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')

    # Featurize data
    sep = find_sep_token(tokenizer)

    corpus = []
    for art_id in list(corpus_dict.keys()):
        corpus.append(featurize_text(
            headline=corpus_dict[art_id]['headline'],
            byline=corpus_dict[art_id]['byline'],
            text=corpus_dict[art_id]['article'],
            sep=sep))

    dataset = Dataset.from_dict({'corpus': corpus})

    # Tokenize datasets
    def tokenize_function(dataset):
        return tokenizer(dataset['corpus'], padding="max_length", truncation=True)

    tokenized_dataset = dataset.map(tokenize_function, batched=True)

    # Predict
    model = AutoModelForSequenceClassification.from_pretrained(
        # '/mnt/data02/luca_ngrams/wire_article_clustering/ngram_ablations/weather_model',
        '/mnt/data01/abhishek_topic_class/weather_classification/2022-09-08_11-43-29/checkpoint-80/',
        # '/mnt/122a7683-fa4b-45dd-9f13-b18cc4f4a187/wire_clusters/checkpoint-80',
        num_labels=2)

    inference_args = TrainingArguments(output_dir="save", per_device_eval_batch_size=512)

    trainer = Trainer(model=model, args=inference_args)

    preds = trainer.predict(tokenized_dataset)

    # Subset to positives only
    predictions = np.argmax(preds.predictions, axis=-1)

    # A couple of articles that were not labelled as thought to be about weather, but actually
    # turn out to be horoscopes so the classifier doesn't pull them out
    manual_list = ['1_56989018-kingsport-times-Jun-20-1955-p-1.jpg',
                   '16_156139745-rocky-mount-evening-telegram-Jun-20-1955-p-1.jpg',
                   '23_246625881-winchester-evening-star-Jun-20-1955-p-1.jpg',
                   '3_8689940-waukesha-daily-freeman-Jun-20-1955-p-1.jpg',
                   '10_300337714-seymour-daily-tribune-Jun-20-1955-p-1.jpg',
                   '3_51983333-albuquerque-tribune-Jun-20-1955-p-1.jpg']

    corpus_without_weather = {}
    for i, art_id in enumerate(list(corpus_dict.keys())):
        if predictions[i] == 0 and art_id not in manual_list:
            corpus_without_weather[art_id] = corpus_dict[art_id]

    return corpus_without_weather


def detect_bylines(text, model_pipe, tokenizer):

    # get tokens detected as bylines
    preds = model_pipe(text)
    indices = [x['index'] for x in preds]
    tokens = tokenizer.tokenize(text)

    # store text/tokens for use in byline detection
    all_bylines = []
    byline_tokens = []
    cur_byline = ""
    cur_word = ""
    byline = False

    # iterate through all tokens
    for index, token in enumerate(tokens):
      # keep track of tokens detected as bylines
      if (index + 1) in indices:
        byline_tokens.append(index)
      # if we have found a non-punctuation token that is detected as a byline, set byline to True
      if (index + 1) in indices and token not in string.punctuation:
        byline = True
        byline_tokens.append(index)
      # the separator tokens indicates that this is connected to a previous token, so we concatenate
      if "##" in token or "@@" in token and any(i.isalnum() for i in token):
        cur_word = cur_word + re.sub("[#@]", "", token)
        if byline:
          byline_tokens.append(index)
      else:
        # otherwise, we create a new word
        if any(i.isalnum() for i in token):
          cur_word = cur_word + token
        # if this is a byline, add it to our current byline
        if byline and any(i.isalnum() for i in cur_word):
          cur_byline = cur_byline + " " + cur_word
        # otherwise, the previous byline is done, and we add it to our list
        elif not byline and cur_byline != "":
          all_bylines.append(cur_byline.strip())
          cur_byline = ""
        cur_word = ""
        byline = False

    # add the final byline to our list
    if cur_byline != "":
      all_bylines.append(cur_byline.strip())

    # finally, get the text that remains after removing all the bylines
    cur_word = ""
    remainder = ""
    for index, token in enumerate(tokens):
      if index not in byline_tokens:
        if not ("##" in token or "@@" in token):
          cur_word = cur_word + token
          remainder = remainder + " " + cur_word
          cur_word = ""

        else:
          cur_word = cur_word + re.sub("[#@]", "", token)

    # return (bylines, rest of text) as tuple
    return " ".join(all_bylines), remainder


def update_bylines(corpus):
    """
    Input: a dictionary of articles:
    {art_id_1: {
        'headline': headline text,
        'byline': byline text,
        'article': article text,
        ... others but I don't think they're relevant
        },
    art_id_2: ...
    }

    Output: dictionary in the same format, but with corrected bylines (and adjusted headlines and articles as relevant).
    """

    # Todo: eventually to put in intial text cleaning (before faro) and remove from here

    print("Updating bylines ...")

    # load the byline detection model
    # model_path = "/mnt/data02/luca_ngrams/wire_article_clustering/ngram_ablations/byline_model"
    model_path = "/mnt/data01/wire_clusters/byline_detection"
    model = AutoModelForTokenClassification.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model_pipe = pipeline('ner', model=model, tokenizer=tokenizer, device=0)

    # modify bylines, article in-place
    for art_id in list(tqdm(corpus.keys())):
        full_text = corpus[art_id]['byline'] + " " + corpus[art_id]['article']
        bylines, rest = detect_bylines(full_text, model_pipe, tokenizer)

        corpus[art_id]['byline'] = bylines
        corpus[art_id]['article'] = rest

    return corpus


def find_sep_token(tokenizer):

    """
    Returns sep token for given tokenizer
    """

    if 'eos_token' in tokenizer.special_tokens_map:
        sep = " " + tokenizer.special_tokens_map['eos_token'] + " " + tokenizer.special_tokens_map['sep_token'] + " "
    else:
        sep = " " + tokenizer.special_tokens_map['sep_token'] + " "

    return sep


def featurize_text(byline, text, sep, headline=None):

    if headline == "nan":
        headline = " "
    if byline == "nan":
        byline = " "
    if text == "nan" or text != text:
        text = " "

    if headline:
        new_text = headline + byline + sep + text

    else:
        new_text = byline + sep + text

    return new_text


def open_realnews():

    start = datetime.now()
    file_list = [f'/mnt/data01/wire_clusters/llm_data/c4/realnewslike/c4-train.{str(i).zfill(5)}-of-00512.json.gz' for i in range(512)]
    file_list.extend(glob('/mnt/data01/wire_clusters/llm_data/c4/realnewslike/c4-validation**'))

    corpus = []
    print("Loading data ...")
    for file in tqdm(file_list):

        with gzip.open(file, 'r') as fin:
            json_bytes = fin.read()
            json_str = json_bytes.decode('utf-8')
            str_split = json_str.split("\n")
            for string in str_split:
                if len(string) != 0:
                    text = string.split('"text":"')[1].split('","timestamp"')[0]
                    corpus.append(text)

    print(len(corpus), "files in corpus")
    print("Time taken:", datetime.now() - start)

    return corpus


def open_c4_by_url(pattern="patents.google.com", name="patents"):

    start = datetime.now()

    full_corpus = []
    for set in ['train', 'validation']:
        file_list = glob(f'/mnt/data01/wire_clusters/llm_data/c4/en/c4-{set}**.json.gz')

        corpus = []
        print("Loading data ...")
        for file in tqdm(file_list):

            with gzip.open(file, 'r') as fin:
                json_bytes = fin.read()
                json_str = json_bytes.decode('utf-8')
                str_split = json_str.split("\n")
                for string in str_split:
                    if len(string) != 0:

                        url = string.split('"url":"')[1].split('"')[-2]
                        if pattern in url:

                            text = string.split('"text":"')[1].split('","timestamp"')[0]
                            corpus.append(text)

        print(len(corpus), f"files in {set} set")
        print("Time taken:", datetime.now() - start)

        with open(f"/mnt/data01/wire_clusters/llm_data/c4/{name}_{set}.pkl", "wb") as f:
            pickle.dump(corpus, f, protocol=4)

        full_corpus.extend(corpus)

    return corpus


def pull_loads_of_newspaper_data(target_size=10000000):

    data_count = 0
    year = 1950

    corpus = []

    while data_count < target_size:

        print(f"********************{str(year)}********************")
        year_data = pull_data_for_dates(
            path_pattern='/mnt/data01/rule_based_outputs/**/',  # '/mnt/data01/editorials/updated_vietnam_topic_*, for topic
            start_date=f"Jan-01-{str(year)}",
            end_date=f"Dec-31-{str(year)}",
            fps_only=True,
            sort_bylines=False,
            remove_weather=True,
            remove_merged=False,
            dailies_only=True
        )

        print(f"Articles pulled: {len(year_data)}")

        with open(f'/mnt/data01/wire_clusters/super_large_inf_data_{str(year)}.json', 'w') as of:
            json.dump(year_data, of, indent=4)

        data_count += len(year_data)
        year += 1

        for art_id in list(year_data.keys()):
            corpus.append(year_data[art_id]['article'])

    return corpus
