The Higher Education and Research forge

Home My Page Projects Code Snippets Project Openings EMULSION public releases
Summary Activity Surveys SCM Listes Sympa

SCM Repository

1 """
2 .. module:: emulsion.model.state_machines
4 .. moduleauthor:: Sébastien Picault <sebastien.picault@inra.fr>
6 """
9 # EMULSION (Epidemiological Multi-Level Simulation framework)
10 # ===========================================================
11
12 # Contributors and contact:
13 # -------------------------
14
15 #     - Sébastien Picault (sebastien.picault@inra.fr)
16 #     - Yu-Lin Huang
17 #     - Vianney Sicard
18 #     - Sandie Arnoux
19 #     - Gaël Beaunée
20 #     - Pauline Ezanno (pauline.ezanno@inra.fr)
21
22 #     BIOEPAR, INRA, Oniris, Atlanpole La Chantrerie,
23 #     Nantes CS 44307 CEDEX, France
24
25
26 # How to cite:
27 # ------------
28
29 #     S. Picault, Y.-L. Huang, V. Sicard, P. Ezanno (2017). "Enhancing
30 #     Sustainability of Complex Epidemiological Models through a Generic
31 #     Multilevel Agent-based Approach", in: C. Sierra (ed.), 26th
32 #     International Joint Conference on Artificial Intelligence (IJCAI),
33 #     AAAI, p. 374-380. DOI: 10.24963/ijcai.2017/53
34
35
36 # License:
37 # --------
38
39 #    Copyright 2016 INRA and Univ. Lille
40
41 #    Inter Deposit Digital Number: IDDN.FR.001.280043.000.R.P.2018.000.10000
42
43 #    Agence pour la Protection des Programmes,
44 #    54 rue de Paradis, 75010 Paris, France
45
46 #    Licensed under the Apache License, Version 2.0 (the "License");
47 #    you may not use this file except in compliance with the License.
48 #    You may obtain a copy of the License at
49
50 #        http://www.apache.org/licenses/LICENSE-2.0
51
52 #    Unless required by applicable law or agreed to in writing, software
53 #    distributed under the License is distributed on an "AS IS" BASIS,
54 #    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
55 #    See the License for the specific language governing permissions and
56 #    limitations under the License.
59 from   functools               import partial
61 import numpy                   as     np
62 from   sympy                   import sympify
64 from   sortedcontainers        import SortedSet
66 import emulsion.tools.graph    as     enx
67 from   emulsion.agent.action   import AbstractAction
68 from   emulsion.tools.state    import StateVarDict, EmulsionEnum
70 from   emulsion.model.functions import ACTION_SYMBOL, WHEN_SYMBOL, ESCAPE_SYMBOL,\
71     COND_SYMBOL, CROSS_SYMBOL, EDGE_KEYWORDS, CLOCK_SYMBOL,\
72     make_when_condition, make_duration_condition, make_duration_init_action
74 from   emulsion.model.exceptions     import SemanticException
77 #   _____ _        _       __  __            _     _
78 #  / ____| |      | |     |  \/  |          | |   (_)
79 # | (___ | |_ __ _| |_ ___| \  / | __ _  ___| |__  _ _ __   ___
80 #  \___ \| __/ _` | __/ _ \ |\/| |/ _` |/ __| '_ \| | '_ \ / _ \
81 #  ____) | || (_| | ||  __/ |  | | (_| | (__| | | | | | | |  __/
82 # |_____/ \__\__,_|\__\___|_|  |_|\__,_|\___|_| |_|_|_| |_|\___|
86 class StateMachine(object):
87     """Class in charge of the description of biological or economical
88     processes, modeled as Finite State Machines. The formalism
89     implemented here is based on UML state machine diagrams, with
90     adaptations to biology.
92     """
93     def __init__(self, machine_name, description, model):
94         """Build a State Machine within the specified model, based on
95         the specified description (dictionary).
97         """
98         self.model = model
99         self.machine_name = machine_name
100         self.parse(description)
102     def _reset_all(self):
103         self._statedesc = {}
104         self._description = {}
105         self.states = None
106         self.graph = enx.MultiDiGraph()
107         self.stateprops = StateVarDict()
108         self.state_actions = {}
109 #        self.edge_actions = {}
111     def parse(self, description):
112         """Build the State Machine from the specified dictionary
113         (expected to come from a YAML configuration file).
115         """
116         self._reset_all()
117         # keep an exhaustive description
118         self._description = description
119         # build the enumeration of the states
120         self.build_states()
121         # build the graph based on the states and the transitions between them
122         self.build_graph()
123         # build actions associated with the state machine (states or edges)
124         self.build_actions()
126     def get_property(self, state_name, property_name):
127         """Return the property associated to the specified state."""
128         if state_name not in self.stateprops or\
129            property_name not in self.stateprops[state_name]:
130             return self.graph.node[state_name][property_name]\
131                 if property_name in self.graph.node[state_name]\
132                    else None
133         return self.stateprops[state_name][property_name]
135     def build_states(self):
136         """Parse the description of the state machine and extract the
137         existing states. States are described as list items, endowed
138         with key-value properties. It is recommended to define only
139         one state per list item (especially to ensure that states are
140         always stored in the same order in all executions).
142         Example of YAML specification:
143         ------------------------------
144         states:
145           - S:
146               name: Susceptible
147               desc: 'non-shedder cows without antibodies'
148           - I+:
149               name: Infectious plus
150               desc: 'shedder cows with antibodies'
151               fillcolor: orange
152               on_stay:
153                 - increase: total_E
154                   rate: Q1
156         """
157         states = []
158         default_state = None
159         # retrieve information for each state
160         for statedict in self._description['states']:
161             for name, value in statedict.items():
162                 states.append(name)
163                 # provide a default fillcolor
164                 if 'fillcolor' not in value:
165                     value['fillcolor'] = 'lightgray'
166                 # if properties are provided, add the corresponding
167                 # expression to the model
168                 if 'properties' not in value:
169                     value['properties'] = {}
170                 # store special property: "autoremove: yes"
171                 value['properties']['autoremove'] = value['autoremove']\
172                                                     if 'autoremove' in value else False
173                 # store special property: "default: yes"
174                 # if several states are marked "default", take the first one
175                 value['properties']['default'] = False
176                 if ('default' in value) and (value['default']) and (default_state is None):
177                     value['properties']['default'] = True
178                     default_state = name
179                 self.stateprops[name] = {k: self.model.add_expression(v)
180                                          for k, v in value['properties'].items()}
181                 # store other information
182                 self._statedesc[name] = value
183                 # and retrieve available actions if any
184                 for keyword in ['on_enter', 'on_stay', 'on_exit']:
185                     if keyword in value:
186                         self._add_state_actions(name, keyword, value[keyword])
187         # build the enumeration of the states
188         self.states = EmulsionEnum(self.machine_name.capitalize(),
189                                    states, module=__name__)
190         self.states.state_machine = self
191         self.states.autoremove = False
193         for state in self.states:
194             if state.name in self.model.states:
195                 other_machine = self.model.states[state.name].__class__.__name__
196                 raise SemanticException(
197                     'Conflict: State %s found in statemachines %s and %s' %
198                     (state.name, other_machine, state.__class__.__name__))
199             if state.name in self.model.parameters:
200                 raise SemanticException(
201                     'Conflict: State %s of statemachines %s found in parameters'
202                     % (state.name, state.__class__.__name__))
203             self.model.states[state.name] = state
204             if self.stateprops[state.name]['autoremove']:
205                 state.autoremove = True
207         self.states.is_default = False
208         if default_state is not None:
209             self.states[default_state].is_default = True
210             self.states.default = self.states[default_state]
211             self.states.available = (self.states.default,)
212         else:
213             self.states.default = None
214             self.states.available = tuple(s for s in self.states if not s.autoremove)
215         # print(self.states, self.states.default, self.states.available)
217         self.model._values['_random_' + self.machine_name] = self.get_random_state
218         if self.states.default:
219             self.model._values['_default_' + self.machine_name] = self.get_default_state
220         else:
221             self.model._values['_default_' + self.machine_name] = self.get_random_state
224     def get_random_state(self, caller=None):
225         """Return a random state for this state machine."""
226         return np.random.choice([state for state in self.states if not state.autoremove])
228     def get_default_state(self, caller=None):
229         """Return a random state for this state machine."""
230         return self.states.default
232     @property
233     def state_colors(self):
234         """Return a dictionary of state names associated with fill colors."""
235         return {state.name: self._statedesc[state.name]['fillcolor']
236                 for state in self.states
237                 if not state.autoremove}
240     def build_graph(self):
241         """Parse the description of the state machine and extract the
242         graph of the transitions between the states. Since a
243         MultiDiGraph is used, each pair of nodes can be bound by
244         several transitions if needed (beware the associated
245         semantics).
247         Example of YAML specification:
248         ------------------------------
249         transitions:
250           - {from: S, to: I-, proba: p, cond: not_vaccinated}
251           - {from: I-, to: S, proba: m}
252           - {from: I-, to: I+m, proba: 'q*plp'}
253           - {from: I-, to: I+, proba: 'q*(1-plp)'}
255         """
256         # add a node for each state
257         for state in self.states:
258             name = state.name
259             self._statedesc[name]['tooltip'] = self.describe_state(name)
260             self.graph.add_node(name, **self._statedesc[name])
261         # build edges between states according to specified transitions
262         if 'transitions' in self._description:
263             self._parse_edges(self._description['transitions'],
264                               type_id=enx.EdgeTypes.TRANSITION)
265         if 'productions' in self._description:
266             self._parse_edges(self._description['productions'],
267                               type_id=enx.EdgeTypes.PRODUCTION)
269     def _parse_edges(self, edges, type_id=enx.EdgeTypes.TRANSITION):
270         """Parse the description of edges, with the difference
271         transitions/productions
273         """
274         for edge in edges:
275             from_ = edge['from']
276             to_ = edge['to']
277             others = {k: v for (k, v) in edge.items()
278                       if k != 'from' and k != 'to'}
279             for kwd in EDGE_KEYWORDS:
280                 if kwd in others:
281                     # parm = pretty(sympify(others[kwd], locals=self.model._namespace))
282                     parm = others[kwd]
283                     label = '{}: {}'.format(kwd, parm)
284             # label = ', '.join([pretty(sympify(x, locals=self.model._namespace))
285             #                    for x in others.values()])
286                     if str(parm) in self.model.parameters:
287                         others['labeltooltip'] = self.model.describe_parameter(parm)
288                     else:
289                         others['labeltooltip'] = label
290             # others['labeltooltip'] = ', '.join([self.model.describe_parameter(x)
291             #                                     for x in others.values()
292             #                                     if str(x) in self.model.parameters])
293             # handle conditions if any on the edge
294             cond, escape = None, False
295             if 'cond' in others:
296                 cond = others['cond']
297                 others['truecond'] = others['cond']
298             if ('escape' in others) and (type_id == enx.EdgeTypes.TRANSITION):
299                 cond = others['escape']
300                 escape = True
301             if cond is not None:
302                 ### WARNING the operation below is not completely
303                 ### safe... it is done to replace conditions of the form
304                 ### 'x == y' by 'Eq(x, y)', but it is a simple
305                 ### substitution instead of parsing the syntax
306                 ### tree... Thus it is *highly* recommended to express
307                 ### conditions directly with Eq(x, y)
308                 if '==' in str(cond):
309                     cond = 'Eq({})'.format(','.join(cond.split('==')))
310                     # others['label'] = ', '.join(others.values())
311             # if duration specified for this state, handle it as an
312             # additional condition
313             if ('duration' in self._statedesc[from_]) and (type_id == enx.EdgeTypes.TRANSITION):
314                 duration_cond = make_duration_condition(self.model, self.machine_name)
315                 if cond is None:
316                     cond = duration_cond
317                 elif escape:
318                     cond = 'AND(Not({}),{})'.format(duration_cond, cond)
319                 else:
320                     cond = 'AND({},{})'.format(duration_cond, cond)
321                     # print(cond)
322                 others['cond'] = cond
323             if cond is not None:
324                 ## DEBUG:                print(cond, self.model._namespace)
325                 self.model.conditions[cond] = sympify(cond,
326                                                       locals=self.model._namespace)
327             # handle 'when' clause if any on the edge
328             self._parse_when(others)
329             # handle 'duration', 'escape' and 'condition' clauses if
330             # any on the edge
331             if type_id == enx.EdgeTypes.TRANSITION:
332                 self._parse_conditions_durations(from_, others)
333             # parse actions on cross if any
334             if ('on_cross' in others) and (type_id == enx.EdgeTypes.TRANSITION):
335                 l_actions = self._parse_action_list(others['on_cross'])
336                 others['actions'] = l_actions
337             others['label'] = label
338             others['type_id'] = type_id
339             self.graph.add_edge(from_, to_, **others)
340             # register rate/proba/amount expressions in the model
341             for keyword in EDGE_KEYWORDS:
342                 if keyword in others:
343                     self.model.add_expression(others[keyword])
346     def _parse_when(self, edge_desc):
347         """Parse the edge description in search for a 'when'
348         clause. This special condition is aimed at globally assessing
349         a time period within the whole simulation.
351         """
352         if 'when' in edge_desc:
353             expression = sympify(edge_desc['when'],
354                                  locals=self.model._event_namespace)
355             edge_desc['when'] = str(expression)
356             self.model._values[str(expression)] = make_when_condition(
357                 expression, modules=self.model.modules)
359     def _parse_conditions_durations(self, from_, edge_desc):
360         """Parse the edge description in search for durations,
361         escapement and conditions specifications. Durations
362         ('duration' clause )are handled as an additional condition
363         (agents entering the state are given a 'time to live' in the
364         state, then they are not allowed to leave the state until
365         their stay reaches that value). Escapements ('escape' clause)
366         are also translated as a condition, allowing the agent to
367         leave the state when the expression is true, only while the
368         stay duration is below its nominal value.
370         """
371         cond, escape = None, False
372         if 'cond' in edge_desc:
373             cond = edge_desc['cond']
374         if 'escape' in edge_desc:
375             cond = edge_desc['escape']
376             escape = True
377         if cond is not None:
378             ### WARNING the operation below is not completely
379             ### safe... it is done to replace conditions of the form
380             ### 'x == y' by 'Eq(x, y)', but it is a simple
381             ### substitution instead of parsing the syntax
382             ### tree... Thus it is *highly* recommended to express
383             ### conditions directly with Eq(x, y)
384             if '==' in str(cond):
385                 cond = 'Eq({})'.format(','.join(cond.split('==')))
386                 # edge_desc['label'] = ', '.join(edge_desc.values()) if
387         # duration specified for this state, handle it as an
388         # additional condition
389         if 'duration' in self._statedesc[from_]:
390             duration_cond = make_duration_condition(self.model, self.machine_name)
391             if cond is None:
392                 cond = duration_cond
393             elif escape:
394                 cond = 'AND(Not({}),{})'.format(duration_cond, cond)
395             else:
396                 cond = 'AND({},{})'.format(duration_cond, cond)
397             edge_desc['cond'] = cond
398         if cond is not None:
399             self.model.conditions[cond] = sympify(cond, locals=self.model._namespace)
401     def build_actions(self):
402         """Parse the description of the state machine and extract the
403         actions that agents running this state machine must have.
405         Example of YAML specification:
406         ------------------------------
407         actions:
408           say_hello:
409             desc: action performed when entering the S state
411         """
412         for name, value in self._statedesc.items():
413             for keyword in ['on_enter', 'on_stay', 'on_exit']:
414                 if keyword in value:
415                     self._add_state_actions(name, keyword, value[keyword])
416             if 'duration' in value:
417                 val = value['duration']
418                 self._add_state_duration_actions(name, val)
420     def get_value(self, name):
421         """Return the value associated with the specified name."""
422         return self.model.get_value(name)
425     def _add_state_duration_actions(self, state_name, duration_value):
426         """Add implicit actions to manage stay duration in the specified state
427         name. The `duration_value` can be either a parameter, a
428         'statevar' or a distribution.
430         """
431         # initialize the actions associated to the state if none
432         if state_name not in self.state_actions:
433             self.state_actions[state_name] = {}
434         # retrieve the list of actions on enter for this state, if any
435         lenter = self.state_actions[state_name]['on_enter']\
436                    if 'on_enter' in self.state_actions[state_name] else []
437         # build a partial function based on the current state machine name
438         enter_action = partial(make_duration_init_action,
439                                machine_name=self.machine_name)
440         # set the name of the action
441         enter_action.__name__ = 'init_duration'
442         # set the action parameters (the expression associated to the duration)
443         enter_params = [self.model.add_expression(duration_value)]
444         # instantiate the action
445         init_action = AbstractAction.build_action('duration',
446                                                   function=enter_action,
447                                                   l_params=enter_params,
448                                                   state_machine=self)
449         # and insert it at the beginning of the list of actions
450         lenter.insert(0, init_action)
451         self.model.add_init_action(self.machine_name,
452                                    self.states[state_name],
453                                    init_action)
454         # lstay = self.state_actions[name]['on_stay']\
455         #           if 'on_stay' in self.state_actions[name] else []
456         # stay_action = partial(make_TTL_increase_action,
457         #                       machine_name=self.machine_name)
458         # stay_action.__name__ = '+_time_spent'
459         # lstay.insert(0, AbstractAction.build_action('duration',
460         #                                             function=stay_action,
461         #                                             state_machine=self))
462         self.state_actions[state_name]['on_enter'] = lenter
463         # self.state_actions[name]['on_stay'] = lstay
465     def _add_state_actions(self, name, event, actions):
466         """Add the specified actions for the state with the given
467         name, associated with the event (e.g. 'on_stay', 'on_enter',
468         'on_exit'). Expressions contained in the parameters lists or
469         dicts are automatically expanded.
471         """
472         if name not in self.state_actions:
473             self.state_actions[name] = {}
474         l_actions = self._parse_action_list(actions)
475         self.state_actions[name][event] = l_actions
477     def _parse_action_list(self, actions):
478         """Parse the list of actions associated with a state."""
479         l_actions = []
480         for d_action in actions:
481             if 'action' in d_action:
482                 action = d_action['action']
483                 l_params = [self.model.add_expression(expr)
484                             for expr in d_action['l_params']]\
485                                 if 'l_params' in d_action\
486                                 else []
487                 # if 'd_params' in d_action:
488                 #     print(d_action['d_params'])
489                 d_params = {key: self.model.add_expression(expr)
490                             for key, expr in d_action['d_params'].items()}\
491                                 if 'd_params' in d_action\
492                                 else {}
493                 l_actions.append(
494                     AbstractAction.build_action('action',
495                                                 method=action,
496                                                 l_params=l_params,
497                                                 d_params=d_params,
498                                                 state_machine=self))
499             else: #TODO: dispatch through AbstractAction (factory),
500                   #make subclasses responsible for parameter parsing
501                 understood = False
502                 for keyword in ['increase', 'decrease',
503                                 'increase_stoch', 'decrease_stoch']:
504                     if keyword in d_action:
505                         # assume that increase statevar with rate
506                         l_actions.append(
507                             AbstractAction.build_action(
508                                 keyword,
509                                 statevar_name=d_action[keyword],
510                                 parameter=self.model.add_expression(d_action['rate']),
511                                 delta_t=self.model.delta_t,
512                                 state_machine=self
513                             )
514                         )
515                         understood = True
516                 for keyword in ['set_var']:
517                     if keyword in d_action:
518                         # assume that increase statevar with rate
519                         l_actions.append(
520                             AbstractAction.build_action(
521                                 keyword,
522                                 statevar_name=d_action[keyword],
523                                 parameter=self.model.add_expression(d_action['value']),
524                                 model=self.model
525                             )
526                         )
527                         understood = True
528                 for keyword in ['become', 'clone', 'produce_offspring']:
529                     if keyword in d_action:
530                         amount = d_action['amount'] if 'amount' in d_action else None
531                         probas = d_action['proba'] if 'proba' in d_action else None
532                         l_actions.append(
533                             AbstractAction.build_action(
534                                 keyword,
535                                 prototypes=d_action[keyword],
536                                 amount = amount,
537                                 probas = probas,
538                                 model = self.model
539                             )
540                         )
541                         understood = True
542                 for keyword in ['message']:
543                     if keyword in d_action:
544                         l_actions.append(
545                             AbstractAction.build_action(
546                                 keyword,
547                                 parameter=d_action[keyword]
548                             )
549                         )
550                         understood = True
551                 if not understood:
552                     print('ERROR !!!!') # but there is certainly a fatal error !
553                     print(d_action)
554         return l_actions
559     #----------------------------------------------------------------
560     # Output facilities
562     def describe_state(self, name):
563         """Return the description of the state with the specified
564         name.
566         """
567         desc = self._statedesc[name]
568         return "{} ({}):\n\t{}".format(name, desc['name'], desc['desc'])
570     def write_dot(self, filename, view_actions=True):
571         """Write the graph of the current state machine in the
572         specified filename, according to the dot/graphviz format.
574         """
576         rankdir = "LR" if self.graph.edges() else "TB"
577         output = '''digraph {
578           charset="utf-8"
579         '''
580         output += '''\trankdir={};
581         '''.format(rankdir)
582         output += '''
583         \tnode[fontsize=16, fontname=Arial, shape=box, style="filled,rounded"];
584         \tedge[minlen=1.5, fontname=Times, penwidth=1.5, tailtooltip="", headtooltip=""];
586         '''
587         for state in self.states:
588             name = state.name
589             name_lab = name
590             if 'duration' in self._statedesc[name]:
591                 name_lab += '&nbsp;{}'.format(CLOCK_SYMBOL)
592             actions = 'shape="Mrecord", label="{}", '.format(name_lab)
593             nodestyle = "filled,rounded"
594             if state.is_default:
595                 nodestyle += ",bold"
596             if state.autoremove:
597                 nodestyle += ",dotted"
598             if view_actions:
599                 onenter = ACTION_SYMBOL+'|'\
600                           if 'on_enter' in self._statedesc[name] else ''
601                 onstay = '|'+ACTION_SYMBOL\
602                          if 'on_stay' in self._statedesc[name] else ''
603                 onexit = '|'+ACTION_SYMBOL\
604                          if 'on_exit' in self._statedesc[name] else ''
605                 if onenter or onstay or onexit:
606                     actions = 'shape="Mrecord", label="{%s{\ %s\ %s}%s}", ' % (
607                         onenter, name_lab, onstay, onexit)
608             output += '\t"{}" [{}tooltip="{}", fillcolor={}, style="{}"] ;\n'.format(
609                 name, actions,
610                 self._statedesc[name]['tooltip'],
611                 # '\n\tON ENTER: {}'.format(self.state_actions[name]['on_enter'])\
612                 # if onenter else '' +\
613                 # '\n\tON STAY: {}'.format(self.state_actions[name]['on_stay'])\
614                 # if onstay else '' +\
615                 # '\n\tON EXIT: {}'.format(self.state_actions[name]['on_exit'])\
616                 # if onexit else '',
617                 self._statedesc[name]['fillcolor'],
618                 nodestyle)
619         for from_, to_ in SortedSet(self.graph.edges()):
620             for desc in self.graph.edge[from_][to_].values():
621                 edgetip = ''
622                 tail = 'none'
623                 if 'when' in desc:
624                     tail += WHEN_SYMBOL
625                     edgetip += 'WHEN: {}'.format(desc['when'])
626                 if 'escape' in desc:
627                     tail += ESCAPE_SYMBOL
628                     edgetip += 'ESCAPE: {}'.format(desc['escape'])
629                 if 'truecond' in desc:
630                     tail += COND_SYMBOL
631                     edgetip += 'COND: {}'.format(desc['truecond'])
632                 head = 'normalnone'
633                 if 'on_cross' in desc:
634                     head += CROSS_SYMBOL
635                     # edgetip += 'ON CROSS: {}\\n'.format(desc['on_cross'])
636                 output += ('\t"{}" -> "{}" [label="{}", labeltooltip="{}", '
637                            'arrowtail="{}", arrowhead="{}", dir=both, '
638                            'tooltip="{}", minlen=3, style="{}"];\n').format(
639                                from_, to_, desc['label'], desc['labeltooltip'],
640                                tail, head, edgetip, desc['type_id'].linestyle)
641         output += '}'
642         with open(filename, 'w', encoding="utf8") as f:
643             f.write(output)