init
This commit is contained in:
0
python/mrds_common/mrds/utils/__init__.py
Normal file
0
python/mrds_common/mrds/utils/__init__.py
Normal file
69
python/mrds_common/mrds/utils/csv_utils.py
Normal file
69
python/mrds_common/mrds/utils/csv_utils.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import csv
|
||||
import os
|
||||
|
||||
TASK_HISTORY_MULTIPLIER = 1_000_000_000
|
||||
|
||||
|
||||
def read_csv_file(csv_filepath, encoding_type="utf-8"):
|
||||
with open(csv_filepath, "r", newline="", encoding=encoding_type) as csvfile:
|
||||
reader = list(csv.reader(csvfile))
|
||||
headers = reader[0]
|
||||
data_rows = reader[1:]
|
||||
return headers, data_rows
|
||||
|
||||
|
||||
def write_data_to_csv_file(csv_filepath, data, encoding_type="utf-8"):
|
||||
temp_csv_filepath = csv_filepath + ".tmp"
|
||||
with open(temp_csv_filepath, "w", newline="", encoding=encoding_type) as csvfile:
|
||||
writer = csv.writer(csvfile, quoting=csv.QUOTE_ALL)
|
||||
writer.writerow(data["headers"])
|
||||
writer.writerows(data["rows"])
|
||||
os.replace(temp_csv_filepath, csv_filepath)
|
||||
|
||||
|
||||
def add_static_columns(data_rows, headers, static_entries):
|
||||
for column_header, value in static_entries:
|
||||
if column_header not in headers:
|
||||
headers.append(column_header)
|
||||
for row in data_rows:
|
||||
row.append(value)
|
||||
else:
|
||||
idx = headers.index(column_header)
|
||||
for row in data_rows:
|
||||
row[idx] = value
|
||||
|
||||
|
||||
def add_a_key_columns(data_rows, headers, a_key_entries, task_history_key):
|
||||
for column_header in a_key_entries:
|
||||
if column_header not in headers:
|
||||
headers.append(column_header)
|
||||
for i, row in enumerate(data_rows, start=1):
|
||||
a_key_value = int(task_history_key) * TASK_HISTORY_MULTIPLIER + i
|
||||
row.append(str(a_key_value))
|
||||
else:
|
||||
idx = headers.index(column_header)
|
||||
for i, row in enumerate(data_rows, start=1):
|
||||
a_key_value = int(task_history_key) * TASK_HISTORY_MULTIPLIER + i
|
||||
row[idx] = str(a_key_value)
|
||||
|
||||
|
||||
def add_workflow_key_columns(data_rows, headers, workflow_key_entries, workflow_key):
|
||||
for column_header in workflow_key_entries:
|
||||
if column_header not in headers:
|
||||
headers.append(column_header)
|
||||
for row in data_rows:
|
||||
row.append(workflow_key)
|
||||
else:
|
||||
idx = headers.index(column_header)
|
||||
for row in data_rows:
|
||||
row[idx] = workflow_key
|
||||
|
||||
|
||||
def rearrange_columns(headers, data_rows, column_order):
|
||||
header_to_index = {header: idx for idx, header in enumerate(headers)}
|
||||
new_indices = [
|
||||
header_to_index[header] for header in column_order if header in header_to_index
|
||||
]
|
||||
headers = [headers[idx] for idx in new_indices]
|
||||
data_rows = [[row[idx] for idx in new_indices] for row in data_rows]
|
||||
return headers, data_rows
|
||||
177
python/mrds_common/mrds/utils/manage_files.py
Normal file
177
python/mrds_common/mrds/utils/manage_files.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from . import oraconn
|
||||
from . import sql_statements
|
||||
from . import utils
|
||||
|
||||
# Get the next load id from the sequence
|
||||
|
||||
#
|
||||
# Workflows
|
||||
#
|
||||
|
||||
|
||||
def process_source_file_from_event(resource_id: str):
|
||||
#
|
||||
# expects object uri in the form /n/<namespace>/b/<bucket>/o/<object>
|
||||
# eg /n/frcnomajoc7v/b/dmarsdb1/o/sqlnet.log
|
||||
# and calls process_source_file with prefix and file_name extracted from that uri
|
||||
#
|
||||
|
||||
_, _, prefix, file_name = utils.parse_uri_with_regex(resource_id)
|
||||
process_source_file(prefix, file_name)
|
||||
|
||||
|
||||
def process_source_file(prefix: str, filename: str):
|
||||
|
||||
sourcefile = f"{prefix.rstrip('/')}/{filename}" # rstrip to cater for cases where the prefix is passed with a trailing slash
|
||||
try:
|
||||
conn = oraconn.connect("MRDS_LOADER")
|
||||
|
||||
oraconn.run_proc(conn, "CT_MRDS.FILE_MANAGER.PROCESS_SOURCE_FILE", [sourcefile])
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def execute_query(query, query_parameters=None, account_alias="MRDS_LOADER"):
|
||||
query_result = None
|
||||
try:
|
||||
conn = oraconn.connect(account_alias)
|
||||
curs = conn.cursor()
|
||||
if query_parameters != None:
|
||||
curs.execute(query, query_parameters)
|
||||
else:
|
||||
curs.execute(query)
|
||||
query_result = curs.fetchall()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
return [t[0] for t in query_result]
|
||||
|
||||
|
||||
def get_file_prefix(source_key, source_file_id, table_id):
|
||||
query_result = None
|
||||
try:
|
||||
conn = oraconn.connect("MRDS_LOADER")
|
||||
curs = conn.cursor()
|
||||
|
||||
curs.execute(
|
||||
sql_statements.get_sql("get_file_prefix"),
|
||||
[source_key, source_file_id, table_id],
|
||||
)
|
||||
query_result = curs.fetchone()
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
return query_result[0]
|
||||
|
||||
|
||||
def get_inbox_bucket():
|
||||
try:
|
||||
conn = oraconn.connect("MRDS_LOADER")
|
||||
|
||||
ret = oraconn.run_func(conn, "CT_MRDS.FILE_MANAGER.GET_INBOX_BUCKET", str, [])
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def get_data_bucket():
|
||||
try:
|
||||
conn = oraconn.connect("MRDS_LOADER")
|
||||
|
||||
ret = oraconn.run_func(conn, "CT_MRDS.FILE_MANAGER.GET_DATA_BUCKET", str, [])
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def add_source_file_config(
|
||||
source_key,
|
||||
source_file_type,
|
||||
source_file_id,
|
||||
source_file_desc,
|
||||
source_file_name_pattern,
|
||||
table_id,
|
||||
template_table_name,
|
||||
):
|
||||
try:
|
||||
conn = oraconn.connect("MRDS_LOADER")
|
||||
|
||||
ret = oraconn.run_proc(
|
||||
conn,
|
||||
"CT_MRDS.FILE_MANAGER.ADD_SOURCE_FILE_CONFIG",
|
||||
[
|
||||
source_key,
|
||||
source_file_type,
|
||||
source_file_id,
|
||||
source_file_desc,
|
||||
source_file_name_pattern,
|
||||
table_id,
|
||||
template_table_name,
|
||||
],
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def add_column_date_format(template_table_name, column_name, date_format):
|
||||
try:
|
||||
conn = oraconn.connect("MRDS_LOADER")
|
||||
|
||||
ret = oraconn.run_proc(
|
||||
conn,
|
||||
"CT_MRDS.FILE_MANAGER.ADD_column_date_format",
|
||||
[template_table_name, column_name, date_format],
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def execute(stmt):
|
||||
try:
|
||||
conn = oraconn.connect("MRDS_LOADER")
|
||||
curs = conn.cursor()
|
||||
curs.execute(stmt)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def create_external_table(table_name, template_table_name, prefix):
|
||||
try:
|
||||
conn = oraconn.connect("ODS_LOADER")
|
||||
|
||||
ret = oraconn.run_proc(
|
||||
conn,
|
||||
"CT_MRDS.FILE_MANAGER.CREATE_EXTERNAL_TABLE",
|
||||
[table_name, template_table_name, prefix, get_bucket("ODS")],
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def get_bucket(bucket):
|
||||
try:
|
||||
conn = oraconn.connect("MRDS_LOADER")
|
||||
|
||||
ret = oraconn.run_func(
|
||||
conn, "CT_MRDS.FILE_MANAGER.GET_BUCKET_URI", str, [bucket]
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return ret
|
||||
97
python/mrds_common/mrds/utils/manage_runs.py
Normal file
97
python/mrds_common/mrds/utils/manage_runs.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from . import oraconn
|
||||
from . import sql_statements
|
||||
from . import static_vars
|
||||
from . import manage_files
|
||||
|
||||
|
||||
def init_workflow(database_name: str, workflow_name: str, workflow_run_id: str):
|
||||
|
||||
try:
|
||||
conn = oraconn.connect("MRDS_LOADER")
|
||||
a_workflow_history_key = oraconn.run_func(
|
||||
conn,
|
||||
"CT_MRDS.WORKFLOW_MANAGER.INIT_WORKFLOW",
|
||||
int,
|
||||
[database_name, workflow_run_id, workflow_name],
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return a_workflow_history_key
|
||||
|
||||
|
||||
def finalise_workflow(a_workflow_history_key: int, workflow_status: str):
|
||||
|
||||
try:
|
||||
conn = oraconn.connect("MRDS_LOADER")
|
||||
|
||||
oraconn.run_proc(
|
||||
conn,
|
||||
"CT_MRDS.WORKFLOW_MANAGER.FINALISE_WORKFLOW",
|
||||
[a_workflow_history_key, workflow_status],
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def init_task(task_name: str, task_run_id: str, a_workflow_history_key: int):
|
||||
|
||||
a_task_history_key: int
|
||||
|
||||
try:
|
||||
conn = oraconn.connect("MRDS_LOADER")
|
||||
a_task_history_key = oraconn.run_func(
|
||||
conn,
|
||||
"CT_MRDS.WORKFLOW_MANAGER.INIT_TASK",
|
||||
int,
|
||||
[task_run_id, task_name, a_workflow_history_key],
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return a_task_history_key
|
||||
|
||||
|
||||
def finalise_task(a_task_history_key: int, task_status: str):
|
||||
|
||||
try:
|
||||
conn = oraconn.connect("MRDS_LOADER")
|
||||
curs = conn.cursor()
|
||||
|
||||
curs.execute(
|
||||
sql_statements.get_sql("finalise_task"), [task_status, a_task_history_key]
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def set_workflow_property(
|
||||
wf_history_key: int, service_name: str, property: str, value: str
|
||||
):
|
||||
try:
|
||||
conn = oraconn.connect("MRDS_LOADER")
|
||||
|
||||
ret = oraconn.run_proc(
|
||||
conn,
|
||||
"CT_MRDS.WORKFLOW_MANAGER.SET_WORKFLOW_PROPERTY",
|
||||
[wf_history_key, service_name, property, value],
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def select_ods_tab(table_name: str, value: str, condition="1 = 1"):
|
||||
|
||||
query = "select %s from %s where %s" % (value, table_name, condition)
|
||||
print("query = |%s|" % query)
|
||||
return manage_files.execute_query(query=query, account_alias="ODS_LOADER")
|
||||
53
python/mrds_common/mrds/utils/objectstore.py
Normal file
53
python/mrds_common/mrds/utils/objectstore.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import oci
|
||||
|
||||
|
||||
def get_client():
|
||||
#
|
||||
# Authentication is done using Instance Principals on VMs and Resouce Principal on OCI Container Instances
|
||||
# The function first tries Resource Principal and fails back to Instance Principal in case of error
|
||||
#
|
||||
try:
|
||||
signer = oci.auth.signers.get_resource_principals_signer()
|
||||
except:
|
||||
signer = signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
|
||||
|
||||
# Create secret client and retrieve content
|
||||
client = oci.object_storage.ObjectStorageClient(
|
||||
{}, signer=signer
|
||||
) # the first empyty bracket is an empty config
|
||||
return client
|
||||
|
||||
|
||||
def list_bucket(client, namespace, bucket, prefix):
|
||||
objects = client.list_objects(namespace, bucket, prefix=prefix)
|
||||
# see https://docs.oracle.com/en-us/iaas/tools/python/2.135.0/api/request_and_response.html#oci.response.Response for all attrs
|
||||
return objects.data
|
||||
|
||||
|
||||
def upload_file(client, source_filename, namespace, bucket, prefix, target_filename):
|
||||
with open(source_filename, "rb") as in_file:
|
||||
client.put_object(
|
||||
namespace, bucket, f"{prefix.rstrip('/')}/{target_filename}", in_file
|
||||
)
|
||||
|
||||
|
||||
def clean_folder(client, namespace, bucket, prefix):
|
||||
objects = client.list_objects(namespace, bucket, prefix=prefix)
|
||||
for o in objects.data.objects:
|
||||
print(f"Deleting {prefix.rstrip('/')}/{o.name}")
|
||||
client.delete_object(namespace, bucket, f"{o.name}")
|
||||
|
||||
|
||||
def delete_file(client, file, namespace, bucket, prefix):
|
||||
client.delete_object(namespace, bucket, f"{prefix.rstrip('/')}/{file}")
|
||||
|
||||
|
||||
def download_file(client, namespace, bucket, prefix, source_filename, target_filename):
|
||||
# Retrieve the file, streaming it into another file in 1 MiB chunks
|
||||
|
||||
get_obj = client.get_object(
|
||||
namespace, bucket, f"{prefix.rstrip('/')}/{source_filename}"
|
||||
)
|
||||
with open(target_filename, "wb") as f:
|
||||
for chunk in get_obj.data.raw.stream(1024 * 1024, decode_content=False):
|
||||
f.write(chunk)
|
||||
38
python/mrds_common/mrds/utils/oraconn.py
Normal file
38
python/mrds_common/mrds/utils/oraconn.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import oracledb
|
||||
import os
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
|
||||
def connect(alias):
|
||||
|
||||
username = os.getenv(alias + "_DB_USER")
|
||||
password = os.getenv(alias + "_DB_PASS")
|
||||
tnsalias = os.getenv(alias + "_DB_TNS")
|
||||
connstr = username + "/" + password + "@" + tnsalias
|
||||
|
||||
oracledb.init_oracle_client()
|
||||
|
||||
try:
|
||||
conn = oracledb.connect(connstr)
|
||||
return conn
|
||||
except oracledb.DatabaseError as db_err:
|
||||
tb = traceback.format_exc()
|
||||
print(f"DatabaseError connecting to '{alias}': {db_err}\n{tb}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except Exception as exc:
|
||||
tb = traceback.format_exc()
|
||||
print(f"Unexpected error connecting to '{alias}': {exc}\n{tb}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def run_proc(connection, proc: str, param: []):
|
||||
curs = connection.cursor()
|
||||
curs.callproc(proc, param)
|
||||
|
||||
|
||||
def run_func(connection, proc: str, rettype, param: []):
|
||||
curs = connection.cursor()
|
||||
ret = curs.callfunc(proc, rettype, param)
|
||||
|
||||
return ret
|
||||
46
python/mrds_common/mrds/utils/secrets.py
Normal file
46
python/mrds_common/mrds/utils/secrets.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import oci
|
||||
import ast
|
||||
import base64
|
||||
|
||||
# Specify the OCID of the secret to retrieve
|
||||
|
||||
|
||||
def get_secretcontents(ocid):
|
||||
#
|
||||
# Authentication is done using Instance Principals on VMs and Resouce Principal on OCI Container Instances
|
||||
# The function first tries Resource Principal and fails back to Instance Principal in case of error
|
||||
#
|
||||
try:
|
||||
signer = oci.auth.signers.get_resource_principals_signer()
|
||||
except:
|
||||
signer = signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
|
||||
|
||||
# Create secret client and retrieve content
|
||||
secretclient = oci.secrets.SecretsClient({}, signer=signer)
|
||||
secretcontents = secretclient.get_secret_bundle(secret_id=ocid)
|
||||
return secretcontents
|
||||
|
||||
|
||||
def get_password(ocid):
|
||||
|
||||
secretcontents = get_secretcontents(ocid)
|
||||
|
||||
# Decode the secret from base64 and return password
|
||||
keybase64 = secretcontents.data.secret_bundle_content.content
|
||||
keybase64bytes = keybase64.encode("ascii")
|
||||
keybytes = base64.b64decode(keybase64bytes)
|
||||
key = keybytes.decode("ascii")
|
||||
keydict = ast.literal_eval(key)
|
||||
return keydict["password"]
|
||||
|
||||
|
||||
def get_secret(ocid):
|
||||
|
||||
# Create client
|
||||
secretcontents = get_secretcontents(ocid)
|
||||
|
||||
# Decode the secret from base64 and return it
|
||||
certbase64 = secretcontents.data.secret_bundle_content.content
|
||||
certbytes = base64.b64decode(certbase64)
|
||||
cert = certbytes.decode("UTF-8")
|
||||
return cert
|
||||
106
python/mrds_common/mrds/utils/security_utils.py
Normal file
106
python/mrds_common/mrds/utils/security_utils.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import re
|
||||
import logging
|
||||
|
||||
|
||||
def verify_run_id(run_id, context=None):
|
||||
"""
|
||||
Verify run_id for security compliance.
|
||||
|
||||
Args:
|
||||
run_id (str): The run_id to verify
|
||||
context (dict, optional): Airflow context for logging
|
||||
|
||||
Returns:
|
||||
str: Verified run_id
|
||||
|
||||
Raises:
|
||||
ValueError: If run_id is invalid or suspicious
|
||||
"""
|
||||
try:
|
||||
# Basic checks
|
||||
if not run_id or not isinstance(run_id, str):
|
||||
raise ValueError(
|
||||
f"Invalid run_id: must be non-empty string, got: {type(run_id).__name__}"
|
||||
)
|
||||
|
||||
run_id = run_id.strip()
|
||||
|
||||
if len(run_id) < 1 or len(run_id) > 250:
|
||||
raise ValueError(
|
||||
f"Invalid run_id: length must be 1-250 chars, got: {len(run_id)}"
|
||||
)
|
||||
|
||||
# Allow only safe characters
|
||||
if not re.match(r"^[a-zA-Z0-9_\-:+.T]+$", run_id):
|
||||
suspicious_chars = "".join(
|
||||
set(
|
||||
char for char in run_id if not re.match(r"[a-zA-Z0-9_\-:+.T]", char)
|
||||
)
|
||||
)
|
||||
logging.warning(f"SECURITY: Invalid chars in run_id: '{suspicious_chars}'")
|
||||
raise ValueError("Invalid run_id: contains unsafe characters")
|
||||
|
||||
# Check for attack patterns
|
||||
dangerous_patterns = [
|
||||
r"\.\./",
|
||||
r"\.\.\\",
|
||||
r"<script",
|
||||
r"javascript:",
|
||||
r"union\s+select",
|
||||
r"drop\s+table",
|
||||
r"insert\s+into",
|
||||
r"delete\s+from",
|
||||
r"exec\s*\(",
|
||||
r"system\s*\(",
|
||||
r"eval\s*\(",
|
||||
r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]",
|
||||
]
|
||||
|
||||
for pattern in dangerous_patterns:
|
||||
if re.search(pattern, run_id, re.IGNORECASE):
|
||||
logging.error(f"SECURITY: Dangerous pattern in run_id: '{run_id}'")
|
||||
raise ValueError("Invalid run_id: contains dangerous pattern")
|
||||
|
||||
# Log success
|
||||
if context:
|
||||
dag_id = (
|
||||
getattr(context.get("dag"), "dag_id", "unknown")
|
||||
if context.get("dag")
|
||||
else "unknown"
|
||||
)
|
||||
logging.info(f"run_id verified: '{run_id}' for DAG: '{dag_id}'")
|
||||
|
||||
return run_id
|
||||
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"SECURITY: run_id verification failed: '{run_id}', Error: {str(e)}"
|
||||
)
|
||||
raise ValueError(f"run_id verification failed: {str(e)}")
|
||||
|
||||
|
||||
def get_verified_run_id(context):
|
||||
"""
|
||||
Extract and verify run_id from Airflow context.
|
||||
|
||||
Args:
|
||||
context (dict): Airflow context
|
||||
|
||||
Returns:
|
||||
str: Verified run_id
|
||||
"""
|
||||
try:
|
||||
run_id = None
|
||||
if context and "ti" in context:
|
||||
run_id = context["ti"].run_id
|
||||
elif context and "run_id" in context:
|
||||
run_id = context["run_id"]
|
||||
|
||||
if not run_id:
|
||||
raise ValueError("Could not extract run_id from context")
|
||||
|
||||
return verify_run_id(run_id, context)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get verified run_id: {str(e)}")
|
||||
raise
|
||||
68
python/mrds_common/mrds/utils/sql_statements.py
Normal file
68
python/mrds_common/mrds/utils/sql_statements.py
Normal file
@@ -0,0 +1,68 @@
|
||||
sql_statements = {}
|
||||
|
||||
#
|
||||
# Workflows
|
||||
#
|
||||
|
||||
# register_workflow: Register new DW load
|
||||
|
||||
sql_statements[
|
||||
"register_workflow"
|
||||
] = """INSERT INTO CT_MRDS.A_WORKFLOW_HISTORY
|
||||
(A_WORKFLOW_HISTORY_KEY, WORKFLOW_RUN_ID,
|
||||
WORKFLOW_NAME, WORKFLOW_START, WORKFLOW_SSUCCESSFUL)
|
||||
VALUES (:a_workflow_history_key, :workflow_run_id, :workflow_name, SYSTIMESTAMP, :running_status)
|
||||
"""
|
||||
|
||||
# get_a_workflow_history_key: get new key from sequence
|
||||
sql_statements["get_a_workflow_history_key"] = (
|
||||
"SELECT CT_MRDS.A_WORKFLOW_HISTORY_KEY_SEQ.NEXTVAL FROM DUAL"
|
||||
)
|
||||
|
||||
# finalise: Update load record in A_LOAD_HISTORY after workflow completion
|
||||
sql_statements[
|
||||
"finalise_workflow"
|
||||
] = """UPDATE CT_MRDS.A_WORKFLOW_HISTORY
|
||||
SET WORKFLOW_END = SYSTIMESTAMP, WORKFLOW_SUCCESSFUL = :workflow_status
|
||||
WHERE A_WORKFLOW_HISTORY_KEY = :a_workflow_history_key
|
||||
"""
|
||||
#
|
||||
# Tasks
|
||||
#
|
||||
|
||||
# register_task
|
||||
|
||||
sql_statements[
|
||||
"register_task"
|
||||
] = """INSERT INTO CT_MRDS.A_TASK_HISTORY (A_TASK_HISTORY_KEY,
|
||||
A_WORKFLOW_HISTORY_KEY, TASK_RUN_ID,
|
||||
TASK_NAME, TASK_START, TASK_SUCCESSFUL)
|
||||
VALUES (:a_workflow_history_key, :workflow_run_id, :workflow_name, SYSTIMESTAMP, :running_status)
|
||||
"""
|
||||
|
||||
# get_a_task_history_key: get new key from sequence
|
||||
sql_statements["get_a_task_history_key"] = (
|
||||
"SELECT CT_MRDS.A_TASK_HISTORY_KEY_SEQ.NEXTVAL FROM DUAL"
|
||||
)
|
||||
|
||||
# finalise: Update load record in A_LOAD_HISTORY after workflow completion
|
||||
sql_statements[
|
||||
"finalise_task"
|
||||
] = """UPDATE CT_MRDS.A_TASK_HISTORY
|
||||
SET TASK_END = SYSTIMESTAMP, TASK_SUCCESSFUL = :workflow_status
|
||||
WHERE A_TASK_HISTORY_KEY = :a_workflow_history_key
|
||||
"""
|
||||
|
||||
#
|
||||
# Files
|
||||
#
|
||||
sql_statements["get_file_prefix"] = (
|
||||
"SELECT CT_MRDS.FILE_MANAGER.GET_BUCKET_PATH(:source_key, :source_file_id, :table_id) FROM DUAL"
|
||||
)
|
||||
|
||||
|
||||
def get_sql(stmt_id: str):
|
||||
if stmt_id in sql_statements:
|
||||
return sql_statements[stmt_id]
|
||||
else:
|
||||
return
|
||||
6
python/mrds_common/mrds/utils/static_vars.py
Normal file
6
python/mrds_common/mrds/utils/static_vars.py
Normal file
@@ -0,0 +1,6 @@
|
||||
#
|
||||
# Task management variables
|
||||
#
|
||||
status_running: str = "RUNNING"
|
||||
status_failed: str = "N"
|
||||
status_success: str = "Y"
|
||||
83
python/mrds_common/mrds/utils/utils.py
Normal file
83
python/mrds_common/mrds/utils/utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import re
|
||||
|
||||
|
||||
def parse_uri_with_regex(uri):
|
||||
"""
|
||||
Parses an Oracle Object Storage URI using regular expressions to extract the namespace,
|
||||
bucket name, prefix, and object name.
|
||||
|
||||
Parameters:
|
||||
uri (str): The URI string to parse, in the format '/n/{namespace}/b/{bucketname}/o/{object_path}'
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing (namespace, bucket_name, prefix, object_name)
|
||||
"""
|
||||
# Define the regular expression pattern
|
||||
pattern = r"^/n/([^/]+)/b/([^/]+)/o/(.*)$"
|
||||
|
||||
# Match the pattern against the URI
|
||||
match = re.match(pattern, uri)
|
||||
|
||||
if not match:
|
||||
raise ValueError("Invalid URI format")
|
||||
|
||||
# Extract namespace, bucket name, and object path from the matched groups
|
||||
namespace = match.group(1)
|
||||
bucket_name = match.group(2)
|
||||
object_path = match.group(3)
|
||||
|
||||
# Split the object path into prefix and object name
|
||||
if "/" in object_path:
|
||||
# Split at the last '/' to separate prefix and object name
|
||||
prefix, object_name = object_path.rsplit("/", 1)
|
||||
# Ensure the prefix ends with a '/'
|
||||
prefix += "/"
|
||||
else:
|
||||
# If there is no '/', there is no prefix
|
||||
prefix = ""
|
||||
object_name = object_path
|
||||
|
||||
return namespace, bucket_name, prefix, object_name
|
||||
|
||||
|
||||
def parse_output_columns(output_columns):
|
||||
xpath_entries = []
|
||||
csv_entries = []
|
||||
static_entries = []
|
||||
a_key_entries = []
|
||||
workflow_key_entries = []
|
||||
xml_position_entries = []
|
||||
column_order = []
|
||||
|
||||
for entry in output_columns:
|
||||
entry_type = entry["type"]
|
||||
column_header = entry["column_header"]
|
||||
column_order.append(column_header)
|
||||
|
||||
if entry_type == "xpath":
|
||||
xpath_expr = entry["value"]
|
||||
is_key = entry["is_key"]
|
||||
xpath_entries.append((xpath_expr, column_header, is_key))
|
||||
elif entry_type == "csv_header":
|
||||
value = entry["value"]
|
||||
csv_entries.append((column_header, value))
|
||||
elif entry_type == "static":
|
||||
value = entry["value"]
|
||||
static_entries.append((column_header, value))
|
||||
elif entry_type == "a_key":
|
||||
a_key_entries.append(column_header)
|
||||
elif entry_type == "workflow_key":
|
||||
workflow_key_entries.append(column_header)
|
||||
elif entry_type == "xpath_element_id": # TODO - update all xml_position namings to xpath_element_id
|
||||
xpath_expr = entry["value"]
|
||||
xml_position_entries.append((xpath_expr, column_header))
|
||||
|
||||
return (
|
||||
xpath_entries,
|
||||
csv_entries,
|
||||
static_entries,
|
||||
a_key_entries,
|
||||
workflow_key_entries,
|
||||
xml_position_entries,
|
||||
column_order,
|
||||
)
|
||||
23
python/mrds_common/mrds/utils/vault.py
Normal file
23
python/mrds_common/mrds/utils/vault.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import oci
|
||||
import ast
|
||||
import base64
|
||||
|
||||
# Specify the OCID of the secret to retrieve
|
||||
|
||||
|
||||
def get_password(ocid):
|
||||
|
||||
# Create vaultsclient using the default config file (\.oci\config) for auth to the API
|
||||
signer = signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
|
||||
|
||||
# Get the secret
|
||||
secretclient = oci.secrets.SecretsClient({}, signer=signer)
|
||||
secretcontents = secretclient.get_secret_bundle(secret_id=ocid)
|
||||
|
||||
# Decode the secret from base64 and print
|
||||
keybase64 = secretcontents.data.secret_bundle_content.content
|
||||
keybase64bytes = keybase64.encode("ascii")
|
||||
keybytes = base64.b64decode(keybase64bytes)
|
||||
key = keybytes.decode("ascii")
|
||||
keydict = ast.literal_eval(key)
|
||||
return keydict["password"]
|
||||
177
python/mrds_common/mrds/utils/xml_utils.py
Normal file
177
python/mrds_common/mrds/utils/xml_utils.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import xmlschema
|
||||
import hashlib
|
||||
from lxml import etree
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def validate_xml(xml_file, xsd_file):
|
||||
try:
|
||||
# Create an XMLSchema instance with strict validation
|
||||
schema = xmlschema.XMLSchema(xsd_file, validation="strict")
|
||||
# Validate the XML file
|
||||
schema.validate(xml_file)
|
||||
return True, "XML file is valid against the provided XSD schema."
|
||||
except xmlschema.validators.exceptions.XMLSchemaValidationError as e:
|
||||
return False, f"XML validation error: {str(e)}"
|
||||
except xmlschema.validators.exceptions.XMLSchemaException as e:
|
||||
return False, f"XML schema error: {str(e)}"
|
||||
except Exception as e:
|
||||
return False, f"An error occurred during XML validation: {str(e)}"
|
||||
|
||||
|
||||
def extract_data(
|
||||
filename,
|
||||
xpath_columns, # List[(expr, header, is_key)]
|
||||
xml_position_columns, # List[(expr, header)]
|
||||
namespaces,
|
||||
workflow_context,
|
||||
encoding_type="utf-8",
|
||||
):
|
||||
"""
|
||||
Parses an XML file using XPath expressions and extracts data.
|
||||
|
||||
Parameters:
|
||||
- filename (str): The path to the XML file to parse.
|
||||
- xpath_columns (list): A list of tuples, each containing:
|
||||
- XPath expression (str)
|
||||
- CSV column header (str)
|
||||
- Indicator if the field is a key ('Y' or 'N')
|
||||
- xml_position_columns (list)
|
||||
- namespaces (dict): Namespace mapping needed for lxml's xpath()
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary containing headers and rows with extracted data.
|
||||
"""
|
||||
|
||||
parser = etree.XMLParser(remove_blank_text=True)
|
||||
tree = etree.parse(filename, parser)
|
||||
root = tree.getroot()
|
||||
|
||||
# Separate out key vs non‐key columns
|
||||
key_cols = [ (expr, h) for expr, h, k in xpath_columns if k == "Y" ]
|
||||
nonkey_cols = [ (expr, h) for expr, h, k in xpath_columns if k == "N" ]
|
||||
|
||||
# Evaluate every non‐key XPath and keep the ELEMENT nodes
|
||||
nonkey_elements = {}
|
||||
for expr, header in nonkey_cols:
|
||||
elems = root.xpath(expr, namespaces=namespaces)
|
||||
nonkey_elements[header] = elems
|
||||
|
||||
# figure out how many rows total we need
|
||||
# that's the maximum length of any of the nonkey lists
|
||||
if nonkey_elements:
|
||||
row_count = max(len(lst) for lst in nonkey_elements.values())
|
||||
else:
|
||||
row_count = 0
|
||||
|
||||
# pad every nonkey list up to row_count with `None`
|
||||
for header, lst in nonkey_elements.items():
|
||||
if len(lst) < row_count:
|
||||
lst.extend([None] * (row_count - len(lst)))
|
||||
|
||||
# key columns
|
||||
key_values = []
|
||||
for expr, header in key_cols:
|
||||
nodes = root.xpath(expr, namespaces=namespaces)
|
||||
if not nodes:
|
||||
key_values.append("")
|
||||
else:
|
||||
first = nodes[0]
|
||||
txt = (first.text if isinstance(first, etree._Element) else str(first)) or ""
|
||||
key_values.append(txt.strip())
|
||||
|
||||
# xml_position columns
|
||||
xml_positions = {}
|
||||
for expr, header in xml_position_columns:
|
||||
xml_positions[header] = root.xpath(expr, namespaces=namespaces)
|
||||
|
||||
# prepare headers
|
||||
headers = [h for _, h in nonkey_cols] + [h for _, h in key_cols] + [h for _, h in xml_position_columns]
|
||||
|
||||
# build rows
|
||||
rows = []
|
||||
for i in range(row_count):
|
||||
row = []
|
||||
|
||||
# non‐key data
|
||||
for expr, header in nonkey_cols:
|
||||
elem = nonkey_elements[header][i]
|
||||
text = ""
|
||||
if isinstance(elem, etree._Element):
|
||||
text = elem.text or ""
|
||||
elif elem is not None:
|
||||
text = str(elem)
|
||||
row.append(text.strip())
|
||||
|
||||
# key columns
|
||||
row.extend(key_values)
|
||||
|
||||
# xml_position columns
|
||||
for expr, header in xml_position_columns:
|
||||
if not nonkey_cols:
|
||||
row.append("")
|
||||
continue
|
||||
|
||||
first_header = nonkey_cols[0][1]
|
||||
data_elem = nonkey_elements[first_header][i]
|
||||
if data_elem is None:
|
||||
row.append("")
|
||||
continue
|
||||
|
||||
target_list = xml_positions[header]
|
||||
current = data_elem
|
||||
found = None
|
||||
while current is not None:
|
||||
if current in target_list:
|
||||
found = current
|
||||
break
|
||||
current = current.getparent()
|
||||
|
||||
if not found:
|
||||
row.append("")
|
||||
else:
|
||||
# compute full‐path with indices
|
||||
path_elems = []
|
||||
walk = found
|
||||
while walk is not None:
|
||||
idx = 1 + sum(1 for s in walk.itersiblings(preceding=True) if s.tag == walk.tag)
|
||||
path_elems.append(f"{walk.tag}[{idx}]")
|
||||
walk = walk.getparent()
|
||||
full_path = "/" + "/".join(reversed(path_elems))
|
||||
row.append(_xml_pos_hasher(full_path, workflow_context["a_workflow_history_key"]))
|
||||
|
||||
rows.append(row)
|
||||
|
||||
return {"headers": headers, "rows": rows}
|
||||
|
||||
|
||||
def _xml_pos_hasher(input_string, salt, hash_length=15):
|
||||
"""
|
||||
Helps hashing xml positions.
|
||||
|
||||
Parameters:
|
||||
input_string (str): The string to hash.
|
||||
salt (int): The integer salt to ensure deterministic, run-specific behavior.
|
||||
hash_length (int): The desired length of the resulting hash (default is 15 digits).
|
||||
|
||||
Returns:
|
||||
int: A deterministic integer hash of the specified length.
|
||||
"""
|
||||
# Ensure the hash length is valid
|
||||
if hash_length <= 0:
|
||||
raise ValueError("Hash length must be a positive integer.")
|
||||
|
||||
# Combine the input string with the salt to create a deterministic input
|
||||
salted_input = f"{salt}:{input_string}"
|
||||
|
||||
# Generate a SHA-256 hash of the salted input
|
||||
hash_object = hashlib.sha256(salted_input.encode())
|
||||
full_hash = hash_object.hexdigest()
|
||||
|
||||
# Convert the hash to an integer
|
||||
hash_integer = int(full_hash, 16)
|
||||
|
||||
# Truncate or pad the hash to the desired length
|
||||
truncated_hash = str(hash_integer)[:hash_length]
|
||||
|
||||
return int(truncated_hash)
|
||||
Reference in New Issue
Block a user