]> gitweb.michael.orlitzky.com - sage.d.git/commitdiff
Remove the __call__() code and refactor everything into subcalls of __getitem__().
authorMichael Orlitzky <michael@orlitzky.com>
Tue, 6 Nov 2012 03:42:52 +0000 (22:42 -0500)
committerMichael Orlitzky <michael@orlitzky.com>
Tue, 6 Nov 2012 03:42:52 +0000 (22:42 -0500)
mjo/symbol_sequence.py

index a6459304e2ddb7a660395473250b16fa7d9e2855..757c60d853ab1b36fae3e7b3df6e8bcd7a7512ba 100644 (file)
@@ -30,7 +30,8 @@ class SymbolSequence:
         sage: a[1]
         a1
 
-    Create coefficients for polynomials of arbitrary degree::
+    Create polynomials with symbolic coefficients of arbitrary
+    degree::
 
         sage: a = SymbolSequence('a')
         sage: p = sum([ a[i]*x^i for i in range(0,5)])
@@ -55,7 +56,7 @@ class SymbolSequence:
         sage: [ a[i,j] for i in range(0,2) for j in range(0,2) ]
         [a00, a01, a10, a11]
 
-    You can pass slice objects instead of integers to obtain a list of
+    You can pass slices instead of integers to obtain a list of
     symbols::
 
         sage: a = SymbolSequence('a')
@@ -85,7 +86,7 @@ class SymbolSequence:
 
         sage: a = SymbolSequence()
         sage: a0str = str(a[0])
-        sage: str(a(0)) == a0str
+        sage: str(a[0]) == a0str
         True
 
     Slices and single indices work when combined::
@@ -101,13 +102,9 @@ class SymbolSequence:
     def __init__(self, name=None, latex_name=None, domain=None):
         # We store a dict of already-created symbols so that we don't
         # recreate a symbol which already exists. This is especially
-        # helpful when using unnamed variables, if you want e.g. a(0)
+        # helpful when using unnamed variables, if you want e.g. a[0]
         # to return the same variable each time.
-        #
-        # The entry corresponding to None is the un-subscripted symbol
-        # with our name.
-        unsubscripted = SR.symbol(name, latex_name, domain)
-        self._symbols = { None: unsubscripted }
+        self._symbols = {}
 
         self._name = name
         self._latex_name = latex_name
@@ -115,16 +112,19 @@ class SymbolSequence:
 
 
     def _create_symbol_(self, subscript):
-        if self._name is None:
-            # Allow creating unnamed symbols, for consistency with
-            # SR.symbol().
-            name = None
-        else:
+        """
+        Return a symbol with the given subscript. Creates the
+        appropriate name and latex_name before delegating to
+        SR.symbol().
+        """
+        # Allow creating unnamed symbols, for consistency with
+        # SR.symbol().
+        name = None
+        if self._name is not None:
             name = '%s%d' % (self._name, subscript)
 
-        if self._latex_name is None:
-            latex_name = None
-        else:
+        latex_name = None
+        if self._latex_name is not None:
             latex_name = r'%s_{%d}' % (self._latex_name, subscript)
 
         return SR.symbol(name, latex_name, self._domain)
@@ -133,7 +133,9 @@ class SymbolSequence:
     def _flatten_list_(self, l):
         """
         Recursively flatten the given list, allowing for some elements
-        to be non-iterable.
+        to be non-iterable. This is slow, but also works, which is
+        more than can be said about some of the snappier solutions of
+        lore.
         """
         result = []
 
@@ -147,67 +149,93 @@ class SymbolSequence:
 
 
     def __getitem__(self, key):
+        """
+        This handles individual integer arguments, slices, and
+        tuples. It just hands off the real work to
+        self._subscript_foo_().
+        """
         if isinstance(key, tuple):
-            return self(*key)
+            return self._subscript_tuple_(key)
 
         if isinstance(key, slice):
-            # We were given a slice. Clean up some of its properties
-            # first. The start/step are default for lists. We make
-            # copies of these because they're read-only.
-            (start, step) = (key.start, key.step)
-            if start is None:
-                start = 0
-            if key.stop is None:
-                # Would otherwise loop forever since our "length" is
-                # undefined.
-                raise ValueError('You must supply an terminal index')
-            if step is None:
-               step = 1
+            return self._subscript_slice_(key)
 
-            # If the user asks for a slice, we'll be returning a list
-            # of symbols.
-            return [ self(idx) for idx in range(start, key.stop, step) ]
+        # This is the most common case so it would make sense to have
+        # this test first. But there are too many different "integer"
+        # classes that you have to check for.
+        return self._subscript_integer_(key)
 
-        return self(key)
 
+    def _subscript_integer_(self, n):
+        """
+        The subscript is a single integer, or something that acts like
+        one.
+        """
+        if n < 0:
+            # Cowardly refuse to create a variable named "a-1".
+            raise IndexError('Indices must be nonnegative')
 
-    def __call__(self, *args):
-        args = list(args)
+        try:
+            return self._symbols[n]
+        except KeyError:
+            self._symbols[n] = self._create_symbol_(n)
+            return self._symbols[n]
 
-        if len(args) == 0:
-            return self._symbols[None]
 
-        # This is safe after the len == 0 test.
+    def _subscript_slice_(self, s):
+        """
+        We were given a slice. Clean up some of its properties
+        first. The start/step are default for lists. We make
+        copies of these because they're read-only.
+        """
+        (start, step) = (s.start, s.step)
+        if start is None:
+            start = 0
+        if s.stop is None:
+            # Would otherwise loop forever since our "length" is
+            # undefined.
+            raise ValueError('You must supply an terminal index')
+        if step is None:
+            step = 1
+
+        # If the user asks for a slice, we'll be returning a list
+        # of symbols.
+        return [ self._subscript_integer_(idx)
+                 for idx in range(start, s.stop, step) ]
+
+
+
+    def _subscript_tuple_(self, args):
+        """
+        When we have more than one level of subscripts, we pick off
+        the first one and generate the rest recursively.
+        """
+
+        # We never call this method without an argument.
         key = args[0]
         args = args[1:] # Peel off the first arg, which we've called 'key'
 
+        # We don't know the type of 'key', but __getitem__ will figure
+        # it out and dispatch properly.
+        v = self[key]
+        if len(args) == 0:
+            # There was only one element left in the tuple.
+            return v
+
+        # At this point, we know we were given at least a two-tuple.
+        # The symbols corresponding to the first entry are already
+        # computed, in 'v'. Here we recursively compute the symbols
+        # corresponding to the second coordinate, with the first
+        # coordinate(s) fixed.
         if isinstance(key, slice):
-            if len(args) == 0:
-                return self[key]
-            else:
-                v = self[key]
-                ss = [ SymbolSequence(w._repr_(), w._latex_(), self._domain)
-                       for w in v ]
-
-                # This might be nested...
-                maybe_nested_list = [ s(*args) for s in ss ]
-                return self._flatten_list_(maybe_nested_list)
+            ss = [ SymbolSequence(w._repr_(), w._latex_(), self._domain)
+                   for w in v ]
 
-        if key < 0:
-            # Cowardly refuse to create a variable named "a-1".
-            raise IndexError('Indices must be nonnegative')
+            # This might be nested...
+            maybe_nested_list = [ s._subscript_tuple_(args) for s in ss ]
+            return self._flatten_list_(maybe_nested_list)
 
-        if len(args) == 0:
-            # Base case, create a symbol and return it.
-            try:
-                return self._symbols[key]
-            except KeyError:
-                self._symbols[key] = self._create_symbol_(key)
-                return self._symbols[key]
         else:
-            # If we're given more than one index, we want to create the
-            # subsequences recursively. For example, if we're asked for
-            # x(1,2), this is really SymbolSequence('x1')(2).
-            v = self(key) # x(1) -> x1
+            # If it's not a slice, it's an integer.
             ss = SymbolSequence(v._repr_(), v._latex_(), self._domain)
-            return ss(*args)
+            return ss._subscript_tuple_(args)