import os
import random
import ujson

from pruner.utils.parser import Arguments
from pruner.utils.runs import Run
import pruner.utils.distributed as distributed

from pruner.utils.utils import print_message, create_directory
from pruner.scoring.encoder import CollectionEncoder

def main():
    random.seed(12345)

    parser = Arguments(description='Computing pruning scores using ColBERTPruner.')

    parser.add_model_parameters()
    parser.add_model_inference_parameters() # checkpoint, bsize, amp

    parser.add_argument('--output', type=str, required=True)
    parser.add_argument('--collection', type=str, help="path to collection.tsv")

    args = parser.parse()

    with Run.context():

        distributed.barrier(args.rank)

        if args.rank < 1:
            create_directory(args.output)

        distributed.barrier(args.rank)
        




        process_idx = max(0, args.rank)
        encoder = CollectionEncoder(args, process_idx=process_idx, num_processes=args.nranks)
        encoder.scoring()




        distributed.barrier(args.rank)

        # Save metadata.
        if args.rank < 1:
            metadata_path = os.path.join(args.output, 'metadata.json')
            print_message("Saving (the following) metadata to", metadata_path, "..")
            print(args.input_arguments)

            with open(metadata_path, 'w') as output_metadata:
                ujson.dump(args.input_arguments.__dict__, output_metadata)

        distributed.barrier(args.rank)





        distributed.barrier(args.rank)
        
        # Merge batch outputs
        if args.rank < 1:
            """
            0.tsv   11.tsv  13.tsv  15.tsv  17.tsv  19.tsv  20.tsv  22.tsv  24.tsv  26.tsv  28.tsv  2.tsv   31.tsv  33.tsv  35.tsv  37.tsv  39.tsv  40.tsv  42.tsv  44.tsv  46.tsv  48.tsv  4.tsv   51.tsv  53.tsv  55.tsv  57.tsv  59.tsv  60.tsv  62.tsv  64.tsv  66.tsv  68.tsv  6.tsv   71.tsv  73.tsv  75.tsv  77.tsv  8.tsv
            10.tsv  12.tsv  14.tsv  16.tsv  18.tsv  1.tsv   21.tsv  23.tsv  25.tsv  27.tsv  29.tsv  30.tsv  32.tsv  34.tsv  36.tsv  38.tsv  3.tsv   41.tsv  43.tsv  45.tsv  47.tsv  49.tsv  50.tsv  52.tsv  54.tsv  56.tsv  58.tsv  5.tsv   61.tsv  63.tsv  65.tsv  67.tsv  69.tsv  70.tsv  72.tsv  74.tsv  76.tsv  7.tsv   9.tsv
            """
            outfile_path = os.path.join(os.path.dirname(args.output), 'collection.pruner_score.tsv')
            num_batches = sum([1 for _ in os.listdir(args.output) if _.endswith('.tsv')])
            print(f'#> Merge {num_batches} batch outputs into {outfile_path}')

            pid_count = 0
            with open(outfile_path, 'w') as ofile:
                for batch_idx in range(num_batches):
                    path = os.path.join(args.output, f'{batch_idx}.tsv')
                    print(f'#> Read {path}')
                    with open(path,'r') as ifile:
                        for i_line, line in enumerate(ifile):
                            pid, *others = line.strip().split('\t')
                            assert pid_count == int(pid)
                            pid_count += 1

                            ofile.write(line)
            
            print(f'#> Done!')

        distributed.barrier(args.rank)
        
        

if __name__ == "__main__":
    main()
