"""
This examples trains BERT (or any other transformer model like RoBERTa, DistilBERT etc.) for the STSbenchmark from scratch. It generates sentence embeddings
that can be compared using cosine-similarity to measure the similarity.
Usage:
python training_nli.py
OR
python training_nli.py pretrained_transformer_model_name
"""

### the code is from the work of sentence BERT, please see 
## https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/sts/training_stsbenchmark.py
## we slightly revise some setting based on it.



from random import sample
import torch
from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer,  LoggingHandler, losses, models, util
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
import logging
from datetime import datetime
import sys
import os
import gzip
import csv
import time 
import json 
import random 

def reinit(model, ckpt_fp):
    ## init the sentence BERT model by our trained GR-BERT 
    if ckpt_fp is None or len(ckpt_fp.strip()) == 0 or not os.path.exists(ckpt_fp): 
        print("no pretrain model in init")
        return model 
    print(ckpt_fp)
    bert_params = torch.load(ckpt_fp , map_location=torch.device('cpu') ) 
    state_dict = dict()
    model_dict = model.state_dict()
    for key in bert_params:
    
        if "sent_encoder" not in key: continue 
        new_k = key.replace("sent_encoder.", "")
        new_k = "0.auto_model." + new_k 
        assert new_k in model_dict
        state_dict[new_k] = bert_params[key]
    print(len(model_dict), len(state_dict))
    model_dict.update( state_dict ) 
    model.load_state_dict(model_dict)
    return model

### data sample id , exp on 4 seeds
### used in low resource data case, we randomly shuffle the data, and pick top X% of data for training and evaluation
def generate_random_ids(data_exps, out_fp, train_samples, dev_samples):
    exp_ids = {}
    train_ids = [i for i in range(len(train_samples))]
    dev_ids = [i for i in range(len(dev_samples))]
    for i in range(data_exps):
        random.shuffle(train_ids)
        random.shuffle(dev_ids) 
        exp_ids[i] = {"train": train_ids, "dev": dev_ids}
    if os.path.exists(out_fp): os.remove( out_fp )
    with open(out_fp,"w") as f:
        f.write( json.dumps(exp_ids))
    print("data shuffle")
    return 

start = time.time()
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout





##Part 1: define the hyper-parameters: 
#Check if dataset exsist. If not, download and extract  it
sts_dataset_path = 'XXX/data/stsbenchmark.tsv.gz'
    ### the place where you put the sts-b data 
if not os.path.exists(sts_dataset_path):
    util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)
out_params_folder="XXX/sentence_transformer/params/" 
    ### the folder where you want to store the training output 



### define the hyper-parameters here 
init_model=True     ## when false, we use the original pretrain params in BERT and RoBERTa
train_batch_size = 16 
num_epochs = 4
train_data_ratio=1  ## you can choose 0.5, 0.2, etc.

exp_i = sys.argv[1].strip()         ## default 0
folder_name = sys.argv[2].strip()   ## name of sub-folder in out_params_folder to store the training output
ckpt_fp = sys.argv[3].strip()       ## ckpt used to init the model 
model_type = sys.argv[4].strip()    ## bert or roberta 



###====================
### PART 1: define the model 
if "bert" == model_type:
    model_name = 'bert-base-uncased'
elif "roberta" == model_type:
    model_name="roberta-base"
else:
    assert False 
print("model name: ", model_name )
model_save_path = out_params_folder+ "/%s/"%(folder_name)

os.system("mkdir %s"%(model_save_path))
cmd = "rm -r %s/*"%(model_save_path)
os.system(cmd)
# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
word_embedding_model = models.Transformer(model_name)
# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                               pooling_mode_mean_tokens=True,   # default setting
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False)

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
if init_model:
    model=reinit(model, ckpt_fp)



###====================
### PART 2: load the dataset, train/eval/test
all_train_samples = []
all_dev_samples = []
test_samples = []
with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn:
    reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
    for row in reader:
        score = float(row['score']) / 5.0  # Normalize score to range 0 ... 1
        inp_example = InputExample(texts=[row['sentence1'].lower(), row['sentence2'].lower()], label=score)
        #inp_example = InputExample( texts=[ remove_puncts(row['sentence1']), remove_puncts(row['sentence2']) ] ,  label=score)

        if row['split'] == 'dev':
            all_dev_samples.append(inp_example)
        elif row['split'] == 'test':
            test_samples.append(inp_example)
        else:
            all_train_samples.append(inp_example)

## if you train the model on low-resource case, set a file-path to record the used data, and ensure that the part of data used 
## is the same in diffrent model 
out_fp = "XXXX"
#generate_random_ids(data_exps, out_fp, all_train_samples, all_dev_samples )

if train_data_ratio < 1: 
    exp_ids = json.loads( open(out_fp, "r").read() )

if train_data_ratio < 1: 
    exp_i = str(exp_i) 
    train_size=int(len(all_train_samples)*train_data_ratio)
    dev_size=int(len(all_dev_samples)*train_data_ratio)
    #train_samples = train_samples[:data_size]
    #dev_samples = dev_samples[:data_size]
    train_ids = exp_ids[exp_i]["train"]
    dev_ids = exp_ids[exp_i]["dev"] 
    train_samples = [all_train_samples[i] for i in train_ids[:train_size] ]
    dev_samples = [all_dev_samples[i] for i in dev_ids[:dev_size] ]
    print("use model exp ", exp_i)
else:
    train_samples = all_train_samples[:]
    dev_samples = all_dev_samples[:] 

print("[train data size %d][dev size %d][test size %d]"%(len(train_samples), len(dev_samples), len(test_samples)))








###====================
### PART 3: training the model and obtain the result 
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)

logging.info("Read STSbenchmark dev dataset")
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')

# Configure the training. We skip evaluation in this example
warmup_steps = math.ceil(len(train_dataloader) * num_epochs  * 0.1) #10% of train data for warm-up
logging.info("Warmup-steps: {}".format(warmup_steps))
## default: WarmupLinear
## lr 2e-5

# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
        evaluator=evaluator,
        epochs=num_epochs,
        evaluation_steps=1000,
        warmup_steps=warmup_steps,
        output_path=model_save_path)

##############################################################################
#
# Load the stored model and evaluate its performance on STS benchmark dataset
#
##############################################################################

model = SentenceTransformer(model_save_path)
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')
test_evaluator(model, output_path=model_save_path)
print("it runs: ", time.time()-start)