1 | # -*- coding: utf-8 -*- |
---|
2 | |
---|
3 | import os |
---|
4 | import re |
---|
5 | import shutil |
---|
6 | import sys |
---|
7 | import time |
---|
8 | from ConfigParser import RawConfigParser |
---|
9 | from subprocess import PIPE, Popen |
---|
10 | from tempfile import mkdtemp |
---|
11 | |
---|
12 | from trac.core import Component, implements, TracError |
---|
13 | from trac.admin.api import IAdminCommandProvider, get_dir_list |
---|
14 | from trac.db import sqlite_backend |
---|
15 | from trac.db.api import DatabaseManager, get_column_names, _parse_db_str |
---|
16 | from trac.env import Environment |
---|
17 | from trac.util.compat import any, close_fds |
---|
18 | from trac.util.text import printerr, printout |
---|
19 | |
---|
20 | |
---|
21 | def get_connection(env): |
---|
22 | return DatabaseManager(env).get_connection() |
---|
23 | |
---|
24 | |
---|
25 | class MigrateEnvironment(Environment): |
---|
26 | |
---|
27 | abstract = True # avoid showing in plugins admin page |
---|
28 | required = False |
---|
29 | |
---|
30 | def is_component_enabled(self, cls): |
---|
31 | name = self._component_name(cls) |
---|
32 | if not any(name.startswith(mod) for mod in ('trac.', 'tracopt.')): |
---|
33 | return False |
---|
34 | return Environment.is_component_enabled(self, cls) |
---|
35 | |
---|
36 | |
---|
37 | class TracMigrationCommand(Component): |
---|
38 | |
---|
39 | implements(IAdminCommandProvider) |
---|
40 | |
---|
41 | _help = """\ |
---|
42 | Migrate to another database |
---|
43 | |
---|
44 | This command migrates to another database in new Trac Environment or this |
---|
45 | Trac Environment in-place. The new Trac Environment is specified in the |
---|
46 | <tracenv>. If -i/--in-place option is specified, in-place migration. |
---|
47 | Another database is specified in the <dburi> and must be empty.""" |
---|
48 | |
---|
49 | def get_admin_commands(self): |
---|
50 | yield ('migrate', '<tracenv|-i|--in-place> <dburi>', |
---|
51 | self._help, self._complete_migrate, self._do_migrate) |
---|
52 | |
---|
53 | def _do_migrate(self, env_path, dburi): |
---|
54 | if env_path in ('-i', '--in-place'): |
---|
55 | return self._do_migrate_inplace(dburi) |
---|
56 | else: |
---|
57 | return self._do_migrate_to_env(env_path, dburi) |
---|
58 | |
---|
59 | def _do_migrate_to_env(self, env_path, dburi): |
---|
60 | try: |
---|
61 | os.rmdir(env_path) # remove directory if it's empty |
---|
62 | except OSError: |
---|
63 | pass |
---|
64 | if os.path.exists(env_path) or os.path.lexists(env_path): |
---|
65 | self._printerr('Cannot create Trac environment: %s: File exists', |
---|
66 | env_path) |
---|
67 | return 1 |
---|
68 | |
---|
69 | dst_env = self._create_env(env_path, dburi) |
---|
70 | src_dburi = self.config.get('trac', 'database') |
---|
71 | src_db = get_connection(self.env) |
---|
72 | dst_db = get_connection(dst_env) |
---|
73 | self._copy_tables(src_db, dst_db, src_dburi, dburi) |
---|
74 | self._copy_directories(src_db, dst_env) |
---|
75 | |
---|
76 | def _do_migrate_inplace(self, dburi): |
---|
77 | src_dburi = self.config.get('trac', 'database') |
---|
78 | if src_dburi == dburi: |
---|
79 | self._printerr('Source database and destination database are ' |
---|
80 | 'same: %s', dburi) |
---|
81 | return 1 |
---|
82 | |
---|
83 | env_path = mkdtemp(prefix='migrate-', |
---|
84 | dir=os.path.dirname(self.env.path)) |
---|
85 | try: |
---|
86 | dst_env = self._create_env(env_path, dburi) |
---|
87 | src_db = get_connection(self.env) |
---|
88 | dst_db = get_connection(dst_env) |
---|
89 | self._copy_tables(src_db, dst_db, src_dburi, dburi, inplace=True) |
---|
90 | del src_db |
---|
91 | del dst_db |
---|
92 | dst_env.shutdown() |
---|
93 | dst_env = None |
---|
94 | if dburi.startswith('sqlite:'): |
---|
95 | schema, params = _parse_db_str(dburi) |
---|
96 | dbpath = os.path.join(self.env.path, params['path']) |
---|
97 | dbdir = os.path.dirname(dbpath) |
---|
98 | if not os.path.isdir(dbdir): |
---|
99 | os.makedirs(dbdir) |
---|
100 | shutil.copy(os.path.join(env_path, params['path']), dbpath) |
---|
101 | finally: |
---|
102 | shutil.rmtree(env_path) |
---|
103 | |
---|
104 | self._backup_tracini(self.env) |
---|
105 | self.config.set('trac', 'database', dburi) |
---|
106 | self.config.save() |
---|
107 | |
---|
108 | def _backup_tracini(self, env): |
---|
109 | dir = env.path |
---|
110 | src = env.config.filename |
---|
111 | basename = os.path.basename |
---|
112 | dst = src + '.migrate-%d' % int(time.time()) |
---|
113 | shutil.copyfile(src, dst) |
---|
114 | self._printout('Back up conf/%s to conf/%s in %s.', basename(src), |
---|
115 | basename(dst), dir) |
---|
116 | |
---|
117 | def _create_env(self, env_path, dburi): |
---|
118 | parser = RawConfigParser() |
---|
119 | parser.read([os.path.join(self.env.path, 'conf', 'trac.ini')]) |
---|
120 | options = dict(((section, name), value) |
---|
121 | for section in parser.sections() |
---|
122 | for name, value in parser.items(section)) |
---|
123 | options[('trac', 'database')] = dburi |
---|
124 | options = sorted((section, name, value) for (section, name), value |
---|
125 | in options.iteritems()) |
---|
126 | |
---|
127 | # create an environment without plugins |
---|
128 | env = MigrateEnvironment(env_path, create=True, options=options) |
---|
129 | env.shutdown() |
---|
130 | # copy plugins directory |
---|
131 | os.rmdir(os.path.join(env_path, 'plugins')) |
---|
132 | shutil.copytree(os.path.join(self.env.path, 'plugins'), |
---|
133 | os.path.join(env_path, 'plugins')) |
---|
134 | # create tables for plugins to upgrade in out-process |
---|
135 | # (if Python is 2.5+, it can use "-m trac.admin.console" simply) |
---|
136 | tracadmin = """\ |
---|
137 | import sys; \ |
---|
138 | from pkg_resources import load_entry_point; \ |
---|
139 | sys.exit(load_entry_point('Trac', 'console_scripts', 'trac-admin')())""" |
---|
140 | proc = Popen((sys.executable, '-c', tracadmin, env_path, 'upgrade'), |
---|
141 | stdin=PIPE, stdout=PIPE, stderr=PIPE, close_fds=close_fds) |
---|
142 | stdout, stderr = proc.communicate(input='') |
---|
143 | for f in (proc.stdin, proc.stdout, proc.stderr): |
---|
144 | f.close() |
---|
145 | if proc.returncode != 0: |
---|
146 | raise TracError("upgrade command failed (stdout %r, stderr %r)" % |
---|
147 | (stdout, stderr)) |
---|
148 | return Environment(env_path) |
---|
149 | |
---|
150 | def _copy_tables(self, src_db, dst_db, src_dburi, dburi, inplace=False): |
---|
151 | self._printout('Copying tables:') |
---|
152 | |
---|
153 | if src_dburi.startswith('sqlite:'): |
---|
154 | src_db.cnx._eager = False # avoid uses of eagar cursor |
---|
155 | src_cursor = src_db.cursor() |
---|
156 | if src_dburi.startswith('sqlite:'): |
---|
157 | if type(src_cursor.cursor) is not sqlite_backend.PyFormatCursor: |
---|
158 | raise AssertionError('src_cursor.cursor is %r' % |
---|
159 | src_cursor.cursor) |
---|
160 | src_tables = set(self._get_tables(src_dburi, src_cursor)) |
---|
161 | cursor = dst_db.cursor() |
---|
162 | tables = set(self._get_tables(dburi, cursor)) & src_tables |
---|
163 | sequences = set(self._get_sequences(dburi, cursor, tables)) |
---|
164 | progress = self._isatty() |
---|
165 | replace_cast = self._get_replace_cast(src_db, dst_db, src_dburi, dburi) |
---|
166 | |
---|
167 | # speed-up copying data with SQLite database |
---|
168 | if dburi.startswith('sqlite:'): |
---|
169 | cursor.execute('PRAGMA synchronous = OFF') |
---|
170 | multirows_insert = sqlite_backend.sqlite_version >= (3, 7, 11) |
---|
171 | max_paramters = 999 |
---|
172 | else: |
---|
173 | multirows_insert = True |
---|
174 | max_paramters = None |
---|
175 | |
---|
176 | def copy_table(db, cursor, table): |
---|
177 | src_cursor.execute('SELECT * FROM ' + src_db.quote(table)) |
---|
178 | columns = get_column_names(src_cursor) |
---|
179 | n_rows = 100 |
---|
180 | if multirows_insert and max_paramters: |
---|
181 | n_rows = min(n_rows, int(max_paramters // len(columns))) |
---|
182 | quoted_table = db.quote(table) |
---|
183 | holders = '(%s)' % ','.join(['%s'] * len(columns)) |
---|
184 | count = 0 |
---|
185 | |
---|
186 | cursor.execute('DELETE FROM ' + quoted_table) |
---|
187 | while True: |
---|
188 | rows = src_cursor.fetchmany(n_rows) |
---|
189 | if not rows: |
---|
190 | break |
---|
191 | count += len(rows) |
---|
192 | if progress: |
---|
193 | self._printout('%d records\r %s table... ', |
---|
194 | count, table, newline=False) |
---|
195 | if replace_cast is not None and table == 'report': |
---|
196 | rows = self._replace_report_query(rows, columns, |
---|
197 | replace_cast) |
---|
198 | query = 'INSERT INTO %s (%s) VALUES ' % \ |
---|
199 | (quoted_table, ','.join(map(db.quote, columns))) |
---|
200 | if multirows_insert: |
---|
201 | cursor.execute(query + ','.join([holders] * len(rows)), |
---|
202 | sum(rows, ())) |
---|
203 | else: |
---|
204 | cursor.executemany(query + holders, rows) |
---|
205 | |
---|
206 | return count |
---|
207 | |
---|
208 | try: |
---|
209 | cursor = dst_db.cursor() |
---|
210 | for table in sorted(tables): |
---|
211 | self._printout(' %s table... ', table, newline=False) |
---|
212 | count = copy_table(dst_db, cursor, table) |
---|
213 | self._printout('%d records.', count) |
---|
214 | for table in tables & sequences: |
---|
215 | dst_db.update_sequence(cursor, table) |
---|
216 | dst_db.commit() |
---|
217 | except: |
---|
218 | dst_db.rollback() |
---|
219 | raise |
---|
220 | |
---|
221 | def _get_replace_cast(self, src_db, dst_db, src_dburi, dst_dburi): |
---|
222 | if src_dburi.split(':', 1) == dst_dburi.split(':', 1): |
---|
223 | return None |
---|
224 | |
---|
225 | type_re = re.compile(r' AS ([^)]+)') |
---|
226 | def cast_type(db, type): |
---|
227 | match = type_re.search(db.cast('name', type)) |
---|
228 | return match.group(1) |
---|
229 | |
---|
230 | type_maps = dict(filter(lambda (src, dst): src != dst.lower(), |
---|
231 | ((cast_type(src_db, t).lower(), |
---|
232 | cast_type(dst_db, t)) |
---|
233 | for t in ('text', 'int', 'int64')))) |
---|
234 | if not type_maps: |
---|
235 | return None |
---|
236 | |
---|
237 | cast_re = re.compile(r'\bCAST\(\s*([^\s)]+)\s+AS\s+(%s)\s*\)' % |
---|
238 | '|'.join(type_maps), re.IGNORECASE) |
---|
239 | def replace(match): |
---|
240 | name, type = match.groups() |
---|
241 | return 'CAST(%s AS %s)' % (name, type_maps.get(type.lower(), type)) |
---|
242 | def replace_cast(text): |
---|
243 | return cast_re.sub(replace, text) |
---|
244 | return replace_cast |
---|
245 | |
---|
246 | def _copy_directories(self, src_db, env): |
---|
247 | self._printout('Copying directories:') |
---|
248 | directories = self._get_directories(src_db) |
---|
249 | for name in directories: |
---|
250 | self._printout(' %s directory... ', name, newline=False) |
---|
251 | src = os.path.join(self.env.path, name) |
---|
252 | dst = os.path.join(env.path, name) |
---|
253 | if os.path.isdir(dst): |
---|
254 | shutil.rmtree(dst) |
---|
255 | if os.path.isdir(src): |
---|
256 | shutil.copytree(src, dst) |
---|
257 | self._printout('done.') |
---|
258 | |
---|
259 | def _replace_report_query(self, rows, columns, replace_cast): |
---|
260 | idx = columns.index('query') |
---|
261 | def replace(row): |
---|
262 | row = list(row) |
---|
263 | row[idx] = replace_cast(row[idx]) |
---|
264 | return tuple(row) |
---|
265 | return [replace(row) for row in rows] |
---|
266 | |
---|
267 | def _complete_migrate(self, args): |
---|
268 | if len(args) == 1: |
---|
269 | if args[0].startswith('-'): |
---|
270 | return ('-i', '--in-place') |
---|
271 | else: |
---|
272 | return get_dir_list(args[0]) |
---|
273 | |
---|
274 | def _get_tables(self, dburi, cursor): |
---|
275 | if dburi.startswith('sqlite:'): |
---|
276 | query = "SELECT name FROM sqlite_master" \ |
---|
277 | " WHERE type='table' AND NOT name='sqlite_sequence'" |
---|
278 | elif dburi.startswith('postgres:'): |
---|
279 | query = "SELECT tablename FROM pg_tables" \ |
---|
280 | " WHERE schemaname = ANY (current_schemas(false))" |
---|
281 | elif dburi.startswith('mysql:'): |
---|
282 | query = "SHOW TABLES" |
---|
283 | else: |
---|
284 | raise TracError('Unsupported %s database' % dburi.split(':')[0]) |
---|
285 | cursor.execute(query) |
---|
286 | return sorted([row[0] for row in cursor]) |
---|
287 | |
---|
288 | def _get_sequences(self, dburi, cursor, tables): |
---|
289 | if dburi.startswith('postgres:'): |
---|
290 | tables = set(tables) |
---|
291 | cursor.execute("""\ |
---|
292 | SELECT c.relname |
---|
293 | FROM pg_class c |
---|
294 | INNER JOIN pg_namespace n ON c.relnamespace = n.oid |
---|
295 | WHERE n.nspname = ANY (current_schemas(false)) |
---|
296 | AND c.relkind='S' AND c.relname LIKE %s ESCAPE '!' |
---|
297 | """, ('%!_id!_seq',)) |
---|
298 | seqs = [name[:-len('_id_seq')] for name, in cursor] |
---|
299 | return sorted(name for name in seqs if name in tables) |
---|
300 | return [] |
---|
301 | |
---|
302 | def _get_directories(self, db): |
---|
303 | version = self.env.get_version() |
---|
304 | path = ('attachments', 'files')[version >= 28] |
---|
305 | return (path, 'htdocs', 'templates', 'plugins') |
---|
306 | |
---|
307 | def _printout(self, message, *args, **kwargs): |
---|
308 | if args: |
---|
309 | message %= args |
---|
310 | printout(message, **kwargs) |
---|
311 | sys.stdout.flush() |
---|
312 | |
---|
313 | def _printerr(self, message, *args, **kwargs): |
---|
314 | if args: |
---|
315 | message %= args |
---|
316 | printerr(message, **kwargs) |
---|
317 | sys.stderr.flush() |
---|
318 | |
---|
319 | def _isatty(self): |
---|
320 | return sys.stdout.isatty() and sys.stderr.isatty() |
---|