import sys
import json 
import regex
import unicodedata

import re
import string
import argparse

def normalize_answer(s):
  """Lower text and remove punctuation, articles and extra whitespace."""
  def remove_articles(text):
    regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
    return re.sub(regex, ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))

def get_tokens(s):
  if not s: return []
  return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
  return int(normalize_answer(a_gold) == normalize_answer(a_pred))


class SimpleTokenizer(object):
    ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
    NON_WS = r'[^\p{Z}\p{C}]'

    def __init__(self):
        """
        Args:
            annotators: None or empty set (only tokenizes).
        """
        self._regexp = regex.compile(
            '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
            flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
        )

    def tokenize(self, text, uncased=False):
        matches = [m for m in self._regexp.finditer(text)]
        if uncased:
            tokens = [m.group().lower() for m in matches]
        else:
            tokens = [m.group() for m in matches]
        return tokens


def read_json(f):
	return json.load(open(f, 'r'))

def write_json(filename, data):
	with open(filename, 'w') as fw:
		json.dump(data, fw)

def _normalize(text):
    return unicodedata.normalize('NFD', text)

def match(table_name, dump, dumpfile):
	table = read_json(table_name)
	tokenizer = SimpleTokenizer()

	# top1 = [int(l.strip('\n')) for l in open('top1_idx.txt')]
	# top5 = [int(l.strip('\n')) for l in open('top5_not1_idx.txt')]


	if dump:
		new_table = []
	em_cnt = 0
	not100_idx = []
	for idx, t in enumerate(table):
		em = False
		for context in t['ctxs']:
			text = context['text']
			if has_answer(t['answers'], text, tokenizer):
				em = True
				break

		if not em:
			em_cnt += 1
			not100_idx.append(idx)
			if dump:
				new_table.append(t)				



	if dump:
		write_json(dumpfile, new_table)

	print(em_cnt)
	print(len(not100_idx))
	print(len(table))

	#####
	f = open('not100_idx.txt', 'w')
	for l in not100_idx:
		f.write(str(l))
		f.write('\n')
		
def has_answer(answers, text, tokenizer) -> bool:
    """Check if a document contains an answer string."""
    #for ans in answers:
   # 	if text.find(ans) != -1:
   # 		return True
    text = _normalize(text)
    text = tokenizer.tokenize(text, uncased=True)

    for answer in answers:
        answer = _normalize(answer)
        answer = tokenizer.tokenize(answer, uncased=True)
        for i in range(0, len(text) - len(answer) + 1):
            if answer == text[i: i + len(answer)]:
                return True
    return False


def cal_incorr_ret_corr_ans(data_file, prediction_file, incorr_index_file):
	incorr_index = [int(l.strip('\n')) for l in open(incorr_index_file)]
	data = json.load(open(data_file))
	predictions = [l.strip('\n').split('\t')[1] for l in open(prediction_file)]
	# print(predictions)
	assert len(predictions) == len(data)
	corr_ans_cnt = 0
	total = 0
	for i in range(len(data)):
		
		if i in incorr_index:
			total += 1
		# if True:
			answers = data[i]['answers']
			pred = predictions[i]
			corr = False
			for ans in answers:
				if compute_exact(ans, pred):
					corr = True
					break 
			if corr:
				corr_ans_cnt += 1
	print(corr_ans_cnt)
	print(total)

def find_first_incorrect_passage(data_file):
	data = json.load(open(data_file))
	tokenizer = SimpleTokenizer()

	new_data = []
	for d in data:
		new_d = {k:v for k, v in d.items() if k != 'ctxs'}	

		assert has_answer(d['answers'], d['ctxs'][0]['text'], tokenizer)
		for ctx in d['ctxs']:
			if not has_answer(d['answers'], ctx['text'], tokenizer):
				new_d['ctxs'] = [ctx]
				break
		new_data.append(new_d)

	return new_data


# evaluation for FiD
def evaluate(data_file, prediction_file, outputfile_path):
	data = json.load(open(data_file))
	predictions = [l.split('\t')[1] for l in open(prediction_file)]

	em = 0
	total = len(predictions)
	assert len(data) == total

	fw = open(outputfile_path, 'w')
	for i, pred in enumerate(predictions):
		correct = False
		for ans in data[i]['answers']:
			if compute_exact(ans, pred):
				em += 1
				correct = True
				break
		fw.write(str(correct))
		fw.write('\n')

	print('match/total: %d/%d | em = %f'%(em, total, float(em)/total))




if __name__ == '__main__':
	# match(sys.argv[1], True, 'NQ/dev_not100.json')

	# cal_incorr_ret_corr_ans(sys.argv[1], sys.argv[2], sys.argv[3])

	new_data = find_first_incorrect_passage(sys.argv[1])
	fw = open(sys.argv[2], 'w')
	fw.write(json.dumps(new_data, indent=4))

	# evaluate(sys.argv[1], sys.argv[2], sys.argv[3])



	# table = read_json(sys.argv[1])
	# top1 = [int(l.strip('\n')) for l in open('top1_idx.txt')]
	# top5 = [int(l.strip('\n')) for l in open('top5_not1_idx.txt')]
	# top20 = [int(l.strip('\n')) for l in open('top20_not5_idx.txt')]

	# top20_all = top1+top5+top20
	# not20_idx = []
	# new_table = []
	# for idx, t in enumerate(table):
	# 	if idx not in top20_all:
	# 		new_table.append(t)
	# 		not20_idx.append(idx)
	# print(len(new_table))
	# write_json('NQ/dev_ret_fail.json', new_table)

	# f = open('not20_idx.txt', 'w')
	# for l in not20_idx:
	# 	f.write(str(l))
	# 	f.write('\n')


