Coverage for melissa/server/main.py: 53%
60 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
1"""
2Main Melissa script
3"""
4import argparse
5import importlib.util
6import sys
7import rapidjson
8import logging
9import os
10from pathlib import Path
11from typing import Any, Dict
12from melissa.launcher.__main__ import validate_config, CONFIG_PARSE_MODE
14from melissa.utility.logger import configure_logger, get_log_level_from_verbosity
15from melissa.utility.networking import get_rank_and_num_server_proc
17logger = logging.getLogger(__name__)
19try:
20 rank = int(os.environ["SLURM_PROCID"])
21 nb_proc_server = int(os.environ["SLURM_NTASKS"])
22except KeyError:
23 from mpi4py import MPI
24 rank = MPI.COMM_WORLD.Get_rank()
25 nb_proc_server = MPI.COMM_WORLD.Get_size()
28def main() -> None:
29 """
30 This function initiates Melissa server.
32 !!!
33 This should not be called directly by the user, it is
34 called by the `melissa-launcher`.
35 !!!
37 1. Parse flags for the project_dir (or others if necessary)
38 2. Resolves custom server class
39 3. Instantiates the server with user config
40 4. Calls server.start()
42 Now the server.sh file should be calling this function:
43 exec python -m melissa
44 --config_name /path/to/config
46 """
48 parser = argparse.ArgumentParser(
49 prog="melissa", description="Melissa"
50 )
52 # CLI flags
53 parser.add_argument(
54 "--project_dir",
55 help="Directory to all necessary files:\n"
56 " client.sh\n"
57 " server.sh\n"
58 " config.py\n"
59 " data_generator\n"
60 " CustomServer.py"
61 )
63 parser.add_argument(
64 "--config_name",
65 help="Defaults to `config` but user can change"
66 "the search name by indicating it with this flag",
67 default=None
68 )
70 args = parser.parse_args()
71 if not args.config_name:
72 conf_name = 'config'
73 else:
74 conf_name = args.config_name
76 # load the config into a python dictionary
77 with open(Path(args.project_dir) / f"{conf_name}.json") as json_file:
78 config_dict = rapidjson.load(json_file, parse_mode=CONFIG_PARSE_MODE)
80 config_dict['user_data_dir'] = Path('user_data')
82 # ensure user passed the necessary information for software configuration
83 args, config_dict = validate_config(args, config_dict)
85 rank, _ = get_rank_and_num_server_proc()
87 # set server log level
88 sconfig = config_dict['study_options'] # this will become server_config
89 log_level = get_log_level_from_verbosity(sconfig.get("verbosity", 3))
90 restart_count = int(os.environ.get("MELISSA_RESTART", 0))
91 if restart_count:
92 app_str = f"_restart_{restart_count}"
93 else:
94 app_str = ""
95 configure_logger(f"melissa_server_{rank}{app_str}.log", log_level)
97 # Resolve the server
98 myserver = get_resolved_server(args, config_dict)
99 try:
100 myserver.initialize_connections()
101 myserver.start()
102 except Exception as msg:
103 logger.exception(f"Server failed with msg {msg}.")
104 myserver.close_connection(1)
107def get_resolved_server(args: argparse.Namespace, config_dict: Dict[str, Any]):
109 server_file_name = config_dict.get("server_filename", "server.py")
110 server_path = Path(args.project_dir) / server_file_name
111 server_module_name = server_path.stem
112 spec_server = importlib.util.spec_from_file_location(server_module_name, server_path)
113 if spec_server and spec_server.loader:
114 sys.path.append(str(server_path.parent))
115 server = importlib.util.module_from_spec(spec_server)
116 spec_server.loader.exec_module(server)
117 server_class_name = config_dict.get("server_class", "MyServer")
118 MyServerClass = getattr(server, server_class_name)
119 my_server = MyServerClass(config_dict)
120 else:
121 logger.warning('Unable to import server')
123 return my_server