from math import log10
from sympy import Eq
from itertools import groupby

# Key indices for text field
TX_KEY_TOK = 0
TX_KEY_NUM = 1

# Key indices for equation field
EQ_KEY_TPL = 0
EQ_KEY_VAR = 1
EQ_KEY_NUM = 2

PAD_ID = -1

# Key indices for preprocessing and field input
PREP_KEY_EQN = 0
PREP_KEY_ANS = 1
PREP_KEY_MEM = 2

# Token for text field
NUM_TOKEN = '[N]'

# Token for equation field
EQ_VAR_MAX = 4
EQ_VAR_TOKEN = 'X_'
EQ_VAR_PREFIX = 'X_'
EQ_VAR_PATTERN = EQ_VAR_PREFIX + '%%0%dd' % (int(log10(EQ_VAR_MAX)) + 1)
EQ_VAR_UNK = 'X_?'

EQ_NUM_MAX = 32
EQ_NUM_TOKEN = 'N_'
EQ_NUM_PREFIX = 'N_'
EQ_NUM_PATTERN = EQ_NUM_PREFIX + '%%0%dd' % (int(log10(EQ_NUM_MAX)) + 1)
EQ_NUM_UNK = 'N_?'

EQ_CON_TOKEN = 'C'
EQ_CON_PREFIX = 'C_'
EQ_CON_UNK = 'C_?'

EQ_OPR_TOKEN = 'O'
EQ_OPR_PREFIX = 'O_'
EQ_OPR_UNK = 'O_?'

EQ_UNK = {EQ_VAR_UNK, EQ_NUM_UNK, EQ_CON_UNK, EQ_OPR_UNK}

# Constants for readability
FIRST_BATCH_ITEM = 0

# String key names for outputs
IN_TXT = 'text'
IN_TPAD = 'text_pad'
IN_TNUM = 'text_num'
IN_TNPAD = 'text_numpad'
IN_TNIF = 'text_numinfo'

IN_EQN = 'equation'
IN_ETID = 'equation_tok'
IN_ENUM = 'equation_num'
IN_EVAR = 'equation_var'

IN_MASK = 'self_mask'

OUT_GEN = 'gen'
OUT_NUM = 'num'
OUT_VAR = 'var'

OUT_CLS_SCORE = 'cls_score'
OUT_GEN_SCORE = 'gen_score'
OUT_TPL_SCORE = 'tpl_score'
OUT_NUM_SCORE = 'num_score'
OUT_VAR_SCORE = 'var_score'
OUT_CHK_SCORE = 'chk_score'
OUT_NXT_SCORE = 'nxt_score'

REPRODUCE_GEN = 'reprod_gen'
REPRODUCE_NUM = 'reprod_num'
REPRODUCE_VAR = 'reprod_var'

TGT_GEN = 'target_gen'
TGT_CLS = 'target_cls'
TGT_TPL = 'target_tpl'
TGT_VAR = 'target_var'
TGT_NUM = 'target_num'
TGT_CHK = 'target_chk'
TGT_NUM_CMP = 'target_num_cmp'

OUTPUT_KEYS = [
    'equation_token',
    'equation_variable',
    'equation_number'
]

# Names for output fields
OUTPUT_READING_NAMES = ['Token', 'Variable', 'Number']
OUTPUT_FIELD_INDICATORS = [None, 'var', 'num']

OUTPUT_INFO = {
    k: {'name': name, 'output': output, 'token': field}
    for k, name, output, field in zip(range(3), OUTPUT_READING_NAMES, OUTPUT_KEYS, OUTPUT_FIELD_INDICATORS)
}

OPERATORS = {
    '+': {'arity': 2, 'commutable': True, 'top_level': False, 'convert': (lambda *x: x[0] + x[1])},
    '-': {'arity': 2, 'commutable': False, 'top_level': False, 'convert': (lambda *x: x[0] - x[1])},
    '*': {'arity': 2, 'commutable': True, 'top_level': False, 'convert': (lambda *x: x[0] * x[1])},
    '/': {'arity': 2, 'commutable': False, 'top_level': False, 'convert': (lambda *x: x[0] / x[1])},
    '^': {'arity': 2, 'commutable': False, 'top_level': False, 'convert': (lambda *x: x[0] ** x[1])},
    '=': {'arity': 2, 'commutable': True, 'top_level': True,
          'convert': (lambda *x: Eq(x[0], x[1], evaluate=False))}
}

TOP_LEVEL_CLASSES = ['Eq']

ARITY_MAP = {key: [item[-1] for item in lst]
             for key, lst in groupby(sorted([((op['arity'], op['top_level']), key) for key, op in OPERATORS.items()],
                                            key=lambda t: t[0]), key=lambda t: t[0])}

EQUALITIES = ['=']

MINUS = ['-', 'subtract']
PLUS = ['+', 'add']

NEG_INF = float('-inf')
POS_INF = float('inf')


# FOR TUPLE INPUT
# Token for equation field
FUN_NEW_EQN = '__NEW_EQN'
FUN_END_EQN = '__DONE'
FUN_NEW_VAR = '__NEW_VAR'
FUN_TOKENS = [FUN_NEW_EQN, FUN_END_EQN, FUN_NEW_VAR]
FUN_NEW_EQN_ID = FUN_TOKENS.index(FUN_NEW_EQN)
FUN_END_EQN_ID = FUN_TOKENS.index(FUN_END_EQN)
FUN_NEW_VAR_ID = FUN_TOKENS.index(FUN_NEW_VAR)

FUN_TOKENS_WITH_EQ = FUN_TOKENS + ['=']
FUN_EQ_SGN_ID = FUN_TOKENS_WITH_EQ.index('=')

ARG_CON = 'CONST:'
ARG_NUM = 'NUMBER:'
ARG_MEM = 'MEMORY:'
ARG_TOKENS = [ARG_CON, ARG_NUM, ARG_MEM]
ARG_CON_ID = ARG_TOKENS.index(ARG_CON)
ARG_NUM_ID = ARG_TOKENS.index(ARG_NUM)
ARG_MEM_ID = ARG_TOKENS.index(ARG_MEM)
ARG_UNK = 'UNK'
ARG_UNK_ID = 0

SEQ_NEW_EQN = FUN_NEW_EQN
SEQ_END_EQN = FUN_END_EQN
SEQ_UNK_TOK = ARG_UNK
SEQ_TOKENS = [SEQ_NEW_EQN, SEQ_END_EQN, SEQ_UNK_TOK]
SEQ_PTR_NUM = '__NUM'
SEQ_PTR_VAR = '__VAR'
SEQ_PTR_TOKENS = SEQ_TOKENS + [SEQ_PTR_NUM, SEQ_PTR_VAR]
SEQ_NEW_EQN_ID = SEQ_PTR_TOKENS.index(SEQ_NEW_EQN)
SEQ_END_EQN_ID = SEQ_PTR_TOKENS.index(SEQ_END_EQN)
SEQ_UNK_TOK_ID = SEQ_PTR_TOKENS.index(SEQ_UNK_TOK)

TOK_TOK_ID = 0
TOK_NUM_ID = 1

VAR_MAX = 2
NUM_MAX = 32
MEM_MAX = 32

SEQ_PTR_NUM_ID = SEQ_PTR_TOKENS.index(SEQ_PTR_NUM)
SEQ_PTR_VAR_ID = SEQ_PTR_TOKENS.index(SEQ_PTR_VAR)
SEQ_GEN_NUM_ID = SEQ_PTR_NUM_ID
SEQ_GEN_VAR_ID = SEQ_GEN_NUM_ID + NUM_MAX

FORMAT_VAR = 'X_%%0%dd' % (int(log10(VAR_MAX)) + 1)
FORMAT_NUM = 'N_%%0%dd' % (int(log10(NUM_MAX)) + 1)
FORMAT_MEM = 'M_%%0%dd' % (int(log10(MEM_MAX)) + 1)
MEM_PREFIX = 'M_'

TOKEN_GEN = 'token_gen'
TOKEN_PTR = 'token_ptr'
TUPLE_GEN = 'tuple_gen'
TUPLE_PTR = 'tuple_ptr'
