import argparse
import random
import string


_IGNORE_TYPE = {"noop", "UNK", "Um"}
_EDIT_START = 0
_EDIT_END = 1
_EDIT_TYPE = 2
_EDIT_COR = 3


def perturb_spell_word(word):
    # avoid perturbing capitalized word
    if word.istitle() or len(word) <= 1:
        return None
    char_idx = random.randrange(len(word)-1)
    word = list(word)
    i = random.randrange(3)
    if i == 0:
        word[char_idx:char_idx] = [random.choice(string.ascii_lowercase)]
    elif i == 1:
        word[char_idx:char_idx+1] = []
    elif i == 2:
        tmp = word[char_idx]
        word[char_idx] = word[char_idx+1]
        word[char_idx+1] = tmp
    
    return ''.join(word)


def perturb_spelling(text, num_words):
    words = text.split(' ')
    sampled = sorted(random.sample(list(range(len(words))), num_words))
    cur_loop_idx = -1
    for loop_idx, word_id in enumerate(sampled):
        rep_idx = word_id
        new_word = perturb_spell_word(words[word_id])
        if new_word is None:
            cur_loop_idx = max(loop_idx, cur_loop_idx + 1)
            while cur_loop_idx < num_words - 1 and \
                sampled[cur_loop_idx+1] == sampled[cur_loop_idx]+1: # finding next non-sampled word
                cur_loop_idx += 1
            end_idx = len(words) if cur_loop_idx >= len(sampled) - 1 else sampled[cur_loop_idx+1]
            mid_sample_idx = (sampled[cur_loop_idx] + end_idx) // 2
            new_word = perturb_spell_word(words[mid_sample_idx])
            rep_idx = mid_sample_idx
        
        if new_word is not None:
            words[rep_idx] = new_word
    
    return ' '.join(words)


def apply_edits(source, edits, offset=0):
    if isinstance(source, str):
        source = source.split(' ')
    result, offset = apply_edits_list(source, edits, offset)
    return ' '.join(result)


def apply_edits_list(source, edits, offset=0):
    for edit in edits:
        e_start = edit[_EDIT_START]
        e_end = edit[_EDIT_END]
        rep_token = edit[_EDIT_COR]

        e_cor = rep_token.split()
        len_cor = 0 if len(rep_token) == 0 else len(e_cor)
        source[e_start + offset:e_end + offset] = e_cor
        offset = offset - (e_end - e_start) + len_cor
    return source, offset


def read_m2(filepath, filter_idx=None):
    with open(filepath, encoding='utf-8') as f:
        m2_entries = f.read().strip().split('\n\n')
    
    if filter_idx is not None:
        m2_entries = [m2_entries[i] for i in filter_idx]
        # m2_entries = [m for i, m in enumerate(m2_entries) if i in filter_idx]
    parsed_data = []
    for m2_entity in m2_entries:
        m2_lines = m2_entity.split('\n')
        source = m2_lines[0][2:]
        edits = []
        for m2_line in m2_lines[1:]:
            if not m2_line.startswith("A"):
                raise ValueError("{} is not an m2 edit".format(m2_line))
            m2_line = m2_line[2:]
            features = m2_line.split("|||")
            span = features[0].split()
            start, end = int(span[0]), int(span[1])
            error_type = features[1].strip()
            if error_type in _IGNORE_TYPE:
                continue
            replace_token = features[2]
            edits.append((start, end, error_type, replace_token))
        parsed_data.append({'source': source, 'edits': edits})
    
    return parsed_data


def main(args):
    data = read_m2(args.data)
    sources = [] # targets = []
    final_sent = []
    sent_count = []
    perturbed_sent_count = 0
    normal_edited_sent_count = 0
    for ent in data:
        fil_edits = []
        edit_count = 0
        for edit in ent['edits']:
            start, end, error_type, rep = edit
            error_type = error_type[2:] # ignore the operation type
            if error_type.startswith(args.type):
                fil_edits.append(edit)
                edit_count += 1
            elif error_type == 'UNK' and random.random() <= args.unk_thres:
                fil_edits.append(edit)
                edit_count += 1
            elif error_type.startswith('OTHER') and random.random() <= args.oth_thres:
                fil_edits.append(edit)
                edit_count += 1
        
        if edit_count == 0 and random.random() <= args.perturb_perc:
            if args.type in ["SPELL"]:
                sources.append(perturb_spelling(ent['source'], args.perturb_word))
                final_sent.append(ent['source'])
                # targets.append(apply_edits(ent['source'], ent['edits']))

                perturbed_sent_count += 1
            else:
                raise ValueError("Perturbation with error type {} is not found."
                " Please set --perturb_perc to 0".format(args.type))
        
        elif len(fil_edits) > 0 or random.random() < args.identity_perc:
            if args.apply_side == "target":
                sources.append(ent['source'])
                final_sent.append(apply_edits(ent['source'], fil_edits))
            else:
                src_edits = [e for e in ent['edits'] if e not in fil_edits]
                sources.append(apply_edits(ent['source'], src_edits))
                final_sent.append(apply_edits(ent['source'], ent['edits']))
            # targets.append(apply_edits(ent['source'], ent['edits']))
        
        if edit_count > 0:
            sent_count.append(edit_count)
        if len(ent['edits']) > 0:
            normal_edited_sent_count += 1
    
    num_edited_line = len(sent_count)
    print('Edit per sentence: ', sum(sent_count) / float(num_edited_line))
    print('Num sentences edit applied: {}/{}'.format(num_edited_line, len(sources)))
    print('Num sentences perturbed: {}/{}'.format(perturbed_sent_count, len(sources)))
    print('Num erroneous sentences: {}/{}'.format(perturbed_sent_count + num_edited_line, len(sources)))
    print('Normal edited sent: {}/{}'.format(normal_edited_sent_count, len(data)))
    cor_path = args.output + '.ori'
    with open(cor_path, 'w', encoding='utf-8') as out:
        out.write('\n'.join(sources))
    cor_path = args.output + '.cor'
    with open(cor_path, 'w', encoding='utf-8') as out:
        out.write('\n'.join(final_sent))


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', help='path to the data')
    parser.add_argument('--apply_side', help='path to the data')
    parser.add_argument('--type', help='error type to be included')
    parser.add_argument('--perturb_word', type=int, default=0,
        help="number of words to be perturbed within a sentence")
    parser.add_argument('--perturb_perc', type=float, default=0,
        help="percentage of perturbed sentences")
    parser.add_argument('--identity_perc', type=float, default=1,
        help="percentage of non-erroneous sentences")
    parser.add_argument('--unk_thres', type=float, default=0.2, help="rate to include unknown errors")
    parser.add_argument('--oth_thres', type=float, default=0, help="rate to include other errors")
    parser.add_argument('--output', help='path to the data')
    return parser.parse_args()


if __name__ == "__main__":
    args = get_arguments()
    main(args)
