#! /usr/bin/env python
__author__ = 'sjbrown'
'''
rectangle collision resolver

Contains one function of interest:
resolve_collisions(movingRect, deltaX, deltaY, 
                   collidables, strategy=CheckHexagonUsingRotation,
                   visitFn=None):

You may also be interested in the Strategy classes available. They can be
used as optional strategy=... arguments to resolve_collisions()

Available strategies:
FastAtomic, CheckHexagonUsingBranching, CheckHexagonUsingRotation
'''

import operator
from pygame import Rect

#------------------------------------------------------------------------------
def rotate_ccw_90(rect):
    return Rect(-rect.bottom, rect.left, rect.height, rect.width)

#------------------------------------------------------------------------------
def rotate_ccw_180(rect):
    return Rect(-rect.right, -rect.bottom, rect.width, rect.height)

#------------------------------------------------------------------------------
def rotate_ccw_270(rect):
    return Rect(rect.top, -rect.right, rect.height, rect.width)

#------------------------------------------------------------------------------
class Line:
    def __init__(self, p1, p2):
        rise = p2[1]-p1[1]
        run = p2[0]-p1[0]
        self.slope = float(rise)/run
        self.y_axis_intersect = p1[1] - p1[0]*self.slope

    def above(self, point):
        return point[1] > self.slope*point[0] + self.y_axis_intersect

    def below(self, point):
        return point[1] < self.slope*point[0] + self.y_axis_intersect

    def y_at_x(self, xcoord):
        return self.slope*xcoord + self.y_axis_intersect

    def x_at_y(self, ycoord):
        return (ycoord - self.y_axis_intersect)/float(self.slope)

#------------------------------------------------------------------------------
class Strategy:
    def __init__(self, origRect, deltaX, deltaY, visitFn):
        self.xLimiter = None
        self.yLimiter = None
        self.origRect = origRect
        self.deltaX = deltaX
        self.deltaY = deltaY
        self.movedRect = origRect.move(deltaX,deltaY)
        self.visitFn = visitFn

    def run(self, iterable):
        if self.visitFn:
            return self.run_with_visit(iterable)
        else:
            return self.run_sans_visit(iterable)

    def run_with_visit(self, iterable):
        for subject, rect in iterable:
            limitsX, limitsY = self.run_step(subject, rect)
            self.visitFn(subject, limitsX, limitsY)

    def run_sans_visit(self, iterable):
        for subject, rect in iterable:
            limitsX, limitsY = self.run_step(subject, rect)



#------------------------------------------------------------------------------
class CheckOneDirection(Strategy):
    '''Abstract class.  Local concrete classes are:
    CheckForward, CheckBackward, CheckUp, and CheckDown
    Concrete classes must set their own self.rotate_fn
    '''
    def __init__(self, origRect, deltaX, deltaY, visitFn):
        Strategy.__init__(self, origRect, deltaX, deltaY, visitFn)
        self.xLimiterRect = None
        origRect = self.rotate_fn(self.origRect)
        self.movedRect = self.rotate_fn(self.movedRect)
        self.bigRect = origRect.union(self.movedRect)

    def run_step_forward(self, subject, rect):
        if not rect.colliderect(self.bigRect):
            return False, False
        if self.xLimiterRect and rect.left > self.xLimiterRect.left:
            return False, False

        self.xLimiter = subject
        self.xLimiterRect = rect
        self.movedRect.right = self.xLimiterRect.left
        return True, False

    def run(self, iterable):
        Strategy.run(self, iterable)
        self.movedRect = self.reverse_rotate_fn(self.movedRect)

    def run_step(self, subject, rect):
        rect = self.rotate_fn(rect)
        return self.run_step_forward(subject, rect)

#------------------------------------------------------------------------------
class CheckForward(CheckOneDirection):
    '''CheckOneDirection does all of my job'''
    def __init__(self, origRect, deltaX, deltaY, visitFn):
        self.rotate_fn = Rect #just copy
        self.reverse_rotate_fn = Rect #just copy
        CheckOneDirection.__init__(self, origRect, deltaX, deltaY, visitFn)

#------------------------------------------------------------------------------
class CheckBackward(CheckOneDirection):
    def __init__(self, origRect, deltaX, deltaY, visitFn):
        self.rotate_fn = rotate_ccw_180
        self.reverse_rotate_fn = rotate_ccw_180
        CheckOneDirection.__init__(self, origRect, deltaX, deltaY, visitFn)

#------------------------------------------------------------------------------
class CheckYDirection(CheckOneDirection):
    def run_with_visit(self, iterable):
        for subject, rect in iterable:
            limitsX, limitsY = self.run_step(subject, rect)
            self.visitFn(subject, limitsY, limitsX) #reversed the order
        if self.xLimiter:
            self.yLimiter = self.xLimiter
            self.xLimiter = None

    def run_sans_visit(self, iterable):
        CheckOneDirection.run_sans_visit(self, iterable)
        if self.xLimiter:
            self.yLimiter = self.xLimiter
            self.xLimiter = None

#------------------------------------------------------------------------------
class CheckUp(CheckYDirection):
    def __init__(self, origRect, deltaX, deltaY, visitFn):
        self.rotate_fn = rotate_ccw_90
        self.reverse_rotate_fn = rotate_ccw_270
        CheckOneDirection.__init__(self, origRect, deltaX, deltaY, visitFn)

#------------------------------------------------------------------------------
class CheckDown(CheckYDirection):
    def __init__(self, origRect, deltaX, deltaY, visitFn):
        self.rotate_fn = rotate_ccw_270
        self.reverse_rotate_fn = rotate_ccw_90
        CheckOneDirection.__init__(self, origRect, deltaX, deltaY, visitFn)


#------------------------------------------------------------------------------
class _AbstractCheckHexagon(Strategy):
    def __init__(self, origRect, deltaX, deltaY, visitFn):
        Strategy.__init__(self, origRect, deltaX, deltaY, visitFn)
        self.bigRect = origRect.union(self.movedRect)
        self.outsideHexInsideBig = []
        self.maxXLimiterRect = None
        self.maxYLimiterRect = None

        # NOTE: if deltaX == 0 or deltaY == 0, 
        # there will be no sweep hexagon, it will be shaped like a 
        # rectangle.  you should have used a more efficient collision strategy

    def run(self, iterable):
        Strategy.run(self, iterable)
        self.post_deflection_retreat()

    def post_deflection_retreat(self):
        '''If hexagon sweep caused a deflection such that the movedRect is
        now outside the hexagon, it may currently be colliding with a 
        collidable that was discarded because it was also outside the 
        hexagon.  Such collidables were collected in the 
        self.outsideHexInsideBig attribute.

        This method looks through those collidables and retreats the 
        movedRect back towards the hexagon if any are hit
        '''
        raise NotImplementedError('Abstract class')


#------------------------------------------------------------------------------
class _CheckHexagonForwardDown(_AbstractCheckHexagon):
    '''blah'''
    def __init__(self, origRect, deltaX, deltaY, visitFn):
        _AbstractCheckHexagon.__init__(self, origRect, deltaX, deltaY, visitFn)
        self.leadingLine = Line(self.origRect.bottomright, 
                                self.movedRect.bottomright)
        self.topHexBorder = Line(self.origRect.topright, 
                                 self.movedRect.topright)
        self.bottomHexBorder = Line(self.origRect.bottomleft,
                                    self.movedRect.bottomleft)

    def post_deflection_retreat(self):
        subStrat = None
        leadingPoint = self.movedRect.bottomright
        farX, farY = leadingPoint
        dX, dY = 0,0
            
        if self.leadingLine.above(leadingPoint):
            dY = farY - self.leadingLine.y_at_x(farX)
            subStrat = CheckDown
        elif self.leadingLine.below(leadingPoint):
            dX = farX - self.leadingLine.x_at_y(farY)
            subStrat = CheckForward

        if subStrat:
            # shift to an 'original' position
            movingRect = self.movedRect.move(-dX,-dY)
            s = subStrat(movingRect, dX, dY, self.visitFn)
            s.run(self.outsideHexInsideBig)
            self.movedRect = s.movedRect
            if subStrat == CheckDown and s.yLimiter:
                self.yLimiter = s.yLimiter
            elif s.xLimiter:
                self.xLimiter = s.xLimiter

    def run_step(self, subject, rect):
        if not rect.colliderect(self.bigRect):
            return False, False
        if not ( rect.colliderect(self.movedRect) 
                 or self.corner_in_hexagon(rect) ):
            self.outsideHexInsideBig.append((subject, rect))
            return False, False
        # we've filtered out all the rects that don't collide with the sweep
        # hexagon, so now we just need to see if this is the x or y limiter
        limitsX, limitsY = False, False
        opposingPoint = rect.topleft
        if self.leadingLine.below(opposingPoint):
            limitsX = True
            self.check_and_set_maxXLimiter(opposingPoint, subject, rect)
        elif self.leadingLine.above(opposingPoint):
            limitsY = True
            self.check_and_set_maxYLimiter(opposingPoint, subject, rect)
        else:
            # opposingPoint *intersects* leadingLine
            limitsX, limitsY = True, True
            self.check_and_set_maxXLimiter(opposingPoint, subject, rect)
            self.check_and_set_maxYLimiter(opposingPoint, subject, rect)
        return limitsX, limitsY

    def corner_in_hexagon(self, rect):
        return (self.bottomHexBorder.below(rect.topright)
                and self.topHexBorder.above(rect.bottomleft) )

    def check_and_set_maxXLimiter(self, opposingPoint, subject, rect):
        if (self.maxXLimiterRect
            and opposingPoint[0] >= self.maxXLimiterRect.left):
            return
        self.maxXLimiterRect = rect
        self.xLimiter = subject
        self.movedRect.right = self.maxXLimiterRect.left

    def check_and_set_maxYLimiter(self, opposingPoint, subject, rect):
        if (self.maxYLimiterRect
            and opposingPoint[1] >= self.maxYLimiterRect.top):
            return
        self.maxYLimiterRect = rect
        self.yLimiter = subject
        self.movedRect.bottom = self.maxYLimiterRect.top


#------------------------------------------------------------------------------
class _CheckHexagonUsingRotation(_CheckHexagonForwardDown):
    def run(self, iterable):
        _CheckHexagonForwardDown.run(self, iterable)
        self.movedRect = self.reverse_rotate_fn(self.movedRect)

    def run_step(self, subject, rect):
        rect = self.rotate_fn(rect)
        return _CheckHexagonForwardDown.run_step(self, subject, rect)

#------------------------------------------------------------------------------
class _CheckHexagonForwardUp(_CheckHexagonUsingRotation):
    def __init__(self, origRect, deltaX, deltaY, visitFn):
        self.rotate_fn = rotate_ccw_90
        self.reverse_rotate_fn = rotate_ccw_270
        origRect = self.rotate_fn(origRect)
        dX, dY = self.rotate_fn( Rect(deltaX, deltaY, 0,0) ).topleft
        _CheckHexagonUsingRotation.__init__(self, origRect, dX, dY, visitFn)

#------------------------------------------------------------------------------
class _CheckHexagonBackwardUp(_CheckHexagonUsingRotation):
    def __init__(self, origRect, deltaX, deltaY, visitFn):
        self.rotate_fn = rotate_ccw_180
        self.reverse_rotate_fn = rotate_ccw_180
        origRect = self.rotate_fn(origRect)
        dX, dY = self.rotate_fn( Rect(deltaX, deltaY, 0,0) ).topleft
        _CheckHexagonUsingRotation.__init__(self, origRect, dX, dY, visitFn)

#------------------------------------------------------------------------------
class _CheckHexagonBackwardDown(_CheckHexagonUsingRotation):
    def __init__(self, origRect, deltaX, deltaY, visitFn):
        self.rotate_fn = rotate_ccw_270
        self.reverse_rotate_fn = rotate_ccw_90
        origRect = self.rotate_fn(origRect)
        dX, dY = self.rotate_fn( Rect(deltaX, deltaY, 0,0) ).topleft
        _CheckHexagonUsingRotation.__init__(self, origRect, dX, dY, visitFn)


#------------------------------------------------------------------------------
class _CheckHexagonUsingBranching(_AbstractCheckHexagon):
    '''blah'''
    def __init__(self, origRect, deltaX, deltaY, visitFn):
        _AbstractCheckHexagon.__init__(self, origRect, deltaX, deltaY, visitFn)
        
        # TODO: there's gotta be a better way than this ugly nested if thing
        if (deltaX >= 0 and deltaY >= 0) or (deltaX < 0 and deltaY < 0):
            self.topPointToCheck = 'topright'
            self.bottomPointToCheck = 'bottomleft'
            self.topHexBorder = Line(self.origRect.topright, 
                                     self.movedRect.topright)
            self.bottomHexBorder = Line(self.origRect.bottomleft,
                                        self.movedRect.bottomleft)
            if (deltaX >= 0 and deltaY >= 0):
                self.betterX = operator.lt
                self.betterY = operator.lt
                self.leadingPoint = 'bottomright'
                self.opposingPointToCheck = 'topleft'
                self.xToCheck = 'left'
                self.yToCheck = 'top'
                self.leadingLine = Line(self.origRect.bottomright, 
                                        self.movedRect.bottomright)
            else:
                self.betterX = operator.gt
                self.betterY = operator.gt
                self.leadingPoint = 'topleft'
                self.opposingPointToCheck = 'bottomright'
                self.xToCheck = 'right'
                self.yToCheck = 'bottom'
                self.leadingLine = Line(self.origRect.topleft, 
                                        self.movedRect.topleft)

        if (deltaX >= 0 and deltaY < 0) or (deltaX < 0 and deltaY >= 0):
            self.topPointToCheck = 'topleft'
            self.bottomPointToCheck = 'bottomright'
            self.topHexBorder = Line(self.origRect.topleft, 
                                     self.movedRect.topleft)
            self.bottomHexBorder = Line(self.origRect.bottomright,
                                        self.movedRect.bottomright)
            if (deltaX >= 0 and deltaY < 0):
                self.betterX = operator.lt
                self.betterY = operator.gt
                self.leadingPoint = 'topright'
                self.opposingPointToCheck = 'bottomleft'
                self.xToCheck = 'left'
                self.yToCheck = 'bottom'
                self.leadingLine = Line(self.origRect.topright, 
                                        self.movedRect.topright)
            else:
                self.betterX = operator.gt
                self.betterY = operator.lt
                self.leadingPoint = 'bottomleft'
                self.opposingPointToCheck = 'topright'
                self.xToCheck = 'right'
                self.yToCheck = 'top'
                self.leadingLine = Line(self.origRect.bottomleft, 
                                        self.movedRect.bottomleft)

    def post_deflection_retreat(self):
        def LookForXCollidersToTheRight():
            farX, farY = getattr(self.movedRect, self.leadingPoint)
            newDeltaX = farX - self.leadingLine.x_at_y(farY)
            newDeltaY = 0
            return CheckForward, newDeltaX, newDeltaY
        def LookForXCollidersToTheLeft():
            farX, farY = getattr(self.movedRect, self.leadingPoint)
            newDeltaX = -(farX - self.leadingLine.x_at_y(farY))
            newDeltaY = 0
            return CheckBackward, newDeltaX, newDeltaY
        def LookForYCollidersBelow():
            farX, farY = getattr(self.movedRect, self.leadingPoint)
            newDeltaX = 0
            newDeltaY = farY - self.leadingLine.y_at_x(farX)
            return CheckDown, newDeltaX, newDeltaY
        def LookForYCollidersAbove():
            farX, farY = getattr(self.movedRect, self.leadingPoint)
            newDeltaX = 0
            newDeltaY = -(farY - self.leadingLine.y_at_x(farX))
            return CheckUp, newDeltaX, newDeltaY
            
        subStrat = None
        leadingPoint = getattr(self.movedRect, self.leadingPoint)
        if self.deltaY > 0:
            if self.leadingLine.above(leadingPoint):
                subStrat, dX, dY = LookForYCollidersBelow()
            elif self.leadingLine.below(leadingPoint):
                if self.deltaX > 0:
                    subStrat, dX, dY = LookForXCollidersToTheRight()
                else:
                    subStrat, dX, dY = LookForXCollidersToTheLeft()
        else:
            if self.leadingLine.above(leadingPoint):
                if self.deltaX > 0:
                    subStrat, dX, dY = LookForXCollidersToTheRight()
                else:
                    subStrat, dX, dY = LookForXCollidersToTheLeft()
            elif self.leadingLine.below(leadingPoint):
                subStrat, dX, dY = LookForYCollidersAbove()

        if subStrat:
            # shift to an 'original' position
            movingRect = self.movedRect.move(-dX,-dY)
            s = subStrat(movingRect, dX, dY, self.visitFn)
            s.run(self.outsideHexInsideBig)
            self.movedRect = s.movedRect
            if subStrat in [CheckUp, CheckDown] and s.yLimiter:
                self.yLimiter = s.yLimiter
            elif s.xLimiter:
                self.xLimiter = s.xLimiter

    def run_step(self, subject, rect):
        if not rect.colliderect(self.bigRect):
            return False, False
        if not ( rect.colliderect(self.movedRect) 
                 or self.corner_in_hexagon(rect) ):
            self.outsideHexInsideBig.append((subject, rect))
            return False, False
        # we've filtered out all the rects that don't collide with the sweep
        # hexagon, so now we just need to see if this is the x or y limiter
        limitsX, limitsY = False, False
        opposingPoint = getattr(rect,self.opposingPointToCheck)
        if self.leadingLine.below(opposingPoint):
            if self.deltaY >= 0:
                limitsX = True
                self.check_and_set_maxXLimiter(opposingPoint, subject, rect)
            else:
                limitsY = True
                self.check_and_set_maxYLimiter(opposingPoint, subject, rect)
        elif self.leadingLine.above(opposingPoint):
            if self.deltaY >= 0:
                limitsY = True
                self.check_and_set_maxYLimiter(opposingPoint, subject, rect)
            else:
                limitsX = True
                self.check_and_set_maxXLimiter(opposingPoint, subject, rect)
        else:
            # opposingPoint intersects leadingLine
            limitsX, limitsY = True, True
            self.check_and_set_maxXLimiter(opposingPoint, subject, rect)
            self.check_and_set_maxYLimiter(opposingPoint, subject, rect)
        return limitsX, limitsY

    def corner_in_hexagon(self, rect):
        topPoint = getattr(rect,self.topPointToCheck) 
        bottomPoint = getattr(rect,self.bottomPointToCheck) 
        return (self.bottomHexBorder.below(topPoint)
                and self.topHexBorder.above(bottomPoint) )

    def set_maxXLimiter(self, subject, rect):
        self.maxXLimiterRect = rect
        self.xLimiter = subject
        if self.deltaX >= 0:
            self.movedRect.right = self.maxXLimiterRect.left
        else:
            self.movedRect.left = self.maxXLimiterRect.right

    def set_maxYLimiter(self, subject, rect):
        self.maxYLimiterRect = rect
        self.yLimiter = subject
        if self.deltaY >= 0:
            self.movedRect.bottom = self.maxYLimiterRect.top
        else:
            self.movedRect.top = self.maxYLimiterRect.bottom

    def check_and_set_maxXLimiter(self, opposingPoint, subject, rect):
        if not self.maxXLimiterRect:
            self.set_maxXLimiter(subject, rect)
            return

        bestX = getattr(self.maxXLimiterRect, self.xToCheck)
        if self.betterX(opposingPoint[0], bestX):
            self.set_maxXLimiter(subject, rect)

    def check_and_set_maxYLimiter(self, opposingPoint, subject, rect):
        if not self.maxYLimiterRect:
            self.set_maxYLimiter(subject, rect)
            return

        bestY = getattr(self.maxYLimiterRect, self.yToCheck)
        if self.betterY(opposingPoint[1], bestY):
            self.set_maxYLimiter(subject, rect)

#------------------------------------------------------------------------------
class FastAtomic(Strategy):
    '''The simplest collision strategy, just check to see if any of
    the collidables intersect with the Rect moved deltaX,deltaY.
    If there was a collision, set the rect back to the original position.
    This is not "correct", but may be sufficent.
    '''
    def run_with_visit(self, iterable):
        for subject, rect in iterable:
            limitsX, limitsY = self.run_step(subject, rect)
            if limitsX or limitsX:
                return

    def run_with_visit(self, iterable):
        for subject, rect in iterable:
            limitsX, limitsY = self.run_step(subject, rect)
            self.visitFn(subject, limitsX, limitsY)
            if limitsX or limitsX:
                return

    def run_step(self, subject, rect):
        limitsX = self.movedRect.colliderect(rect)
        limitsY = self.movedRect.colliderect(rect)
        if limitsX or limitsY:
            self.xLimiter = subject
            self.yLimiter = subject
            self.movedRect = self.origRect
        return limitsX, limitsY

#------------------------------------------------------------------------------
def CheckHexagonUsingBranching(movingRect, deltaX, deltaY, visitFn):
    '''Factory function.  If there is diagonal movement return an actual
    _CheckHexagonUsingBranching instance, otherwise, if there is movement 
    in only one direction, return an instance of one of the 
    CheckOneDirection classes
    '''
    args = (movingRect, deltaX, deltaY, visitFn)
    if deltaX == 0:
        if deltaY >= 0:
            return CheckDown(*args)
        else:
            return CheckUp(*args)
    elif deltaY == 0:
        if deltaX >= 0:
            return CheckForward(*args)
        else:
            return CheckBackward(*args)
    return _CheckHexagonUsingBranching(*args)

#------------------------------------------------------------------------------
def CheckHexagonUsingRotation(movingRect, deltaX, deltaY, visitFn):
    '''Factory function.  If there is diagonal movement return an actual
    _CheckHexagonUsingRotation instance, otherwise, if there is movement 
    in only one direction, return an instance of one of the 
    CheckOneDirection classes
    '''
    args = (movingRect, deltaX, deltaY, visitFn)
    if deltaX == 0:
        if deltaY >= 0:
            return CheckDown(*args)
        else:
            return CheckUp(*args)
    elif deltaY == 0:
        if deltaX >= 0:
            return CheckForward(*args)
        else:
            return CheckBackward(*args)
    elif (deltaX > 0 and deltaY > 0):
        return _CheckHexagonForwardDown(*args)
    elif (deltaX > 0 and deltaY < 0):
        return _CheckHexagonForwardUp(*args)
    elif (deltaX < 0 and deltaY > 0):
        return _CheckHexagonBackwardDown(*args)
    elif (deltaX < 0 and deltaY < 0):
        return _CheckHexagonBackwardUp(*args)

#------------------------------------------------------------------------------
def resolve_collisions(movingRect, deltaX, deltaY, 
                       collidables, strategy=CheckHexagonUsingRotation,
                       visitFn=None):
    '''Takes a Rect to be moved by deltaX,deltaY.
    Returns a 3-tuple, the moved Rect and 1 or 2 collidables - one 
    that limited the x movement and one that limited the y movement.  
    If a collidable was not found, it returns None in it's place.
    If it hits collidables, the returned Rect will be adjusted to a
    position such that it doesn't share any space with a collidable.
    '''
    movedRect = movingRect.move(deltaX,deltaY)

    # do nothing if not actually moving
    if 0 == deltaX == deltaY:
        return movedRect, None, None

    # do nothing if collidables is empty
    if not collidables:
        return movedRect, None, None

    # peek and see if the collection has .rect members. if not, treat 
    # them as if they were themselves Rects
    if hasattr(collidables.__iter__().next(), 'rect'):
        rectGenerator = ((sprite, sprite.rect) for sprite in collidables)
    else:
        rectGenerator = ((rect, rect) for sprite in collidables)

    s = strategy(movingRect, deltaX, deltaY, visitFn)
    s.run(rectGenerator)
    return s.movedRect, s.xLimiter, s.yLimiter

