Source code for Bcfg2.Server.SSLServer

""" Bcfg2 SSL server used by the builtin server core
(:mod:`Bcfg2.Server.BuiltinCore`).  This needs to be documented
better. """

import os
import sys
import socket
import signal
import logging
import ssl
import threading
import time
from Bcfg2.Compat import xmlrpclib, SimpleXMLRPCServer, SocketServer, \
    b64decode


[docs]class XMLRPCACLCheckException(Exception): """ Raised when ACL checks fail on an RPC request """
[docs]class XMLRPCDispatcher(SimpleXMLRPCServer.SimpleXMLRPCDispatcher): """ An XML-RPC dispatcher. """ def __init__(self, allow_none, encoding): try: SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self, allow_none, encoding) except: # Python 2.4? SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self) self.logger = logging.getLogger(self.__class__.__name__) self.allow_none = allow_none self.encoding = encoding def _marshaled_dispatch(self, address, data): params, method = xmlrpclib.loads(data) if not self.instance.check_acls(address, method): raise XMLRPCACLCheckException try: if '.' not in method: params = (address, ) + params response = self.instance._dispatch(method, params, self.funcs) # py3k compatibility if type(response) not in [bool, str, list, dict, set, type(None)]: response = (response.decode('utf-8'), ) elif type(response) == set: response = (list(response), ) else: response = (response, ) raw_response = xmlrpclib.dumps(response, methodresponse=True, allow_none=self.allow_none, encoding=self.encoding) except xmlrpclib.Fault: fault = sys.exc_info()[1] raw_response = xmlrpclib.dumps(fault, methodresponse=True, allow_none=self.allow_none, encoding=self.encoding) except: err = sys.exc_info() self.logger.error("Unexpected handler error", exc_info=1) # report exception back to server raw_response = xmlrpclib.dumps( xmlrpclib.Fault(1, "%s:%s" % (err[0].__name__, err[1])), methodresponse=True, allow_none=self.allow_none, encoding=self.encoding) return raw_response
[docs]class SSLServer(SocketServer.TCPServer, object): """ TCP server supporting SSL encryption. """ allow_reuse_address = True def __init__(self, listen_all, server_address, RequestHandlerClass, keyfile=None, certfile=None, reqCert=False, ca=None, timeout=None, protocol='xmlrpc/tlsv1'): """ :param listen_all: Listen on all interfaces :type listen_all: bool :param server_address: Address to bind to the server :param RequestHandlerClass: Request handler used by TCP server :param keyfile: Full path to SSL encryption key file :type keyfile: string :param certfile: Full path to SSL certificate file :type certfile: string :param reqCert: Require client to present certificate :type reqCert: bool :param ca: Full path to SSL CA that signed the key and cert :type ca: string :param timeout: Timeout for non-blocking request handling :param protocol: The protocol to serve. Supported values are ``xmlrpc/ssl`` and ``xmlrpc/tlsv1``. :type protocol: string """ # check whether or not we should listen on all interfaces if listen_all: listen_address = ('', server_address[1]) else: listen_address = (server_address[0], server_address[1]) # check for IPv6 address if ':' in server_address[0]: self.address_family = socket.AF_INET6 self.logger = logging.getLogger(self.__class__.__name__) try: SocketServer.TCPServer.__init__(self, listen_address, RequestHandlerClass) except socket.gaierror: e = sys.exc_info()[1] self.logger.error("Failed to bind to socket: %s" % e) raise except socket.error: self.logger.error("Failed to bind to socket") raise self.timeout = timeout self.socket.settimeout(timeout) self.keyfile = keyfile if (keyfile is not None and (keyfile is False or not os.path.exists(keyfile) or not os.access(keyfile, os.R_OK))): msg = "Keyfile %s does not exist or is not readable" % keyfile self.logger.error(msg) raise Exception(msg) self.certfile = certfile if (certfile is not None and (certfile is False or not os.path.exists(certfile) or not os.access(certfile, os.R_OK))): msg = "Certfile %s does not exist or is not readable" % certfile self.logger.error(msg) raise Exception(msg) self.ca = ca if (ca is not None and (ca is False or not os.path.exists(ca) or not os.access(ca, os.R_OK))): msg = "CA %s does not exist or is not readable" % ca self.logger.error(msg) raise Exception(msg) self.reqCert = reqCert if ca and certfile: self.mode = ssl.CERT_OPTIONAL else: self.mode = ssl.CERT_NONE if protocol == 'xmlrpc/ssl': self.ssl_protocol = ssl.PROTOCOL_SSLv23 elif protocol == 'xmlrpc/tlsv1': self.ssl_protocol = ssl.PROTOCOL_TLSv1 else: self.logger.error("Unknown protocol %s" % (protocol)) raise Exception("unknown protocol %s" % protocol)
[docs] def get_request(self): (sock, sockinfo) = self.socket.accept() sock.settimeout(self.timeout) # pylint: disable=E1101 sslsock = ssl.wrap_socket(sock, server_side=True, certfile=self.certfile, keyfile=self.keyfile, cert_reqs=self.mode, ca_certs=self.ca, ssl_version=self.ssl_protocol) return sslsock, sockinfo
[docs] def close_request(self, request): try: request.unwrap() except: pass try: request.close() except: pass
def _get_url(self): port = self.socket.getsockname()[1] hostname = socket.gethostname() protocol = "https" return "%s://%s:%i" % (protocol, hostname, port) url = property(_get_url)
[docs]class XMLRPCRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler): """ XML-RPC request handler. Adds support for HTTP authentication. """ def __init__(self, *args, **kwargs): self.logger = logging.getLogger(self.__class__.__name__) SimpleXMLRPCServer.SimpleXMLRPCRequestHandler.__init__(self, *args, **kwargs) def authenticate(self): try: header = self.headers['Authorization'] except KeyError: self.logger.error("No authentication data presented") return False auth_content = b64decode(header.split()[1]) try: # py3k compatibility try: username, password = auth_content.split(":") except TypeError: # pylint: disable=E0602 username, pw = auth_content.split(bytes(":", encoding='utf-8')) password = pw.decode('utf-8') # pylint: enable=E0602 except ValueError: username = auth_content password = "" cert = self.request.getpeercert() client_address = self.request.getpeername() return self.server.instance.authenticate(cert, username, password, client_address)
[docs] def parse_request(self): """Extends parse_request. Optionally check HTTP authentication when parsing. """ if not SimpleXMLRPCServer.SimpleXMLRPCRequestHandler.parse_request(self): return False try: if not self.authenticate(): self.logger.error("Authentication Failure") self.send_error(401, self.responses[401][0]) return False except: # pylint: disable=W0702 self.logger.error("Unexpected Authentication Failure", exc_info=1) self.send_error(401, self.responses[401][0]) return False return True
def do_POST(self): try: max_chunk_size = 10 * 1024 * 1024 size_remaining = int(self.headers["content-length"]) L = [] while size_remaining: chunk_size = min(size_remaining, max_chunk_size) chunk = self.rfile.read(chunk_size).decode('utf-8') if not chunk: break L.append(chunk) size_remaining -= len(L[-1]) data = ''.join(L) if data is None: return # response has been sent response = self.server._marshaled_dispatch(self.client_address, data) if sys.hexversion >= 0x03000000: response = response.encode('utf-8') except XMLRPCACLCheckException: self.send_error(401, self.responses[401][0]) self.end_headers() except: # pylint: disable=W0702 self.logger.error("Unexpected dispatch error for %s: %s" % (self.client_address, sys.exc_info()[1])) try: self.send_response(500) self.send_header("Content-length", "0") self.end_headers() except: (etype, msg) = sys.exc_info()[:2] self.logger.error("Error sending 500 response (%s): %s" % (etype.__name__, msg)) raise else: # got a valid XML RPC response client_address = self.request.getpeername() try: self.send_response(200) self.send_header("Content-type", "text/xml") self.send_header("Content-length", str(len(response))) self.end_headers() failcount = 0 while True: try: # If we hit SSL3_WRITE_PENDING here try to resend. self.wfile.write(response) break except ssl.SSLError: e = sys.exc_info()[1] if str(e).find("SSL3_WRITE_PENDING") < 0: raise self.logger.error("SSL3_WRITE_PENDING") failcount += 1 if failcount < 5: continue raise except socket.error: err = sys.exc_info()[1] if isinstance(err, socket.timeout): self.logger.warning("Connection timed out for %s" % self.client_address[0]) elif err[0] == 32: self.logger.warning("Connection dropped from %s" % self.client_address[0]) elif err[0] == 104: self.logger.warning("Connection reset by peer: %s" % self.client_address[0]) else: self.logger.warning("Socket error sending response to %s: " "%s" % (self.client_address[0], err)) except ssl.SSLError: err = sys.exc_info()[1] self.logger.warning("SSLError handling client %s: %s" % (self.client_address[0], err)) except: etype, err = sys.exc_info()[:2] self.logger.error("Unknown error sending response to %s: " "%s (%s)" % (self.client_address[0], err, etype.__name__)) def finish(self): # shut down the connection try: SimpleXMLRPCServer.SimpleXMLRPCRequestHandler.finish(self) except socket.error: err = sys.exc_info()[1] self.logger.warning("Error closing connection: %s" % err)
[docs]class XMLRPCServer(SocketServer.ThreadingMixIn, SSLServer, XMLRPCDispatcher, object): """ Component XMLRPCServer. """ def __init__(self, listen_all, server_address, RequestHandlerClass=None, keyfile=None, certfile=None, ca=None, protocol='xmlrpc/tlsv1', timeout=10, logRequests=False, register=True, allow_none=True, encoding=None): """ :param listen_all: Listen on all interfaces :type listen_all: bool :param server_address: Address to bind to the server :param RequestHandlerClass: request handler used by TCP server :param keyfile: Full path to SSL encryption key file :type keyfile: string :param certfile: Full path to SSL certificate file :type certfile: string :param ca: Full path to SSL CA that signed the key and cert :type ca: string :param logRequests: Log all requests :type logRequests: bool :param register: Presence should be reported to service-location :type register: bool :param allow_none: Allow None values in XML-RPC :type allow_none: bool :param encoding: Encoding to use for XML-RPC """ XMLRPCDispatcher.__init__(self, allow_none, encoding) if not RequestHandlerClass: # pylint: disable=E0102 class RequestHandlerClass(XMLRPCRequestHandler): """A subclassed request handler to prevent class-attribute conflicts.""" # pylint: enable=E0102 SSLServer.__init__(self, listen_all, server_address, RequestHandlerClass, ca=ca, timeout=timeout, keyfile=keyfile, certfile=certfile, protocol=protocol) self.logRequests = logRequests self.serve = False self.register = register self.register_introspection_functions() self.register_function(self.ping) self.logger.info("service available at %s" % self.url) self.timeout = timeout def _tasks_thread(self): try: while self.serve: try: if self.instance and hasattr(self.instance, 'do_tasks'): self.instance.do_tasks() except: self.logger.error("Unexpected task failure", exc_info=1) time.sleep(self.timeout) except: self.logger.error("tasks_thread failed", exc_info=1)
[docs] def server_close(self): SSLServer.server_close(self) self.logger.info("server_close()")
def _get_require_auth(self): return getattr(self.RequestHandlerClass, "require_auth", False) def _set_require_auth(self, value): self.RequestHandlerClass.require_auth = value require_auth = property(_get_require_auth, _set_require_auth) def _get_credentials(self): try: return self.RequestHandlerClass.credentials except AttributeError: return dict() def _set_credentials(self, value): self.RequestHandlerClass.credentials = value credentials = property(_get_credentials, _set_credentials)
[docs] def register_instance(self, instance, *args, **kwargs): XMLRPCDispatcher.register_instance(self, instance, *args, **kwargs) try: name = instance.name except AttributeError: name = "unknown" if hasattr(instance, '_get_rmi'): for fname, func in instance._get_rmi().items(): self.register_function(func, name=fname) self.logger.info("serving %s at %s" % (name, self.url))
[docs] def serve_forever(self): """Serve single requests until (self.serve == False).""" self.serve = True self.task_thread = \ threading.Thread(name="%sThread" % self.__class__.__name__, target=self._tasks_thread) self.task_thread.start() self.logger.info("serve_forever() [start]") signal.signal(signal.SIGINT, self._handle_shutdown_signal) signal.signal(signal.SIGTERM, self._handle_shutdown_signal) try: while self.serve: try: self.handle_request() except socket.timeout: pass except: self.logger.error("Got unexpected error in handle_request", exc_info=1) finally: self.logger.info("serve_forever() [stop]")
[docs] def shutdown(self): """Signal that automatic service should stop.""" self.serve = False
def _handle_shutdown_signal(self, *_): self.shutdown()
[docs] def ping(self, *args): """Echo response.""" self.logger.info("ping(%s)" % (", ".join([repr(arg) for arg in args]))) return args