Files
mars-elt/python/mrds_common/mrds/core.py
Grzegorz Michalski 2c225d68ac init
2026-03-02 09:47:35 +01:00

367 lines
11 KiB
Python

import os
import uuid
import logging
import yaml
import zipfile
import tempfile
from dataclasses import dataclass, field
from mrds import __version__
from mrds.processors import get_file_processor
from mrds.utils import (
manage_runs,
objectstore,
static_vars,
xml_utils,
)
# environment variables
MRDS_ENV = os.getenv("MRDS_ENV", "poc")
BUCKET = os.getenv("INBOX_BUCKET", "mrds_inbox_poc")
BUCKET_NAMESPACE = os.getenv("BUCKET_NAMESPACE", "frcnomajoc7v")
# Static configuration variables
WORKFLOW_TYPE = "ODS"
ENCODING_TYPE = "utf-8"
CONFIG_REQUIRED_KEYS = [
"tmpdir",
"inbox_prefix",
"archive_prefix",
"workflow_name",
"validation_schema_path",
"tasks",
"file_type",
]
TASK_REQUIRED_KEYS = [
"task_name",
"ods_prefix",
"output_table",
"output_columns",
]
STATUS_SUCCESS = static_vars.status_success
STATUS_FAILURE = static_vars.status_failed
@dataclass
class GlobalConfig:
tmpdir: str
inbox_prefix: str
archive_prefix: str
workflow_name: str
source_filename: str
validation_schema_path: str
bucket: str
bucket_namespace: str
file_type: str
encoding_type: str
def __post_init__(self):
self.original_source_filename = self.source_filename # keep this in case we have a zip file to archive
@property
def source_filepath(self) -> str:
return os.path.join(self.tmpdir, self.source_filename)
@property
def original_source_filepath(self) -> str:
return os.path.join(self.tmpdir, self.original_source_filename)
@dataclass
class TaskConfig:
task_name: str
ods_prefix: str
output_table: str
namespaces: dict
output_columns: list
def initialize_config(source_filename, config_file_path):
logging.info(f"Source filename is set to: {source_filename}")
logging.info(f"Loading configuration from {config_file_path}")
# Ensure the file exists
if not os.path.exists(config_file_path):
raise FileNotFoundError(f"Configuration file {config_file_path} not found.")
# Load the configuration
with open(config_file_path, "r") as f:
config_data = yaml.safe_load(f)
logging.debug(f"Configuration data: {config_data}")
missing_keys = [key for key in CONFIG_REQUIRED_KEYS if key not in config_data]
if missing_keys:
raise ValueError(f"Missing required keys in configuration: {missing_keys}")
# Create GlobalConfig instance
global_config = GlobalConfig(
tmpdir=config_data["tmpdir"],
inbox_prefix=config_data["inbox_prefix"],
archive_prefix=config_data["archive_prefix"],
workflow_name=config_data["workflow_name"],
source_filename=source_filename,
validation_schema_path=config_data["validation_schema_path"],
bucket=BUCKET,
bucket_namespace=BUCKET_NAMESPACE,
file_type=config_data["file_type"],
encoding_type=config_data.get("encoding_type", ENCODING_TYPE),
)
# Create list of TaskConfig instances
tasks_data = config_data["tasks"]
tasks = []
for task_data in tasks_data:
# Validate required keys in task_data
missing_task_keys = [key for key in TASK_REQUIRED_KEYS if key not in task_data]
if missing_task_keys:
raise ValueError(
f"Missing required keys in task configuration: {missing_task_keys}"
)
task = TaskConfig(
task_name=task_data["task_name"],
ods_prefix=task_data["ods_prefix"],
output_table=task_data["output_table"],
namespaces=task_data.get("namespaces", {}),
output_columns=task_data["output_columns"],
)
tasks.append(task)
return global_config, tasks
def initialize_workflow(global_config):
run_id = str(uuid.uuid4())
logging.info(f"Initializing workflow '{global_config.workflow_name}'")
a_workflow_history_key = manage_runs.init_workflow(
WORKFLOW_TYPE, global_config.workflow_name, run_id
)
return {
"run_id": run_id,
"a_workflow_history_key": a_workflow_history_key,
}
def download_source_file(client, global_config):
logging.info(
f"Downloading source file '{global_config.source_filename}' "
f"from '{global_config.bucket}/{global_config.inbox_prefix}'"
)
objectstore.download_file(
client,
global_config.bucket_namespace,
global_config.bucket,
global_config.inbox_prefix,
global_config.source_filename,
global_config.source_filepath,
)
logging.info(f"Source file downloaded to '{global_config.source_filepath}'")
def delete_source_file(client, global_config):
logging.info(
f"Deleting source file '{global_config.bucket}/{global_config.inbox_prefix}/{global_config.original_source_filename}'"
)
objectstore.delete_file(
client,
global_config.original_source_filename,
global_config.bucket_namespace,
global_config.bucket,
global_config.inbox_prefix,
)
logging.info(
f"Deleted source file '{global_config.bucket}/{global_config.inbox_prefix}/{global_config.original_source_filename}'"
)
def archive_source_file(client, global_config):
logging.info(
f"Archiving source file to '{global_config.bucket}/{global_config.archive_prefix}/{global_config.original_source_filename}'"
)
objectstore.upload_file(
client,
global_config.original_source_filepath,
global_config.bucket_namespace,
global_config.bucket,
global_config.archive_prefix,
global_config.original_source_filename,
)
logging.info(
f"Source file archived to '{global_config.bucket}/{global_config.archive_prefix}/{global_config.original_source_filename}'"
)
def unzip_source_file_if_needed(global_config):
source_filepath = global_config.source_filepath
# If it's not a zip, nothing to do
if not zipfile.is_zipfile(source_filepath):
logging.info(f"File '{source_filepath}' is not a ZIP file.")
return True
logging.info(f"File '{source_filepath}' is a ZIP file. Unzipping...")
extract_dir = os.path.dirname(source_filepath)
try:
with zipfile.ZipFile(source_filepath, "r") as zip_ref:
extracted_files = zip_ref.namelist()
if len(extracted_files) != 1:
logging.error(
f"Expected one file in the ZIP, but found {len(extracted_files)} files."
)
return False
# Extract everything
zip_ref.extractall(extract_dir)
except Exception as e:
logging.error(f"Error while extracting '{source_filepath}': {e}")
return False
# Update the global_config to point to the extracted file
extracted_filename = extracted_files[0]
global_config.source_filename = extracted_filename
logging.info(
f"Extracted '{extracted_filename}' to '{extract_dir}'. "
f"Updated source_filepath to '{global_config.source_filepath}'."
)
return True
def validate_source_file(global_config):
file_type = global_config.file_type.lower()
if file_type == "xml":
xml_is_valid, xml_validation_message = xml_utils.validate_xml(
global_config.source_filepath, global_config.validation_schema_path
)
if not xml_is_valid:
raise ValueError(f"XML validation failed: {xml_validation_message}")
logging.info(xml_validation_message)
elif file_type == "csv":
# TODO: add CSV validation here
pass
else:
raise ValueError(f"Unsupported file type: {file_type}")
return True
def process_tasks(tasks, global_config, workflow_context, client):
# get appropriate task processor
processor_class = get_file_processor(global_config)
for task_conf in tasks:
logging.info(f"Starting task '{task_conf.task_name}'")
file_processor = processor_class(
global_config, task_conf, client, workflow_context
)
file_processor.process()
def finalize_workflow(workflow_context, success=True):
status = STATUS_SUCCESS if success else STATUS_FAILURE
manage_runs.finalise_workflow(workflow_context["a_workflow_history_key"], status)
if success:
logging.info("Workflow completed successfully")
else:
logging.error("Workflow failed")
def main(
workflow_context: dict,
source_filename: str,
config_file_path: str,
generate_workflow_context=False,
keep_source_file=False,
keep_tmp_dir=False,
):
logging.info(f"Initializing mrds app, version {__version__}")
tmpdir_manager = None
try:
# get configs
global_config, tasks = initialize_config(source_filename, config_file_path)
# Handle temporary dirs
if keep_tmp_dir:
tmpdir = tempfile.mkdtemp(
prefix="mrds_", dir=global_config.tmpdir
) # dir is created and never deleted
logging.info(
f"Created temporary working directory (not auto-deleted): {tmpdir}"
)
else:
tmpdir_manager = tempfile.TemporaryDirectory(
prefix="mrds_", dir=global_config.tmpdir
)
tmpdir = tmpdir_manager.name
logging.info(
f"Created temporary working directory (auto-deleted): {tmpdir}"
)
# override tmpdir with newly created tmpdir
global_config.tmpdir = tmpdir
client = objectstore.get_client()
# Handle workflow_context generation if required
if generate_workflow_context:
logging.info("Generating workflow context automatically.")
workflow_context = initialize_workflow(global_config)
logging.info(f"Generated workflow context: {workflow_context}")
else:
logging.info(f"Using provided workflow context: {workflow_context}")
download_source_file(client, global_config)
unzip_source_file_if_needed(global_config)
validate_source_file(global_config)
process_tasks(tasks, global_config, workflow_context, client)
if generate_workflow_context:
finalize_workflow(workflow_context)
if not keep_source_file:
archive_source_file(client, global_config)
delete_source_file(client, global_config)
except Exception as e:
logging.error(f"Critical error: {str(e)}")
# Finalize workflow with failure if needed
if generate_workflow_context and "workflow_context" in locals():
finalize_workflow(workflow_context, success=False)
raise RuntimeError(f"Workflow failed due to: {e}")
finally:
# Always attempt to remove tmpdir if created a TemporaryDirectory manager
if tmpdir_manager and not keep_tmp_dir:
try:
tmpdir_manager.cleanup()
logging.info(f"Deleted temporary working directory {tmpdir}")
except Exception:
logging.exception(
f"Failed to delete up temporary working directory {tmpdir}"
)