from flask import Flask, jsonify, request
import kiwi
from kiwi import constants as const
from flask_cors import CORS
import re
from string import punctuation

# load model
model = kiwi.load_model('trained_models/estimator_en_de/estimator_en_de.torch')
# app
app = Flask(__name__)
CORS(app)


# routes
@app.route('/', methods=['POST'])
def predict():
    data = request.get_json(force=True)
    threshold = 0.5
    source = data['source']
    mt = data['mt']
    model_out = model.predict({const.SOURCE: [source.lower()], const.TARGET: [mt.lower()]})
    bad_probs = model_out[const.TARGET_TAGS][0]
    global red_array, green_array, final_array
    red_array = []
    green_array = []
    final_array = []
    for i in range(0, len(mt.split())):
        final_array.append(bad_prob[i])

    pattern = re.compile(r'(\s+|[{}])'.format(re.escape(punctuation)))
    color_array = []
    count = 0
    for i in range(0,len(mt.split())):
        for part in pattern.split(str(mt.split()[i])):
            count+=1
            part = part.strip()
            if part != '':
                color_array.append(final_array[i])
    print(color_array)
    f = open("result_final_.txt", "a")
    output = {'qualityLabels':color_array}
    f.write(str(color_array) + '\n')
    f.close()
    return jsonify(results = output)


if __name__ == '__main__':
    app.run(port=5000, debug=True)