Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import re 

2 

3from sqlalchemy import __version__ 

4from sqlalchemy import inspect 

5from sqlalchemy import schema 

6from sqlalchemy import sql 

7from sqlalchemy import types as sqltypes 

8from sqlalchemy.ext.compiler import compiles 

9from sqlalchemy.schema import CheckConstraint 

10from sqlalchemy.schema import Column 

11from sqlalchemy.schema import ForeignKeyConstraint 

12from sqlalchemy.sql.elements import quoted_name 

13from sqlalchemy.sql.expression import _BindParamClause 

14from sqlalchemy.sql.expression import _TextClause as TextClause 

15from sqlalchemy.sql.visitors import traverse 

16 

17from . import compat 

18 

19 

20def _safe_int(value): 

21 try: 

22 return int(value) 

23 except: 

24 return value 

25 

26 

27_vers = tuple( 

28 [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)] 

29) 

30sqla_110 = _vers >= (1, 1, 0) 

31sqla_1115 = _vers >= (1, 1, 15) 

32sqla_120 = _vers >= (1, 2, 0) 

33sqla_1216 = _vers >= (1, 2, 16) 

34sqla_13 = _vers >= (1, 3) 

35sqla_14 = _vers >= (1, 4) 

36try: 

37 from sqlalchemy import Computed # noqa 

38 

39 has_computed = True 

40 

41 has_computed_reflection = _vers >= (1, 3, 16) 

42except ImportError: 

43 has_computed = False 

44 has_computed_reflection = False 

45 

46AUTOINCREMENT_DEFAULT = "auto" 

47 

48 

49def _connectable_has_table(connectable, tablename, schemaname): 

50 if sqla_14: 

51 return inspect(connectable).has_table(tablename, schemaname) 

52 else: 

53 return connectable.dialect.has_table( 

54 connectable, tablename, schemaname 

55 ) 

56 

57 

58def _exec_on_inspector(inspector, statement, **params): 

59 if sqla_14: 

60 with inspector._operation_context() as conn: 

61 return conn.execute(statement, params) 

62 else: 

63 return inspector.bind.execute(statement, params) 

64 

65 

66def _server_default_is_computed(column): 

67 if not has_computed: 

68 return False 

69 else: 

70 return isinstance(column.computed, Computed) 

71 

72 

73def _table_for_constraint(constraint): 

74 if isinstance(constraint, ForeignKeyConstraint): 

75 return constraint.parent 

76 else: 

77 return constraint.table 

78 

79 

80def _columns_for_constraint(constraint): 

81 if isinstance(constraint, ForeignKeyConstraint): 

82 return [fk.parent for fk in constraint.elements] 

83 elif isinstance(constraint, CheckConstraint): 

84 return _find_columns(constraint.sqltext) 

85 else: 

86 return list(constraint.columns) 

87 

88 

89def _fk_spec(constraint): 

90 source_columns = [ 

91 constraint.columns[key].name for key in constraint.column_keys 

92 ] 

93 

94 source_table = constraint.parent.name 

95 source_schema = constraint.parent.schema 

96 target_schema = constraint.elements[0].column.table.schema 

97 target_table = constraint.elements[0].column.table.name 

98 target_columns = [element.column.name for element in constraint.elements] 

99 ondelete = constraint.ondelete 

100 onupdate = constraint.onupdate 

101 deferrable = constraint.deferrable 

102 initially = constraint.initially 

103 return ( 

104 source_schema, 

105 source_table, 

106 source_columns, 

107 target_schema, 

108 target_table, 

109 target_columns, 

110 onupdate, 

111 ondelete, 

112 deferrable, 

113 initially, 

114 ) 

115 

116 

117def _fk_is_self_referential(constraint): 

118 spec = constraint.elements[0]._get_colspec() 

119 tokens = spec.split(".") 

120 tokens.pop(-1) # colname 

121 tablekey = ".".join(tokens) 

122 return tablekey == constraint.parent.key 

123 

124 

125def _is_type_bound(constraint): 

126 # this deals with SQLAlchemy #3260, don't copy CHECK constraints 

127 # that will be generated by the type. 

128 # new feature added for #3260 

129 return constraint._type_bound 

130 

131 

132def _find_columns(clause): 

133 """locate Column objects within the given expression.""" 

134 

135 cols = set() 

136 traverse(clause, {}, {"column": cols.add}) 

137 return cols 

138 

139 

140def _remove_column_from_collection(collection, column): 

141 """remove a column from a ColumnCollection.""" 

142 

143 # workaround for older SQLAlchemy, remove the 

144 # same object that's present 

145 to_remove = collection[column.key] 

146 collection.remove(to_remove) 

147 

148 

149def _textual_index_column(table, text_): 

150 """a workaround for the Index construct's severe lack of flexibility""" 

151 if isinstance(text_, compat.string_types): 

152 c = Column(text_, sqltypes.NULLTYPE) 

153 table.append_column(c) 

154 return c 

155 elif isinstance(text_, TextClause): 

156 return _textual_index_element(table, text_) 

157 else: 

158 raise ValueError("String or text() construct expected") 

159 

160 

161class _textual_index_element(sql.ColumnElement): 

162 """Wrap around a sqlalchemy text() construct in such a way that 

163 we appear like a column-oriented SQL expression to an Index 

164 construct. 

165 

166 The issue here is that currently the Postgresql dialect, the biggest 

167 recipient of functional indexes, keys all the index expressions to 

168 the corresponding column expressions when rendering CREATE INDEX, 

169 so the Index we create here needs to have a .columns collection that 

170 is the same length as the .expressions collection. Ultimately 

171 SQLAlchemy should support text() expressions in indexes. 

172 

173 See SQLAlchemy issue 3174. 

174 

175 """ 

176 

177 __visit_name__ = "_textual_idx_element" 

178 

179 def __init__(self, table, text): 

180 self.table = table 

181 self.text = text 

182 self.key = text.text 

183 self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE) 

184 table.append_column(self.fake_column) 

185 

186 def get_children(self): 

187 return [self.fake_column] 

188 

189 

190@compiles(_textual_index_element) 

191def _render_textual_index_column(element, compiler, **kw): 

192 return compiler.process(element.text, **kw) 

193 

194 

195class _literal_bindparam(_BindParamClause): 

196 pass 

197 

198 

199@compiles(_literal_bindparam) 

200def _render_literal_bindparam(element, compiler, **kw): 

201 return compiler.render_literal_bindparam(element, **kw) 

202 

203 

204def _get_index_expressions(idx): 

205 return list(idx.expressions) 

206 

207 

208def _get_index_column_names(idx): 

209 return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)] 

210 

211 

212def _column_kwargs(col): 

213 if sqla_13: 

214 return col.kwargs 

215 else: 

216 return {} 

217 

218 

219def _get_constraint_final_name(constraint, dialect): 

220 if constraint.name is None: 

221 return None 

222 elif sqla_14: 

223 # for SQLAlchemy 1.4 we would like to have the option to expand 

224 # the use of "deferred" names for constraints as well as to have 

225 # some flexibility with "None" name and similar; make use of new 

226 # SQLAlchemy API to return what would be the final compiled form of 

227 # the name for this dialect. 

228 return dialect.identifier_preparer.format_constraint( 

229 constraint, _alembic_quote=False 

230 ) 

231 else: 

232 

233 # prior to SQLAlchemy 1.4, work around quoting logic to get at the 

234 # final compiled name without quotes. 

235 if hasattr(constraint.name, "quote"): 

236 # might be quoted_name, might be truncated_name, keep it the 

237 # same 

238 quoted_name_cls = type(constraint.name) 

239 else: 

240 quoted_name_cls = quoted_name 

241 

242 new_name = quoted_name_cls(str(constraint.name), quote=False) 

243 constraint = constraint.__class__(name=new_name) 

244 

245 if isinstance(constraint, schema.Index): 

246 # name should not be quoted. 

247 return dialect.ddl_compiler(dialect, None)._prepared_index_name( 

248 constraint 

249 ) 

250 else: 

251 # name should not be quoted. 

252 return dialect.identifier_preparer.format_constraint(constraint) 

253 

254 

255def _constraint_is_named(constraint, dialect): 

256 if sqla_14: 

257 if constraint.name is None: 

258 return False 

259 name = dialect.identifier_preparer.format_constraint( 

260 constraint, _alembic_quote=False 

261 ) 

262 return name is not None 

263 else: 

264 return constraint.name is not None 

265 

266 

267def _dialect_supports_comments(dialect): 

268 if sqla_120: 

269 return dialect.supports_comments 

270 else: 

271 return False 

272 

273 

274def _comment_attribute(obj): 

275 """return the .comment attribute from a Table or Column""" 

276 

277 if sqla_120: 

278 return obj.comment 

279 else: 

280 return None 

281 

282 

283def _is_mariadb(mysql_dialect): 

284 return ( 

285 mysql_dialect.server_version_info 

286 and "MariaDB" in mysql_dialect.server_version_info 

287 ) 

288 

289 

290def _mariadb_normalized_version_info(mysql_dialect): 

291 if len(mysql_dialect.server_version_info) > 5: 

292 return mysql_dialect.server_version_info[3:] 

293 else: 

294 return mysql_dialect.server_version_info 

295 

296 

297if sqla_14: 

298 from sqlalchemy import create_mock_engine 

299else: 

300 from sqlalchemy import create_engine 

301 

302 def create_mock_engine(url, executor): 

303 return create_engine( 

304 "postgresql://", strategy="mock", executor=executor 

305 )