How many sets of 4 numbers are there such that their xor is 0?
I have two non-negative integers x and y, both of which are at most 30 bits (so their values ββare around 10 ^ 9).
I would like to calculate how many sets of 4 numbers {a_1, a_2, a_3, a_4} exist in such a way that a_1 + a_2 = x and a_3 + a_4 = y and xor of all these 4 numbers is 0.
What is the fastest algorithm to solve this problem?
The fastest I can think of is to rearrange the xor equation to a_1 xor a_2 = a_3 xor a_4.
Then I can calculate all the left side values ββin O (x) and the right side values ββin O (y), so the whole algorithm works in O (x + y).
Let be the N(x, y)
number of solutions to this problem. Obviously N(0, 0)
equal to 1, since the only solution is (0, 0, 0, 0). And if either x
or is y
negative, then there are no solutions, since we require a1, a2, a3, a4 to be all non-negative.
Otherwise, we can proceed by solving for the least significant bit and generate a recurrence relation. Let's write n:0
and n:1
to denote 2n + 0 and 2n + 1 (so that 0 and 1 are the least significant bits).
Then:
N(0, 0) = 1
N(-x, y) = N(x, -y) = 0
N(x:0, y:0) = N(x, y) + N(x-1, y) + N(x, y-1) + N(x-1, y-1)
N(x:0, y:1) = N(x:1, y:0) = 0
N(x:1, y:1) = 4 * N(x, y)
To see them, you need to consider the possible least significant bits for any a1, a2, a3, a4.
First, the N(x:0, y:0)
. We need the least significant bit of a1 + a2 to be 0, which means that both a1 and a2 are even, or both are odd. If they are both odd, then carry and the sum of the higher bits plus 1 must be added to the higher bits of x. The same logic applies to a3, a4. There are 4 possibilities: all lower bits a1, a2, a3, a4 are 0, lower bits a1, a2 1, lower bits a3, a4 1, lower bits a1, a2, a3, a4 are 1. These are 4 cases.
Secondly, N(x:0, y:1)
and N(x:1, y:0)
. If one sum is even and the other is odd, then there are no solutions: you can check each combination for the least significant bits a1, a2, a3, a4 to find out.
Third N(x:1, y:1)
. Exactly one of a1 and a2 must be odd, and just like one of a3 and a4 must be odd. There are 4 possibilities in this, and don't wear them anyway.
Here's the complete solution:
def N(x, y):
if x == y == 0: return 1
if x < 0 or y < 0: return 0
if x % 2 == y % 2 == 0:
return N(x//2, y//2) + N(x//2-1, y//2) + N(x//2, y//2-1) + N(x//2-1, y//2-1)
elif x % 2 == y % 2 == 1:
return 4 * N(x//2, y//2)
else:
return 0
The algorithm makes multiple recursive calls, so theoretically exponential. But in practice, many branches end up quickly, so the code runs fast enough for values ββup to 2 ^ 30. But of course you can add a cache or use a dynamic programming table to guarantee the execution time O (log (x) + log (y)) ...
Finally, to increase confidence in correctness, here are some tests against a naive O (xy) solution:
def N_slow(x, y):
s = 0
for a1 in xrange(x + 1):
for a3 in xrange(y + 1):
a2 = x - a1
a4 = y - a3
if a1 ^ a2 ^ a3 ^ a4:
continue
s += 1
return s
for x in xrange(50):
for y in xrange(50):
n = N(x, y)
ns = N_slow(x, y)
if n != ns:
print 'N(%d, %d) = %d, want %d' % (x, y, n, ns)