Sorting as a convex optimization problem
Lately, I have been working on problems that require solving convex optimization problems. In the course of this work I have been using cvxpy
. The cvxpy
homepage states:
CVXPY is a Python-embedded modeling language for convex optimization problems.
It is inspired by CVX, which in turn implements a language for diciplined convex programming (DCP).
Just for fun, I used cvxpy
to write a program that sorts a list of numbers. It is based on the observation that for weights \(w \in \mathbb{R}^{n}\) and values \(x \in \mathbb{R}^{n}\), the weighted sum \[
w^{T} x = \sum_{i=1}^{n} w_{i}x_{i}
\] is larger when larger values are weighted more than smaller. This means that for permutation matrix \(P\) \[
w^{T} P x
\] is maximal when \(P\) reorders \(x\) such that this happens. This means that if we let \(w_{i} < w_{i+1}\) for \(i \in \{1,2,\ldots, n-1\}\), the permutation matrix \(P\) that sorts \(x\) as \(P x\) is the one that maximizes \(w^{T} P x\). Searching among permutation matrices to find the right one is difficult. However, if we instead of looking for a matrix with entries 0 and 1 look for a matrix with entries in \([0,1]\), things become easier. Particularly for doubly stochastic matrices. Such a matrix is a convex combination of permutation matrices (Birkhoff’s Theorem) and consequently the permutation matrix \(P\) that sorts \(x\) as \(P x\) is the doubly stochastic matrix that maximizes \(w^{T} P x\).
This is the program I wrote:
import cvxpy as cy
import numpy as np
= 6
n = np.round(np.random.uniform(-1, 1, n) * 10, 0)
xvals = np.arange(n)
wvals
# cvxpy definitions
= cy.Constant(xvals)
x = cy.Constant(wvals)
w = cy.Constant(np.ones(n)) # constant vector of all ones
ones = cy.Variable((n,n),nonneg=True) # matrix P
P = cy.Minimize(-w.T * P * x) # objective function
obj = ([P*ones == ones, ones.T*P == ones]) # doubly stochastic P
cons = cy.Problem(obj, cons) # problem definition
p
# solve problem: doubly stochastic P becomes a permutation at optimum
='ECOS', verbose= not True)
p.solve(solver
print 'x:', xvals
print 'w:', w.value
print
print 'objective value :', -w.value.dot(P.value.dot(xvals))
print 'P . x :', list(np.round(P.value.dot(xvals),5))
print 'sorted(x) :', sorted(xvals, reverse=not ascending)
print
print 'Permutation matrix P (rounded to 7 digits):'
print np.round(P.value, 7)
Gives output:
x: [ 0.27256 0.65008 -0.31671 0.47928 -0.06169 0.08558]
w: [0. 1. 2. 3. 4. 5.]
objective value : -6.094669999566384
P . x : [-0.31671, -0.06169, 0.08558, 0.27256, 0.47928, 0.65008]
sorted(x) : [-0.31671, -0.06169, 0.08558, 0.27256, 0.47928, 0.65008]
Permutation matrix P (rounded to 7 digits):
[[0. 0. 1. 0. 0. 0.]
[0. 0. 0. 0. 1. 0.]
[0. 0. 0. 0. 0. 1.]
[1. 0. 0. 0. 0. 0.]
[0. 0. 0. 1. 0. 0.]
[0. 1. 0. 0. 0. 0.]]
While the program works, it is not the most efficient way of sorting. More on that over at Wikipedia’s page on sorting algorithms.