Coverage for melissa/server/fault_tolerance.py: 39%

46 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1from typing import Dict, List 

2import time 

3import logging 

4from melissa.server.simulation import Group 

5 

6logger = logging.getLogger(__name__) 

7 

8 

9class FaultTolerance: 

10 def __init__(self, ft_off: bool, max_delay: float, crashes_before_redraw: int, nb_group: int): 

11 """ 

12 Fault-Tolerance object has the following attributes: 

13 - ft_off: Fault-Tolerance switch 

14 - max_delay: simulation walltime 

15 - crashes_before_redraw: number of accepted failures before input 

16 resampling 

17 - restart_grp: dictionary indicating which groups need relaunching 

18 with of without resampling 

19 - failed_id: unique ids list of failed groups 

20 - nb_group: total number of groups 

21 """ 

22 self.ft_off = ft_off 

23 self.max_delay = max_delay 

24 self.crashes_before_redraw = crashes_before_redraw 

25 self.restart_grp: Dict[int, bool] = {} 

26 self.failed_id: List[int] = [] 

27 self.nb_group: int = nb_group 

28 

29 def checkpointing(self) -> None: 

30 """ 

31 this function needs to be designed with the following 

32 features: 

33 - save NN model [DL] or Stats [SA] 

34 - save buffer [DL] 

35 - save simulations dictionary [DL, SA] 

36 """ 

37 

38 def restart(self) -> None: 

39 """ 

40 this function needs to be designed with the following 

41 features: 

42 - import NN model [DL] or Stats [SA] 

43 - import buffer [DL] 

44 - import simulations dictionary [DL, SA] 

45 """ 

46 

47 def handle_failed_group(self, group_id: int, group: Group) -> bool: 

48 """ 

49 this function reacts to failed simulations 

50 """ 

51 return_bool: bool = False 

52 group.n_failures += 1 

53 

54 if group.n_failures > self.crashes_before_redraw: 

55 logger.warning(f"Group with id {group_id} failed too many times") 

56 return_bool = True 

57 else: 

58 logger.warning(f"Group with id {group_id} failed {group.n_failures} times") 

59 return_bool = False 

60 

61 self.append_failed_id(group_id) 

62 return return_bool 

63 

64 def check_time_out(self, groups: Dict[int, Group]) -> bool: 

65 """ 

66 this function verifies if any simulation timed-out 

67 by doing the following: 

68 - create list of timed-out sim ids 

69 - update the number of failures 

70 - create a dictionary to relaunch all simulations 

71 - try to update the list of unique failed ids 

72 - return a boolean indicating if any simulation timed-out 

73 """ 

74 timed_out_ids: List[int] = [] 

75 

76 for grp_id, grp in groups.items(): 

77 for sim_id, sim in grp.simulations.items(): 

78 if ( 

79 sim.last_message is not None 

80 and not sim.finished() 

81 and time.time() - sim.last_message 

82 > self.max_delay 

83 and grp_id not in timed_out_ids 

84 ): 

85 timed_out_ids.append(grp_id) 

86 

87 for grp_id in timed_out_ids: 

88 logger.warning(f"Simulation(s) in group of id {grp_id} timed-out") 

89 groups[grp_id].n_failures += 1 

90 for sim in list(groups[grp_id].simulations.values()): 

91 sim.last_message = None 

92 self.restart_grp[grp_id] = ( 

93 groups[grp_id].n_failures > self.crashes_before_redraw 

94 ) 

95 self.append_failed_id(grp_id) 

96 

97 return len(self.restart_grp) > 0 

98 

99 def append_failed_id(self, group_id: int): 

100 """ 

101 this function tries to update the list of 

102 unique failed ids and raises Exception if 

103 needed 

104 """ 

105 if group_id not in self.failed_id: 

106 self.failed_id.append(group_id) 

107 if self.ft_off: 

108 raise Exception( 

109 f"Fault-Tolerance is off, group {group_id} " 

110 "failure will cause the server to abort" 

111 ) 

112 if len(self.failed_id) == self.nb_group: 

113 raise Exception( 

114 "All groups failed please make sure: \n" 

115 "- the path to the executable is correct, \n" 

116 "- the number of expected time steps is correct, \n" 

117 "- the simulation walltime was well estimated." 

118 )