

import torch
import argparse
# from utils import *
from torch.utils.data import DataLoader
import time
import psutil

def get_data(dataset, split='train'):
    alignment = 'a'
    data_path = os.path.join(os.getcwd(),dataset)

    data = Multimodal_Datasets(os.getcwd(), dataset, split,True)
    # print(np.shape(data))
    return data




dataset = 'iemocap'

train_data = get_data(dataset, 'train')

test_data = get_data(dataset, 'test')

valid_data = get_data(dataset, 'valid')
batch_size = 32 #128 for mosi #40 for mosei  #32 for iemocap
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(test_data, batch_size=batch_size,shuffle=False)
valid_loader = DataLoader(valid_data, batch_size=batch_size,shuffle=False)
torch.manual_seed(20)


seed_value = 20
os.environ['PYTHONHASHSEED'] = str(seed_value)
np.random.seed(seed_value)

dataset = sys.argv[1]

train_data = get_data(dataset, 'train')

test_data = get_data(dataset, 'test')

valid_data = get_data(dataset, 'valid')
batch_size = int(sys.argv[2])
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
model_type=sys.argv[2]
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)
torch.manual_seed(seed_value)

vals=[{'T':True,'A':True,'V':True}]
# vals=[{'model_type':'hybrid'},{'model_type':'round_robin1'}]#,{'model_type':'late_fusion'}]

if __name__ == "__main__":
    
    
    for i in vals:
      hparams={
        'dataset':dataset,
        'batch_size':batch_size,
        'model_size':40, #40 for iemocap 
        'num_heads':5,
        'num_blocks':8,
        'lr':0.001,
        'text_shape':train_data.get_dim()[0],
        'audio_shape':train_data.get_dim()[1],
        'video_shape':train_data.get_dim()[2],
        'epochs':40,
        'model_type':model_type,
       

        

    }

      hparams.update(i)
      
      start=time.time()
      train_model(hparams,train_loader, test_loader, valid_loader, len(test_data),len(valid_data),len(train_data))
      print("total time taken:",time.time()-start)
      process = psutil.Process(os.getpid())
      print("test data kength",len(valid_data))
      print("memory usage:")
      print(process.memory_info().rss)
      print(process.memory_percent())
# vals= [{'T':True,'A':True,'V':False},{'T':True,'A':False,'V':True},{'T':False,'A':True,'V':True}]#,{'T':True,'A':False,'V':False},{'T':False,'A':False,'V':True},{'T':False,'A':True,'V':False}]
# vals=[{'T':True,'A':False,'V':False},{'T':False,'A':False,'V':True},{'T':False,'A':True,'V':False}]
