from cgi import test
import os
import sys
import random
import unicodedata
import argparse
import numpy as np
from numpy.random import choice

def get_seg_dict(word_seg_file):
    d = {}
    with open(word_seg_file, "r") as f:
        lines = f.readlines()
        for line in lines:
            line = line.replace('▁','').strip()
            if ('\t' not in line):
                text = line
                value = 0
            else:
                text, value = line.split('\t')
            value = float(value)
            word = text.replace(' ', '')
            seg = text.split()
            if (d.get(word) is None):
                d[word] = []
            d[word].append((seg, value))
    return d

def get_random_weights(word_dict):
    weights = {}
    for word, segs in word_dict.items():
        vs = [v/T for s, v in segs]
        sumev = np.sum(np.exp(vs))
        weights[word] = [np.exp(v)/sumev for v in vs]
    return weights

def seg_line_1(line, seg_dict):
    seged_line = []
    for word in line.split():
        if (seg_dict.get(word)==None):
            seg = list(word)
            print ("No seg", word)
        else:
            seg = seg_dict[word][0][0]
        seged_line.append('▁'+' '.join(seg))
    return ' '.join(seged_line)

def seg_line_random(line, seg_dict):
    seged_line = []
    for word in line.split():
        if (seg_dict.get(word)==None):
            seg = word
        else:
            seg = random.choice(seg_dict[word])
        seged_line.append('▁'+' '.join(seg))
    return ' '.join(seged_line)

def seg_line_random_weight(line, seg_dict, weights):
    seged_line = []
    for word in line.split():
        if (seg_dict.get(word)==None):
            print ("not found:", word)
            seg = list(word)
        else:
            seg, v= random.choices(seg_dict[word], weights=weights[word])[0]
        seged_line.append('▁'+' '.join(seg))
    return ' '.join(seged_line)

def seg_corpus(corpus_file, output_file, seg_dict, weights, regularization=0):
    with open(corpus_file, "r") as f, open(output_file, "w") as of:
        lines = f.readlines()
        for line in lines:
            line = unicodedata.normalize("NFKC", line).strip()
            if (regularization):
                seged_line = seg_line_random_weight(line, seg_dict, weights)
            else:
                seged_line = seg_line_1(line, seg_dict)
            of.write(seged_line.strip()+'\n')

def check_word_dict(word_dict):
    tot = len(word_dict)
    more_than_one = 0
    for word, segs in word_dict.items():
        if len(segs)>1:
            more_than_one += 1
        print (word, segs)
    print (f'{more_than_one} {tot}')

parser = argparse.ArgumentParser()
# read word_seg_file, corpus_file, output_file
parser.add_argument('--word_seg_file', type=str, default=None)
parser.add_argument('--corpus_file', type=str, default=None)
parser.add_argument('--output_file', type=str, default=None)
parser.add_argument('--T', type=float, default=1.0)
parser.add_argument('--regularization', type=int, default=0)

args = parser.parse_args()
word_seg_file = args.word_seg_file
corpus_file = args.corpus_file
output_file = args.output_file
T = args.T
regularization = args.regularization

os.makedirs(os.path.dirname(output_file), exist_ok=True)

word_dict = get_seg_dict(word_seg_file)
weights = get_random_weights(word_dict)
seg_corpus(corpus_file, output_file, word_dict, weights, regularization)
