Diseases

Disease class architecture

Starsim has a two-tier disease class hierarchy:

ss.Disease

  • Base class for all diseases
  • Defines step methods and basic disease structure
  • Does not include transmission logic
  • Used for non-communicable diseases (NCDs)
  • Key methods: define_states(), set_prognoses()

ss.Infection

  • Inherits from ss.Disease
  • Includes transmission logic via the infect() method
  • Used for all communicable diseases
  • Handles network-based transmission automatically
  • Applies network-specific betas and agent susceptibility/transmissibility

Important: Almost all diseases should inherit from ss.Infection. Do not write your own infect() method unless you have very specific requirements - the built-in method correctly handles:

  • Looping over agents in each network
  • Applying network- and disease-specific transmission probabilities
  • Managing agent transmissibility and susceptibility
  • Mixing pool logic

Key implementation methods

Method Purpose When to override
define_states() Initialize disease states (S, I, R, etc.) Always for custom diseases
set_prognoses() Set outcomes for newly infected people (Almost) always for custom diseases
step_state() Update states each timestep When adding new state transitions
step_die() Handle deaths When disease has custom states
infect() Handle transmission Rarely - use built-in version

Implementation patterns

Pattern 1: Extending existing diseases

When you need to modify an existing disease model, inherit from it and override specific methods:

import starsim as ss

class MyCustomSIR(ss.SIR):
    def __init__(self, **kwargs):
        super().__init__()
        # Add custom parameters
        self.define_pars(my_param=0.5)
        self.update_pars(**kwargs)
        
    def set_prognoses(self, uids, sources=None):
        # Custom progression logic
        super().set_prognoses(uids, sources)
        # Additional custom logic here

Pattern 2: Adding new states

To add states to an existing disease:

class MySEIR(ss.SIR):
    def __init__(self, **kwargs):
        super().__init__()
        # Add new parameters
        self.define_pars(dur_exp=ss.lognorm_ex(0.5))
        self.update_pars(**kwargs)
        
        # Add new states
        self.define_states(
            ss.BoolState('exposed', label='Exposed'),
            ss.FloatArr('ti_exposed', label='Time of exposure'),
        )

    @property
    def infectious(self):
        # Define who can transmit (both infected and exposed)
        return self.infected | self.exposed

    def step_state(self):
        # Call parent state updates first
        super().step_state()
        
        # Add custom state transitions
        transitioning = self.exposed & (self.ti_infected <= self.ti)
        self.exposed[transitioning] = False
        self.infected[transitioning] = True