Coverage for melissa/launcher/schema.py: 61%

41 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-03 09:52 +0100

1import rapidjson 

2from typing import Dict, Any, Tuple 

3import argparse 

4from jsonschema import Draft4Validator, validators 

5from jsonschema.exceptions import ValidationError 

6import logging 

7import sys 

8 

9from melissa.utility.bcolors import TextColor 

10 

11 

12logger = logging.getLogger(__name__) 

13 

14CONFIG_PARSE_MODE = rapidjson.PM_COMMENTS | rapidjson.PM_TRAILING_COMMAS 

15 

16 

17CONF_SCHEMA = { 

18 'type': 'object', 

19 'properties': { 

20 "server_filename": {"type": "string", "required": True, "message": "The name of the file containing the user defined server. Assumed to be in the same folder as the config."}, 

21 "server_class": {"type": "string", "required": True, "message": "The class name of the user defined server inside the server_filename file."}, 

22 "output_dir": {"type": "string", "required": True, "message": "The output dir to write results and logs. If relative path, then it is assumed relative to the CWD of the melissa-launcher command."}, 

23 "study_options": { 

24 "type": "object", 

25 "message": "A custom dictionary which is accessible inside the server_class for users to parameterize their studies.", 

26 "properties": { 

27 "bind_simulation_to_server_rank": {"type": "boolean", "default": False, "message": "Binds all timesteps of a simulation to be sent on the same server rank."}, 

28 "parameter_sweep_size": {"type": "integer", "required": True, "message": "The number of clients to launch (or groups if using sobol indices)."}, 

29 "nb_parameters": {"type": "integer", "required": True, "message": "The number of parameters to sample for each client."}, 

30 "nb_time_steps": {"type": "integer", "default": 0, "message": "Number of samples expected to arrive from each client. When not given, it can be inferred by Melissa (DL server only)."}, 

31 "verbosity": {"type": "integer", "default": 0, "message": "Set the logger verbosity. 3 includes all levels (including info, error, warning, and debug), 0 reduces to logging to minimum (error only)."}, 

32 "remove_client_scripts": {"type": "boolean", "default": False, "message": "Whether to remove all the client scripts after the execution."} 

33 } 

34 }, 

35 "dl_config": { 

36 "type": "object", 

37 "properties": { 

38 "simulation_timeout": {"type": "integer", "default": 400, "message": "Seconds of client inactivity between two messages before timing out the client."}, 

39 "batch_size": {"type": "integer", "default": 10, "message": "Number of samples to build each batch."}, 

40 "nb_batches_update": {"type": "integer", "default": 10, "message": "Number of batches between validation checks and loss logging."}, 

41 "buffer": {"type": "string", "default": "Reservoir", "message": "Type of buffers for batch creation. Accepted values are `FIFO`, `FIRO` and, `Reservoir`."}, 

42 "buffer_size": {"type": "integer", "default": 10000, "message": "Maximum number of samples to store in the buffer (object used to generate batches for training)."}, 

43 "per_server_watermark": {"type": "integer", "message": "Required number of samples in each server process buffer before batch creation and training can begin."}, 

44 "tensorboard": {"type": "boolean", "default": True, "message": "Set to False to disable tensorboard logger entirely for production level runs where you do not wish to log metrics"}, 

45 "checkpoint_interval": {"type": "integer", "message": "Checkpoint frequency for the deep learning. Number of batches between each checkpoint. Defaults to `nb_batches_update`."}, 

46 "get_buffer_statistics": {"type": "boolean", "default": False, "message": "Estimate buffer statistics each time a batch is generated and add to the tensorboard log. Requires custom server imlementation of `get_buffer_statistics()`. See `examples/heat-pde/heat-pde-dl/heatpde_dl_server.py`"} 

47 }, 

48 "message": "A custom dictionary which is accessible inside the server_class for users to customize their training loops and buffers."}, 

49 "sa_config": { 

50 "type": "object", 

51 "properties": { 

52 "mean": {"type": "boolean", "default": True, "message": "Collect mean for all fields."}, 

53 "variance": {"type": "boolean", "default": False, "message": "Collect variance for all fields."}, 

54 "skewness": {"type": "boolean", "default": False, "message": "Collect skewness for all fields."}, 

55 "kurtosis": {"type": "boolean", "default": False, "message": "Collect kurtosis for all fields."}, 

56 "checkpoint_interval": {"type": "integer", "default": 0, "message": "Checkpoint frequency for the sensitivity analysis. Number of samples between each checkpoint."}, 

57 "sobol_indices": {"type": "boolean", "default": False, "message": "Activate sobol indicies. Group count determined by study_options.parameter_sweep_size"}, 

58 }, 

59 "message": "A dictionary used to control the sensitivity analysis servers." 

60 }, 

61 "server_config": { 

62 "type": "object", 

63 "default": {"preprocessing_commands": []}, 

64 "properties": { 

65 "preprocessing_commands": {"type": "array", "default": [], "message": "Commands that will be preprocessed by bash prior to launching the server job."}, 

66 "melissa_server_env": {"type": "string", "message": "Explicit path to the server installation. Typically does not need to be touched unless two different melissa installations are used."} 

67 }, 

68 "message": "Special configuration for the server only.", 

69 }, 

70 "client_config": { 

71 "type": "object", 

72 "properties": { 

73 "executable_command": {"type": "string", "required": True, "message": "Command for executable, binary, or a script."}, 

74 "command_default_args": {"type": "array", "default": [], "message": "Default arguments that are concatenated to the `executable_command` option."}, 

75 "preprocessing_commands": {"type": "array", "default": [], "message": "Commands that will be preprocessed by bash prior to launching the client job."}, 

76 "melissa_client_env": {"type": "string", "message": "Explicit path to find the client installation. Typically does not need to be touched unless two different melissa installations are used."} 

77 }, 

78 "message": "Special configuration for the client only."}, 

79 "launcher_config": { 

80 "type": "object", 

81 "properties": { 

82 "scheduler": {"type": "string", "required": True, "message": "Select scheduler, can be 'oar', 'slurm', 'openmpi'"}, 

83 "server_executable": {"type": "string", "default": "server.sh", "message": "Experienced users only, used to modify the bash template."}, 

84 "bind": {"type": "string", "default": "0.0.0.0", "message": "Address to bind the REST API."}, 

85 "http_port": {"type": "integer", "default": 8888, "message": "Port to put the REST API."}, 

86 "http_token": {"type": "string", "default": "", "message": "Token used to access REST API, leave empty to let Melissa generate a unique secure token on launch."}, 

87 "fault_tolerance": {"type": "boolean", "default": True, "message": "Activate/deactivate fault tolerance."}, 

88 "protocol": {"type": "string", "default": "auto", "message": "Experienced users only, Melissa determines best protocol automatically."}, 

89 "std_output": {"type": "boolean", "default": True, "message": "Keep or delete the std out/err files from all jobs."}, 

90 "scheduler_arg": {"type": "array", "default": [], "message": "Common arguments to pass to scheduler for both client and server."}, 

91 "scheduler_arg_client": {"type": "array", "default": [], "message": "Arguments to pass to scheduler for client only."}, 

92 "scheduler_arg_server": {"type": "array", "default": [], "message": "Arguments to pass to scheduler for server only."}, 

93 "scheduler_server_command": {"type": "string", "message": "Option to change the execution command (e.g. in place of srun or mpirun)"}, 

94 "scheduler_client_command": {"type": "string", "message": "Option to change the execution command (e.g. in place of srun or mpirun)"}, 

95 "scheduler_server_command_options": {"type": "array", "default": [], "message": "Options to pass to the scheduler inside the server execution command. Example: ['mpi=pmi2'] which, with slurm, would yield an sbatch.X.sh file with srun mpi=pmi2 <other arguments>."}, 

96 "scheduler_client_command_options": {"type": "array", "default": [], "message": "Options to pass to the scheduler inside the client execution command. Example: ['mpi=pmi2'] which, with slurm, would yield an sbatch.X.sh file with srun mpi=pmi2 <other arguments>."}, 

97 "scheduler_arg_container": {"type": "array", "default": [], "message": "Arguments to pass to containers (e.g. oar-hybrid)."}, 

98 "container_client_size": {"type": "integer", "default": 1, "message": "Size of the container."}, 

99 "job_limit": {"type": "integer", "default": 1000, "message": "Maximum number of active jobs allowed."}, 

100 "besteffort_allocation_frequency": {"type": "integer", "default": 1, "message": "The frequency of job submission to submit to best-effort queue."}, 

101 "timer_delay": {"type": "integer", "message": "The minimal delay between two job status updates with the same value."}, 

102 "server_timeout": {"type": "integer", "message": "Maximum amount of seconds which defines a server timeout exit."}, 

103 "load_from_checkpoint": {"type": "boolean", "default": False, "message": "Look for checkpoint files to start the server from."}, 

104 "verbosity": {"type": "integer", "default": 0, "message": "Set the logger verbosity. 3 includes all levels (including info, error, warning, and debug), 0 reduces to logging to minimum (error only)."} 

105 }, 

106 } 

107 } 

108} 

109 

110 

111def _extend_validator(validator_class): 

112 """ 

113 Extended validator for Melissa 

114 """ 

115 validate_properties = validator_class.VALIDATORS['properties'] 

116 

117 def set_defaults(validator, properties, instance, schema): 

118 for prop, subschema in properties.items(): 

119 if 'default' in subschema: 

120 instance.setdefault(prop, subschema['default']) 

121 

122 for error in validate_properties( 

123 validator, properties, instance, schema, 

124 ): 

125 yield error 

126 

127 return validators.extend( 

128 validator_class, {'properties': set_defaults} 

129 ) 

130 

131 

132def validate_config(args: argparse.Namespace, 

133 config: Dict[str, Any]) -> Tuple[argparse.Namespace, 

134 Dict[str, Any]]: 

135 

136 MelissaValidator = _extend_validator(Draft4Validator) 

137 try: 

138 MelissaValidator(CONF_SCHEMA).validate(config) 

139 except ValidationError as e: 

140 logger.critical( 

141 f"Invalid configuration. Reason: {e}" 

142 ) 

143 

144 return args, config 

145 

146 

147def print_options(): 

148 def print_dict(schema, indent=0, prefix="*"): 

149 for key, value in schema.items(): 

150 if isinstance(value, dict): 

151 message = value.get("message", "") 

152 type_info = value.get("type", "N/A") 

153 default = value.get("default", "N/A") 

154 print( 

155 f"{' ' * indent}{TextColor.RED}{prefix} {key}{TextColor.ENDC}: {message} " 

156 f"{TextColor.GREEN}Type: {TextColor.UNDERLINE}{type_info}{TextColor.ENDC}, " 

157 f"{TextColor.BLUE}Default: {default}{TextColor.ENDC}" 

158 ) 

159 if "properties" in value: 

160 print_dict(value["properties"], indent + 8, "-") 

161 else: 

162 print(f"{' ' * indent}{TextColor.RED}{prefix} {key}{TextColor.ENDC}: {value}") 

163 

164 print(f"{TextColor.HEADER}{TextColor.BOLD}Configuration Options:{TextColor.ENDC}", end="\n\n") 

165 print_dict(CONF_SCHEMA["properties"]) 

166 sys.exit()