import argparse
import json
import os
from logging import getLogger

from utils.task_config import TaskConfig

logger = getLogger(__name__)

label_type = {
  'sentiment': [
    'sentiment'
  ],
  'emotion': [
    'sentiment', 'happiness', 'sadness', 'anger', 'fear', 'disgust', 'surprise'
  ],
  'pure_emotion': [
    'happiness', 'sadness', 'anger', 'fear', 'disgust', 'surprise', 'neutral'
  ],
  'emotion_class': [
    'sentiment', 'emotion'
  ],
  'meld_emotion': [
    'anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise'
  ],
  'expression': [
    'confident', 'passionate', 'voice_pleasant', 'dominant', 'credible', 'vivid', 'expertise',
    'entertaining', 'reserved', 'trusting', 'relaxed', 'outgoing', 'thorough', 'nervous',
    'humerous', 'persuasion'
  ],
  'sentiment_binary': ['positive', 'negative']
}

def parse_arg() -> argparse.Namespace:
  parser = argparse.ArgumentParser()

  # source configuration
  parser.add_argument("--src_dir", "-s", type=str, default="data", help="The location for input directory prefix")
  parser.add_argument("--src_dir_2", "-2", type=str, default="data", help="The location for input directory prefix, 2nd data line")
  parser.add_argument("--src_lang", "-l", type=str, default="lang", help="The location for language input directory")
  parser.add_argument("--src_audio", "-a", type=str, default="audio.pkl", help="The location for audio input file")
  parser.add_argument("--config", "-p", type=str, help="config path if exists")
  parser.add_argument("--src_audio_raw", "-A", type=str, default="audio_raw.pkl", help="The location for raw audio input file")
  parser.add_argument("--src_video", "-v", type=str, default="video.pkl", help="The location for video input file")
  parser.add_argument("--src_video_raw", "-V", type=str, default="video_raw.pkl", help="The location for raw video input file")
  parser.add_argument("--src_sema", "-S", type=str, default="sema.pkl", help="The location for semantics data (e.g. LMMS) input file")
  parser.add_argument("--category", "-c", type=str, default="sentiment", help="label type for training. sentiment, emotion, or expression")
  parser.add_argument("--category_2", "-C", type=str, default="none", help="label type for training. sentiment, emotion, or expression")
  parser.add_argument("--loss_type", "-L", type=str, default="regression", help="specify loss function type. regression, classify or hybrid")

  # task type
  parser.add_argument("--param_confirm", action="store_true", help="set if only parameter number confirmation")
  parser.add_argument("--online_test", action="store_true", help="Running test while training to pick best models")
  parser.add_argument("--stmt", action="store_true", help="Use STMT model")
  parser.add_argument("--no_align", action='store_true', help="Use STMT V2(no-aligned) model")
  parser.add_argument("--cw_msu", action="store_true", help="Use MSU layer")

  # dest configuration
  parser.add_argument("--dest", "-d", default='out', type=str, help="The output directory for the task")
  parser.add_argument("--no_model_save", action='store_true', help="Indicates not to save best state_dict")

  # device configuration
  parser.add_argument("--gpu_num", default=1, type=int, help="Indicates count of gpu for this task")
  parser.add_argument("--gpu_id", default=0, type=int, help="Indicates id of gpu for this task")
  parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available")

  # training configs
  parser.add_argument("--ramdom_seed", type=int, help="fixing random seed number. i.e. 42")
  parser.add_argument("--learning_rate", default=2e-4, type=float, help="The initial learning rate for optimizer")
  parser.add_argument("--learning_rate_2", default=None, type=float, help="The initial learning rate for optimizer (for 2nd data stream)")
  parser.add_argument("--num_train_epochs", "-e", default=50, type=float, help="Total number of training epochs to perform.")
  parser.add_argument("--max_stalled_attempt", "-E", default=50, type=int, help="Early stopping trigger sequential attempt.")
  parser.add_argument("--attempt", "-t", type=int, default=1, help="Indicates trial attempt")
  parser.add_argument("--warmup_proportion", default=0.1, type=float,
                      help="Proportion of training to perform linear learning rate warmup for. e.g. 0.1 = 10%% of training.")
  parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                      help="Number of updates steps to accumulate before performing a backward/update pass.")

  # tuning related configs
  parser.add_argument("--use_bce", action='store_true', help="Indicate to use binary cross entropy loss for vae loss function")
  parser.add_argument("--no_bert_tuning", action='store_true', help="Indicates to do training without bert fine-tuning")
  parser.add_argument("--less_bert_tuning", action='store_true', help="Indicates to do training with 'lesser' bert fine-tuning")
  parser.add_argument("--bert_freeze_layers", "-F", type=int, default=0, help="layer count to freeze bert layer")
  parser.add_argument("--no_topic", action='store_true', help="Indicates to do training without topic attention")
  parser.add_argument("--no_word_topic", action='store_true', help="Indicates to do training without word topic attention")
  parser.add_argument("--later_topic", action='store_true', help="Indicates to apply topic vector after bert encoder")
  parser.add_argument("--no_sema", action='store_true', help="Indicates to do training without semantic attention")
  parser.add_argument("--priority", type=int, default=-1, help="(Only for multidata) Indicates which loss should be prioritized. Default(-1) indicates sum of both")
  parser.add_argument("--no_video", action='store_true', help="Indicates to do training without video")
  parser.add_argument("--no_audio", action='store_true', help="Indicates to do training without audio")
  parser.add_argument("--a_conv_v", action='store_true', help="vertical convolution for audio unaligned data")
  parser.add_argument("--v_conv_v", action='store_true', help="vertical convolution for video unaligned data")

  # batch size configs
  parser.add_argument('--train_batch', default=48, type=int, help="The batch size for train task")
  parser.add_argument('--eval_batch', default=24, type=int, help="The batch size for eval task")
  parser.add_argument('--test_batch', default=24, type=int, help="The batch size for test task")

  # model config
  parser.add_argument('--bert_model', default="bert-base-uncased", type=str,
                      help="The name or path of pretrained BERT model used by the task")
  parser.add_argument('--pretrained_path', default="", type=str, help="pretrained model path for fine tuning, if exists")
  parser.add_argument('--w2v2_model', default="facebook/wav2vec2-base-960h", type=str,
                      help="The name or path of pretrained Wav2Vec2.0 model used by the task")
  parser.add_argument('--case_sensitive', action="store_true", help="Indicates using case-sensitive dataset or not")
  parser.add_argument('--topic_dim', default=40, type=int, help="The dimension size for topic vector")
  parser.add_argument('--sema_dim', default=1024, type=int, help="The dimension size for semantic info vector")
  parser.add_argument('--lang_dim', default=768, type=int, help="The language modal feature vector dimension size")
  parser.add_argument('--audio_dim', default=5, type=int, help="The audio modal feature vector dimension size")
  parser.add_argument('--video_dim', default=35, type=int, help="The video modal feature vector dimension size")
  parser.add_argument('--lang_head', default=12, type=int, help="The language modal attention head size")
  parser.add_argument('--audio_head', default=1, type=int, help="The audio modal attention head size")
  parser.add_argument('--video_head', default=5, type=int, help="The video modal attention head size")
  parser.add_argument("--seq_limit", default=50, type=int,
                      help="The maximum total input sequence length after WordPiece tokenization. \n"
                           "Sequences longer than this will be truncated, and sequences shorter \n"
                           "than this will be padded.")
  parser.add_argument("--audio_raw_seq_limit", default=5000, type=int)
  parser.add_argument("--video_raw_seq_limit", default=1250, type=int)
  parser.add_argument("--audio_seq_limit", default=2, type=int)
  parser.add_argument("--sampling_rate", default=16000, type=int)

  args = parser.parse_args()

  config = None
  if args.config:
    with open(args.config, 'r') as cfg: config = json.load(cfg)

  task_config = TaskConfig(args, config)

  src_lang = args.src_lang
  src_audio = args.src_audio
  src_audio_raw = args.src_audio_raw
  src_video = args.src_video
  src_video_raw = args.src_video_raw
  src_sema = args.src_sema

  task_config.audio_seq_limit = args.audio_seq_limit * args.sampling_rate

  task_config.src_lang = f"{args.src_dir}/{src_lang}"
  task_config.src_audio = f"{args.src_dir}/{src_audio}"
  task_config.src_audio_raw = f"{args.src_dir}/{src_audio_raw}"
  task_config.src_video = f"{args.src_dir}/{src_video}"
  task_config.src_video_raw = f"{args.src_dir}/{src_video_raw}"
  task_config.src_sema = f"{args.src_dir}/{src_sema}"

  task_config.src_lang_2nd = f"{args.src_dir_2}/{src_lang}"
  task_config.src_audio_2nd = f"{args.src_dir_2}/{src_audio}"
  task_config.src_audio_raw_2nd = f"{args.src_dir_2}/{src_audio_raw}"
  task_config.src_video_2nd = f"{args.src_dir_2}/{src_video}"
  task_config.src_video_raw_2nd = f"{args.src_dir_2}/{src_video_raw}"
  task_config.src_sema_2nd = f"{args.src_dir_2}/{src_sema}"

  task_config.emo_cls = args.category == 'emotion_class'
  task_config.sent_bin = args.category == 'sentiment_binary'

  task_config.category_2 = label_type[args.category if args.category_2 == 'none' else args.category_2]
  task_config.category = label_type[args.category]
  task_config.class_size = len(task_config.category)
  task_config.class_size_2 = len(task_config.category_2)

  assert_args(task_config)

  logger.info("current config:")
  logger.info(task_config)

  return task_config


def assert_args(task_config: TaskConfig):
  if task_config.gradient_accumulation_steps < 1:
    raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(task_config.gradient_accumulation_steps))
  # if not os.path.exists(args.src_audio) or os.path.isdir(args.src_audio):
  #   raise ValueError("parameter 'src_audio' invalid. must exists, and be a pickle file")
  # if not os.path.exists(args.src_audio_raw) or os.path.isdir(args.src_audio_raw):
  #   raise ValueError("parameter 'src_audio_raw' invalid. must exists, and be a pickle file")
  # if not os.path.exists(args.src_video) or os.path.isdir(args.src_video):
  #   raise ValueError("parameter 'src_video' invalid. must exists, and be a pickle file")
  # if not os.path.exists(args.src_video_raw) or os.path.isdir(args.src_video_raw):
  #   raise ValueError("parameter 'src_video_raw' invalid. must exists, and be a pickle file")
  # if not os.path.exists(args.src_sema) or os.path.isdir(args.src_sema):
  #   raise ValueError("parameter 'src_sema' invalid. must exists, and be a pickle file")
