import numpy as np
from typing import Optional
import torch
import csv
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM
)
import tensorflow as tf
import random
import numpy as np

tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-base')
encoder_embed = model.state_dict()['model.encoder.embed_tokens.weight'] # encoder embedding
model.eval()
dimension = 36
hidden = 768
low = 0.3# bart_score_low
high = 0.7#bart_score_high
accept_prob = high



split_data = np.load('data-augmentation-for-sts/data/low_data.npz', allow_pickle=True)['store'][()]['train']
dev_data = np.load('data-augmentation-for-sts/data/low_data.npz', allow_pickle=True)['store'][()]['dev']
test_data = np.load('data-augmentation-for-sts/data/low_data.npz', allow_pickle=True)['store'][()]['test']
print("Example of data:", split_data[0])

train_embed = np.zeros((len(split_data), hidden, dimension))
for i in range(len(split_data)):
  txt = split_data[i][0]
  # Exclude first and last, meaning the start and end of sentence
  input_ids = torch.tensor([tokenizer.encode(txt, add_special_tokens=True)])[0][1:-1]
  if len(input_ids) > dimension:
    input_ids = input_ids[:dimension]

  for j in range(len(input_ids)):
    train_embed[i,:,j] = encoder_embed[input_ids[j]]

candidates = encoder_embed.numpy()
candidates = tf.linalg.l2_normalize(candidates, axis=-1)
normalized_dictionary = candidates.numpy()

syn_data = []

for _ in range(1): # 9 (low-resource) or 2 (half) or 1 (all)
  for i in range(len(split_data)):
    txt = split_data[i][0]

    tokenized_input = tokenizer.tokenize(txt)
    embedding = train_embed[i]

    embedding_size = np.shape(embedding)[0]
    if len(np.where(embedding[0] == 0)[0]) != 0:  # Find the dimensions until considered tokens
        target_len = np.where(embedding[0] == 0)[0][0]
    else:
        target_len = dimension

    input_data = embedding
    input_data = np.transpose(input_data)
    input_data = input_data[:target_len, :]
    input_data = tf.linalg.l2_normalize(input_data, axis=-1)
    # Compute dot product
    similarity = tf.linalg.matmul(input_data, normalized_dictionary, transpose_b=True)
    # Sorting indices across tokens and truncating to up to 50 candidates
    indices = tf.argsort(similarity, axis=-1, direction='DESCENDING')[:, :50]
    results = similarity.numpy()
    indices = indices.numpy()
    final_shape = np.shape(indices)

    output_loc = []
    for k in range(final_shape[0]):
        for jj in range(final_shape[1]):
            token_index = indices[k, jj]
            score = results[k, token_index]

            # Define threshold for comparison
            if low < score < high:
                output_loc.append((k, token_index, score))
    output_loc = sorted(output_loc, key=lambda x: x[2], reverse=True)


    output_txt = [[] for _ in range(dimension)]  # Generate the possible combination
    for kk in range(len(output_loc)):
        token = tokenizer.decode(output_loc[kk][1], skip_special_tokens=True)
        if tokenized_input[output_loc[kk][0]] == token and len(token.strip()) > 0:
            continue
        output_txt[output_loc[kk][0]].append((token, output_loc[kk][2]))

    synthetic_text = []
    for piece_id in range(target_len):
        old_fragment: str = tokenized_input[piece_id]
        candidates = output_txt[piece_id]

        # Old tokens and the replacing tokens should be coherent
        filtered_candidates = []
        for candidate_tuple in candidates:
            candidate, score = candidate_tuple
            if (old_fragment.startswith(" ") or old_fragment.startswith("Ġ")) \
                    != (candidate.startswith(" ") or candidate.startswith("Ġ")):
                continue
            if candidate.isupper() != old_fragment[0].isupper():
                continue
            filtered_candidates.append(candidate)
        candidates = filtered_candidates

        if len(candidates) == 0 or random.random() > accept_prob:
            synthetic_text.append(tokenized_input[piece_id])
            continue

        chosen_candidate = None
        for candidate in candidates:
            if random.random() >= 0.5:
                chosen_candidate = candidate
                break
        if chosen_candidate is None:
            chosen_candidate = candidates[0]
        synthetic_text.append(chosen_candidate)
    synthetic_text = tokenizer.convert_tokens_to_string(synthetic_text)


    syn_data.append([synthetic_text, split_data[i][1], split_data[i][2]])

syn_data = syn_data + split_data
print("size of train data:", len(split_data))
print("size of augmented data:", len(syn_data))

store = {}
store['train'] = syn_data
store['dev'] = dev_data
store['test'] = test_data

np.savez('data-augmentation-for-sts/low_data_bart.npz', store=store)