#!/usr/bin/env python
#####################################################################
# $ProjectHeader: munchy 0.7.1 Tue, 25 Apr 2000 21:15:56 -0600 nas $
# Neil's Ad Munching HTTP Proxy Server
#
# Usage: proxy.py [port]
#
# This code has been placed in the public domain.
# Neil Schemenauer <nascheme@enme.ucalgary.ca>
#####################################################################

# default port if none specified on command line
PORT = 8000

# html munching
BLOCK_ADS = 1
FIX_MS_CHARS = 1
BLOCK_COOKIES = 1

# debugging level, 0 = no debugging
DEBUG_LEVEL = 1
BLOCK_ALL = 0 # block all URLs (for testing HTML parsing)

# patterns of sites and paths to block
BLOCK_PATTERNS = (
    # we don't like domains that serve ads
    ('^ads?[0-9]*\..*\.(com|net)$', '[0-9]'),
    ('adse?rv.*\.(com|net)$',  '[0-9]'),
    ('doubleclick\.net$', ''),
    ('^rd.yahoo.com$', ''),
    ('(mediaplex|realmedia|imgis)\.com$', ''),
    # Serving an ad by cgi? I don't think so.
    ('', '(/ads?/|cgi-bin/.*adlog).*([=&?]|\.gif)'),
    # this is for slashdot
    ('', '^/redir.pl[?]'),
    # ads usually come in gif format
    ('\.com$', '/(ads?|banners)/.*\.gif$'),
    # is this a trend? no hostname, no data from you
    ('^[0-9.:]+$', '[=&?]|\.gif$'),
    # kill cgi things with bad words in them
    ('', '(ad.*click|click.*thr|click.*ad).*[=&?]'),
    ('', '(advert|banner|adid|track|profileid).*[=&?]'),
    # javascript bites too
    #('', '*javascript:*'),
)


#####################################################################
# End of user configuration
#####################################################################

import sys
import os
import time
import string
import socket
import urlparse
import re
import fnmatch
import SocketServer


def log(s, level=1):
    if level <= DEBUG_LEVEL:
        sys.stderr.write(s)


class Blocker:
    def __init__(self, patterns):
        self.patterns = [] # [ ("site re", "path re") ... ]
        for host, path in patterns:
            self.patterns.append((re.compile(host, re.I),
                                  re.compile(path, re.I)))

    def split_url(self, url):
        """returns (host, path)"""
        parts = urlparse.urlparse(url)
        site = parts[1]
        path = urlparse.urlunparse(('', '', parts[2], parts[3],
                parts[4], parts[5]))
        return site, path

    def blocked(self, url):
        if not url:
            return 0
        site, path = self.split_url(url)
        log("testing %s %s for block\n" % (site, path), 3)
        for site_re, path_re in self.patterns:
            if site_re.search(site) and path_re.search(path): 
                log('BLOCKED %s "%s" "%s"\n' % (url, 
                        site_re.pattern, path_re.pattern), 1)
                return 1
        return 0


class MicrosoftCharacterFilter:

    TRANS = string.maketrans('\221\222\223\224', '`\'""')

    def filter(self, s):
        return string.translate(s, self.TRANS)


class HtmlUrlFilter:

    tag_start_re = re.compile(r"""
        (?i)                        # case-insensitive
        <                           # open tag
        (?P<name> a|img|layer)      # tag name
        \s+                         # whitespace
        """, re.VERBOSE)

    attr_re = re.compile(r"""
        (?i)                    # case-insensitive
        \s*                     # whitespace
        (?P<name>[a-z0-9-]+)    # attribute name  
        (\s*=\s*                # equals sign
          (?P<value>
            ".*?" |             # double quoted value
            '.*?' |             # single quoted value
            [^>\s]+)            # unquoted value
        )?                      # value optional
        """, re.VERBOSE)

    TAG_ATTR = 0
    TAG_CLOSE = 1
    tags = { # name: attr, close
            "a":  ("href", re.compile(r"(?i)</a\s*>")),
            "layer": ("src", re.compile(r"(?i)</layer\s*>")),
            "img": ("src",  None),
           }

    def __init__(self):
        self.unparsed = ''

    def finish(self):
        return self.unparsed

    def find_close(self, name, s, i):
        close_re = self.tags[name][self.TAG_CLOSE]
        if close_re:
            m = close_re.search(s, i)
            if m:
                i = m.end()
            else:
                i = -1
        else:
            i = string.find(s, '>', i)
        return i

    def parse_attr(self, s, i):
        """parse attributes inside a tag and return as dict"""
        attr = {}
        while 1:
            m = self.attr_re.match(s, i)
            if not m:
                break
            name = m.group('name')
            value = m.group('value')
            if value and (value[0] in '"\''):
                value = value[1:-1]
            attr[string.lower(name)] = value
            i = m.end()
        return attr, i+1
        
    def tag_blocked(self, name, attr):
        url = attr.get(self.tags[name][self.TAG_ATTR])
        if SiteBlocker.blocked(url) or BLOCK_ALL:
            return 1
        else:
            return 0

    def filter(self, s):
        if self.unparsed:
            log('begin, unparsed = %s\n' % repr(self.unparsed), 3)
            s = self.unparsed + s 
            self.unparsed = ''
        filtered = []
        i = 0
        while 1:
            # find start of tag
            m = self.tag_start_re.search(s, i)
            if not m:
                # split by newlines if possible
                i = string.rfind(s, '\n')
                if i != -1:
                    filtered.append(s[:i+1])
                    s = s[i+1:]
                self.unparsed = s
                break
            name = string.lower(m.group('name'))
            start = m.start()
            # parse contents of tag
            attr, i = self.parse_attr(s, m.end())
            log('tag attr %s = %s\n' % (name, `attr`), 3)
            end = self.find_close(name, s, i)
            if end == -1:
                filtered.append(s[:start])
                self.unparsed = s[start:]
                log('no close, unparsed = %s\n' % repr(self.unparsed[-40:]), 3)
                break
            if self.tag_blocked(name, attr):
                log('removing %s\n' % `s[start:end]`, 3)
                filtered.append(s[:start])
                s = s[end:]
                i = 0
            else:
                i = end
        return string.join(filtered, '')


def munch_html(input, output):
    msfilter = MicrosoftCharacterFilter()
    adfilter = HtmlUrlFilter()
    while 1:
        data = input.read(512)
        if not data:
            output.write(adfilter.finish())
            break
        log('munch data = %s' % data, 4)
        if FIX_MS_CHARS:
            data = msfilter.filter(data)
        if BLOCK_ADS:
            data = adfilter.filter(data)
        output.write(data)



def try_del(dict, key):
    try:
        del dict[key]
    except KeyError:
        pass


class ProxyHandler(SocketServer.StreamRequestHandler):
    def handle(self):
        """handle one request from the browser"""
        host, port, request = self.read_request()
        srfile, swfile = self.connect(host, port)
        log('sending request to server "%s"\n' % `request`, 2)
        self.send_request(swfile, request)
        try:
            self.handle_response(srfile)
        except IOError:
            pass # browser closed connection?
        log('finished request', 2)

    def read_request(self):
        """read request to find host and port and make new request"""
        request = self.rfile.readline()
        sys.stdout.write('%s - %s - %s' % (
                                self.client_address[0], 
                                time.ctime(time.time()),
                                request))
        try:
            method, url, protocol = string.split(request)
        except:
            self.error(400, "Can't parse request")
        if not url:
            self.error(400, "Empty URL")
        if method not in ['GET', 'HEAD', 'POST']:
            self.error(501, "Unknown request method (%s)" % method)
        # split url into site and path
        scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
        if string.lower(scheme) != 'http':
            self.error(501, "Unknown request scheme (%s)" % scheme)
        # find port number
        if ':' in netloc:
            host, port = string.split(netloc, ':')
            port = string.atoi(port)
        else:
            host = netloc
            port = 80
        path = urlparse.urlunparse(('', '', path, params, query, fragment))
        # read headers
        headers = self.read_headers(self.rfile)
        if method == 'POST' and not headers.has_key('content-length'):
            self.error(400, "Missing Content-Length for POST method")
        length = int(headers.get('content-length', 0))
        # read content if any
        content = self.rfile.read(length)
        log('content = %s\n' % `content`, 2)
        # build new request
        try_del(headers, 'accept-encoding')
        try_del(headers, 'proxy-connection')
        request = '%s %s HTTP/1.0\r\n%s\r\n%s' % (method, path,
                                                  self.join_headers(headers),
                                                  content)
        return host, port, request

    def connect(self, host, port):
        try:
            addr = socket.gethostbyname(host)
            server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            server.connect((addr, port))
        except socket.error, err:
            self.error(200, 'Error connecting to "%s" (%s)' % (host, err))
        return server.makefile('rb'), server.makefile('wb')

    def read_headers(self, input):
        headers = {}
        name = ''
        while 1:
            line = input.readline()
            if line == '\r\n' or line == '\n':
                break
            if line[0] in ' \t':
                # continued header
                headers[name] = headers[name] + '\r\n ' + string.strip(line)
            else:
                i = string.find(line, ':')
                assert(i != -1)
                name = string.lower(line[:i])
                if headers.has_key(name):
                    # merge value
                    headers[name] = headers[name] + ', ' + string.strip(line)
                else:
                    headers[name] = string.strip(line[i+1:])
        return headers

    def join_headers(self, headers):
        data = []
        for name, value in headers.items():
            data.append('%s: %s\r\n' % (name, value))
        return string.join(data, '')

    def send_request(self, server, request):
        try:
            server.write(request)
            server.flush()
        except socket.error, err:
            self.error(500, 'Error sending data to "%s" (%s)' % (host, err))

    def handle_response(self, server):
        log('reading server response\n', 2)
        response = server.readline()
        log('response = %s\n' % response, 2)
        fields = string.split(response)
        version = fields[0]
        status = fields[1]
        comment = string.join(fields[2:])
        self.wfile.write('HTTP/1.0 %s %s\r\n' % (status, comment))
        log('reading response headers\n', 2)
        headers = self.read_headers(server)
        type = headers.get('content-type', 'unknown')
        if type[:9] == 'text/html': # match 'text/html; blah blah'
            type = 'text/html'
            try_del(headers, 'content-length')
            if BLOCK_COOKIES:
                try_del(headers, 'set-cookie')
        log('writing headers to client\n', 3)
        self.wfile.write(self.join_headers(headers))
        self.wfile.write('\r\n')
        log('encoding %s\n' % type, 2)
        if type == 'text/html':
            log('munching html\n', 2)
            munch_html(server, self.wfile)
        else:
            # read by blocks
            log('transfering raw data\n', 2)
            while 1:
                data = server.read(4096)
                log('data = %s\n' % `data`, 5)
                if not data:
                    break
                self.wfile.write(data)
        self.wfile.flush()

    def finish(self):
        import select
        try:
            self.connection.setblocking(0)
            r, w, e = select.select([self.rfile], [self.wfile], [], 0)
            if r and w:
                self.wfile.write(self.rfile.read())
            SocketServer.StreamRequestHandler.finish(self)
        except IOError:
            pass

    def error(self, code, body):
        import BaseHTTPServer
        response = BaseHTTPServer.BaseHTTPRequestHandler.responses[code][0]
        self.wfile.write("HTTP/1.0 %s %s\r\n" % (code, response))
        self.wfile.write("Server: Neil's Proxy\r\n")
        self.wfile.write("Content-type: text/html\r\n")
        self.wfile.write("\r\n")
        self.wfile.write('<html><head>\n<title>%d %s</title>\n</head>\n'
                '<body>\n%s\n</body>\n</html>' % (code, response, body))
        self.wfile.flush()
        self.wfile.close()
        self.rfile.close()
        raise SystemExit


#ServerBase = SocketServer.TCPServer
try:
    import thread
    ServerBase = SocketServer.ThreadingTCPServer
except ImportError:
    ServerBase = SocketServer.ForkingTCPServer

class ProxyServer(ServerBase):
    def __init__(self, port):
        log('Starting proxy on port %d\n' % port, 1)
        ServerBase.__init__(self, ('', port), ProxyHandler)

    def process_request(self, request, client_address):
        return ServerBase.process_request(self, request, client_address)


SiteBlocker = Blocker(BLOCK_PATTERNS)
if __name__ == '__main__':
    if len(sys.argv) < 2:
        ProxyServer(PORT).serve_forever()
    else:
        ProxyServer(int(sys.argv[1])).serve_forever()
