2010-12-02 33 views
9

PyBrain est une bibliothèque python qui fournit (entre autres choses) des réseaux neuronaux artificiels faciles à utiliser.Comment sérialiser/désérialiser les réseaux pybrain?

Je n'arrive pas à sérialiser/désérialiser correctement les réseaux PyBrain en utilisant pickle ou cPickle.

Voir l'exemple suivant:

from pybrain.datasets   import SupervisedDataSet 
from pybrain.tools.shortcuts  import buildNetwork 
from pybrain.supervised.trainers import BackpropTrainer 
import cPickle as pickle 
import numpy as np 

#generate some data 
np.random.seed(93939393) 
data = SupervisedDataSet(2, 1) 
for x in xrange(10): 
    y = x * 3 
    z = x + y + 0.2 * np.random.randn() 
    data.addSample((x, y), (z,)) 

#build a network and train it  

net1 = buildNetwork(data.indim, 2, data.outdim) 
trainer1 = BackpropTrainer(net1, dataset=data, verbose=True) 
for i in xrange(4): 
    trainer1.trainEpochs(1) 
    print '\tvalue after %d epochs: %.2f'%(i, net1.activate((1, 4))[0]) 

Ceci est la sortie du code ci-dessus:

Total error: 201.501998476 
    value after 0 epochs: 2.79 
Total error: 152.487616382 
    value after 1 epochs: 5.44 
Total error: 120.48092561 
    value after 2 epochs: 7.56 
Total error: 97.9884043452 
    value after 3 epochs: 8.41 

Comme vous pouvez le voir, l'erreur totale du réseau diminue à mesure que la formation progresse. Vous pouvez également voir que la valeur prédite se rapproche de la valeur attendue de 12.

Maintenant, nous allons faire un exercice similaire, mais comprendra sérialisation/désérialisation:

print 'creating net2' 
net2 = buildNetwork(data.indim, 2, data.outdim) 
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True) 
trainer2.trainEpochs(1) 
print '\tvalue after %d epochs: %.2f'%(1, net2.activate((1, 4))[0]) 

#So far, so good. Let's test pickle 
pickle.dump(net2, open('testNetwork.dump', 'w')) 
net2 = pickle.load(open('testNetwork.dump')) 
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True) 
print 'loaded net2 using pickle, continue training' 
for i in xrange(1, 4): 
     trainer2.trainEpochs(1) 
     print '\tvalue after %d epochs: %.2f'%(i, net2.activate((1, 4))[0]) 

Ceci est la sortie de ce bloc:

creating net2 
Total error: 176.339378639 
    value after 1 epochs: 5.45 
loaded net2 using pickle, continue training 
Total error: 123.392181859 
    value after 1 epochs: 5.45 
Total error: 94.2867637623 
    value after 2 epochs: 5.45 
Total error: 78.076711114 
    value after 3 epochs: 5.45 

Comme vous pouvez le voir, il semble que la formation a un effet sur le réseau (la valeur d'erreur totale déclarée continue de diminuer), mais la valeur de sortie du réseau se fige sur une valeur qui est pertinente pour la première formation itération.

Y a-t-il un mécanisme de mise en cache dont j'ai besoin de savoir qui cause ce comportement erroné? Existe-t-il de meilleurs moyens de sérialiser/désérialiser les réseaux pybrain?

numéros de version pertinents:

  • Python 2.6.5 (R265: 79096 19 Mar 2010, 21:48:26) [MSC 32 bits (Intel)]
  • NumPy 1.5. 1
  • cPickle 1,71
  • pybrain 0,3

PS J'ai créé a bug report sur le site du projet et garder les deux SO et le bug tracker updatedj

+0

Etes-vous sûr de ne plus faire 'trainer2 = BackpropTrainer (net2, dataset = data, verbose = True)' après avoir rechargé 'net2'? –

+0

@Seth Johnson Bien sûr, je le fais, mais cela ne résout pas le problème. En fait, mon code de test incluait cette ligne, mais je l'ai commis par erreur en la collant ici. Fixé maintenant –

Répondre

11

cause

Le mécanisme qui provoque ce comportement est la manipulation des paramètres (.params) et dérivés (.derivs) dans PyBrain modules: en fait, tous les paramètres réseau sont stockés dans un tableau, mais les objets individuels Module ou Connection ont accès à "leur propre" .params, qui, cependant, sont juste une vue sur une tranche de la matrice totale. Cela permet à la fois des écritures et des lectures locales et à l'échelle du réseau sur la même structure de données.

Apparemment, ce lien de vue en coupe est perdu par le décapage-décapage.

Solution

Insérer

net2.sorted = False 
net2.sortModules() 

après le chargement du fichier (qui reconstitue ce partage), et cela devrait fonctionner.