Tuesday 28 July 2009

IOCP-based sockets with ctypes in Python: 6

Previous post: IOCP-based sockets with ctypes in Python: 5

The goal of this step is to implement a partial amount of a Stackless-compatible socket module that provides the basic functionality the standard one has. It should be usable to provide blocking socket IO within Stackless tasklets.

The most straightforward approach is to base the code on how it was done with stacklesssocket.py. In stacklesssocket.py, a tasklet is launched to poll asyncore.

managerRunning = False

def ManageSockets():
global managerRunning

while len(asyncore.socket_map):
# Check the sockets for activity.
asyncore.poll(0.05)
# Yield to give other tasklets a chance to be scheduled.
stackless.schedule()

managerRunning = False
There I was looping and polling asyncore. I can easily substitute in the polling I already have built up with GetQueuedCompletionStatus.
managerRunning = False

def ManageSockets():
global managerRunning

wsaData = WSADATA()
ret = WSAStartup(MAKEWORD(2, 2), LP_WSADATA(wsaData))
if ret != 0:
raise WinError(ret)

hIOCP = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL_HANDLE, NULL, NULL)
if hIOCP == 0:
WSACleanup()
raise WinError()

numberOfBytes = DWORD()
completionKey = c_ulong()
ovCompletedPtr = POINTER(OVERLAPPED)()

while True:
while True:
# Yield to give other tasklets a chance to be scheduled.
stackless.schedule()

ret = GetQueuedCompletionStatus(hIOCP, byref(numberOfBytes), byref(completionKey), byref(ovCompletedPtr), 50)
if ret == FALSE:
err = WSAGetLastError()
if err == WAIT_TIMEOUT:
continue

ovCompletedPtr.contents.channel.send_exception(WinError, err)
continue

break

# Handle the completed packet.
ovCompletedPtr.contents.channel.send(numberOfBytes.value)

managerRunning = False
There are several things to note here.

As in the asyncore polling, I regularly call stackless.schedule() in order to yield to the scheduler and allow any other tasklets that might be present within it to get a chance to execute.

The custom attribute channel in the OVERLAPPED structure is expected to have had a stackless.channel instance assigned to it. In the case of an error, I can wake up the tasklet that initiated the asynchronous IO and raise an exception on it. In the case of success, I can return important information to the logic it invoked to trigger that IO, so that it can more easily handle the success case. The only piece of information I can see being of use is numberOfBytes, as that saves me from having to query it manually using WSAGetOverlappedResult.

Now I need to make a socket class that is compatible with the standard Python one.
class socket:
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=IPPROTO_TCP):
ret = WSASocket(family, type, proto, None, 0, WSA_FLAG_OVERLAPPED)
if ret == INVALID_SOCKET:
raise WinError()

# Bind the socket to the shared IO completion port.
CreateIoCompletionPort(ret, hIOCP, NULL, NULL)

self._socket = ret
That gives me a fully set up socket object, on instantiation. Refactoring send, recv, accept and bind methods onto it from the code I have already written, should be similarly as straightforward.

Starting with send:
    def send(self, data):
self.sendBuffer[0].buf = data
self.sendBuffer[0].len = len(data)

bytesSent = DWORD()
ovSend = OVERLAPPED()
c = ovSend.channel = stackless.channel()

ret = WSASend(self._socket, cast(self.sendBuffer, POINTER(WSABUF)), 1, byref(bytesSent), 0, byref(ovSend), 0)
if ret != 0:
err = WSAGetLastError()
# The operation was successful and is currently in progress. Ignore this error...
if err != ERROR_IO_PENDING:
Cleanup()
raise WinError(err)

# Return the number of bytes that were send.
return c.receive()
This was a pretty direct conversion. The only additions were the setting of the channel attribute, blocking on the channel until we get returned the number of bytes sent or an except raised through it and returning the number of bytes sent to the caller.

Next is the recv method:
    def recv(self, byteCount, flags=0):
if self.recvBuffer is None:
self.recvBuffer = (WSABUF * 1)()
self.recvBuffer[0].buf = ' ' * READ_BUFFER_SIZE
self.recvBuffer[0].len = READ_BUFFER_SIZE

# WARNING: For now, we cap the readable amount to size of the preallocated buffer.
byteCount = min(byteCount, READ_BUFFER_SIZE)

numberOfBytesRecvd = DWORD()
flags = DWORD()
ovRecv = OVERLAPPED()
c = ovRecv.channel = stackless.channel()

ret = WSARecv(self._socket, cast(self.recvBuffer, POINTER(WSABUF)), 1, byref(numberOfBytesRecvd), byref(flags), byref(ovRecv), 0)
if ret != 0:
err = WSAGetLastError()
# The operation was successful and is currently in progress. Ignore this error...
if err != ERROR_IO_PENDING:
raise WinError(err)

# Block until the overlapped operation completes.
numberOfBytes = c.receive()
return self.recvBuffer[0].buf[:numberOfBytes]
This was also straightforward. Again, a channel was provided and blocked upon. A received data buffer is created on the first call and reused on subsequent calls, and the amount of data that is allowed to be received is capped to the size of that buffer. On success and a returned number of bytes received, the appropriate segment of the data buffer is sliced and returned to the caller.

In the event that the remote connection disconnected, the channel would have indicated that 0 bytes were received and this would have resulted in an empty string being returned. This is fine, as it is the right way to indicate a socket has disconnected on a recv call.

bind is extremely straightforward, so that came next:
    def bind(self, address):
host, port = address

sa = sockaddr_in()
sa.sin_family = AF_INET
sa.sin_addr.s_addr = inet_addr(host)
sa.sin_port = htons(port)

ret = bind(self._socket, sockaddr_inp(sa), sizeof(sa))
if ret == SOCKET_ERROR:
raise WinError()
As is listen:
    def listen(self, backlog):
ret = listen(self._socket, backlog)
if ret != 0:
raise WinError()
A little more complex is accept:
    def accept(self):
dwReceiveDataLength = 0
dwLocalAddressLength = sizeof(sockaddr_in) + 16
dwRemoteAddressLength = sizeof(sockaddr_in) + 16
outputBuffer = create_string_buffer(dwReceiveDataLength + dwLocalAddressLength + dwRemoteAddressLength)

dwBytesReceived = DWORD()
ovAccept = OVERLAPPED()
c = ovAccept.channel = stackless.channel()

acceptSocket = socket()

ret = AcceptEx(self._socket, acceptSocket._socket, outputBuffer, dwReceiveDataLength, dwLocalAddressLength, dwRemoteAddressLength, byref(dwBytesReceived), byref(ovAccept))
if ret == FALSE:
err = WSAGetLastError()
# The operation was successful and is currently in progress. Ignore this error...
if err != ERROR_IO_PENDING:
closesocket(acceptSocket._socket)
raise WinError(err)

# Block until the overlapped operation completes.
c.receive()

localSockaddr = sockaddr_in()
localSockaddrSize = c_int(sizeof(sockaddr_in))
remoteSockaddr = sockaddr_in()
remoteSockaddrSize = c_int(sizeof(sockaddr_in))

GetAcceptExSockaddrs(outputBuffer, dwReceiveDataLength, dwLocalAddressLength, dwRemoteAddressLength, byref(localSockaddr), byref(localSockaddrSize), byref(remoteSockaddr), byref(remoteSockaddrSize))

hostbuf = create_string_buffer(NI_MAXHOST)
servbuf = c_char_p()

port = ntohs(localSockaddr.sin_port)

localSockaddr.sin_family = AF_INET
ret = getnameinfo(localSockaddr, sizeof(sockaddr_in), hostbuf, sizeof(hostbuf), servbuf, 0, NI_NUMERICHOST)
if ret != 0:
err = WSAGetLastError()
closesocket(acceptSocket._socket)
raise WinError(err)

# host = inet_ntoa(localSockaddr.sin_addr)

return (acceptSocket, (hostbuf.value, port))
This was refactored in much the same way as the other asynchronous methods (send and recv), with a channel to block for the result upon. However it got a little more complex when it came to providing the expected address part of the return value. When I connect to a standard Python socket listening on my local machine, with a telnet connection from the same machine, I see the host as the standard localhost address (127.0.0.1). However, this did not give the same result:
host = inet_ntoa(localSockaddr.sin_addr)
Instead it would give the address 0.0.0.0. So, I looked at how the low-level Python C source code did it, and saw it was calling getnameinfo. My next attempt based on that, was the following code:
        localSockaddr.sin_family = AF_INET
ret = getnameinfo(localSockaddr, sizeof(sockaddr_in), hostbuf, sizeof(hostbuf), servbuf, 0, NI_NUMERICHOST)
if ret != 0:
err = WSAGetLastError()
closesocket(acceptSocket._socket)
raise WinError(err)
However, this gives the same host address. For now, you can see both in the accept logic. I have to assume that this is just some idiosyncracy that does not really matter.

One thing I did notice though, was that my in_addr structure was not working correctly. It turns out that I had defined the sockaddr_in structure incorrectly.

Instead of:
class _UN(Structure):
_fields_ = [
("s_un_b", _UN_b),
("s_un_w", _UN_w),
("s_addr", c_ulong),
]

class in_addr(Union):
_fields_ = [
("s_un", _UN),
]
_anonymous_ = ("s_un",)
I should have had:
class in_addr(Union):
_fields_ = [
("s_un_b", _UN_b),
("s_un_w", _UN_w),
("s_addr", c_ulong),
]
As a foreign function interface solution, ctypes is extremely impressive and allows you to do pretty much anything you want. Union structures however, seem to be strangely hard to get right. With this change, now the in_addr field of sockaddr_in is working correctly, although it did not fix my host address problem.

This completes the basic range of socket functionality. A good next step from here is to write a simple echo server, with standard Python blocking socket usage within Stackless Python.
def Run():
address = ("127.0.0.1", 3000)
listenSocket = socket(AF_INET, SOCK_STREAM)
listenSocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
listenSocket.bind(address)
listenSocket.listen(5)

def handle_echo(_socket, _address):
while True:
data = currentSocket.recv(256)
if data == "":
print _address, "DISCONNECTED"
return

print _address, "READ", data, len(data)
dlen = currentSocket.send(data)
print _address, "ECHOD", dlen

while True:
print "Waiting for new connection"
currentSocket, clientAddress = listenSocket.accept()
print "Connection", currentSocket.fileno(), "from", clientAddress

stackless.tasklet(handle_echo)(currentSocket, clientAddress)

if __name__ == "__main__":
StartManager()

print "STARTED"
stackless.tasklet(Run)()
stackless.run()
print "EXITED"
This works perfectly, however there is one problem. When control-c is pressed, the Python interpreter will be killed due to an access violation or some such thing.

A reasonable assumption would be that Stackless tasklets blocked on IOCP related actions are not being cleaned up properly. In order to make sure the required clean up happens, I wrapped the main polling loop with suitable logic.
    try:
# Call the polling loop.
finally:
_CleanupActiveIO()
CloseHandle(hIOCP)
WSACleanup()
The definition of _CleanupActiveIO was as follows:
def _CleanupActiveIO():
for k, v in activeIO.items():
ret = CancelIo(k)
if ret == 0:
raise WinError()

# Any tasklets blocked on IO are killed silently.
v.send_exception(TaskletExit)
Raising a real exception on the blocked tasklets, causes the interpreter to exit on the given tasklet with that exception. I want them to be killed silently, having cleaned up properly and for the exception on the main tasklet to be the one that causes the interpreter to exit and be displayed.

In order for the cleanup method to know about the channels in use and tasklets blocked on them, any method that started asynchronous IO had the following line added before its channel receive call.
        activeIO[self._socket] = c
CancelIO is of course a Windows function.
BOOL WINAPI CancelIo(
__in HANDLE hFile
);
With the matching ctypes definition:
CancelIo = windll.kernel32.CancelIo
CancelIo.argtypes = (HANDLE,)
CancelIo.restype = BOOL
Next steps

At this point, I have a socket object that can be used with Stackless Python to do synchronous IO that is actually transparently doing asynchronous IO using IO completion ports behind the scenes. It still needs to be fleshed out into a module that can be monkey-patched in place of the standard socket module, as the stacklesssocket.py module can.

Next post: IOCP-based sockets with ctypes in Python: 7
Script source code: 05 - Socket.py

No comments:

Post a Comment