Coverage for melissa/server/fault_tolerance.py: 39%
46 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
1from typing import Dict, List
2import time
3import logging
4from melissa.server.simulation import Group
6logger = logging.getLogger(__name__)
9class FaultTolerance:
10 def __init__(self, ft_off: bool, max_delay: float, crashes_before_redraw: int, nb_group: int):
11 """
12 Fault-Tolerance object has the following attributes:
13 - ft_off: Fault-Tolerance switch
14 - max_delay: simulation walltime
15 - crashes_before_redraw: number of accepted failures before input
16 resampling
17 - restart_grp: dictionary indicating which groups need relaunching
18 with of without resampling
19 - failed_id: unique ids list of failed groups
20 - nb_group: total number of groups
21 """
22 self.ft_off = ft_off
23 self.max_delay = max_delay
24 self.crashes_before_redraw = crashes_before_redraw
25 self.restart_grp: Dict[int, bool] = {}
26 self.failed_id: List[int] = []
27 self.nb_group: int = nb_group
29 def checkpointing(self) -> None:
30 """
31 this function needs to be designed with the following
32 features:
33 - save NN model [DL] or Stats [SA]
34 - save buffer [DL]
35 - save simulations dictionary [DL, SA]
36 """
38 def restart(self) -> None:
39 """
40 this function needs to be designed with the following
41 features:
42 - import NN model [DL] or Stats [SA]
43 - import buffer [DL]
44 - import simulations dictionary [DL, SA]
45 """
47 def handle_failed_group(self, group_id: int, group: Group) -> bool:
48 """
49 this function reacts to failed simulations
50 """
51 return_bool: bool = False
52 group.n_failures += 1
54 if group.n_failures > self.crashes_before_redraw:
55 logger.warning(f"Group with id {group_id} failed too many times")
56 return_bool = True
57 else:
58 logger.warning(f"Group with id {group_id} failed {group.n_failures} times")
59 return_bool = False
61 self.append_failed_id(group_id)
62 return return_bool
64 def check_time_out(self, groups: Dict[int, Group]) -> bool:
65 """
66 this function verifies if any simulation timed-out
67 by doing the following:
68 - create list of timed-out sim ids
69 - update the number of failures
70 - create a dictionary to relaunch all simulations
71 - try to update the list of unique failed ids
72 - return a boolean indicating if any simulation timed-out
73 """
74 timed_out_ids: List[int] = []
76 for grp_id, grp in groups.items():
77 for sim_id, sim in grp.simulations.items():
78 if (
79 sim.last_message is not None
80 and not sim.finished()
81 and time.time() - sim.last_message
82 > self.max_delay
83 and grp_id not in timed_out_ids
84 ):
85 timed_out_ids.append(grp_id)
87 for grp_id in timed_out_ids:
88 logger.warning(f"Simulation(s) in group of id {grp_id} timed-out")
89 groups[grp_id].n_failures += 1
90 for sim in list(groups[grp_id].simulations.values()):
91 sim.last_message = None
92 self.restart_grp[grp_id] = (
93 groups[grp_id].n_failures > self.crashes_before_redraw
94 )
95 self.append_failed_id(grp_id)
97 return len(self.restart_grp) > 0
99 def append_failed_id(self, group_id: int):
100 """
101 this function tries to update the list of
102 unique failed ids and raises Exception if
103 needed
104 """
105 if group_id not in self.failed_id:
106 self.failed_id.append(group_id)
107 if self.ft_off:
108 raise Exception(
109 f"Fault-Tolerance is off, group {group_id} "
110 "failure will cause the server to abort"
111 )
112 if len(self.failed_id) == self.nb_group:
113 raise Exception(
114 "All groups failed please make sure: \n"
115 "- the path to the executable is correct, \n"
116 "- the number of expected time steps is correct, \n"
117 "- the simulation walltime was well estimated."
118 )