import argparse

import numpy as np
import pandas as pd
from visualize import plot_bar


def magnitude_analysis(root_dir: str):
    df = pd.read_csv(f"{root_dir}/hidden_states.csv")
    num_layers = len(df)
    for l in range(num_layers):
        features = df.loc[l, "dim0":].to_numpy()
        plot_bar(
            x=[str(i) for i in features.argsort()[::-1][:20]],
            y=np.sort(features)[::-1][:20],
            save_path=f"{root_dir}/layer{l}.top{20}.png",
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root_dir", type=str)
    args = parser.parse_args()

    magnitude_analysis(root_dir=args.root_dir)
