import os
import socket
import sys


from coli.torch_hrg.hrg_parser_greedy import UdefQParserGreedy


from coli.basic_tools.common_utils import ensure_dir
from coli.hrgguru.extra_labels import EPVariablesLabeler
from coli.hrgguru.extract_sync_grammar import ExtractionParams, extract_dataset, datasets
from coli.hrgguru.hrg import HRGDerivation
from coli.hrgguru.strip_utils import STRIP_ALL_LABELS
from coli.parser_tools.training_scheduler import TrainingScheduler
from coli.torch_hrg.internal_tagger import InternalTagger
from coli.torch_hrg.pred_tagger import PredTagger
from coli.torch_span.parser import SpanParser
from coli.torch_hrg.hrg_parser_graph_embedding import GraphEmbeddingUdefQParser


# the count-based model
parser_class = UdefQParserGreedy
# the subgraph-based model
# parser_class = GraphEmbeddingUdefQParser


home = real_home = os.path.expanduser("~")

# the two path is mentioned in https://github.com/draplater/hrg-parser
main_dir_base = YOUR_PATH_HERE
deepbank_export_dir = YOUR_PATH_HERE
# a flat copy of deepbank_export_dir
deepbank_export_flat = YOUR_PATH_HERE

# a folder containing lm_weights.hdf5, options.json and vocab.txt
elmo_path = YOUR_PATH_HERE


project_name = "deepbank1.1-lfrg-submission"
span_name = project_name
tagger_name = project_name
internal_tagger_name = project_name

# modify this if your pipeline broke at some stage
start_stage = 0

# the output directory
outdir_prefix = home + "/work/hrg_chain/"



grammar_dir = outdir_prefix + project_name + "/grammar/"
span_dir = outdir_prefix + project_name + "/span/"
tagger_dir = outdir_prefix + project_name + "/tagger/"


# grammar extraction
if start_stage <= 0:
    ensure_dir(grammar_dir)
    params = ExtractionParams(strip_tree=STRIP_ALL_LABELS,
                              detect_func=HRGDerivation.detect_lfrg,
                              graph_type="lfrg",
                              punct_hyphen_fixer="none",
                              extra_labels=EPVariablesLabeler,
                              extra={"include_qeq": True},
                              )

    extract_dataset(project_name, datasets["wsj"], params,
                    java_out_dir=main_dir_base,
                    deepbank_export_path=deepbank_export_dir,
                    output_dir=grammar_dir
                    )

# phrase structure parsing
if start_stage <= 1:
    op = SpanParser.Options()
    op.debug_cache = True
    op.hparams.evaluate_every = 300
    op.hparams.train_batch_size = 2000
    op.hparams.max_sentence_batch_size = 250
    op.hparams.predict_postags = True
    op.hparams.d_label_hidden = 768
    op.hparams.pretrained_contextual.elmo_options.path = elmo_path
    op.train = f"{grammar_dir}/{project_name}.fulllabel.train"
    op.dev = [f"{grammar_dir}/{project_name}.fulllabel.dev"]
    op.gpu = True
    op.use_rules = True
    op.restrict_root_rule = True
    op.use_exception_handler = True
    op.hparams.bucket_type = "length_group"

    trainer = TrainingScheduler(SpanParser)
    trainer.add_options(span_name, op, span_dir)

    trainer.run_parallel()


# EP tagger
if start_stage <= 2:
    scheduler = TrainingScheduler(PredTagger)

    op = PredTagger.Options()
    op.grammar = f"{grammar_dir}/cfg_hrg_mapping-{project_name}.pickle"
    op.debug_cache = True
    op.use_exception_handler = True
    op.bilm_path = elmo_path
    op.train = f"{grammar_dir}/{project_name}.fulllabel.train"
    op.dev = [f"{span_dir}/model-{span_name}/{project_name}.fulllabel_epoch_best.dev"]
    op.span_model = f"{span_dir}/model-{span_name}/model"
    op.hparams.evaluate_every = 300

    op.hparams.pretrained_contextual.type = "elmo"
    op.hparams.pretrained_contextual.elmo_options.path = elmo_path

    op.hparams.max_sentence_batch_size = 128

    op.hparams.attachment_contextual.type = "lstm"
    op.hparams.attachment_contextual.lstm_options.num_layers = 2
    op.hparams.attachment_contextual.lstm_options.hidden_size = 512
    op.hparams.attachment_contextual.lstm_options.recurrent_keep_prob = 1
    op.hparams.attachment_dims_hidden = [512]

    op.hparams.contextual.type = "lstm"
    op.hparams.contextual.lstm_options.num_layers = 2
    op.hparams.contextual.lstm_options.hidden_size = 512
    op.hparams.contextual.lstm_options.recurrent_keep_prob = 1
    op.hparams.dims_hidden = [512]
    op.hparams.use_crf = False
    op.gpu = True
    op.hparams.train_iters = 40000
    op.hparams.stop_grad = True

    scheduler.add_options(tagger_name, op, tagger_dir)
    scheduler.run_parallel()


# semantic interpretation
hrg_dir = outdir_prefix + project_name + "/hrg/"
if start_stage <= 3:
    dev_file = f"{span_dir}/model-{span_name}/{project_name}.fulllabel_epoch_best.dev"
    assert os.path.exists(dev_file), f"{dev_file} not exist"

    op = parser_class.Options.get_default()
    op.span_model = f"{span_dir}/model-{span_name}/model"
    op.pred_tagger = f"{tagger_dir}/model-{tagger_name}/model"
    op.graph_type = "lf"
    op.bilm_path = elmo_path
    op.deepbank_dir = deepbank_export_flat
    op.debug_cache = True
    op.hparams.disable_span_dropout = True
    op.gpu = True
    op.hparams.stop_grad = True
    op.derivations = f"{grammar_dir}/derivations-{project_name}.pickle"
    op.grammar = f"{grammar_dir}/cfg_hrg_mapping-{project_name}.pickle"
    op.train = f"{grammar_dir}/{project_name}.fulllabel.train"
    op.dev = [f"{span_dir}/model-{span_name}/{project_name}.fulllabel_epoch_best.dev"]
    op.hparams.print_every = 2
    op.hparams.evaluate_every = 600
    op.hparams.max_sentence_batch_size = 128
    op.hparams.train_batch_size = 200
    op.use_exception_handler = True

    if issubclass(parser_class, GraphEmbeddingUdefQParser):
        op.hparams.graph_embedding.hrg_batch_size = 128
        op.hparams.loss_type = "cross_entropy"
        op.beam_size = 1
        op.hparams.graph_embedding.graph_embedding_type = "lf"
        op.hparams.graph_embedding.graph_encoder.use_attention = False
    else:
        op.hparams.scorer.type = "count"

    trainer = TrainingScheduler(parser_class)

    trainer.add_options(hrg_name, op, hrg_dir)

    trainer.run_parallel()


# phrase structure parsing (test set)
if start_stage <= 4:
    scheduler = TrainingScheduler(SpanParser)
    op = {
        "model": f"{span_dir}/model-{span_name}/model",
        "test": f"{grammar_dir}/{project_name}.fulllabel.test",
        "output": f"{span_dir}/model-{span_name}/{project_name}.fulllabel_epoch_best.test",
        "bilm-path": elmo_path,
        "eval": "True",
        "gpu": "True",
        # "use-exception-handler": "True",
    }
    scheduler.add_options(",", op, "", mode="predict")
    scheduler.run()


# semantic interpretation (test set)
if start_stage <= 5:
    scheduler = TrainingScheduler(parser_class)

    op = {
        "span-model": f"{span_dir}/model-{span_name}/model",
        "pred-tagger": f"{tagger_dir}/model-{tagger_name}/model",
        "model": f"{hrg_dir}/model-{hrg_name}/model",
        "test": f"{span_dir}/model-{span_name}/{project_name}.fulllabel_epoch_best.test",
        "output": f"{hrg_dir}/model-{hrg_name}/test",
        "beam-size": 1,
        "deepbank-dir": deepbank_export_flat,
        "bilm-path": elmo_path,
        "eval": "True",
        "gpu": "True"
    }

    scheduler.add_options(",", op, "", mode="predict")
    scheduler.run()
