import asyncdispatch, asyncnet, strformat from net import IpAddress, Port, `$`, `==`, getPrimaryIPAddr, toSockAddr from nativesockets import SockAddr, Sockaddr_storage, SockLen, setSockOptInt import asyncutils import ip_packet import network_interface import raw_socket var IPPROTO_IP {.importc: "IPPROTO_IP", header: "".}: cint var IP_TTL {.importc: "IP_TTL", header: "".}: cint type TcpSyniPuncher* = ref object srcIp: IpAddress srcPort: Port dstIp: IpAddress dstPorts: seq[Port] seqNums: seq[uint32] firewallRules: seq[string] PunchProgressCb* = proc (seqNums: seq[uint32]) {.async.} PunchHoleError* = object of ValueError proc makeFirewallRule(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port): string = result = fmt"""-w \ -d {srcIp} \ -p icmp \ --icmp-type time-exceeded \ -m conntrack \ --ctstate RELATED \ --ctproto tcp \ --ctorigsrc {srcIp} \ --ctorigsrcport {srcPort.int} \ --ctorigdst {dstIp} \ --ctorigdstport {dstPort.int} \ -j DROP""" proc iptablesInsert(chain: string, rule: string) {.async.} = let firewall_cmd = fmt"iptables -I {chain} {rule}" discard await asyncExecCmd(firewall_cmd) proc iptablesDelete(chain: string, rule: string) {.async.} = let firewall_cmd = fmt"iptables -D {chain} {rule}" discard await asyncExecCmd(firewall_cmd) proc injectTcpPacket(rawFd: AsyncFD, ipPacket: IpPacket) {.async.} = assert(ipPacket.protocol == tcp) try: let packet = serialize(ipPacket) var sockaddr: Sockaddr_storage var sockaddrLen: SockLen toSockAddr(ipPacket.ipAddrDst, ipPacket.tcpPortDst, sockaddr, sockaddrLen) await rawFd.sendTo(packet.cstring, packet.len, cast[ptr SockAddr](addr sockaddr), sockaddrLen) echo &"injected {ipPacket.ipAddrSrc}:{ipPacket.tcpPortSrc.int} -> {ipPacket.ipAddrDst}:{ipPacket.tcpPortDst.int} (seq {ipPacket.tcpSeqNumber})" except OSError as e: echo &"cannot inject {ipPacket.ipAddrSrc}:{ipPacket.tcpPortSrc.int} -> {ipPacket.ipAddrDst}:{ipPacket.tcpPortDst.int} (seq {ipPacket.tcpSeqNumber}): ", e.msg raise newException(PunchHoleError, e.msg) proc captureSeqNumbers(puncher: TcpSyniPuncher, rawFd: AsyncFD, cb: PunchProgressCb) {.async.} = # FIXME: every sequence number is captured twice (RST too?) # FIXME: timeout? var seqNums = newSeq[uint32]() while seqNums.len < puncher.dstPorts.len: let packet = await rawFd.recv(4000) if packet == "": break let parsed = parseEthernetPacket(packet) if parsed.protocol == tcp and parsed.ipAddrSrc == puncher.srcIp and parsed.tcpPortSrc.int == puncher.srcPort.int and parsed.ipAddrDst == puncher.dstIp and parsed.tcpFlags == {SYN}: for port in puncher.dstPorts: if parsed.tcpPortDst.int == port.int: seqNums.add(parsed.tcpSeqNumber) break await cb(seqNums) proc captureAndResendAck(puncher: TcpSyniPuncher, captureFd: AsyncFD, injectFd: AsyncFD) {.async.} = while true: let packet = await captureFd.recv(4000) if packet == "": break var parsed = parseEthernetPacket(packet) if parsed.protocol == tcp and parsed.ipAddrSrc == puncher.srcIp and parsed.tcpPortSrc.int == puncher.srcPort.int and parsed.ipAddrDst == puncher.dstIp and parsed.tcpFlags == {ACK}: for port in puncher.dstPorts: if parsed.tcpPortDst.int == port.int: parsed.ipTTL = 64 await injectFd.injectTcpPacket(parsed) return proc initPuncher*(srcPort: Port, dstIp: IpAddress, dstPorts: array[3, Port], seqNums: seq[uint32] = @[]): TcpSyniPuncher = let localIp = getPrimaryIPAddr(dstIp) # TODO: do real port prediction var predictedDstPorts = newSeq[Port](3) let basePort = min(dstPorts[1].uint16, uint16.high - (predictedDstPorts.len - 1).uint16) for i in 0 .. predictedDstPorts.len - 1: predictedDstPorts[i] = Port(basePort + i.uint16) result = TcpSyniPuncher(srcIp: localIp, srcPort: srcPort, dstIp: dstIp, dstPorts: predictedDstPorts, seqNums: seqNums) proc addFirewallRules(puncher: TcpSyniPuncher) {.async.} = for dstPort in puncher.dstPorts: let rule = makeFirewallRule(puncher.srcIp, puncher.srcPort, puncher.dstIp, dstPort) try: await iptablesInsert("INPUT", rule) puncher.firewallRules.add(rule) except OSError as e: echo "cannot add firewall rule: ", e.msg raise newException(PunchHoleError, e.msg) proc cleanup*(puncher: TcpSyniPuncher) {.async.} = for rule in puncher.firewallRules: try: await iptablesDelete("INPUT", rule) except OSError: # At least we tried discard proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port, future: Future[AsyncSocket]) {.async.} = let sock = newAsyncSocket() sock.setSockOpt(OptReuseAddr, true) sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2) sock.bindAddr(srcPort, $srcIp) try: await sock.connect($dstIp, dstPort) sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 64) future.complete(sock) except OSError as e: echo &"connection {srcIP}:{srcPort.int} -> {dstIp}:{dstPort.int} failed: ", e.msg discard proc connectParallel(puncher: TcpSyniPuncher): Future[AsyncSocket] = result = newFuture[AsyncSocket]("doConnect") for dstPort in puncher.dstPorts: asyncCheck doConnect(puncher.srcIp, puncher.srcPort, puncher.dstIp, dstPort, result) proc connect*(puncher: TcpSyniPuncher, progressCb: PunchProgressCb): Future[AsyncSocket] {.async.} = let iface = fromIpAddress(puncher.srcIp) let captureSeqFd = setupEthernetCapturingSocket(iface) let captureAckFd = setupEthernetCapturingSocket(iface) let injectAckFd = setupTcpInjectingSocket() asyncCheck puncher.captureSeqNumbers(captureSeqFd, progressCb) asyncCheck puncher.captureAndResendAck(captureAckFd, injectAckFd) await puncher.addFirewallRules() try: result = await puncher.connectParallel() await puncher.cleanup() except OSError as e: raise newException(PunchHoleError, e.msg) proc prepareAccept(puncher: TcpSyniPuncher) {.async.} = # FIXME: timeouts for dstPort in puncher.dstPorts: try: let sock = newAsyncSocket() sock.setSockOpt(OptReuseAddr, true) sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2) sock.bindAddr(puncher.srcPort, $(puncher.srcIp)) await sock.connect($(puncher.dstIp), dstPort) echo "connected during accept phase" sock.close() except OSError: discard proc accept*(puncher: TcpSyniPuncher): Future[AsyncSocket] {.async.} = await puncher.prepareAccept() await puncher.addFirewallRules() try: # FIXME: timeout let rawFd = setupTcpInjectingSocket() for dstPort in puncher.dstPorts: for seqNum in puncher.seqNums: let ipPacket = IpPacket(protocol: tcp, ipAddrSrc: puncher.dstIp, ipAddrDst: puncher.srcIp, ipTTL: 64, tcpPortSrc: dstPort, tcpPortDst: puncher.srcPort, tcpSeqNumber: seqNum, tcpAckNumber: 0, tcpFlags: {SYN}, tcpWindowSize: 1452 * 10) asyncCheck rawFd.injectTcpPacket(ipPacket) let sock = newAsyncSocket() sock.setSockOpt(OptReuseAddr, true) sock.bindAddr(puncher.srcPort, $(puncher.srcIp)) sock.listen() echo &"accepting connections from {puncher.dstIp}:{puncher.dstPorts[0].int}" result = await sock.accept() await puncher.cleanup() except OSError as e: echo &"accepting connections from {puncher.dstIP}:{puncher.dstPorts[0].int} failed: ", e.msg await puncher.cleanup() raise newException(PunchHoleError, e.msg)