]> gitweb.michael.orlitzky.com - dunshire.git/blobdiff - dunshire/matrices.py
Figure out the base field automatically in condition_number().
[dunshire.git] / dunshire / matrices.py
index bcf83778d62752436f7894b5a6fbad4cc3e1012e..123085c6c02f69574dc614fcff49d46207b7981c 100644 (file)
@@ -3,6 +3,7 @@ Utility functions for working with CVXOPT matrices (instances of the
 class:`cvxopt.base.matrix` class).
 """
 
+from copy import copy
 from math import sqrt
 from cvxopt import matrix
 from cvxopt.lapack import gees, gesdd, syevr
@@ -142,7 +143,7 @@ def eigenvalues(symmat):
     eigs = matrix(0, (domain_dim, 1), tc='d')
 
     # Create a copy of ``symmat`` here because ``syevr`` clobbers it.
-    dummy = matrix(symmat, symmat.size)
+    dummy = copy(symmat)
     syevr(dummy, eigs)
     return list(eigs)
 
@@ -329,12 +330,12 @@ def norm(matrix_or_vector):
     --------
 
         >>> v = matrix([1,1])
-        >>> print('{:.5f}'.format(norm(v)))
-        1.41421
+        >>> norm(v)
+        1.414...
 
         >>> A = matrix([1,1,1,1], (2,2))
         >>> norm(A)
-        2.0
+        2.0...
 
     """
     return sqrt(inner_product(matrix_or_vector, matrix_or_vector))
@@ -421,25 +422,25 @@ def condition_number(mat):
     Examples
     --------
 
-    >>> condition_number(identity(1, typecode='d'))
-    1.0
-    >>> condition_number(identity(2, typecode='d'))
-    1.0
-    >>> condition_number(identity(3, typecode='d'))
+    >>> condition_number(identity(3))
     1.0
 
-    >>> A = matrix([[2,1],[1,2]], tc='d')
+    >>> A = matrix([[2,1],[1,2]])
     >>> abs(condition_number(A) - 3.0) < options.ABS_TOL
     True
 
-    >>> A = matrix([[2,1j],[-1j,2]], tc='z')
+    >>> A = matrix([[2,1j],[-1j,2]])
     >>> abs(condition_number(A) - 3.0) < options.ABS_TOL
     True
 
     """
     num_eigs = min(mat.size)
     eigs = matrix(0, (num_eigs, 1), tc='d')
-    gesdd(mat, eigs)
+    typecode = 'd'
+    if any([isinstance(entry, complex) for entry in mat]):
+        typecode = 'z'
+    dummy = matrix(mat, mat.size, tc=typecode)
+    gesdd(dummy, eigs)
 
     if len(eigs) > 0:
         return eigs[0]/eigs[-1]