import starsim as ss
class MyCustomSIR(ss.SIR):
def __init__(self, **kwargs):
super().__init__()
# Add custom parameters
self.define_pars(my_param=0.5)
self.update_pars(**kwargs)
def set_prognoses(self, uids, sources=None):
# Custom progression logic
super().set_prognoses(uids, sources)
# Additional custom logic here
Diseases
Disease class architecture
Starsim has a two-tier disease class hierarchy:
ss.Disease
- Base class for all diseases
- Defines step methods and basic disease structure
- Does not include transmission logic
- Used for non-communicable diseases (NCDs)
- Key methods:
define_states()
,set_prognoses()
ss.Infection
- Inherits from
ss.Disease
- Includes transmission logic via the
infect()
method - Used for all communicable diseases
- Handles network-based transmission automatically
- Applies network-specific betas and agent susceptibility/transmissibility
Important: Almost all diseases should inherit from ss.Infection
. Do not write your own infect()
method unless you have very specific requirements - the built-in method correctly handles:
- Looping over agents in each network
- Applying network- and disease-specific transmission probabilities
- Managing agent transmissibility and susceptibility
- Mixing pool logic
Key implementation methods
Method | Purpose | When to override |
---|---|---|
define_states() |
Initialize disease states (S, I, R, etc.) | Always for custom diseases |
set_prognoses() |
Set outcomes for newly infected people | (Almost) always for custom diseases |
step_state() |
Update states each timestep | When adding new state transitions |
step_die() |
Handle deaths | When disease has custom states |
infect() |
Handle transmission | Rarely - use built-in version |
Implementation patterns
Pattern 1: Extending existing diseases
When you need to modify an existing disease model, inherit from it and override specific methods:
Pattern 2: Adding new states
To add states to an existing disease:
class MySEIR(ss.SIR):
def __init__(self, **kwargs):
super().__init__()
# Add new parameters
self.define_pars(dur_exp=ss.lognorm_ex(0.5))
self.update_pars(**kwargs)
# Add new states
self.define_states(
'exposed', label='Exposed'),
ss.BoolState('ti_exposed', label='Time of exposure'),
ss.FloatArr(
)
@property
def infectious(self):
# Define who can transmit (both infected and exposed)
return self.infected | self.exposed
def step_state(self):
# Call parent state updates first
super().step_state()
# Add custom state transitions
= self.exposed & (self.ti_infected <= self.ti)
transitioning self.exposed[transitioning] = False
self.infected[transitioning] = True