Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2Mediation analysis 

3 

4Implements algorithm 1 ('parametric inference') and algorithm 2 

5('nonparametric inference') from: 

6 

7Imai, Keele, Tingley (2010). A general approach to causal mediation 

8analysis. Psychological Methods 15:4, 309-334. 

9 

10http://imai.princeton.edu/research/files/BaronKenny.pdf 

11 

12The algorithms are described on page 317 of the paper. 

13 

14In the case of linear models with no interactions involving the 

15mediator, the results should be similar or identical to the earlier 

16Barron-Kenny approach. 

17""" 

18import numpy as np 

19import pandas as pd 

20from statsmodels.graphics.utils import maybe_name_or_idx 

21 

22 

23class Mediation(object): 

24 """ 

25 Conduct a mediation analysis. 

26 

27 Parameters 

28 ---------- 

29 outcome_model : statsmodels model 

30 Regression model for the outcome. Predictor variables include 

31 the treatment/exposure, the mediator, and any other variables 

32 of interest. 

33 mediator_model : statsmodels model 

34 Regression model for the mediator variable. Predictor 

35 variables include the treatment/exposure and any other 

36 variables of interest. 

37 exposure : str or (int, int) tuple 

38 The name or column position of the treatment/exposure 

39 variable. If positions are given, the first integer is the 

40 column position of the exposure variable in the outcome model 

41 and the second integer is the position of the exposure variable 

42 in the mediator model. If a string is given, it must be the name 

43 of the exposure variable in both regression models. 

44 mediator : {str, int} 

45 The name or column position of the mediator variable in the 

46 outcome regression model. If None, infer the name from the 

47 mediator model formula (if present). 

48 moderators : dict 

49 Map from variable names or index positions to values of 

50 moderator variables that are held fixed when calculating 

51 mediation effects. If the keys are index position they must 

52 be tuples `(i, j)` where `i` is the index in the outcome model 

53 and `j` is the index in the mediator model. Otherwise the 

54 keys must be variable names. 

55 outcome_fit_kwargs : dict-like 

56 Keyword arguments to use when fitting the outcome model. 

57 mediator_fit_kwargs : dict-like 

58 Keyword arguments to use when fitting the mediator model. 

59 

60 Returns a ``MediationResults`` object. 

61 

62 Notes 

63 ----- 

64 The mediator model class must implement ``get_distribution``. 

65 

66 Examples 

67 -------- 

68 A basic mediation analysis using formulas: 

69 

70 >>> import statsmodels.api as sm 

71 >>> import statsmodels.genmod.families.links as links 

72 >>> probit = links.probit 

73 >>> outcome_model = sm.GLM.from_formula("cong_mesg ~ emo + treat + age + educ + gender + income", 

74 ... data, family=sm.families.Binomial(link=probit())) 

75 >>> mediator_model = sm.OLS.from_formula("emo ~ treat + age + educ + gender + income", data) 

76 >>> med = Mediation(outcome_model, mediator_model, "treat", "emo").fit() 

77 >>> med.summary() 

78 

79 A basic mediation analysis without formulas. This may be slightly 

80 faster than the approach using formulas. If there are any 

81 interactions involving the treatment or mediator variables this 

82 approach will not work, you must use formulas. 

83 

84 >>> import patsy 

85 >>> outcome = np.asarray(data["cong_mesg"]) 

86 >>> outcome_exog = patsy.dmatrix("emo + treat + age + educ + gender + income", data, 

87 ... return_type='dataframe') 

88 >>> probit = sm.families.links.probit 

89 >>> outcome_model = sm.GLM(outcome, outcome_exog, family=sm.families.Binomial(link=probit())) 

90 >>> mediator = np.asarray(data["emo"]) 

91 >>> mediator_exog = patsy.dmatrix("treat + age + educ + gender + income", data, 

92 ... return_type='dataframe') 

93 >>> mediator_model = sm.OLS(mediator, mediator_exog) 

94 >>> tx_pos = [outcome_exog.columns.tolist().index("treat"), 

95 ... mediator_exog.columns.tolist().index("treat")] 

96 >>> med_pos = outcome_exog.columns.tolist().index("emo") 

97 >>> med = Mediation(outcome_model, mediator_model, tx_pos, med_pos).fit() 

98 >>> med.summary() 

99 

100 A moderated mediation analysis. The mediation effect is computed 

101 for people of age 20. 

102 

103 >>> fml = "cong_mesg ~ emo + treat*age + emo*age + educ + gender + income", 

104 >>> outcome_model = sm.GLM.from_formula(fml, data, 

105 ... family=sm.families.Binomial()) 

106 >>> mediator_model = sm.OLS.from_formula("emo ~ treat*age + educ + gender + income", data) 

107 >>> moderators = {"age" : 20} 

108 >>> med = Mediation(outcome_model, mediator_model, "treat", "emo", 

109 ... moderators=moderators).fit() 

110 

111 References 

112 ---------- 

113 Imai, Keele, Tingley (2010). A general approach to causal mediation 

114 analysis. Psychological Methods 15:4, 309-334. 

115 http://imai.princeton.edu/research/files/BaronKenny.pdf 

116 

117 Tingley, Yamamoto, Hirose, Keele, Imai (2014). mediation : R 

118 package for causal mediation analysis. Journal of Statistical 

119 Software 59:5. http://www.jstatsoft.org/v59/i05/paper 

120 """ 

121 

122 def __init__(self, outcome_model, mediator_model, exposure, mediator=None, 

123 moderators=None, outcome_fit_kwargs=None, mediator_fit_kwargs=None): 

124 

125 self.outcome_model = outcome_model 

126 self.mediator_model = mediator_model 

127 self.exposure = exposure 

128 self.moderators = moderators if moderators is not None else {} 

129 

130 if mediator is None: 

131 self.mediator = self._guess_endog_name(mediator_model, 'mediator') 

132 else: 

133 self.mediator = mediator 

134 

135 self._outcome_fit_kwargs = (outcome_fit_kwargs if outcome_fit_kwargs 

136 is not None else {}) 

137 self._mediator_fit_kwargs = (mediator_fit_kwargs if mediator_fit_kwargs 

138 is not None else {}) 

139 

140 # We will be changing these so need to copy. 

141 self._outcome_exog = outcome_model.exog.copy() 

142 self._mediator_exog = mediator_model.exog.copy() 

143 

144 # Position of the exposure variable in the mediator model. 

145 self._exp_pos_mediator = self._variable_pos('exposure', 'mediator') 

146 

147 # Position of the exposure variable in the outcome model. 

148 self._exp_pos_outcome = self._variable_pos('exposure', 'outcome') 

149 

150 # Position of the mediator variable in the outcome model. 

151 self._med_pos_outcome = self._variable_pos('mediator', 'outcome') 

152 

153 

154 def _variable_pos(self, var, model): 

155 if model == 'mediator': 

156 mod = self.mediator_model 

157 else: 

158 mod = self.outcome_model 

159 

160 if var == 'mediator': 

161 return maybe_name_or_idx(self.mediator, mod)[1] 

162 

163 exp = self.exposure 

164 exp_is_2 = ((len(exp) == 2) and not isinstance(exp, str)) 

165 

166 if exp_is_2: 

167 if model == 'outcome': 

168 return exp[0] 

169 elif model == 'mediator': 

170 return exp[1] 

171 else: 

172 return maybe_name_or_idx(exp, mod)[1] 

173 

174 

175 def _guess_endog_name(self, model, typ): 

176 if hasattr(model, 'formula'): 

177 return model.formula.split("~")[0].strip() 

178 else: 

179 raise ValueError('cannot infer %s name without formula' % typ) 

180 

181 

182 def _simulate_params(self, result): 

183 """ 

184 Simulate model parameters from fitted sampling distribution. 

185 """ 

186 mn = result.params 

187 cov = result.cov_params() 

188 return np.random.multivariate_normal(mn, cov) 

189 

190 

191 def _get_mediator_exog(self, exposure): 

192 """ 

193 Return the mediator exog matrix with exposure set to the given 

194 value. Set values of moderated variables as needed. 

195 """ 

196 mediator_exog = self._mediator_exog 

197 if not hasattr(self.mediator_model, 'formula'): 

198 mediator_exog[:, self._exp_pos_mediator] = exposure 

199 for ix in self.moderators: 

200 v = self.moderators[ix] 

201 mediator_exog[:, ix[1]] = v 

202 else: 

203 # Need to regenerate the model exog 

204 df = self.mediator_model.data.frame.copy() 

205 df.loc[:, self.exposure] = exposure 

206 for vname in self.moderators: 

207 v = self.moderators[vname] 

208 df.loc[:, vname] = v 

209 klass = self.mediator_model.__class__ 

210 init_kwargs = self.mediator_model._get_init_kwds() 

211 model = klass.from_formula(data=df, **init_kwargs) 

212 mediator_exog = model.exog 

213 

214 return mediator_exog 

215 

216 

217 def _get_outcome_exog(self, exposure, mediator): 

218 """ 

219 Retun the exog design matrix with mediator and exposure set to 

220 the given values. Set values of moderated variables as 

221 needed. 

222 """ 

223 outcome_exog = self._outcome_exog 

224 if not hasattr(self.outcome_model, 'formula'): 

225 outcome_exog[:, self._med_pos_outcome] = mediator 

226 outcome_exog[:, self._exp_pos_outcome] = exposure 

227 for ix in self.moderators: 

228 v = self.moderators[ix] 

229 outcome_exog[:, ix[0]] = v 

230 else: 

231 # Need to regenerate the model exog 

232 df = self.outcome_model.data.frame.copy() 

233 df.loc[:, self.exposure] = exposure 

234 df.loc[:, self.mediator] = mediator 

235 for vname in self.moderators: 

236 v = self.moderators[vname] 

237 df.loc[:, vname] = v 

238 klass = self.outcome_model.__class__ 

239 init_kwargs = self.outcome_model._get_init_kwds() 

240 model = klass.from_formula(data=df, **init_kwargs) 

241 outcome_exog = model.exog 

242 

243 return outcome_exog 

244 

245 

246 def _fit_model(self, model, fit_kwargs, boot=False): 

247 klass = model.__class__ 

248 init_kwargs = model._get_init_kwds() 

249 endog = model.endog 

250 exog = model.exog 

251 if boot: 

252 ii = np.random.randint(0, len(endog), len(endog)) 

253 endog = endog[ii] 

254 exog = exog[ii, :] 

255 outcome_model = klass(endog, exog, **init_kwargs) 

256 return outcome_model.fit(**fit_kwargs) 

257 

258 

259 def fit(self, method="parametric", n_rep=1000): 

260 """ 

261 Fit a regression model to assess mediation. 

262 

263 Parameters 

264 ---------- 

265 method : str 

266 Either 'parametric' or 'bootstrap'. 

267 n_rep : int 

268 The number of simulation replications. 

269 

270 Returns a MediationResults object. 

271 """ 

272 

273 if method.startswith("para"): 

274 # Initial fit to unperturbed data. 

275 outcome_result = self._fit_model(self.outcome_model, self._outcome_fit_kwargs) 

276 mediator_result = self._fit_model(self.mediator_model, self._mediator_fit_kwargs) 

277 elif not method.startswith("boot"): 

278 raise("method must be either 'parametric' or 'bootstrap'") 

279 

280 indirect_effects = [[], []] 

281 direct_effects = [[], []] 

282 

283 for iter in range(n_rep): 

284 

285 if method == "parametric": 

286 # Realization of outcome model parameters from sampling distribution 

287 outcome_params = self._simulate_params(outcome_result) 

288 

289 # Realization of mediation model parameters from sampling distribution 

290 mediation_params = self._simulate_params(mediator_result) 

291 else: 

292 outcome_result = self._fit_model(self.outcome_model, 

293 self._outcome_fit_kwargs, boot=True) 

294 outcome_params = outcome_result.params 

295 mediator_result = self._fit_model(self.mediator_model, 

296 self._mediator_fit_kwargs, boot=True) 

297 mediation_params = mediator_result.params 

298 

299 # predicted outcomes[tm][te] is the outcome when the 

300 # mediator is set to tm and the outcome/exposure is set to 

301 # te. 

302 predicted_outcomes = [[None, None], [None, None]] 

303 for tm in 0, 1: 

304 mex = self._get_mediator_exog(tm) 

305 gen = self.mediator_model.get_distribution(mediation_params, 

306 mediator_result.scale, 

307 exog=mex) 

308 potential_mediator = gen.rvs(mex.shape[0]) 

309 

310 for te in 0, 1: 

311 oex = self._get_outcome_exog(te, potential_mediator) 

312 po = self.outcome_model.predict(outcome_params, oex) 

313 predicted_outcomes[tm][te] = po 

314 

315 for t in 0, 1: 

316 indirect_effects[t].append(predicted_outcomes[1][t] - predicted_outcomes[0][t]) 

317 direct_effects[t].append(predicted_outcomes[t][1] - predicted_outcomes[t][0]) 

318 

319 for t in 0, 1: 

320 indirect_effects[t] = np.asarray(indirect_effects[t]).T 

321 direct_effects[t] = np.asarray(direct_effects[t]).T 

322 

323 self.indirect_effects = indirect_effects 

324 self.direct_effects = direct_effects 

325 

326 rslt = MediationResults(self.indirect_effects, self.direct_effects) 

327 rslt.method = method 

328 return rslt 

329 

330 

331def _pvalue(vec): 

332 return 2 * min(sum(vec > 0), sum(vec < 0)) / float(len(vec)) 

333 

334 

335class MediationResults(object): 

336 """ 

337 A class for holding the results of a mediation analysis. 

338 

339 The following terms are used in the summary output: 

340 

341 ACME : average causal mediated effect 

342 ADE : average direct effect 

343 """ 

344 

345 def __init__(self, indirect_effects, direct_effects): 

346 

347 self.indirect_effects = indirect_effects 

348 self.direct_effects = direct_effects 

349 

350 indirect_effects_avg = [None, None] 

351 direct_effects_avg = [None, None] 

352 for t in 0, 1: 

353 indirect_effects_avg[t] = indirect_effects[t].mean(0) 

354 direct_effects_avg[t] = direct_effects[t].mean(0) 

355 

356 self.ACME_ctrl = indirect_effects_avg[0] 

357 self.ACME_tx = indirect_effects_avg[1] 

358 self.ADE_ctrl = direct_effects_avg[0] 

359 self.ADE_tx = direct_effects_avg[1] 

360 self.total_effect = (self.ACME_ctrl + self.ACME_tx + self.ADE_ctrl + self.ADE_tx) / 2 

361 

362 self.prop_med_ctrl = self.ACME_ctrl / self.total_effect 

363 self.prop_med_tx = self.ACME_tx / self.total_effect 

364 self.prop_med_avg = (self.prop_med_ctrl + self.prop_med_tx) / 2 

365 

366 self.ACME_avg = (self.ACME_ctrl + self.ACME_tx) / 2 

367 self.ADE_avg = (self.ADE_ctrl + self.ADE_tx) / 2 

368 

369 def summary(self, alpha=0.05): 

370 """ 

371 Provide a summary of a mediation analysis. 

372 """ 

373 

374 columns = ["Estimate", "Lower CI bound", "Upper CI bound", "P-value"] 

375 index = ["ACME (control)", "ACME (treated)", 

376 "ADE (control)", "ADE (treated)", 

377 "Total effect", 

378 "Prop. mediated (control)", 

379 "Prop. mediated (treated)", 

380 "ACME (average)", "ADE (average)", 

381 "Prop. mediated (average)"] 

382 smry = pd.DataFrame(columns=columns, index=index) 

383 

384 for i, vec in enumerate([self.ACME_ctrl, self.ACME_tx, 

385 self.ADE_ctrl, self.ADE_tx, 

386 self.total_effect, self.prop_med_ctrl, 

387 self.prop_med_tx, self.ACME_avg, 

388 self.ADE_avg, self.prop_med_avg]): 

389 

390 if ((vec is self.prop_med_ctrl) or (vec is self.prop_med_tx) or 

391 (vec is self.prop_med_avg)): 

392 smry.iloc[i, 0] = np.median(vec) 

393 else: 

394 smry.iloc[i, 0] = vec.mean() 

395 smry.iloc[i, 1] = np.percentile(vec, 100 * alpha / 2) 

396 smry.iloc[i, 2] = np.percentile(vec, 100 * (1 - alpha / 2)) 

397 smry.iloc[i, 3] = _pvalue(vec) 

398 

399 smry = smry.apply(pd.to_numeric, errors='coerce') 

400 

401 return smry