import os
import sys
sys.path.append('.')
from os.path import join
from common.utils import read_json, dump_json
from collections import OrderedDict
from utils import IndexedFeature, FeatureVocab
import numpy as np
import random
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from utils import normalize_raw_prediction, prepro_data

def extract_feature_of_example(ex):
    feat = IndexedFeature()
    
    question = ex['question']
    context = ex['context']
    # print(question)
    # print(context)
    # distractor position
    NAMES = ['Scott', 'Wood', 'James']
    contain_by_q = [x for x in NAMES if x in ex['question']]
    distractor_name = ([x for x in NAMES if x not in ex['question']])[0]    
    
    pos_in_c = [context.index(x) for x in NAMES]
    # index_in_c = sorted(key=lambda x: pos_in_c[i], list(range(3)))
    name_and_pos = list(zip(NAMES, pos_in_c))
    name_and_pos.sort(key=lambda x: x[1])

    nations = []
    for i in range(3):
        if i == 2:
            segment = context[name_and_pos[i][1]:]
        else:
            segment = context[name_and_pos[i][1]: name_and_pos[i+1][1]]

        if 'American' in segment:
            nations.append('American')
        elif 'English' in segment:
            nations.append('English')
    assert len(nations) == 3

    sorted_names = [x[0] for x in name_and_pos]
    # distractor position
    distractor_position = sorted_names.index(distractor_name)
    # print(distractor_position, distractor_name)
    feat.add('dis_pos:'+str(distractor_position))
    # feat.add('dis_name:'+distractor_name)

    # is reverse
    qnames_pos_in_q  = [question.index(x) for x in contain_by_q]
    qnames_pos_in_c = [context.index(x) for x in contain_by_q]
    
    if (qnames_pos_in_q[0] < qnames_pos_in_q[1]) ^ (qnames_pos_in_c[0] < qnames_pos_in_c[1]):
        feat.add('reversed_mention')
        
    # qnames_and_pos = list(zip(contain_by_q, qnames_pos_in_q))
    # qnames_and_pos.sort(key=lambda x: x[1])
    # sorted_qnames = [x[0] for x in qnames_and_pos]
    # qname0_index_in_c = sorted_names.index(sorted_qnames[0])
    # qname1_index_in_c = sorted_names.index(sorted_qnames[1])
    # feat.add('qname0_in_c:' + str(qname0_index_in_c))
    # feat.add('qname1_in_c:' + str(qname1_index_in_c))

    # longest par
    # lonstest_par = context.index('lives')
    # lonstest_idx = 0
    # if lonstest_par > name_and_pos[1][1]:
    #     lonstest_idx = 1
    # if lonstest_par > name_and_pos[2][1]:
    #     lonstest_idx = 2
    # feat.add('longest_par_idx:' + str(lonstest_idx))
    

    
    # is long distractor
    # 0 - 1 identity
    if nations[0] == nations[1]:
        feat.add('0_1_identical')
        identity_nation = nations[0]
    # 0 - 2 identity
    if nations[0] == nations[2]:
        feat.add('0_2_identical')
        identity_nation = nations[0]
    # 1 - 2 identity
    if nations[1] == nations[2]:
        feat.add('1_2_identical')
        identity_nation = nations[1]
    if identity_nation == 'American':
        feat.add('Two_american')
    if nations[distractor_position] == identity_nation:
        feat.add('Distrator_toward_yes')

    # is sp identity
    nation_qname_0 = nations[sorted_names.index(contain_by_q[0])]
    nation_qname_1 = nations[sorted_names.index(contain_by_q[1])]
    # feat.add('first_entity_mentioned')
    # print(question)
    # print(context)
    # print(nation_qname_0,nation_qname_1)
    if nation_qname_0 == nation_qname_1:
        feat.add('gt:yes')
    # exit()
    return feat

def construct_features(data):
    
    features = OrderedDict()
    feature_vocab = FeatureVocab()
    for p in data:
        feat = extract_feature_of_example(data[p])
        features[p] = feat
        for k in feat.data:
            feature_vocab.add(k)
    return features, feature_vocab
    # features = [extract_feature_of_example(d) for d in da ta]

def split_data(X, Y, ids, ratio=0.6):
    permuted_idx = np.random.permutation(X.shape[0])
    num_train = int(X.shape[0] * ratio)
    
    train_idx = permuted_idx[:num_train]
    dev_idx = permuted_idx[num_train:]
    
    train_x = X[train_idx]
    train_y = Y[train_idx]
    train_ids = [ids[i] for i in train_idx]

    dev_x = X[dev_idx]
    dev_y = Y[dev_idx]
    dev_ids = [ids[i] for i in dev_idx]
    return train_x, train_y, train_ids, dev_x, dev_y, dev_ids

    
def make_classification_dataset(preds, features, indexer):
    labels = []
    np_features = []
    ids = []
    for k in preds:
        ids.append(k)
        labels.append(1 if preds[k] > 0.5 else 0)
        single_feat = features[k]
        val_feat = [.0] * len(indexer)
        for f, v in single_feat.data.items():            
            val_feat[indexer[f]] = v
        np_features.append(val_feat)
    
    labels = np.array(labels)
    np_features = np.array(np_features)
    return np_features, labels, ids

def extend_compositonal_features(features, feature_vocab):
    ind_feature_names = feature_vocab.get_names()
    # for name in ind_feature_names:
    #     comp_vocab.add(name)

    for i in range(len(ind_feature_names)):
        for j in range(i + 1, len(ind_feature_names)):
            feature_vocab.add('AND({},{})'.format(ind_feature_names[i], ind_feature_names[j]))
            feature_vocab.add('OR({},{})'.format(ind_feature_names[i], ind_feature_names[j]))
            feature_vocab.add('XOR({},{})'.format(ind_feature_names[i], ind_feature_names[j]))

    for _, feat in features.items():
        for i in range(len(ind_feature_names)):
            # print(feat, ind_feature_names)
            val_i = feat[ind_feature_names[i]]
            for j in range(i + 1, len(ind_feature_names)):
                val_j = feat[ind_feature_names[j]]
                
                # if (val_i > 0) and (val_j > 0):
                #     feat.add('AND({},{})'.format(ind_feature_names[i], ind_feature_names[j]))
                if (val_i > 0) or (val_j > 0):
                    feat.add('OR({},{})'.format(ind_feature_names[i], ind_feature_names[j]))
                # if (val_i > 0) ^ (val_j > 0):
                #     feat.add('XOR({},{})'.format(ind_feature_names[i], ind_feature_names[j]))
        
        
def sp_linear_cls_exp():
    raw_preds = read_json('lime/sp_case/raw_predictions.json')
    norm_preds = normalize_raw_prediction(raw_preds)
    
    raw_data = read_json('lime/sp_case/nation_perturb.json')
    raw_data = prepro_data(raw_data)
    features, feature_vocab = construct_features(raw_data)
    extend_compositonal_features(features, feature_vocab)
    print('Feature Dim', len(feature_vocab))

    np_features, np_labels, ids = make_classification_dataset(norm_preds, features, feature_vocab)
    # train_split, dev_split = split_data()
    train_x, train_y, train_ids, dev_x, dev_y, dev_ids = split_data(np_features, np_labels, ids)

    # run classifier
    clf = LogisticRegression(random_state=666, C=10.0, max_iter=200).fit(train_x, train_y)
    train_pred = clf.predict(train_x)
    dev_pred = clf.predict(dev_x)
    print('Train ACC', np.sum(train_pred == train_y)/train_pred.size,
        'Dev Acc', np.sum(dev_pred == dev_y) / dev_pred.size)
    
def sp_tree_cls_exp():
    raw_preds = read_json('lime/sp_case/raw_predictions.json')
    norm_preds = normalize_raw_prediction(raw_preds)
    
    raw_data = read_json('lime/sp_case/nation_perturb.json')
    raw_data = prepro_data(raw_data)
    features, feature_vocab = construct_features(raw_data)
    # extend_compositonal_features(features, feature_vocab)
    # extend_compositonal_features(features, feature_vocab)
    print('Feature Dim', len(feature_vocab))
    np_features, np_labels, ids = make_classification_dataset(norm_preds, features, feature_vocab)
    # train_split, dev_split = split_data()
    train_x, train_y, train_ids, dev_x, dev_y, dev_ids = split_data(np_features, np_labels, ids)

    # run classifier
    print(train_y.size, dev_y.size)
    clf = DecisionTreeClassifier(random_state=100, max_depth=3).fit(train_x, train_y)
    train_pred = clf.predict(train_x)
    dev_pred = clf.predict(dev_x)
    print('Tree', 'Train ACC', np.sum(train_pred == train_y)/train_pred.size,
        'Dev Acc', np.sum(dev_pred == dev_y) / dev_pred.size)
    import graphviz 
    dot_data = tree.export_graphviz(clf, out_file=None, feature_names=feature_vocab.get_names(), class_names=['no', 'yes']) 
    graph = graphviz.Source(dot_data) 
    graph.render("nation") 

    # sp_analyze_data(train_x, train_y, train_ids, feature_vocab, raw_data, norm_preds)

def sp_analyze_data(X, Y, ids, feature_vocab, raw_data, norm_preds):
    get_feature_val = lambda a, b: a[feature_vocab[b]]
    selected_ids = []
    for xi, yi, id in zip(X, Y, ids):
        if get_feature_val(xi, '0_2_identical') >= 0.5:
            continue
        selected_ids.append(id)
    
    for id in selected_ids:
        if norm_preds[id] < 0.5:
            continue
        print('--------------------------')
        for k, v in raw_data[id].items():
            print(k+':', v)
        print('teacher pred: {}\n'.format('yes' if norm_preds[id]>0.5 else 'no'))
    print('total num:', len(selected_ids), 'Pred as yes:', sum([norm_preds[i] > 0.5 for i in selected_ids]))        

if __name__ == "__main__":

    random.seed(2333)
    np.random.seed(2333)

    # verify_hypothesis()
    # sp_linear_cls_exp()
    sp_tree_cls_exp()