This commit is contained in:
Grzegorz Michalski
2026-03-02 09:47:35 +01:00
commit 2c225d68ac
715 changed files with 130067 additions and 0 deletions

View File

@@ -0,0 +1 @@
__version__ = "0.6.0"

View File

@@ -0,0 +1,117 @@
import click
import json
import logging
import sys
from mrds import __version__
from mrds.core import main
@click.command()
@click.version_option(version=__version__, prog_name="mrds")
@click.option(
"--workflow-context",
"-w",
required=False,
help="Workflow context to be used by the application. This is required unless --generate-workflow-context is provided.",
)
@click.option(
"--source-filename",
"-s",
required=True,
help="Source filename to be processed.",
)
@click.option(
"--config-file",
"-c",
type=click.Path(exists=True),
required=True,
help="Path to the YAML configuration file.",
)
@click.option(
"--generate-workflow-context",
is_flag=True,
default=False,
help="Generate a workflow context automatically. If this is set, --workflow-context is not required.",
)
@click.option(
"--keep-source-file",
is_flag=True,
default=False,
help="Keep source file, instead of deleting it.",
)
@click.option(
"--keep-tmp-dir",
is_flag=True,
default=False,
help="Keep tmp directory, instead of deleting it.",
)
def cli_main(
workflow_context,
source_filename,
config_file,
generate_workflow_context,
keep_source_file,
keep_tmp_dir,
):
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
handlers=[
logging.StreamHandler(sys.stdout),
],
)
# Handle conflicting options
if workflow_context and generate_workflow_context:
raise click.UsageError(
"You cannot use both --workflow-context and --generate-workflow-context at the same time. "
"Please provide only one."
)
# Enforce that either --workflow-context or --generate-workflow-context must be provided
if not workflow_context and not generate_workflow_context:
raise click.UsageError(
"You must provide --workflow-context or use --generate-workflow-context flag."
)
# Parse and validate the workflow_context if provided
if workflow_context:
try:
workflow_context = json.loads(workflow_context)
except json.JSONDecodeError as e:
raise click.UsageError(f"Invalid JSON for --workflow-context: {e}")
# Validate that the workflow_context matches the expected structure
if (
not isinstance(workflow_context, dict)
or "run_id" not in workflow_context
or "a_workflow_history_key" not in workflow_context
):
raise click.UsageError(
"Invalid workflow context structure. It must be a JSON object with 'run_id' and 'a_workflow_history_key'."
)
# Call the core processing function
main(
workflow_context,
source_filename,
config_file,
generate_workflow_context,
keep_source_file,
keep_tmp_dir,
)
if __name__ == "__main__":
try:
cli_main()
sys.exit(0)
except click.UsageError as e:
logging.error(f"Usage error: {e}")
sys.exit(2)
except Exception as e:
logging.error(f"Unexpected error: {e}")
sys.exit(1)

View File

@@ -0,0 +1,366 @@
import os
import uuid
import logging
import yaml
import zipfile
import tempfile
from dataclasses import dataclass, field
from mrds import __version__
from mrds.processors import get_file_processor
from mrds.utils import (
manage_runs,
objectstore,
static_vars,
xml_utils,
)
# environment variables
MRDS_ENV = os.getenv("MRDS_ENV", "poc")
BUCKET = os.getenv("INBOX_BUCKET", "mrds_inbox_poc")
BUCKET_NAMESPACE = os.getenv("BUCKET_NAMESPACE", "frcnomajoc7v")
# Static configuration variables
WORKFLOW_TYPE = "ODS"
ENCODING_TYPE = "utf-8"
CONFIG_REQUIRED_KEYS = [
"tmpdir",
"inbox_prefix",
"archive_prefix",
"workflow_name",
"validation_schema_path",
"tasks",
"file_type",
]
TASK_REQUIRED_KEYS = [
"task_name",
"ods_prefix",
"output_table",
"output_columns",
]
STATUS_SUCCESS = static_vars.status_success
STATUS_FAILURE = static_vars.status_failed
@dataclass
class GlobalConfig:
tmpdir: str
inbox_prefix: str
archive_prefix: str
workflow_name: str
source_filename: str
validation_schema_path: str
bucket: str
bucket_namespace: str
file_type: str
encoding_type: str
def __post_init__(self):
self.original_source_filename = self.source_filename # keep this in case we have a zip file to archive
@property
def source_filepath(self) -> str:
return os.path.join(self.tmpdir, self.source_filename)
@property
def original_source_filepath(self) -> str:
return os.path.join(self.tmpdir, self.original_source_filename)
@dataclass
class TaskConfig:
task_name: str
ods_prefix: str
output_table: str
namespaces: dict
output_columns: list
def initialize_config(source_filename, config_file_path):
logging.info(f"Source filename is set to: {source_filename}")
logging.info(f"Loading configuration from {config_file_path}")
# Ensure the file exists
if not os.path.exists(config_file_path):
raise FileNotFoundError(f"Configuration file {config_file_path} not found.")
# Load the configuration
with open(config_file_path, "r") as f:
config_data = yaml.safe_load(f)
logging.debug(f"Configuration data: {config_data}")
missing_keys = [key for key in CONFIG_REQUIRED_KEYS if key not in config_data]
if missing_keys:
raise ValueError(f"Missing required keys in configuration: {missing_keys}")
# Create GlobalConfig instance
global_config = GlobalConfig(
tmpdir=config_data["tmpdir"],
inbox_prefix=config_data["inbox_prefix"],
archive_prefix=config_data["archive_prefix"],
workflow_name=config_data["workflow_name"],
source_filename=source_filename,
validation_schema_path=config_data["validation_schema_path"],
bucket=BUCKET,
bucket_namespace=BUCKET_NAMESPACE,
file_type=config_data["file_type"],
encoding_type=config_data.get("encoding_type", ENCODING_TYPE),
)
# Create list of TaskConfig instances
tasks_data = config_data["tasks"]
tasks = []
for task_data in tasks_data:
# Validate required keys in task_data
missing_task_keys = [key for key in TASK_REQUIRED_KEYS if key not in task_data]
if missing_task_keys:
raise ValueError(
f"Missing required keys in task configuration: {missing_task_keys}"
)
task = TaskConfig(
task_name=task_data["task_name"],
ods_prefix=task_data["ods_prefix"],
output_table=task_data["output_table"],
namespaces=task_data.get("namespaces", {}),
output_columns=task_data["output_columns"],
)
tasks.append(task)
return global_config, tasks
def initialize_workflow(global_config):
run_id = str(uuid.uuid4())
logging.info(f"Initializing workflow '{global_config.workflow_name}'")
a_workflow_history_key = manage_runs.init_workflow(
WORKFLOW_TYPE, global_config.workflow_name, run_id
)
return {
"run_id": run_id,
"a_workflow_history_key": a_workflow_history_key,
}
def download_source_file(client, global_config):
logging.info(
f"Downloading source file '{global_config.source_filename}' "
f"from '{global_config.bucket}/{global_config.inbox_prefix}'"
)
objectstore.download_file(
client,
global_config.bucket_namespace,
global_config.bucket,
global_config.inbox_prefix,
global_config.source_filename,
global_config.source_filepath,
)
logging.info(f"Source file downloaded to '{global_config.source_filepath}'")
def delete_source_file(client, global_config):
logging.info(
f"Deleting source file '{global_config.bucket}/{global_config.inbox_prefix}/{global_config.original_source_filename}'"
)
objectstore.delete_file(
client,
global_config.original_source_filename,
global_config.bucket_namespace,
global_config.bucket,
global_config.inbox_prefix,
)
logging.info(
f"Deleted source file '{global_config.bucket}/{global_config.inbox_prefix}/{global_config.original_source_filename}'"
)
def archive_source_file(client, global_config):
logging.info(
f"Archiving source file to '{global_config.bucket}/{global_config.archive_prefix}/{global_config.original_source_filename}'"
)
objectstore.upload_file(
client,
global_config.original_source_filepath,
global_config.bucket_namespace,
global_config.bucket,
global_config.archive_prefix,
global_config.original_source_filename,
)
logging.info(
f"Source file archived to '{global_config.bucket}/{global_config.archive_prefix}/{global_config.original_source_filename}'"
)
def unzip_source_file_if_needed(global_config):
source_filepath = global_config.source_filepath
# If it's not a zip, nothing to do
if not zipfile.is_zipfile(source_filepath):
logging.info(f"File '{source_filepath}' is not a ZIP file.")
return True
logging.info(f"File '{source_filepath}' is a ZIP file. Unzipping...")
extract_dir = os.path.dirname(source_filepath)
try:
with zipfile.ZipFile(source_filepath, "r") as zip_ref:
extracted_files = zip_ref.namelist()
if len(extracted_files) != 1:
logging.error(
f"Expected one file in the ZIP, but found {len(extracted_files)} files."
)
return False
# Extract everything
zip_ref.extractall(extract_dir)
except Exception as e:
logging.error(f"Error while extracting '{source_filepath}': {e}")
return False
# Update the global_config to point to the extracted file
extracted_filename = extracted_files[0]
global_config.source_filename = extracted_filename
logging.info(
f"Extracted '{extracted_filename}' to '{extract_dir}'. "
f"Updated source_filepath to '{global_config.source_filepath}'."
)
return True
def validate_source_file(global_config):
file_type = global_config.file_type.lower()
if file_type == "xml":
xml_is_valid, xml_validation_message = xml_utils.validate_xml(
global_config.source_filepath, global_config.validation_schema_path
)
if not xml_is_valid:
raise ValueError(f"XML validation failed: {xml_validation_message}")
logging.info(xml_validation_message)
elif file_type == "csv":
# TODO: add CSV validation here
pass
else:
raise ValueError(f"Unsupported file type: {file_type}")
return True
def process_tasks(tasks, global_config, workflow_context, client):
# get appropriate task processor
processor_class = get_file_processor(global_config)
for task_conf in tasks:
logging.info(f"Starting task '{task_conf.task_name}'")
file_processor = processor_class(
global_config, task_conf, client, workflow_context
)
file_processor.process()
def finalize_workflow(workflow_context, success=True):
status = STATUS_SUCCESS if success else STATUS_FAILURE
manage_runs.finalise_workflow(workflow_context["a_workflow_history_key"], status)
if success:
logging.info("Workflow completed successfully")
else:
logging.error("Workflow failed")
def main(
workflow_context: dict,
source_filename: str,
config_file_path: str,
generate_workflow_context=False,
keep_source_file=False,
keep_tmp_dir=False,
):
logging.info(f"Initializing mrds app, version {__version__}")
tmpdir_manager = None
try:
# get configs
global_config, tasks = initialize_config(source_filename, config_file_path)
# Handle temporary dirs
if keep_tmp_dir:
tmpdir = tempfile.mkdtemp(
prefix="mrds_", dir=global_config.tmpdir
) # dir is created and never deleted
logging.info(
f"Created temporary working directory (not auto-deleted): {tmpdir}"
)
else:
tmpdir_manager = tempfile.TemporaryDirectory(
prefix="mrds_", dir=global_config.tmpdir
)
tmpdir = tmpdir_manager.name
logging.info(
f"Created temporary working directory (auto-deleted): {tmpdir}"
)
# override tmpdir with newly created tmpdir
global_config.tmpdir = tmpdir
client = objectstore.get_client()
# Handle workflow_context generation if required
if generate_workflow_context:
logging.info("Generating workflow context automatically.")
workflow_context = initialize_workflow(global_config)
logging.info(f"Generated workflow context: {workflow_context}")
else:
logging.info(f"Using provided workflow context: {workflow_context}")
download_source_file(client, global_config)
unzip_source_file_if_needed(global_config)
validate_source_file(global_config)
process_tasks(tasks, global_config, workflow_context, client)
if generate_workflow_context:
finalize_workflow(workflow_context)
if not keep_source_file:
archive_source_file(client, global_config)
delete_source_file(client, global_config)
except Exception as e:
logging.error(f"Critical error: {str(e)}")
# Finalize workflow with failure if needed
if generate_workflow_context and "workflow_context" in locals():
finalize_workflow(workflow_context, success=False)
raise RuntimeError(f"Workflow failed due to: {e}")
finally:
# Always attempt to remove tmpdir if created a TemporaryDirectory manager
if tmpdir_manager and not keep_tmp_dir:
try:
tmpdir_manager.cleanup()
logging.info(f"Deleted temporary working directory {tmpdir}")
except Exception:
logging.exception(
f"Failed to delete up temporary working directory {tmpdir}"
)

View File

@@ -0,0 +1,186 @@
# static configs
tmpdir: /tmp
inbox_prefix: INBOX/RQSD/RQSD_PROCESS
workflow_name: w_ODS_RQSD_PROCESS_DEVO
validation_schema_path: None
file_type: csv
# task configs
tasks:
- task_name: m_ODS_RQSD_OBSERVATIONS_PARSE
ods_prefix: INBOX/RQSD/RQSD_PROCESS/RQSD_OBSERVATIONS
output_table: RQSD_OBSERVATIONS
output_columns:
- type: 'workflow_key'
column_header: 'A_WORKFLOW_HISTORY_KEY'
- type: 'csv_header'
value: 'datacollectioncode'
column_header: 'datacollectioncode'
- type: 'csv_header'
value: 'datacollectionname'
column_header: 'datacollectionname'
- type: 'csv_header'
value: 'datacollectionowner'
column_header: 'datacollectionowner'
- type: 'csv_header'
value: 'reportingcyclename'
column_header: 'reportingcyclename'
- type: 'csv_header'
value: 'reportingcyclestatus'
column_header: 'reportingcyclestatus'
- type: 'csv_header'
value: 'modulecode'
column_header: 'modulecode'
- type: 'csv_header'
value: 'modulename'
column_header: 'modulename'
- type: 'csv_header'
value: 'moduleversionnumber'
column_header: 'moduleversionnumber'
- type: 'csv_header'
value: 'reportingentitycollectionuniqueid'
column_header: 'reportingentitycollectionuniqueid'
- type: 'csv_header'
value: 'entityattributereportingcode'
column_header: 'entityattributereportingcode'
- type: 'csv_header'
value: 'reportingentityname'
column_header: 'reportingentityname'
- type: 'csv_header'
value: 'reportingentityentitytype'
column_header: 'reportingentityentitytype'
- type: 'csv_header'
value: 'entityattributecountry'
column_header: 'entityattributecountry'
- type: 'csv_header'
value: 'entitygroupentityname'
column_header: 'entitygroupentityname'
- type: 'csv_header'
value: 'obligationmodulereferencedate'
column_header: 'obligationmodulereferencedate'
- type: 'csv_header'
value: 'obligationmoduleremittancedate'
column_header: 'obligationmoduleremittancedate'
- type: 'csv_header'
value: 'receivedfilereceiveddate'
column_header: 'receivedfilereceiveddate'
- type: 'csv_header'
value: 'obligationmoduleexpected'
column_header: 'obligationmoduleexpected'
- type: 'csv_header'
value: 'receivedfileversionnumber'
column_header: 'receivedfileversionnumber'
- type: 'csv_header'
value: 'revalidationversionnumber'
column_header: 'revalidationversionnumber'
- type: 'csv_header'
value: 'revalidationdate'
column_header: 'revalidationdate'
- type: 'csv_header'
value: 'receivedfilesystemfilename'
column_header: 'receivedfilesystemfilename'
- type: 'csv_header'
value: 'obligationstatusstatus'
column_header: 'obligationstatusstatus'
- type: 'csv_header'
value: 'filestatussetsubmissionstatus'
column_header: 'filestatussetsubmissionstatus'
- type: 'csv_header'
value: 'filestatussetvalidationstatus'
column_header: 'filestatussetvalidationstatus'
- type: 'csv_header'
value: 'filestatussetexternalvalidationstatus'
column_header: 'filestatussetexternalvalidationstatus'
- type: 'csv_header'
value: 'numberoferrors'
column_header: 'numberoferrors'
- type: 'csv_header'
value: 'numberofwarnings'
column_header: 'numberofwarnings'
- type: 'csv_header'
value: 'delayindays'
column_header: 'delayindays'
- type: 'csv_header'
value: 'failedattempts'
column_header: 'failedattempts'
- type: 'csv_header'
value: 'observationvalue'
column_header: 'observationvalue'
- type: 'csv_header'
value: 'observationtextvalue'
column_header: 'observationtextvalue'
- type: 'csv_header'
value: 'observationdatevalue'
column_header: 'observationdatevalue'
- type: 'csv_header'
value: 'datapointsetdatapointidentifier'
column_header: 'datapointsetdatapointidentifier'
- type: 'csv_header'
value: 'datapointsetlabel'
column_header: 'datapointsetlabel'
- type: 'csv_header'
value: 'obsrvdescdatatype'
column_header: 'obsrvdescdatatype'
- type: 'csv_header'
value: 'ordinatecode'
column_header: 'ordinatecode'
- type: 'csv_header'
value: 'ordinateposition'
column_header: 'ordinateposition'
- type: 'csv_header'
value: 'tablename'
column_header: 'tablename'
- type: 'csv_header'
value: 'isstock'
column_header: 'isstock'
- type: 'csv_header'
value: 'scale'
column_header: 'scale'
- type: 'csv_header'
value: 'currency'
column_header: 'currency'
- type: 'csv_header'
value: 'numbertype'
column_header: 'numbertype'
- type: 'csv_header'
value: 'ismandatory'
column_header: 'ismandatory'
- type: 'csv_header'
value: 'decimalplaces'
column_header: 'decimalplaces'
- type: 'csv_header'
value: 'serieskey'
column_header: 'serieskey'
- type: 'csv_header'
value: 'tec_source_system'
column_header: 'tec_source_system'
- type: 'csv_header'
value: 'tec_dataset'
column_header: 'tec_dataset'
- type: 'csv_header'
value: 'tec_surrogate_key'
column_header: 'tec_surrogate_key'
- type: 'csv_header'
value: 'tec_crc'
column_header: 'tec_crc'
- type: 'csv_header'
value: 'tec_ingestion_date'
column_header: 'tec_ingestion_date'
- type: 'csv_header'
value: 'tec_version_id'
column_header: 'tec_version_id'
- type: 'csv_header'
value: 'tec_execution_date'
column_header: 'tec_execution_date'
- type: 'csv_header'
value: 'tec_run_id'
column_header: 'tec_run_id'
- type: 'static'
value: 'test test'
column_header: 'BLABLA'
- type: 'a_key'
column_header: 'A_KEY'
- type: 'csv_header'
value: 'tec_business_date'
column_header: 'tec_business_dateTest!'

View File

@@ -0,0 +1,50 @@
# file uploader
import os
import sys
import logging
from mrds.utils import objectstore
BUCKET = os.getenv("INBOX_BUCKET", "mrds_inbox_poc")
BUCKET_NAMESPACE = os.getenv("BUCKET_NAMESPACE", "frcnomajoc7v")
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
handlers=[
logging.StreamHandler(sys.stdout),
],
)
source_filepath = '/home/dbt/tmp/mrds_4twsw_ib/20250630_Pre-Production_DV_P2_DBT_I4.zip'
source_filename = '20250630_Pre-Production_DV_P2_DBT_I4.zip'
target_prefix = 'INBOX/CSDB/STC_CentralizedSecuritiesDissemination_ECB'
def upload_file():
client = objectstore.get_client()
logging.info(
f"uploading source file to '{BUCKET}/{target_prefix}/{source_filename}'"
)
objectstore.upload_file(
client,
source_filepath,
BUCKET_NAMESPACE,
BUCKET,
target_prefix,
source_filename,
)
logging.info(
f"Source file uploaded to '{BUCKET}/{target_prefix}/{source_filename}'"
)
if __name__ == "__main__":
try:
upload_file()
sys.exit(0)
except Exception as e:
logging.error(f"Unexpected error: {e}")
sys.exit(1)

View File

@@ -0,0 +1,15 @@
from .xml_processor import XMLTaskProcessor
from .csv_processor import CSVTaskProcessor
def get_file_processor(global_config):
"""
Factory function to get the appropriate file processor class based on the file type in the global configuration.
"""
file_type = global_config.file_type.lower()
if file_type == "xml":
return XMLTaskProcessor
elif file_type == "csv":
return CSVTaskProcessor
else:
raise ValueError(f"Unsupported file type: {file_type}")

View File

@@ -0,0 +1,211 @@
import logging
import os
import csv
from abc import ABC, abstractmethod
from mrds.utils.utils import parse_output_columns
from mrds.utils import (
manage_files,
manage_runs,
objectstore,
static_vars,
)
OUTPUT_FILENAME_TEMPLATE = "{output_table}-{task_history_key}.csv"
STATUS_SUCCESS = static_vars.status_success # duplicated needs to be moved #TODO
class TaskProcessor(ABC):
def __init__(self, global_config, task_conf, client, workflow_context):
self.global_config = global_config
self.task_conf = task_conf
self.client = client
self.workflow_context = workflow_context
self._init_common()
self._post_init()
def _init_common(self):
# Initialize task
self.a_task_history_key = manage_runs.init_task(
self.task_conf.task_name,
self.workflow_context["run_id"],
self.workflow_context["a_workflow_history_key"],
)
logging.info(f"Task initialized with history key: {self.a_task_history_key}")
# Define output file paths
self.output_filename = OUTPUT_FILENAME_TEMPLATE.format(
output_table=self.task_conf.output_table,
task_history_key=self.a_task_history_key,
)
self.output_filepath = os.path.join(
self.global_config.tmpdir, self.output_filename
)
# Parse the output_columns
(
self.xpath_entries,
self.csv_entries,
self.static_entries,
self.a_key_entries,
self.workflow_key_entries,
self.xml_position_entries,
self.column_order,
) = parse_output_columns(self.task_conf.output_columns)
def _post_init(self):
"""Optional hook for classes to override"""
pass
@abstractmethod
def _extract(self):
"""Non-optional hook for classes to override"""
pass
def _enrich(self):
"""
Stream-based enrich: read one row at a time, append static/A-key/workflow-key,
reorder columns, and write out immediately.
"""
TASK_HISTORY_MULTIPLIER = 1_000_000_000
logging.info(f"Enriching CSV file at '{self.output_filepath}'")
temp_output = self.output_filepath + ".tmp"
encoding = self.global_config.encoding_type
with open(self.output_filepath, newline="", encoding=encoding) as inf, open(
temp_output, newline="", encoding=encoding, mode="w"
) as outf:
reader = csv.reader(inf)
writer = csv.writer(outf, quoting=csv.QUOTE_ALL)
# Read the original header
original_headers = next(reader)
# Compute the full set of headers
headers = list(original_headers)
# Add static column headers if missing
for col_name, _ in self.static_entries:
if col_name not in headers:
headers.append(col_name)
# Add A-key column headers if missing
for col_name in self.a_key_entries:
if col_name not in headers:
headers.append(col_name)
# Add workflow key column headers if missing
for col_name in self.workflow_key_entries:
if col_name not in headers:
headers.append(col_name)
# Rearrange headers to desired ordr
header_to_index = {h: i for i, h in enumerate(headers)}
out_indices = [
header_to_index[h] for h in self.column_order if h in header_to_index
]
out_headers = [headers[i] for i in out_indices]
# Write the new header
writer.writerow(out_headers)
# Stream each row, enrich in-place, reorder, and write
row_count = 0
base_task_history = int(self.a_task_history_key) * TASK_HISTORY_MULTIPLIER
for i, in_row in enumerate(reader, start=1):
# Build a working list that matches `headers` order.
# Start by copying the existing columns (or '' if missing)
work_row = [None] * len(headers)
for j, h in enumerate(original_headers):
idx = header_to_index[h]
work_row[idx] = in_row[j]
# Fill static columns
for col_name, value in self.static_entries:
idx = header_to_index[col_name]
work_row[idx] = value
# Fill A-key columns
for col_name in self.a_key_entries:
idx = header_to_index[col_name]
a_key_value = base_task_history + i
work_row[idx] = str(a_key_value)
# Fill workflow key columns
wf_val = self.workflow_context["a_workflow_history_key"]
for col_name in self.workflow_key_entries:
idx = header_to_index[col_name]
work_row[idx] = wf_val
# Reorder to output order and write
out_row = [work_row[j] for j in out_indices]
writer.writerow(out_row)
row_count += 1
# Atomically replace
os.replace(temp_output, self.output_filepath)
logging.info(
f"CSV file enriched at '{self.output_filepath}', {row_count} rows generated"
)
def _upload(self):
# Upload CSV to object store
logging.info(
f"Uploading CSV file to '{self.global_config.bucket}/{self.task_conf.ods_prefix}/{self.output_filename}'"
)
objectstore.upload_file(
self.client,
self.output_filepath,
self.global_config.bucket_namespace,
self.global_config.bucket,
self.task_conf.ods_prefix,
self.output_filename,
)
logging.info(
f"CSV file uploaded to '{self.global_config.bucket}/{self.task_conf.ods_prefix}/{self.output_filename}'"
)
def _process_remote(self):
# Process the source file
logging.info(f"Processing source file '{self.output_filename}' with CT_MRDS.FILE_MANAGER.PROCESS_SOURCE_FILE database function.")
try:
manage_files.process_source_file(
self.task_conf.ods_prefix, self.output_filename
)
except Exception as e:
logging.error(
f"Processing source file '{self.output_filename}' failed. Cleaning up..."
)
objectstore.delete_file(
self.client,
self.output_filename,
self.global_config.bucket_namespace,
self.global_config.bucket,
self.task_conf.ods_prefix,
)
logging.error(
f"CSV file '{self.global_config.bucket}/{self.task_conf.ods_prefix}/{self.output_filename}' deleted."
)
raise
else:
logging.info(f"Source file '{self.output_filename}' processed")
def _finalize(self):
# Finalize task
manage_runs.finalise_task(self.a_task_history_key, STATUS_SUCCESS)
logging.info(f"Task '{self.task_conf.task_name}' completed successfully")
def process(self):
# main processor function
self._extract()
self._enrich()
self._upload()
self._process_remote()
self._finalize()

View File

@@ -0,0 +1,52 @@
import logging
import csv
import os
from .base import TaskProcessor
class CSVTaskProcessor(TaskProcessor):
def _extract(self):
input_path = self.global_config.source_filepath
output_path = self.output_filepath
encoding = self.global_config.encoding_type
logging.info(f"Reading source CSV file at '{input_path}'")
# Open both input & output at once for streaming row-by-row
temp_output = output_path + ".tmp"
with open(input_path, newline="", encoding=encoding) as inf, open(
temp_output, newline="", encoding=encoding, mode="w"
) as outf:
reader = csv.reader(inf)
writer = csv.writer(outf, quoting=csv.QUOTE_ALL)
# Read and parse the header
headers = next(reader)
# Build the list of headers to keep + their new names
headers_to_keep = [old for _, old in self.csv_entries]
headers_rename = [new for new, _ in self.csv_entries]
# Check if all specified headers exist in the input file
missing = [h for h in headers_to_keep if h not in headers]
if missing:
raise ValueError(
f"The following headers are not in the input CSV: {missing}"
)
# Determine the indices of the headers to keep
indices = [headers.index(old) for old in headers_to_keep]
# Write the renamed header
writer.writerow(headers_rename)
# Stream through every data row and write out the filtered columns
for row in reader:
filtered = [row[i] for i in indices]
writer.writerow(filtered)
# Atomically replace the old file
os.replace(temp_output, output_path)
logging.info(f"Core data written to CSV file at '{output_path}'")

View File

@@ -0,0 +1,30 @@
import logging
from .base import TaskProcessor
from mrds.utils import (
xml_utils,
csv_utils,
)
class XMLTaskProcessor(TaskProcessor):
def _extract(self):
# Extract data from XML
csv_data = xml_utils.extract_data(
self.global_config.source_filepath,
self.xpath_entries,
self.xml_position_entries,
self.task_conf.namespaces,
self.workflow_context,
self.global_config.encoding_type,
)
logging.info(f"CSV data extracted for task '{self.task_conf.task_name}'")
# Generate CSV
logging.info(f"Writing core data to CSV file at '{self.output_filepath}'")
csv_utils.write_data_to_csv_file(
self.output_filepath, csv_data, self.global_config.encoding_type
)
logging.info(f"Core data written to CSV file at '{self.output_filepath}'")

View File

@@ -0,0 +1,69 @@
import csv
import os
TASK_HISTORY_MULTIPLIER = 1_000_000_000
def read_csv_file(csv_filepath, encoding_type="utf-8"):
with open(csv_filepath, "r", newline="", encoding=encoding_type) as csvfile:
reader = list(csv.reader(csvfile))
headers = reader[0]
data_rows = reader[1:]
return headers, data_rows
def write_data_to_csv_file(csv_filepath, data, encoding_type="utf-8"):
temp_csv_filepath = csv_filepath + ".tmp"
with open(temp_csv_filepath, "w", newline="", encoding=encoding_type) as csvfile:
writer = csv.writer(csvfile, quoting=csv.QUOTE_ALL)
writer.writerow(data["headers"])
writer.writerows(data["rows"])
os.replace(temp_csv_filepath, csv_filepath)
def add_static_columns(data_rows, headers, static_entries):
for column_header, value in static_entries:
if column_header not in headers:
headers.append(column_header)
for row in data_rows:
row.append(value)
else:
idx = headers.index(column_header)
for row in data_rows:
row[idx] = value
def add_a_key_columns(data_rows, headers, a_key_entries, task_history_key):
for column_header in a_key_entries:
if column_header not in headers:
headers.append(column_header)
for i, row in enumerate(data_rows, start=1):
a_key_value = int(task_history_key) * TASK_HISTORY_MULTIPLIER + i
row.append(str(a_key_value))
else:
idx = headers.index(column_header)
for i, row in enumerate(data_rows, start=1):
a_key_value = int(task_history_key) * TASK_HISTORY_MULTIPLIER + i
row[idx] = str(a_key_value)
def add_workflow_key_columns(data_rows, headers, workflow_key_entries, workflow_key):
for column_header in workflow_key_entries:
if column_header not in headers:
headers.append(column_header)
for row in data_rows:
row.append(workflow_key)
else:
idx = headers.index(column_header)
for row in data_rows:
row[idx] = workflow_key
def rearrange_columns(headers, data_rows, column_order):
header_to_index = {header: idx for idx, header in enumerate(headers)}
new_indices = [
header_to_index[header] for header in column_order if header in header_to_index
]
headers = [headers[idx] for idx in new_indices]
data_rows = [[row[idx] for idx in new_indices] for row in data_rows]
return headers, data_rows

View File

@@ -0,0 +1,177 @@
from . import oraconn
from . import sql_statements
from . import utils
# Get the next load id from the sequence
#
# Workflows
#
def process_source_file_from_event(resource_id: str):
#
# expects object uri in the form /n/<namespace>/b/<bucket>/o/<object>
# eg /n/frcnomajoc7v/b/dmarsdb1/o/sqlnet.log
# and calls process_source_file with prefix and file_name extracted from that uri
#
_, _, prefix, file_name = utils.parse_uri_with_regex(resource_id)
process_source_file(prefix, file_name)
def process_source_file(prefix: str, filename: str):
sourcefile = f"{prefix.rstrip('/')}/{filename}" # rstrip to cater for cases where the prefix is passed with a trailing slash
try:
conn = oraconn.connect("MRDS_LOADER")
oraconn.run_proc(conn, "CT_MRDS.FILE_MANAGER.PROCESS_SOURCE_FILE", [sourcefile])
conn.commit()
finally:
conn.close()
def execute_query(query, query_parameters=None, account_alias="MRDS_LOADER"):
query_result = None
try:
conn = oraconn.connect(account_alias)
curs = conn.cursor()
if query_parameters != None:
curs.execute(query, query_parameters)
else:
curs.execute(query)
query_result = curs.fetchall()
conn.commit()
finally:
conn.close()
return [t[0] for t in query_result]
def get_file_prefix(source_key, source_file_id, table_id):
query_result = None
try:
conn = oraconn.connect("MRDS_LOADER")
curs = conn.cursor()
curs.execute(
sql_statements.get_sql("get_file_prefix"),
[source_key, source_file_id, table_id],
)
query_result = curs.fetchone()
conn.commit()
finally:
conn.close()
return query_result[0]
def get_inbox_bucket():
try:
conn = oraconn.connect("MRDS_LOADER")
ret = oraconn.run_func(conn, "CT_MRDS.FILE_MANAGER.GET_INBOX_BUCKET", str, [])
conn.commit()
finally:
conn.close()
return ret
def get_data_bucket():
try:
conn = oraconn.connect("MRDS_LOADER")
ret = oraconn.run_func(conn, "CT_MRDS.FILE_MANAGER.GET_DATA_BUCKET", str, [])
conn.commit()
finally:
conn.close()
return ret
def add_source_file_config(
source_key,
source_file_type,
source_file_id,
source_file_desc,
source_file_name_pattern,
table_id,
template_table_name,
):
try:
conn = oraconn.connect("MRDS_LOADER")
ret = oraconn.run_proc(
conn,
"CT_MRDS.FILE_MANAGER.ADD_SOURCE_FILE_CONFIG",
[
source_key,
source_file_type,
source_file_id,
source_file_desc,
source_file_name_pattern,
table_id,
template_table_name,
],
)
conn.commit()
finally:
conn.close()
return ret
def add_column_date_format(template_table_name, column_name, date_format):
try:
conn = oraconn.connect("MRDS_LOADER")
ret = oraconn.run_proc(
conn,
"CT_MRDS.FILE_MANAGER.ADD_column_date_format",
[template_table_name, column_name, date_format],
)
conn.commit()
finally:
conn.close()
return ret
def execute(stmt):
try:
conn = oraconn.connect("MRDS_LOADER")
curs = conn.cursor()
curs.execute(stmt)
conn.commit()
finally:
conn.close()
def create_external_table(table_name, template_table_name, prefix):
try:
conn = oraconn.connect("ODS_LOADER")
ret = oraconn.run_proc(
conn,
"CT_MRDS.FILE_MANAGER.CREATE_EXTERNAL_TABLE",
[table_name, template_table_name, prefix, get_bucket("ODS")],
)
conn.commit()
finally:
conn.close()
return ret
def get_bucket(bucket):
try:
conn = oraconn.connect("MRDS_LOADER")
ret = oraconn.run_func(
conn, "CT_MRDS.FILE_MANAGER.GET_BUCKET_URI", str, [bucket]
)
conn.commit()
finally:
conn.close()
return ret

View File

@@ -0,0 +1,97 @@
from . import oraconn
from . import sql_statements
from . import static_vars
from . import manage_files
def init_workflow(database_name: str, workflow_name: str, workflow_run_id: str):
try:
conn = oraconn.connect("MRDS_LOADER")
a_workflow_history_key = oraconn.run_func(
conn,
"CT_MRDS.WORKFLOW_MANAGER.INIT_WORKFLOW",
int,
[database_name, workflow_run_id, workflow_name],
)
conn.commit()
finally:
conn.close()
return a_workflow_history_key
def finalise_workflow(a_workflow_history_key: int, workflow_status: str):
try:
conn = oraconn.connect("MRDS_LOADER")
oraconn.run_proc(
conn,
"CT_MRDS.WORKFLOW_MANAGER.FINALISE_WORKFLOW",
[a_workflow_history_key, workflow_status],
)
conn.commit()
finally:
conn.close()
def init_task(task_name: str, task_run_id: str, a_workflow_history_key: int):
a_task_history_key: int
try:
conn = oraconn.connect("MRDS_LOADER")
a_task_history_key = oraconn.run_func(
conn,
"CT_MRDS.WORKFLOW_MANAGER.INIT_TASK",
int,
[task_run_id, task_name, a_workflow_history_key],
)
conn.commit()
finally:
conn.close()
return a_task_history_key
def finalise_task(a_task_history_key: int, task_status: str):
try:
conn = oraconn.connect("MRDS_LOADER")
curs = conn.cursor()
curs.execute(
sql_statements.get_sql("finalise_task"), [task_status, a_task_history_key]
)
conn.commit()
finally:
conn.close()
def set_workflow_property(
wf_history_key: int, service_name: str, property: str, value: str
):
try:
conn = oraconn.connect("MRDS_LOADER")
ret = oraconn.run_proc(
conn,
"CT_MRDS.WORKFLOW_MANAGER.SET_WORKFLOW_PROPERTY",
[wf_history_key, service_name, property, value],
)
conn.commit()
finally:
conn.close()
return ret
def select_ods_tab(table_name: str, value: str, condition="1 = 1"):
query = "select %s from %s where %s" % (value, table_name, condition)
print("query = |%s|" % query)
return manage_files.execute_query(query=query, account_alias="ODS_LOADER")

View File

@@ -0,0 +1,53 @@
import oci
def get_client():
#
# Authentication is done using Instance Principals on VMs and Resouce Principal on OCI Container Instances
# The function first tries Resource Principal and fails back to Instance Principal in case of error
#
try:
signer = oci.auth.signers.get_resource_principals_signer()
except:
signer = signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
# Create secret client and retrieve content
client = oci.object_storage.ObjectStorageClient(
{}, signer=signer
) # the first empyty bracket is an empty config
return client
def list_bucket(client, namespace, bucket, prefix):
objects = client.list_objects(namespace, bucket, prefix=prefix)
# see https://docs.oracle.com/en-us/iaas/tools/python/2.135.0/api/request_and_response.html#oci.response.Response for all attrs
return objects.data
def upload_file(client, source_filename, namespace, bucket, prefix, target_filename):
with open(source_filename, "rb") as in_file:
client.put_object(
namespace, bucket, f"{prefix.rstrip('/')}/{target_filename}", in_file
)
def clean_folder(client, namespace, bucket, prefix):
objects = client.list_objects(namespace, bucket, prefix=prefix)
for o in objects.data.objects:
print(f"Deleting {prefix.rstrip('/')}/{o.name}")
client.delete_object(namespace, bucket, f"{o.name}")
def delete_file(client, file, namespace, bucket, prefix):
client.delete_object(namespace, bucket, f"{prefix.rstrip('/')}/{file}")
def download_file(client, namespace, bucket, prefix, source_filename, target_filename):
# Retrieve the file, streaming it into another file in 1 MiB chunks
get_obj = client.get_object(
namespace, bucket, f"{prefix.rstrip('/')}/{source_filename}"
)
with open(target_filename, "wb") as f:
for chunk in get_obj.data.raw.stream(1024 * 1024, decode_content=False):
f.write(chunk)

View File

@@ -0,0 +1,38 @@
import oracledb
import os
import traceback
import sys
def connect(alias):
username = os.getenv(alias + "_DB_USER")
password = os.getenv(alias + "_DB_PASS")
tnsalias = os.getenv(alias + "_DB_TNS")
connstr = username + "/" + password + "@" + tnsalias
oracledb.init_oracle_client()
try:
conn = oracledb.connect(connstr)
return conn
except oracledb.DatabaseError as db_err:
tb = traceback.format_exc()
print(f"DatabaseError connecting to '{alias}': {db_err}\n{tb}", file=sys.stderr)
sys.exit(1)
except Exception as exc:
tb = traceback.format_exc()
print(f"Unexpected error connecting to '{alias}': {exc}\n{tb}", file=sys.stderr)
sys.exit(1)
def run_proc(connection, proc: str, param: []):
curs = connection.cursor()
curs.callproc(proc, param)
def run_func(connection, proc: str, rettype, param: []):
curs = connection.cursor()
ret = curs.callfunc(proc, rettype, param)
return ret

View File

@@ -0,0 +1,46 @@
import oci
import ast
import base64
# Specify the OCID of the secret to retrieve
def get_secretcontents(ocid):
#
# Authentication is done using Instance Principals on VMs and Resouce Principal on OCI Container Instances
# The function first tries Resource Principal and fails back to Instance Principal in case of error
#
try:
signer = oci.auth.signers.get_resource_principals_signer()
except:
signer = signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
# Create secret client and retrieve content
secretclient = oci.secrets.SecretsClient({}, signer=signer)
secretcontents = secretclient.get_secret_bundle(secret_id=ocid)
return secretcontents
def get_password(ocid):
secretcontents = get_secretcontents(ocid)
# Decode the secret from base64 and return password
keybase64 = secretcontents.data.secret_bundle_content.content
keybase64bytes = keybase64.encode("ascii")
keybytes = base64.b64decode(keybase64bytes)
key = keybytes.decode("ascii")
keydict = ast.literal_eval(key)
return keydict["password"]
def get_secret(ocid):
# Create client
secretcontents = get_secretcontents(ocid)
# Decode the secret from base64 and return it
certbase64 = secretcontents.data.secret_bundle_content.content
certbytes = base64.b64decode(certbase64)
cert = certbytes.decode("UTF-8")
return cert

View File

@@ -0,0 +1,106 @@
import re
import logging
def verify_run_id(run_id, context=None):
"""
Verify run_id for security compliance.
Args:
run_id (str): The run_id to verify
context (dict, optional): Airflow context for logging
Returns:
str: Verified run_id
Raises:
ValueError: If run_id is invalid or suspicious
"""
try:
# Basic checks
if not run_id or not isinstance(run_id, str):
raise ValueError(
f"Invalid run_id: must be non-empty string, got: {type(run_id).__name__}"
)
run_id = run_id.strip()
if len(run_id) < 1 or len(run_id) > 250:
raise ValueError(
f"Invalid run_id: length must be 1-250 chars, got: {len(run_id)}"
)
# Allow only safe characters
if not re.match(r"^[a-zA-Z0-9_\-:+.T]+$", run_id):
suspicious_chars = "".join(
set(
char for char in run_id if not re.match(r"[a-zA-Z0-9_\-:+.T]", char)
)
)
logging.warning(f"SECURITY: Invalid chars in run_id: '{suspicious_chars}'")
raise ValueError("Invalid run_id: contains unsafe characters")
# Check for attack patterns
dangerous_patterns = [
r"\.\./",
r"\.\.\\",
r"<script",
r"javascript:",
r"union\s+select",
r"drop\s+table",
r"insert\s+into",
r"delete\s+from",
r"exec\s*\(",
r"system\s*\(",
r"eval\s*\(",
r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]",
]
for pattern in dangerous_patterns:
if re.search(pattern, run_id, re.IGNORECASE):
logging.error(f"SECURITY: Dangerous pattern in run_id: '{run_id}'")
raise ValueError("Invalid run_id: contains dangerous pattern")
# Log success
if context:
dag_id = (
getattr(context.get("dag"), "dag_id", "unknown")
if context.get("dag")
else "unknown"
)
logging.info(f"run_id verified: '{run_id}' for DAG: '{dag_id}'")
return run_id
except Exception as e:
logging.error(
f"SECURITY: run_id verification failed: '{run_id}', Error: {str(e)}"
)
raise ValueError(f"run_id verification failed: {str(e)}")
def get_verified_run_id(context):
"""
Extract and verify run_id from Airflow context.
Args:
context (dict): Airflow context
Returns:
str: Verified run_id
"""
try:
run_id = None
if context and "ti" in context:
run_id = context["ti"].run_id
elif context and "run_id" in context:
run_id = context["run_id"]
if not run_id:
raise ValueError("Could not extract run_id from context")
return verify_run_id(run_id, context)
except Exception as e:
logging.error(f"Failed to get verified run_id: {str(e)}")
raise

View File

@@ -0,0 +1,68 @@
sql_statements = {}
#
# Workflows
#
# register_workflow: Register new DW load
sql_statements[
"register_workflow"
] = """INSERT INTO CT_MRDS.A_WORKFLOW_HISTORY
(A_WORKFLOW_HISTORY_KEY, WORKFLOW_RUN_ID,
WORKFLOW_NAME, WORKFLOW_START, WORKFLOW_SSUCCESSFUL)
VALUES (:a_workflow_history_key, :workflow_run_id, :workflow_name, SYSTIMESTAMP, :running_status)
"""
# get_a_workflow_history_key: get new key from sequence
sql_statements["get_a_workflow_history_key"] = (
"SELECT CT_MRDS.A_WORKFLOW_HISTORY_KEY_SEQ.NEXTVAL FROM DUAL"
)
# finalise: Update load record in A_LOAD_HISTORY after workflow completion
sql_statements[
"finalise_workflow"
] = """UPDATE CT_MRDS.A_WORKFLOW_HISTORY
SET WORKFLOW_END = SYSTIMESTAMP, WORKFLOW_SUCCESSFUL = :workflow_status
WHERE A_WORKFLOW_HISTORY_KEY = :a_workflow_history_key
"""
#
# Tasks
#
# register_task
sql_statements[
"register_task"
] = """INSERT INTO CT_MRDS.A_TASK_HISTORY (A_TASK_HISTORY_KEY,
A_WORKFLOW_HISTORY_KEY, TASK_RUN_ID,
TASK_NAME, TASK_START, TASK_SUCCESSFUL)
VALUES (:a_workflow_history_key, :workflow_run_id, :workflow_name, SYSTIMESTAMP, :running_status)
"""
# get_a_task_history_key: get new key from sequence
sql_statements["get_a_task_history_key"] = (
"SELECT CT_MRDS.A_TASK_HISTORY_KEY_SEQ.NEXTVAL FROM DUAL"
)
# finalise: Update load record in A_LOAD_HISTORY after workflow completion
sql_statements[
"finalise_task"
] = """UPDATE CT_MRDS.A_TASK_HISTORY
SET TASK_END = SYSTIMESTAMP, TASK_SUCCESSFUL = :workflow_status
WHERE A_TASK_HISTORY_KEY = :a_workflow_history_key
"""
#
# Files
#
sql_statements["get_file_prefix"] = (
"SELECT CT_MRDS.FILE_MANAGER.GET_BUCKET_PATH(:source_key, :source_file_id, :table_id) FROM DUAL"
)
def get_sql(stmt_id: str):
if stmt_id in sql_statements:
return sql_statements[stmt_id]
else:
return

View File

@@ -0,0 +1,6 @@
#
# Task management variables
#
status_running: str = "RUNNING"
status_failed: str = "N"
status_success: str = "Y"

View File

@@ -0,0 +1,83 @@
import re
def parse_uri_with_regex(uri):
"""
Parses an Oracle Object Storage URI using regular expressions to extract the namespace,
bucket name, prefix, and object name.
Parameters:
uri (str): The URI string to parse, in the format '/n/{namespace}/b/{bucketname}/o/{object_path}'
Returns:
tuple: A tuple containing (namespace, bucket_name, prefix, object_name)
"""
# Define the regular expression pattern
pattern = r"^/n/([^/]+)/b/([^/]+)/o/(.*)$"
# Match the pattern against the URI
match = re.match(pattern, uri)
if not match:
raise ValueError("Invalid URI format")
# Extract namespace, bucket name, and object path from the matched groups
namespace = match.group(1)
bucket_name = match.group(2)
object_path = match.group(3)
# Split the object path into prefix and object name
if "/" in object_path:
# Split at the last '/' to separate prefix and object name
prefix, object_name = object_path.rsplit("/", 1)
# Ensure the prefix ends with a '/'
prefix += "/"
else:
# If there is no '/', there is no prefix
prefix = ""
object_name = object_path
return namespace, bucket_name, prefix, object_name
def parse_output_columns(output_columns):
xpath_entries = []
csv_entries = []
static_entries = []
a_key_entries = []
workflow_key_entries = []
xml_position_entries = []
column_order = []
for entry in output_columns:
entry_type = entry["type"]
column_header = entry["column_header"]
column_order.append(column_header)
if entry_type == "xpath":
xpath_expr = entry["value"]
is_key = entry["is_key"]
xpath_entries.append((xpath_expr, column_header, is_key))
elif entry_type == "csv_header":
value = entry["value"]
csv_entries.append((column_header, value))
elif entry_type == "static":
value = entry["value"]
static_entries.append((column_header, value))
elif entry_type == "a_key":
a_key_entries.append(column_header)
elif entry_type == "workflow_key":
workflow_key_entries.append(column_header)
elif entry_type == "xpath_element_id": # TODO - update all xml_position namings to xpath_element_id
xpath_expr = entry["value"]
xml_position_entries.append((xpath_expr, column_header))
return (
xpath_entries,
csv_entries,
static_entries,
a_key_entries,
workflow_key_entries,
xml_position_entries,
column_order,
)

View File

@@ -0,0 +1,23 @@
import oci
import ast
import base64
# Specify the OCID of the secret to retrieve
def get_password(ocid):
# Create vaultsclient using the default config file (\.oci\config) for auth to the API
signer = signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
# Get the secret
secretclient = oci.secrets.SecretsClient({}, signer=signer)
secretcontents = secretclient.get_secret_bundle(secret_id=ocid)
# Decode the secret from base64 and print
keybase64 = secretcontents.data.secret_bundle_content.content
keybase64bytes = keybase64.encode("ascii")
keybytes = base64.b64decode(keybase64bytes)
key = keybytes.decode("ascii")
keydict = ast.literal_eval(key)
return keydict["password"]

View File

@@ -0,0 +1,177 @@
import xmlschema
import hashlib
from lxml import etree
from typing import Dict, List
def validate_xml(xml_file, xsd_file):
try:
# Create an XMLSchema instance with strict validation
schema = xmlschema.XMLSchema(xsd_file, validation="strict")
# Validate the XML file
schema.validate(xml_file)
return True, "XML file is valid against the provided XSD schema."
except xmlschema.validators.exceptions.XMLSchemaValidationError as e:
return False, f"XML validation error: {str(e)}"
except xmlschema.validators.exceptions.XMLSchemaException as e:
return False, f"XML schema error: {str(e)}"
except Exception as e:
return False, f"An error occurred during XML validation: {str(e)}"
def extract_data(
filename,
xpath_columns, # List[(expr, header, is_key)]
xml_position_columns, # List[(expr, header)]
namespaces,
workflow_context,
encoding_type="utf-8",
):
"""
Parses an XML file using XPath expressions and extracts data.
Parameters:
- filename (str): The path to the XML file to parse.
- xpath_columns (list): A list of tuples, each containing:
- XPath expression (str)
- CSV column header (str)
- Indicator if the field is a key ('Y' or 'N')
- xml_position_columns (list)
- namespaces (dict): Namespace mapping needed for lxml's xpath()
Returns:
- dict: A dictionary containing headers and rows with extracted data.
"""
parser = etree.XMLParser(remove_blank_text=True)
tree = etree.parse(filename, parser)
root = tree.getroot()
# Separate out key vs nonkey columns
key_cols = [ (expr, h) for expr, h, k in xpath_columns if k == "Y" ]
nonkey_cols = [ (expr, h) for expr, h, k in xpath_columns if k == "N" ]
# Evaluate every nonkey XPath and keep the ELEMENT nodes
nonkey_elements = {}
for expr, header in nonkey_cols:
elems = root.xpath(expr, namespaces=namespaces)
nonkey_elements[header] = elems
# figure out how many rows total we need
# that's the maximum length of any of the nonkey lists
if nonkey_elements:
row_count = max(len(lst) for lst in nonkey_elements.values())
else:
row_count = 0
# pad every nonkey list up to row_count with `None`
for header, lst in nonkey_elements.items():
if len(lst) < row_count:
lst.extend([None] * (row_count - len(lst)))
# key columns
key_values = []
for expr, header in key_cols:
nodes = root.xpath(expr, namespaces=namespaces)
if not nodes:
key_values.append("")
else:
first = nodes[0]
txt = (first.text if isinstance(first, etree._Element) else str(first)) or ""
key_values.append(txt.strip())
# xml_position columns
xml_positions = {}
for expr, header in xml_position_columns:
xml_positions[header] = root.xpath(expr, namespaces=namespaces)
# prepare headers
headers = [h for _, h in nonkey_cols] + [h for _, h in key_cols] + [h for _, h in xml_position_columns]
# build rows
rows = []
for i in range(row_count):
row = []
# nonkey data
for expr, header in nonkey_cols:
elem = nonkey_elements[header][i]
text = ""
if isinstance(elem, etree._Element):
text = elem.text or ""
elif elem is not None:
text = str(elem)
row.append(text.strip())
# key columns
row.extend(key_values)
# xml_position columns
for expr, header in xml_position_columns:
if not nonkey_cols:
row.append("")
continue
first_header = nonkey_cols[0][1]
data_elem = nonkey_elements[first_header][i]
if data_elem is None:
row.append("")
continue
target_list = xml_positions[header]
current = data_elem
found = None
while current is not None:
if current in target_list:
found = current
break
current = current.getparent()
if not found:
row.append("")
else:
# compute fullpath with indices
path_elems = []
walk = found
while walk is not None:
idx = 1 + sum(1 for s in walk.itersiblings(preceding=True) if s.tag == walk.tag)
path_elems.append(f"{walk.tag}[{idx}]")
walk = walk.getparent()
full_path = "/" + "/".join(reversed(path_elems))
row.append(_xml_pos_hasher(full_path, workflow_context["a_workflow_history_key"]))
rows.append(row)
return {"headers": headers, "rows": rows}
def _xml_pos_hasher(input_string, salt, hash_length=15):
"""
Helps hashing xml positions.
Parameters:
input_string (str): The string to hash.
salt (int): The integer salt to ensure deterministic, run-specific behavior.
hash_length (int): The desired length of the resulting hash (default is 15 digits).
Returns:
int: A deterministic integer hash of the specified length.
"""
# Ensure the hash length is valid
if hash_length <= 0:
raise ValueError("Hash length must be a positive integer.")
# Combine the input string with the salt to create a deterministic input
salted_input = f"{salt}:{input_string}"
# Generate a SHA-256 hash of the salted input
hash_object = hashlib.sha256(salted_input.encode())
full_hash = hash_object.hexdigest()
# Convert the hash to an integer
hash_integer = int(full_hash, 16)
# Truncate or pad the hash to the desired length
truncated_hash = str(hash_integer)[:hash_length]
return int(truncated_hash)