Skip to content

Commit

Permalink
Fixed the issue of len not working with the result of db query. (#554)
Browse files Browse the repository at this point in the history
* Fixed the issue of len not working with the result of db query.

The earlier implentation was adding an __len__ function to the result
object and that worked fine for old-style classes. With old-style
classes gone in Python 3, that started giving trouble. Fixed it by
writing a ResultSet class and a special SqliteResultSet which doesn't
support len, but supports bool.

Fixes #547.

* cleanup of db tests.
  • Loading branch information
anandology authored and iredmail committed Sep 27, 2019
1 parent 663cc97 commit fa27cc5
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 36 deletions.
36 changes: 18 additions & 18 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ class DBTest(unittest.TestCase):

def setUp(self):
self.db = setup_database(self.dbname, driver=self.driver)
self.db.query("DROP TABLE IF EXISTS person")
self.db.query("CREATE TABLE person (name text, email text, active boolean)")

def tearDown(self):
# there might be some error with the current connection, delete from a new connection
self.db = setup_database(self.dbname, driver=self.driver)
self.db.query("DROP TABLE person")
self.db.query("DROP TABLE IF EXISTS person")
self.db.query("DROP TABLE IF EXISTS mi")
self.db.ctx.db.close()

def _testable(self):
try:
Expand Down Expand Up @@ -148,11 +149,14 @@ def testPooling(self):
except ImportError:
return
db = setup_database(self.dbname, pooling=True)
self.assertEqual(db.ctx.db.__class__.__module__, "DBUtils.PooledDB")
db.select("person", limit=1)
try:
self.assertEqual(db.ctx.db.__class__.__module__, "DBUtils.PooledDB")
db.select("person", limit=1)
finally:
db.ctx.db.close()

def test_multiple_insert(self):
db = setup_database(self.dbname)
db = self.db
db.multiple_insert("person", [dict(name="a"), dict(name="b")], seqname=False)

assert db.select("person", where="name='a'").list()
Expand Down Expand Up @@ -186,15 +190,13 @@ def test_multiple_insert(self):

def test_result_is_unicode(self):
# TODO : not sure this test has still meaning with Py3
db = setup_database(self.dbname)
self.db.insert("person", False, name="user")
name = db.select("person")[0].name
name = self.db.select("person")[0].name
self.assertEqual(type(name), unicode)

def test_result_is_true(self):
db = setup_database(self.dbname)
self.db.insert("person", False, name="user")
self.assertEqual(bool(db.select("person")), True)
self.assertEqual(bool(self.db.select("person")), True)

def testBoolean(self):
def t(active):
Expand All @@ -207,15 +209,13 @@ def t(active):
t(True)

def test_insert_default_values(self):
db = setup_database(self.dbname)
db.insert("person")
self.db.insert("person")

def test_where(self):
db = setup_database(self.dbname)
db.insert("person", False, name="Foo")
d = db.where("person", name="Foo").list()
self.db.insert("person", False, name="Foo")
d = self.db.where("person", name="Foo").list()
assert len(d) == 1
d = db.where("person").list()
d = self.db.where("person").list()
assert len(d) == 1


Expand All @@ -225,7 +225,7 @@ class PostgresTest2(DBTest):
driver = "psycopg2"

def test_limit_with_unsafe_value(self):
db = setup_database(self.dbname)
db = self.db
db.insert("person", False, name="Foo")
assert len(db.select("person").list()) == 1

Expand All @@ -238,7 +238,7 @@ def test_limit_with_unsafe_value(self):
assert len(db.select("person").list()) == 1

def test_offset_with_unsafe_value(self):
db = setup_database(self.dbname)
db = self.db
db.insert("person", False, name="Foo")
assert len(db.select("person").list()) == 1

Expand Down
118 changes: 100 additions & 18 deletions web/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,100 @@ def sqlquote(a):
return sqlparam(a).sqlquery()


class BaseResultSet:
"""Base implementation of Result Set, the result of a db query.
"""

def __init__(self, cursor):
self.cursor = cursor
self.names = [x[0] for x in cursor.description]
self._index = 0

def list(self):
rows = [self._prepare_row(d) for d in self.cursor.fetchall()]
self._index += len(rows)
return rows

def _prepare_row(self, row):
return storage(dict(zip(self.names, row)))

def __iter__(self):
return self

def __next__(self):
row = self.cursor.fetchone()
if row is None:
raise StopIteration()
self._index += 1
return self._prepare_row(row)

next = __next__ # for python 2.7 support

def first(self, default=None):
"""Returns the first row of this ResultSet or None when there are no
elements.
If the optional argument default is specified, that is returned instead
of None when there are no elements.
"""
try:
return next(iter(self))
except StopIteration:
return default

def __getitem__(self, i):
# todo: slices
if i < self._index:
raise IndexError("already passed " + str(i))
try:
while i > self._index:
next(self)
self._index += 1
# now self._index == i
self._index += 1
return next(self)
except StopIteration:
raise IndexError(str(i))


class ResultSet(BaseResultSet):
"""The result of a database query.
"""

def __len__(self):
return int(self.cursor.rowcount)


class SqliteResultSet(BaseResultSet):
"""Result Set for sqlite.
Same functionaly as ResultSet except len is not supported.
"""

def __init__(self, cursor):
BaseResultSet.__init__(self, cursor)
self._head = None

def __next__(self):
if self._head is not None:
self._index += 1
return self._head
else:
return super().__next__()

def __bool__(self):
# The ResultSet class class doesn't need to support __bool__ explicity
# because it has __len__. Since SqliteResultSet doesn't support len,
# we need to peep into the result to find if the result is empty of not.
if self._head is None:
try:
self._head = next(self)
self._index -= 1 # reset the index
except StopIteration:
return False
return True


class Transaction:
"""Database transaction."""

Expand Down Expand Up @@ -737,26 +831,17 @@ def query(self, sql_query, vars=None, processed=False, _test=False):
self._db_execute(db_cursor, sql_query)

if db_cursor.description:
names = [x[0] for x in db_cursor.description]

def iterwrapper():
row = db_cursor.fetchone()
while row:
yield storage(dict(zip(names, row)))
row = db_cursor.fetchone()

out = iterbetter(iterwrapper())
out.__len__ = lambda: int(db_cursor.rowcount)
out.list = lambda: [
storage(dict(zip(names, x))) for x in db_cursor.fetchall()
]
return self.create_result_set(db_cursor)
else:
out = db_cursor.rowcount

if not self.ctx.transactions:
self.ctx.commit()
return out

def create_result_set(self, cursor):
return ResultSet(cursor)

def select(
self,
tables,
Expand Down Expand Up @@ -1232,11 +1317,8 @@ def __init__(self, **keywords):
def _process_insert_query(self, query, tablename, seqname):
return query, SQLQuery("SELECT last_insert_rowid();")

def query(self, *a, **kw):
out = DB.query(self, *a, **kw)
if isinstance(out, iterbetter):
del out.__len__
return out
def create_result_set(self, cursor):
return SqliteResultSet(cursor)


class FirebirdDB(DB):
Expand Down

0 comments on commit fa27cc5

Please sign in to comment.