import torch
from transformers import BartModel, BartTokenizer, BartForConditionalGeneration
from bart_with_posthoc_layer_drop import BartWithPosthocLayerDropConfig, MyBart

def main():
    config = BartWithPosthocLayerDropConfig.from_pretrained('facebook/bart-base', forced_bos_token_id=0)
    bart = MyBart(config)

    bart_old = BartForConditionalGeneration.from_pretrained("facebook/bart-base", forced_bos_token_id=0)
    ret = bart.model.load_state_dict(bart_old.model.state_dict(), strict=True)

    # print(ret)
    tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
    example_english_phrase = "UN Chief Says There Is No <mask> in Syria"

    inputs = tokenizer.batch_encode_plus([example_english_phrase], max_length=1024, return_tensors='pt')
    generated_ids = bart.generate(inputs['input_ids'], num_beams=4, max_length=30, early_stopping=True)
    print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))

    # bart.model.encoder.encoder_dropped_layers = [2,3,4,5]
    bart.model.decoder.decoder_dropped_layers = [5]
    generated_ids = bart.generate(inputs['input_ids'], num_beams=4, max_length=30, early_stopping=True)
    print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))



if __name__ == "__main__":
    main()