'''
code to generate extended-bAbI (already generated)
'''
import random
import copy
import os
import json
from templates import TMEPLATE
from ontology import ONTO, PARAMS
from collections import defaultdict

hyparams = PARAMS
domains = ['restaurants','hotels', 'travels', 'movies', 'music',
           'buses', 'flights', 'weather']

class DialogSimulator:
    def __init__(self,
                 domain,
                 num_dial=3000,
                 num_trn=1500,
                 num_dev=500,
                 num_tst=1000,
                 maxnum_req_slots=3,
                 slot_update_ratio=(0,1,2,3),  # at least 0 time
                 entity_update_ratio=(1, 2, 3),  # at least one time
                 ):
        self.domain = domain
        self.user_template = TMEPLATE[domain]['user']
        self.sys_template = TMEPLATE[domain]['sys']
        self.onto = ONTO[self.domain]
        self.vocab = {"$u": 0, "$r": 1, "UNK": 3}
        self.data_root = 'data'
        if not os.path.exists(self.data_root):
            os.mkdir(self.data_root)
        self.save_dir = os.path.join(self.data_root, domain)
        if not os.path.exists(self.save_dir):
            os.mkdir(self.save_dir)
        self.max_turn = 0
        self.max_sent = 0
        assert num_dial == num_trn + num_dev + num_tst
        self.num_total = num_dial
        self.num_trn = num_trn
        self.num_dev = num_dev
        self.num_tst = num_tst
        self.maxnum_req_slots = maxnum_req_slots
        self.candidates = defaultdict(lambda: 0)

        self.requestable_slots = [s for s in self.onto if self.onto[s] is None]
        self.informable_slots = {s: self.onto[s] for s in self.onto if self.onto[s] is not None}
        self.change_times_probs = self.normalize(slot_update_ratio)
        self.entity_times_probs = self.normalize(entity_update_ratio)

    @staticmethod
    def normalize(dist):
        total = sum(dist)
        return [p / total for p in dist]

    def save(self, file_name, data):
        with open(os.path.join(self.save_dir, file_name), 'w') as f:
            for dial in data:
                for turn_id, turn in enumerate(dial):
                    f.write('%d %s || %s\n' % (turn_id + 1, turn[0], turn[1]))
                f.write('\n')

    def render_slots(self, s):
        pass

    def generate_dial(self):
        pass

    def run(self):
        total_dial = []
        for _ in range(self.num_total):
            total_dial.append(self.generate_dial())
        random.shuffle(total_dial)
        trn_dial = total_dial[:self.num_trn]
        dev_dial = total_dial[self.num_trn:self.num_trn+self.num_dev]
        tst_dial = total_dial[self.num_trn+self.num_dev:]
        self.save('trn.txt', trn_dial)
        self.save('dev.txt', dev_dial)
        self.save('tst.txt', tst_dial)
        with open(os.path.join(self.save_dir, 'candidates.json'), 'w') as f:
            json.dump(self.candidates, f, indent=1)
        for i in range(1, self.max_turn + 1):
            self.vocab['#%d' % i] = len(self.vocab)
        with open(os.path.join(self.save_dir, 'vocab.json'), 'w') as f:
            json.dump(self.vocab, f, indent=1)
        print(self.domain, 'max sent', self.max_sent, 'max turn', self.max_turn)


class MovieSimulator(DialogSimulator):
    def __init__(self):
        super(MovieSimulator, self).__init__(
            domain='movies',
            maxnum_req_slots=hyparams['movies']['maxnum_req_slots'],
            slot_update_ratio=hyparams['movies']['slot_update_ratio'],
            entity_update_ratio=hyparams['movies']['entity_update_ratio'])

    def render_slots(self, s):
        if s == 'director':
            return random.choice(['directed by %s', 'made by %s'])
        if s == 'actor':
            return random.choice(['acted by %s', 'starred by %s'])
        else:
            return random.choice(['for %s genre', 'for %s type', 'in the type of %s'])

    def generate_dial(self):
        dial = []
        user_act = 'welcome'
        sys_act = 'welcome'
        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        rest_s = []
        DS = {s: None for s in self.informable_slots}
        change_weights = {s: 1 for s in DS}

        requested_slots = copy.deepcopy(self.requestable_slots)
        random.shuffle(requested_slots)
        num_req = random.choice(range(1, 1+self.maxnum_req_slots))
        requested_slots = requested_slots[:num_req]

        slot_fill_flag = True
        slot_change_flag = False
        entity_rec_flag = False
        start = False
        request_flag = False
        end = False
        rec_times = 0
        change_times = 0
        entityOrder = 0
        turn_num = 0

        while sys_act != 'byebye':
            dial.append((user_sent, sys_sent))
            for w in (user_sent + ' ' + sys_sent).split():
                if w not in self.vocab: self.vocab[w] = len(self.vocab)
            self.max_sent = max(self.max_sent, len(user_sent.split()))
            self.candidates[sys_sent] += 1
            turn_num += 1
            if user_act == 'welcome':
                # user inital slot filling
                slots = [s for s in self.onto.keys() if random.random() < 0.5 and self.onto[s] is not None]
                init_sv = {s: random.choice(self.onto[s]) for s in slots} # init slot values
                for k, v in init_sv.items(): DS[k] = v # initial slot filling
                rest_s = [s for s in self.informable_slots if s not in init_sv]
                user_act = 'inform_all'
                param_string = []
                random.shuffle(slots)
                for s in slots:
                    param_string.append(self.render_slots(s) % init_sv[s])
                user_sent = random.choice(self.user_template[user_act]) + ' ' + ' '.join(param_string)

            elif slot_fill_flag and sys_act.startswith('request'):
                s = sys_act.split('_')[1]
                user_act = 'inform_' + s
                v = random.choice(self.onto[s])
                DS[s] = v
                user_sent = random.choice(self.user_template[user_act]) % v

            if request_flag:
                if len(requested_slots) > 0:
                    random.shuffle(requested_slots)
                    s = requested_slots.pop()
                    if s == 'free':
                        user_act = 'request_free'
                    elif s == 'cast':
                        user_act = 'request_cast'
                    elif s == 'length':
                        user_act = 'request_length'
                    else:
                        user_act = 'request_introduction'
                    user_sent = random.choice(self.user_template[user_act])

                    sys_act = 'inform'
                    sys_sent = random.choice(self.sys_template[sys_act]) % (entityOrder, s)
                else:
                    user_act = 'thanks'
                    user_sent = random.choice(self.user_template[user_act])
                    sys_act = 'req_more'
                    sys_sent = random.choice(self.sys_template[sys_act])
                    request_flag = False
                    end = True
                    continue

            if entity_rec_flag:
                if start:
                    rec_times = random.choices(range(1, len(self.entity_times_probs) + 1), self.entity_times_probs)[0]
                    start = False
                    # print('rec_times:', rec_times)
                    user_act = 'affirm'
                    user_sent = random.choice(self.user_template[user_act])
                    sys_act = 'recommend'
                    sys_sent = random.choice(self.sys_template[sys_act]) % entityOrder
                    rec_times -= 1
                    continue
                if rec_times > 0:
                    entityOrder += 1
                    user_act = 'refuse'
                    tmp_list = copy.deepcopy(self.user_template[user_act])
                    random.shuffle(tmp_list)
                    user_sent = random.choice(tmp_list)
                    sys_act = 'next_recommend'
                    sys_sent = random.choice(self.sys_template[sys_act]) % entityOrder
                    rec_times -= 1
                else:
                    user_act = 'agree'
                    user_sent = random.choice(self.user_template[user_act])
                    sys_act = 'reserve'
                    sys_sent = random.choice(self.sys_template[sys_act])
                    entity_rec_flag = False
                    request_flag = True

            if slot_change_flag:
                if change_times > 0:
                    p = change_weights.values()
                    s = random.choices(list(DS.keys()), self.normalize(p))[0]
                    change_weights[s] /= 3
                    # print(change_weights)
                    candid_vs = copy.deepcopy(self.onto[s])
                    candid_vs.remove(DS[s])
                    v = random.choice(candid_vs)
                    v_string = self.render_slots(s) % v
                    DS[s] = v
                    user_act = 'change'
                    # print('change:' , v_string)
                    tmp_list = copy.deepcopy(self.user_template[user_act])
                    random.shuffle(tmp_list)
                    user_sent = random.choice(tmp_list) % v_string
                    change_times -= 1

                    sys_act = 'req_change'
                    sys_sent = random.choice(self.sys_template[sys_act])

                else:
                    user_act = 'deny'
                    user_sent = random.choice(self.user_template[user_act])
                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                    slot_change_flag = False
                    entity_rec_flag = True
                    start = True
                continue

            if slot_fill_flag:
                if len(rest_s) > 0:
                    # random.shuffle(rest_s)
                    s = rest_s.pop(0)
                    sys_act = 'request_' + s
                    sys_sent = random.choice(self.sys_template[sys_act])
                elif len(rest_s) == 0:
                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                    slot_fill_flag = False
                    slot_change_flag = True
                    change_times = random.choices(range(len(self.change_times_probs)), self.change_times_probs)[0]
                    if change_times == 0:
                        slot_change_flag = False
                        entity_rec_flag = True
                        start = True
                    # print('change times:', change_times)
                continue

            if end:
                user_act = 'byebye'
                sys_act = 'byebye'
        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        dial.append((user_sent, sys_sent))
        for w in (user_sent + ' ' + sys_sent).split():
            if w not in self.vocab: self.vocab[w] = len(self.vocab)
        self.max_sent = max(self.max_sent, len(user_sent.split()))
        self.candidates[sys_sent] += 1
        turn_num += 1
        self.max_turn = max(turn_num, self.max_turn)
        return dial


class RestaurantSimulator(DialogSimulator):
    def __init__(self):
        super(RestaurantSimulator, self).__init__(
            domain='restaurants',
            maxnum_req_slots=hyparams['restaurants']['maxnum_req_slots'],
            slot_update_ratio=hyparams['restaurants']['slot_update_ratio'],
            entity_update_ratio=hyparams['restaurants']['entity_update_ratio'])

    def render_slots(self, s):
        if s == 'food':
            return random.choice(['with %s food', 'with %s cuisine', 'with %s type'])
        elif s == 'location':
            return random.choice(['in %s', 'at %s'])
        elif s == 'people':
            return random.choice(['for %s', 'for %s people', 'for totally %s people'])
        else:
            return random.choice(['in a %s price', 'in a %s price range'])


    def generate_dial(self):
        dial = []
        user_act = 'welcome'
        sys_act = 'welcome'
        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        rest_s = []
        DS = {s: None for s in self.informable_slots}
        change_weights = {s: 1 for s in DS}

        requested_slots = copy.deepcopy(self.requestable_slots)
        random.shuffle(requested_slots)
        num_req = random.choice(range(1, 1+self.maxnum_req_slots))
        requested_slots = requested_slots[:num_req]

        slot_fill_flag = True
        slot_change_flag = False
        entity_rec_flag = False
        start = False
        request_flag = False
        end = False
        rec_times = 0
        change_times = 0
        entityOrder = 0
        turn_num = 0

        onto = self.onto
        user_temp = self.user_template
        sys_temp = self.sys_template

        while sys_act != 'byebye':
            dial.append((user_sent, sys_sent))
            for w in (user_sent + ' ' + sys_sent).split():
                if w not in self.vocab: self.vocab[w] = len(self.vocab)
            self.max_sent = max(self.max_sent, len(user_sent.split()))
            self.candidates[sys_sent] += 1
            turn_num += 1
            if user_act == 'welcome':
                # user inital slot filling
                slots = [s for s in onto.keys() if random.random() < 0.4 and onto[s] is not None]
                init_sv = {s: random.choice(onto[s]) for s in slots}
                for k, v in init_sv.items(): DS[k] = v
                rest_s = [s for s in onto.keys() if onto[s] is not None and s not in init_sv]
                # print(rest_s)
                user_act = 'inform_all'
                param_string = []
                random.shuffle(slots)
                for s in slots:
                    param_string.append(self.render_slots(s) % init_sv[s])
                user_sent = random.choice(user_temp[user_act]) + ' ' + ' '.join(param_string)

            elif slot_fill_flag and sys_act.startswith('request'):
                s = sys_act.split('_')[1]
                user_act = 'inform_' + s
                v = random.choice(onto[s])
                DS[s] = v
                user_sent = random.choice(user_temp[user_act]) % v

            if request_flag:
                if len(requested_slots) > 0:
                    random.shuffle(requested_slots)
                    s = requested_slots.pop()
                    user_act = 'request'
                    # print('requestable slot:', s)
                    user_sent = random.choice(user_temp[user_act]) % s
                    sys_act = 'inform'
                    sys_sent = random.choice(sys_temp[sys_act]) % (entityOrder, s)
                else:
                    user_act = 'thanks'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'req_more'
                    sys_sent = random.choice(sys_temp[sys_act])
                    request_flag = False
                    end = True
                    continue

            if entity_rec_flag:
                if start:
                    rec_times = random.choices(range(1, len(self.entity_times_probs) + 1), self.entity_times_probs)[0]
                    start = False
                    # print('rec_times:', rec_times)
                    user_act = 'affirm'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'recommend'
                    sys_sent = random.choice(sys_temp[sys_act]) % entityOrder
                    rec_times -= 1
                    continue
                if rec_times > 0:
                    entityOrder += 1
                    user_act = 'refuse'
                    tmp_list = copy.deepcopy(user_temp[user_act])
                    random.shuffle(tmp_list)
                    user_sent = random.choice(tmp_list)
                    sys_act = 'next_recommend'
                    sys_sent = random.choice(sys_temp[sys_act]) % entityOrder
                    rec_times -= 1
                else:
                    user_act = 'agree'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'reserve'
                    sys_sent = random.choice(sys_temp[sys_act])
                    entity_rec_flag = False
                    request_flag = True

            if slot_change_flag:
                if change_times > 0:
                    p = change_weights.values()
                    s = random.choices(list(DS.keys()), self.normalize(p))[0]
                    change_weights[s] /= 3
                    # print(change_weights)
                    candid_vs = copy.deepcopy(onto[s])
                    candid_vs.remove(DS[s])
                    v = random.choice(candid_vs)
                    v_string = self.render_slots(s) % v
                    DS[s] = v
                    user_act = 'change'
                    # print('change:' , v_string)
                    tmp_list = copy.deepcopy(user_temp[user_act])
                    random.shuffle(tmp_list)
                    user_sent = random.choice(tmp_list) % v_string
                    change_times -= 1

                    sys_act = 'req_change'
                    sys_sent = random.choice(sys_temp[sys_act])

                else:
                    user_act = 'deny'
                    user_sent = random.choice(user_temp[user_act])

                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                    slot_change_flag = False
                    entity_rec_flag = True
                    start = True
                continue

            if slot_fill_flag:
                if len(rest_s) > 0:
                    # random.shuffle(rest_s)
                    s = rest_s.pop(0)
                    sys_act = 'request_' + s
                    sys_sent = random.choice(sys_temp[sys_act])
                elif len(rest_s) == 0:
                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                    slot_fill_flag = False

                    slot_change_flag = True
                    change_times = random.choices(range(len(self.change_times_probs)), self.change_times_probs)[0]
                    if change_times == 0:
                        slot_change_flag = False
                        entity_rec_flag = True
                        start = True
                    # print('change times:', change_times)
                continue

            if end:
                user_act = 'byebye'
                sys_act = 'byebye'

        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        dial.append((user_sent, sys_sent))
        for w in (user_sent + ' ' + sys_sent).split():
            if w not in self.vocab: self.vocab[w] = len(self.vocab)
        self.max_sent = max(self.max_sent, len(user_sent.split()))
        self.candidates[sys_sent] += 1
        turn_num += 1
        self.max_turn = max(turn_num, self.max_turn)
        return dial


class HotelSimulator(DialogSimulator):
    def __init__(self):
        super(HotelSimulator, self).__init__(
            domain='hotels',
            maxnum_req_slots=hyparams['hotels']['maxnum_req_slots'],
            slot_update_ratio=hyparams['hotels']['slot_update_ratio'],
            entity_update_ratio=hyparams['hotels']['entity_update_ratio'])

    def render_slots(self, s):
        if s == 'checkin':
            return random.choice(['from %s', 'starting from %s', 'beginning at %s'])
        elif s == 'checkout':
            return random.choice(['to %s', 'ending at %s', 'till %s'])
        elif s == 'location':
            return random.choice(['in %s'])
        elif s == 'room':
            return random.choice(['for %s rooms'])
        else:
            return random.choice(['in a %s price'])


    def generate_dial(self):
        dial = []
        user_act = 'welcome'
        sys_act = 'welcome'
        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        rest_s = []
        DS = {s: None for s in self.informable_slots}
        change_weights = {s: 1 for s in DS}

        requested_slots = copy.deepcopy(self.requestable_slots)
        random.shuffle(requested_slots)
        num_req = random.choice(range(1, 1+self.maxnum_req_slots))
        requested_slots = requested_slots[:num_req]

        slot_fill_flag = True
        slot_change_flag = False
        entity_rec_flag = False
        start = False
        request_flag = False
        end = False
        rec_times = 0
        change_times = 0
        entityOrder = 0
        turn_num = 0

        onto = self.onto
        user_temp = self.user_template
        sys_temp = self.sys_template

        while sys_act != 'byebye':
            dial.append((user_sent, sys_sent))
            for w in (user_sent + ' ' + sys_sent).split():
                if w not in self.vocab: self.vocab[w] = len(self.vocab)
            self.max_sent = max(self.max_sent, len(user_sent.split()))
            self.candidates[sys_sent] += 1
            turn_num += 1
            if user_act == 'welcome':
                # user inital slot filling
                slots = [s for s in onto.keys() if random.random() < 0.4 and onto[s] is not None]
                init_sv = {s: random.choice(onto[s]) for s in slots}
                for k, v in init_sv.items(): DS[k] = v
                rest_s = [s for s in onto.keys() if onto[s] is not None and s not in init_sv]
                user_act = 'inform_all'
                param_string = []
                random.shuffle(slots)
                for s in slots:
                    if s == 'room' and init_sv[s] == 'one':
                        param_string.append('for one room')
                    else:
                        param_string.append(self.render_slots(s) % init_sv[s])
                user_sent = random.choice(user_temp[user_act]) + ' ' + ' '.join(param_string)

            elif slot_fill_flag and sys_act.startswith('request'):
                s = sys_act.split('_')[1]
                user_act = 'inform_' + s
                v = random.choice(onto[s])
                DS[s] = v
                if s == 'room' and v == 'one':
                    user_sent = random.choice(user_temp['inform_room_single']) % v
                else:
                    user_sent = random.choice(user_temp[user_act]) % v

            if request_flag:
                if len(requested_slots) > 0:
                    random.shuffle(requested_slots)
                    s = requested_slots.pop()
                    user_act = 'request'
                    # print('requestable slot:', s)
                    user_sent = random.choice(user_temp[user_act]) % s
                    sys_act = 'inform'
                    sys_sent = random.choice(sys_temp[sys_act]) % (entityOrder, s)
                else:
                    user_act = 'thanks'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'req_more'
                    sys_sent = random.choice(sys_temp[sys_act])
                    request_flag = False
                    end = True
                    continue

            if entity_rec_flag:
                if start:
                    rec_times = random.choices(range(1, len(self.entity_times_probs) + 1), self.entity_times_probs)[0]
                    start = False
                    # print('rec_times:', rec_times)
                    user_act = 'affirm'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'recommend'
                    sys_sent = random.choice(sys_temp[sys_act]) % entityOrder
                    rec_times -= 1
                    continue
                if rec_times > 0:
                    entityOrder += 1
                    user_act = 'refuse'
                    tmp_list = copy.deepcopy(user_temp[user_act])
                    random.shuffle(tmp_list)
                    user_sent = random.choice(tmp_list)
                    sys_act = 'next_recommend'
                    sys_sent = random.choice(sys_temp[sys_act]) % entityOrder
                    rec_times -= 1
                else:
                    user_act = 'agree'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'reserve'
                    sys_sent = random.choice(sys_temp[sys_act])
                    entity_rec_flag = False
                    request_flag = True

            if slot_change_flag:
                if change_times > 0:
                    p = change_weights.values()
                    s = random.choices(list(DS.keys()), self.normalize(p))[0]
                    change_weights[s] /= 3
                    # print(change_weights)
                    candid_vs = copy.deepcopy(onto[s])
                    candid_vs.remove(DS[s])
                    v = random.choice(candid_vs)
                    v_string = self.render_slots(s) % v
                    DS[s] = v
                    user_act = 'change'
                    # print('change:' , v_string)
                    tmp_list = copy.deepcopy(user_temp[user_act])
                    random.shuffle(tmp_list)
                    user_sent = random.choice(tmp_list) % v_string
                    change_times -= 1

                    sys_act = 'req_change'
                    sys_sent = random.choice(sys_temp[sys_act])

                else:
                    user_act = 'deny'
                    user_sent = random.choice(user_temp[user_act])

                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                    slot_change_flag = False
                    entity_rec_flag = True
                    start = True
                continue

            if slot_fill_flag:
                if len(rest_s) > 0:
                    # random.shuffle(rest_s)
                    s = rest_s.pop(0)
                    sys_act = 'request_' + s
                    sys_sent = random.choice(sys_temp[sys_act])
                elif len(rest_s) == 0:
                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                    slot_fill_flag = False

                    slot_change_flag = True
                    change_times = random.choices(range(len(self.change_times_probs)), self.change_times_probs)[0]
                    if change_times == 0:
                        slot_change_flag = False
                        entity_rec_flag = True
                        start = True
                    # print('change times:', change_times)
                continue

            if end:
                user_act = 'byebye'
                sys_act = 'byebye'

        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        dial.append((user_sent, sys_sent))
        for w in (user_sent + ' ' + sys_sent).split():
            if w not in self.vocab: self.vocab[w] = len(self.vocab)
        self.max_sent = max(self.max_sent, len(user_sent.split()))
        self.candidates[sys_sent] += 1
        turn_num += 1
        self.max_turn = max(turn_num, self.max_turn)
        return dial


class TravelSimulator(DialogSimulator):
    def __init__(self):
        super(TravelSimulator, self).__init__(
            domain='travels',
            maxnum_req_slots=hyparams['travels']['maxnum_req_slots'],
            slot_update_ratio=hyparams['travels']['slot_update_ratio'],
            entity_update_ratio=hyparams['travels']['entity_update_ratio'])

    def render_slots(self, s):
        if s == 'category':
            return random.choice(['for %s', 'for some %s'])
        else:
            return random.choice(['in %s'])

    def generate_dial(self):
        dial = []
        user_act = 'welcome'
        sys_act = 'welcome'
        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        rest_s = []
        DS = {s: None for s in self.informable_slots}
        change_weights = {s: 1 for s in DS}

        requested_slots = copy.deepcopy(self.requestable_slots)
        random.shuffle(requested_slots)
        num_req = random.choice(range(1, 1+self.maxnum_req_slots))
        requested_slots = requested_slots[:num_req]

        slot_fill_flag = True
        slot_change_flag = False
        entity_rec_flag = False
        start = False
        request_flag = False
        end = False
        rec_times = 0
        change_times = 0
        entityOrder = 0
        turn_num = 0

        onto = self.onto
        user_temp = self.user_template
        sys_temp = self.sys_template

        while sys_act != 'byebye':
            dial.append((user_sent, sys_sent))
            for w in (user_sent + ' ' + sys_sent).split():
                if w not in self.vocab: self.vocab[w] = len(self.vocab)
            self.max_sent = max(self.max_sent, len(user_sent.split()))
            self.candidates[sys_sent] += 1
            turn_num += 1
            if user_act == 'welcome':
                # user inital slot filling
                slots = [s for s in onto.keys() if random.random() < 0.3 and onto[s] is not None]
                # slots = ['food']
                init_sv = {s: random.choice(onto[s]) for s in slots}
                # slots = []
                # init_sv = {}
                for k, v in init_sv.items(): DS[k] = v
                rest_s = [s for s in onto.keys() if onto[s] is not None and s not in init_sv]
                if len(rest_s) == 2:
                    user_act = 'inform'
                    user_sent = random.choice(user_temp[user_act])
                elif len(rest_s) == 0:
                    user_act = 'inform_all'
                    user_sent = random.choice(user_temp[user_act]) % (DS['category'], DS['country'])
                elif rest_s[0] == 'category':
                    user_act = 'inform_country'
                    user_sent = random.choice(user_temp[user_act]) % DS['country']
                else:
                    user_act = 'inform_category'
                    user_sent = random.choice(user_temp[user_act]) % DS['category']


            elif slot_fill_flag and sys_act.startswith('request'):
                s = sys_act.split('_')[1]
                user_act = 'inform_' + s
                v = random.choice(onto[s])
                DS[s] = v
                user_sent = random.choice(user_temp[user_act]) % v

            if request_flag:
                if len(requested_slots) > 0:
                    random.shuffle(requested_slots)
                    s = requested_slots.pop()
                    if s == 'free entry':
                        user_act = 'request_free entry'
                        user_sent = random.choice(user_temp[user_act])
                    else:
                        user_act = 'request'
                        user_sent = random.choice(user_temp[user_act]) % s
                    sys_act = 'inform'
                    sys_sent = random.choice(sys_temp[sys_act]) % (entityOrder, s)
                else:
                    user_act = 'thanks'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'req_more'
                    sys_sent = random.choice(sys_temp[sys_act])
                    request_flag = False
                    end = True
                    continue

            if entity_rec_flag:
                if start:
                    rec_times = random.choices(range(1, len(self.entity_times_probs) + 1), self.entity_times_probs)[0]
                    start = False
                    # print('rec_times:', rec_times)
                    user_act = 'affirm'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'recommend'
                    sys_sent = random.choice(sys_temp[sys_act]) % entityOrder
                    rec_times -= 1
                    continue
                if rec_times > 0:
                    entityOrder += 1
                    user_act = 'refuse'
                    tmp_list = copy.deepcopy(user_temp[user_act])
                    random.shuffle(tmp_list)
                    user_sent = random.choice(tmp_list)
                    sys_act = 'next_recommend'
                    sys_sent = random.choice(sys_temp[sys_act]) % entityOrder
                    rec_times -= 1
                else:
                    user_act = 'agree'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'reserve'
                    sys_sent = random.choice(sys_temp[sys_act])
                    entity_rec_flag = False
                    request_flag = True

            if slot_change_flag:
                if change_times > 0:
                    p = change_weights.values()
                    s = random.choices(list(DS.keys()), self.normalize(p))[0]
                    change_weights[s] /= 2
                    # print(change_weights)
                    candid_vs = copy.deepcopy(onto[s])
                    candid_vs.remove(DS[s])
                    v = random.choice(candid_vs)
                    v_string = self.render_slots(s) % v
                    DS[s] = v
                    user_act = 'change'
                    # print('change:' , v_string)
                    tmp_list = copy.deepcopy(user_temp[user_act])
                    random.shuffle(tmp_list)
                    user_sent = random.choice(tmp_list) % v_string
                    change_times -= 1

                    sys_act = 'req_change'
                    sys_sent = random.choice(sys_temp[sys_act])

                else:
                    user_act = 'deny'
                    user_sent = random.choice(user_temp[user_act])

                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                    slot_change_flag = False
                    entity_rec_flag = True
                    start = True
                continue

            if slot_fill_flag:
                if len(rest_s) > 0:
                    # random.shuffle(rest_s)
                    s = rest_s.pop(0)
                    sys_act = 'request_' + s
                    sys_sent = random.choice(sys_temp[sys_act])
                elif len(rest_s) == 0:
                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                    slot_fill_flag = False

                    slot_change_flag = True
                    change_times = random.choices(range(len(self.change_times_probs)), self.change_times_probs)[0]
                    if change_times == 0:
                        slot_change_flag = False
                        entity_rec_flag = True
                        start = True
                    # print('change times:', change_times)
                continue

            if end:
                user_act = 'byebye'
                sys_act = 'byebye'

        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        dial.append((user_sent, sys_sent))
        for w in (user_sent + ' ' + sys_sent).split():
            if w not in self.vocab: self.vocab[w] = len(self.vocab)
        self.max_sent = max(self.max_sent, len(user_sent.split()))
        self.candidates[sys_sent] += 1
        turn_num += 1
        self.max_turn = max(turn_num, self.max_turn)
        return dial


class MusicSimulator(DialogSimulator):
    def __init__(self):
        super(MusicSimulator, self).__init__(
            domain='music',
            maxnum_req_slots=hyparams['music']['maxnum_req_slots'],
            slot_update_ratio=hyparams['music']['slot_update_ratio'],
            entity_update_ratio=hyparams['music']['entity_update_ratio'])

    def render_slots(self, s):
        pass


    def generate_dial(self):
        dial = []
        user_act = 'welcome'
        sys_act = 'welcome'
        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        rest_s = []
        DS = {s: None for s in self.informable_slots}
        change_weights = {s: 1 for s in DS}

        requested_slots = copy.deepcopy(self.requestable_slots)
        random.shuffle(requested_slots)
        num_req = random.choice(range(1, 1+self.maxnum_req_slots))
        requested_slots = requested_slots[:num_req]

        slot_fill_flag = True
        slot_change_flag = False
        entity_rec_flag = False
        start = False
        request_flag = False
        end = False
        rec_times = 0
        change_times = 0
        entityOrder = 0
        turn_num = 0

        onto = self.onto
        user_temp = self.user_template
        sys_temp = self.sys_template

        while sys_act != 'byebye':
            dial.append((user_sent, sys_sent))
            for w in (user_sent + ' ' + sys_sent).split():
                if w not in self.vocab: self.vocab[w] = len(self.vocab)
            self.max_sent = max(self.max_sent, len(user_sent.split()))
            self.candidates[sys_sent] += 1
            turn_num += 1
            if user_act == 'welcome':
                # user inital slot filling
                slots = [s for s in onto.keys() if random.random() < 0.3 and onto[s] is not None]
                init_sv = {s: random.choice(onto[s]) for s in slots}
                for k, v in init_sv.items(): DS[k] = v
                rest_s = [s for s in onto.keys() if onto[s] is not None and s not in init_sv]
                if len(rest_s) == 0:
                    user_act = 'inform_all'
                    user_sent = random.choice(user_temp[user_act]) % (init_sv['category'], init_sv['singer'])
                elif len(rest_s) == 2:
                    user_act = 'inform'
                    user_sent = random.choice(user_temp[user_act])
                else:
                    if rest_s[0] == 'singer':
                        user_act = 'inform_category'
                        user_sent = random.choice(user_temp[user_act]) % init_sv['category']
                    else:
                        user_act = 'inform_singer'
                        user_sent = random.choice(user_temp[user_act]) % init_sv['singer']

            elif slot_fill_flag and sys_act.startswith('request'):
                s = sys_act.split('_')[1]
                user_act = 'inform_' + s
                v = random.choice(onto[s])
                DS[s] = v
                user_sent = random.choice(user_temp[user_act]) % v

            if request_flag:
                if len(requested_slots) > 0:
                    random.shuffle(requested_slots)
                    s = requested_slots.pop()
                    user_act = 'request'
                    # print('requestable slot:', s)
                    user_sent = random.choice(user_temp[user_act]) % s
                    sys_act = 'inform'
                    sys_sent = random.choice(sys_temp[sys_act]) % (entityOrder, s)
                else:
                    user_act = 'thanks'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'req_more'
                    sys_sent = random.choice(sys_temp[sys_act])
                    request_flag = False
                    end = True
                    continue

            if entity_rec_flag:
                if start:
                    rec_times = random.choices(range(1, len(self.entity_times_probs) + 1), self.entity_times_probs)[0]
                    start = False
                    # print('rec_times:', rec_times)
                    user_act = 'affirm'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'recommend'
                    sys_sent = random.choice(sys_temp[sys_act]) % entityOrder
                    rec_times -= 1
                    continue
                if rec_times > 0:
                    entityOrder += 1
                    user_act = 'refuse'
                    tmp_list = copy.deepcopy(user_temp[user_act])
                    random.shuffle(tmp_list)
                    user_sent = random.choice(tmp_list)
                    sys_act = 'next_recommend'
                    sys_sent = random.choice(sys_temp[sys_act]) % entityOrder
                    rec_times -= 1
                else:
                    user_act = 'agree'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'reserve'
                    sys_sent = random.choice(sys_temp[sys_act])
                    entity_rec_flag = False
                    request_flag = True

            if slot_change_flag:
                if change_times > 0:
                    p = change_weights.values()
                    s = random.choices(list(DS.keys()), self.normalize(p))[0]
                    change_weights[s] /= 3
                    # print(change_weights)
                    candid_vs = copy.deepcopy(onto[s])
                    candid_vs.remove(DS[s])
                    v = random.choice(candid_vs)
                    DS[s] = v
                    if s == 'singer':
                        user_act = 'change_singer'
                    else:
                        user_act = 'change_category'
                    user_sent = random.choice(user_temp[user_act]) % v
                    change_times -= 1

                    sys_act = 'req_change'
                    sys_sent = random.choice(sys_temp[sys_act])

                else:
                    user_act = 'deny'
                    user_sent = random.choice(user_temp[user_act])

                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                    slot_change_flag = False
                    entity_rec_flag = True
                    start = True
                continue

            if slot_fill_flag:
                if len(rest_s) > 0:
                    # random.shuffle(rest_s)
                    s = rest_s.pop(0)
                    sys_act = 'request_' + s
                    sys_sent = random.choice(sys_temp[sys_act])
                elif len(rest_s) == 0:
                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                    slot_fill_flag = False

                    slot_change_flag = True
                    change_times = random.choices(range(len(self.change_times_probs)), self.change_times_probs)[0]
                    if change_times == 0:
                        slot_change_flag = False
                        entity_rec_flag = True
                        start = True
                continue

            if end:
                user_act = 'byebye'
                sys_act = 'byebye'

        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        dial.append((user_sent, sys_sent))
        for w in (user_sent + ' ' + sys_sent).split():
            if w not in self.vocab: self.vocab[w] = len(self.vocab)
        self.max_sent = max(self.max_sent, len(user_sent.split()))
        self.candidates[sys_sent] += 1
        turn_num += 1
        self.max_turn = max(turn_num, self.max_turn)
        return dial


class FlightSimulator(DialogSimulator):
    def __init__(self):
        super(FlightSimulator, self).__init__(
            domain='flights',
            maxnum_req_slots=hyparams['flights']['maxnum_req_slots'],
            slot_update_ratio=hyparams['flights']['slot_update_ratio'],
            entity_update_ratio=hyparams['flights']['entity_update_ratio'])

    def render_slots(self, s):
        if s == 'from':
            return random.choice(['from %s'])
        elif s == 'to':
            return random.choice(['to %s'])
        elif s == 'leaving date':
            return random.choice(['on %s'])
        else:
            return random.choice(['with %s people'])

    def generate_dial(self):
        dial = []
        user_act = 'welcome'
        sys_act = 'welcome'
        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        rest_s = []
        DS = {s: None for s in self.informable_slots}
        change_weights = {s: 1 for s in DS}

        requested_slots = copy.deepcopy(self.requestable_slots)
        random.shuffle(requested_slots)
        num_req = random.choice(range(1, 1+self.maxnum_req_slots))
        requested_slots = requested_slots[:num_req]

        slot_fill_flag = True
        slot_change_flag = False
        entity_rec_flag = False
        start = False
        request_flag = False
        end = False
        rec_times = 0
        change_times = 0
        entityOrder = 0
        turn_num = 0

        onto = self.onto
        user_temp = self.user_template
        sys_temp = self.sys_template

        while sys_act != 'byebye':
            dial.append((user_sent, sys_sent))
            for w in (user_sent + ' ' + sys_sent).split():
                if w not in self.vocab: self.vocab[w] = len(self.vocab)
            self.max_sent = max(self.max_sent, len(user_sent.split()))
            self.candidates[sys_sent] += 1
            turn_num += 1
            if user_act == 'welcome':
                # user inital slot filling
                slots = [s for s in onto.keys() if random.random() < 0.3 and onto[s] is not None]
                init_sv = {s: random.choice(onto[s]) for s in slots}
                for k, v in init_sv.items(): DS[k] = v
                rest_s = [s for s in onto.keys() if onto[s] is not None and s not in init_sv]
                if len(rest_s) != 4:
                    user_act = 'inform_all'
                    param_string = []
                    for s in slots:
                        param_string.append(self.render_slots(s) % init_sv[s])
                    user_sent = random.choice(user_temp[user_act]) % ' '.join(param_string)
                else:
                    user_act = 'inform'
                    user_sent = random.choice(user_temp[user_act])

            elif slot_fill_flag and sys_act.startswith('request'):
                s = sys_act.split('_')[1]
                user_act = 'inform_' + s
                v = random.choice(onto[s])
                DS[s] = v
                if s == 'ticket' and v == 'one':
                    user_sent = random.choice(user_temp['inform_ticket_single'])
                else:
                    user_sent = random.choice(user_temp[user_act]) % v

            if request_flag:
                if len(requested_slots) > 0:
                    random.shuffle(requested_slots)
                    s = requested_slots.pop()
                    if s == 'fare':
                        user_act = 'request_fare'
                        user_sent = random.choice(user_temp[user_act])
                    else:
                        user_act = 'request_airport'
                        user_sent = random.choice(user_temp[user_act])
                    sys_act = 'inform'
                    sys_sent = random.choice(sys_temp[sys_act]) % (entityOrder, s)
                else:
                    user_act = 'thanks'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'req_more'
                    sys_sent = random.choice(sys_temp[sys_act])
                    request_flag = False
                    end = True
                    continue

            if entity_rec_flag:
                if start:
                    rec_times = random.choices(range(1, len(self.entity_times_probs) + 1), self.entity_times_probs)[0]
                    start = False
                    # print('rec_times:', rec_times)
                    user_act = 'affirm'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'recommend'
                    sys_sent = random.choice(sys_temp[sys_act]) % entityOrder
                    rec_times -= 1
                    continue
                if rec_times > 0:
                    entityOrder += 1
                    user_act = 'refuse'
                    tmp_list = copy.deepcopy(user_temp[user_act])
                    random.shuffle(tmp_list)
                    user_sent = random.choice(tmp_list)
                    sys_act = 'next_recommend'
                    sys_sent = random.choice(sys_temp[sys_act]) % entityOrder
                    rec_times -= 1
                else:
                    user_act = 'agree'
                    user_sent = random.choice(user_temp[user_act])
                    sys_act = 'reserve'
                    sys_sent = random.choice(sys_temp[sys_act])
                    entity_rec_flag = False
                    request_flag = True

            if slot_change_flag:
                if change_times > 0:
                    p = change_weights.values()
                    s = random.choices(list(DS.keys()), self.normalize(p))[0]
                    change_weights[s] /= 3
                    # print(change_weights)
                    candid_vs = copy.deepcopy(onto[s])
                    candid_vs.remove(DS[s])
                    v = random.choice(candid_vs)
                    DS[s] = v
                    # print(s, v)
                    v_string = self.render_slots(s) % v
                    user_act = 'change'
                    user_sent = random.choice(user_temp[user_act]) % v_string
                    change_times -= 1

                    sys_act = 'req_change'
                    sys_sent = random.choice(sys_temp[sys_act])

                else:
                    user_act = 'deny'
                    user_sent = random.choice(user_temp[user_act])

                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                    slot_change_flag = False
                    entity_rec_flag = True
                    start = True
                continue

            if slot_fill_flag:
                if len(rest_s) > 0:
                    # random.shuffle(rest_s)
                    s = rest_s.pop(0)
                    sys_act = 'request_' + s
                    sys_sent = random.choice(sys_temp[sys_act])
                elif len(rest_s) == 0:
                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                    slot_fill_flag = False

                    slot_change_flag = True
                    change_times = random.choices(range(len(self.change_times_probs)), self.change_times_probs)[0]
                    if change_times == 0:
                        slot_change_flag = False
                        entity_rec_flag = True
                        start = True
                continue

            if end:
                user_act = 'byebye'
                sys_act = 'byebye'

        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        dial.append((user_sent, sys_sent))
        for w in (user_sent + ' ' + sys_sent).split():
            if w not in self.vocab: self.vocab[w] = len(self.vocab)
        self.max_sent = max(self.max_sent, len(user_sent.split()))
        self.candidates[sys_sent] += 1
        turn_num += 1
        self.max_turn = max(turn_num, self.max_turn)
        return dial


class WeatherSimulator(DialogSimulator):
    def __init__(self):
        super(WeatherSimulator, self).__init__(
            domain='weather',
            maxnum_req_slots=hyparams['weather']['maxnum_req_slots'],
            slot_update_ratio=hyparams['weather']['slot_update_ratio'],
            entity_update_ratio=hyparams['weather']['entity_update_ratio'])

    def render_slots(self, s):
        if s == 'location':
            return random.choice(['in %s'])
        else:
            return random.choice(['%s'])

    def generate_dial(self):
        dial = []
        user_act = 'welcome'
        sys_act = 'welcome'
        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        rest_s = []
        DS = {s: None for s in self.informable_slots}
        change_weights = {s: 1 for s in DS}


        slot_fill_flag = True
        slot_change_flag = False
        entity_rec_flag = False
        start = False
        request_flag = False
        end = False
        rec_times = 0
        change_times = 0
        entityOrder = 0
        turn_num = 0

        onto = self.onto
        user_temp = self.user_template
        sys_temp = self.sys_template

        while sys_act != 'byebye':
            dial.append((user_sent, sys_sent))
            for w in (user_sent + ' ' + sys_sent).split():
                if w not in self.vocab: self.vocab[w] = len(self.vocab)
            self.max_sent = max(self.max_sent, len(user_sent.split()))
            self.candidates[sys_sent] += 1
            turn_num += 1
            if user_act == 'welcome':
                # user inital slot filling
                slots = [s for s in onto.keys() if random.random() < 0.3 and onto[s] is not None]
                init_sv = {s: random.choice(onto[s]) for s in slots}
                for k, v in init_sv.items(): DS[k] = v
                rest_s = [s for s in onto.keys() if onto[s] is not None and s not in init_sv]
                if len(rest_s) != 2:
                    user_act = 'inform_all'
                    param_string = []
                    for s in slots:
                        param_string.append(self.render_slots(s) % init_sv[s])
                    user_sent = random.choice(user_temp[user_act]) % ' '.join(param_string)
                else:
                    user_act = 'inform'
                    user_sent = random.choice(user_temp[user_act])

            elif slot_fill_flag and sys_act.startswith('request'):
                s = sys_act.split('_')[1]
                user_act = 'inform_' + s
                v = random.choice(onto[s])
                DS[s] = v
                user_sent = random.choice(user_temp[user_act]) % v

            if slot_change_flag:
                if change_times > 0:
                    p = change_weights.values()
                    s = random.choices(list(DS.keys()), self.normalize(p))[0]
                    change_weights[s] /= 3
                    # print(change_weights)
                    candid_vs = copy.deepcopy(onto[s])
                    candid_vs.remove(DS[s])
                    v = random.choice(candid_vs)
                    DS[s] = v
                    v_string = self.render_slots(s) % v
                    user_act = 'change'
                    user_sent = random.choice(user_temp[user_act]) % v_string
                    change_times -= 1
                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                else:
                    slot_change_flag = False
                    end = True

            if slot_fill_flag:
                if len(rest_s) > 0:
                    # random.shuffle(rest_s)
                    s = rest_s.pop(0)
                    sys_act = 'request_' + s
                    sys_sent = random.choice(sys_temp[sys_act])
                elif len(rest_s) == 0:
                    sys_act = 'api_call'
                    sys_sent = 'api_call ' + ' '.join(DS.values())
                    slot_fill_flag = False

                    slot_change_flag = True
                    change_times = random.choices(range(len(self.change_times_probs)), self.change_times_probs)[0]
                    if change_times == 0:
                        slot_change_flag = False
                        end = True
                continue

            if end:
                user_act = 'byebye'
                sys_act = 'byebye'

        user_sent = random.choice(self.user_template[user_act])
        sys_sent = random.choice(self.sys_template[sys_act])
        dial.append((user_sent, sys_sent))
        for w in (user_sent + ' ' + sys_sent).split():
            if w not in self.vocab: self.vocab[w] = len(self.vocab)
        self.max_sent = max(self.max_sent, len(user_sent.split()))
        self.candidates[sys_sent] += 1
        turn_num += 1
        self.max_turn = max(turn_num, self.max_turn)
        return dial



if __name__ == '__main__':
    MovieSimulator().run()
    RestaurantSimulator().run()
    HotelSimulator().run()
    TravelSimulator().run()
    MusicSimulator().run()
    FlightSimulator().run()
    WeatherSimulator().run()

    # movies        max sent 19     max turn 18
    # restaurants   max sent 20     max turn 22
    # hotels        max sent 24     max turn 22
    # travels       max sent 12     max turn 16
    # music         max sent 9      max turn 17
    # flights       max sent 17     max turn 16
    # weather       max sent 10     max turn 7