import os ###### from airflow.decorators import dag from airflow.operators.bash import BashOperator from airflow.operators.python import PythonOperator from airflow.utils.dates import days_ago from airflow.utils.trigger_rule import TriggerRule from airflow.api.common.trigger_dag import trigger_dag from cosmos import DbtTaskGroup, ProfileConfig, ProjectConfig, RenderConfig from mrds.utils.security_utils import get_verified_run_id, verify_run_id from mrds.utils import oraconn DAG_NAME = os.path.splitext(os.path.basename(__file__))[0] ENV_NAME = os.getenv("MRDS_ENV", "").lower() DATABASE_NAME_MAP = { "dev": "MOPDB", "test": "MOPDB_TEST", } DATABASE_NAME = DATABASE_NAME_MAP.get(ENV_NAME, "MOPDB") # check cron 2 dbt_root_path = "/opt/dbt" dbt_profiles_dir = "/opt/dbt/profiles.yml" dbt_profiles_dir_parent = "/opt/dbt" dbt_env = { "DBT_PROFILES_DIR": dbt_profiles_dir_parent, "DBT_TARGET": ENV_NAME, "MRDS_LOADER_DB_USER": os.getenv("MRDS_LOADER_DB_USER"), "MRDS_LOADER_DB_PASS": os.getenv("MRDS_LOADER_DB_PASS"), "MRDS_LOADER_DB_TNS": os.getenv("MRDS_LOADER_DB_TNS", "XE"), "MRDS_SCHEMA": os.getenv("MRDS_SCHEMA", "CT_MRDS"), "MRDS_PROTOCOL": os.getenv("MRDS_PROTOCOL", "tcps"), "MRDS_THREADS": os.getenv("MRDS_THREADS", "4"), "DBT_LOG_PATH": "/opt/dbt/logs", "DBT_TARGET_PATH": "/opt/dbt/target", "PYTHONUNBUFFERED": "1", } def retrieve_run_id(**kwargs): # Get verified run_id using security utilities run_id = get_verified_run_id(kwargs) kwargs["ti"].xcom_push(key="run_id", value=run_id) return run_id def check_dag_status(**kwargs): for ti in kwargs["dag_run"].get_task_instances(): if ti.state == "failed" and ti.task_id != kwargs["task_instance"].task_id: raise Exception(f"Task {ti.task_id} failed. Failing this DAG run") def get_rqsd_tables_to_replicate(**kwargs): """ Get list of RQSD tables from a_devo_replica_mgmt_rqsd, excluding tables ending with _COPY """ import logging oracle_conn = None try: oracle_conn = oraconn.connect('MRDS_LOADER') cursor = oracle_conn.cursor() # Query to get all tables excluding _COPY versions sql = """ SELECT OWNER, TABLE_NAME FROM CT_MRDS.a_devo_replica_mgmt_rqsd WHERE TABLE_NAME NOT LIKE '%_COPY' ORDER BY OWNER, TABLE_NAME """ cursor.execute(sql) tables = cursor.fetchall() cursor.close() logging.info(f"Found {len(tables)} RQSD tables to replicate (excluding _COPY versions)") # Convert to list of owner.table_name format table_list = [f"{owner}.{table_name}" for owner, table_name in tables] # Push to XCom for next task kwargs["ti"].xcom_push(key="rqsd_tables", value=table_list) return table_list except Exception as e: logging.error(f"Error getting RQSD tables: {e}") raise finally: if oracle_conn: oracle_conn.close() def trigger_rqsd_replication(**kwargs): """ Trigger devo_replicator_trigger_rqsd for each table in the list """ import logging from datetime import datetime ti = kwargs["ti"] table_list = ti.xcom_pull(task_ids="get_rqsd_tables", key="rqsd_tables") if not table_list: logging.warning("No RQSD tables found to replicate") return logging.info(f"Triggering replication for {len(table_list)} tables") triggered_count = 0 failed_triggers = [] for owner_table in table_list: try: conf = { "owner_table": owner_table } trigger_dag( dag_id='devo_replicator_trigger_rqsd', conf=conf, execution_date=None, replace_microseconds=False ) triggered_count += 1 logging.info(f"Successfully triggered replication for {owner_table}") except Exception as e: logging.error(f"Failed to trigger replication for {owner_table}: {e}") failed_triggers.append(owner_table) logging.info(f"Replication triggered for {triggered_count}/{len(table_list)} tables") if failed_triggers: logging.warning(f"Failed to trigger replication for: {', '.join(failed_triggers)}") # Push results to XCom ti.xcom_push(key="triggered_count", value=triggered_count) ti.xcom_push(key="failed_triggers", value=failed_triggers) return { "triggered_count": triggered_count, "total_tables": len(table_list), "failed_triggers": failed_triggers } @dag( dag_id=DAG_NAME, schedule_interval=None, start_date=days_ago(2), catchup=False, ) def run_dag(): def read_vars(**context): BUCKET = os.getenv("INBOX_BUCKET") BUCKET_NAMESPACE = os.getenv("BUCKET_NAMESPACE") print("========= DBT ENV =========") print(f"BUCKET_NAMESPACE: {BUCKET_NAMESPACE}, BUCKET : {BUCKET}") return 1 read_vars_task = PythonOperator( task_id="read_vars", python_callable=read_vars, provide_context=True, ) retrieve_run_id_task = PythonOperator( task_id="retrieve_run_id", python_callable=retrieve_run_id, provide_context=True, ) control_external_run_start = BashOperator( task_id="control_external_run_start", params={"db": DATABASE_NAME, "wf": DAG_NAME}, env=dbt_env, bash_command=""" set -euxo pipefail cd /opt/dbt dbt --log-format json --log-level debug --debug --log-path /opt/dbt/logs \ run-operation control_external_run_start \ --vars '{{ { "orchestration_run_id": ti.xcom_pull(task_ids="retrieve_run_id", key="run_id"), "input_service_name": params.db, "workflow_name": params.wf } | tojson }}' """, ) common_profile = ProfileConfig( profiles_yml_filepath=dbt_profiles_dir, profile_name="mrds", target_name=ENV_NAME, ) common_project = ProjectConfig(dbt_project_path=dbt_root_path) common_vars = { "orchestration_run_id": "{{ ti.xcom_pull(task_ids='retrieve_run_id', key='run_id') }}", "input_service_name": DATABASE_NAME, "workflow_name": DAG_NAME, } common_operator_args = { "vars": common_vars, "env": dbt_env, } m_MOPDB_RQSD_ANNEX_1_1_ALL_ODS_RQSD_OBSERVATIONS = DbtTaskGroup( group_id="m_MOPDB_RQSD_ANNEX_1_1_ALL_ODS_RQSD_OBSERVATIONS", project_config=common_project, profile_config=common_profile, render_config=RenderConfig(select=[ "tag:m_MOPDB_RQSD_ANNEX_1_1_ALL_ODS_RQSD_OBSERVATIONS", ]), operator_args=common_operator_args, ) m_MOPDB_RQSD_ANNEX_1_2_ALL_ODS_RQSD_OBSERVATIONS = DbtTaskGroup( group_id="m_MOPDB_RQSD_ANNEX_1_2_ALL_ODS_RQSD_OBSERVATIONS", project_config=common_project, profile_config=common_profile, render_config=RenderConfig(select=[ "tag:m_MOPDB_RQSD_ANNEX_1_2_ALL_ODS_RQSD_OBSERVATIONS", ]), operator_args=common_operator_args, ) m_MOPDB_RQSD_ANNEX_1_1_FIN_ALL_ODS_RQSD_OBSERVATIONS = DbtTaskGroup( group_id="m_MOPDB_RQSD_ANNEX_1_1_FIN_ALL_ODS_RQSD_OBSERVATIONS", project_config=common_project, profile_config=common_profile, render_config=RenderConfig(select=[ "tag:m_MOPDB_RQSD_ANNEX_1_1_FIN_ALL_ODS_RQSD_OBSERVATIONS", ]), operator_args=common_operator_args, ) m_MOPDB_RQSD_ANNEX_1_2_FIN_ALL_ODS_RQSD_OBSERVATIONS = DbtTaskGroup( group_id="m_MOPDB_RQSD_ANNEX_1_2_FIN_ALL_ODS_RQSD_OBSERVATIONS", project_config=common_project, profile_config=common_profile, render_config=RenderConfig(select=[ "tag:m_MOPDB_RQSD_ANNEX_1_2_FIN_ALL_ODS_RQSD_OBSERVATIONS", ]), operator_args=common_operator_args, ) m_MOPDB_RQSD_ANNEX_2_ALL_ODS_RQSD_OBSERVATIONS = DbtTaskGroup( group_id="m_MOPDB_RQSD_ANNEX_2_ALL_ODS_RQSD_OBSERVATIONS", project_config=common_project, profile_config=common_profile, render_config=RenderConfig(select=[ "tag:m_MOPDB_RQSD_ANNEX_2_ALL_ODS_RQSD_OBSERVATIONS", ]), operator_args=common_operator_args, ) m_MOPDB_RQSD_OUTPUT_CURR_RQSD_NCB_SUBA = DbtTaskGroup( group_id="m_MOPDB_RQSD_OUTPUT_CURR_RQSD_NCB_SUBA", project_config=common_project, profile_config=common_profile, render_config=RenderConfig(select=[ "tag:m_MOPDB_RQSD_OUTPUT_CURR_RQSD_NCB_SUBA", ]), operator_args=common_operator_args, ) control_external_run_end = BashOperator( task_id="control_external_run_end", params={"db": DATABASE_NAME, "wf": DAG_NAME}, env=dbt_env, bash_command=""" set -euxo pipefail cd /opt/dbt dbt --log-format json --log-level debug --debug --log-path /opt/dbt/logs \ run-operation control_external_run_end \ --vars '{{ { "orchestration_run_id": ti.xcom_pull(task_ids="retrieve_run_id", key="run_id"), "input_service_name": params.db, "workflow_name": params.wf } | tojson }}' """, trigger_rule=TriggerRule.ALL_DONE, ) # Get list of RQSD tables to replicate get_rqsd_tables = PythonOperator( task_id="get_rqsd_tables", python_callable=get_rqsd_tables_to_replicate, provide_context=True, ) # Trigger replication for all RQSD tables trigger_rqsd_replication_task = PythonOperator( task_id="trigger_rqsd_replication", python_callable=trigger_rqsd_replication, provide_context=True, ) dag_status = PythonOperator( task_id="dag_status", provide_context=True, python_callable=check_dag_status, trigger_rule=TriggerRule.ALL_DONE, ) # dependency chain read_vars_task >> retrieve_run_id_task >> control_external_run_start >> [ m_MOPDB_RQSD_ANNEX_1_1_ALL_ODS_RQSD_OBSERVATIONS, m_MOPDB_RQSD_ANNEX_1_2_ALL_ODS_RQSD_OBSERVATIONS, m_MOPDB_RQSD_ANNEX_1_1_FIN_ALL_ODS_RQSD_OBSERVATIONS, m_MOPDB_RQSD_ANNEX_1_2_FIN_ALL_ODS_RQSD_OBSERVATIONS, m_MOPDB_RQSD_ANNEX_2_ALL_ODS_RQSD_OBSERVATIONS, ] >> m_MOPDB_RQSD_OUTPUT_CURR_RQSD_NCB_SUBA >> control_external_run_end >> get_rqsd_tables >> trigger_rqsd_replication_task >> dag_status globals()[DAG_NAME] = run_dag()