diff --git a/punchd.nim b/punchd.nim index b71bc07..70e55c8 100644 --- a/punchd.nim +++ b/punchd.nim @@ -61,11 +61,9 @@ proc handleRequest(line: string, unixSock: AsyncSocket) {.async.} = let unixFd = unixSock.getFd.AsyncFD await unixFd.asyncSendMsg(&"ok|{id}\n", @[fromFd(sock.getFd.AsyncFD)]) - await puncher.cleanup except PunchHoleError as e: await unixSock.send(&"error|{id}|{e.msg}\n") - await puncher.cleanup except ValueError: unixSock.close diff --git a/tcp_syni.nim b/tcp_syni.nim index efde0a8..8b177ae 100644 --- a/tcp_syni.nim +++ b/tcp_syni.nim @@ -35,10 +35,7 @@ proc addFirewallRule(srcIp: IpAddress, srcPort: Port, --ctorigdst {dstIp} \ --ctorigdstport {dstPort.int} \ -j DROP""" - try: - discard await asyncExecCmd(firewall_cmd) - except OSError: - raise newException(PunchHoleError, "cannot add firewall rule") + discard await asyncExecCmd(firewall_cmd) proc delFirewallRule(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port) {.async.} = @@ -54,10 +51,7 @@ proc delFirewallRule(srcIp: IpAddress, srcPort: Port, --ctorigdst {dstIp} \ --ctorigdstport {dstPort.int} \ -j DROP""" - try: - discard await asyncExecCmd(firewall_cmd) - except OSError: - raise newException(PunchHoleError, "cannot delete firewall rule") + discard await asyncExecCmd(firewall_cmd) proc captureSeqNumbers(puncher: TcpSyniPuncher, rawFd: AsyncFD, cb: PunchProgressCb) {.async.} = @@ -114,12 +108,6 @@ proc initPuncher*(srcPort: Port, dstIp: IpAddress, dstPorts: array[3, Port], predictedDstPorts[i] = Port(basePort + i.uint16) result = TcpSyniPuncher(srcIp: localIp, srcPort: srcPort, dstIp: dstIp, dstPorts: predictedDstPorts, seqNums: seqNums) - for dstPort in result.dstPorts: - await addFirewallRule(result.srcIp, result.srcPort, result.dstIp, dstPort) - -proc cleanup*(puncher: TcpSyniPuncher) {.async.} = - for dstPort in puncher.dstPorts: - await delFirewallRule(puncher.srcIp, puncher.srcPort, puncher.dstIp, dstPort) proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port, future: Future[AsyncSocket]) {.async.} = @@ -127,12 +115,21 @@ proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, sock.setSockOpt(OptReuseAddr, true) sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2) sock.bindAddr(srcPort, $srcIp) + try: + await addFirewallRule(srcIp, srcPort, dstIp, dstPort) + except OSError as e: + echo "cannot add firewall rule: ", e.msg + return try: await sock.connect($dstIp, dstPort) future.complete(sock) except OSError as e: echo &"connection {srcIP}:{srcPort.int} -> {dstIp}:{dstPort.int} failed: ", e.msg discard + try: + await delFirewallRule(srcIp, srcPort, dstIp, dstPort) + except OSError as e: + echo "cannot delete firewall rule: ", e.msg proc doAccept(puncher: TcpSyniPuncher, future: Future[AsyncSocket]) {.async.} = for dstPort in puncher.dstPorts: @@ -147,6 +144,12 @@ proc doAccept(puncher: TcpSyniPuncher, future: Future[AsyncSocket]) {.async.} = sock.close() except OSError: discard + try: + await addFirewallRule(puncher.srcIp, puncher.srcPort, puncher.dstIp, + dstPort) + except OSError as e: + echo "cannot add firewall rule: ", e.msg + return try: # FIXME: timeout let rawFd = setupTcpInjectingSocket() @@ -163,6 +166,12 @@ proc doAccept(puncher: TcpSyniPuncher, future: Future[AsyncSocket]) {.async.} = except OSError as e: echo &"accepting connections from {puncher.dstIP}:{puncher.dstPorts[0].int} failed: ", e.msg discard + for dstPort in puncher.dstPorts: + try: + await delFirewallRule(puncher.srcIp, puncher.srcPort, puncher.dstIp, + dstPort) + except OSError as e: + echo "cannot delete firewall rule: ", e.msg proc connect*(puncher: TcpSyniPuncher, progressCb: PunchProgressCb): Future[AsyncSocket] =