import numpy as np
import re
import itertools, csv
from collections import Counter
from nltk.tokenize import TweetTokenizer

def clean_str(string):
    tokenizer = TweetTokenizer()
    string = ' '.join(tokenizer.tokenize(string))
    string = re.sub(r"[-.#\"/]", " ", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'m", " \'m", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r"\'(?!(s|ve|t|re|d|ll))", " ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip()

def load_data_and_labels(positive_data_file, negative_data_file):

    # Load data from files
    positive_examples = list(open(positive_data_file, "r").readlines())
    positive_examples = [s.strip() for s in positive_examples]
    negative_examples = list(open(negative_data_file, "r").readlines())
    negative_examples = [s.strip() for s in negative_examples]
    # Split by words
    x_text = positive_examples + negative_examples
    x_text = [clean_str(sent) for sent in x_text]
    # Generate labels
    positive_labels = [[0, 1] for _ in positive_examples]
    negative_labels = [[1, 0] for _ in negative_examples]
    y = np.concatenate([positive_labels, negative_labels], 0)
    return [x_text, y]

def load_data_and_labels_quora(data_file):
    x1 = []
    x2 = []
    y = []
    with open(data_file) as raw_data:
        reader = csv.reader(raw_data, delimiter="\t")
        for line in reader:
            if line[0] == '0':
                y.append([1, 0])
            elif line[0] == '1':
                y.append([0, 1])
            else:
                print(line)
                print('error')
                continue
            x1.append(line[1])
            x2.append(line[2])
    return x1, x2, np.array(y)



def load_data_and_labels_cikm(data_file):
    x1 = []
    x2 = []
    y = []
    with open(data_file) as raw_data:
        reader = raw_data.readlines()
        data = [item.split('\t') for item in reader]
        x1 = [clean_str(item[0]) for item in data]
        x2 = [clean_str(item[2]) for item in data]
        y = [[1, 0] if int(item[4]) == 0 else [0, 1] for item in data]
    return x1, x2, np.array(y)

def load_data_and_labels_cikm_no_clean(data_file):
    x1 = []
    x2 = []
    y = []
    with open(data_file) as raw_data:
        reader = raw_data.readlines()
        data = [item.split('\t') for item in reader]
        x1 = [item[0] for item in data]
        x2 = [item[2] for item in data]
        y = [[1, 0] if int(item[4]) == 0 else [0, 1] for item in data]
    return x1, x2, np.array(y)

def load_data_and_labels_nli(data_file):
    x1 = []
    x2 = []
    y = []
    with open(data_file) as raw_data:
        for line in raw_data:
            if line != '\n':
                tokens = line.strip().split('\t')
                if tokens[0] == '0':
                    y.append([1, 0])
                elif tokens[0] == '1':
                    y.append([0, 1])
                else:
                    continue
                x1.append(tokens[1])
                x2.append(tokens[2])

    return x1, x2, np.array(y)

def load_address_s(data_file,data_adv_file):
    x1 = []
    x2 = []
    y = []
    with open(data_file) as raw_data:
        for line in raw_data:
            if line != '\n':
                tokens = line.strip().split('\t')
                if tokens[2] == 'neutral':
                    y.append([1, 0])
                elif tokens[2] == 'entailment':
                    y.append([0, 1])
                else:
                    continue
                x1.append(tokens[0])
                x2.append(tokens[1])
    lenlen=len(x1)
    with open(data_adv_file) as f:
        num=0
        for line in f:
            if line != '\n':
                tokens = line.strip().split('\t')
                if tokens[2] == 'neutral':
                    y.append([1, 0])
                elif tokens[2] == 'entailment':
                    y.append([0, 1])
                else:
                    continue
                x1.append(tokens[0])
                x2.append(tokens[1])
                num+=1
            if num>lenlen*2:
                break
    return x1, x2, np.array(y)

def load_mnli(data_file):
    x1 = []
    x2 = []
    y = []
    with open(data_file) as raw_data:
        for line in raw_data:
            if line != '\n':
                tokens = line.strip().split('\t')
                if tokens[2] == 'neutral':
                    y.append([1, 0])
                elif tokens[2] == 'entailment':
                    y.append([0, 1])
                else:
                    continue
                x1.append(tokens[0])
                x2.append(tokens[1])
    return x1, x2, np.array(y)

def load_data_and_labels_csv(data_file):
    x1 = []
    x2 = []
    y = []
    with open(data_file) as raw_data:
        reader = csv.reader(raw_data, delimiter=",")
        for line in reader:
            if line[5]=='0':
                y.append([1,0])
            elif line[5]=='1':
                y.append([0,1])
            else:
                print(line)
                print('error')
                continue
            x1.append(line[3])
            x2.append(line[4])

    return x1, x2, np.array(y)

def load_data_and_labels_tsv(data_file):
    x1 = []
    x2 = []
    y = []
    with open(data_file) as raw_data:
        reader = csv.reader(raw_data, delimiter="\t")
        for line in reader:
            if line[2] == 'neutral':
                y.append([1, 0])
            elif line[2] == 'entails':
                y.append([0, 1])
            else:
                print(line)
                print('tsv error')
                continue
            x1.append(line[0])
            x2.append(line[1])
    return x1, x2, np.array(y)

def batch_iter(data, batch_size, num_epochs, shuffle=True):
    data = np.array(data)
    data_size = len(data)
    num_batches_per_epoch = int((len(data)-1)/batch_size) + 1
    for epoch in range(num_epochs):
        if shuffle:
            np.random.seed(0)
            shuffle_indices = np.random.permutation(np.arange(data_size))
            shuffled_data = data[shuffle_indices]
        else:
            shuffled_data = data

        for batch_num in range(num_batches_per_epoch):
            start_index = batch_num * batch_size
            end_index = min((batch_num + 1) * batch_size, data_size)
            yield shuffled_data[start_index:end_index],shuffle_indices[start_index:end_index]
