source: tracmigrateplugin/0.12/tracmigrate/admin.py @ 15443

Last change on this file since 15443 was 15443, checked in by Jun Omae, 8 years ago

TracMigratePlugin: fix not migrating tables when installed egg files in plugins directory (closes #12696)

File size: 12.4 KB
Line 
1# -*- coding: utf-8 -*-
2
3import os
4import re
5import shutil
6import sys
7import time
8from ConfigParser import RawConfigParser
9from subprocess import PIPE, Popen
10from tempfile import mkdtemp
11
12from trac.core import Component, implements, TracError
13from trac.admin.api import IAdminCommandProvider, get_dir_list
14from trac.db import sqlite_backend
15from trac.db.api import DatabaseManager, get_column_names, _parse_db_str
16from trac.env import Environment
17from trac.util.compat import any, close_fds
18from trac.util.text import printerr, printout
19
20
21def get_connection(env):
22    return DatabaseManager(env).get_connection()
23
24
25class MigrateEnvironment(Environment):
26
27    abstract = True  # avoid showing in plugins admin page
28    required = False
29
30    def is_component_enabled(self, cls):
31        name = self._component_name(cls)
32        if not any(name.startswith(mod) for mod in ('trac.', 'tracopt.')):
33            return False
34        return Environment.is_component_enabled(self, cls)
35
36
37class TracMigrationCommand(Component):
38
39    implements(IAdminCommandProvider)
40
41    _help = """\
42    Migrate to another database
43
44    This command migrates to another database in new Trac Environment or this
45    Trac Environment in-place. The new Trac Environment is specified in the
46    <tracenv>. If -i/--in-place option is specified, in-place migration.
47    Another database is specified in the <dburi> and must be empty."""
48
49    def get_admin_commands(self):
50        yield ('migrate', '<tracenv|-i|--in-place> <dburi>',
51               self._help, self._complete_migrate, self._do_migrate)
52
53    def _do_migrate(self, env_path, dburi):
54        if env_path in ('-i', '--in-place'):
55            return self._do_migrate_inplace(dburi)
56        else:
57            return self._do_migrate_to_env(env_path, dburi)
58
59    def _do_migrate_to_env(self, env_path, dburi):
60        try:
61            os.rmdir(env_path)  # remove directory if it's empty
62        except OSError:
63            pass
64        if os.path.exists(env_path) or os.path.lexists(env_path):
65            self._printerr('Cannot create Trac environment: %s: File exists',
66                           env_path)
67            return 1
68
69        dst_env = self._create_env(env_path, dburi)
70        src_dburi = self.config.get('trac', 'database')
71        src_db = get_connection(self.env)
72        dst_db = get_connection(dst_env)
73        self._copy_tables(src_db, dst_db, src_dburi, dburi)
74        self._copy_directories(src_db, dst_env)
75
76    def _do_migrate_inplace(self, dburi):
77        src_dburi = self.config.get('trac', 'database')
78        if src_dburi == dburi:
79            self._printerr('Source database and destination database are '
80                           'same: %s', dburi)
81            return 1
82
83        env_path = mkdtemp(prefix='migrate-',
84                           dir=os.path.dirname(self.env.path))
85        try:
86            dst_env = self._create_env(env_path, dburi)
87            src_db = get_connection(self.env)
88            dst_db = get_connection(dst_env)
89            self._copy_tables(src_db, dst_db, src_dburi, dburi, inplace=True)
90            del src_db
91            del dst_db
92            dst_env.shutdown()
93            dst_env = None
94            if dburi.startswith('sqlite:'):
95                schema, params = _parse_db_str(dburi)
96                dbpath = os.path.join(self.env.path, params['path'])
97                dbdir = os.path.dirname(dbpath)
98                if not os.path.isdir(dbdir):
99                    os.makedirs(dbdir)
100                shutil.copy(os.path.join(env_path, params['path']), dbpath)
101        finally:
102            shutil.rmtree(env_path)
103
104        self._backup_tracini(self.env)
105        self.config.set('trac', 'database', dburi)
106        self.config.save()
107
108    def _backup_tracini(self, env):
109        dir = env.path
110        src = env.config.filename
111        basename = os.path.basename
112        dst = src + '.migrate-%d' % int(time.time())
113        shutil.copyfile(src, dst)
114        self._printout('Back up conf/%s to conf/%s in %s.', basename(src),
115                       basename(dst), dir)
116
117    def _create_env(self, env_path, dburi):
118        parser = RawConfigParser()
119        parser.read([os.path.join(self.env.path, 'conf', 'trac.ini')])
120        options = dict(((section, name), value)
121                       for section in parser.sections()
122                       for name, value in parser.items(section))
123        options[('trac', 'database')] = dburi
124        options = sorted((section, name, value) for (section, name), value
125                                                in options.iteritems())
126
127        # create an environment without plugins
128        env = MigrateEnvironment(env_path, create=True, options=options)
129        env.shutdown()
130        # copy plugins directory
131        os.rmdir(os.path.join(env_path, 'plugins'))
132        shutil.copytree(os.path.join(self.env.path, 'plugins'),
133                        os.path.join(env_path, 'plugins'))
134        # create tables for plugins to upgrade in out-process
135        # (if Python is 2.5+, it can use "-m trac.admin.console" simply)
136        tracadmin = """\
137import sys; \
138from pkg_resources import load_entry_point; \
139sys.exit(load_entry_point('Trac', 'console_scripts', 'trac-admin')())"""
140        proc = Popen((sys.executable, '-c', tracadmin, env_path, 'upgrade'),
141                     stdin=PIPE, stdout=PIPE, stderr=PIPE, close_fds=close_fds)
142        stdout, stderr = proc.communicate(input='')
143        for f in (proc.stdin, proc.stdout, proc.stderr):
144            f.close()
145        if proc.returncode != 0:
146            raise TracError("upgrade command failed (stdout %r, stderr %r)" %
147                            (stdout, stderr))
148        return Environment(env_path)
149
150    def _copy_tables(self, src_db, dst_db, src_dburi, dburi, inplace=False):
151        self._printout('Copying tables:')
152
153        if src_dburi.startswith('sqlite:'):
154            src_db.cnx._eager = False  # avoid uses of eagar cursor
155        src_cursor = src_db.cursor()
156        if src_dburi.startswith('sqlite:'):
157            if type(src_cursor.cursor) is not sqlite_backend.PyFormatCursor:
158                raise AssertionError('src_cursor.cursor is %r' %
159                                     src_cursor.cursor)
160        src_tables = set(self._get_tables(src_dburi, src_cursor))
161        cursor = dst_db.cursor()
162        tables = set(self._get_tables(dburi, cursor)) & src_tables
163        sequences = set(self._get_sequences(dburi, cursor, tables))
164        progress = self._isatty()
165        replace_cast = self._get_replace_cast(src_db, dst_db, src_dburi, dburi)
166
167        # speed-up copying data with SQLite database
168        if dburi.startswith('sqlite:'):
169            cursor.execute('PRAGMA synchronous = OFF')
170            multirows_insert = sqlite_backend.sqlite_version >= (3, 7, 11)
171            max_paramters = 999
172        else:
173            multirows_insert = True
174            max_paramters = None
175
176        def copy_table(db, cursor, table):
177            src_cursor.execute('SELECT * FROM ' + src_db.quote(table))
178            columns = get_column_names(src_cursor)
179            n_rows = 100
180            if multirows_insert and max_paramters:
181                n_rows = min(n_rows, int(max_paramters // len(columns)))
182            quoted_table = db.quote(table)
183            holders = '(%s)' % ','.join(['%s'] * len(columns))
184            count = 0
185
186            cursor.execute('DELETE FROM ' + quoted_table)
187            while True:
188                rows = src_cursor.fetchmany(n_rows)
189                if not rows:
190                    break
191                count += len(rows)
192                if progress:
193                    self._printout('%d records\r  %s table... ',
194                                   count, table, newline=False)
195                if replace_cast is not None and table == 'report':
196                    rows = self._replace_report_query(rows, columns,
197                                                      replace_cast)
198                query = 'INSERT INTO %s (%s) VALUES ' % \
199                        (quoted_table, ','.join(map(db.quote, columns)))
200                if multirows_insert:
201                    cursor.execute(query + ','.join([holders] * len(rows)),
202                                   sum(rows, ()))
203                else:
204                    cursor.executemany(query + holders, rows)
205
206            return count
207
208        try:
209            cursor = dst_db.cursor()
210            for table in sorted(tables):
211                self._printout(%s table... ', table, newline=False)
212                count = copy_table(dst_db, cursor, table)
213                self._printout('%d records.', count)
214            for table in tables & sequences:
215                dst_db.update_sequence(cursor, table)
216            dst_db.commit()
217        except:
218            dst_db.rollback()
219            raise
220
221    def _get_replace_cast(self, src_db, dst_db, src_dburi, dst_dburi):
222        if src_dburi.split(':', 1) == dst_dburi.split(':', 1):
223            return None
224
225        type_re = re.compile(r' AS ([^)]+)')
226        def cast_type(db, type):
227            match = type_re.search(db.cast('name', type))
228            return match.group(1)
229
230        type_maps = dict(filter(lambda (src, dst): src != dst.lower(),
231                                ((cast_type(src_db, t).lower(),
232                                  cast_type(dst_db, t))
233                                 for t in ('text', 'int', 'int64'))))
234        if not type_maps:
235            return None
236
237        cast_re = re.compile(r'\bCAST\(\s*([^\s)]+)\s+AS\s+(%s)\s*\)' %
238                             '|'.join(type_maps), re.IGNORECASE)
239        def replace(match):
240            name, type = match.groups()
241            return 'CAST(%s AS %s)' % (name, type_maps.get(type.lower(), type))
242        def replace_cast(text):
243            return cast_re.sub(replace, text)
244        return replace_cast
245
246    def _copy_directories(self, src_db, env):
247        self._printout('Copying directories:')
248        directories = self._get_directories(src_db)
249        for name in directories:
250            self._printout(%s directory... ', name, newline=False)
251            src = os.path.join(self.env.path, name)
252            dst = os.path.join(env.path, name)
253            if os.path.isdir(dst):
254                shutil.rmtree(dst)
255            if os.path.isdir(src):
256                shutil.copytree(src, dst)
257            self._printout('done.')
258
259    def _replace_report_query(self, rows, columns, replace_cast):
260        idx = columns.index('query')
261        def replace(row):
262            row = list(row)
263            row[idx] = replace_cast(row[idx])
264            return tuple(row)
265        return [replace(row) for row in rows]
266
267    def _complete_migrate(self, args):
268        if len(args) == 1:
269            if args[0].startswith('-'):
270                return ('-i', '--in-place')
271            else:
272                return get_dir_list(args[0])
273
274    def _get_tables(self, dburi, cursor):
275        if dburi.startswith('sqlite:'):
276            query = "SELECT name FROM sqlite_master" \
277                    " WHERE type='table' AND NOT name='sqlite_sequence'"
278        elif dburi.startswith('postgres:'):
279            query = "SELECT tablename FROM pg_tables" \
280                    " WHERE schemaname = ANY (current_schemas(false))"
281        elif dburi.startswith('mysql:'):
282            query = "SHOW TABLES"
283        else:
284            raise TracError('Unsupported %s database' % dburi.split(':')[0])
285        cursor.execute(query)
286        return sorted([row[0] for row in cursor])
287
288    def _get_sequences(self, dburi, cursor, tables):
289        if dburi.startswith('postgres:'):
290            tables = set(tables)
291            cursor.execute("""\
292                SELECT c.relname
293                FROM pg_class c
294                INNER JOIN pg_namespace n ON c.relnamespace = n.oid
295                WHERE n.nspname = ANY (current_schemas(false))
296                AND c.relkind='S' AND c.relname LIKE %s ESCAPE '!'
297                """, ('%!_id!_seq',))
298            seqs = [name[:-len('_id_seq')] for name, in cursor]
299            return sorted(name for name in seqs if name in tables)
300        return []
301
302    def _get_directories(self, db):
303        version = self.env.get_version()
304        path = ('attachments', 'files')[version >= 28]
305        return (path, 'htdocs', 'templates', 'plugins')
306
307    def _printout(self, message, *args, **kwargs):
308        if args:
309            message %= args
310        printout(message, **kwargs)
311        sys.stdout.flush()
312
313    def _printerr(self, message, *args, **kwargs):
314        if args:
315            message %= args
316        printerr(message, **kwargs)
317        sys.stderr.flush()
318
319    def _isatty(self):
320        return sys.stdout.isatty() and sys.stderr.isatty()
Note: See TracBrowser for help on using the repository browser.