import data_provider
import os
import config as c
import helper
from transformers import BertTokenizer
from transformers import RobertaTokenizer
from transformers import XLMRobertaTokenizer

##### PARSING / MLM ######
if not c.mlm:
    deps_dict = None
    if os.path.exists(os.path.join(c.base_path, c.deps_dict_path)):
        deps_dict = helper.deserialize(os.path.join(c.base_path, c.deps_dict_path)) 

if c.pretrained_transformer.startswith("bert-"):
    tokenizer = BertTokenizer.from_pretrained(c.pretrained_transformer)
if c.pretrained_transformer.startswith("roberta-"):
    tokenizer = RobertaTokenizer.from_pretrained(c.pretrained_transformer)
if c.pretrained_transformer.startswith("xlm-roberta"):
    tokenizer = XLMRobertaTokenizer.from_pretrained(c.pretrained_transformer)

sentences, arc_labels, rel_labels, deps_dict = data_provider.load_ud_treebank(os.path.join(c.base_path, c.in_file), None if c.mlm else deps_dict, c.max_word_len)
print("Num. sentences: " + str(len(sentences)))
print("Num. dependency relations: " + str(len(deps_dict)))

if c.mlm:
    sentences = [("" if c.is_zh else " ").join(s) for s in sentences]

if c.mlm:
    data_provider.featurize_serialize_mlm(sentences, os.path.join(c.base_path, c.out_file), tokenizer)
else:
    data_provider.featurize_serialize_ud(os.path.join(c.base_path, c.out_file), sentences, arc_labels, rel_labels, tokenizer, is_roberta = "roberta" in c.pretrained_transformer)

if not c.mlm:
    helper.serialize(deps_dict, os.path.join(c.base_path, c.deps_dict_path))

###### PAWS-X ######
# if c.pretrained_transformer.startswith("bert-"):
#     tokenizer = BertTokenizer.from_pretrained(c.pretrained_transformer)
# if c.pretrained_transformer.startswith("roberta-"):
#     tokenizer = RobertaTokenizer.from_pretrained(c.pretrained_transformer)
# if c.pretrained_transformer.startswith("xlm-roberta"):
#     tokenizer = XLMRobertaTokenizer.from_pretrained(c.pretrained_transformer)

# examples = data_provider.load_paws(os.path.join(c.base_path, c.in_file))
# print("Num. sentences: " + str(len(examples)))

# data_provider.feature_serialize_paws(examples, os.path.join(c.base_path, c.out_file), tokenizer)
