X-Git-Url: http://gitweb.michael.orlitzky.com/?a=blobdiff_plain;f=mjo%2Feja%2Feja_utils.py;h=8334f516094fcbea657637da037f825c1cf4cbd2;hb=dc92538d3fc92d16c9b6432ad17c37cb0d6b2be9;hp=6f8cab6d8019dcbba0be1e81c3872a7ba738f807;hpb=a46720db62543983ab375654dee211ca844ac46c;p=sage.d.git diff --git a/mjo/eja/eja_utils.py b/mjo/eja/eja_utils.py index 6f8cab6..8334f51 100644 --- a/mjo/eja/eja_utils.py +++ b/mjo/eja/eja_utils.py @@ -1,6 +1,4 @@ -from sage.functions.other import sqrt -from sage.matrix.constructor import matrix -from sage.modules.free_module_element import vector +from sage.structure.element import is_Matrix def _scale(x, alpha): r""" @@ -106,21 +104,21 @@ def _all2list(x): [3, 4, 1, 0, 0, 0, 0, 0, 0, 0] """ - if hasattr(x, 'list') and hasattr(x, 'to_vector'): - # This avoids calling to_vector() on a matrix algebra with - # e.g. quaternions where the returned vector is of the wrong - # length (three instead of four) because the quaternions don't - # know how many generators they have. - return _all2list(x.list()) - if hasattr(x, 'to_vector'): # This works on matrices of e.g. octonions directly, without # first needing to convert them to a list of octonions and # then recursing down into the list. It also avoids the wonky # list(x) when x is an element of a CFM. I don't know what it - # returns but it aint the coordinates. This will fall through - # to the iterable case the next time around. - return _all2list(x.to_vector()) + # returns but it aint the coordinates. We don't recurse + # because vectors can only contain ring elements as entries. + return x.to_vector().list() + + if is_Matrix(x): + # This sucks, but for performance reasons we don't want to + # call _all2list recursively on the contents of a matrix + # when we don't have to (they only contain ring elements + # as entries) + return x.list() try: xl = list(x) @@ -131,15 +129,8 @@ def _all2list(x): # Avoid the retardation of list(QQ(1)) == [1]. return [x] - return sum(list( map(_all2list, xl) ), []) - - - -def _mat2vec(m): - return vector(m.base_ring(), m.list()) + return sum( map(_all2list, xl) , []) -def _vec2mat(v): - return matrix(v.base_ring(), sqrt(v.degree()), v.list()) def gram_schmidt(v, inner_product=None): """ @@ -254,54 +245,40 @@ def gram_schmidt(v, inner_product=None): sage: len(gram_schmidt(v)) == 2 True """ + if len(v) == 0: + # cool + return v + + V = v[0].parent() + if inner_product is None: inner_product = lambda x,y: x.inner_product(y) + def norm(x): - ip = inner_product(x,x) # Don't expand the given field; the inner-product's codomain # is already correct. For example QQ(2).sqrt() returns sqrt(2) # in SR, and that will give you weird errors about symbolics # when what's really going wrong is that you're trying to # orthonormalize in QQ. - return ip.parent()(ip.sqrt()) - - v = list(v) # make a copy, don't clobber the input - - # Drop all zero vectors before we start. - v = [ v_i for v_i in v if not v_i.is_zero() ] - - if len(v) == 0: - # cool - return v - - R = v[0].base_ring() - - # Our "zero" needs to belong to the right space for sum() to work. - zero = v[0].parent().zero() + return V.base_ring()(inner_product(x,x).sqrt()) sc = lambda x,a: a*x - if hasattr(v[0], 'cartesian_factors'): + if hasattr(V, 'cartesian_factors'): # Only use the slow implementation if necessary. sc = _scale def proj(x,y): + # project y onto the span of {x} return sc(x, (inner_product(x,y)/inner_product(x,x))) - # First orthogonalize... - for i in range(1,len(v)): - # Earlier vectors can be made into zero so we have to ignore them. - v[i] -= sum( (proj(v[j],v[i]) - for j in range(i) - if not v[j].is_zero() ), - zero ) + def normalize(x): + return sc(x, ~norm(x)) - # And now drop all zero vectors again if they were "orthogonalized out." - v = [ v_i for v_i in v if not v_i.is_zero() ] + v_out = [] # make a copy, don't clobber the input - # Just normalize. If the algebra is missing the roots, we can't add - # them here because then our subalgebra would have a bigger field - # than the superalgebra. - for i in range(len(v)): - v[i] = sc(v[i], ~norm(v[i])) + for (i, v_i) in enumerate(v): + ortho_v_i = v_i - V.sum( proj(v_out[j],v_i) for j in range(i) ) + if not ortho_v_i.is_zero(): + v_out.append(normalize(ortho_v_i)) - return v + return v_out