How to make the program faster [Keypad_Sticky_Note]
Attached note to keyboard
Minions have some of the verbatim boolean secrets that are safely locked away. Or they think. In fact, they are so confident that they even have a password hint attached to the lock keypad.
Locking requires a pair of non-negative integers (a, b) to be entered into the keyboard. Since integers can go as high as 2 billion, you turn to sticky note for help.
The sticky note contains two numbers, but even the minions know enough not to enter passwords. They actually wrote down the sum (they are denoted as s) and bitwise exception or (xor, marked as x) password integer pairs (a, b). Thus, they only need to remember. If they have difficulty with subtraction, they can use a bitwise exception or.
ie, we have s = a + b and x = a ^ b (where ^ is a bitwise XOR operation).
With your automated hacking equipment, each guess attempt takes a few milliseconds. Since you only have a little time before you open up, you want to know how long it will take before you can try all the combinations. Thanks to the note, you can eliminate some combinations without even typing them into the keyboard, and you can find out how long it takes to crack the lock - in the worst case.
Write a function called answer (s, x) that will find the number of pairs (a, b) that have a target sum and xor.
For example, if s = 10 and x = 4, then the possible values for (a, b) are (3, 7) and (7, 3), so the answer will return 2.
If s = 5 and x = 3, then there are no possible values, so the response will return 0.
s and x are not less than 0 and not more than 2 billion.
Languages
To provide a Python solution, edit the solution.py file To provide a Java solution, edit the solution.java file
Test cases
Inputs: (int) s = 10 (int) x = 4 Output: (int) 2
Inputs: (int) s = 0 (int) x = 0 Output: (int) 1
public static int answer(int s, int x) {
List<Integer> num = new ArrayList<>();
int a;
int b;
int sum;
int finalans;
for(int i = 0; i <=s; i++){
for(int e = 0; e <= s; e++){
sum = i + e;
if(sum == s){
if((i^e) == x){
if(!num.contains(i)){
num.add(i);
}
if(!num.contains(e)){
num.add(e);
}
}
}
}
}
finalans = num.size();
if((finalans%2) == 0){
return finalans*2;
} else if(!((finalans%2) == 0)){
return finalans;
}
return 0;
}
My code works, but takes too long when s and x get too large. How can I run this program faster?
source to share
You can solve this by realizing that there is a limited number of outgoing states (outgoing carry) for the incoming state (xor digit, sum digit, incoming carry). You can address each conditional state if
and use recursion to calculate the total number of combinations. You can use memoization to make recursion efficient. My solution below solves the problem in O(m)
time, where m
is the number of binary digits in your number data type. As the problem indicates that m = 32
(integers), this is a technical solution O(1)
.
Let me know if you have any questions. I have tried to add helpful comments to the code to explain various cases.
public class SumAndXor {
public static void main(String[] args) {
int a = 3;
int b = 7;
int sum = a + b;
int xor = a ^ b;
System.out.println(answer(sum, xor));
}
private static final int NOT_SET = -1;
// Driver
public static int answer(int sum, int xor) {
int numBitsPerInt = Integer.toBinaryString(Integer.MAX_VALUE).length() + 1;
int[][] cache = new int[numBitsPerInt][2];
for (int i = 0; i < numBitsPerInt; ++i) {
cache[i][0] = NOT_SET;
cache[i][1] = NOT_SET;
}
return answer(sum, xor, 0, 0, cache);
}
// Recursive helper
public static int answer(int sum, int xor, int carry, int index, int[][] cache) {
// Return memoized value if available
if (cache[index][carry] != NOT_SET) {
return cache[index][carry];
}
// Base case: nothing else to process
if ((sum >> index) == 0 && (xor >> index) == 0 && carry == 0) {
return 1;
}
// Get least significant bits
int sumLSB = (sum >> index) & 1;
int xorLSB = (xor >> index) & 1;
// Recursion
int result = 0;
if (carry == 0) {
if (xorLSB == 0 && sumLSB == 0) {
// Since the XOR is 0, the binary digits are either [0, 0] or [1, 1]. Since the
// sum is 0 and the incoming carry is 0, both [0, 0] and [1, 1] are valid. We
// recurse with a carry of 0 to represent [0, 0], and we recurse with a carry of
// 1 to represent [1, 1].
result = answer(sum, xor, 0, index + 1, cache) + answer(sum, xor, 1, index + 1, cache);
} else if (xorLSB == 0 && sumLSB == 1) {
// Since the XOR is 0, the binary digits are either [0, 0] or [1, 1]. Since the
// sum is 1 and the incoming carry is 0, neither [0, 0] nor [1, 1] is valid.
result = 0;
} else if (xorLSB == 1 && sumLSB == 0) {
// Since the XOR is 1, the binary digits are either [0, 1] or [1, 0]. Since the
// sum is 0 and the incoming carry is 0, neither [0, 1] nor [1, 0] is valid.
result = 0;
} else if (xorLSB == 1 && sumLSB == 1) {
// Since the XOR is 1, the binary digits are either [0, 1] or [1, 0]. Since the
// sum is 1 and the incoming carry is 0, both [0, 1] and [1, 0] is valid. We
// recurse with a carry of 0 to represent [0, 1], and we recurse with a carry
// of 0 to represent [1, 0].
result = 2 * answer(sum, xor, 0, index + 1, cache);
}
} else {
if (xorLSB == 0 && sumLSB == 0) {
// Since the XOR is 0, the binary digits are either [0, 0] or [1, 1]. Since the
// sum is 0 and the incoming carry is 1, neither [0, 0] nor [1, 1] is valid.
result = 0;
} else if (xorLSB == 0 && sumLSB == 1) {
// Since the XOR is 0, the binary digits are either [0, 0] or [1, 1]. Since the
// sum is 1 and the incoming carry is 1, both [0, 0] and [1, 1] are valid. We
// recurse with a carry of 0 to represent [0, 0], and we recurse with a carry of
// 1 to represent [1, 1].
result = answer(sum, xor, 0, index + 1, cache) + answer(sum, xor, 1, index + 1, cache);
} else if (xorLSB == 1 && sumLSB == 0) {
// Since the XOR is 1, the binary digits are either [0, 1] or [1, 0]. Since the
// sum is 0 and the incoming carry is 1, both [0, 1] and [1, 0] are valid. We
// recurse with a carry of 0 to represent [0, 1], and we recurse with a carry
// of 0 to represent [1, 0].
result = 2 * answer(sum, xor, 1, index + 1, cache);
} else if (xorLSB == 1 && sumLSB == 1) {
// Since the XOR is 1, the binary digits are either [0, 1] or [1, 0]. Since the
// sum is 1 and the incoming carry is 1, neither [0, 1] nor [1, 0] is valid.
result = 0;
}
}
cache[index][carry] = result;
return result;
}
}
source to share
Google says this is too long because your algorithm runs in O (n ^ 2) and Google wants it in O (lg n). If you ask me, it was too difficult to challenge level 3. I had an easier level 4. The solution to this does not sound like what you expect. In fact, you can never even set the values of (a, b) and compare (a, b) with (S, x) in the correct answer. This is against logic until you see and understand the solution.
Either way, it helps to display correct answers in a 2D graph or Excel spreadsheet using S for rows and x for columns (leaving zeros blank). Then find samples. In fact, the data points form the Sierpinski triangle (see http://en.wikipedia.org/wiki/Sierpinski_triangle ).
You will also notice that every data point (greater than zero) in a column is the same for all instances in that column, so given your x value, you automatically know what the final answer should be as long as the row corresponding to your S value crosses data point in the triangle. You just need to determine if the S (row) value intersects the triangle in column x. It makes sense?
Even the values in the columns form a pattern from 0 to x: 1, 2, 2, 4, 2, 4, 4, 8, 2, 4, 4, 8, 4, 8, 8, 16 .. I'm sure you can understand this is.
Here is the "final value given x" method along with most of the remaining code (in Python ... Java is too verbose and complicated). You just need to write a triangle traversal algorithm (I don't give it, but it's a solid push in the right direction):
def final(x, t):
if x > 0:
if x % 2: # x is odd
return final(x / 2, t * 2)
else: # x is even
return final(x / 2, t)
else:
return t
def mid(l, r):
return (l + r) / 2
def sierpinski_traverse(s_mod_xms, x. lo, hi, e, l, r):
# you can do this in 16 lines of code to end with...
if intersect:
# always start with a t-value of 1 when first calling final in case x=0
return final(x, 1)
else:
return 0
def answer(s, x):
print final(x, 1)
if s < 0 or x < 0 or s > 2000000000 or x > 2000000000 or s < x or s % 2 != x % 2:
return 0
if x == 0:
return 1
x_modulus_size = 2 ** int(math.log(x, 2) + 2)
s_mod_xms = s % x_modulus_size
lo_root = x_modulus_size / 4
hi_root = x_modulus_size / 2
exp = x_modulus_size / 4 # exponent of 2 (e.g. 2 ** exp)
return sierpinski_traverse(s_mod_xms, x, lo_root, hi_root, exp, exp, 2 * exp)
if __name__ == '__main__':
answer(10, 4)
source to share
Try changing num to HashSet. You can also clear your if / else at the end.
eg.
public static int answer(int s, int x) {
HashSet<Integer> num = new HashSet<>();
int a;
int b;
int sum;
int finalans;
for(int i = 0; i <=s; i++){
for(int e = 0; e <= s; e++){
sum = i + e;
if(sum == s){
if((i^e) == x){
num.add(i);
num.add(e);
}
}
}
}
finalans = num.size();
if((finalans%2) == 0){
return finalans*2;
} else {
return finalans;
}
}
source to share
Most of the steps in your algorithm do too much work:
- You are performing a linear scan over all non-negative integer values up to
s
. Since the problem is symmetric, scanning up to is sufficients/2
. - You perform a second line scan to find a
a
different integer for eachb
that satisfiesa + b = s
. Simple algebra shows that there is only one suchb
that there iss - a
, so no line scan is required at all. - You do a third line scan to see if you've already found a pair
(a, b)
. If you only obsess overs/2
, it will always containa ≤ b
, and therefore you will not have a double count.
Finally, I can think of one simple optimization to save some work:
- If
s
even, then eithera
, orb
are even or odd. Hence,a ^ b
even in this case. - If
s
odd, eithera
orb
is odd and thereforea ^ b
odd.
You can add this check before doing any work:
public static int answer(int s, int x) {
int result = 0;
if (s % 2 == x % 2) {
for (int a = 0; a <= s / 2; a++) {
int b = s - a;
if ((a ^ b) == x) {
result += 2;
}
}
// we might have double counted the pair (s/2, s/2)
// decrement the count if needed
if (s % 2 == 0 && ((s / 2) ^ (s / 2)) == x) {
result--;
}
}
return result;
}
source to share
To explain my previous answer further, look at the big picture ... literally. The triangle traversal algorithm works like a binary search, with the exception of three options instead of two ("3D" search?). Look at the 3 largest triangles within the largest triangle that spans both S and x. Then select a triangle of three containing S and x. Then look at the three largest triangles in the selected triangle and select the one that contains S and x. Repeat until you reach one point. If this point is nonzero, return the "end" value I specified. There are some if-else statements that will also speed this up if you select the triangle and the S string does not cross the data point.
source to share