import torch
import nltk
import string
import re

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def handle_punc(text):
        exclude = set(string.punctuation + "".join([u"‘", u"’", u"´", u"`"]))
        return ''.join(ch if ch not in exclude else ' ' for ch in text)

    def lower(text):
        return text.lower()

    def replace_underscore(text):
        return text.replace('_', ' ')

    return white_space_fix(remove_articles(handle_punc(lower(replace_underscore(s))))).strip()

def customized_tokenize(tokenizer, words_lst):
    subtoken_ids=[tokenizer.bos_token_id]
    word_idx_to_subtoken_start_end_idx={}
    for word_idx, word in enumerate(words_lst):
        word_subtokens_ids=tokenizer(word, add_special_tokens=False)["input_ids"]
        subtoken_start_idx=len(subtoken_ids)
        subtoken_ids.extend(word_subtokens_ids)
        subtoken_end_idx=len(subtoken_ids)-1
        word_idx_to_subtoken_start_end_idx[word_idx]=(subtoken_start_idx, subtoken_end_idx)
        
    subtoken_ids=torch.tensor(subtoken_ids, dtype=torch.int).unsqueeze(0)
    
    return subtoken_ids, word_idx_to_subtoken_start_end_idx

def is_stopword(word):
    stopwords = ["i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", "yourselves", "he", "him", "his", "himself", "she", "her", "hers", 
                 "herself", "it", "its", "itself", "they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this", "that", "these", "those", "am", "is", 
                 "are", "was", "were", "be", "been", "being", "have", "has", "had", "having", "do", "does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", 
                 "as", "until", "while", "of", "at", "by", "for", "with", "about", "against", "between", "into", "through", "during", "before", "after", "above", "below", "to", 
                 "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", "further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "any", 
                 "both", "each", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s", "t", "can", "will", 
                 "just", "don", "should", "now", '!', "@",  '#', "$", '%', "^", '&', "*", '(', ")", '-', "_", '+', "=", '[', "]", '{', "}",  "|", "\'", ';', "\"", "\'", "<", ">", ",",
                 "." , "?", "/", "\n", "answer", "query", "q", "a", "document", "documents", "according", "following"
            ]
    
    if word.lower() in stopwords:
        return True
    else:
        return False
    
def flatten(xss):
    return [x for xs in xss for x in xs]

def get_peak_memory(device):
    stats=torch.cuda.memory_stats_as_nested_dict(device=device)
    if 'allocated_bytes' not in stats:
        return 0
    return stats['allocated_bytes']['all']['peak']//1000000

def get_flops(model, input):
    return profile_macs(model, tuple(input))*2