import os
os.environ["AZURE_ML_INTERNAL_COMPONENTS_ENABLED"] = "True"
os.environ["AZURE_ML_CLI_PRIVATE_FEATURES_ENABLED"] = "true"

# Add CLI to identity scope
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.ai.ml import MLClient
from azure.ai.ml.dsl import pipeline


def connect_to_aml(subscription_id, resource_group, workspace_name=None, registry_name=None):
    assert workspace_name is not None or registry_name is not None, "Either workspace_name or registry_name must be provided."
    try:
        credential = DefaultAzureCredential()
        # Check if given credential can get token successfully.
        credential.get_token("https://management.azure.com/.default")
    except Exception as ex:
        # Fall back to InteractiveBrowserCredential in case DefaultAzureCredential does not work
        credential = InteractiveBrowserCredential()

    ML_CLIENT = MLClient(
        subscription_id=subscription_id,
        resource_group_name=resource_group,
        workspace_name=workspace_name,
        credential=credential,
        registry_name=registry_name
    )

    return ML_CLIENT

def ft_pipeline_v1(ml_client_registry_oai_v2, ml_client_workspace):
    # Get the components and datasets we need
    dataset_train = ml_client_workspace.data.get(name="mbpp-train", version="2")
    # Requirement to provide validation set will be removed in the upcoming version
    dataset_valid = ml_client_workspace.data.get(name="mbpp-valid", version="2")
    oai_data_import = ml_client_registry_oai_v2.components.get(name="openai_data_import", version="0.3.6")
    ft_model = ml_client_registry_oai_v2.components.get(name="openai_completions_finetune", version="0.5.13")

    @pipeline(
        display_name="mbpp-ft-3.5turbo-v1",
        description="mbpp-ft-3.5turbo-v1",
        compute="serverless"
    )
    def ft_pipeline():

        # Feed train and validation set to data import component
        oai_data_import_step = oai_data_import(
            train_dataset=dataset_train,
            validation_dataset=dataset_valid
        )
        # Set the compute target to the Singularity A100 IPP cluster
        oai_data_import_step.compute = COMPUTE_TARGET
        oai_data_import_step.resources = {
            "instance_type": "Singularity.ND12am_A100_v4",
            "virtual_cluster_arm_id": COMPUTE_TARGET,
            "instance_count": 1,
            "shm_size": "128G",
            "properties": {
                "singularity": {
                    "slaTier": "Standard",  # Basic, Standard, Premium
                    "priority": "Medium"  # Low, Medium, High
                }
            }
        }

        # Finetune the model
        ft_step = ft_model(
            input_dataset=oai_data_import_step.outputs.out_dataset,
            model="gpt-35-turbo-1106",
            task_type="chat",
            export_merged_weights=False,
            registered_model_name="gpt-35-turbo-mbpp-v1",
            n_ctx=4096,
            lora_dim=32,
            n_epochs=-1,
            batch_size=-1,
            learning_rate_multiplier=1.0,
            weight_decay_multiplier=1e-05,
            prompt_loss_weight=0.0,
            trim_mode="right",
            shuffle_type="full",
            checkpoint_interval=200,
            n_steps=1
        )
        # Set the compute target to the Singularity A100 IPP cluster
        ft_step.compute = COMPUTE_TARGET
        ft_step.resources = {
            "instance_type": "Singularity.ND96amrs_A100_v4",
            "virtual_cluster_arm_id": COMPUTE_TARGET,
            "instance_count": 3,
            "shm_size": "128G",
            "properties": {
                "singularity": {
                    "slaTier": "Standard",  # Basic, Standard, Premium
                    "priority": "Medium"  # Low, Medium, High
                }
            }
        }
        return ft_step

    return ft_pipeline

def ft_pipeline_v2(ml_client_registry_oai_v2, ml_client_workspace):
    # Get the components and datasets we need
    dataset_train = ml_client_workspace.data.get(name="mbpp-train", version="2")
    # Requirement to provide validation set will be removed in the upcoming version
    dataset_valid = ml_client_workspace.data.get(name="mbpp-valid", version="2")
    oai_data_import = ml_client_registry_oai_v2.components.get(name="openai_data_import", version="0.3.6")
    ft_model = ml_client_registry_oai_v2.components.get(name="openai_completions_finetune", version="0.5.13")

    @pipeline(
        display_name="mbpp-ft-4-v1",
        description="mbpp-ft-4-v1",
        compute="serverless"
    )
    def ft_pipeline():

        # Feed train and validation set to data import component
        oai_data_import_step = oai_data_import(
            train_dataset=dataset_train,
            validation_dataset=dataset_valid
        )
        # Set the compute target to the Singularity A100 IPP cluster
        oai_data_import_step.compute = COMPUTE_TARGET
        oai_data_import_step.resources = {
            "instance_type": "Singularity.ND12am_A100_v4",
            "virtual_cluster_arm_id": COMPUTE_TARGET,
            "instance_count": 1,
            "shm_size": "128G",
            "properties": {
                "singularity": {
                    "slaTier": "Standard",  # Basic, Standard, Premium
                    "priority": "Medium"  # Low, Medium, High
                }
            }
        }

        # Finetune the model
        ft_step = ft_model(
            input_dataset=oai_data_import_step.outputs.out_dataset,
            model="gpt-4",
            task_type="chat",
            export_merged_weights=False,
            registered_model_name="gpt-4-mbpp-v1",
            n_ctx=4096,
            lora_dim=32,
            n_epochs=-1,
            batch_size=-1,
            learning_rate_multiplier=1.0,
            weight_decay_multiplier=1e-05,
            prompt_loss_weight=0.0,
            trim_mode="right",
            shuffle_type="full",
            checkpoint_interval=200,
            n_steps=1
        )
        # Set the compute target to the Singularity A100 IPP cluster
        ft_step.compute = COMPUTE_TARGET
        ft_step.resources = {
            "instance_type": "Singularity.ND96amrs_A100_v4",
            "virtual_cluster_arm_id": COMPUTE_TARGET,
            "instance_count": 3,
            "shm_size": "128G",  # Do not use default since it is too small and may result in failure
            "properties": {
                "singularity": {
                    "slaTier": "Standard",  # Basic, Standard, Premium
                    "priority": "Medium"  # Low, Medium, High
                }
            }
        }
        return ft_step

    return ft_pipeline

def main():
    ml_client_registry_oai_v2 = connect_to_aml(subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, registry_name="azure-openai-v2-1p")
    ml_client_workspace = connect_to_aml(subscription_id=SUBSCRIPTION_ID, resource_group=RESOURCE_GROUP, workspace_name=WORKSPACE_NAME)

    # Submit the pipeline jobs
    pipeline_job_v1 = ft_pipeline_v1(ml_client_registry_oai_v2, ml_client_workspace)()
    pipeline_job_v1 = ml_client_workspace.jobs.create_or_update(
        pipeline_job_v1, experiment_name="LLM_FT"
    )

    pipeline_job_v2 = ft_pipeline_v2(ml_client_registry_oai_v2, ml_client_workspace)()
    pipeline_job_v2 = ml_client_workspace.jobs.create_or_update(
        pipeline_job_v2, experiment_name="LLM_FT"
    )

    print("The URLs to see your live jobs running are returned by the SDK:")
    print(f"Job 1: {pipeline_job_v1.services['Studio'].endpoint}")
    # print(f"Job 2: {pipeline_job_v2.services['Studio'].endpoint}")

if __name__ == "__main__":
    main()
