source: hidevalsplugin/0.11/hidevals/api.py @ 16751

Last change on this file since 16751 was 16751, checked in by Ryan J Ollos, 7 years ago

TracHideVals 1.2: Run autopep8 on codebase

File size: 4.5 KB
Line 
1# Created by Noah Kantrowitz on 2007-04-02.
2# Copyright (c) 2007 Noah Kantrowitz. All rights reserved.
3
4from trac.core import *
5from trac.env import IEnvironmentSetupParticipant
6from trac.perm import IPermissionRequestor, IPermissionGroupProvider, PermissionSystem
7from trac.config import ListOption
8
9try:
10    set = set
11except NameError:
12    from sets import Set as set
13
14import db_default
15
16
17class HideValsSystem(Component):
18    """Database provider for the TracHideVals plugin."""
19
20    group_providers = ExtensionPoint(IPermissionGroupProvider)
21
22    dont_filter = ListOption('hidevals', 'dont_filter',
23                             doc='Ticket fields to ignore when filtering.')
24
25    implements(IPermissionRequestor, IEnvironmentSetupParticipant)
26
27    # Public methods
28    def visible_fields(self, req, db=None):
29        db = db or self.env.get_db_cnx()
30        cursor = db.cursor()
31
32        groups = self._get_groups(req.authname)
33        fields = {}
34        for group in groups:
35            cursor.execute(
36                'SELECT field, value FROM hidevals WHERE sid = %s', (group,))
37            for f, v in cursor:
38                fields.setdefault(f, []).append(v)
39
40        return fields
41
42    # IPermissionRequestor methods
43    def get_permission_actions(self):
44        yield 'TICKET_HIDEVALS'
45
46    # IEnvironmentSetupParticipant methods
47    def environment_created(self):
48        self.found_db_version = 0
49        self.upgrade_environment(self.env.get_db_cnx())
50
51    def environment_needs_upgrade(self, db):
52        cursor = db.cursor()
53        cursor.execute("SELECT value FROM system WHERE name=%s",
54                       (db_default.name,))
55        value = cursor.fetchone()
56        if not value:
57            self.found_db_version = 0
58            return True
59        else:
60            self.found_db_version = int(value[0])
61            #self.log.debug('HideValsSystem: Found db version %s, current is %s' % (self.found_db_version, db_default.version))
62            return self.found_db_version < db_default.version
63
64    def upgrade_environment(self, db):
65        # 0.10 compatibility hack (thanks Alec)
66        try:
67            from trac.db import DatabaseManager
68            db_manager, _ = DatabaseManager(self.env)._get_connector()
69        except ImportError:
70            db_manager = db
71
72        # Insert the default table
73        old_data = {}  # {table_name: (col_names, [row, ...]), ...}
74        cursor = db.cursor()
75        if not self.found_db_version:
76            cursor.execute("INSERT INTO system (name, value) VALUES (%s, %s)",
77                           (db_default.name, db_default.version))
78        else:
79            cursor.execute("UPDATE system SET value=%s WHERE name=%s",
80                           (db_default.version, db_default.name))
81            for tbl in db_default.tables:
82                try:
83                    cursor.execute('SELECT * FROM %s' % tbl.name)
84                    old_data[tbl.name] = (
85                        [d[0] for d in cursor.description], cursor.fetchall())
86                    cursor.execute('DROP TABLE %s' % tbl.name)
87                except Exception, e:
88                    if 'OperationalError' not in e.__class__.__name__:
89                        raise e  # If it is an OperationalError, just move on to the next table
90
91        for tbl in db_default.tables:
92            for sql in db_manager.to_sql(tbl):
93                cursor.execute(sql)
94
95            # Try to reinsert any old data
96            if tbl.name in old_data:
97                data = old_data[tbl.name]
98                sql = 'INSERT INTO %s (%s) VALUES (%s)' % \
99                    (tbl.name, ','.join(data[0]),
100                     ','.join(['%s'] * len(data[0])))
101                for row in data[1]:
102                    try:
103                        cursor.execute(sql, row)
104                    except Exception, e:
105                        if 'OperationalError' not in e.__class__.__name__:
106                            raise e
107
108    # Private methods
109    def _get_groups(self, user):
110        # Get initial subjects
111        groups = set([user])
112        for provider in self.group_providers:
113            for group in provider.get_permission_groups(user):
114                groups.add(group)
115
116        perms = PermissionSystem(self.env).get_all_permissions()
117        repeat = True
118        while repeat:
119            repeat = False
120            for subject, action in perms:
121                if subject in groups and action.islower() and action not in groups:
122                    groups.add(action)
123                    repeat = True
124
125        return groups
Note: See TracBrowser for help on using the repository browser.