import numpy as np
from os.path import join
from random import shuffle
from sari.SARI import SARIsent
import torch
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction,\
    sentence_bleu
from sys import getsizeof
from rouge_score import rouge_scorer
from torch import sigmoid
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import csv
import os
from random import randint
from torch.nn.modules.loss import BCEWithLogitsLoss


class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)


def glove_dict(path, dim=300):

    glove_dict = {}
    # read the pretrained embeddings
    with open(path, encoding="utf8") as f:
        for line in f:
            word, vec = line.split(' ', 1)
            # word=word.lower()
            glove_dict[word] = np.array(list(map(float, vec.split())))
    return glove_dict


def word_index_mapping(vocab):
    word2index = {}
    index2word = {}
    for index, word in enumerate(vocab):
        word2index[word] = index + 1  # plus 1, since the 0th index is padding
        index2word[index + 1] = word
    return word2index, index2word


def pretty_print_prediction(input, gold_output, predicted_output):
    print("\n\n\n")
    print("Input: ", input)
    print("Output: ", predicted_output)
    print("Gold: ", gold_output)


def read_all(path):
    with open(path, 'r') as f:
        all_examples = f.readlines()
        all_examples = [t.strip() for t in all_examples]
    return all_examples


def read_file(path, fraction):
    all_examples = read_all(path)

    num_examples = int(fraction * len(all_examples))
    all_examples = all_examples[:num_examples]
    return all_examples


def _prepare_styletransfer(params):

    # load the binary classifier
    if params.binary_classifier_path:
        # load the binary classifier
        if params.binary_classifier_path == "no_eval":
            params.binary_classifier = -1  # do not eval binary accuracy
        else:
            params.binary_tokenizer = AutoTokenizer.from_pretrained(
                params.binary_classifier_path)
            params.binary_classifier = AutoModelForSequenceClassification.from_pretrained(
                params.binary_classifier_path)
    else:
        params.binary_classifier = None
    params.current_epoch = 0


def get_data(params):
    if "script" in params.dataset_path:
        return _get_data_pairs(params), evaluate_bleu_script
    elif "nli" in params.dataset_path:
        return _get_data_pairs(params), evaluate_bleu_nli
    elif "wiki" in params.dataset_path:

        # load the binary classifier
        if params.binary_classifier_path == "no_eval":
            params.binary_classifier = -1  # do not eval binary accuracy
        elif params.binary_classifier_path is not None:
            params.binary_tokenizer = AutoTokenizer.from_pretrained(
                params.binary_classifier_path)
            params.binary_classifier = AutoModelForSequenceClassification.from_pretrained(
                params.binary_classifier_path)
        else:
            params.binary_classifier = None
        params.current_epoch = 0
        return _get_data_pairs(params), evaluate_wiki
    elif "mt" in params.dataset_path:
        return _get_data_pairs(params), evaluate_bleu_mt
    elif "abssum" in params.dataset_path:
        params.max_rouge = 0.
        params.current_epoch = 0

        # load the binary classifier
        if params.binary_classifier_path:
            # load the binary classifier
            if params.binary_classifier_path == "no_eval":
                params.binary_classifier = -1  # do not eval binary accuracy
            else:
                params.binary_tokenizer = AutoTokenizer.from_pretrained(
                    params.binary_classifier_path)
                params.binary_classifier = AutoModelForSequenceClassification.from_pretrained(
                    params.binary_classifier_path)
        else:
            params.binary_classifier = None
        return _get_data_pairs(params), evaluate_rouge_abssum
    elif "fullyelp" in params.dataset_path:
        _prepare_styletransfer(params)
        return _get_data_pairs(params), get_eval_function_by_eval_data("/anon/temp/anon/fullyelp-short/")
    elif "yelp" in params.dataset_path:
        _prepare_styletransfer(params)
        return _get_data_pairs(params), evaluate_yelp
    elif "shakespeare" in params.dataset_path:
        _prepare_styletransfer(params)
        return _get_data_pairs(params), evaluate_shakespeare
    else:
        raise ValueError("Don't know dataset " + str(params.dataset_path))


def evaluate_bleu_script(model, mode="valid", params=None):

    if mode == "valid":
        inputs = "../data/script-inference/s1.dev"
        ref = "../data/script-inference/s2.dev"
    elif mode == "test":
        inputs = "../data/script-inference/s1.test"
        ref = "../data/script-inference/s2.test"

    inputs = read_file(inputs, params.test_data_fraction)
    ref = read_file(ref, params.test_data_fraction)
    ref = [[r] for r in ref]

    return evaluate_bleu(model, inputs, ref, params.batch_size, 0 if not params.print_outputs else params.max_prints)


def evaluate_rouge_abssum(model, mode="valid", params=None):

    if mode == "valid":
        inputs = "../data/abssum/s1.dev"
        ref = "../data/abssum/s2.dev"
    elif mode == "test":
        inputs = "../data/abssum/s1.test"
        ref = "../data/abssum/s2.test"

    if params.invert_style:
        tmp = inputs
        inputs = ref
        ref = tmp

    inputs = read_file(inputs, params.test_data_fraction)
    ref = read_file(ref, params.test_data_fraction)

    model.test_out_lens_sum = 0
    model.in_output_distance_mean = 0.
    model.in_output_distance_min = 0.
    model.in_output_distance_batchsize = 0
    r_score, predictions = evaluate_rouge(model, inputs, ref, params.batch_size,
                                          0 if not params.print_outputs else params.max_prints)
    avg_out_lens = model.test_out_lens_sum / len(inputs)
    in_output_distance_mean = model.in_output_distance_mean / \
        len(inputs)
    in_output_distance_min = model.in_output_distance_min / \
        len(inputs)
    print("Average lens of Emb2emb outputs:", avg_out_lens)

    avg_pred_lens = np.array([len(s.split(" ")) for s in predictions]).mean()
    b_acc = eval_binary_accuracy(model, predictions, mode, params)

    if params.eval_self_bleu:
        self_bleu = evaluate_bleu(model, inputs, [
                                  [i] for i in inputs], params.batch_size, max_prints=0, return_predictions=False, predictions=predictions)
    else:
        self_bleu = None
    _save_to_csv(params, b_acc=b_acc, sari=None,
                 bleu=None, self_bleu=self_bleu, rouge=r_score, average_pred_lens=avg_pred_lens,
                 in_output_distance_mean=in_output_distance_mean, in_output_distance_min=in_output_distance_min)

    write_dev_set_predictions(params, inputs, predictions, ref)

    params.current_epoch = params.current_epoch + 1

    return r_score['rougeL']


def evaluate_bleu_mt(model, mode="valid", params=None):

    if mode == "valid":
        inputs = "../data/mt/de-en/s1.dev"
        ref = "../data/mt/de-en/s2.dev"
    elif mode == "test":
        inputs = "../data/mt/de-en/s1.test"
        ref = "../data/mt/de-en/s2.test"

    inputs = read_file(inputs, params.test_data_fraction)
    ref = read_file(ref, params.test_data_fraction)
    ref = [[r] for r in ref]

    return evaluate_bleu(model, inputs, ref, params.batch_size, 0 if not params.print_outputs else params.max_prints)


def evaluate_styletransfer(model, mode, params, predictions, inputs, ref):

    if type(inputs) == str:
        inputs = read_file(inputs, params.test_data_fraction)
    if type(ref) == str:
        ref = read_file(ref, params.test_data_fraction)
        ref = [[r] for r in ref]

    if not params.sentence_wise_evaluation:

        self_bleu, predictions = evaluate_bleu(model, inputs, ref, params.batch_size,
                                               0 if not params.print_outputs else params.max_prints,
                                               return_predictions=True, predictions=predictions)
        b_acc = eval_binary_accuracy(model, predictions, mode, params)

        _save_to_csv(params, self_bleu=self_bleu, b_acc=b_acc)

        val_score = self_bleu + b_acc
        sb = self_bleu
        b_acc = b_acc

    else:

        # self_bleu is not a single number, but a single number per datapoint
        self_bleu, predictions = evaluate_bleu(model, inputs, ref, params.batch_size,
                                               0 if not params.print_outputs else params.max_prints,
                                               return_predictions=True, predictions=predictions, sentencewise=True)
        self_bleu = np.array(self_bleu)

        # b_acc is not a single number, but a single number per datapoint
        b_acc = np.array(eval_binary_accuracy(
            model, predictions, mode, params, sentencewise=True))

        fluency = np.array(eval_fluency(model, predictions, mode, params))

        style_transfer_score = (self_bleu * b_acc * fluency).mean()

        _save_to_csv(params, self_bleu=self_bleu.mean(), b_acc=b_acc.mean(),
                     style_transfer_score=style_transfer_score)

        val_score = style_transfer_score
        sb = self_bleu.mean()
        b_acc = b_acc.mean()

    write_dev_set_predictions(params, inputs, predictions, [r[0] for r in ref])
    params.current_epoch = params.current_epoch + 1
    return val_score, sb, b_acc


def eval_fluency(model, predictions, mode, params):
    # TODO: how do we implement fluency?
    return [1.] * len(predictions)


def write_dev_set_predictions(params, inputs, predictions, gold):
    """
    Writes the final predictions for dev and test-set
    """

    test_prediction_path = os.path.join(
        params.outputdir, str(params.run_id) + ".outputs")

    if params.current_epoch >= params.n_epochs and params.save_dev_set_predictions:
        print("Writing out dev set predictions...")
        # we're saving the predictions on the final evaluation on the
        # validation set (where the best model is loaded)
        ending = ".dev" if params.current_epoch == params.n_epochs else ".test"
        with open(test_prediction_path + ending, 'w') as outf:
            print("\t".join(["input", "prediction", "gold"]), file=outf)
            for i, input_text in enumerate(inputs):
                prediction_text = predictions[i]
                print(input_text, prediction_text, gold[i],
                      sep="\t", file=outf)


def evaluate_shakespeare(model, mode="valid", params=None, predictions=None):

    # compute bleu with input
    if mode == "valid":
        data = "../data/shakespeare/s1.dev" if not params.invert_style else "../data/shakespeare/s2.dev"
    elif mode == "test":
        data = "../data/shakespeare/s1.test" if not params.invert_style else "../data/shakespeare/s2.test"

    inputs = data
    ref = data

    results = evaluate_styletransfer(
        model, mode, params, predictions, inputs, ref)

    return results[0], results[1], results[2]


def evaluate_yelp(model, mode="valid", params=None, predictions=None):

    # compute bleu with input
    if mode == "valid":
        data_path = "../data/yelp/s1.dev" if not params.invert_style else "../data/yelp/s2.dev"
        inputs = read_file(data_path, params.test_data_fraction)
        ref = [[r] for r in inputs]
    elif mode == "test":
        if not params.invert_style:
            in_file = "../data/yelp/sentiment.test.0.human"
        else:
            in_file = "../data/yelp/sentiment.test.1.human"
        inputs, ref = [], []
        with open(in_file, "r") as sentiment_test:
            for l in sentiment_test:
                t = l.strip().split("\t")
                inputs.append(t[0])
                ref.append([t[1]])

    results = evaluate_styletransfer(
        model, mode, params, predictions, inputs, ref)

    return results[0], results[1], results[2]


def get_eval_function_by_eval_data(eval_root_path):

    def eval_f(model, mode="valid", params=None, predictions=None):

        # compute bleu with input
        if mode == "valid":
            data = os.path.join(eval_root_path, "s1.dev") if not params.invert_style else os.path.join(
                eval_root_path, "s2.dev")
        elif mode == "test":
            data = os.path.join(eval_root_path, "s1.test") if not params.invert_style else os.path.join(
                eval_root_path, "s2.test")

        inputs = data
        ref = data

        results = evaluate_styletransfer(
            model, mode, params, predictions, inputs, ref)

        return results[0], results[1], results[2]

    return eval_f


def _save_to_csv(params, b_acc=None, sari=None, bleu=None, self_bleu=None, rouge=None,
                 average_pred_lens=None,
                 in_output_distance_mean=None,
                 in_output_distance_min=None,
                 style_transfer_score=None):
    scores = {"run_id": params.run_id,
              "epoch": params.current_epoch,
              "bleu": bleu,
              "sari": sari,
              "self-bleu": self_bleu,
              "b-acc": b_acc,
              "avg-pred-lens": average_pred_lens,
              "in_out_mean": in_output_distance_mean,
              "in_out_min": in_output_distance_min,
              "style_transfer_score": style_transfer_score}
    if rouge is not None:
        scores.update(rouge)
    write_to_csv(scores,
                 params)


def write_to_csv(score, opt, escaped_keys=["binary_classifier", "binary_tokenizer", "latent_binary_classifier", "critic", "critic_optimizer", "critic_loss", "real_data"], output_file_path=None):
    """
    Writes the scores and configuration to csv file.
    """
    fpath = output_file_path if output_file_path is not None else opt.output_file
    f = open(fpath, 'a')
    if os.stat(fpath).st_size == 0:
        for i, (key, _) in enumerate(opt.__dict__.items()):
            f.write(key + ";")
        for i, (key, _) in enumerate(score.items()):
            if i < len(score.items()) - 1:
                f.write(key + ";")
            else:
                f.write(key)
        f.write('\n')
        f.flush()
    f.close()

    f = open(fpath, 'r')
    reader = csv.reader(f, delimiter=";")
    column_names = next(reader)
    f.close()

    def clean_str(s):
        return s.replace("\n", "")

    f = open(fpath, 'a')
    for i, key in enumerate(column_names):
        if i < len(column_names) - 1:
            if key in opt.__dict__:
                if key in escaped_keys:
                    val_str = ""
                else:
                    val_str = str(opt.__dict__[key])
                    val_str = clean_str(val_str)
                f.write(val_str + ";")
            else:
                f.write(str(score[key]) + ";")
        else:
            if key in opt.__dict__:
                val_str = str(opt.__dict__[key])
                f.write(clean_str(val_str))
            else:
                f.write(str(score[key]))
    f.write('\n')
    f.flush()
    f.close()


def _model_name_from_params(params):
    return ";".join([p + ":" + str(getattr(params, p)) for p in ["emb2emb", "adversarial_regularization", "adversarial_reconstruction_weight", "lambda_clfloss", "real_data_path", "lambda_schedule"]])


def eval_binary_accuracy(model, predictions, mode="valid", params=None, sentencewise=False):
    target = 0 if params.invert_style else 1
    if params.binary_classifier is not None:

        if params.binary_classifier == -1:
            return 0.

        total_count = len(predictions)
        tokenizer = params.binary_tokenizer
        model = params.binary_classifier
        model.eval()
        correct = 0.

        if sentencewise:
            all_predictions = []
        for stidx in range(0, len(predictions), params.batch_size):
            # prepare batch
            predictions_batch = predictions[stidx:(stidx + params.batch_size)]

            predictions_batch = tokenizer.batch_encode_plus(
                predictions_batch, return_tensors="pt", pad_to_max_length=True,
                max_length=512)
            # returns logits, hidden_states
            predictions_batch = model(**predictions_batch)
            predictions_batch = predictions_batch[0]  # get logits

            predictions_batch = torch.softmax(predictions_batch, dim=1)
            predictions_batch = predictions_batch[:, target]
            b_acc = (predictions_batch > 0.5).sum()
            if sentencewise:
                all_predictions.extend(
                    (predictions_batch > 0.5).cpu().numpy().tolist())

            correct = correct + b_acc.item()

        if not sentencewise:
            result = correct / float(total_count)
        else:
            result = all_predictions

        return result
    else:

        model.eval()
        binary_classifier = model.loss_fn.classifier

        batch_size = params.batch_size
        target = 0  # we want to generate from the "fake distribution" labeled "0"
        correct = 0
        for stidx in range(0, len(predictions), batch_size):
            # prepare batch
            Sx_batch = predictions[stidx:stidx + batch_size]
            # model forward

            if params.emb2emb == "bovtobov":
                clf_predictions = model.compute_emb2emb(Sx_batch)
                preds, preds_len = clf_predictions[0], clf_predictions[2]
                clf_predictions = binary_classifier(
                    preds, preds_len)

                # second entry is the output
                clf_predictions = clf_predictions[..., 1]
                clf_predictions = sigmoid(clf_predictions)
            else:
                clf_predictions = model.compute_emb2emb(Sx_batch)[0]
                clf_predictions = sigmoid(binary_classifier(clf_predictions))

            if target == 1:
                b_acc = (clf_predictions > 0.5).sum()
            elif target == 0:
                b_acc = (clf_predictions < 0.5).sum()
            correct = correct + b_acc.item()

        return correct / float(len(predictions))


def evaluate_bleu_nli(model, mode="valid", params=None):

    if mode == "valid":
        inputs = "../data/nli/s1.dev"
        ref = "../data/nli/s2.dev"
    elif mode == "test":
        inputs = "../data/nli/s1.test"
        ref = "../data/nli/s2.test"

    inputs = read_file(inputs, params.test_data_fraction)
    ref = read_file(ref, params.test_data_fraction)
    ref = [[r] for r in ref]

    return evaluate_bleu(model, inputs, ref, params.batch_size, 0 if not params.print_outputs else params.max_prints)


def bleu_tokenize(s):
    return s.split()


def evaluate_bleu(model, input_sentences, reference_sentences, batch_size, max_prints, return_predictions=False, predictions=None, sentencewise=False):
    model.eval()

    if predictions is None:
        pred_outputs = _get_predictions(
            model, input_sentences, reference_sentences, batch_size, max_prints)
    else:
        pred_outputs = predictions

    # corpus_bleu(list_of_references, hypotheses) # list_of_refereces : list
    # of list of list of str, hypotheses list of list of str
    list_of_references = []
    for refs in reference_sentences:
        new_refs = []
        for r in refs:
            new_refs.append(bleu_tokenize(r))
        list_of_references.append(new_refs)

    pred_outputs_bleu = [bleu_tokenize(h) for h in pred_outputs]

    if sentencewise:
        scores = []
        for i in range(len(list_of_references)):
            s = sentence_bleu(
                list_of_references[i], hypothesis=pred_outputs_bleu[i], smoothing_function=SmoothingFunction().method1)
            scores.append(s)
        score = scores
    else:
        # TODO: do we need a smoothing function?
        score = corpus_bleu(list_of_references, pred_outputs_bleu,
                            smoothing_function=SmoothingFunction().method1)
    if return_predictions:
        return score, pred_outputs
    else:
        return score


def _get_predictions(model, input_sentences, reference_sentences, batch_size, max_prints):
    model.eval()

    pred_outputs = []
    i = 1
    for i, stidx in enumerate(range(0, len(input_sentences), batch_size)):
        if i % 10 == 0:
            print("Eval progress:", float(stidx) / len(input_sentences))

        # prepare batch
        Sx_batch = input_sentences[stidx:stidx + batch_size]
        Sy_batch = flatten(reference_sentences[stidx:stidx + batch_size])
        # model forward
        with torch.no_grad():
            pred_outputs.extend(model(Sx_batch, Sy_batch))

    for i in range(min(len(input_sentences), max_prints)):
        pretty_print_prediction(
            input_sentences[i], reference_sentences[i][0], pred_outputs[i])

    # FIXME: quickfix to remove wrong EOS
    pred_outputs = [p[:p.find("EOS")] if p.find(
        "EOS") != -1 else p for p in pred_outputs]
    return pred_outputs


def flatten(t): return [item for sublist in t for item in sublist]


def write_predictions(path, inputs, references, pred_outputs):
    with open(path, 'w') as csvfile:
        outwriter = csv.writer(csvfile, delimiter=';', quotechar='|')
        outwriter.writerow(['input', 'prediction', 'gold'])
        for inp, reference, pred in zip(inputs, references, pred_outputs):
            outwriter.writerow([inp, pred, reference])


def evaluate_rouge(model, input_sentences, reference_sentences, batch_size, max_prints):

    pred_outputs = _get_predictions(model, input_sentences, [
                                    [r] for r in reference_sentences], batch_size, max_prints)

    scorer = rouge_scorer.RougeScorer(
        ['rougeL', 'rouge1', 'rouge2', 'rouge2', 'rouge3', 'rouge4'], use_stemmer=True)

    def extract_score(a):
        a, b = a
        return scorer.score(a, b)
    rouge_l = list(map(extract_score, zip(pred_outputs, reference_sentences)))

    all_rouge_scores = {'rougeL': [],
                        'rouge1': [],
                        'rouge2': [],
                        'rouge3': [],
                        'rouge4': []}
    for l in rouge_l:
        for k, v in l.items():
            all_rouge_scores[k].append(v)

    rouge_scores = {'rougeL': np.array(all_rouge_scores['rougeL']).mean(),
                    'rouge1': np.array(all_rouge_scores['rouge1']).mean(),
                    'rouge2': np.array(all_rouge_scores['rouge2']).mean(),
                    'rouge3': np.array(all_rouge_scores['rouge3']).mean(),
                    'rouge4': np.array(all_rouge_scores['rouge4']).mean()}
    return rouge_scores, pred_outputs


def evaluate_wiki(model, mode="valid", params=None):

    sari, predictions = evaluate_sari(model, mode, params)
    b_acc = eval_binary_accuracy(model, predictions, mode, params)

    reference_sentences, norm_sentences, _ = _load_wikilarge_references(mode)
    bleu = evaluate_bleu(model, norm_sentences, reference_sentences, params.batch_size,
                         max_prints=0, return_predictions=False, predictions=predictions)
    if params.eval_self_bleu:
        self_bleu = evaluate_bleu(model, norm_sentences, [
                                  [n] for n in norm_sentences], params.batch_size, max_prints=0, return_predictions=False, predictions=predictions)
    else:
        self_bleu = -1.

    _save_to_csv(params, b_acc=b_acc, sari=sari,
                 bleu=bleu, self_bleu=self_bleu)
    params.current_epoch = params.current_epoch + 1

    return sari, sari, b_acc


def _load_wikilarge_references(mode):
    if mode == "valid":
        base_path = "../data/simplification/valid/"
    elif mode == "test":
        base_path = "../data/simplification/test/"

    norm_sentences = read_all(join(base_path, "norm"))
    simp_sentences = read_all(join(base_path, "simp"))

    reference_sentences_sep = [
        read_all(join(base_path, "turk" + str(i))) for i in range(8)]
    reference_sentences = []
    for i in range(len(reference_sentences_sep[0])):
        reference_sentences.append(
            [reference_sentences_sep[j][i] for j in range(8)])

    return reference_sentences, norm_sentences, simp_sentences


def evaluate_sari(model, mode="valid", params=None):
    batch_size = params.batch_size

    model.eval()

    reference_sentences, norm_sentences, simp_sentences = _load_wikilarge_references(
        mode)

    pred_simple_sentences = []
    for stidx in range(0, len(norm_sentences), batch_size):
        # prepare batch
        Sx_batch = norm_sentences[stidx:stidx + batch_size]
        Sy_batch = simp_sentences[stidx:stidx + batch_size]
        # model forward
        with torch.no_grad():
            pred_simple_sentences.extend(model(Sx_batch, Sy_batch))

    copy_baseline = _calc_sari(
        norm_sentences, norm_sentences, reference_sentences, params)
    obtained_scores = _calc_sari(
        norm_sentences, pred_simple_sentences, reference_sentences, params)
    print("Text Simplification Copy-Baseline:", copy_baseline)
    return obtained_scores, pred_simple_sentences


def _calc_sari(norm_sentences, pred_simple_sentences, reference_sentences, params):
    sari_scores = []
    for i, (n, s, rs) in enumerate(zip(norm_sentences, pred_simple_sentences, reference_sentences)):

        sari_scores.append(SARIsent(n, s, rs))
        if params.print_outputs and i < params.max_prints:
            pretty_print_prediction(n, rs[0], s)

    return np.array(sari_scores).mean()


def _get_data_pairs(params):
    """
    The dataset is assumed to be given as a directory containing
    the files 's1' (input sequence) and 's2' (output sequence) for each of the
    data splits, i.e. 's1.train', 's1.dev', 's1.test', and 's2.train', 's2.dev',
    's2.test'.
    Each file contains one text per line.
    """
    dataset_path = params.dataset_path

    endings = ["train", "dev", "test"]
    data_dict = {e: {} for e in endings}
    for ending in endings:
        fraction = params.data_fraction if ending == "train" else params.test_data_fraction
        s1 = read_file(join(dataset_path, "s1." + ending),
                       fraction)
        s1 = s1
        s2 = read_file(join(dataset_path, "s2." + ending), fraction)
        data_dict[ending]["Sx"] = s1 if not params.invert_style else s2
        data_dict[ending]["Sy"] = s2 if not params.invert_style else s1
    return data_dict["train"], data_dict["dev"], data_dict["test"]


def fast_gradient_iterative_modification(inputs, binary_classifier, weights=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
                                         decay_factor=0.9,
                                         t=0.001,
                                         max_steps=30,
                                         target=0, custom_loss=None, Y_embeddings=None, start_at_y=False):
    with torch.enable_grad():

        loss_f = BCEWithLogitsLoss()
        weights = torch.tensor(np.array(weights), device=inputs.device).float()

        if start_at_y:
            x = Y_embeddings.detach().clone()
        else:
            x = inputs.detach().clone()

        x = x.unsqueeze(0)
        x = x.repeat(len(weights), 1, 1)
        if Y_embeddings is not None:
            y = Y_embeddings.unsqueeze(0)
            y = y.repeat(len(weights), 1, 1)

        finished = torch.zeros(x.size(0), x.size(1), device=inputs.device)
        weights = weights.unsqueeze(1).unsqueeze(1)
        cnt = 0

        # if custom_loss:
        # custom_loss.eval()

        while True:
            cnt = cnt + 1
            x = x.detach().clone()
            x.requires_grad = True
            logits = binary_classifier(x)
            #print("logsize", logits.size())
            preds = torch.sigmoid(logits).squeeze()
            correct_classification = (
                torch.abs(target - preds) <= t) * (1 - finished)

            def first_nonzero(x, axis=0):
                nonz = (x > 0)
                return ((nonz.cumsum(axis) == 1) & nonz).max(axis)

            # print(correct_classification.nonzero())
            # print(preds)
            #print(correct_classification.sum(dim = 0))
            # print(correct_classification.size())
            if (correct_classification.sum(dim=0) >= 1).all() or cnt == max_steps:
                any_nonz, first_success = first_nonzero(
                    correct_classification, axis=0)
                #print(any_nonz, first_success)
                # print(any_nonz.sum())
                #assert any_nonz.sum() == x.size(1)
                break

            if custom_loss is None:

                binary_classifier.zero_grad()
                desired_label = torch.ones_like(
                    preds, device=inputs.device) if target == 1 else torch.zeros_like(preds, device=inputs.device)
                loss = loss_f(preds, desired_label)
            else:
                # custom_loss.zero_grad()
                loss = custom_loss(x.view(-1, x.size(2)),
                                   y.view(-1, y.size(2)))
                if type(loss) == tuple:
                    loss = loss[0]

            # print(correct_classification)
            weights = weights * decay_factor * \
                (1 - correct_classification.float().unsqueeze(2))
            # print(weights)
            # print(weight)
            # compute gradients
            # if x.grad is not None:
            # x.grad.fill_(0)
            # print(x.grad)
            loss.sum().backward()
            #print(x-(x-weights * x.grad))
            x = x - weights * x.grad

        # print("fs",first_success)
        # print()
        x = x.gather(0, first_success.view(1, -1, 1).expand_as(x))[0, :, :]
        # print("XXXXXXXXX",x)
        # print(x.size())
        #logits = binary_classifier(x)
        #preds = torch.sigmoid(logits).squeeze()
        # print(preds)
        return x


def main():
    binary_classifier_path = "/anon/temp/anon/distilyelp/checkpoint-20000/"
    with open("/anon/temp/anon/shen/language-style-transfer/tmp/sentiment.dev.0.tsf", 'r') as f:
        predictions = [l.strip() for l in f]
    binary_tokenizer = AutoTokenizer.from_pretrained(binary_classifier_path)
    binary_classifier = AutoModelForSequenceClassification.from_pretrained(
        binary_classifier_path)

    params = Namespace(binary_classifier=binary_classifier, binary_tokenizer=binary_tokenizer, batch_size=64,
                       max_prints=20, output_file="shen.csv", data_fraction=1.0, print_outputs=True, run_id=1337, current_epoch=0)
    evaluate_yelp(binary_classifier, mode="valid",
                  params=params, predictions=predictions)


if __name__ == "__main__":
    main()
