Algebraic Data Types 2018-06-04

By Max Woerner Chase

In this post:


I'm going to try to speed through setting up intervals and such a little, because I want to try scraping together a prototype for writing MIDI files. Long term, I'd like to try doing my own synthesis because that seems cool, but it's easier to synthesize a note when you've decided just how a note is represented.

Ugh, thinking about how to do this, I feel like I'm getting bitten by the fact that I never did get my ADT code into a workable state. I think I'll switch gears a little, to getting that working. I've come to the conclusion that I want to try using dataclasses-style class decorators, because attempting to use metaclasses was really unpleasant. I'd like to paste some old code to help make sense of all the stuff I'm saying, but the old code I have doesn't really make sense. Instead, I'll sketch out how I'd like to use the new thing.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
from typing import Generic, TypeVar

from wherever import algebraic

R = typing.TypeVar('R')
E = typing.TypeVar('E')

@algebraic
class Result(Generic[R, E]):

    Ok: R
    Err: E

The decorator should do most of the same work as the dataclass decorator, but with some differences:

Let's see what methods I want to generate:

Let's get started on this.

Okay, I've got code that resembles a stripped-down version of dataclasses, but because the layout is more constrained, there's less generated code. I think for now, I'm going to avoid flexibility until I discover I need it, so all algebraic classes will be comparable, orderable, and hashable with no control. Consumers will be able to tweak the repr logic, because that's not a big deal.

Actually, I really should put the customization in, but I want to consider carefully how this is going to work. If the algebraic code defines equality, then it "might as well" define hashing. If the algebraic code does not define equality, then it should not define hashing. I'm currently using equality in some pattern-matching code, but I could just use the functions directly... Except actually, it should be fine. What I'm going to do differently is, I'm going to say that the methods get added in bundles.

Except I just realized that my implementation of ordering is very wrong. When I get this into a repository, I'm going to strip it out.

Regardless, aside from that hiccup, I'm sure the code I wrote over the course of two days is perfect in every way.

... Okay, so I need to use the descriptor protocol to get at the name of wrapped methods. Fine.

... Okay, it coughs up bizarre errors if I attempt to actually instantiate generic classes. It looks like this is a result of some kind of optimization attempt that bypasses the normal MRO. The problem appears to be that I'm creating subclasses of a generic class, that add a built-in class as a superclass. I believe I can hack around this by looking for an attribute, and changing it if it exists.

I was incorrect. The problem was that I wasn't explicitly delegating __new__ calls through a base class, which I think confused things more than it should have? Anyway, next I removed the __init_subclass__ stuff, because I forgot that specializing a generic class in Python subclasses it. ... Actually, I can keep that in, I just need to restrict the scope.

All right, here's where things ended up. My main issue right now is, I made this in part to work well with typing information, and mypy reacts to my carefully crafted annotations like I threw garbage in its face. I can't actually use this in my projects until I work out how to alter mypy's interpretation of it.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
"""Class decorator for defining abstract data types."""

import collections
import keyword
import typing


def _as_tuple(value) -> tuple:
    """Convert non-tuple values to a 1-tuple, return tuples unchanged."""
    if not isinstance(value, tuple):
        value = (value,)
    return value


def _name(cls, function) -> str:
    """Return the name of a function accessed through a descriptor."""
    return function.__get__(None, cls).__name__


def _set_new_functions(cls, *functions) -> typing.Optional[str]:
    """Attempt to set the attributes corresponding to the functions on cls.

    If any attributes are already defined, fail before setting, and return the
    already-defined name.
    """
    for function in functions:
        if _name(cls, function) in cls.__dict__:
            return _name(cls, function)
    for function in functions:
        setattr(cls, _name(cls, function), function)
    return None


class MatchFailure(BaseException):
    """An exception that signals a failure in ADT matching."""


def desugar(constructor: type, instance: tuple) -> tuple:
    """Return the inside of an ADT instance, given its constructor."""
    # I really do want to match exactly, I swear.
    if type(instance) is not constructor:
        raise MatchFailure
    return tuple.__getitem__(instance, slice(None))


def _unpack(instance: tuple) -> tuple:
    """Return the inside of any ADT instance.

    This function is not meant for general use.
    """
    return desugar(type(instance), instance)


def _algebraic_base(obj):
    return getattr(obj.__class__, '__algebraic_base__', None)


class AlgebraicConstructor:

    """Base class for ADT Constructor classes."""

    __slots__ = ()

    def __new__(cls, *args, **kwargs):
        """Explicitly forward to base class."""
        return super().__new__(cls, *args, **kwargs)


def __repr__(self):
    return self.__class__.__qualname__ + (
        f'({", ".join(repr(item) for item in _unpack(self))})')


def __eq__(self, other):
    if other.__class__ is self.__class__:
        return _unpack(self) == _unpack(other)
    if _algebraic_base(other) is _algebraic_base(self):
        return False
    return NotImplemented


def __ne__(self, other):
    if other.__class__ is self.__class__:
        return _unpack(self) != _unpack(other)
    if _algebraic_base(other) is _algebraic_base(self):
        return True
    return NotImplemented


def __lt__(self, other):
    if other.__class__ is self.__class__:
        return _unpack(self) < _unpack(other)
    if _algebraic_base(other) is _algebraic_base(self):
        order = _algebraic_base(self).__subclass_order__
        self_index = order.index(self.__class__)
        other_index = order.index(other.__class__)
        return self_index < other_index
    return NotImplemented


def __le__(self, other):
    if other.__class__ is self.__class__:
        return _unpack(self) <= _unpack(other)
    if _algebraic_base(other) is _algebraic_base(self):
        order = _algebraic_base(self).__subclass_order__
        self_index = order.index(self.__class__)
        other_index = order.index(other.__class__)
        return self_index <= other_index
    return NotImplemented


def __gt__(self, other):
    if other.__class__ is self.__class__:
        return _unpack(self) > _unpack(other)
    if _algebraic_base(other) is _algebraic_base(self):
        order = _algebraic_base(self).__subclass_order__
        self_index = order.index(self.__class__)
        other_index = order.index(other.__class__)
        return self_index > other_index
    return NotImplemented


def __ge__(self, other):
    if other.__class__ is self.__class__:
        return _unpack(self) >= _unpack(other)
    if _algebraic_base(other) is _algebraic_base(self):
        order = _algebraic_base(self).__subclass_order__
        self_index = order.index(self.__class__)
        other_index = order.index(other.__class__)
        return self_index >= other_index
    return NotImplemented


def __hash__(self):
    return hash(_unpack(self))


def _make_constructor(_cls, name, length, subclasses, subclass_order):
    class Constructor(_cls, AlgebraicConstructor, tuple):
        """Auto-generated subclass of an ADT."""
        __slots__ = ()

        __algebraic_base__ = _cls

        def __new__(cls, *args):
            if len(args) != length:
                raise ValueError
            return super().__new__(cls, args)

    Constructor.__name__ = name
    Constructor.__qualname__ = f'{_cls.__qualname__}.{name}'

    subclasses.add(Constructor)
    setattr(_cls, name, Constructor)
    subclass_order.append(Constructor)


def _process_class(_cls, _repr, eq, order):
    if order and not eq:
        raise ValueError('eq must be true if order is true')

    lengths = {}
    subclasses = set()
    subclass_order = []
    for cls in reversed(_cls.__mro__):
        for key, value in getattr(cls, '__annotations__', {}).items():
            lengths[key] = len(_as_tuple(value))

    for name, length in lengths.items():
        _make_constructor(_cls, name, length, subclasses, subclass_order)

    @classmethod
    def __init_subclass__(cls, **kwargs):
        if issubclass(cls, tuple(subclasses)):
            raise TypeError
        # Allow it to go through otherwise, because Generic.
        return super(_cls, cls).__init_subclass__(**kwargs)

    _cls.__init_subclass__ = __init_subclass__

    @staticmethod
    def __new__(cls, args):
        if cls not in subclasses:
            raise TypeError
        return super(_cls, cls).__new__(cls, args)

    if _set_new_functions(_cls, __new__):
        base__new__ = _cls.__new__

        @staticmethod
        def __new__(cls, args):
            if cls not in subclasses:
                raise TypeError
            return base__new__(cls, args)

        _cls.__new__ = __new__

    if _repr:
        _set_new_functions(_cls, __repr__)

    equality_methods_were_set = False

    if eq:
        equality_methods_were_set = not _set_new_functions(
            _cls, __eq__, __ne__)

    if equality_methods_were_set:
        _cls.__hash__ = __hash__

    if order:
        if not equality_methods_were_set:
            raise ValueError(
                "Can't add ordering methods if equality methods are provided.")
        collision = _set_new_functions(_cls, __lt__, __le__, __gt__, __ge__)
        if collision:
            raise TypeError(f'Cannot overwrite attribute {collision} '
                            f'in class {_cls.__name__}. Consider using '
                            'functools.total_ordering')

    _cls.__subclass_order__ = tuple(subclass_order)

    return _cls


def algebraic(_cls=None, *, repr=True, eq=True, order=False):
    """Decorate a class to be an algebraic data type."""

    def wrap(cls):
        """Return the processed class."""
        return _process_class(cls, repr, eq, order)

    if _cls is None:
        return wrap

    return wrap(_cls)


DISCARD = object()


class Matcher(tuple):
    """A matcher that binds a value to a name."""

    __slots__ = ()

    def __new__(cls, name: str):
        if name == '_':
            return DISCARD
        if not name.isidentifier():
            raise ValueError
        if keyword.iskeyword(name):
            raise ValueError
        return super().__new__(cls, (name,))

    @property
    def name(self):
        """Return the name of the matcher."""
        return self[0]

    def __matmul__(self, other):
        return AsMatcher(self, other)


class AsMatcher(tuple):
    """A matcher that contains further bindings."""

    __slots__ = ()

    def __new__(cls, matcher: Matcher, match):
        if matcher is DISCARD:
            return match
        return super().__new__(cls, (matcher, match))

    @property
    def matcher(self):
        """Return the left-hand-side of the as-match."""
        return self[0]

    @property
    def match(self):
        """Return the right-hand-side of the as-match."""
        return self[1]


def names(target):
    """Return every name bound by a target."""
    name_list = []
    names_seen = set()
    to_process = [target]
    while to_process:
        item = to_process.pop()
        if isinstance(item, Matcher):
            if item.name in names_seen:
                raise ValueError
            names_seen.add(item.name)
            name_list.append(item.name)
        elif isinstance(item, AsMatcher):
            to_process.append(item.match)
            to_process.append(item.matcher)
        elif isinstance(item, AlgebraicConstructor):
            to_process.extend(reversed(_unpack(item)))
        elif isinstance(item, tuple):
            to_process.extend(reversed(item))
    yield from name_list


def _match(target, value):
    match_dict = collections.OrderedDict()
    to_process = [(target, value)]
    while to_process:
        target, value = to_process.pop()
        if target is DISCARD:
            pass
        elif isinstance(target, Matcher):
            if target.name in match_dict:
                raise ValueError
            match_dict[target.name] = value
        elif isinstance(target, AsMatcher):
            to_process.append((target.match, value))
            to_process.append((target.matcher, value))
        elif isinstance(target, AlgebraicConstructor):
            to_process.extend(zip(reversed(_unpack(target)),
                                  reversed(desugar(type(target), value))))
        elif (isinstance(target, tuple) and
              target.__class__ is value.__class__ and
              len(target) == len(value)):
            to_process.extend(zip(reversed(target), reversed(value)))
        elif target != value:
            raise MatchFailure
    return match_dict


def get_values(dct, keys):
    """Unpack a dict, in order."""
    for key in keys:
        yield dct[key]


class ValueMatcher:
    """Given a value, attempt to match against a target."""

    def __init__(self, value):
        self.value = value
        self.matches = None

    def match(self, target):
        """Match against target, generating a set of bindings."""
        try:
            self.matches = _match(target, self.value)
        except MatchFailure:
            self.matches = None
        return self.matches is not None


if __name__ == '__main__':

    R = typing.TypeVar('R')
    E = typing.TypeVar('E')

    @algebraic
    class Result(typing.Generic[R, E]):
        """Experimental generic ADT."""

        Ok: R
        Err: E

    my_result: Result[int, str] = Result.Ok(10)

Next week, or possibly sooner, I try to figure out how to write a mypy plugin.