from . import shap
import torch
from torch import nn


class BertForSequenceClassificationWrapper(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model

    def forward(self, word_vecs, attention_masks=None, segment_ids=None):
        if attention_masks is None:
            attention_masks = torch.ones(word_vecs.size(0), word_vecs.size(1)).long().cuda()
            segment_ids = torch.zeros(word_vecs.size(0), word_vecs.size(1)).long().cuda()
        out = self.base(word_vecs, attention_masks, segment_ids)
        return out


class ShapExplainerForBert:
    def __init__(self, model, tokenizer, configs, train_iterator, output_path):
        self.model = model
        self.model_wrapper = BertForSequenceClassificationWrapper(self.model)
        self.tokenizer = tokenizer
        self.max_seq_length = 128
        self.batch_size = configs.batch_size
        self.train_iterator = train_iterator
        self.bg = self.compute_bg()
        self.output_path = output_path

        self.output_file = open(self.output_path, 'w')

    def compute_bg(self):
        examples = []
        s = 0
        for idx, (input_ids, input_mask, segment_ids, label_ids) \
                in enumerate(self.train_iterator):
            if s >= 50:
                break
            else:
                examples.append(input_ids)
            s += input_ids.size(0)
        input_ids = torch.cat(examples, 0).cuda() # [B, T]
        bg = self.model_wrapper.base.bert.embeddings(input_ids)
        return bg

    def word_level_explanation_bert(self, input_ids, attention_mask, segment_ids, gt_label=None):
        embed_layer = self.model_wrapper.base.bert.embeddings
        word_vecs = embed_layer(input_ids, segment_ids)

        if self.bg is None:
            self.compute_bg()
        bg = self.bg
        explainer = shap.GradientExplainer(self.model_wrapper, bg)
        shap_values = explainer.shap_values(word_vecs, attention_mask=attention_mask, segment_ids=segment_ids)

        shap_values = shap_values[1] - shap_values[0]
        shap_by_word = shap_values.sum(-1).reshape(-1)  # [T]

        print(shap_by_word)

        i = 1
        spans, scores = [], []
        while i < input_ids.size(1) and input_ids[0,i] != 0:
            spans.append((i, i))
            scores.append(shap_by_word[i])
            i += 1
        inp = input_ids.view(-1).cpu().numpy()
        s = self.repr_result_region(inp, spans, scores)
        self.output_file.write(s + '\n')

    def repr_result_region(self, inp, spans, contribs, gt_label=None):
        tokens = self.tokenizer.convert_ids_to_tokens(inp)
        outputs = []
        assert (len(spans) == len(contribs))
        for span, contrib in zip(spans, contribs):
            outputs.append((' '.join(tokens[span[0]:span[1] + 1]), contrib))
        output_str = ' '.join(['%s %.6f\t' % (x, y) for x, y in outputs])
        if gt_label is not None:
            output_str = self.label_vocab.itos[gt_label] + '\t' + output_str
        return output_str