1 | # Created by Noah Kantrowitz on 2007-04-02. |
---|
2 | # Copyright (c) 2007 Noah Kantrowitz. All rights reserved. |
---|
3 | |
---|
4 | from trac.core import * |
---|
5 | from trac.env import IEnvironmentSetupParticipant |
---|
6 | from trac.perm import IPermissionRequestor, IPermissionGroupProvider, PermissionSystem |
---|
7 | from trac.config import ListOption |
---|
8 | |
---|
9 | try: |
---|
10 | set = set |
---|
11 | except NameError: |
---|
12 | from sets import Set as set |
---|
13 | |
---|
14 | import db_default |
---|
15 | |
---|
16 | |
---|
17 | class HideValsSystem(Component): |
---|
18 | """Database provider for the TracHideVals plugin.""" |
---|
19 | |
---|
20 | group_providers = ExtensionPoint(IPermissionGroupProvider) |
---|
21 | |
---|
22 | dont_filter = ListOption('hidevals', 'dont_filter', |
---|
23 | doc='Ticket fields to ignore when filtering.') |
---|
24 | |
---|
25 | implements(IPermissionRequestor, IEnvironmentSetupParticipant) |
---|
26 | |
---|
27 | # Public methods |
---|
28 | def visible_fields(self, req, db=None): |
---|
29 | db = db or self.env.get_db_cnx() |
---|
30 | cursor = db.cursor() |
---|
31 | |
---|
32 | groups = self._get_groups(req.authname) |
---|
33 | fields = {} |
---|
34 | for group in groups: |
---|
35 | cursor.execute( |
---|
36 | 'SELECT field, value FROM hidevals WHERE sid = %s', (group,)) |
---|
37 | for f, v in cursor: |
---|
38 | fields.setdefault(f, []).append(v) |
---|
39 | |
---|
40 | return fields |
---|
41 | |
---|
42 | # IPermissionRequestor methods |
---|
43 | def get_permission_actions(self): |
---|
44 | yield 'TICKET_HIDEVALS' |
---|
45 | |
---|
46 | # IEnvironmentSetupParticipant methods |
---|
47 | def environment_created(self): |
---|
48 | self.found_db_version = 0 |
---|
49 | self.upgrade_environment(self.env.get_db_cnx()) |
---|
50 | |
---|
51 | def environment_needs_upgrade(self, db): |
---|
52 | cursor = db.cursor() |
---|
53 | cursor.execute("SELECT value FROM system WHERE name=%s", |
---|
54 | (db_default.name,)) |
---|
55 | value = cursor.fetchone() |
---|
56 | if not value: |
---|
57 | self.found_db_version = 0 |
---|
58 | return True |
---|
59 | else: |
---|
60 | self.found_db_version = int(value[0]) |
---|
61 | #self.log.debug('HideValsSystem: Found db version %s, current is %s' % (self.found_db_version, db_default.version)) |
---|
62 | return self.found_db_version < db_default.version |
---|
63 | |
---|
64 | def upgrade_environment(self, db): |
---|
65 | # 0.10 compatibility hack (thanks Alec) |
---|
66 | try: |
---|
67 | from trac.db import DatabaseManager |
---|
68 | db_manager, _ = DatabaseManager(self.env)._get_connector() |
---|
69 | except ImportError: |
---|
70 | db_manager = db |
---|
71 | |
---|
72 | # Insert the default table |
---|
73 | old_data = {} # {table_name: (col_names, [row, ...]), ...} |
---|
74 | cursor = db.cursor() |
---|
75 | if not self.found_db_version: |
---|
76 | cursor.execute("INSERT INTO system (name, value) VALUES (%s, %s)", |
---|
77 | (db_default.name, db_default.version)) |
---|
78 | else: |
---|
79 | cursor.execute("UPDATE system SET value=%s WHERE name=%s", |
---|
80 | (db_default.version, db_default.name)) |
---|
81 | for tbl in db_default.tables: |
---|
82 | try: |
---|
83 | cursor.execute('SELECT * FROM %s' % tbl.name) |
---|
84 | old_data[tbl.name] = ( |
---|
85 | [d[0] for d in cursor.description], cursor.fetchall()) |
---|
86 | cursor.execute('DROP TABLE %s' % tbl.name) |
---|
87 | except Exception, e: |
---|
88 | if 'OperationalError' not in e.__class__.__name__: |
---|
89 | raise e # If it is an OperationalError, just move on to the next table |
---|
90 | |
---|
91 | for tbl in db_default.tables: |
---|
92 | for sql in db_manager.to_sql(tbl): |
---|
93 | cursor.execute(sql) |
---|
94 | |
---|
95 | # Try to reinsert any old data |
---|
96 | if tbl.name in old_data: |
---|
97 | data = old_data[tbl.name] |
---|
98 | sql = 'INSERT INTO %s (%s) VALUES (%s)' % \ |
---|
99 | (tbl.name, ','.join(data[0]), |
---|
100 | ','.join(['%s'] * len(data[0]))) |
---|
101 | for row in data[1]: |
---|
102 | try: |
---|
103 | cursor.execute(sql, row) |
---|
104 | except Exception, e: |
---|
105 | if 'OperationalError' not in e.__class__.__name__: |
---|
106 | raise e |
---|
107 | |
---|
108 | # Private methods |
---|
109 | def _get_groups(self, user): |
---|
110 | # Get initial subjects |
---|
111 | groups = set([user]) |
---|
112 | for provider in self.group_providers: |
---|
113 | for group in provider.get_permission_groups(user): |
---|
114 | groups.add(group) |
---|
115 | |
---|
116 | perms = PermissionSystem(self.env).get_all_permissions() |
---|
117 | repeat = True |
---|
118 | while repeat: |
---|
119 | repeat = False |
---|
120 | for subject, action in perms: |
---|
121 | if subject in groups and action.islower() and action not in groups: |
---|
122 | groups.add(action) |
---|
123 | repeat = True |
---|
124 | |
---|
125 | return groups |
---|