What you want is a proxy model that acts as a bridge between QSql*Model
and the view. For that, you need to subclass QAbstractProxyModel
. You have to have a consistent way of finding parent-child relationships in proxy model and mapping them to the source model, so that might require keeping some tally in the proxy model.
When you are sub-classing QAbstractProxyModel
, you need to re-define, at minimum, these methods:
- rowCount
- columnCount
- parent
- index
- data
- mapToSource
- mapFromSource
Also, keep in mind that QAbstractProxyModel
does not auto-propagate signals through. So, in order to have the view be aware of changes in source model (like insert, delete, update), you need to pass them in the proxy model (while of course, updating your mappings in the proxy model).
It will require some work, but in the end you'll have a more flexible structure. And it will eliminate all the stuff that you need to do for synchronizing database and custom QAbstractItemModel
.
Edit
A custom proxy model that groups items from a flat model according to a given column:
import sys
from collections import namedtuple
import random
from PyQt4 import QtCore, QtGui
groupItem = namedtuple("groupItem",["name","children","index"])
rowItem = namedtuple("rowItem",["groupIndex","random"])
class GrouperProxyModel(QtGui.QAbstractProxyModel):
def __init__(self, parent=None):
super(GrouperProxyModel, self).__init__(parent)
self._rootItem = QtCore.QModelIndex()
self._groups = [] # list of groupItems
self._groupMap = {} # map of group names to group indexes
self._groupIndexes = [] # list of groupIndexes for locating group row
self._sourceRows = [] # map of source rows to group index
self._groupColumn = 0 # grouping column.
def setSourceModel(self, source, groupColumn=0):
super(GrouperProxyModel, self).setSourceModel(source)
# connect signals
self.sourceModel().columnsAboutToBeInserted.connect(self.columnsAboutToBeInserted.emit)
self.sourceModel().columnsInserted.connect(self.columnsInserted.emit)
self.sourceModel().columnsAboutToBeRemoved.connect(self.columnsAboutToBeRemoved.emit)
self.sourceModel().columnsRemoved.connect(self.columnsRemoved.emit)
self.sourceModel().rowsInserted.connect(self._rowsInserted)
self.sourceModel().rowsRemoved.connect(self._rowsRemoved)
self.sourceModel().dataChanged.connect(self._dataChanged)
# set grouping
self.groupBy(groupColumn)
def rowCount(self, parent):
if parent == self._rootItem:
# root level
return len(self._groups)
elif parent.internalPointer() == self._rootItem:
# children level
return len(self._groups[parent.row()].children)
else:
return 0
def columnCount(self, parent):
if self.sourceModel():
return self.sourceModel().columnCount(QtCore.QModelIndex())
else:
return 0
def index(self, row, column, parent):
if parent == self._rootItem:
# this is a group
return self.createIndex(row,column,self._rootItem)
elif parent.internalPointer() == self._rootItem:
return self.createIndex(row,column,self._groups[parent.row()].index)
else:
return QtCore.QModelIndex()
def parent(self, index):
parent = index.internalPointer()
if parent == self._rootItem:
return self._rootItem
else:
parentRow = self._getGroupRow(parent)
return self.createIndex(parentRow,0,self._rootItem)
def data(self, index, role):
if role == QtCore.Qt.DisplayRole:
parent = index.internalPointer()
if parent == self._rootItem:
return self._groups[index.row()].name
else:
parentRow = self._getGroupRow(parent)
sourceRow = self._sourceRows.index(self._groups[parentRow].children[index.row()])
sourceIndex = self.createIndex(sourceRow, index.column(), 0)
return self.sourceModel().data(sourceIndex, role)
return None
def flags(self, index):
return QtCore.Qt.ItemIsEnabled | QtCore.Qt.ItemIsSelectable
def headerData(self, section, orientation, role):
return self.sourceModel().headerData(section, orientation, role)
def mapToSource(self, index):
if not index.isValid():
return QtCore.QModelIndex()
parent = index.internalPointer()
if not parent.isValid():
return QtCore.QModelIndex()
elif parent == self._rootItem:
return QtCore.QModelIndex()
else:
rowItem_ = self._groups[parent.row()].children[index.row()]
sourceRow = self._sourceRows.index(rowItem_)
return self.createIndex(sourceRow, index.column(), QtCore.QModelIndex())
def mapFromSource(self, index):
rowItem_ = self._sourceRows[index.row()]
groupRow = self._getGroupRow(rowItem_.groupIndex)
itemRow = self._groups[groupRow].children.index(rowItem_)
return self.createIndex(itemRow,index.column(),self._groupIndexes[groupRow])
def _clearGroups(self):
self._groupMap = {}
self._groups = []
self._sourceRows = []
def groupBy(self,column=0):
self.beginResetModel()
self._clearGroups()
self._groupColumn = column
sourceModel = self.sourceModel()
for row in range(sourceModel.rowCount(QtCore.QModelIndex())):
groupName = sourceModel.data(self.createIndex(row,column,0),
QtCore.Qt.DisplayRole)
groupIndex = self._getGroupIndex(groupName)
rowItem_ = rowItem(groupIndex,random.random())
self._groups[groupIndex.row()].children.append(rowItem_)
self._sourceRows.append(rowItem_)
self.endResetModel()
def _getGroupIndex(self, groupName):
""" return the index for a group denoted with name.
if there is no group with given name, create and then return"""
if groupName in self._groupMap:
return self._groupMap[groupName]
else:
groupRow = len(self._groupMap)
groupIndex = self.createIndex(groupRow,0,self._rootItem)
self._groupMap[groupName] = groupIndex
self._groups.append(groupItem(groupName,[],groupIndex))
self._groupIndexes.append(groupIndex)
self.layoutChanged.emit()
return groupIndex
def _getGroupRow(self, groupIndex):
for i,x in enumerate(self._groupIndexes):
if id(groupIndex)==id(x):
return i
return 0
def _rowsInserted(self, parent, start, end):
for row in range(start, end+1):
groupName = self.sourceModel().data(self.createIndex(row,self._groupColumn,0),
QtCore.Qt.DisplayRole)
groupIndex = self._getGroupIndex(groupName)
self._getGroupRow(groupIndex)
groupItem_ = self._groups[self._getGroupRow(groupIndex)]
rowItem_ = rowItem(groupIndex,random.random())
groupItem_.children.append(rowItem_)
self._sourceRows.insert(row, rowItem_)
self.layoutChanged.emit()
def _rowsRemoved(self, parent, start, end):
for row in range(start, end+1):
rowItem_ = self._sourceRows[start]
groupIndex = rowItem_.groupIndex
groupItem_ = self._groups[self._getGroupRow(groupIndex)]
childrenRow = groupItem_.children.index(rowItem_)
groupItem_.children.pop(childrenRow)
self._sourceRows.pop(start)
if not len(groupItem_.children):
# remove the group
groupRow = self._getGroupRow(groupIndex)
groupName = self._groups[groupRow].name
self._groups.pop(groupRow)
self._groupIndexes.pop(groupRow)
del self._groupMap[groupName]
self.layoutChanged.emit()
def _dataChanged(self, topLeft, bottomRight):
topRow = topLeft.row()
bottomRow = bottomRight.row()
sourceModel = self.sourceModel()
# loop through all the changed data
for row in range(topRow,bottomRow+1):
oldGroupIndex = self._sourceRows[row].groupIndex
oldGroupItem = self._groups[self._getGroupRow(oldGroupIndex)]
newGroupName = sourceModel.data(self.createIndex(row,self._groupColumn,0),QtCore.Qt.DisplayRole)
if newGroupName != oldGroupItem.name:
# move to new group...
newGroupIndex = self._getGroupIndex(newGroupName)
newGroupItem = self._groups[self._getGroupRow(newGroupIndex)]
rowItem_ = self._sourceRows[row]
newGroupItem.children.append(rowItem_)
# delete from old group
oldGroupItem.children.remove(rowItem_)
if not len(oldGroupItem.children):
# remove the group
groupRow = self._getGroupRow(oldGroupItem.index)
groupName = oldGroupItem.name
self._groups.pop(groupRow)
self._groupIndexes.pop(groupRow)
del self._groupMap[groupName]
self.layoutChanged.emit()