0

我正在尝试实现将 aQAbstractProxyModel映射SqlTableModel到树状数据结构的 a 。该表有一个名为 的列parent_id,其值createIndex作为第三个参数添加到调用中。这个问题类似于这个用户的帖子,只是我在 Python 中工作,而不是在 C++ 中。

TreeView 正确加载:

见附图

但是当我尝试展开一个项目时,应用程序崩溃了。调试告诉我,似乎有一个无限循环indexrowCount并且mapToSource被调用。

我已经走投无路了。你有什么想法?请参阅下面的 MWE。

from __future__ import annotations
from PySide6.QtWidgets import QGridLayout
from PySide6.QtWidgets import QTreeView
from PySide6.QtWidgets import QApplication
from PySide6.QtWidgets import QMainWindow
from PySide6.QtWidgets import QWidget
from PySide6.QtCore import QModelIndex
from PySide6.QtCore import Qt
from PySide6.QtCore import Slot
from PySide6.QtCore import QAbstractProxyModel
from PySide6.QtSql import QSqlDatabase
from PySide6.QtSql import QSqlQuery
from PySide6.QtSql import QSqlTableModel


class CustomTreeModel(QAbstractProxyModel):
    def __init__(self, database: str, parent: QWidget = None):
        QAbstractProxyModel.__init__(self, parent)
        sourceModel = QSqlTableModel(parent, database)
        sourceModel.setTable('test')
        sourceModel.select()
        self.setSourceModel(sourceModel)

    def flags(self, proxyIndex: QModelIndex) -> Qt.ItemFlags:
        return Qt.ItemIsEnabled | Qt.ItemIsEditable

    def data(self, proxyIndex: QModelIndex, role: int):
        print("data")
        if proxyIndex.isValid:
            sourceIndex = self.mapToSource(proxyIndex)
            return sourceIndex.data(role)
        return None

    def index(
            self,
            row: int,
            column: int,
            parentIndex: QModelIndex
            ) -> QModelIndex:
        print("index")
        if row < 0 and column < 0:
            return QModelIndex()
        parentId = parentIndex.internalPointer()
        return self.createIndex(row, column, parentId)

    def mapFromSource(self, sourceIndex: QModelIndex) -> QModelIndex:
        print("mapFromSource")
        if self.isRootItem(sourceIndex):
            return QModelIndex()
        if sourceIndex.column() == 0:
            sourceId = sourceIndex.data()
        else:
            sourceId = sourceIndex.siblingAtColumn(0).data()
        parentId = self.getParentId(sourceId)
        childIds = self.getChildIds(parentId)
        row = childIds.index(sourceId)
        column = sourceIndex.column()
        proxyIndex = self.createIndex(row, column, parentId)
        return proxyIndex

    def mapToSource(self, proxyIndex: QModelIndex) -> QModelIndex:
        print("mapToSource")
        if self.isRootItem(proxyIndex):
            return QModelIndex()
        parentId = proxyIndex.internalPointer()
        childIds = self.getChildIds(parentId)
        rowId = childIds[proxyIndex.row()]
        rowIds = self.getAllIds()
        sourceRow = rowIds.index(rowId)
        sourceColumn = proxyIndex.column()
        sourceIndex = self.sourceModel().index(sourceRow, sourceColumn)
        return sourceIndex

    def rowCount(self, parentIndex: QModelIndex) -> int:
        print("rowCount")
        if parentIndex.column() > 0:
            return 0
        parentId = parentIndex.internalPointer()
        childIds = self.getChildIds(parentId)
        return len(childIds)

    def columnCount(self, parentIndex: QModelIndex) -> int:
        print("columnCount")
        if parentIndex.column() > 0:
            return 0
        numColumns = self.sourceModel().columnCount(parentIndex)
        return numColumns

    def parent(self, childIndex: QModelIndex) -> QModelIndex:
        print("parent")
        if childIndex.column() > 0:
            return QModelIndex()
        sourceIndex = self.mapToSource(childIndex)
        childId = sourceIndex.siblingAtColumn(0).data()
        parentId = self.getParentId(childId)
        if not parentId:
            return QModelIndex()
        parentParentId = self.getParentId(parentId)
        parentIds = self.getChildIds(parentParentId)
        parentRow = parentIds.index(parentId)
        parentIndex = self.createIndex(parentRow, 0, parentId)
        return parentIndex

    def getParentId(self, childId: str) -> str | None:
        table = self.sourceModel().tableName()
        query = QSqlQuery()
        query.prepare(f"""
            SELECT parent_id
            FROM {table}
            WHERE id=?
            """)
        query.addBindValue(childId)
        query.exec_()
        if query.first():
            parentId = query.value(0)
            return parentId if parentId else None
        return None

    def hasChildren(self, parentIndex: QModelIndex) -> bool:
        if parentIndex.column() > 0:
            return False
        parentId = parentIndex.internalPointer()
        childIds = self.getChildIds(parentId)
        return len(childIds) > 0

    def getAllIds(self) -> list[str]:
        table = self.sourceModel().tableName()
        query = QSqlQuery()
        query.prepare(f"""
            SELECT id
            FROM {table}
            """)
        query.exec_()
        ids = []
        while query.next():
            ids.append(query.value(0))
        return ids

    def getChildIds(self, parentId: str | None) -> list[str]:
        table = self.sourceModel().tableName()
        query = QSqlQuery()
        if not parentId or parentId == '':
            query.prepare(f"""
                SELECT id
                FROM {table}
                WHERE parent_id IS NULL OR parent_id=''
                """)
        else:
            query.prepare(f"""
                SELECT id
                FROM {table}
                WHERE parent_id=?""")
            query.addBindValue(parentId)
        query.exec_()

        childIds = []
        while query.next():
            childIds.append(query.value(0))
        return childIds

    def isRootItem(self, index: QModelIndex):
        return index.row() == -1 and index.column() == -1


class CustomTreeWidget(QWidget):
    def __init__(self, parent: QWidget = None):
        QWidget.__init__(self, parent)
        self.model: CustomTreeModel
        self.view = QTreeView(self)

        layout = QGridLayout(self)
        layout.addWidget(self.view)
        self.setLayout(layout)

    @Slot()
    def setDatabase(self):
        database = QSqlDatabase.database()
        model = CustomTreeModel(database, self)
        self.view.setModel(model)
        self.model = model


def initTestDatabase():
    query = QSqlQuery()
    query.prepare("""
        CREATE TABLE test (
            "id"    TEXT,
            "text"  TEXT,
            "parent_id" TEXT,
            PRIMARY KEY("id")
            );
            """)
    query.exec_()

    query = QSqlQuery()
    query.prepare("""
        INSERT INTO test (
            id, text, parent_id)
        VALUES
            (?, ?, ?),
            (?, ?, ?),
            (?, ?, ?),
            (?, ?, ?);
        """)
    query.addBindValue("ID101")
    query.addBindValue("Text")
    query.addBindValue(None)

    query.addBindValue("ID102")
    query.addBindValue("Text")
    query.addBindValue("ID101")

    query.addBindValue("ID103")
    query.addBindValue("Text")
    query.addBindValue("ID101")

    query.addBindValue("ID104")
    query.addBindValue("Text")
    query.addBindValue(None)
    query.exec_()


if __name__ == "__main__":
    projectDb = QSqlDatabase.addDatabase("QSQLITE")
    projectDb.setDatabaseName(":memory:")
    projectDb.open()

    initTestDatabase()

    app = QApplication()
    mainWindow = QMainWindow()
    widget = CustomTreeWidget(mainWindow)
    widget.setDatabase()

    mainWindow.setCentralWidget(widget)
    mainWindow.showMaximized()

    app.exec_()
4

0 回答 0