from torch.utils.data import DataLoader
import random
import math
from sentence_transformers import models, losses
from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
import logging
from datetime import datetime
import sys
import re
import os

print("Loading the models...")

# Parameters
model_name = sys.argv[1] if len(sys.argv) > 1 else 'gsarti/biobert-nli' #'emilyalsentzer/Bio_ClinicalBERT' #'bert-base-uncased'
friendly_model_name = model_name.replace('/home/fremy/Downloads/azdelta/output/','').strip('/').replace('/0_Transformer','').replace("/", "-")
model_save_path = 'output/defs3ft_'+friendly_model_name+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
train_batch_size = 64

# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
word_embedding_model = models.Transformer(model_name, max_seq_length=128)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

print("Reading the data...")

train_samples = []

data_filename = './data/umls/defs_contrast.txt'
for line in open(os.path.join(data_filename)):

    # Split the data file into components
    line_data = str(line).rstrip('\r\n').split('\t')
    w1 = line_data[0]
    w2 = line_data[1]

    # Add the base training example
    train_samples.append(InputExample(texts=[w1, w2], label=1))

    # Add the abbreviation augmented example
    if re.search(r'\b[A-Z][A-Z]+\b', w1) != None:
        w1_abb = re.sub(r'\b[A-Z][A-Z]+\b', lambda s: ".".join(list(s.group()))+'.', w1)
        train_samples.append(InputExample(texts=[w1_abb, w2], label=1))

data_filename = './data/umls/defs_contrast_2.txt'
for line in open(os.path.join(data_filename)):

    # Split the data file into components
    line_data = str(line).rstrip('\r\n').split('\t')
    w1 = line_data[0]
    w2 = line_data[1]

    # Add the base training example
    train_samples.append(InputExample(texts=[w1, w2], label=1))

    # Add the abbreviation augmented example
    if re.search(r'\b[A-Z][A-Z]+\b', w1) != None:
        w1_abb = re.sub(r'\b[A-Z][A-Z]+\b', lambda s: ".".join(list(s.group()))+'.', w1)
        train_samples.append(InputExample(texts=[w1_abb, w2], label=1))

data_filename = './data/umls/rels_contrast.txt'
for line in open(os.path.join(data_filename)):

    # Split the data file into components
    line_data = str(line).rstrip('\r\n').split('\t')
    w1 = line_data[0]
    w2 = line_data[1]

    # Add the base training example
    train_samples.append(InputExample(texts=[w1, w2], label=1))

    # Add the abbreviation augmented example
    if re.search(r'\b[A-Z][A-Z]+\b', w1) != None:
        w1_abb = re.sub(r'\b[A-Z][A-Z]+\b', lambda s: ".".join(list(s.group()))+'.', w1)
        train_samples.append(InputExample(texts=[w1_abb, w2], label=1))

    if len(train_samples) > 100*1000*1000: break
    
print("Starting the training...")

# Use a standard constrastive learning strategy
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MultipleNegativesRankingLoss(model=model)

# We are now ready to fit
num_epochs = 1
warmup_steps = math.ceil(len(train_dataloader) * 0.05) #5% of train data for warm-up

# We do not evaluate because the training loss is as good as a development loss given we only see each sample once
def dummy_evaluator(model, output_path='', epoch=0, steps=0):
    return epoch + steps / 1000000.0

# All unset parameters are obviously equal to their default value...
model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=num_epochs,
          evaluator=dummy_evaluator,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path=model_save_path,
          use_amp=True
          )