import asyncfutures, asyncdispatch, asyncnet, strformat from net import IpAddress, Port, `$`, `==`, getPrimaryIPAddr, toSockAddr, parseIpAddress from nativesockets import SockAddr, Sockaddr_storage, SockLen, setSockOptInt from sequtils import any import asyncutils import ip_packet import network_interface import raw_socket var IPPROTO_IP {.importc: "IPPROTO_IP", header: "".}: cint var IP_TTL {.importc: "IP_TTL", header: "".}: cint const Timeout = 3000 type ConnectAttempt* = ref object srcIp: IpAddress srcPort: Port dstIp: IpAddress dstPorts: seq[Port] firewallRules: seq[string] AcceptAttempt* = ref object srcIp: IpAddress srcPort: Port dstIp: IpAddress dstPorts: seq[Port] seqNums: seq[uint32] firewallRules: seq[string] future: Future[AsyncSocket] TcpSyniPuncher* = ref object connectAttempts: seq[ConnectAttempt] acceptAttempts: seq[AcceptAttempt] PunchProgressCb* = proc (seqNums: seq[uint32]) {.async.} PunchHoleError* = object of ValueError proc makeFirewallRule(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port): string = result = fmt"""-w \ -d {srcIp} \ -p icmp \ --icmp-type time-exceeded \ -m conntrack \ --ctstate RELATED \ --ctproto tcp \ --ctorigsrc {srcIp} \ --ctorigsrcport {srcPort.int} \ --ctorigdst {dstIp} \ --ctorigdstport {dstPort.int} \ -j DROP""" proc iptablesInsert(chain: string, rule: string) {.async.} = let firewall_cmd = fmt"iptables -I {chain} {rule}" discard await asyncExecCmd(firewall_cmd) proc iptablesDelete(chain: string, rule: string) {.async.} = let firewall_cmd = fmt"iptables -D {chain} {rule}" discard await asyncExecCmd(firewall_cmd) proc addFirewallRules[T](attempt: T) {.async.} = for dstPort in attempt.dstPorts: let rule = makeFirewallRule(attempt.srcIp, attempt.srcPort, attempt.dstIp, dstPort) try: await iptablesInsert("INPUT", rule) attempt.firewallRules.add(rule) except OSError as e: echo "cannot add firewall rule: ", e.msg raise newException(PunchHoleError, e.msg) proc deleteFirewallRules[T](attempt: T) {.async.} = for rule in attempt.firewallRules: # FIXME: close sock? try: await iptablesDelete("INPUT", rule) except OSError: # At least we tried discard proc injectTcpPacket(rawFd: AsyncFD, ipPacket: IpPacket) {.async.} = assert(ipPacket.protocol == tcp) try: let packet = serialize(ipPacket) var sockaddr: Sockaddr_storage var sockaddrLen: SockLen toSockAddr(ipPacket.ipAddrDst, ipPacket.tcpPortDst, sockaddr, sockaddrLen) await rawFd.sendTo(packet.cstring, packet.len, cast[ptr SockAddr](addr sockaddr), sockaddrLen) except OSError as e: raise newException(PunchHoleError, e.msg) proc captureSeqNumbers(attempt: ConnectAttempt, rawFd: AsyncFD, cb: PunchProgressCb) {.async.} = # FIXME: timeout? # FIXME: create raw socket here var seqNums = newSeq[uint32]() while seqNums.len < attempt.dstPorts.len: let packet = await rawFd.recv(4000) if packet == "": break let parsed = parseEthernetPacket(packet) if parsed.protocol == tcp and parsed.ipAddrSrc == attempt.srcIp and parsed.tcpPortSrc.int == attempt.srcPort.int and parsed.ipAddrDst == attempt.dstIp and parsed.tcpFlags == {SYN}: for port in attempt.dstPorts: if parsed.tcpPortDst.int == port.int: seqNums.add(parsed.tcpSeqNumber) break closeSocket(rawFd) await cb(seqNums) proc captureAndResendAck(attempt: ConnectAttempt, captureFd: AsyncFD, injectFd: AsyncFD) {.async.} = # FIXME: create raw socket here block loops: while true: let packet = await captureFd.recv(4000) if packet == "": break var parsed = parseEthernetPacket(packet) if parsed.protocol == tcp and parsed.ipAddrSrc == attempt.srcIp and parsed.tcpPortSrc.int == attempt.srcPort.int and parsed.ipAddrDst == attempt.dstIp and parsed.tcpFlags == {ACK}: for port in attempt.dstPorts: if parsed.tcpPortDst.int == port.int: parsed.ipTTL = 64 echo &"[{parsed.ipAddrSrc}:{parsed.tcpPortSrc.int} -> {parsed.ipAddrDst}:{parsed.tcpPortDst}, SEQ {parsed.tcpSeqNumber}] resending ACK with TTL {parsed.ipTTL}" await injectFd.injectTcpPacket(parsed) break loops closeSocket(captureFd) closeSocket(injectFd) proc initPuncher*(): TcpSyniPuncher = TcpSyniPuncher() proc findConnectAttempt(puncher: TcpSyniPuncher, srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPorts: seq[Port]): int = for (index, attempt) in puncher.connectAttempts.pairs(): if attempt.srcIp == srcIp and attempt.srcPort == srcPort and attempt.dstIp == dstIp and attempt.dstPorts.any(proc (p: Port): bool = p in dstPorts): return index return -1 proc findAcceptAttempt(puncher: TcpSyniPuncher, srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPorts: seq[Port]): int = for (index, attempt) in puncher.acceptAttempts.pairs(): if attempt.srcIp == srcIp and attempt.srcPort == srcPort and attempt.dstIp == dstIp and attempt.dstPorts.any(proc (p: Port): bool = p in dstPorts): return index return -1 proc findAcceptAttemptsByLocalAddr(puncher: TcpSyniPuncher, address: IpAddress, port: Port): seq[AcceptAttempt] = for attempt in puncher.acceptAttempts: if attempt.srcIp == address and attempt.srcPort == port: result.add(attempt) proc predictPortRange(dstPorts: seq[Port]): seq[Port] = # TODO: do real port prediction result = newSeq[Port](1) let basePort = min(dstPorts[1].uint16, uint16.high - (result.len - 1).uint16) for i in 0 .. result.len - 1: result[i] = Port(basePort + i.uint16) proc cleanup*(puncher: TcpSyniPuncher) {.async.} = while puncher.connectAttempts.len() != 0: await puncher.connectAttempts.pop().deleteFirewallRules() while puncher.acceptAttempts.len() != 0: await puncher.connectAttempts.pop().deleteFirewallRules() proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port, future: Future[AsyncSocket]) {.async.} = let sock = newAsyncSocket() sock.setSockOpt(OptReuseAddr, true) sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2) echo &"doConnect {srcIp}:{srcPort} -> {dstIp}:{dstPort}" sock.bindAddr(srcPort, $srcIp) try: await sock.connect($dstIp, dstPort) sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 64) future.complete(sock) except OSError as e: echo &"connection {srcIP}:{srcPort.int} -> {dstIp}:{dstPort.int} failed: ", e.msg sock.close() proc connect*(puncher: TcpSyniPuncher, srcPort: Port, dstIp: IpAddress, dstPorts: seq[Port], progressCb: PunchProgressCb): Future[AsyncSocket] {.async.} = let localIp = getPrimaryIPAddr(dstIp) if puncher.findConnectAttempt(localIp, srcPort, dstIp, dstPorts) != -1: raise newException(PunchHoleError, "hole punching for given parameters already active") let attempt = ConnectAttempt(srcIp: localIp, srcPort: srcPort, dstIp: dstIp, dstPorts: predictPortRange(dstPorts)) puncher.connectAttempts.add(attempt) await attempt.addFirewallRules() let iface = fromIpAddress(attempt.srcIp) let captureSeqFd = setupEthernetCapturingSocket(iface) let captureAckFd = setupEthernetCapturingSocket(iface) let injectAckFd = setupTcpInjectingSocket() asyncCheck attempt.captureSeqNumbers(captureSeqFd, progressCb) asyncCheck attempt.captureAndResendAck(captureAckFd, injectAckFd) try: let connectFuture = newFuture[AsyncSocket]("connect") for dstPort in attempt.dstPorts: asyncCheck doConnect(attempt.srcIp, attempt.srcPort, attempt.dstIp, dstPort, connectfuture) await connectFuture or sleepAsync(Timeout) await attempt.deleteFirewallRules() puncher.connectAttempts.del(puncher.connectAttempts.find(attempt)) if connectFuture.finished(): result = connectFuture.read() else: raise newException(PunchHoleError, "timeout") except OSError as e: raise newException(PunchHoleError, e.msg) proc doAccept(puncher: TcpSyniPuncher, srcIp: IpAddress, srcPort: Port) {.async.} = let sock = newAsyncSocket() sock.setSockOpt(OptReuseAddr, true) sock.bindAddr(srcPort, $(srcIp)) sock.listen() while true: let acceptFuture = sock.accept() await acceptFuture or sleepAsync(Timeout) if acceptFuture.finished(): let peer = acceptFuture.read() let (peerAddr, peerPort) = peer.getPeerAddr() let peerIp = parseIpAddress(peerAddr) let i = puncher.findAcceptAttempt(srcIp, srcPort, peerIp, @[peerPort]) if i == -1: echo "Accepted connection, but no attempt found. Discarding." peer.close() continue else: let attempt = puncher.acceptAttempts[i] attempt.future.complete(peer) let attempts = puncher.findAcceptAttemptsByLocalAddr(srcIp, srcPort) # FIXME: should attempts have timestamps, so we can decide here which ones to delete? if attempts.len() <= 1: break sock.close() proc accept*(puncher: TcpSyniPuncher, srcPort: Port, dstIp: IpAddress, dstPorts: seq[Port], seqNums: seq[uint32]): Future[AsyncSocket] {.async.} = let localIp = getPrimaryIPAddr(dstIp) let existingAttempts = puncher.findAcceptAttemptsByLocalAddr(localIp, srcPort) if existingAttempts.len() == 0: echo &"accepting connections from {dstIp}:{dstPorts[0].int}" asyncCheck puncher.doAccept(localIp, srcPort) else: for a in existingAttempts: if a.dstIp == dstIp and a.dstPorts.any(proc (p: Port): bool = p in dstPorts): raise newException(PunchHoleError, "hole punching for given parameters already active") try: let rawFd = setupTcpInjectingSocket() let attempt = AcceptAttempt(srcIp: localIp, srcPort: srcPort, dstIp: dstIp, dstPorts: predictPortRange(dstPorts), seqNums: seqNums, future: newFuture[AsyncSocket]("accept")) puncher.acceptAttempts.add(attempt) await attempt.addFirewallRules() for dstPort in attempt.dstPorts: let synOut = IpPacket(protocol: tcp, ipAddrSrc: attempt.srcIp, ipAddrDst: attempt.dstIp, ipTTL: 2, tcpPortSrc: attempt.srcPort, tcpPortDst: dstPort, tcpSeqNumber: 0, tcpAckNumber: 0, tcpFlags: {SYN}, tcpWindowSize: 1452 * 10) await rawFd.injectTcpPacket(synOut) for seqNum in attempt.seqNums: let synIn = IpPacket(protocol: tcp, ipAddrSrc: attempt.dstIp, ipAddrDst: attempt.srcIp, ipTTL: 64, tcpPortSrc: dstPort, tcpPortDst: attempt.srcPort, tcpSeqNumber: seqNum, tcpAckNumber: 0, tcpFlags: {SYN}, tcpWindowSize: 1452 * 10) echo &"[{synIn.ipAddrSrc}:{synIn.tcpPortSrc} -> {synIn.ipAddrDst}:{synIn.tcpPortDst}, SEQ {synIn.tcpSeqNumber}] injecting SYN" await rawFd.injectTcpPacket(synIn) closeSocket(rawFd) await attempt.future or sleepAsync(Timeout) await attempt.deleteFirewallRules() puncher.acceptAttempts.del(puncher.acceptAttempts.find(attempt)) if attempt.future.finished(): result = attempt.future.read() else: raise newException(PunchHoleError, "timeout") except OSError as e: echo &"accepting connections from {dstIP}:{dstPorts[0].int} failed: ", e.msg