#!/usr/bin/env python2 # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function import crcmod import logging import struct from enum import Enum from time import sleep try: from time import monotonic except: from .missing import monotonic csum_func = crcmod.predefined.mkCrcFun('xmodem') class RatpState(Enum): listen = "listen" # 1 syn_sent = "syn-sent" # 2 syn_received = "syn-received" # 3 established = "established" # 4 fin_wait = "fin-wait" # 5 last_ack = "last-ack" # 6 closing = "closing" # 7 time_wait = "time-wait" # 8 closed = "closed" # 9 class RatpInvalidHeader(ValueError): pass class RatpInvalidPayload(ValueError): pass class RatpError(ValueError): pass class RatpPacket(object): def __init__(self, data=None, flags=''): self.payload = None self.synch = 0x01 self._control = 0 self.length = 0 self.csum = 0 self.c_syn = False self.c_ack = False self.c_fin = False self.c_rst = False self.c_sn = 0 self.c_an = 0 self.c_eor = False self.c_so = False if data: (self.synch, self._control, self.length, self.csum) = \ struct.unpack('!BBBB', data) if self.synch != 0x01: raise RatpInvalidHeader("invalid synch octet (%x != %x)" % (self.synch, 0x01)) csum = (self._control + self.length + self.csum) & 0xff if csum != 0xff: raise RatpInvalidHeader("invalid csum octet (%x != %x)" % (csum, 0xff)) self._unpack_control() elif flags: if 'S' in flags: self.c_syn = True if 'A' in flags: self.c_ack = True if 'F' in flags: self.c_fin = True if 'R' in flags: self.c_rst = True if 'E' in flags: self.c_eor = True def __repr__(self): s = "RatpPacket(" if self.c_syn: s += "SYN," if self.c_ack: s += "ACK," if self.c_fin: s += "FIN," if self.c_rst: s += "RST," s += "SN=%i,AN=%i," % (self.c_sn, self.c_an) if self.c_eor: s += "EOR," if self.c_so: s += "SO,DATA=%i)" % self.length else: s += "DATA=%i)" % self.length return s def _pack_control(self): self._control = 0 | \ self.c_syn << 7 | \ self.c_ack << 6 | \ self.c_fin << 5 | \ self.c_rst << 4 | \ self.c_sn << 3 | \ self.c_an << 2 | \ self.c_eor << 1 | \ self.c_so << 0 def _unpack_control(self): self.c_syn = bool(self._control & 1 << 7) self.c_ack = bool(self._control & 1 << 6) self.c_fin = bool(self._control & 1 << 5) self.c_rst = bool(self._control & 1 << 4) self.c_sn = bool(self._control & 1 << 3) self.c_an = bool(self._control & 1 << 2) self.c_eor = bool(self._control & 1 << 1) self.c_so = bool(self._control & 1 << 0) def pack(self): self._pack_control() self.csum = 0 self.csum = (self._control + self.length + self.csum) self.csum = (self.csum & 0xff) ^ 0xff return struct.pack('!BBBB', self.synch, self._control, self.length, self.csum) def unpack_payload(self, payload): (c_recv,) = struct.unpack('!H', payload[-2:]) c_calc = csum_func(payload[:-2]) if c_recv != c_calc: raise RatpInvalidPayload("bad checksum (%04x != %04x)" % (c_recv, c_calc)) self.payload = payload[:-2] def pack_payload(self): c_calc = csum_func(self.payload) return self.payload+struct.pack('!H', c_calc) class RatpConnection(object): def __init__(self): self._state = RatpState.closed self._passive = True self._input = b'' self._s_sn = 0 self._r_sn = 0 self._retrans = None self._retrans_counter = None self._retrans_deadline = None self._r_mdl = None self._s_mdl = 0xff self._rx_buf = [] # reassembly buffer self._rx_queue = [] self._tx_queue = [] self._rtt_alpha = 0.8 self._rtt_beta = 2.0 self._srtt = 0.2 self._rto_min, self._rto_max = 0.2, 1 self._tx_timestamp = None self.total_retransmits = 0 self.total_crc_errors = 0 def _update_srtt(self, rtt): self._srtt = (self._rtt_alpha * self._srtt) + \ ((1.0 - self._rtt_alpha) * rtt) logging.info("SRTT: %r", self._srtt) def _get_rto(self): return min(self._rto_max, max(self._rto_min, self._rtt_beta * self._srtt)) def _write(self, pkt): if pkt.payload or pkt.c_so or pkt.c_syn or pkt.c_rst or pkt.c_fin: self._s_sn = pkt.c_sn if not self._retrans: self._retrans = pkt self._retrans_counter = 0 else: self.total_retransmits += 1 self._retrans_counter += 1 if self._retrans_counter > 10: raise RatpError("Maximum retransmit count exceeded") self._retrans_deadline = monotonic()+self._get_rto() logging.info("Write: %r", pkt) self._write_raw(pkt.pack()) if pkt.payload: self._write_raw(pkt.pack_payload()) self._tx_timestamp = monotonic() def _check_rto(self): if self._retrans is None: return if self._retrans_deadline < monotonic(): logging.debug("Retransmit...") self._write(self._retrans) def _check_time_wait(self): if not self._state == RatpState.time_wait: return remaining = self._time_wait_deadline - monotonic() if remaining < 0: self._state = RatpState.closed else: logging.debug("Time-Wait: %.2f remaining" % remaining) sleep(min(remaining, 0.1)) def _read(self): if len(self._input) < 4: self._input += self._read_raw(4-len(self._input)) if len(self._input) < 4: return try: pkt = RatpPacket(data=self._input[:4]) except RatpInvalidHeader as e: logging.info("%r", e) self._input = self._input[1:] return self._input = self._input[4:] logging.info("Read: %r", pkt) if pkt.c_syn or pkt.c_rst or pkt.c_so or pkt.c_fin: return pkt if pkt.length == 0: return pkt while len(self._input) < pkt.length+2: self._input += self._read_raw() try: pkt.unpack_payload(self._input[:pkt.length+2]) except RatpInvalidPayload as e: self.total_crc_errors += 1 return finally: self._input = self._input[pkt.length+2:] return pkt def _close(self): pass def _a(self, r): logging.info("A") if r.c_rst: return True if r.c_ack: s = RatpPacket(flags='R') s.c_sn = r.c_an self._write(s) return False if r.c_syn: self._r_mdl = r.length s = RatpPacket(flags='SA') s.c_sn = 0 s.c_an = (r.c_sn + 1) % 2 s.length = self._s_mdl self._write(s) self._state = RatpState.syn_received return False return False def _b(self, r): logging.info("B") if r.c_ack and r.c_an != (self._s_sn + 1) % 2: if r.c_rst: return False else: s = RatpPacket(flags='R') s.c_sn = r.c_an self._write(s) return False if r.c_rst: if r.c_ack: self._retrans = None # FIXME: delete the TCB self._state = RatpState.closed return False else: return False if r.c_syn: if r.c_ack: self._r_mdl = r.length self._retrans = None self._r_sn = r.c_sn s = RatpPacket(flags='A') s.c_sn = r.c_an s.c_an = (r.c_sn + 1) % 2 self._write(s) self._state = RatpState.established return False else: self._retrans = None s = RatpPacket(flags='SA') s.c_sn = 0 s.c_an = (r.c_sn + 1) % 2 s.length = self._s_mdl self._write(s) self._state = RatpState.syn_received return False return False def _c1(self, r): logging.info("C1") if r.c_sn != self._r_sn: return True if r.c_rst or r.c_fin: return False s = RatpPacket(flags='A') s.c_sn = r.c_an s.c_an = (r.c_sn + 1) % 2 self._write(s) return False def _c2(self, r): logging.info("C2") if r.c_sn != self._r_sn: return True if r.c_rst or r.c_fin: return False if r.c_syn: s = RatpPacket(flags='RA') s.c_sn = r.c_an s.c_an = (r.c_sn + 1) % 2 self._write(s) self._retrans = None # FIXME: inform the user "Error: Connection reset" self._state = RatpState.closed return False logging.info("C2: duplicate packet") s = RatpPacket(flags='A') s.c_sn = r.c_an s.c_an = (r.c_sn + 1) % 2 self._write(s) return False def _d1(self, r): logging.info("D1") if not r.c_rst: return True if self._passive: self._retrans = None self._state = RatpState.listen return False else: self._retrans = None self._state = RatpState.closed raise RatpError("Connection refused") def _d2(self, r): logging.info("D2") if not r.c_rst: return True self._retrans = None self._state = RatpState.closed raise RatpError("Connection reset") def _d3(self, r): logging.info("C3") if not r.c_rst: return True self._state = RatpState.closed return False def _e(self, r): logging.info("E") if not r.c_syn: return True self._retrans = None s = RatpPacket(flags='R') if r.c_ack: s.c_sn = r.c_an else: s.c_sn = 0 self._write(s) self._state = RatpState.closed raise RatpError("Connection reset") def _f1(self, r): logging.info("F1") if not r.c_ack: return False if r.c_an == (self._s_sn + 1) % 2: return True if self._passive: self._retrans = None s = RatpPacket(flags='R') s.c_sn = r.c_an self._write(s) self._state = RatpState.listen return False else: self._retrans = None s = RatpPacket(flags='R') s.c_sn = r.c_an self._write(s) self._state = RatpState.closed raise RatpError("Connection refused") def _f2(self, r): logging.info("F2") if not r.c_ack: return False if r.c_an == (self._s_sn + 1) % 2: if self._retrans: self._retrans = None self._update_srtt(monotonic()-self._tx_timestamp) # FIXME: inform the user with an "Ok" if a buffer has been # entirely acknowledged. Another packet containing data may # now be sent. return True return True def _f3(self, r): logging.info("F3") if not r.c_ack: return False if r.c_an == (self._s_sn + 1) % 2: return True return True def _g(self, r): logging.info("G") if not r.c_rst: return False self._retrans = None if r.c_ack: s = RatpPacket(flags='R') s.c_sn = r.c_an self._write(s) else: s = RatpPacket(flags='RA') s.c_sn = r.c_an s.c_an = (r.c_sn + 1) % 2 self._write(s) return False def _h1(self, r): logging.info("H1") self._state = RatpState.established return self._common_i1(r) def _h2(self, r): logging.info("H2") if not r.c_fin: return True if self._retrans is not None: # FIXME: inform the user "Warning: Data left unsent.", "Connection closing." self._retrans = None s = RatpPacket(flags='FA') s.c_sn = r.c_an s.c_an = (r.c_sn + 1) % 2 self._write(s) self._state = RatpState.last_ack raise RatpError("Connection closed by remote") def _h3(self, r): logging.info("H3") if not r.c_fin: # Our fin was lost, rely on retransmission return False if (r.length and not r.c_syn and not r.c_rst and not r.c_fin) or r.c_so: self._retrans = None s = RatpPacket(flags='RA') s.c_sn = r.c_an s.c_an = (r.c_sn + 1) % 2 self._write(s) self._state = RatpState.closed raise RatpError("Connection reset") if r.c_an == (self._s_sn + 1) % 2: self._retrans = None s = RatpPacket(flags='A') s.c_sn = r.c_an s.c_an = (r.c_sn + 1) % 2 self._write(s) self._time_wait_deadline = monotonic() + self._get_rto() self._state = RatpState.time_wait return False else: self._retrans = None s = RatpPacket(flags='A') s.c_sn = r.c_an s.c_an = (r.c_sn + 1) % 2 self._write(s) self._state = RatpState.closing return False def _h4(self, r): logging.info("H4") if r.c_an == (self._s_sn + 1) % 2: self._retrans = None self._time_wait_deadline = monotonic() + self._get_rto() self._state = RatpState.time_wait return False return False def _h5(self, r): logging.info("H5") if r.c_an == (self._s_sn + 1) % 2: self._time_wait_deadline = monotonic() + self._get_rto() self._state = RatpState.time_wait return False return False def _h6(self, r): logging.info("H6") if not r.c_ack: return False if not r.c_fin: return False self._retrans = None s = RatpPacket(flags='A') s.c_sn = r.c_an s.c_an = (r.c_sn + 1) % 2 self._write(s) self._time_wait_deadline = monotonic() + self._get_rto() return False def _common_i1(self, r): if r.c_so: self._r_sn = r.c_sn self._rx_buf.append(chr(r.length)) elif r.length and not r.c_syn and not r.c_rst and not r.c_fin: self._r_sn = r.c_sn self._rx_buf.append(r.payload) else: return False # reassemble if r.c_eor: logging.info("Reassembling %i frames", len(self._rx_buf)) self._rx_queue.append(''.join(self._rx_buf)) self._rx_buf = [] s = RatpPacket(flags='A') s.c_sn = r.c_an s.c_an = (r.c_sn + 1) % 2 self._write(s) return False def _i1(self, r): logging.info("I1") return self._common_i1(r) def _machine(self, pkt): logging.info("State: %r", self._state) if self._state == RatpState.listen: self._a(pkt) elif self._state == RatpState.syn_sent: self._b(pkt) elif self._state == RatpState.syn_received: self._c1(pkt) and \ self._d1(pkt) and \ self._e(pkt) and \ self._f1(pkt) and \ self._h1(pkt) elif self._state == RatpState.established: self._c2(pkt) and \ self._d2(pkt) and \ self._e(pkt) and \ self._f2(pkt) and \ self._h2(pkt) and \ self._i1(pkt) elif self._state == RatpState.fin_wait: self._c2(pkt) and \ self._d2(pkt) and \ self._e(pkt) and \ self._f3(pkt) and \ self._h3(pkt) elif self._state == RatpState.last_ack: self._c2(pkt) and \ self._d3(pkt) and \ self._e(pkt) and \ self._f3(pkt) and \ self._h4(pkt) elif self._state == RatpState.closing: self._c2(pkt) and \ self._d3(pkt) and \ self._e(pkt) and \ self._f3(pkt) and \ self._h5(pkt) elif self._state == RatpState.time_wait: self._d3(pkt) and \ self._e(pkt) and \ self._f3(pkt) and \ self._h6(pkt) elif self._state == RatpState.closed: self._g(pkt) def wait(self, deadline): while deadline is None or deadline > monotonic(): pkt = self._read() if pkt: self._machine(pkt) else: self._check_rto() self._check_time_wait() if not self._retrans or self._rx_queue: return def wait1(self, deadline): while deadline is None or deadline > monotonic(): pkt = self._read() if pkt: self._machine(pkt) else: self._check_rto() self._check_time_wait() if not self._retrans: return def listen(self): logging.info("LISTEN") self._state = RatpState.listen def connect(self, timeout=5.0): deadline = monotonic() + timeout logging.info("CONNECT") self._retrans = None syn = RatpPacket(flags='S') syn.length = self._s_mdl self._write(syn) self._state = RatpState.syn_sent self.wait(deadline) def send_one(self, data, eor=True, timeout=1.0): deadline = monotonic() + timeout logging.info("SEND_ONE (len=%i, eor=%r)", len(data), eor) assert self._state == RatpState.established assert self._retrans is None snd = RatpPacket(flags='A') snd.c_eor = eor snd.c_sn = (self._s_sn + 1) % 2 snd.c_an = (self._r_sn + 1) % 2 snd.length = len(data) snd.payload = data self._write(snd) self.wait1(deadline=None) def send(self, data, timeout=1.0): logging.info("SEND (len=%i)", len(data)) while len(data) > 255: self.send_one(data[:255], eor=False, timeout=timeout) data = data[255:] self.send_one(data, eor=True, timeout=timeout) def recv(self, timeout=1.0): deadline = monotonic() + timeout assert self._state == RatpState.established if self._rx_queue: return self._rx_queue.pop(0) self.wait(deadline) if self._rx_queue: return self._rx_queue.pop(0) def close(self, timeout=1.0): deadline = monotonic() + timeout logging.info("CLOSE") if self._state == RatpState.established or self._state == RatpState.syn_received: fin = RatpPacket(flags='FA') fin.c_sn = (self._s_sn + 1) % 2 fin.c_an = (self._r_sn + 1) % 2 self._write(fin) self._state = RatpState.fin_wait while deadline > monotonic() and not self._state == RatpState.time_wait: self.wait(deadline) while self._state == RatpState.time_wait: self.wait(None) if self._state == RatpState.closed: logging.info("CLOSE: success") else: logging.info("CLOSE: failure") def abort(self): logging.info("ABORT") def status(self): logging.info("STATUS") return self._state class SerialRatpConnection(RatpConnection): def __init__(self, port): super(SerialRatpConnection, self).__init__() self.__port = port self.__port.timeout = 0.01 self.__port.writeTimeout = None self.__port.flushInput() def _write_raw(self, data): if data: logging.debug("-> %r", bytearray(data)) return self.__port.write(data) def _read_raw(self, size=1): data = self.__port.read(size) if data: logging.debug("<- %r", bytearray(data)) return data