"""
Clean the sentence of the templates to make it consistent to the word embedding results

For example, change
i 'll take care of this for you in cancelling 's account
to
i ll take care of this for you in cancelling s account

Also remove the Response and Mode columns

"""

import pandas as pd
import argparse
import re
import pdb
from string import punctuation
import code

# code.interact(local=vars())
# pdb.set_trace()


DELIMITER = '\t'

argparser = argparse.ArgumentParser()
argparser.add_argument("--input_file", action="store", dest='input_file', default='',
                       help="input template with Template, Response, Mode")
argparser.add_argument("--output_file", action="store", dest='output_file', default='',
                       help="output clean template column with no header")
# argparser.add_argument("--gold", action="store_true", dest='gold',
#                        help="whether golden label column is in the file")

args = argparser.parse_args()

df_in = pd.read_csv(args.input_file, skip_blank_lines=True, sep=DELIMITER)
df_in = df_in[df_in['Template'].str.len() > 0]

# clean the templates
def clean_text(s):
    # remove numbers and punctuations
    s_no_punc = ''.join([c if c not in punctuation and not c.isdigit() else ' ' for c in s.lower()])
    s_clean = re.sub('\s+', ' ', s_no_punc)
    s_clean = s_clean.replace('uci token', 'UCI_TOKEN')
    s_clean = s_clean.replace('agentturn', '')
    #for w in range(len(words)):
    #    words[w] = ''.join(c for c in words[w] if c not in punctuation)
    # s_clean = ' '.join(re.split('\s+', ' '.join(s_no_punc)))
    return s_clean.strip()


template_pools = df_in['Template'].apply(lambda x: clean_text(str(x)))

if 'Gold' not in df_in.columns:
    df_in['Gold'] = pd.Series([None]*df_in.shape[0])

"""
# check how many input are not in the embedding vocab
glove_embed = pd.read_csv('/home/jxchen/efs/1Projects/E2E/template-cluster/input/all-text-agent-ALL-clean-norm-cl-vectors.glove.txt',
                          header=None, sep=' ')
vocab_dict = set()

for i in glove_embed[0]:
    vocab_dict.add(i)

unknown_chars = []
for i in template_pools:
    for j in i.split():
        if j not in vocab_dict and j not in unknown_chars:
            unknown_chars.append(j)

print('The number of unknown tokens is: ' + len(unknown_chars))
"""


output_df = pd.DataFrame(pd.concat([template_pools, df_in['Gold']], axis=1))
output_df.columns = ['Template', 'Gold']
output_df.to_csv(args.output_file, index=False, sep=DELIMITER)



# tail -n+2 templates-cancel-prime-clean > tmp
# cat gold500-clean tmp > cp-clust-clean
# rm tmp