Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of Patsy 

2# Copyright (C) 2012-2013 Nathaniel Smith <njs@pobox.com> 

3# See file LICENSE.txt for license information. 

4 

5# R-compatible spline basis functions 

6 

7# These are made available in the patsy.* namespace 

8__all__ = ["bs"] 

9 

10import numpy as np 

11 

12from patsy.util import have_pandas, no_pickling, assert_no_pickling 

13from patsy.state import stateful_transform 

14 

15if have_pandas: 

16 import pandas 

17 

18def _eval_bspline_basis(x, knots, degree): 

19 try: 

20 from scipy.interpolate import splev 

21 except ImportError: # pragma: no cover 

22 raise ImportError("spline functionality requires scipy") 

23 # 'knots' are assumed to be already pre-processed. E.g. usually you 

24 # want to include duplicate copies of boundary knots; you should do 

25 # that *before* calling this constructor. 

26 knots = np.atleast_1d(np.asarray(knots, dtype=float)) 

27 assert knots.ndim == 1 

28 knots.sort() 

29 degree = int(degree) 

30 x = np.atleast_1d(x) 

31 if x.ndim == 2 and x.shape[1] == 1: 

32 x = x[:, 0] 

33 assert x.ndim == 1 

34 # XX FIXME: when points fall outside of the boundaries, splev and R seem 

35 # to handle them differently. I don't know why yet. So until we understand 

36 # this and decide what to do with it, I'm going to play it safe and 

37 # disallow such points. 

38 if np.min(x) < np.min(knots) or np.max(x) > np.max(knots): 

39 raise NotImplementedError("some data points fall outside the " 

40 "outermost knots, and I'm not sure how " 

41 "to handle them. (Patches accepted!)") 

42 # Thanks to Charles Harris for explaining splev. It's not well 

43 # documented, but basically it computes an arbitrary b-spline basis 

44 # given knots and degree on some specificed points (or derivatives 

45 # thereof, but we don't use that functionality), and then returns some 

46 # linear combination of these basis functions. To get out the basis 

47 # functions themselves, we use linear combinations like [1, 0, 0], [0, 

48 # 1, 0], [0, 0, 1]. 

49 # NB: This probably makes it rather inefficient (though I haven't checked 

50 # to be sure -- maybe the fortran code actually skips computing the basis 

51 # function for coefficients that are zero). 

52 # Note: the order of a spline is the same as its degree + 1. 

53 # Note: there are (len(knots) - order) basis functions. 

54 n_bases = len(knots) - (degree + 1) 

55 basis = np.empty((x.shape[0], n_bases), dtype=float) 

56 for i in range(n_bases): 

57 coefs = np.zeros((n_bases,)) 

58 coefs[i] = 1 

59 basis[:, i] = splev(x, (knots, coefs, degree)) 

60 return basis 

61 

62def _R_compat_quantile(x, probs): 

63 #return np.percentile(x, 100 * np.asarray(probs)) 

64 probs = np.asarray(probs) 

65 quantiles = np.asarray([np.percentile(x, 100 * prob) 

66 for prob in probs.ravel(order="C")]) 

67 return quantiles.reshape(probs.shape, order="C") 

68 

69def test__R_compat_quantile(): 

70 def t(x, prob, expected): 

71 assert np.allclose(_R_compat_quantile(x, prob), expected) 

72 t([10, 20], 0.5, 15) 

73 t([10, 20], 0.3, 13) 

74 t([10, 20], [0.3, 0.7], [13, 17]) 

75 t(list(range(10)), [0.3, 0.7], [2.7, 6.3]) 

76 

77class BS(object): 

78 """bs(x, df=None, knots=None, degree=3, include_intercept=False, lower_bound=None, upper_bound=None) 

79 

80 Generates a B-spline basis for ``x``, allowing non-linear fits. The usual 

81 usage is something like:: 

82 

83 y ~ 1 + bs(x, 4) 

84 

85 to fit ``y`` as a smooth function of ``x``, with 4 degrees of freedom 

86 given to the smooth. 

87 

88 :arg df: The number of degrees of freedom to use for this spline. The 

89 return value will have this many columns. You must specify at least one 

90 of ``df`` and ``knots``. 

91 :arg knots: The interior knots to use for the spline. If unspecified, then 

92 equally spaced quantiles of the input data are used. You must specify at 

93 least one of ``df`` and ``knots``. 

94 :arg degree: The degree of the spline to use. 

95 :arg include_intercept: If ``True``, then the resulting 

96 spline basis will span the intercept term (i.e., the constant 

97 function). If ``False`` (the default) then this will not be the case, 

98 which is useful for avoiding overspecification in models that include 

99 multiple spline terms and/or an intercept term. 

100 :arg lower_bound: The lower exterior knot location. 

101 :arg upper_bound: The upper exterior knot location. 

102 

103 A spline with ``degree=0`` is piecewise constant with breakpoints at each 

104 knot, and the default knot positions are quantiles of the input. So if you 

105 find yourself in the situation of wanting to quantize a continuous 

106 variable into ``num_bins`` equal-sized bins with a constant effect across 

107 each bin, you can use ``bs(x, num_bins - 1, degree=0)``. (The ``- 1`` is 

108 because one degree of freedom will be taken by the intercept; 

109 alternatively, you could leave the intercept term out of your model and 

110 use ``bs(x, num_bins, degree=0, include_intercept=True)``. 

111 

112 A spline with ``degree=1`` is piecewise linear with breakpoints at each 

113 knot. 

114 

115 The default is ``degree=3``, which gives a cubic b-spline. 

116 

117 This is a stateful transform (for details see 

118 :ref:`stateful-transforms`). If ``knots``, ``lower_bound``, or 

119 ``upper_bound`` are not specified, they will be calculated from the data 

120 and then the chosen values will be remembered and re-used for prediction 

121 from the fitted model. 

122 

123 Using this function requires scipy be installed. 

124 

125 .. note:: This function is very similar to the R function of the same 

126 name. In cases where both return output at all (e.g., R's ``bs`` will 

127 raise an error if ``degree=0``, while patsy's will not), they should 

128 produce identical output given identical input and parameter settings. 

129 

130 .. warning:: I'm not sure on what the proper handling of points outside 

131 the lower/upper bounds is, so for now attempting to evaluate a spline 

132 basis at such points produces an error. Patches gratefully accepted. 

133 

134 .. versionadded:: 0.2.0 

135 """ 

136 def __init__(self): 

137 self._tmp = {} 

138 self._degree = None 

139 self._all_knots = None 

140 

141 def memorize_chunk(self, x, df=None, knots=None, degree=3, 

142 include_intercept=False, 

143 lower_bound=None, upper_bound=None): 

144 args = {"df": df, 

145 "knots": knots, 

146 "degree": degree, 

147 "include_intercept": include_intercept, 

148 "lower_bound": lower_bound, 

149 "upper_bound": upper_bound, 

150 } 

151 self._tmp["args"] = args 

152 # XX: check whether we need x values before saving them 

153 x = np.atleast_1d(x) 

154 if x.ndim == 2 and x.shape[1] == 1: 

155 x = x[:, 0] 

156 if x.ndim > 1: 

157 raise ValueError("input to 'bs' must be 1-d, " 

158 "or a 2-d column vector") 

159 # There's no better way to compute exact quantiles than memorizing 

160 # all data. 

161 self._tmp.setdefault("xs", []).append(x) 

162 

163 def memorize_finish(self): 

164 tmp = self._tmp 

165 args = tmp["args"] 

166 del self._tmp 

167 

168 if args["degree"] < 0: 

169 raise ValueError("degree must be greater than 0 (not %r)" 

170 % (args["degree"],)) 

171 if int(args["degree"]) != args["degree"]: 

172 raise ValueError("degree must be an integer (not %r)" 

173 % (self._degree,)) 

174 

175 # These are guaranteed to all be 1d vectors by the code above 

176 x = np.concatenate(tmp["xs"]) 

177 if args["df"] is None and args["knots"] is None: 

178 raise ValueError("must specify either df or knots") 

179 order = args["degree"] + 1 

180 if args["df"] is not None: 

181 n_inner_knots = args["df"] - order 

182 if not args["include_intercept"]: 

183 n_inner_knots += 1 

184 if n_inner_knots < 0: 

185 raise ValueError("df=%r is too small for degree=%r and " 

186 "include_intercept=%r; must be >= %s" 

187 % (args["df"], args["degree"], 

188 args["include_intercept"], 

189 # We know that n_inner_knots is negative; 

190 # if df were that much larger, it would 

191 # have been zero, and things would work. 

192 args["df"] - n_inner_knots)) 

193 if args["knots"] is not None: 

194 if len(args["knots"]) != n_inner_knots: 

195 raise ValueError("df=%s with degree=%r implies %s knots, " 

196 "but %s knots were provided" 

197 % (args["df"], args["degree"], 

198 n_inner_knots, len(args["knots"]))) 

199 else: 

200 # Need to compute inner knots 

201 knot_quantiles = np.linspace(0, 1, n_inner_knots + 2)[1:-1] 

202 inner_knots = _R_compat_quantile(x, knot_quantiles) 

203 if args["knots"] is not None: 

204 inner_knots = args["knots"] 

205 if args["lower_bound"] is not None: 

206 lower_bound = args["lower_bound"] 

207 else: 

208 lower_bound = np.min(x) 

209 if args["upper_bound"] is not None: 

210 upper_bound = args["upper_bound"] 

211 else: 

212 upper_bound = np.max(x) 

213 if lower_bound > upper_bound: 

214 raise ValueError("lower_bound > upper_bound (%r > %r)" 

215 % (lower_bound, upper_bound)) 

216 inner_knots = np.asarray(inner_knots) 

217 if inner_knots.ndim > 1: 

218 raise ValueError("knots must be 1 dimensional") 

219 if np.any(inner_knots < lower_bound): 

220 raise ValueError("some knot values (%s) fall below lower bound " 

221 "(%r)" 

222 % (inner_knots[inner_knots < lower_bound], 

223 lower_bound)) 

224 if np.any(inner_knots > upper_bound): 

225 raise ValueError("some knot values (%s) fall above upper bound " 

226 "(%r)" 

227 % (inner_knots[inner_knots > upper_bound], 

228 upper_bound)) 

229 all_knots = np.concatenate(([lower_bound, upper_bound] * order, 

230 inner_knots)) 

231 all_knots.sort() 

232 

233 self._degree = args["degree"] 

234 self._all_knots = all_knots 

235 

236 def transform(self, x, df=None, knots=None, degree=3, 

237 include_intercept=False, 

238 lower_bound=None, upper_bound=None): 

239 basis = _eval_bspline_basis(x, self._all_knots, self._degree) 

240 if not include_intercept: 

241 basis = basis[:, 1:] 

242 if have_pandas: 

243 if isinstance(x, (pandas.Series, pandas.DataFrame)): 

244 basis = pandas.DataFrame(basis) 

245 basis.index = x.index 

246 return basis 

247 

248 __getstate__ = no_pickling 

249 

250bs = stateful_transform(BS) 

251 

252def test_bs_compat(): 

253 from patsy.test_state import check_stateful 

254 from patsy.test_splines_bs_data import (R_bs_test_x, 

255 R_bs_test_data, 

256 R_bs_num_tests) 

257 lines = R_bs_test_data.split("\n") 

258 tests_ran = 0 

259 start_idx = lines.index("--BEGIN TEST CASE--") 

260 while True: 

261 if not lines[start_idx] == "--BEGIN TEST CASE--": 

262 break 

263 start_idx += 1 

264 stop_idx = lines.index("--END TEST CASE--", start_idx) 

265 block = lines[start_idx:stop_idx] 

266 test_data = {} 

267 for line in block: 

268 key, value = line.split("=", 1) 

269 test_data[key] = value 

270 # Translate the R output into Python calling conventions 

271 kwargs = { 

272 "degree": int(test_data["degree"]), 

273 # integer, or None 

274 "df": eval(test_data["df"]), 

275 # np.array() call, or None 

276 "knots": eval(test_data["knots"]), 

277 } 

278 if test_data["Boundary.knots"] != "None": 

279 lower, upper = eval(test_data["Boundary.knots"]) 

280 kwargs["lower_bound"] = lower 

281 kwargs["upper_bound"] = upper 

282 kwargs["include_intercept"] = (test_data["intercept"] == "TRUE") 

283 # Special case: in R, setting intercept=TRUE increases the effective 

284 # dof by 1. Adjust our arguments to match. 

285 # if kwargs["df"] is not None and kwargs["include_intercept"]: 

286 # kwargs["df"] += 1 

287 output = np.asarray(eval(test_data["output"])) 

288 if kwargs["df"] is not None: 

289 assert output.shape[1] == kwargs["df"] 

290 # Do the actual test 

291 check_stateful(BS, False, R_bs_test_x, output, **kwargs) 

292 tests_ran += 1 

293 # Set up for the next one 

294 start_idx = stop_idx + 1 

295 assert tests_ran == R_bs_num_tests 

296 

297test_bs_compat.slow = 1 

298 

299# This isn't checked by the above, because R doesn't have zero degree 

300# b-splines. 

301def test_bs_0degree(): 

302 x = np.logspace(-1, 1, 10) 

303 result = bs(x, knots=[1, 4], degree=0, include_intercept=True) 

304 assert result.shape[1] == 3 

305 expected_0 = np.zeros(10) 

306 expected_0[x < 1] = 1 

307 assert np.array_equal(result[:, 0], expected_0) 

308 expected_1 = np.zeros(10) 

309 expected_1[(x >= 1) & (x < 4)] = 1 

310 assert np.array_equal(result[:, 1], expected_1) 

311 expected_2 = np.zeros(10) 

312 expected_2[x >= 4] = 1 

313 assert np.array_equal(result[:, 2], expected_2) 

314 # Check handling of points that exactly fall on knots. They arbitrarily 

315 # get included into the larger region, not the smaller. This is consistent 

316 # with Python's half-open interval convention -- each basis function is 

317 # constant on [knot[i], knot[i + 1]). 

318 assert np.array_equal(bs([0, 1, 2], degree=0, knots=[1], 

319 include_intercept=True), 

320 [[1, 0], 

321 [0, 1], 

322 [0, 1]]) 

323 

324 result_int = bs(x, knots=[1, 4], degree=0, include_intercept=True) 

325 result_no_int = bs(x, knots=[1, 4], degree=0, include_intercept=False) 

326 assert np.array_equal(result_int[:, 1:], result_no_int) 

327 

328def test_bs_errors(): 

329 from nose.tools import assert_raises 

330 x = np.linspace(-10, 10, 20) 

331 # error checks: 

332 # out of bounds 

333 assert_raises(NotImplementedError, bs, x, 3, lower_bound=0) 

334 assert_raises(NotImplementedError, bs, x, 3, upper_bound=0) 

335 # must specify df or knots 

336 assert_raises(ValueError, bs, x) 

337 # df/knots match/mismatch (with and without intercept) 

338 # match: 

339 bs(x, df=10, include_intercept=False, knots=[0] * 7) 

340 bs(x, df=10, include_intercept=True, knots=[0] * 6) 

341 bs(x, df=10, include_intercept=False, knots=[0] * 9, degree=1) 

342 bs(x, df=10, include_intercept=True, knots=[0] * 8, degree=1) 

343 # too many knots: 

344 assert_raises(ValueError, 

345 bs, x, df=10, include_intercept=False, knots=[0] * 8) 

346 assert_raises(ValueError, 

347 bs, x, df=10, include_intercept=True, knots=[0] * 7) 

348 assert_raises(ValueError, 

349 bs, x, df=10, include_intercept=False, knots=[0] * 10, 

350 degree=1) 

351 assert_raises(ValueError, 

352 bs, x, df=10, include_intercept=True, knots=[0] * 9, 

353 degree=1) 

354 # too few knots: 

355 assert_raises(ValueError, 

356 bs, x, df=10, include_intercept=False, knots=[0] * 6) 

357 assert_raises(ValueError, 

358 bs, x, df=10, include_intercept=True, knots=[0] * 5) 

359 assert_raises(ValueError, 

360 bs, x, df=10, include_intercept=False, knots=[0] * 8, 

361 degree=1) 

362 assert_raises(ValueError, 

363 bs, x, df=10, include_intercept=True, knots=[0] * 7, 

364 degree=1) 

365 # df too small 

366 assert_raises(ValueError, 

367 bs, x, df=1, degree=3) 

368 assert_raises(ValueError, 

369 bs, x, df=3, degree=5) 

370 # bad degree 

371 assert_raises(ValueError, 

372 bs, x, df=10, degree=-1) 

373 assert_raises(ValueError, 

374 bs, x, df=10, degree=1.5) 

375 # upper_bound < lower_bound 

376 assert_raises(ValueError, 

377 bs, x, 3, lower_bound=1, upper_bound=-1) 

378 # multidimensional input 

379 assert_raises(ValueError, 

380 bs, np.column_stack((x, x)), 3) 

381 # unsorted knots are okay, and get sorted 

382 assert np.array_equal(bs(x, knots=[1, 4]), bs(x, knots=[4, 1])) 

383 # 2d knots 

384 assert_raises(ValueError, 

385 bs, x, knots=[[0], [20]]) 

386 # knots > upper_bound 

387 assert_raises(ValueError, 

388 bs, x, knots=[0, 20]) 

389 assert_raises(ValueError, 

390 bs, x, knots=[0, 4], upper_bound=3) 

391 # knots < lower_bound 

392 assert_raises(ValueError, 

393 bs, x, knots=[-20, 0]) 

394 assert_raises(ValueError, 

395 bs, x, knots=[-4, 0], lower_bound=-3) 

396 

397 

398 

399# differences between bs and ns (since the R code is a pile of copy-paste): 

400# - degree is always 3 

401# - different number of interior knots given df (b/c fewer dof used at edges I 

402# guess) 

403# - boundary knots always repeated exactly 4 times (same as bs with degree=3) 

404# - complications at the end to handle boundary conditions 

405# the 'rcs' function uses slightly different conventions -- in particular it 

406# picks boundary knots that are not quite at the edges of the data, which 

407# makes sense for a natural spline.