Coverage for melissa/launcher/schema.py: 61%
41 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1import rapidjson
2from typing import Dict, Any, Tuple
3import argparse
4from jsonschema import Draft4Validator, validators
5from jsonschema.exceptions import ValidationError
6import logging
7import sys
9from melissa.utility.bcolors import TextColor
12logger = logging.getLogger(__name__)
14CONFIG_PARSE_MODE = rapidjson.PM_COMMENTS | rapidjson.PM_TRAILING_COMMAS
17CONF_SCHEMA = {
18 'type': 'object',
19 'properties': {
20 "server_filename": {"type": "string", "required": True, "message": "The name of the file containing the user defined server. Assumed to be in the same folder as the config."},
21 "server_class": {"type": "string", "required": True, "message": "The class name of the user defined server inside the server_filename file."},
22 "output_dir": {"type": "string", "required": True, "message": "The output dir to write results and logs. If relative path, then it is assumed relative to the CWD of the melissa-launcher command."},
23 "study_options": {
24 "type": "object",
25 "message": "A custom dictionary which is accessible inside the server_class for users to parameterize their studies.",
26 "properties": {
27 "bind_simulation_to_server_rank": {"type": "boolean", "default": False, "message": "Binds all timesteps of a simulation to be sent on the same server rank."},
28 "parameter_sweep_size": {"type": "integer", "required": True, "message": "The number of clients to launch (or groups if using sobol indices)."},
29 "nb_parameters": {"type": "integer", "required": True, "message": "The number of parameters to sample for each client."},
30 "nb_time_steps": {"type": "integer", "default": 0, "message": "Number of samples expected to arrive from each client. When not given, it can be inferred by Melissa (DL server only)."},
31 "verbosity": {"type": "integer", "default": 0, "message": "Set the logger verbosity. 3 includes all levels (including info, error, warning, and debug), 0 reduces to logging to minimum (error only)."},
32 "remove_client_scripts": {"type": "boolean", "default": False, "message": "Whether to remove all the client scripts after the execution."}
33 }
34 },
35 "dl_config": {
36 "type": "object",
37 "properties": {
38 "simulation_timeout": {"type": "integer", "default": 400, "message": "Seconds of client inactivity between two messages before timing out the client."},
39 "batch_size": {"type": "integer", "default": 10, "message": "Number of samples to build each batch."},
40 "nb_batches_update": {"type": "integer", "default": 10, "message": "Number of batches between validation checks and loss logging."},
41 "buffer": {"type": "string", "default": "Reservoir", "message": "Type of buffers for batch creation. Accepted values are `FIFO`, `FIRO` and, `Reservoir`."},
42 "buffer_size": {"type": "integer", "default": 10000, "message": "Maximum number of samples to store in the buffer (object used to generate batches for training)."},
43 "per_server_watermark": {"type": "integer", "message": "Required number of samples in each server process buffer before batch creation and training can begin."},
44 "tensorboard": {"type": "boolean", "default": True, "message": "Set to False to disable tensorboard logger entirely for production level runs where you do not wish to log metrics"},
45 "checkpoint_interval": {"type": "integer", "message": "Checkpoint frequency for the deep learning. Number of batches between each checkpoint. Defaults to `nb_batches_update`."},
46 "get_buffer_statistics": {"type": "boolean", "default": False, "message": "Estimate buffer statistics each time a batch is generated and add to the tensorboard log. Requires custom server imlementation of `get_buffer_statistics()`. See `examples/heat-pde/heat-pde-dl/heatpde_dl_server.py`"}
47 },
48 "message": "A custom dictionary which is accessible inside the server_class for users to customize their training loops and buffers."},
49 "sa_config": {
50 "type": "object",
51 "properties": {
52 "mean": {"type": "boolean", "default": True, "message": "Collect mean for all fields."},
53 "variance": {"type": "boolean", "default": False, "message": "Collect variance for all fields."},
54 "skewness": {"type": "boolean", "default": False, "message": "Collect skewness for all fields."},
55 "kurtosis": {"type": "boolean", "default": False, "message": "Collect kurtosis for all fields."},
56 "checkpoint_interval": {"type": "integer", "default": 0, "message": "Checkpoint frequency for the sensitivity analysis. Number of samples between each checkpoint."},
57 "sobol_indices": {"type": "boolean", "default": False, "message": "Activate sobol indicies. Group count determined by study_options.parameter_sweep_size"},
58 },
59 "message": "A dictionary used to control the sensitivity analysis servers."
60 },
61 "server_config": {
62 "type": "object",
63 "default": {"preprocessing_commands": []},
64 "properties": {
65 "preprocessing_commands": {"type": "array", "default": [], "message": "Commands that will be preprocessed by bash prior to launching the server job."},
66 "melissa_server_env": {"type": "string", "message": "Explicit path to the server installation. Typically does not need to be touched unless two different melissa installations are used."}
67 },
68 "message": "Special configuration for the server only.",
69 },
70 "client_config": {
71 "type": "object",
72 "properties": {
73 "executable_command": {"type": "string", "required": True, "message": "Command for executable, binary, or a script."},
74 "command_default_args": {"type": "array", "default": [], "message": "Default arguments that are concatenated to the `executable_command` option."},
75 "preprocessing_commands": {"type": "array", "default": [], "message": "Commands that will be preprocessed by bash prior to launching the client job."},
76 "melissa_client_env": {"type": "string", "message": "Explicit path to find the client installation. Typically does not need to be touched unless two different melissa installations are used."}
77 },
78 "message": "Special configuration for the client only."},
79 "launcher_config": {
80 "type": "object",
81 "properties": {
82 "scheduler": {"type": "string", "required": True, "message": "Select scheduler, can be 'oar', 'slurm', 'openmpi'"},
83 "server_executable": {"type": "string", "default": "server.sh", "message": "Experienced users only, used to modify the bash template."},
84 "bind": {"type": "string", "default": "0.0.0.0", "message": "Address to bind the REST API."},
85 "http_port": {"type": "integer", "default": 8888, "message": "Port to put the REST API."},
86 "http_token": {"type": "string", "default": "", "message": "Token used to access REST API, leave empty to let Melissa generate a unique secure token on launch."},
87 "fault_tolerance": {"type": "boolean", "default": True, "message": "Activate/deactivate fault tolerance."},
88 "protocol": {"type": "string", "default": "auto", "message": "Experienced users only, Melissa determines best protocol automatically."},
89 "std_output": {"type": "boolean", "default": True, "message": "Keep or delete the std out/err files from all jobs."},
90 "scheduler_arg": {"type": "array", "default": [], "message": "Common arguments to pass to scheduler for both client and server."},
91 "scheduler_arg_client": {"type": "array", "default": [], "message": "Arguments to pass to scheduler for client only."},
92 "scheduler_arg_server": {"type": "array", "default": [], "message": "Arguments to pass to scheduler for server only."},
93 "scheduler_server_command": {"type": "string", "message": "Option to change the execution command (e.g. in place of srun or mpirun)"},
94 "scheduler_client_command": {"type": "string", "message": "Option to change the execution command (e.g. in place of srun or mpirun)"},
95 "scheduler_server_command_options": {"type": "array", "default": [], "message": "Options to pass to the scheduler inside the server execution command. Example: ['mpi=pmi2'] which, with slurm, would yield an sbatch.X.sh file with srun mpi=pmi2 <other arguments>."},
96 "scheduler_client_command_options": {"type": "array", "default": [], "message": "Options to pass to the scheduler inside the client execution command. Example: ['mpi=pmi2'] which, with slurm, would yield an sbatch.X.sh file with srun mpi=pmi2 <other arguments>."},
97 "scheduler_arg_container": {"type": "array", "default": [], "message": "Arguments to pass to containers (e.g. oar-hybrid)."},
98 "container_client_size": {"type": "integer", "default": 1, "message": "Size of the container."},
99 "job_limit": {"type": "integer", "default": 1000, "message": "Maximum number of active jobs allowed."},
100 "besteffort_allocation_frequency": {"type": "integer", "default": 1, "message": "The frequency of job submission to submit to best-effort queue."},
101 "timer_delay": {"type": "integer", "message": "The minimal delay between two job status updates with the same value."},
102 "server_timeout": {"type": "integer", "message": "Maximum amount of seconds which defines a server timeout exit."},
103 "load_from_checkpoint": {"type": "boolean", "default": False, "message": "Look for checkpoint files to start the server from."},
104 "verbosity": {"type": "integer", "default": 0, "message": "Set the logger verbosity. 3 includes all levels (including info, error, warning, and debug), 0 reduces to logging to minimum (error only)."}
105 },
106 }
107 }
108}
111def _extend_validator(validator_class):
112 """
113 Extended validator for Melissa
114 """
115 validate_properties = validator_class.VALIDATORS['properties']
117 def set_defaults(validator, properties, instance, schema):
118 for prop, subschema in properties.items():
119 if 'default' in subschema:
120 instance.setdefault(prop, subschema['default'])
122 for error in validate_properties(
123 validator, properties, instance, schema,
124 ):
125 yield error
127 return validators.extend(
128 validator_class, {'properties': set_defaults}
129 )
132def validate_config(args: argparse.Namespace,
133 config: Dict[str, Any]) -> Tuple[argparse.Namespace,
134 Dict[str, Any]]:
136 MelissaValidator = _extend_validator(Draft4Validator)
137 try:
138 MelissaValidator(CONF_SCHEMA).validate(config)
139 except ValidationError as e:
140 logger.critical(
141 f"Invalid configuration. Reason: {e}"
142 )
144 return args, config
147def print_options():
148 def print_dict(schema, indent=0, prefix="*"):
149 for key, value in schema.items():
150 if isinstance(value, dict):
151 message = value.get("message", "")
152 type_info = value.get("type", "N/A")
153 default = value.get("default", "N/A")
154 print(
155 f"{' ' * indent}{TextColor.RED}{prefix} {key}{TextColor.ENDC}: {message} "
156 f"{TextColor.GREEN}Type: {TextColor.UNDERLINE}{type_info}{TextColor.ENDC}, "
157 f"{TextColor.BLUE}Default: {default}{TextColor.ENDC}"
158 )
159 if "properties" in value:
160 print_dict(value["properties"], indent + 8, "-")
161 else:
162 print(f"{' ' * indent}{TextColor.RED}{prefix} {key}{TextColor.ENDC}: {value}")
164 print(f"{TextColor.HEADER}{TextColor.BOLD}Configuration Options:{TextColor.ENDC}", end="\n\n")
165 print_dict(CONF_SCHEMA["properties"])
166 sys.exit()