import os
import sys
import torch.multiprocessing as mp
os.chdir("/home/[USER]/workshop/wikihow")
sys.path.append("/home/[USER]/workshop/wikihow")
import json
import time
from nltk.tokenize import sent_tokenize
from external.autopunct.correct import correct
import spacy
from youtube_transcript_api import YouTubeTranscriptApi
import random

DEBUG = False
FROM_SCRATCH = False

def download(d, task_queue, punct_model, caps_model):
    fail = 0
    while True:
        if task_queue.qsize() == 0:
            break
        task = task_queue.get()
        try:
            cp = YouTubeTranscriptApi.get_transcript(task)
            cp = ' '.join([x['text'] for x in cp])
            cp = correct(cp, punct_model=punct_model, caps_model=caps_model)
            tk_cp = sent_tokenize(cp)
            d[task] = tk_cp
        except:
            fail += 1
        print(f"there are {task_queue.qsize()} left, {fail} errors already")

    print(fail)


def extract_video_caption(video_list):
    # load data
    # with open("./data/howto100m/onek_meta_data.p", "rb") as f:
    #     meta_data = pickle.load(f)

    manager = mp.Manager()
    task_queue = manager.Queue()

    for vid in video_list:
        task_queue.put(vid)
    print(f"number of videos: {task_queue.qsize()}")

    # punct model
    punct_model = spacy.load("./external/autopunct/punct-model")
    caps_model = spacy.load("./external/autopunct/caps-model")
    # punct_model = None
    # caps_model = None

    res = manager.dict()
    ps = []

    for i in range(12):
        p = mp.Process(target=download, args=(res, task_queue, punct_model, caps_model))
        p.start()
        time.sleep(1)
        ps.append(p)

    for p in ps:
        p.join()

    res = res.copy()

    return res

    # meta_map = {'vid_task_map': {}, 'tid_task_map': {}}
    # for k, v in meta_data.items():
    #     meta_map['tid_task_map'][k] = v['text']['task']
    #     for vid in v['videos']['train'] + v['videos']['val'] + v['videos']['test']:
    #         meta_map['vid_task_map'][vid] = v['text']['task']

    # with open("./data/howto100m/video1k/video1k.meta_map.json", "w+") as f:
    #     json.dump(meta_map, f, indent=2)

def split_2k_video():
    with open("./data/howto100m/video2k/all_task_video_map.json", "r") as f:
        d = json.load(f)
    with open("./data/howto100m/video1k/split_task_video_map.json", "r") as f:
        seen_task = json.load(f)['train'].keys()
    unseen_task = set([x for x in d if x not in seen_task])
    print(f"all: {len(d)}, seen: {len(seen_task)}, unseen: {len(unseen_task)}")

    split = {'train': {}, 'dev': {}, 'test': {}}
    test_v = set()
    dev_v = set()
    for t in unseen_task:
        v = d[t]
        random.shuffle(v)
        split['dev'][t] = [x[0] for x in v[:10]]
        dev_v =  dev_v.union(set(split['dev'][t]))
        split['test'][t] = [x[0] for x in v[10:]]
        test_v =  test_v.union(set(split['test'][t]))

    print(len(dev_v), len(test_v))
    with open("./data/howto100m/video2k/split_task_video_map.json", "w+") as f:
        json.dump(split, f, indent=2)

    with open("./data/howto100m/video2k/all_task_video_caption.punct.json", "r") as f:
        cap_map = json.load(f)

    dev_cap = {k: v for k, v in cap_map.items() if k in dev_v}
    test_cap = {k: v for k, v in cap_map.items() if k in test_v}
    print(len(dev_cap), len(test_cap))

    with open("./data/howto100m/video2k/video2k.dev.caption.json", "w+") as f:
        json.dump(dev_cap, f, indent=2)

    with open("./data/howto100m/video2k/video2k.test.caption.json", "w+") as f:
        json.dump(test_cap, f, indent=2)



def extract_split():
    train_v = set()
    dev_v = set()
    test_v = set()
    with open("./data/howto100m/video2k/split_task_video_map.json") as f:
        d = json.load(f)
        for goal, v in d['train'].items():
            train_v = train_v.union(set(v))
        for goal, v in d['dev'].items():
            dev_v = dev_v.union(set(v))
        for goal, v in d['test'].items():
            test_v = test_v.union(set(v))

    print(f"train: {len(train_v)}, dev: {len(dev_v)}, test: {len(test_v)}")

    res = extract_video_caption(list(train_v) + list(dev_v) + list(test_v))

    with open("./data/howto100m/video2k/video2k.train.caption.json", "w+") as f:
        _res = {k: v for k, v in res.items() if k in train_v}
        print(len(_res))
        json.dump(_res, f, indent=2)

    with open("./data/howto100m/video2k/video2k.dev.caption.json", "w+") as f:
        _res = {k: v for k, v in res.items() if k in dev_v}
        print(len(_res))
        json.dump(_res, f, indent=2)

    with open("./data/howto100m/video2k/video2k.test.caption.json", "w+") as f:
        _res = {k: v for k, v in res.items() if k in test_v}
        print(len(_res))
        json.dump(_res, f, indent=2)

def add_punct(d, task_queue, punct_model, caps_model):
    fail = 0
    while True:
        if task_queue.qsize() == 0:
            break
        task = task_queue.get()
        try:
            cp = task['caption']
            cp = correct(cp, punct_model=punct_model, caps_model=caps_model)
            tk_cp = sent_tokenize(cp)
            d[task['vid']] = tk_cp
        except:
            fail += 1
        print(f"there are {task_queue.qsize()} left, {fail} errors already")
    print(fail)

def punct_video_caption(caption_map):
    manager = mp.Manager()
    task_queue = manager.Queue()

    for vid, cp in caption_map.items():
        task_queue.put({'vid': vid, 'caption': cp})
    print(f"number of captions: {task_queue.qsize()}")

    # punct model
    punct_model = spacy.load("./external/autopunct/punct-model")
    caps_model = spacy.load("./external/autopunct/caps-model")
    # punct_model = None
    # caps_model = None

    res = manager.dict()
    ps = []

    for i in range(12):
        p = mp.Process(target=add_punct, args=(res, task_queue, punct_model, caps_model))
        p.start()
        time.sleep(1)
        ps.append(p)

    for p in ps:
        p.join()

    res = res.copy()

    return res

if __name__ == "__main__":
    split_2k_video()
    exit(0)
    # with open("./data/howto100m/video2k/all_task_video_map.json", "r") as f:
    #     d = json.load(f)
    # video_list = []
    # for k, v in d.items():
    #     video_list += [vv[0] for vv in v]
    # video_list = list(set(video_list))
    # res = extract_video_caption(list(video_list))
    # print(len(res))
    # with open("./data/howto100m/video2k/all_task_video_caption.json", "w+") as f:
    #     json.dump(res, f, indent=2)

    with open("./data/howto100m/video2k/all_task_video_caption.json", "r") as f:
        d = json.load(f)
        d = {k: d[k] for k in list(d.keys())}
    res = punct_video_caption(caption_map=d)
    print(len(res))
    with open("./data/howto100m/video2k/all_task_video_caption.punct.json", "w+") as f:
        json.dump(res, f, indent=2)
