注意:本文不會將全部完整源碼貼出,只是將具體的思路以及部分源碼貼出,須要感興趣的讀者本身實驗而後實現吆。 html
公司最近的項目須要將以前的部分業務的數據庫鏈接方式改成jdbc,但因爲以前的項目都使用sqlarchemy做爲orm框架,該框架彷佛沒有支持jdbc,爲了能作最小的修改並知足需求,因此須要修改sqlarchemy的源碼。java
基本配置介紹python
sqlalchemy 版本:1.1.15mysql
使用jaydebeapi模塊調用jdbc鏈接mysqlgit
前提:github
1 學會使用jaydebeapi模塊,使用方法具體能夠參考:sql
https://pypi.python.org/pypi/JayDeBeApi數據庫
介紹的比較詳細的能夠參考:http://shuaizki.github.io/language_related/2013/06/22/introduction-to-jpype.htmlapi
jaydebeapi是一個基於jpype的在Cpython中能夠經過jdbc鏈接數據庫的模塊。該模塊的python代碼不多,基本上能夠分爲鏈接部分、遊標部分、結果轉換部分這三個。通常來講咱們可能須要修改的就是結果轉換部分,好比說sqlalchemy查詢時若是某條記錄中含TIME字段,那麼該字段通常要表現爲timedelta對象。而在jaydebeapi中則返回的是字符串對象,這樣在sqlalchemy中會報錯的。session
sqlarchemy爲咱們實現了ORM對象與語句的轉換,鏈接池,session(包括對線程的支持scope_session)等較爲上層的邏輯,但這些東西在這裏咱們不須要考慮(固然建立一個鏈接,生成curcor仍是要考慮的),咱們要考慮的僅僅是當sqlarchemy把sql語句以及參數傳過來的時候咱們該怎麼作,以及當sql語句執行後如何對結果進行轉換。
1 sql語句以及參數傳過來的時候咱們該怎麼作:
1.1 對參數進行轉義,防止sql注入
2 執行完sql語句後對結果如何處理:
2.1 咱們知道python的基礎sql模塊會對結果進行處理,好比說把NUll轉換爲None,把數據庫中的date字段轉換爲python的date對象等等
2.2 一些不知道該怎麼形容的數據:
當咱們查詢時,獲取的數據對應字段的元信息
當咱們update或者delete等操做時須要獲取影響了多少行
當咱們插入數據後,若是主鍵是自增字段,咱們通常(能夠說在sqlarchemy中這是必須)須要獲取該記錄的主鍵值
3 sqlalchemy增長代碼,使其支持咱們修改後的jaydebeapi
1.1解決方案:
人家pymysql咋搞,我就咋搞!
在pymysql.corsors文件中Cursor類中有一個叫作mogrify的方法,這個方法不只對參數轉義,並且會將參數放置到sql語句中組成完整的可執行sql語句。因此偷一些代碼而後稍加修改就是這樣:
#!/usr/bin/env python # -*- coding: utf-8 -*- from functools import partial from pymysql.converters import escape_item, escape_string import sys PY2 = sys.version_info[0] == 2 if PY2: import __builtin__ range_type = xrange text_type = unicode long_type = long str_type = basestring unichr = __builtin__.unichr else: range_type = range text_type = str long_type = int str_type = str unichr = chr def _ensure_bytes(x, encoding="utf8"): if isinstance(x, text_type): x = x.encode(encoding) return x def _escape_args(args, encoding): ensure_bytes = partial(_ensure_bytes, encoding=encoding) if isinstance(args, (tuple, list)): if PY2: args = tuple(map(ensure_bytes, args)) return tuple(escape(arg, encoding) for arg in args) elif isinstance(args, dict): if PY2: args = dict((ensure_bytes(key), ensure_bytes(val)) for (key, val) in args.items()) return dict((key, escape(val, encoding)) for (key, val) in args.items()) def escape(obj, charset, mapping=None): if isinstance(obj, str_type): return "'" + escape_string(obj) + "'" return escape_item(obj, charset, mapping=mapping) def mogrify(query, encoding, args=None): if PY2: # Use bytes on Python 2 always query = _ensure_bytes(query, encoding=encoding) if args is not None: # r = _escape_args(args, encoding) query = query % _escape_args(args, encoding) return query # 調用一下mogrigy函數 # print(mogrify("select * from ll where a in %s and b = %s", "utf8", [[2, 1], 3]))
2.1解決方案:
人家pymysql咋搞,我就咋搞!
在pymysql.converters中有一個名爲decoders的字典,這裏面存放了mysql字段與python對象的轉換關係!大概是這樣
def _convert_second_fraction(s): if not s: return 0 # Pad zeros to ensure the fraction length in microseconds s = s.ljust(6, '0') return int(s[:6]) DATETIME_RE = re.compile(r"(\d{1,4})-(\d{1,2})-(\d{1,2})[T ](\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?") def convert_datetime(obj): """Returns a DATETIME or TIMESTAMP column value as a datetime object: >>> datetime_or_None('2007-02-25 23:06:20') datetime.datetime(2007, 2, 25, 23, 6, 20) >>> datetime_or_None('2007-02-25T23:06:20') datetime.datetime(2007, 2, 25, 23, 6, 20) Illegal values are returned as None: >>> datetime_or_None('2007-02-31T23:06:20') is None True >>> datetime_or_None('0000-00-00 00:00:00') is None True """ if not PY2 and isinstance(obj, (bytes, bytearray)): obj = obj.decode('ascii') m = DATETIME_RE.match(obj) if not m: return convert_date(obj) try: groups = list(m.groups()) groups[-1] = _convert_second_fraction(groups[-1]) return datetime.datetime(*[ int(x) for x in groups ]) except ValueError: return convert_date(obj) TIMEDELTA_RE = re.compile(r"(-)?(\d{1,3}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?") def convert_timedelta(obj): """Returns a TIME column as a timedelta object: >>> timedelta_or_None('25:06:17') datetime.timedelta(1, 3977) >>> timedelta_or_None('-25:06:17') datetime.timedelta(-2, 83177) Illegal values are returned as None: >>> timedelta_or_None('random crap') is None True Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but can accept values as (+|-)DD HH:MM:SS. The latter format will not be parsed correctly by this function. """ if not PY2 and isinstance(obj, (bytes, bytearray)): obj = obj.decode('ascii') m = TIMEDELTA_RE.match(obj) if not m: return None try: groups = list(m.groups()) groups[-1] = _convert_second_fraction(groups[-1]) negate = -1 if groups[0] else 1 hours, minutes, seconds, microseconds = groups[1:] tdelta = datetime.timedelta( hours = int(hours), minutes = int(minutes), seconds = int(seconds), microseconds = int(microseconds) ) * negate return tdelta except ValueError: return None TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?") def convert_time(obj): """Returns a TIME column as a time object: >>> time_or_None('15:06:17') datetime.time(15, 6, 17) Illegal values are returned as None: >>> time_or_None('-25:06:17') is None True >>> time_or_None('random crap') is None True Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but can accept values as (+|-)DD HH:MM:SS. The latter format will not be parsed correctly by this function. Also note that MySQL's TIME column corresponds more closely to Python's timedelta and not time. However if you want TIME columns to be treated as time-of-day and not a time offset, then you can use set this function as the converter for FIELD_TYPE.TIME. """ if not PY2 and isinstance(obj, (bytes, bytearray)): obj = obj.decode('ascii') m = TIME_RE.match(obj) if not m: return None try: groups = list(m.groups()) groups[-1] = _convert_second_fraction(groups[-1]) hours, minutes, seconds, microseconds = groups return datetime.time(hour=int(hours), minute=int(minutes), second=int(seconds), microsecond=int(microseconds)) except ValueError: return None def convert_date(obj): """Returns a DATE column as a date object: >>> date_or_None('2007-02-26') datetime.date(2007, 2, 26) Illegal values are returned as None: >>> date_or_None('2007-02-31') is None True >>> date_or_None('0000-00-00') is None True """ if not PY2 and isinstance(obj, (bytes, bytearray)): obj = obj.decode('ascii') try: return datetime.date(*[ int(x) for x in obj.split('-', 2) ]) except ValueError: return None def convert_mysql_timestamp(timestamp): """Convert a MySQL TIMESTAMP to a Timestamp object. MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME: >>> mysql_timestamp_converter('2007-02-25 22:32:17') datetime.datetime(2007, 2, 25, 22, 32, 17) MySQL < 4.1 uses a big string of numbers: >>> mysql_timestamp_converter('20070225223217') datetime.datetime(2007, 2, 25, 22, 32, 17) Illegal values are returned as None: >>> mysql_timestamp_converter('2007-02-31 22:32:17') is None True >>> mysql_timestamp_converter('00000000000000') is None True """ if not PY2 and isinstance(timestamp, (bytes, bytearray)): timestamp = timestamp.decode('ascii') if timestamp[4] == '-': return convert_datetime(timestamp) timestamp += "0"*(14-len(timestamp)) # padding year, month, day, hour, minute, second = \ int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \ int(timestamp[8:10]), int(timestamp[10:12]), int(timestamp[12:14]) try: return datetime.datetime(year, month, day, hour, minute, second) except ValueError: return None def convert_set(s): if isinstance(s, (bytes, bytearray)): return set(s.split(b",")) return set(s.split(",")) def through(x): return x #def convert_bit(b): # b = "\x00" * (8 - len(b)) + b # pad w/ zeroes # return struct.unpack(">Q", b)[0] # # the snippet above is right, but MySQLdb doesn't process bits, # so we shouldn't either convert_bit = through def convert_characters(connection, field, data): field_charset = charset_by_id(field.charsetnr).name encoding = charset_to_encoding(field_charset) if field.flags & FLAG.SET: return convert_set(data.decode(encoding)) if field.flags & FLAG.BINARY: return data if connection.use_unicode: data = data.decode(encoding) elif connection.charset != field_charset: data = data.decode(encoding) data = data.encode(connection.encoding) return data encoders = { bool: escape_bool, int: escape_int, long_type: escape_int, float: escape_float, str: escape_str, text_type: escape_unicode, tuple: escape_sequence, list: escape_sequence, set: escape_sequence, frozenset: escape_sequence, dict: escape_dict, bytearray: escape_bytes, type(None): escape_None, datetime.date: escape_date, datetime.datetime: escape_datetime, datetime.timedelta: escape_timedelta, datetime.time: escape_time, time.struct_time: escape_struct_time, Decimal: escape_object, } if not PY2 or JYTHON or IRONPYTHON: encoders[bytes] = escape_bytes decoders = { FIELD_TYPE.BIT: convert_bit, FIELD_TYPE.TINY: int, FIELD_TYPE.SHORT: int, FIELD_TYPE.LONG: int, FIELD_TYPE.FLOAT: float, FIELD_TYPE.DOUBLE: float, FIELD_TYPE.LONGLONG: int, FIELD_TYPE.INT24: int, FIELD_TYPE.YEAR: int, FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp, FIELD_TYPE.DATETIME: convert_datetime, FIELD_TYPE.TIME: convert_timedelta, FIELD_TYPE.DATE: convert_date, FIELD_TYPE.SET: convert_set, FIELD_TYPE.BLOB: through, FIELD_TYPE.TINY_BLOB: through, FIELD_TYPE.MEDIUM_BLOB: through, FIELD_TYPE.LONG_BLOB: through, FIELD_TYPE.STRING: through, FIELD_TYPE.VAR_STRING: through, FIELD_TYPE.VARCHAR: through, FIELD_TYPE.DECIMAL: Decimal, FIELD_TYPE.NEWDECIMAL: Decimal, }
而在jaydebeapi中也有一些類似的代碼:
def _to_datetime(rs, col): java_val = rs.getTimestamp(col) if not java_val: return d = datetime.datetime.strptime(str(java_val)[:19], "%Y-%m-%d %H:%M:%S") d = d.replace(microsecond=int(str(java_val.getNanos())[:6])) return str(d) def _to_time(rs, col): java_val = rs.getTime(col) if not java_val: return return str(java_val) def _to_date(rs, col): java_val = rs.getDate(col) if not java_val: return # The following code requires Python 3.3+ on dates before year 1900. # d = datetime.datetime.strptime(str(java_val)[:10], "%Y-%m-%d") # return d.strftime("%Y-%m-%d") # Workaround / simpler soltution (see # https://github.com/baztian/jaydebeapi/issues/18): return str(java_val)[:10] def _to_binary(rs, col): java_val = rs.getObject(col) if java_val is None: return return str(java_val) def _java_to_py(java_method): def to_py(rs, col): java_val = rs.getObject(col) if java_val is None: return if PY2 and isinstance(java_val, (string_type, int, long, float, bool)): return java_val elif isinstance(java_val, (string_type, int, float, bool)): return java_val return getattr(java_val, java_method)() return to_py _to_double = _java_to_py('doubleValue') _to_int = _java_to_py('intValue') _to_boolean = _java_to_py('booleanValue') _DEFAULT_CONVERTERS = { # see # http://download.oracle.com/javase/8/docs/api/java/sql/Types.html # for possible keys 'TIMESTAMP': _to_datetime, 'TIME': _to_time, 'DATE': _to_date, 'BINARY': _to_binary, 'DECIMAL': _to_double, 'NUMERIC': _to_double, 'DOUBLE': _to_double, 'FLOAT': _to_double, 'TINYINT': _to_int, 'INTEGER': _to_int, 'SMALLINT': _to_int, 'BOOLEAN': _to_boolean, 'BIT': _to_boolean }
而後咱們稍微修改一下便可。
2.2解決方案
在jaydebeapi中的Cursor類中,有一個屬性叫作description這個屬性,經過他咱們就能獲取查詢時表的字段的元信息
在jaydebeapi中的Cursor類中,是有rowcount這個屬性的,他表示當咱們進行插入更新刪除操做時受影響的行數。
而在pymysql的cursors文件中的Cursor類中的_do_get_result方法中不只僅有受影響的行數rowcount,還有lastrowid這個屬性,他表示當咱們插入數據且對應主鍵是自增字段時,最後一條數據的主鍵值。可是在jaydebeapi中是沒有的,而這個屬性在sqlalchemy中偏偏是須要的,因此咱們要爲jaydebeapi的Cursor類加上這個屬性。代碼以下:
class Cursor(object): lastrowid = None rowcount = -1 _meta = None _prep = None _rs = None _description = None
...此處省略部分不相關代碼...
def execute(self, operation, parameters=None): if self._connection._closed: raise Error() if not parameters: parameters = () self._close_last() self._prep = self._connection.jconn.prepareStatement(operation) self._set_stmt_parms(self._prep, parameters) try: is_rs = self._prep.execute() # print is_rs except: _handle_sql_exception() # print(dir(self._prep)) # 若是是查詢的話 is_rs就是1 if is_rs: self._rs = self._prep.getResultSet() self._meta = self._rs.getMetaData() self.rowcount = -1 self.lastrowid = None # 插入/修改/刪除時 is_rs都爲0 else: self.rowcount = self._prep.getUpdateCount() self.lastrowid = int(self._prep.lastInsertID)
注意:上面的代碼中紅色的代碼是我新增的
3解決方案
sqlarchemy中底層數據庫鏈接模塊都放在dialects這個包中,這個包裏面有多個包分別是mysql oracle等數據庫的基本數據庫鏈接類,由於公司只使用mysql數據庫,因此僅僅作了mysql的jdbc擴展,就放到了mysql包中。
大致介紹一下咱們將要修改的或者用到的類:
MySQLDialect
位置:sqlarchemy.dialects.mysql.base
描述:它是一個提供了對mysql數據庫的鏈接、語句的執行等操做的基類,因此咱們須要新寫一個jdbcdialect類並繼承它,而後重寫某些方法。
爲何會用到:這個就不用多說了
ExecutionContext
位置:sqlarchemy.engine.interface
描述:經過這個東西咱們能夠獲取當前遊標的執行環境,好比說本次sql語句的執行影響了多少行,咱們剛插入的一行的自增主鍵值是多少。他也負責把咱們所寫的python ORM語句轉換爲能夠被底層數據庫模塊好比pymysql能夠執行的東西。
建立dialect類:
咱們知道使用sqlalchemy時首先須要建立一個engine,engine的第一個參數是一個URL,就像這樣:mysql+pymysql://user:password@host:port/db?charset=utf8
這段URL主要配置了三項:
配置1 首先聲明瞭咱們要鏈接mysql數據庫
配置2 而後配置了底層鏈接數據庫的dialect(這個單詞翻譯過來叫方言,就比如同是漢語(鏈接mysql),咱們能夠說山東話(pymysql)也能夠說湖南話(mysqldb))模塊是pymysql
配置3 配置了用戶名,密碼,主機地址,端口,數據庫名等信息
經過查看代碼咱們能夠看到:
上面中的配置1實際上就是說接下來要在 sqlalchemy.dialects.mysql包中獲取提供數據庫操做等方法的class了。
配置2實際上就是說 配置1想要找的的class我定義在了sqlalcehmy.dialects.mysql.pymysql中
配置3會做爲URL類包裝解析,而後做爲參數傳入dialect實例的create_connect_args方法,以獲取數據庫鏈接參數。
而後建立engine時還能夠指定許多額外的參數,好比說鏈接池的配置等,這裏面有幾個咱們須要注意的參數:
假如咱們沒有指定module(數據庫鏈接底層模塊),默認會調用dialect類的類方法dbapi。
假如咱們沒有指定creator(與數據庫創建鏈接的方法,通常是個函數)這個參數的話默認創建鏈接時會調用dialect實例的connect方法,並把create_connect_args返回的鏈接參數傳入。
當咱們第一次與數據庫創建鏈接時,會調用dialect實例的initialize方法,這個方法會作一系列操做,好比說獲取當前數據庫的版本信息:dialect實例的_get_server_version_info方法;獲取當前isolation級別:dialect實例的get_isolation_level方法
而後就很簡單了:在sqlalchemy中找到sqlalchemy.dialects.mysql這個目錄,而後新建一個名叫jaydebeapi的文件,並找到該目錄下的pymysql文件,你會看到:
from .mysqldb import MySQLDialect_mysqldb from ...util import langhelpers, py3k class MySQLDialect_pymysql(MySQLDialect_mysqldb): driver = 'pymysql' description_encoding = None # generally, these two values should be both True # or both False. PyMySQL unicode tests pass all the way back # to 0.4 either way. See [ticket:3337] supports_unicode_statements = True supports_unicode_binds = True def __init__(self, server_side_cursors=False, **kwargs): super(MySQLDialect_pymysql, self).__init__(**kwargs) self.server_side_cursors = server_side_cursors @langhelpers.memoized_property def supports_server_side_cursors(self): try: cursors = __import__('pymysql.cursors').cursors self._sscursor = cursors.SSCursor return True except (ImportError, AttributeError): return False @classmethod def dbapi(cls): return __import__('pymysql') if py3k: def _extract_error_code(self, exception): if isinstance(exception.args[0], Exception): exception = exception.args[0] return exception.args[0] dialect = MySQLDialect_pymysql
就這一個類,咱們只須要繼承這個類並重寫某些方法就是了。就像這樣:
#!/usr/bin/env python # -*- coding: utf-8 -*- import re from .pymysql import MySQLDialect_mysqldb class MySQLDialect_jaydebeapi(MySQLDialect_mysqldb): driver = 'jaydebeapi' @classmethod def dbapi(cls): return __import__('jaydebeapi') def connect(self, *cargs, **cparams): # get_jdbc_conn這個方法就本身寫吧,實際上就是用jaydebeapi生成一個鏈接,但須要注意,鏈接的autocommit要設置爲False return get_jdbc_conn(self.dbapi, **cparams) def _get_server_version_info(self, connection): dbapi_con = connection.connection cursor = dbapi_con.cursor() cursor.execute("select version()") version = str(cursor.fetchone()[0]) cursor.close() version_list = [] r = re.compile(r'[.\-]') for n in r.split(version): try: version_list.append(int(n)) except ValueError: version_list.append(n) return tuple(version_list) def _detect_charset(self, connection): """Sniff out the character set in use for connection results.""" try: # note: the SQL here would be # "SHOW VARIABLES LIKE 'character_set%%'" # print dir(connection.connection) cset_name = connection.connection.character_set_name except AttributeError: return 'utf8' else: return cset_name()
點1:
com.mysql.jdbc.exceptions.MySQLNonTransientConnectionException: Can’t call rollback when autocommit=true
1. 當開啓autocommit=true時,回滾沒有意義,不管成功/失敗都已經已經將事務提交
2. autocommit=false,咱們須要運行conn.commit()執行事務, 若是失敗則須要conn.rollback()對事務進行回滾;
點2:
嘗試鏈接mysql時報錯:Unknown system variable 'transaction_isolation'
這是由於個人MySQLDialect_jaydebeapi類中的_get_server_version_info方法返回寫死爲5.7.21版本,而在mysql的Mysqldialect類的get_isolation_level中,會判斷若是版本大於等於5.7.20的話執行SELECT @@transaction_isolation,反之會執行SELECT @@tx_isolation。
因而看了看本身的mysql版本是5.7.11 ,遂改變版本號。