from flair.embeddings import TransformerWordEmbeddings, StackedEmbeddings, CharacterEmbeddings, MuseCrosslingualEmbeddings, WordEmbeddings, BytePairEmbeddings
from flair.data import Corpus
from flair.datasets import ColumnCorpus
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
from hyperopt import hp
from flair.hyperparameter.param_selection import SearchSpace, Parameter
from flair.hyperparameter.param_selection import SequenceTaggerParamSelector, OptimizationValue
import argparse
from pathlib import Path



parser = argparse.ArgumentParser()
parser.add_argument('--data_folder', type=str, help='Path to file with training data')
parser.add_argument('--out', type=str, help='Path to folder for out')

args = parser.parse_args()

# define columns
columns = {0: 'text', 1: 'borrowing'}


print("Creating corpus...")
# init a corpus using column format, data folder and the names of the train, dev and test files
corpus: Corpus = ColumnCorpus(args.data_folder, columns,
                              train_file='training.conll',
                              test_file='test_background.conll',
                              dev_file='dev.conll')
                              
print("Corpus created...")



# init embedding
# embeddings = BytePairEmbeddings('es')
#embeddings = StackedEmbeddings([TransformerWordEmbeddings('dccuchile/bert-base-spanish-wwm-cased'), TransformerWordEmbeddings('bert-base-cased'), CharacterEmbeddings()])
# Most successful
"""
embeddings = StackedEmbeddings(
    [
        # standard FastText word embeddings for English
        #MuseCrosslingualEmbeddings(),
        CharacterEmbeddings(),
        TransformerWordEmbeddings('dccuchile/bert-base-spanish-wwm-cased'), 
        TransformerWordEmbeddings('bert-base-cased'),
        #TransformerWordEmbeddings('bert-base-multilingual-cased'),
        BytePairEmbeddings('en'),
        BytePairEmbeddings('es'),
    ]
)

embeddings = StackedEmbeddings(
    [
        # standard FastText word embeddings for English
        #MuseCrosslingualEmbeddings(),
        CharacterEmbeddings(),
        #TransformerWordEmbeddings('bert-base-multilingual-cased'),
        #BytePairEmbeddings('multi'),
    ]
)

"""
embeddings = StackedEmbeddings(
    [
        # standard FastText word embeddings for English
        #MuseCrosslingualEmbeddings(),
        CharacterEmbeddings(),
        TransformerWordEmbeddings('dccuchile/bert-base-spanish-wwm-cased'), 
        TransformerWordEmbeddings('bert-base-cased'),
        #TransformerWordEmbeddings('bert-base-multilingual-cased'),
        #TransformerWordEmbeddings('sagorsarker/codeswitch-spaeng-lid-lince'),
        BytePairEmbeddings('en'),
        BytePairEmbeddings('es'),
    ]
)
print("Embeddings created...")
"""
embedding = BytePairEmbeddings('multi')
flair_embedding_forward = FlairEmbeddings('es-X')
flair_embedding_forward = FlairEmbeddings('multi-X')
embeddings = OneHotEmbeddings(corpus=corpus)
embedding = TransformerWordEmbeddings('bert-base-multilingual-cased')
embedding = TransformerWordEmbeddings('xlm-mlm-xnli15-1024')
glove_embedding = WordEmbeddings('es')
char_embedding = CharacterEmbeddings()
"""

tag_type = 'borrowing'

# 3. make the tag dictionary from the corpus
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
print("Tagset dict created...")


tagger: SequenceTagger = SequenceTagger(hidden_size=256,
                                        embeddings=embeddings,
                                        tag_dictionary=tag_dictionary,
                                        tag_type=tag_type,
                                        use_crf=True)
                                        
print("Tagger created...")




trainer: ModelTrainer = ModelTrainer(tagger, corpus)
print("Trainer created...")


trainer.train(args.out,
              learning_rate=0.1,
              mini_batch_size=32,
              max_epochs=150)
print("Done training...")


tagger: SequenceTagger = SequenceTagger.load(Path(args.out) / 'best-model.pt')
tagger.embeddings = embeddings

# run evaluation procedure
result, score = tagger.evaluate(corpus.dev, mini_batch_size=32, out_path= Path(args.out) / "final_dev.tsv")
result, score = tagger.evaluate(corpus.test, mini_batch_size=32, out_path= Path(args.out) / "final_test.tsv")
print(result.detailed_results)

