2010-10-21 45 views
7

Je veux accélérer le code suivant à l'aide cython:Cython peut-il accélérer le tableau d'itérations d'objets?

class A(object): 
    cdef fun(self): 
     return 3 


class B(object): 
    cdef fun(self): 
     return 2 

def test(): 
    cdef int x, y, i, s = 0 
    a = [ [A(), B()], [B(), A()]] 
    for i in xrange(1000): 
     for x in xrange(2): 
      for y in xrange(2): 
       s += a[x][y].fun() 
    return s 

La seule chose qui vient à l'esprit est quelque chose comme ceci:

def test(): 
    cdef int x, y, i, s = 0 
    types = [ [0, 1], [1, 0]] 
    data = [[...], [...]] 
    for i in xrange(1000): 
     for x in xrange(2): 
      for y in xrange(2): 
       if types[x,y] == 0: 
        s+= A(data[x,y]).fun() 
       else: 
        s+= B(data[x,y]).fun() 
    return s 

Fondamentalement, la solution en C++ sera d'avoir gamme des pointeurs vers une classe de base avec la méthode virtuelle fun(), alors vous pouvez parcourir rapidement. Y a-t-il un moyen de le faire en utilisant python/cython? BTW: Serait-il plus rapide d'utiliser le tableau 2D de numpy avec dtype = object_, plutôt que des listes python?

+0

Essayez les 2 boucles dérouler internes, les nombres sont petits, donc il ne sera pas ajouter du code beaucoup plus. Je pense qu'il y a de bonnes chances que cela puisse aider. –

+0

Ceci est juste un exemple, dans le code réel, une taille est grande et connue seulement à l'exécution – Maxim

Répondre

5

On dirait code comme ceci donne environ 20x speedup:

import numpy as np 
cimport numpy as np 
cdef class Base(object): 
    cdef int fun(self): 
     return -1 

cdef class A(Base): 
    cdef int fun(self): 
     return 3 


cdef class B(Base): 
    cdef int fun(self): 
     return 2 

def test(): 
    bbb = np.array([[A(), B()], [B(), A()]], dtype=np.object_) 
    cdef np.ndarray[dtype=object, ndim=2] a = bbb 

    cdef int i, x, y 
    cdef int s = 0 
    cdef Base u 

    for i in xrange(1000): 
     for x in xrange(2): 
      for y in xrange(2): 
       u = a[x,y]     
       s += u.fun() 
    return s 

Il vérifie même, que A et B sont héritées de base, probablement il y a moyen de le désactiver dans la version builds et obtenir gain de vitesse

supplémentaires

EDIT: Vérifiez peut être retiré à l'aide

u = <Base>a[x,y] 
+2

Y at-il une raison pour laquelle vous stockez les objets dans un tableau numpy plutôt qu'une liste ou une autre structure de données? – Zephyr