367 lines
11 KiB
Python
367 lines
11 KiB
Python
import os
|
|
import uuid
|
|
import logging
|
|
import yaml
|
|
import zipfile
|
|
import tempfile
|
|
from dataclasses import dataclass, field
|
|
|
|
from mrds import __version__
|
|
|
|
from mrds.processors import get_file_processor
|
|
from mrds.utils import (
|
|
manage_runs,
|
|
objectstore,
|
|
static_vars,
|
|
xml_utils,
|
|
)
|
|
|
|
|
|
# environment variables
|
|
MRDS_ENV = os.getenv("MRDS_ENV", "poc")
|
|
BUCKET = os.getenv("INBOX_BUCKET", "mrds_inbox_poc")
|
|
BUCKET_NAMESPACE = os.getenv("BUCKET_NAMESPACE", "frcnomajoc7v")
|
|
|
|
|
|
# Static configuration variables
|
|
WORKFLOW_TYPE = "ODS"
|
|
ENCODING_TYPE = "utf-8"
|
|
|
|
CONFIG_REQUIRED_KEYS = [
|
|
"tmpdir",
|
|
"inbox_prefix",
|
|
"archive_prefix",
|
|
"workflow_name",
|
|
"validation_schema_path",
|
|
"tasks",
|
|
"file_type",
|
|
]
|
|
|
|
TASK_REQUIRED_KEYS = [
|
|
"task_name",
|
|
"ods_prefix",
|
|
"output_table",
|
|
"output_columns",
|
|
]
|
|
|
|
STATUS_SUCCESS = static_vars.status_success
|
|
STATUS_FAILURE = static_vars.status_failed
|
|
|
|
|
|
@dataclass
|
|
class GlobalConfig:
|
|
tmpdir: str
|
|
inbox_prefix: str
|
|
archive_prefix: str
|
|
workflow_name: str
|
|
source_filename: str
|
|
validation_schema_path: str
|
|
bucket: str
|
|
bucket_namespace: str
|
|
file_type: str
|
|
encoding_type: str
|
|
|
|
def __post_init__(self):
|
|
self.original_source_filename = self.source_filename # keep this in case we have a zip file to archive
|
|
|
|
@property
|
|
def source_filepath(self) -> str:
|
|
return os.path.join(self.tmpdir, self.source_filename)
|
|
|
|
@property
|
|
def original_source_filepath(self) -> str:
|
|
return os.path.join(self.tmpdir, self.original_source_filename)
|
|
|
|
|
|
@dataclass
|
|
class TaskConfig:
|
|
task_name: str
|
|
ods_prefix: str
|
|
output_table: str
|
|
namespaces: dict
|
|
output_columns: list
|
|
|
|
|
|
def initialize_config(source_filename, config_file_path):
|
|
logging.info(f"Source filename is set to: {source_filename}")
|
|
logging.info(f"Loading configuration from {config_file_path}")
|
|
# Ensure the file exists
|
|
if not os.path.exists(config_file_path):
|
|
raise FileNotFoundError(f"Configuration file {config_file_path} not found.")
|
|
|
|
# Load the configuration
|
|
with open(config_file_path, "r") as f:
|
|
config_data = yaml.safe_load(f)
|
|
logging.debug(f"Configuration data: {config_data}")
|
|
|
|
missing_keys = [key for key in CONFIG_REQUIRED_KEYS if key not in config_data]
|
|
if missing_keys:
|
|
raise ValueError(f"Missing required keys in configuration: {missing_keys}")
|
|
|
|
# Create GlobalConfig instance
|
|
global_config = GlobalConfig(
|
|
tmpdir=config_data["tmpdir"],
|
|
inbox_prefix=config_data["inbox_prefix"],
|
|
archive_prefix=config_data["archive_prefix"],
|
|
workflow_name=config_data["workflow_name"],
|
|
source_filename=source_filename,
|
|
validation_schema_path=config_data["validation_schema_path"],
|
|
bucket=BUCKET,
|
|
bucket_namespace=BUCKET_NAMESPACE,
|
|
file_type=config_data["file_type"],
|
|
encoding_type=config_data.get("encoding_type", ENCODING_TYPE),
|
|
)
|
|
|
|
# Create list of TaskConfig instances
|
|
tasks_data = config_data["tasks"]
|
|
tasks = []
|
|
for task_data in tasks_data:
|
|
# Validate required keys in task_data
|
|
missing_task_keys = [key for key in TASK_REQUIRED_KEYS if key not in task_data]
|
|
if missing_task_keys:
|
|
raise ValueError(
|
|
f"Missing required keys in task configuration: {missing_task_keys}"
|
|
)
|
|
|
|
task = TaskConfig(
|
|
task_name=task_data["task_name"],
|
|
ods_prefix=task_data["ods_prefix"],
|
|
output_table=task_data["output_table"],
|
|
namespaces=task_data.get("namespaces", {}),
|
|
output_columns=task_data["output_columns"],
|
|
)
|
|
tasks.append(task)
|
|
|
|
return global_config, tasks
|
|
|
|
|
|
def initialize_workflow(global_config):
|
|
|
|
run_id = str(uuid.uuid4())
|
|
|
|
logging.info(f"Initializing workflow '{global_config.workflow_name}'")
|
|
a_workflow_history_key = manage_runs.init_workflow(
|
|
WORKFLOW_TYPE, global_config.workflow_name, run_id
|
|
)
|
|
|
|
return {
|
|
"run_id": run_id,
|
|
"a_workflow_history_key": a_workflow_history_key,
|
|
}
|
|
|
|
|
|
def download_source_file(client, global_config):
|
|
logging.info(
|
|
f"Downloading source file '{global_config.source_filename}' "
|
|
f"from '{global_config.bucket}/{global_config.inbox_prefix}'"
|
|
)
|
|
objectstore.download_file(
|
|
client,
|
|
global_config.bucket_namespace,
|
|
global_config.bucket,
|
|
global_config.inbox_prefix,
|
|
global_config.source_filename,
|
|
global_config.source_filepath,
|
|
)
|
|
logging.info(f"Source file downloaded to '{global_config.source_filepath}'")
|
|
|
|
|
|
def delete_source_file(client, global_config):
|
|
|
|
logging.info(
|
|
f"Deleting source file '{global_config.bucket}/{global_config.inbox_prefix}/{global_config.original_source_filename}'"
|
|
)
|
|
objectstore.delete_file(
|
|
client,
|
|
global_config.original_source_filename,
|
|
global_config.bucket_namespace,
|
|
global_config.bucket,
|
|
global_config.inbox_prefix,
|
|
)
|
|
logging.info(
|
|
f"Deleted source file '{global_config.bucket}/{global_config.inbox_prefix}/{global_config.original_source_filename}'"
|
|
)
|
|
|
|
|
|
def archive_source_file(client, global_config):
|
|
|
|
logging.info(
|
|
f"Archiving source file to '{global_config.bucket}/{global_config.archive_prefix}/{global_config.original_source_filename}'"
|
|
)
|
|
objectstore.upload_file(
|
|
client,
|
|
global_config.original_source_filepath,
|
|
global_config.bucket_namespace,
|
|
global_config.bucket,
|
|
global_config.archive_prefix,
|
|
global_config.original_source_filename,
|
|
)
|
|
logging.info(
|
|
f"Source file archived to '{global_config.bucket}/{global_config.archive_prefix}/{global_config.original_source_filename}'"
|
|
)
|
|
|
|
|
|
def unzip_source_file_if_needed(global_config):
|
|
source_filepath = global_config.source_filepath
|
|
|
|
# If it's not a zip, nothing to do
|
|
if not zipfile.is_zipfile(source_filepath):
|
|
logging.info(f"File '{source_filepath}' is not a ZIP file.")
|
|
return True
|
|
|
|
logging.info(f"File '{source_filepath}' is a ZIP file. Unzipping...")
|
|
|
|
extract_dir = os.path.dirname(source_filepath)
|
|
|
|
try:
|
|
with zipfile.ZipFile(source_filepath, "r") as zip_ref:
|
|
extracted_files = zip_ref.namelist()
|
|
|
|
if len(extracted_files) != 1:
|
|
logging.error(
|
|
f"Expected one file in the ZIP, but found {len(extracted_files)} files."
|
|
)
|
|
return False
|
|
|
|
# Extract everything
|
|
zip_ref.extractall(extract_dir)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error while extracting '{source_filepath}': {e}")
|
|
return False
|
|
|
|
# Update the global_config to point to the extracted file
|
|
extracted_filename = extracted_files[0]
|
|
global_config.source_filename = extracted_filename
|
|
|
|
logging.info(
|
|
f"Extracted '{extracted_filename}' to '{extract_dir}'. "
|
|
f"Updated source_filepath to '{global_config.source_filepath}'."
|
|
)
|
|
|
|
return True
|
|
|
|
|
|
def validate_source_file(global_config):
|
|
file_type = global_config.file_type.lower()
|
|
|
|
if file_type == "xml":
|
|
xml_is_valid, xml_validation_message = xml_utils.validate_xml(
|
|
global_config.source_filepath, global_config.validation_schema_path
|
|
)
|
|
if not xml_is_valid:
|
|
raise ValueError(f"XML validation failed: {xml_validation_message}")
|
|
logging.info(xml_validation_message)
|
|
|
|
elif file_type == "csv":
|
|
# TODO: add CSV validation here
|
|
pass
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported file type: {file_type}")
|
|
|
|
return True
|
|
|
|
|
|
def process_tasks(tasks, global_config, workflow_context, client):
|
|
|
|
# get appropriate task processor
|
|
processor_class = get_file_processor(global_config)
|
|
|
|
for task_conf in tasks:
|
|
|
|
logging.info(f"Starting task '{task_conf.task_name}'")
|
|
file_processor = processor_class(
|
|
global_config, task_conf, client, workflow_context
|
|
)
|
|
file_processor.process()
|
|
|
|
|
|
def finalize_workflow(workflow_context, success=True):
|
|
status = STATUS_SUCCESS if success else STATUS_FAILURE
|
|
manage_runs.finalise_workflow(workflow_context["a_workflow_history_key"], status)
|
|
if success:
|
|
logging.info("Workflow completed successfully")
|
|
else:
|
|
logging.error("Workflow failed")
|
|
|
|
|
|
def main(
|
|
workflow_context: dict,
|
|
source_filename: str,
|
|
config_file_path: str,
|
|
generate_workflow_context=False,
|
|
keep_source_file=False,
|
|
keep_tmp_dir=False,
|
|
):
|
|
|
|
logging.info(f"Initializing mrds app, version {__version__}")
|
|
|
|
tmpdir_manager = None
|
|
|
|
try:
|
|
# get configs
|
|
global_config, tasks = initialize_config(source_filename, config_file_path)
|
|
|
|
# Handle temporary dirs
|
|
if keep_tmp_dir:
|
|
tmpdir = tempfile.mkdtemp(
|
|
prefix="mrds_", dir=global_config.tmpdir
|
|
) # dir is created and never deleted
|
|
logging.info(
|
|
f"Created temporary working directory (not auto-deleted): {tmpdir}"
|
|
)
|
|
else:
|
|
tmpdir_manager = tempfile.TemporaryDirectory(
|
|
prefix="mrds_", dir=global_config.tmpdir
|
|
)
|
|
tmpdir = tmpdir_manager.name
|
|
logging.info(
|
|
f"Created temporary working directory (auto-deleted): {tmpdir}"
|
|
)
|
|
|
|
# override tmpdir with newly created tmpdir
|
|
global_config.tmpdir = tmpdir
|
|
|
|
client = objectstore.get_client()
|
|
|
|
# Handle workflow_context generation if required
|
|
if generate_workflow_context:
|
|
logging.info("Generating workflow context automatically.")
|
|
workflow_context = initialize_workflow(global_config)
|
|
logging.info(f"Generated workflow context: {workflow_context}")
|
|
else:
|
|
logging.info(f"Using provided workflow context: {workflow_context}")
|
|
|
|
download_source_file(client, global_config)
|
|
unzip_source_file_if_needed(global_config)
|
|
validate_source_file(global_config)
|
|
process_tasks(tasks, global_config, workflow_context, client)
|
|
|
|
if generate_workflow_context:
|
|
finalize_workflow(workflow_context)
|
|
|
|
if not keep_source_file:
|
|
archive_source_file(client, global_config)
|
|
delete_source_file(client, global_config)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Critical error: {str(e)}")
|
|
|
|
# Finalize workflow with failure if needed
|
|
if generate_workflow_context and "workflow_context" in locals():
|
|
finalize_workflow(workflow_context, success=False)
|
|
|
|
raise RuntimeError(f"Workflow failed due to: {e}")
|
|
|
|
finally:
|
|
# Always attempt to remove tmpdir if created a TemporaryDirectory manager
|
|
if tmpdir_manager and not keep_tmp_dir:
|
|
try:
|
|
tmpdir_manager.cleanup()
|
|
logging.info(f"Deleted temporary working directory {tmpdir}")
|
|
except Exception:
|
|
logging.exception(
|
|
f"Failed to delete up temporary working directory {tmpdir}"
|
|
)
|