base.py 6.05 KB
Newer Older
m!nus's avatar
m!nus committed
1 2 3
import sys
import socket
import select
4
import queue
m!nus's avatar
m!nus committed
5 6 7 8 9 10 11 12 13 14 15
from struct import unpack
from time import time
from collections import defaultdict
import logging

logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(funcName)s: %(message)s", level=logging.DEBUG)
L = logging.getLogger(__name__)

def get_address(host, port=8303, family=0):
	try:
		info = socket.getaddrinfo(host, port, family, socket.SOCK_DGRAM)
16
		return Address(info[0][4][0], info[0][4][1], family)
m!nus's avatar
m!nus committed
17 18 19 20
	except socket.gaierror as e:
		L.warning('getaddrinfo failed: ' + str(e))
		return None

21
def listdata2addresslist(listdata):
m!nus's avatar
m!nus committed
22 23 24
	"""takes 6 or 18 bytes of data and extracts IPv4/v6 addresses from it
	returns a tuple (family, address)
	"""
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
	addresses = []
	for i in range(0, len(listdata), 18):
		data = listdata[i:i+18]

		# ::ffff:0:0/96 == IPv4 mapping
		if data[0:12] == b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff":
			data = data[12:18]

		if len(data) == 6:
			host = socket.inet_ntoa(data[:4])
			port = unpack("!H", data[4:])[0]
			addresses.append(Address(host, port, socket.AF_INET))
		elif len(data) == 18:
			host = None
			port = unpack("!H", data[16:])[0]
			if sys.platform == "win32":
				segments = []
				for (a, b) in (data[:16:2], data[1:16:2]):
					segments.append("{:x}".format((ord(a)<<8) + ord(b)))
				host = ':'.join(segments)
			else:
				host = socket.inet_ntop(socket.AF_INET6, data[:16])
			addresses.append(Address(host, port, socket.AF_INET6))
m!nus's avatar
m!nus committed
48
		else:
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
			raise Exception("Invalid IP data")

	return addresses


class Address(object):
	def __init__(self, host=None, port=None, family=None):
		self.host = host
		self.port = port
		self.family = family

	def address_tuple(self):
		return (self.host, self.port)

	def __lt__(self, other):
		return self.address_tuple() < other.address_tuple()

	def __str__(self):
		if self.family == socket.AF_INET6:
			return "[{host}]:{port}".format(self.host, self.port)
		return "{host}:{port}".format(host=self.host, port=self.port)
m!nus's avatar
m!nus committed
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85


class Request(object):
	"""network communication data wrapper"""
	address = None

	def sent(self):
		"""callback: request has been processed and sent"""
		pass

	def response_received(self, data):
		"""callback: response received
		return True if more data is expected"""
		return False

	def get_address(self):
86
		"""must return destination address, an Address object"""
m!nus's avatar
m!nus committed
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
		return self.address

	def get_data(self):
		"""must return data to be sent"""
		raise NotImplementedError()

	def __str__(self):
		return "<Request to {} with {} bytes>".format(self.get_address(), len(self.get_data()))


class EventSocket(object):
	"""handles low-level network communication, supports queuing"""
	def __init__(self, packets_per_second=200, idle_limit=10):
		self._max_packets_per_second = packets_per_second
		self._idle_counter = 0
		self._idle_limit = idle_limit
		self._packet_rate = 0
		self._packet_rate_last_update = 0
		self._sockets = {}
		# dict of sent requets, so we know where to return received data to
		self._requests = defaultdict(set)
		self._queue = queue.Queue()
		self._sockets[socket.AF_INET] = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.SOL_UDP)
		self.has_ipv6 = socket.has_ipv6
		if self.has_ipv6:
			self._sockets[socket.AF_INET6] = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, socket.SOL_UDP)

	def _packet_rate_update(self):
		"""Update the rate limiting counter"""
		cur_time = time()
		diff = cur_time - self._packet_rate_last_update
		self._packet_rate_last_update = cur_time
		self._packet_rate -= diff*self._max_packets_per_second
		if self._packet_rate < 0:
			self._packet_rate = 0

	def _select(self):
		"""select wrapper raising exception on timeout"""
		timeout = 1.0/self._max_packets_per_second
		if self._queue.empty():
			timeout = 1.0
		ret = select.select(self._sockets.values(), [], [], timeout)
		#L.debug("selected: r={} w={} x={}".format(*ret))
		#if ret == ([], [], []):
		#	raise socket.timeout('select timed out')
		#else:
		#	return ret
		return ret

	def _send(self, request):
		"""Actually send request"""
		socket_type = socket.AF_INET
139 140
		addr = request.get_address()
		if addr.family == socket.AF_INET6:
m!nus's avatar
m!nus committed
141 142 143
			if not self.has_ipv6:
				raise socket.error("Cannot send IPv6 packet without IPv6 socket")
			socket_type = socket.AF_INET6
144 145
		self._requests[addr.address_tuple()].add(request)
		length = self._sockets[socket_type].sendto(request.get_data(), addr.address_tuple())
m!nus's avatar
m!nus committed
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
		L.debug("Sent {}".format(request))
		request.sent()
		self._packet_rate += 1
		return length

	def send(self, request):
		"""Queue Request for sending"""
		assert(isinstance(request, Request))
		self._queue.put(request)

	def run(self):
		"""Main loop"""
		while True:
			self._packet_rate_update()

			if not self._queue.empty() and self._packet_rate < self._max_packets_per_second:
				request = self._queue.get()
				data = request.get_data()
				try:
					length = self._send(request)
					if length != len(data):
						L.warning('Sent {} of {} bytes of {}'.format(length, len(data)))
				except socket.error as e:
					if e.errno == 10054: # ICMP port unreachable
						L.debug('ICMP port unreachable, {} discarded'.format(request))
					else: raise

			(r, _, _) = self._select()

			# read from all receivable sockets
			for sock in r:
				(data, address) = sock.recvfrom(65535)
				rem = []
				for request in self._requests[address]:
					if not request.response_received(data):
						rem.append(request)
				if len(self._requests[address]) == 0:
					L.warning("Nothing sent to {} but received response".format(address))
				self._requests[address].difference_update(rem)

			if not r and self._queue.empty():
				self._idle_counter += 1
				remaining_requests = 0
				for requestlist in self._requests.values():
					remaining_requests += len(requestlist)
				if remaining_requests == 0:
					L.info("No more outstanding requests, returning")
					return
				if self._idle_counter > self._idle_limit:
					L.info("Idle limit hit, returning")
					remaining_requests = []
					for requestlist in self._requests.values():
						remaining_requests += requestlist
					L.debug("Remaining sent requests: {}".format(remaining_requests))
					return
			else:
				self._idle_counter = 0