init
This commit is contained in:
366
python/mrds_common/mrds/core.py
Normal file
366
python/mrds_common/mrds/core.py
Normal file
@@ -0,0 +1,366 @@
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import yaml
|
||||
import zipfile
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from mrds import __version__
|
||||
|
||||
from mrds.processors import get_file_processor
|
||||
from mrds.utils import (
|
||||
manage_runs,
|
||||
objectstore,
|
||||
static_vars,
|
||||
xml_utils,
|
||||
)
|
||||
|
||||
|
||||
# environment variables
|
||||
MRDS_ENV = os.getenv("MRDS_ENV", "poc")
|
||||
BUCKET = os.getenv("INBOX_BUCKET", "mrds_inbox_poc")
|
||||
BUCKET_NAMESPACE = os.getenv("BUCKET_NAMESPACE", "frcnomajoc7v")
|
||||
|
||||
|
||||
# Static configuration variables
|
||||
WORKFLOW_TYPE = "ODS"
|
||||
ENCODING_TYPE = "utf-8"
|
||||
|
||||
CONFIG_REQUIRED_KEYS = [
|
||||
"tmpdir",
|
||||
"inbox_prefix",
|
||||
"archive_prefix",
|
||||
"workflow_name",
|
||||
"validation_schema_path",
|
||||
"tasks",
|
||||
"file_type",
|
||||
]
|
||||
|
||||
TASK_REQUIRED_KEYS = [
|
||||
"task_name",
|
||||
"ods_prefix",
|
||||
"output_table",
|
||||
"output_columns",
|
||||
]
|
||||
|
||||
STATUS_SUCCESS = static_vars.status_success
|
||||
STATUS_FAILURE = static_vars.status_failed
|
||||
|
||||
|
||||
@dataclass
|
||||
class GlobalConfig:
|
||||
tmpdir: str
|
||||
inbox_prefix: str
|
||||
archive_prefix: str
|
||||
workflow_name: str
|
||||
source_filename: str
|
||||
validation_schema_path: str
|
||||
bucket: str
|
||||
bucket_namespace: str
|
||||
file_type: str
|
||||
encoding_type: str
|
||||
|
||||
def __post_init__(self):
|
||||
self.original_source_filename = self.source_filename # keep this in case we have a zip file to archive
|
||||
|
||||
@property
|
||||
def source_filepath(self) -> str:
|
||||
return os.path.join(self.tmpdir, self.source_filename)
|
||||
|
||||
@property
|
||||
def original_source_filepath(self) -> str:
|
||||
return os.path.join(self.tmpdir, self.original_source_filename)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskConfig:
|
||||
task_name: str
|
||||
ods_prefix: str
|
||||
output_table: str
|
||||
namespaces: dict
|
||||
output_columns: list
|
||||
|
||||
|
||||
def initialize_config(source_filename, config_file_path):
|
||||
logging.info(f"Source filename is set to: {source_filename}")
|
||||
logging.info(f"Loading configuration from {config_file_path}")
|
||||
# Ensure the file exists
|
||||
if not os.path.exists(config_file_path):
|
||||
raise FileNotFoundError(f"Configuration file {config_file_path} not found.")
|
||||
|
||||
# Load the configuration
|
||||
with open(config_file_path, "r") as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
logging.debug(f"Configuration data: {config_data}")
|
||||
|
||||
missing_keys = [key for key in CONFIG_REQUIRED_KEYS if key not in config_data]
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing required keys in configuration: {missing_keys}")
|
||||
|
||||
# Create GlobalConfig instance
|
||||
global_config = GlobalConfig(
|
||||
tmpdir=config_data["tmpdir"],
|
||||
inbox_prefix=config_data["inbox_prefix"],
|
||||
archive_prefix=config_data["archive_prefix"],
|
||||
workflow_name=config_data["workflow_name"],
|
||||
source_filename=source_filename,
|
||||
validation_schema_path=config_data["validation_schema_path"],
|
||||
bucket=BUCKET,
|
||||
bucket_namespace=BUCKET_NAMESPACE,
|
||||
file_type=config_data["file_type"],
|
||||
encoding_type=config_data.get("encoding_type", ENCODING_TYPE),
|
||||
)
|
||||
|
||||
# Create list of TaskConfig instances
|
||||
tasks_data = config_data["tasks"]
|
||||
tasks = []
|
||||
for task_data in tasks_data:
|
||||
# Validate required keys in task_data
|
||||
missing_task_keys = [key for key in TASK_REQUIRED_KEYS if key not in task_data]
|
||||
if missing_task_keys:
|
||||
raise ValueError(
|
||||
f"Missing required keys in task configuration: {missing_task_keys}"
|
||||
)
|
||||
|
||||
task = TaskConfig(
|
||||
task_name=task_data["task_name"],
|
||||
ods_prefix=task_data["ods_prefix"],
|
||||
output_table=task_data["output_table"],
|
||||
namespaces=task_data.get("namespaces", {}),
|
||||
output_columns=task_data["output_columns"],
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
return global_config, tasks
|
||||
|
||||
|
||||
def initialize_workflow(global_config):
|
||||
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
logging.info(f"Initializing workflow '{global_config.workflow_name}'")
|
||||
a_workflow_history_key = manage_runs.init_workflow(
|
||||
WORKFLOW_TYPE, global_config.workflow_name, run_id
|
||||
)
|
||||
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"a_workflow_history_key": a_workflow_history_key,
|
||||
}
|
||||
|
||||
|
||||
def download_source_file(client, global_config):
|
||||
logging.info(
|
||||
f"Downloading source file '{global_config.source_filename}' "
|
||||
f"from '{global_config.bucket}/{global_config.inbox_prefix}'"
|
||||
)
|
||||
objectstore.download_file(
|
||||
client,
|
||||
global_config.bucket_namespace,
|
||||
global_config.bucket,
|
||||
global_config.inbox_prefix,
|
||||
global_config.source_filename,
|
||||
global_config.source_filepath,
|
||||
)
|
||||
logging.info(f"Source file downloaded to '{global_config.source_filepath}'")
|
||||
|
||||
|
||||
def delete_source_file(client, global_config):
|
||||
|
||||
logging.info(
|
||||
f"Deleting source file '{global_config.bucket}/{global_config.inbox_prefix}/{global_config.original_source_filename}'"
|
||||
)
|
||||
objectstore.delete_file(
|
||||
client,
|
||||
global_config.original_source_filename,
|
||||
global_config.bucket_namespace,
|
||||
global_config.bucket,
|
||||
global_config.inbox_prefix,
|
||||
)
|
||||
logging.info(
|
||||
f"Deleted source file '{global_config.bucket}/{global_config.inbox_prefix}/{global_config.original_source_filename}'"
|
||||
)
|
||||
|
||||
|
||||
def archive_source_file(client, global_config):
|
||||
|
||||
logging.info(
|
||||
f"Archiving source file to '{global_config.bucket}/{global_config.archive_prefix}/{global_config.original_source_filename}'"
|
||||
)
|
||||
objectstore.upload_file(
|
||||
client,
|
||||
global_config.original_source_filepath,
|
||||
global_config.bucket_namespace,
|
||||
global_config.bucket,
|
||||
global_config.archive_prefix,
|
||||
global_config.original_source_filename,
|
||||
)
|
||||
logging.info(
|
||||
f"Source file archived to '{global_config.bucket}/{global_config.archive_prefix}/{global_config.original_source_filename}'"
|
||||
)
|
||||
|
||||
|
||||
def unzip_source_file_if_needed(global_config):
|
||||
source_filepath = global_config.source_filepath
|
||||
|
||||
# If it's not a zip, nothing to do
|
||||
if not zipfile.is_zipfile(source_filepath):
|
||||
logging.info(f"File '{source_filepath}' is not a ZIP file.")
|
||||
return True
|
||||
|
||||
logging.info(f"File '{source_filepath}' is a ZIP file. Unzipping...")
|
||||
|
||||
extract_dir = os.path.dirname(source_filepath)
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(source_filepath, "r") as zip_ref:
|
||||
extracted_files = zip_ref.namelist()
|
||||
|
||||
if len(extracted_files) != 1:
|
||||
logging.error(
|
||||
f"Expected one file in the ZIP, but found {len(extracted_files)} files."
|
||||
)
|
||||
return False
|
||||
|
||||
# Extract everything
|
||||
zip_ref.extractall(extract_dir)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error while extracting '{source_filepath}': {e}")
|
||||
return False
|
||||
|
||||
# Update the global_config to point to the extracted file
|
||||
extracted_filename = extracted_files[0]
|
||||
global_config.source_filename = extracted_filename
|
||||
|
||||
logging.info(
|
||||
f"Extracted '{extracted_filename}' to '{extract_dir}'. "
|
||||
f"Updated source_filepath to '{global_config.source_filepath}'."
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_source_file(global_config):
|
||||
file_type = global_config.file_type.lower()
|
||||
|
||||
if file_type == "xml":
|
||||
xml_is_valid, xml_validation_message = xml_utils.validate_xml(
|
||||
global_config.source_filepath, global_config.validation_schema_path
|
||||
)
|
||||
if not xml_is_valid:
|
||||
raise ValueError(f"XML validation failed: {xml_validation_message}")
|
||||
logging.info(xml_validation_message)
|
||||
|
||||
elif file_type == "csv":
|
||||
# TODO: add CSV validation here
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {file_type}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def process_tasks(tasks, global_config, workflow_context, client):
|
||||
|
||||
# get appropriate task processor
|
||||
processor_class = get_file_processor(global_config)
|
||||
|
||||
for task_conf in tasks:
|
||||
|
||||
logging.info(f"Starting task '{task_conf.task_name}'")
|
||||
file_processor = processor_class(
|
||||
global_config, task_conf, client, workflow_context
|
||||
)
|
||||
file_processor.process()
|
||||
|
||||
|
||||
def finalize_workflow(workflow_context, success=True):
|
||||
status = STATUS_SUCCESS if success else STATUS_FAILURE
|
||||
manage_runs.finalise_workflow(workflow_context["a_workflow_history_key"], status)
|
||||
if success:
|
||||
logging.info("Workflow completed successfully")
|
||||
else:
|
||||
logging.error("Workflow failed")
|
||||
|
||||
|
||||
def main(
|
||||
workflow_context: dict,
|
||||
source_filename: str,
|
||||
config_file_path: str,
|
||||
generate_workflow_context=False,
|
||||
keep_source_file=False,
|
||||
keep_tmp_dir=False,
|
||||
):
|
||||
|
||||
logging.info(f"Initializing mrds app, version {__version__}")
|
||||
|
||||
tmpdir_manager = None
|
||||
|
||||
try:
|
||||
# get configs
|
||||
global_config, tasks = initialize_config(source_filename, config_file_path)
|
||||
|
||||
# Handle temporary dirs
|
||||
if keep_tmp_dir:
|
||||
tmpdir = tempfile.mkdtemp(
|
||||
prefix="mrds_", dir=global_config.tmpdir
|
||||
) # dir is created and never deleted
|
||||
logging.info(
|
||||
f"Created temporary working directory (not auto-deleted): {tmpdir}"
|
||||
)
|
||||
else:
|
||||
tmpdir_manager = tempfile.TemporaryDirectory(
|
||||
prefix="mrds_", dir=global_config.tmpdir
|
||||
)
|
||||
tmpdir = tmpdir_manager.name
|
||||
logging.info(
|
||||
f"Created temporary working directory (auto-deleted): {tmpdir}"
|
||||
)
|
||||
|
||||
# override tmpdir with newly created tmpdir
|
||||
global_config.tmpdir = tmpdir
|
||||
|
||||
client = objectstore.get_client()
|
||||
|
||||
# Handle workflow_context generation if required
|
||||
if generate_workflow_context:
|
||||
logging.info("Generating workflow context automatically.")
|
||||
workflow_context = initialize_workflow(global_config)
|
||||
logging.info(f"Generated workflow context: {workflow_context}")
|
||||
else:
|
||||
logging.info(f"Using provided workflow context: {workflow_context}")
|
||||
|
||||
download_source_file(client, global_config)
|
||||
unzip_source_file_if_needed(global_config)
|
||||
validate_source_file(global_config)
|
||||
process_tasks(tasks, global_config, workflow_context, client)
|
||||
|
||||
if generate_workflow_context:
|
||||
finalize_workflow(workflow_context)
|
||||
|
||||
if not keep_source_file:
|
||||
archive_source_file(client, global_config)
|
||||
delete_source_file(client, global_config)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Critical error: {str(e)}")
|
||||
|
||||
# Finalize workflow with failure if needed
|
||||
if generate_workflow_context and "workflow_context" in locals():
|
||||
finalize_workflow(workflow_context, success=False)
|
||||
|
||||
raise RuntimeError(f"Workflow failed due to: {e}")
|
||||
|
||||
finally:
|
||||
# Always attempt to remove tmpdir if created a TemporaryDirectory manager
|
||||
if tmpdir_manager and not keep_tmp_dir:
|
||||
try:
|
||||
tmpdir_manager.cleanup()
|
||||
logging.info(f"Deleted temporary working directory {tmpdir}")
|
||||
except Exception:
|
||||
logging.exception(
|
||||
f"Failed to delete up temporary working directory {tmpdir}"
|
||||
)
|
||||
Reference in New Issue
Block a user