Optimize dam operation with numba function

I am trying to implement the largest possible version of jaccard distance in python using Numba

@nb.jit()
def nbjaccard(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    return 1 - len(set1 & set2) / float(len(set1 | set2))

def jaccard(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    return 1 - len(set1 & set2) / float(len(set1 | set2))


%%timeit
nbjaccard("compare this string","compare a different string")

      

- 12.4 ms

%%timeit 
jaccard("compare this string","compare a different string")

      

- 3.87 ms

Why does numba version take longer? Any way to get a boost?

+3


source to share


2 answers


In my opinion, this was a small design flaw allowing pure object-fashioned numba functions (or that there is no warning if numba implements an entire function using python objects) - because they are usually a little slower than pure python functions.

Numba is very powerful (type dispatcher and you can write python code without type declarations - compared to C extensions or Cython - that's really great), but only when it supports the operation:

This means that any operation that is not listed there is not supported in "nopython" mode. And if numba should go back to "object mode" then beware:

object mode

Numba compilation mode, which generates code that treats all values ā€‹ā€‹as Python objects and uses the Python C API to perform all operations on those objects. Code compiled in object mode will often run no faster than interpreted Python code unless the Numba compiler can take advantage of loop jitter.

And this is exactly what happened in your case: you are working exclusively in object mode:

>>> nbjaccard.inspect_types()

[...]
# --- LINE 3 --- 
#   seq1 = arg(0, name=seq1)  :: pyobject
#   seq2 = arg(1, name=seq2)  :: pyobject
#   $0.1 = global(set: <class 'set'>)  :: pyobject
#   $0.3 = call $0.1(seq1)  :: pyobject
#   $0.4 = global(set: <class 'set'>)  :: pyobject
#   $0.6 = call $0.4(seq2)  :: pyobject
#   set1 = $0.3  :: pyobject
#   set2 = $0.6  :: pyobject

set1, set2 = set(seq1), set(seq2)

# --- LINE 4 --- 
#   $const0.7 = const(int, 1)  :: pyobject
#   $0.8 = global(len: <built-in function len>)  :: pyobject
#   $0.11 = set1 & set2  :: pyobject
#   $0.12 = call $0.8($0.11)  :: pyobject
#   $0.13 = global(float: <class 'float'>)  :: pyobject
#   $0.14 = global(len: <built-in function len>)  :: pyobject
#   $0.17 = set1 | set2  :: pyobject
#   $0.18 = call $0.14($0.17)  :: pyobject
#   $0.19 = call $0.13($0.18)  :: pyobject
#   $0.20 = $0.12 / $0.19  :: pyobject
#   $0.21 = $const0.7 - $0.20  :: pyobject
#   $0.22 = cast(value=$0.21)  :: pyobject
#   return $0.22

return 1 - len(set1 & set2) / float(len(set1 | set2))

      

As you can see, each operation works on Python objects (as indicated at :: pyobject

the end of each line). This is because numba

not support str

and set

s. So there is absolutely nothing that could be faster here. Also, you have an idea how to solve this problem using numpy arrays or homogeneous lists (numeric type).

On my computer, the time difference is much larger (using numba 0.32.0), but the individual timings are much faster - micro seconds ( 10**-6

seconds) instead of milliseconds ( 10**-3

seconds):

%timeit nbjaccard("compare this string","compare a different string")
10000 loops, best of 3: 84.4 Āµs per loop

%timeit jaccard("compare this string","compare a different string")
100000 loops, best of 3: 15.9 Āµs per loop

      

Note that the default is jit

lazy , so the first call must be made before you execute - because it includes the time to compile the code.




However, there is one optimization you could do: if you know the intersection of two sets, you can calculate the length of the union (like @Paul Hankin mentioned in his now deleted answer):

len(union) = len(set1) + len(set2) - len(intersection)

      

This will lead to the following (pure python) code:

def jaccard2(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    num_intersection = len(set1 & set2)
    return 1 - num_intersection / float(len(set1) + len(set2) - num_intersection)

%timeit jaccard2("compare this string","compare a different string")
100000 loops, best of 3: 13.7 Āµs per loop

      

Not much faster, but better.


There is room for improvement if you use :

%load_ext cython

%%cython
def cyjaccard(seq1, seq2):
    cdef set set1 = set(seq1)
    cdef set set2 = set()

    cdef Py_ssize_t length_intersect = 0

    for char in seq2:
        if char not in set2:
            if char in set1:
                length_intersect += 1
            set2.add(char)

    return 1 - (length_intersect / float(len(set1) + len(set2) - length_intersect))

%timeit cyjaccard("compare this string","compare a different string")
100000 loops, best of 3: 7.97 Āµs per loop

      

The main advantage here is that with just one iteration, you can create set2

and calculate the number of elements in an intersection (without having to create a set of intersections at all)!

+2


source


When I use these two functions it nbjaccard

takes ~ 4.7 microseconds (after jit warms up) and plain python takes ~ 3.2 microseconds using Numba 0.32.0. However, I don't expect numba to give you a speedup in this case, as there is currently no line-mode support nopython

. This means you are stepping through the python object layer, which is generally no different from running without jit, unless numba is able to do some smart loop (this is subblock compilation using pure intin functions, not python). You are probably just paying a little overhead besides for checking the types of inputs in the numba case.



I think the gist is that you are trying to use numba for a use case that is not currently covered. Where Numba really excels is dealing with numpy arrays and numerical scalar operations or problems that can be ported to the GPU.

+1


source







All Articles