import os
import sys
import pickle
import numpy as np
import torch
import tqdm
import argparse

from transformers import BertTokenizer
from utils.character_cnn import CharacterIndexer
from modeling.character_bert import CharacterBertModel

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

tokenizer = BertTokenizer.from_pretrained('path/to/bert-tokenizer')
model = CharacterBertModel.from_pretrained('path/to/character-bert/pretrained-models/general_character_bert')
model.to(device)
indexer = CharacterIndexer()

def read_lines(file_path):
    with open(file_path, 'r') as f:
        lines = [line.strip() for line in f.readlines()]
    return lines

def extract_frequency_table_from_lines(lines):
    frequency_table = {}
    for line in lines:
        for word in line.strip().split():
            if word not in frequency_table:
                frequency_table[word] = 1
            else:
                frequency_table[word] += 1
    sorted_frequency_table = sorted(frequency_table.items(), key=lambda x: x[1], reverse=True) # the type of sorted_frequency_table is list
    # convert it to dict
    return frequency_table

def cosine_similarity(vector1, vector2):
    return np.dot(vector1, vector2) / (np.linalg.norm(vector1) * np.linalg.norm(vector2))

def get_start_end_idx(toked_line1, toked_line2):
    idx = []
    current_pos = 0
    for i, word in enumerate(toked_line1):
        if (toked_line2[current_pos] == "[CLS]"):
            current_pos += 1
        if (toked_line2[current_pos] == "[SEP]"):
            assert (i==len(toked_line1)-1)
            break
        start_idx = current_pos
        toked_word = tokenizer.basic_tokenizer.tokenize(word)
        end_idx = current_pos + len(toked_word) - 1
        current_pos = end_idx + 1
        idx.append((start_idx, end_idx))
    return idx

def get_avg_embedding(embeddings):
    avg_embeddings = {}
    for word, embedding_list in embeddings.items():
        avg_embeddings[word] = np.mean(embedding_list, axis=0)
        assert (not np.isnan(avg_embeddings[word]).any())
    return avg_embeddings

def get_embedding_for_tokenized_line(tokenized_line_with_cls_sep):
    batch = [tokenized_line_with_cls_sep]
    batch_ids = indexer.as_padded_tensor(batch).to(device)
    embeddings_for_batch, _ = model(batch_ids)
    embeddings_for_line = embeddings_for_batch[0]
    return embeddings_for_line

def get_embeddings_from_lines(lines):
    print ("Generating embeddings for lines...")
    embeddings = {}
    for line in tqdm.tqdm(lines):
        tokenized_line_by_split = line.strip().split()

        tokenized_line = tokenizer.basic_tokenizer.tokenize(line)
        if (len(tokenized_line_by_split)>300):
            continue
        tokenized_line_with_cls_sep = ['[CLS]', *tokenized_line, '[SEP]']
        embeddings_for_line = get_embedding_for_tokenized_line(tokenized_line_with_cls_sep)

        idx = get_start_end_idx(tokenized_line_by_split, tokenized_line_with_cls_sep)
        for i, word in enumerate(tokenized_line_by_split):
            try:
                start_idx, end_idx = idx[i]
                word_embedding = torch.mean(embeddings_for_line[start_idx:end_idx+1], dim=0)
                if (torch.isnan(word_embedding).any()):
                    print (word)
                    continue
                #assert (not torch.isnan(word_embedding).any())
                word_embedding = word_embedding.detach().cpu().numpy()

                if (embeddings.get(word) is None):
                    embeddings[word] = []
                if (len(embeddings[word])<20):
                    embeddings[word].append(word_embedding)
            except:
                continue
    return embeddings

def get_embeddings_from_lines_one(lines):
    print ("Generating embeddings for lines...")
    embeddings = {}
    for line in tqdm.tqdm(lines):
        line = line.strip()
        tokenized_line_by_split = line.strip().split()
        if (len(line)==0): continue
        if (len(tokenized_line_by_split)>300):
            continue

        process_flag = 0
        for word in tokenized_line_by_split:
            if (word not in embeddings):
                process_flag = 1
                break
        if (process_flag == 0):
            continue

        tokenized_line = tokenizer.basic_tokenizer.tokenize(line)
        tokenized_line_with_cls_sep = ['[CLS]', *tokenized_line, '[SEP]']

        embeddings_for_line = get_embedding_for_tokenized_line(tokenized_line_with_cls_sep)

        idx = get_start_end_idx(tokenized_line_by_split, tokenized_line_with_cls_sep)
        for i, word in enumerate(tokenized_line_by_split):
            if (word not in embeddings):
                start_idx, end_idx = idx[i]
                word_embedding = torch.mean(embeddings_for_line[start_idx:end_idx+1], dim=0)
                if (torch.isnan(word_embedding).any()):
                    print ("NAN", word)
                    continue
                word_embedding = word_embedding.detach().cpu().numpy()
                embeddings[word] = word_embedding
    return embeddings

def save_embeddings(embedding_dict, output_file):
    with open(output_file, 'wb') as f:
        pickle.dump(embedding_dict, f)

parser = argparse.ArgumentParser()
parser.add_argument('--input_file', type=str)
parser.add_argument('--output_file', type=str)
parser.add_argument('--avg', type=int, default=1)
args = parser.parse_args()
input_file = args.input_file
output_file = args.output_file
avg = args.avg

lines = read_lines(input_file)
if (avg==1):
    embeddings = get_embeddings_from_lines(lines)
    res = get_avg_embedding(embeddings)
else:
    res = get_embeddings_from_lines_one(lines)
save_embeddings(res, output_file)
