[Author Prev][Author Next][Thread Prev][Thread Next][Author Index][Thread Index]

[pygame] Patch to pygame.sprite.AbstractGroup



If you look at pygame.sprite.AbstractGroup, you'll see that .add, .remove, and
.has all share a great deal of duplicate code. Furthermore, .remove calls .add
in one case, which is a bug that was probably introduced by copying the whole
function and trying to tweak it accordingly. This patch removes these recursive
cases from these functions and puts them in one function, ._slist, which turns
the sprite argument into a simple flat list of sprites, regardless of whether
it's a sprite, a list of sprites, a SpriteGroup, or any combination thereof.

This may cause a small efficiency hit on these functions, particularly .has,
which ordinarily returns as soon as it finds a False response, but I believe it
increases readability and maintainability by a comparable amount. Hope you apply.

Ethan
Index: lib/sprite.py
===================================================================
RCS file: /home/cvspsrv/cvsroot/games/pygame/lib/sprite.py,v
retrieving revision 1.29
diff -u -r1.29 sprite.py
--- lib/sprite.py	2 Feb 2005 16:44:01 -0000	1.29
+++ lib/sprite.py	12 Jul 2005 03:20:48 -0000
@@ -208,64 +208,56 @@
     def __contains__(self, sprite):
         return self.has(sprite)
 
-    def add(self, *sprites):
-        """add(sprite, list, or group, ...)
-           add sprite to group
-
-           Add a sprite or sequence of sprites to a group."""
+    def _slist(self, sprites):
+        temp = []
         for sprite in sprites:
             # It's possible that some sprite is also an iterator.
             # If this is the case, we should add the sprite itself,
             # and not the objects it iterates over.
             if isinstance(sprite, Sprite):
-                if not self.has_internal(sprite):
-                    self.add_internal(sprite)
-                    sprite.add_internal(self)
+                temp.append(sprite)
             else:
                 try:
                     # See if sprite is an iterator, like a list or sprite
                     # group.
+                    # Recursively expand sprite, in case it too happens to be
+                    # a list or otherwise iterable.
                     for spr in sprite:
-                        self.add(spr)
+                        temp.extend(self._slist(sprite))
                 except (TypeError, AttributeError):
                     # Not iterable, this is probably a sprite that happens
                     # to not subclass Sprite. Alternately, it could be an
                     # old-style sprite group.
                     if hasattr(sprite, '_spritegroup'):
                         for spr in sprite.sprites():
-                            if not self.has_internal(spr):
-                                self.add_internal(spr)
-                                spr.add_internal(self)
-                    elif not self.has_internal(sprite):
-                        self.add_internal(sprite)
-                        sprite.add_internal(self)
+                            temp.append(sprite)
+                    else:
+                        temp.append(sprite)
+        return temp
+                        
+
+    def add(self, *sprites):
+        """add(sprite, list, or group, ...)
+           add sprite to group
+
+           Add a sprite or sequence of sprites to a group."""
+        # This function behaves essentially the same as Group.add.
+        sprites = self._slist(sprites)
+        for sprite in sprites:
+            if not self.has_internal(sprite):
+                self.add_internal(sprite)
+                sprite.add_internal(self)
 
     def remove(self, *sprites):
         """remove(sprite, list, or group, ...)
            remove sprite from group
 
            Remove a sprite or sequence of sprites from a group."""
-        # This function behaves essentially the same as Group.add.
-        # Check for Spritehood, check for iterability, check for
-        # old-style sprite group, and fall back to assuming
-        # spritehood.
+        sprites = self._slist(sprites)
         for sprite in sprites:
-            if isinstance(sprite, Sprite):
-                if self.has_internal(sprite):
-                    self.remove_internal(sprite)
-                    sprite.remove_internal(self)
-            else:
-                try:
-                    for spr in sprite: self.add(spr)
-                except (TypeError, AttributeError):
-                    if hasattr(sprite, '_spritegroup'):
-                        for spr in sprite.sprites():
-                            if self.has_internal(spr):
-                                self.remove_internal(spr)
-                                spr.remove_internal(self)
-                    elif self.has_internal(sprite):
-                        self.remove_internal(sprite)
-                        sprite.remove_internal(self)
+            if self.has_internal(sprite):
+                self.remove_internal(sprite)
+                sprite.remove_internal(self)
 
     def has(self, *sprites):
         """has(sprite or group, ...)
@@ -276,23 +268,11 @@
            or 'subgroup in group'."""
         # Again, this follows the basic pattern of Group.add and
         # Group.remove.
+        sprites = self._slist(sprites)
         for sprite in sprites:
-            if isinstance(sprite, Sprite):
-                return self.has_internal(sprite)
-
-            try:
-                for spr in sprite:
-                    if not self.has(sprite):
-                        return False
-                return True
-            except (TypeError, AttributeError):
-                if hasattr(sprite, '_spritegroup'):
-                    for spr in sprite.sprites():
-                        if not self.has_internal(spr):
-                            return False
-                    return True
-                else:
-                    return self.has_internal(sprite)
+            if self.has_internal(sprite) == False:
+                return False
+        return True
 
     def update(self, *args):
         """update(*args)

Attachment: signature.asc
Description: OpenPGP digital signature