diff --git a/punchd.nim b/punchd.nim index 70e55c8..cea9779 100644 --- a/punchd.nim +++ b/punchd.nim @@ -47,18 +47,18 @@ proc handleRequest(line: string, unixSock: AsyncSocket) {.async.} = $req.dstIp, req.dstPorts.join(","), seqNumbers.join(",")].join("|") await unixSock.send(&"progress|{id}|{content}\n") - puncher = await initPuncher(req.srcPorts[0], req.dstIp, req.dstPorts) + puncher = initPuncher(req.srcPorts[0], req.dstIp, req.dstPorts) sock = await puncher.connect(handleSeqNumbers) of "tcp-syni-accept": let req = parseMessage[TcpSyniAccept](args[2]) - puncher = await initPuncher(req.srcPorts[0], req.dstIp, req.dstPorts, - req.seqNums) + puncher = initPuncher(req.srcPorts[0], req.dstIp, req.dstPorts, + req.seqNums) sock = await puncher.accept() else: raise newException(ValueError, "invalid request") - + let unixFd = unixSock.getFd.AsyncFD await unixFd.asyncSendMsg(&"ok|{id}\n", @[fromFd(sock.getFd.AsyncFD)]) diff --git a/tcp_syni.nim b/tcp_syni.nim index 8b177ae..dcfa061 100644 --- a/tcp_syni.nim +++ b/tcp_syni.nim @@ -10,20 +10,21 @@ var IPPROTO_IP {.importc: "IPPROTO_IP", header: "".}: cint var IP_TTL {.importc: "IP_TTL", header: "".}: cint type - TcpSyniPuncher* = object + TcpSyniPuncher* = ref object srcIp: IpAddress srcPort: Port dstIp: IpAddress dstPorts: seq[Port] seqNums: seq[uint32] + firewallRules: seq[string] PunchProgressCb* = proc (seqNums: seq[uint32]) {.async.} PunchHoleError* = object of ValueError -proc addFirewallRule(srcIp: IpAddress, srcPort: Port, - dstIp: IpAddress, dstPort: Port) {.async.} = - let firewall_cmd = fmt"""iptables -I INPUT \ +proc makeFirewallRule(srcIp: IpAddress, srcPort: Port, + dstIp: IpAddress, dstPort: Port): string = + result = fmt"""-w \ -d {srcIp} \ -p icmp \ --icmp-type time-exceeded \ @@ -35,22 +36,13 @@ proc addFirewallRule(srcIp: IpAddress, srcPort: Port, --ctorigdst {dstIp} \ --ctorigdstport {dstPort.int} \ -j DROP""" + +proc iptablesInsert(chain: string, rule: string) {.async.} = + let firewall_cmd = fmt"iptables -I {chain} {rule}" discard await asyncExecCmd(firewall_cmd) -proc delFirewallRule(srcIp: IpAddress, srcPort: Port, - dstIp: IpAddress, dstPort: Port) {.async.} = - let firewall_cmd = fmt"""iptables -D INPUT \ - -d {srcIp} \ - -p icmp \ - --icmp-type time-exceeded \ - -m conntrack \ - --ctstate RELATED \ - --ctproto tcp \ - --ctorigsrc {srcIp} \ - --ctorigsrcport {srcPort.int} \ - --ctorigdst {dstIp} \ - --ctorigdstport {dstPort.int} \ - -j DROP""" +proc iptablesDelete(chain: string, rule: string) {.async.} = + let firewall_cmd = fmt"iptables -D {chain} {rule}" discard await asyncExecCmd(firewall_cmd) proc captureSeqNumbers(puncher: TcpSyniPuncher, rawFd: AsyncFD, @@ -98,7 +90,7 @@ proc injectSyns(rawFd: AsyncFD, srcIp: IpAddress, srcPort: Port, echo "cannot inject {srcIp}:{srcPort.int} -> {dstIp}:{dstPort.int} (seq {seqNum}): ", e.msg proc initPuncher*(srcPort: Port, dstIp: IpAddress, dstPorts: array[3, Port], - seqNums: seq[uint32] = @[]): Future[TcpSyniPuncher] {.async.} = + seqNums: seq[uint32] = @[]): TcpSyniPuncher = let localIp = getPrimaryIPAddr(dstIp) # TODO: do real port prediction var predictedDstPorts = newSeq[Port](3) @@ -109,31 +101,58 @@ proc initPuncher*(srcPort: Port, dstIp: IpAddress, dstPorts: array[3, Port], result = TcpSyniPuncher(srcIp: localIp, srcPort: srcPort, dstIp: dstIp, dstPorts: predictedDstPorts, seqNums: seqNums) -proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, - dstPort: Port, future: Future[AsyncSocket]) {.async.} = +proc addFirewallRules(puncher: TcpSyniPuncher) {.async.} = + for dstPort in puncher.dstPorts: + let rule = makeFirewallRule(puncher.srcIp, puncher.srcPort, + puncher.dstIp, dstPort) + try: + await iptablesInsert("INPUT", rule) + puncher.firewallRules.add(rule) + except OSError as e: + echo "cannot add firewall rule: ", e.msg + raise newException(PunchHoleError, e.msg) + +proc cleanup*(puncher: TcpSyniPuncher) {.async.} = + for rule in puncher.firewallRules: + try: + await iptablesDelete("INPUT", rule) + except OSError: + # At least we tried + discard + +proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port, + future: Future[AsyncSocket]) {.async.} = let sock = newAsyncSocket() sock.setSockOpt(OptReuseAddr, true) sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2) sock.bindAddr(srcPort, $srcIp) - try: - await addFirewallRule(srcIp, srcPort, dstIp, dstPort) - except OSError as e: - echo "cannot add firewall rule: ", e.msg - return try: await sock.connect($dstIp, dstPort) future.complete(sock) except OSError as e: echo &"connection {srcIP}:{srcPort.int} -> {dstIp}:{dstPort.int} failed: ", e.msg discard - try: - await delFirewallRule(srcIp, srcPort, dstIp, dstPort) - except OSError as e: - echo "cannot delete firewall rule: ", e.msg -proc doAccept(puncher: TcpSyniPuncher, future: Future[AsyncSocket]) {.async.} = +proc connectParallel(puncher: TcpSyniPuncher): Future[AsyncSocket] = + result = newFuture[AsyncSocket]("doConnect") + for dstPort in puncher.dstPorts: + asyncCheck doConnect(puncher.srcIp, puncher.srcPort, puncher.dstIp, dstPort, result) + +proc connect*(puncher: TcpSyniPuncher, + progressCb: PunchProgressCb): Future[AsyncSocket] {.async.} = + let iface = fromIpAddress(puncher.srcIp) + let rawFd = setupEthernetCapturingSocket(iface) + asyncCheck puncher.captureSeqNumbers(rawFd, progressCb) + await puncher.addFirewallRules() + try: + result = await puncher.connectParallel() + await puncher.cleanup() + except OSError as e: + raise newException(PunchHoleError, e.msg) + +proc prepareAccept(puncher: TcpSyniPuncher) {.async.} = + # FIXME: timeouts for dstPort in puncher.dstPorts: - # TODO: connect in parallel for better performance try: let sock = newAsyncSocket() sock.setSockOpt(OptReuseAddr, true) @@ -144,12 +163,10 @@ proc doAccept(puncher: TcpSyniPuncher, future: Future[AsyncSocket]) {.async.} = sock.close() except OSError: discard - try: - await addFirewallRule(puncher.srcIp, puncher.srcPort, puncher.dstIp, - dstPort) - except OSError as e: - echo "cannot add firewall rule: ", e.msg - return + +proc accept*(puncher: TcpSyniPuncher): Future[AsyncSocket] {.async.} = + await puncher.prepareAccept() + await puncher.addFirewallRules() try: # FIXME: timeout let rawFd = setupTcpInjectingSocket() @@ -161,28 +178,9 @@ proc doAccept(puncher: TcpSyniPuncher, future: Future[AsyncSocket]) {.async.} = sock.bindAddr(puncher.srcPort, $(puncher.srcIp)) sock.listen() echo &"accepting connections from {puncher.dstIp}:{puncher.dstPorts[0].int}" - let connectedSock = await sock.accept() - future.complete(connectedSock) + result = await sock.accept() + await puncher.cleanup() except OSError as e: echo &"accepting connections from {puncher.dstIP}:{puncher.dstPorts[0].int} failed: ", e.msg - discard - for dstPort in puncher.dstPorts: - try: - await delFirewallRule(puncher.srcIp, puncher.srcPort, puncher.dstIp, - dstPort) - except OSError as e: - echo "cannot delete firewall rule: ", e.msg - -proc connect*(puncher: TcpSyniPuncher, - progressCb: PunchProgressCb): Future[AsyncSocket] = - result = newFuture[AsyncSocket]("tcp_syni.connect") - let iface = fromIpAddress(puncher.srcIp) - let rawFd = setupEthernetCapturingSocket(iface) - asyncCheck puncher.captureSeqNumbers(rawFd, progressCb) - for dstPort in puncher.dstPorts: - asyncCheck doConnect(puncher.srcIp, puncher.srcPort, puncher.dstIp, - dstPort, result) - -proc accept*(puncher: TcpSyniPuncher): Future[AsyncSocket] = - result = newFuture[AsyncSocket]("tcp_syni.accept") - asyncCheck puncher.doAccept(result) + await puncher.cleanup() + raise newException(PunchHoleError, e.msg)