# -*- coding: utf-8 -*-

import csv
import re


def extract_insertions(filename):
    insertions = []
    with open(filename) as csvfile:
        spamreader = csv.reader(csvfile, delimiter='\t')
        for i, row in enumerate(spamreader):
            if i == 0:
                continue
            
            insertion_pair = row[3].split(",")
            insertion_gt = row[6]
            
            insertions.append((insertion_pair[0].strip(),
                               insertion_pair[1].strip(),
                               insertion_gt))
    
    return insertions


def extract_unique_insertions(filenamelist):
    insertions = []
    for filename in filenamelist:
        insertions.extend(extract_insertions(filename))
    
    return list(set(insertions))


def extract_contexts(filename):
    contexts = []
    with open(filename) as csvfile:
        spamreader = csv.reader(csvfile, delimiter='\t')
        for i, row in enumerate(spamreader):
            if i == 0:
                continue
            
            context_gt = row[8]
            context_template = row[9]
            
            re_1 = re.search("^x\W", context_template) # starts with x
            re_2 = re.search("\Wx$", context_template) # ends with x
            re_3 = re.search("\Wx\W", context_template) # x in the middle
            
            if hasattr(re_1, "start"):
                placeholder_char = re_1.start()
            elif hasattr(re_2, "start"):
                placeholder_char = re_2.start() + 1
            elif hasattr(re_3, "start"):
                placeholder_char = re_3.start() + 1
            else:
                raise ValueError("no x was found in context: " + context_template)
            assert(context_template[placeholder_char] == "x")
            
            contexts.append((context_template,
                             placeholder_char,
                             context_gt))
    
    return contexts


def extract_unique_contexts(filenamelist):
    contexts = []
    for filename in filenamelist:
        contexts.extend(extract_contexts(filename))
    
    return list(set(contexts))


def generate_samples(context, insertions):
    template = context[0]
    placeholder = int(context[1])
    prefix = template[0:placeholder]
    suffix = template[placeholder+1:]
    
    samples = []
    for insertion in insertions:
        
        text_1 = prefix + insertion[0] + suffix
        text_2 = prefix + insertion[1] + suffix
        
        start = len(prefix)
        end_1 = start + len(insertion[0])
        end_2 = start + len(insertion[1])
        
        samples.append((text_1, text_2, start, end_1, end_2))
    
    return samples

    