Coverage for melissa/server/main.py: 53%

60 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1""" 

2Main Melissa script 

3""" 

4import argparse 

5import importlib.util 

6import sys 

7import rapidjson 

8import logging 

9import os 

10from pathlib import Path 

11from typing import Any, Dict 

12from melissa.launcher.__main__ import validate_config, CONFIG_PARSE_MODE 

13 

14from melissa.utility.logger import configure_logger, get_log_level_from_verbosity 

15from melissa.utility.networking import get_rank_and_num_server_proc 

16 

17logger = logging.getLogger(__name__) 

18 

19try: 

20 rank = int(os.environ["SLURM_PROCID"]) 

21 nb_proc_server = int(os.environ["SLURM_NTASKS"]) 

22except KeyError: 

23 from mpi4py import MPI 

24 rank = MPI.COMM_WORLD.Get_rank() 

25 nb_proc_server = MPI.COMM_WORLD.Get_size() 

26 

27 

28def main() -> None: 

29 """ 

30 This function initiates Melissa server. 

31 

32 !!! 

33 This should not be called directly by the user, it is 

34 called by the `melissa-launcher`. 

35 !!! 

36 

37 1. Parse flags for the project_dir (or others if necessary) 

38 2. Resolves custom server class 

39 3. Instantiates the server with user config 

40 4. Calls server.start() 

41 

42 Now the server.sh file should be calling this function: 

43 exec python -m melissa 

44 --config_name /path/to/config 

45 

46 """ 

47 

48 parser = argparse.ArgumentParser( 

49 prog="melissa", description="Melissa" 

50 ) 

51 

52 # CLI flags 

53 parser.add_argument( 

54 "--project_dir", 

55 help="Directory to all necessary files:\n" 

56 " client.sh\n" 

57 " server.sh\n" 

58 " config.py\n" 

59 " data_generator\n" 

60 " CustomServer.py" 

61 ) 

62 

63 parser.add_argument( 

64 "--config_name", 

65 help="Defaults to `config` but user can change" 

66 "the search name by indicating it with this flag", 

67 default=None 

68 ) 

69 

70 args = parser.parse_args() 

71 if not args.config_name: 

72 conf_name = 'config' 

73 else: 

74 conf_name = args.config_name 

75 

76 # load the config into a python dictionary 

77 with open(Path(args.project_dir) / f"{conf_name}.json") as json_file: 

78 config_dict = rapidjson.load(json_file, parse_mode=CONFIG_PARSE_MODE) 

79 

80 config_dict['user_data_dir'] = Path('user_data') 

81 

82 # ensure user passed the necessary information for software configuration 

83 args, config_dict = validate_config(args, config_dict) 

84 

85 rank, _ = get_rank_and_num_server_proc() 

86 

87 # set server log level 

88 sconfig = config_dict['study_options'] # this will become server_config 

89 log_level = get_log_level_from_verbosity(sconfig.get("verbosity", 3)) 

90 restart_count = int(os.environ.get("MELISSA_RESTART", 0)) 

91 if restart_count: 

92 app_str = f"_restart_{restart_count}" 

93 else: 

94 app_str = "" 

95 configure_logger(f"melissa_server_{rank}{app_str}.log", log_level) 

96 

97 # Resolve the server 

98 myserver = get_resolved_server(args, config_dict) 

99 try: 

100 myserver.initialize_connections() 

101 myserver.start() 

102 except Exception as msg: 

103 logger.exception(f"Server failed with msg {msg}.") 

104 myserver.close_connection(1) 

105 

106 

107def get_resolved_server(args: argparse.Namespace, config_dict: Dict[str, Any]): 

108 

109 server_file_name = config_dict.get("server_filename", "server.py") 

110 server_path = Path(args.project_dir) / server_file_name 

111 server_module_name = server_path.stem 

112 spec_server = importlib.util.spec_from_file_location(server_module_name, server_path) 

113 if spec_server and spec_server.loader: 

114 sys.path.append(str(server_path.parent)) 

115 server = importlib.util.module_from_spec(spec_server) 

116 spec_server.loader.exec_module(server) 

117 server_class_name = config_dict.get("server_class", "MyServer") 

118 MyServerClass = getattr(server, server_class_name) 

119 my_server = MyServerClass(config_dict) 

120 else: 

121 logger.warning('Unable to import server') 

122 

123 return my_server