#!/usr/bin/env python
# coding: utf-8

from utils import *
from config import Tasks




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str)
    parser.add_argument("--key", type=str, default='z')
    parser.add_argument("--dataset", type=str, default='sst2', help='test dataset')
    parser.add_argument("--demo_dataset", type=str, default=None, help='demonstrations dataset')
    parser.add_argument("--format", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=50)
    parser.add_argument("--n_shots", type=int, default=4)
    parser.add_argument('--seed_list', type=int, nargs='+')
    args = parser.parse_args()
    set_more_args(args, out_dir='DECOMP')
    assert args.n_shots
    #args.dtype = torch.float32


    def run_caches(sents):
        caches = {
            'z': torch.zeros((len(sents), n_layers, n_heads, d_head), dtype=args.dtype),
            'mlp_out': torch.zeros((len(sents), n_layers, d_model), dtype=args.dtype),
            'resid_pre': torch.zeros((len(sents),  d_model), dtype=args.dtype),
            'resid_post': torch.zeros((len(sents), d_model), dtype=args.dtype),
            'ln_scale': torch.zeros((len(sents), 1))
        }

        for i in tqdm(range(0, len(sents), args.batch_size)):
            batch_sents = sents[i:i+args.batch_size]
            inputs = model.tokenizer(batch_sents, return_tensors="pt", padding=True)
            input_ids = inputs['input_ids'].to(device)
            positions = inputs['attention_mask'].sum(1) - 1

            res_pre0 = t_utils.get_act_name("resid_pre", 0) 
            res_final = t_utils.get_act_name("resid_post", n_layers-1) 
            ln_scale = 'ln_final.hook_scale'
            _, cache = model.run_with_cache(
                input_ids,
                names_filter = lambda name: name.endswith('z') or name.endswith('mlp_out') or \
                    name == res_final or name == res_pre0 or name == ln_scale,
                device='cpu',
            )

            bs = len(batch_sents)
            caches['resid_pre'][i:i+bs] = cache[res_pre0][range(bs), positions]
            caches['resid_post'][i:i+bs] = cache[res_final][range(bs), positions]
            caches['ln_scale'][i:i+bs] = cache[ln_scale][range(bs), positions]
            for key in ['z', 'mlp_out']:
                for layer in range(n_layers):
                    caches[key][i:i+bs, layer] = cache[key, layer][range(bs), positions]
        return caches


    def layer_norm(x):
        x = x.to(torch.float32)
        scale = caches_dict['ln_scale'].cuda()
        x = (x / scale).to(args.dtype)  # [batch, pos, length]
        return (x * model.ln_final.w)

    def unembed(h, label_ids, key=None):
        h = layer_norm(h)
        projs = torch.mm(h, model.W_U[:, option_ids]).cpu() #[batch, d_model] x [d_model, n_options] -> [batch, n_options]
        preds = projs.argmax(1).numpy()
        acc = (preds == label_ids).mean()
        if key == 'resid_post' and args.mode == 'test':
            preds = model.tokenizer.batch_decode(np.array(option_ids)[preds])
            np.save(os.path.join(args.out_dir, f'all_pred_labels-seed{seed}.npy'), preds)
        return acc, projs

    def get_head_scores(caches, label_ids, head, layer):
        z = caches[:, layer, head]
        h = z.cuda() @ model.W_O[layer, head] # [batch, d_head] @ [d_head, d_model] -> [batch, d_model]
        acc, projs = unembed(h, label_ids)
        return acc, projs

    def get_mlp_scores(caches, label_ids):
        z = caches[:, layer].cuda() # [batch, d_model]
        acc, projs = unembed(z, label_ids)
        return acc, projs


    model = load_hooked_model(args.model_name, args.dtype)
    n_heads, n_layers, d_head, d_model = model.cfg.n_heads, model.cfg.n_layers, model.cfg.d_head, model.cfg.d_model

    modes = ['test'] if args.demo_dataset else ['dev', 'test']
    for mode in modes:
        args.mode = mode
        for seed in tqdm(args.seed_list):
            test_sents, test_labels, test_label_ids, option_ids = prep_inputs(args, seed, model.tokenizer)

            N = len(test_sents)

            accs_z = np.zeros((n_layers, n_heads))
            projs_z = torch.zeros((n_layers, n_heads, N, len(option_ids)), dtype=args.dtype)
            accs_mlp = np.zeros((n_layers, ))
            projs_mlp = torch.zeros((n_layers, N, len(option_ids)), dtype=args.dtype)

            caches_dict = run_caches(test_sents)

            for key in ['resid_pre', 'resid_post']:
                acc, projs = unembed(caches_dict[key].cuda(), test_label_ids, key=key)
                print(key, f"Acc: {acc:.1%}")
                np.save(os.path.join(args.out_dir, f'acc-{key}-{args.mode}-{seed}.npy'), acc)
                torch.save(projs, os.path.join(args.out_dir, f'projs-{key}-{args.mode}-{seed}.pt'))

            for layer in range(n_layers):
                for head in range(n_heads):
                    acc, proj = get_head_scores(caches_dict['z'], test_label_ids, head, layer)
                    accs_z[layer, head] = acc
                    projs_z[layer, head] = proj

            for layer in range(n_layers):
                acc, proj = get_mlp_scores(caches_dict['mlp_out'], test_label_ids)
                accs_mlp[layer] = acc
                projs_mlp[layer] = proj


            print(f'TopHead: {accs_z.max():.1%}')
            np.save(os.path.join(args.out_dir, f'acc-heads-{args.mode}-{seed}.npy'), accs_z)
            torch.save(projs_z, os.path.join(args.out_dir, f'projs-heads-{args.mode}-{seed}.pt'))

            print(f'Top L{accs_mlp.argmax()}: {accs_mlp.max():.1%}')
            np.save(os.path.join(args.out_dir, f'acc-mlps-{args.mode}-{seed}.npy'), accs_mlp)
            torch.save(projs_mlp, os.path.join(args.out_dir, f'projs-mlps-{args.mode}-{seed}.pt'))
            np.save(os.path.join(args.out_dir, f'{args.mode}_label_ids.npy'), test_label_ids)
