summaryrefslogtreecommitdiff
path: root/src/mailman/database/types.py
blob: bbd040d961aaac672492e880e4e2dba296d0658f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright (C) 2007-2016 by the Free Software Foundation, Inc.
#
# This file is part of GNU Mailman.
#
# GNU Mailman is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# GNU Mailman is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# GNU Mailman.  If not, see <http://www.gnu.org/licenses/>.

"""Database type conversions."""

import uuid

from mailman import public
from sqlalchemy import Integer
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import CHAR, TypeDecorator, Unicode


@public
class Enum(TypeDecorator):
    """Handle Python 3.4 style enums.

    Stores an integer-based Enum as an integer in the database, and
    converts it on-the-fly.
    """
    impl = Integer

    def __init__(self, enum, *args, **kw):
        super().__init__(*args, **kw)
        self.enum = enum

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return value.value

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return self.enum(value)


@public
class UUID(TypeDecorator):
    """Platform-independent GUID type.

    Uses Postgresql's UUID type, otherwise uses
    CHAR(32), storing as stringified hex values.

    """
    impl = CHAR

    def load_dialect_impl(self, dialect):
        if dialect.name == 'postgresql':
            return dialect.type_descriptor(postgresql.UUID())
        else:
            return dialect.type_descriptor(CHAR(32))

    def process_bind_param(self, value, dialect):
        if value is None:
            return value
        elif dialect.name == 'postgresql':
            return str(value)
        else:
            if not isinstance(value, uuid.UUID):
                value = uuid.UUID(value)
            return '%.32x' % value.int

    def process_result_value(self, value, dialect):
        if value is None:
            return value
        else:
            return uuid.UUID(value)


@public
class SAUnicode(TypeDecorator):
    """Unicode datatype to support fixed length VARCHAR in MySQL.

    This type compiles to VARCHAR(255) in case of MySQL, and in case of
    other dailects defaults to the Unicode type.  This was created so
    that we don't have to alter the output of the default Unicode data
    type and it can still be used if needed in the codebase.
    """
    impl = Unicode


@compiles(SAUnicode)
def default_sa_unicode(element, compiler, **kw):
    return compiler.visit_Unicode(element, **kw)


@compiles(SAUnicode, 'mysql')
def compile_sa_unicode(element, compiler, **kw):
    # We hardcode the collate here to make string comparison case sensitive.
    return 'VARCHAR(255) COLLATE utf8_bin'


@public
class SAUnicodeLarge(TypeDecorator):
    """Similar to SAUnicode type, but compiles to VARCHAR(510).

    This is double size of SAUnicode defined above.
    """
    impl = Unicode


@compiles(SAUnicodeLarge, 'mysql')
def compile_sa_unicode_large(element, compiler, **kw):
    # We hardcode the collate here to make string comparison case sensitive.
    return 'VARCHAR(510) COLLATE utf8_bin'


@compiles(SAUnicode)
def defalt_sa_unicode_large(element, compiler, **kw):
    return compiler.visit_unicode(element, **kw)