import os
import time
import datetime
import re
import argparse
import sys
import random

import json
import numpy as np

sys.path.append("..")
from utils.logConfig import Log

random.seed(666)


def readTxt(name):
    logger.info("Reading file %s", name)
    data = []
    with open(name, 'rb') as fin:
        for line in fin:
            data.append(line.decode('utf-8').replace(u'\u2028', '').replace('\r', '').strip())
    return data
    
def saveTxt(name, data):
    logger.info("Saving data with size %d to file %s", len(data), name)
    with open(name, 'w') as fout:
        for d in data:
            fout.write(d + "\n")


def readJson(name):
    logger.info("Reading file %s", name)
    data = {}
    with open(name) as fin:
        for line in fin:
            e = json.loads(line)
            if e['src_lang'] in data.keys():
                data[e['src_lang']].append({e['src_text']: e['trg_text']})
            else:
                data[e['src_lang']] = [{e['src_text']: e['trg_text']}]
    return data

def readJson2(name):
    logger.info("Reading file %s", name)
    data = {}
    with open(name) as fin:
        for line in fin:
            e = json.loads(line)
            data[e['src_text']] = e['trg_text']
    return data

def saveJson(name, data):
    logger.info("Saving data with size %d to file %s", len(data), name)
    with open(name, 'w') as fout:
        for d in data:
            fout.write(json.dumps(d, ensure_ascii=False) + "\n")

def selectCand(line, seed, delimiter="<q>"):
    data = [d.strip() for d in line.split('|')]
    # print(data)
    data = [d for d in data if d.strip() != '']
    if len(data) < 1:
        return ""
    idx = np.random.randint(len(data))
    return data[idx]

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', type=str, default='input.txt', help='input path')
    parser.add_argument('-o', type=str, default='output', help='output dir')
    parser.add_argument('-r', type=str, default='reference.txt', help='reference dir')

    args = parser.parse_args()

    logger = Log.getLogger(os.path.basename(sys.argv[0]), None)
    logger.info(args)

    # data = readJson(args.i)
    # for key, value in data.items():
    #     filename = args.i.split('/')[-1]
    #     saveJson(os.path.join(args.o, filename + '.' + key), value)

    data = []
    if os.path.isdir(args.i):
        files = [os.path.join(args.i, f) for f in os.listdir(args.i)]
        for f in files:
            data.append(readJson2(f))
    else:
        data = readJson2(args.i)
    reference = readTxt(args.r)
    trans = []
    for r in reference:
        line = ''
        try:
            line = data[r]
        except:
            pass
            # print('Key error for {}'.format(r))
        trans.append(line)
    lg = args.i.split('/')[-1].split('.')[-1]
    if lg.startswith('2'):
        lg = lg[1:3]
    filename = args.r.split('/')[-1]
    saveTxt(os.path.join(args.o, filename + '.' + lg), trans)
    
    # lines = readTxt(args.i)
    # candidate = [selectCand(line, 666) for line in lines]
    # saveTxt(args.o, candidate)
    



    

    