prototype SchemeQL/Roe-like thing in Python

Kragen Sitaker kragen at pobox.com
Thu Apr 22 01:10:36 EDT 2004


#!/usr/bin/python
# relational algebra in Python 2.3

# Allows you to construct simple relational algebra queries in Python
# and translates them on the fly to SQL.  Then you can iterate over
# the query results, which are Python dictionaries.  I've tried this
# in Python 2.3, but maybe it might work in other versions of Python 2.

# This is just the barest of bare bones here, but I implemented it in
# one evening.  You can see the kind of stuff it supports by looking
# in the test() method down below.

# Credit goes to Avi Bryant's vaporware "Roe" for Smalltalk, to
# SchemeQL, and to E. F. Codd for inspiration.

# A short list of the most egregiously missing pieces:
# - union, intersect, difference
# - more joins than just the simple inner join
# - a more convenient simple inner join
# - aggregate operations (and thus 'group by')
# - other computations in the list of output columns
# - aliases for output columns
# - relops other than '=' in the where list
# - more convenient syntax for specifying foo.column('baz'),
#   e.g. foo['baz'] or foo.baz
# - support for query.column('foo.bar')
# - insertion and update of data

import types

class WideningProjection(Exception): pass
class AliasingAmbiguous(Exception): pass
class AmbiguousAttributeName(Exception): pass
class CantProjectColumnsFromAnotherTable(Exception): pass

class query:
	def __init__(self, tables, dbconn=None, where=[], columns=[]):
		(self.tables, self.where, self.columns, self.dbconn) = (
			tables, where, columns, dbconn)
	def clonebut(self, **overrides):
		defaults = {
			'tables': self.tables, 
			'where': self.where,
			'columns': self.columns,
			'dbconn': self.dbconn,
		}
		defaults.update(overrides)
		return self.__class__(**defaults)
	def sql(self):
		return 'select %(columns)s from %(tables)s%(where)s' % {
			'columns': ', '.join(self.columnnames()),
			'tables': self.tablespecs(),
			'where': self.whereclause()
		}
	def default_table(self):
		if len(self.tables) != 1: return None
		else: return self.tables[0][1]
	def columnnames(self):
		columns = self.columns or [
			(alias, '*') for table, alias in self.tables]
		return [self.namefor(column) for column in columns]
	def tablespecs(self):
		rv = []
		for tablename, alias in self.tables:
			if alias != tablename: 
				rv.append('%s as %s' % (tablename, alias))
			else: rv.append(tablename)
		return ', '.join(rv)
	def whereclause(self):
		if not self.where: return ''
		rv = []
		for variable, relop, value in self.where:
			rv.append('%s %s %s' % (self.namefor(variable), relop, self.namefor(value)))
		return ' where ' + ' and '.join(rv)
	def project(self, *columns):
		ncolumns = []
		available_tables = [alias for tablename, alias in self.tables]
		for item in columns:
			if type(item) in types.StringTypes:
				item = self.column(item)
			if item.tablename not in available_tables:
				raise CantProjectColumnsFromAnotherTable(
					item.tablename, 
					item.attrname, 
					available_tables
				)
			ncolumns.append((item.tablename, item.attrname))
		if self.columns:
			for col in ncolumns:
				if col not in self.columns: 
					raise WideningProjection(col, self.columns)
		return self.clonebut(columns=ncolumns)
	def as(self, alias):
		# hmm, we need to handle renaming here somehow
		mytable = self.default_table()
		if not mytable: raise AliasingAmbiguous(alias, self.tables)
		return self.clonebut(tables=[(mytable, alias)])
	def __mul__(self, other):   # cartesian product
		# XXX shouldn't we be worrying about renaming
		# conflicts here?
		return self.clonebut(
			tables=(self.tables + other.tables), 
			where=(self.where + other.where),
			columns=(self.columns + other.columns)
		)
	def select(self, *conditions, **sugarclauses):
		eqclauses = [(self.column(var), '=', value) 
			for (var, value) in sugarclauses.items()]
		return self.clonebut(where=self.where + eqclauses + list(conditions))
	def column(self, attrname):
		if not self.columns: return self.starattr(attrname)
		candidates = [(table, col) for (table, col) in self.columns
			if col == attrname]
		if len(candidates) == 1:
			return tablecolumn(self, candidates[0])
		else:
			raise AmbiguousAttributeName(attrname, candidates)
	def starattr(self, attrname):
		# for when we're a "select *" query
		default = self.default_table()
		if not default:
			raise AmbiguousAttributeName(attrname, 
				[(tbl, attrname) for tbl in self.tables])
		return tablecolumn(self, (default, attrname))
	def namefor(self, thing):
		if type(thing) in types.StringTypes:
			# quoting that works for Postgres and MySQL
			# but breaks standard databases
			return "'%s'" % thing.replace('\\', '\\\\').replace("'", "''")
		elif type(thing) is types.IntType:
			return str(thing)
		elif type(thing) is types.TupleType:
			tablename, attrname = thing
		elif hasattr(thing, 'tablename'):
			tablename, attrname = thing.tablename, thing.attrname
		else:
			raise TypeError(thing)
		if tablename == self.default_table(): return attrname
		else: return '%s.%s' % (tablename, attrname)
	def __iter__(self):
		return query_results(self.sql(), self.dbconn)

class tablecolumn:
	def __init__(self, query, (tablename, attrname)):
		(self.query, self.tablename, self.attrname) = (
			query, tablename, attrname)
	def __eq__(self, other):
		return self, '=', other
	def name(self):
		return self.query.namefor(self)

class query_results:
	def __init__(self, sql, dbconn):
		dbconn.query(sql)
		self.results = dbconn.use_result()
	def __iter__(self): return self
	def next(self):
		rows = self.results.fetch_row(how=1)  # how=1 returns dicts
		if not rows: raise StopIteration
		return rows[0]

def table(tablename, dbconn=None):
	return query(tables=[(tablename, tablename)], dbconn=dbconn)

def ok(a, b): assert a == b, (a, b)

def test():
	foo = table("foo")
	ok(foo.sql(), 'select * from foo')
	ok(foo.project('a', 'b').sql(), 'select a, b from foo')
	ok(foo.as('bar').sql(), 'select * from foo as bar')
	ok((foo * foo.as('baz')).sql(), 
		'select foo.*, baz.* from foo, foo as baz')
	ok(foo.select(a=3).sql(), 'select * from foo where a = 3')
	ok(foo.select(a=5, b='asdf').sql(), "select * from foo where a = 5 and b = 'asdf'")
	ok(foo.select(b="Can't").sql(), "select * from foo where b = 'Can''t'")
	bar = table('bar')
	joinq = (foo * bar).select(bar.column('id') == foo.column('barid')).project(foo.column('a'), foo.column('b'), bar.column('d'))
	ok(joinq.sql(), "select foo.a, foo.b, bar.d from foo, bar where bar.id = foo.barid")
	child = foo.as('child')
	ok((foo * child).select(foo.column('id') == child.column('parentid'))
		.sql(), "select foo.*, child.* from foo, foo as child "
		"where foo.id = child.parentid")

test()

# I tested the MySQL connectivity as follows:
# import _mysql, relalg
# db = _mysql.connect(db='kragen')
# q = relalg.table('foo', db)
# list(q)
# list(q.select(a=3))
# list(q.select(a=4))

# Before this, I'd had to install the python-mysql Debian package to
# get the _mysql module, and I'd had to create the database, grant
# myself access, create a table, and put stuff in the table.  For
# future reference, MySQL doesn't have a 'varchar' type, just 'text',
# and a default MySQL installation allows "mysql -u root" to create
# databases and "grant all on newdatabasename.* to ''@'localhost'".
# Surprisingly, the '' around the @ really *are* important.


More information about the Kragen-hacks mailing list