diff --git a/punchd.nim b/punchd.nim index 8802d2f..61921bf 100644 --- a/punchd.nim +++ b/punchd.nim @@ -13,7 +13,7 @@ const PunchdSocket = "/tmp/punchd.socket" type Punchd = ref object unixSocket: AsyncSocket - punchers: seq[TcpSyniPuncher] + tcpSyniPuncher: TcpSyniPuncher Sigint = object of CatchableError @@ -52,16 +52,13 @@ proc handleRequest(punchd: Punchd, line: string, $req.dstIp, req.dstPorts.join(","), seqNumbers.join(",")].join("|") await unixSock.send(&"progress|{id}|{content}\n") - puncher = initPuncher(req.srcPorts[0], req.dstIp, req.dstPorts) - punchd.punchers.add(puncher) - sock = await puncher.connect(handleSeqNumbers) - + sock = await punchd.tcpSyniPuncher.connect(req.srcPorts[0], req.dstIp, + req.dstPorts, handleSeqNumbers) + of "tcp-syni-accept": let req = parseMessage[TcpSyniAccept](args[2]) - puncher = initPuncher(req.srcPorts[0], req.dstIp, req.dstPorts, - req.seqNums) - punchd.punchers.add(puncher) - sock = await puncher.accept() + sock = await punchd.tcpSyniPuncher.accept(req.srcPorts[0], req.dstIp, + req.dstPorts, req.seqNums) else: raise newException(ValueError, "invalid request") @@ -91,19 +88,19 @@ proc handleUsers(punchd: Punchd) {.async.} = proc main() = setControlCHook(handleSigint) removeFile(PunchdSocket) - let punchd = Punchd(unixSocket: newAsyncSocket(AF_UNIX, SOCK_STREAM, IPPROTO_IP)) - punchd.unixSocket.bindUnix(PunchdSocket) + let unixSocket = newAsyncSocket(AF_UNIX, SOCK_STREAM, IPPROTO_IP) + unixSocket.bindUnix(PunchdSocket) + unixSocket.listen() setFilePermissions(PunchdSocket, {fpUserRead, fpUserWrite, fpGroupRead, fpGroupWrite, fpOthersRead, fpOthersWrite}) - punchd.unixSocket.listen() + let punchd = Punchd(unixSocket: unixSocket, tcpSyniPuncher: initPuncher()) asyncCheck handleUsers(punchd) try: runForever() except Sigint: - for puncher in punchd.punchers: - waitFor puncher.cleanup() + waitFor punchd.tcpSyniPuncher.cleanup() punchd.unixSocket.close() removeFile(PunchdSocket) quit(0) diff --git a/tcp_syni.nim b/tcp_syni.nim index ab80994..90146e4 100644 --- a/tcp_syni.nim +++ b/tcp_syni.nim @@ -1,6 +1,7 @@ import asyncfutures, asyncdispatch, asyncnet, strformat -from net import IpAddress, Port, `$`, `==`, getPrimaryIPAddr, toSockAddr +from net import IpAddress, Port, `$`, `==`, getPrimaryIPAddr, toSockAddr, parseIpAddress from nativesockets import SockAddr, Sockaddr_storage, SockLen, setSockOptInt +from sequtils import any import asyncutils import ip_packet import network_interface @@ -12,13 +13,25 @@ var IP_TTL {.importc: "IP_TTL", header: "".}: cint const Timeout = 3000 type - TcpSyniPuncher* = ref object + ConnectAttempt* = ref object + srcIp: IpAddress + srcPort: Port + dstIp: IpAddress + dstPorts: seq[Port] + firewallRules: seq[string] + + AcceptAttempt* = ref object srcIp: IpAddress srcPort: Port dstIp: IpAddress dstPorts: seq[Port] seqNums: seq[uint32] firewallRules: seq[string] + future: Future[AsyncSocket] + + TcpSyniPuncher* = ref object + connectAttempts: seq[ConnectAttempt] + acceptAttempts: seq[AcceptAttempt] PunchProgressCb* = proc (seqNums: seq[uint32]) {.async.} @@ -47,6 +60,26 @@ proc iptablesDelete(chain: string, rule: string) {.async.} = let firewall_cmd = fmt"iptables -D {chain} {rule}" discard await asyncExecCmd(firewall_cmd) +proc addFirewallRules[T](attempt: T) {.async.} = + for dstPort in attempt.dstPorts: + let rule = makeFirewallRule(attempt.srcIp, attempt.srcPort, + attempt.dstIp, dstPort) + try: + await iptablesInsert("INPUT", rule) + attempt.firewallRules.add(rule) + except OSError as e: + echo "cannot add firewall rule: ", e.msg + raise newException(PunchHoleError, e.msg) + +proc deleteFirewallRules[T](attempt: T) {.async.} = + for rule in attempt.firewallRules: + # FIXME: close sock? + try: + await iptablesDelete("INPUT", rule) + except OSError: + # At least we tried + discard + proc injectTcpPacket(rawFd: AsyncFD, ipPacket: IpPacket) {.async.} = assert(ipPacket.protocol == tcp) try: @@ -59,27 +92,27 @@ proc injectTcpPacket(rawFd: AsyncFD, ipPacket: IpPacket) {.async.} = except OSError as e: raise newException(PunchHoleError, e.msg) -proc captureSeqNumbers(puncher: TcpSyniPuncher, rawFd: AsyncFD, +proc captureSeqNumbers(attempt: ConnectAttempt, rawFd: AsyncFD, cb: PunchProgressCb) {.async.} = # FIXME: timeout? var seqNums = newSeq[uint32]() - while seqNums.len < puncher.dstPorts.len: + while seqNums.len < attempt.dstPorts.len: let packet = await rawFd.recv(4000) if packet == "": break let parsed = parseEthernetPacket(packet) if parsed.protocol == tcp and - parsed.ipAddrSrc == puncher.srcIp and - parsed.tcpPortSrc.int == puncher.srcPort.int and - parsed.ipAddrDst == puncher.dstIp and + parsed.ipAddrSrc == attempt.srcIp and + parsed.tcpPortSrc.int == attempt.srcPort.int and + parsed.ipAddrDst == attempt.dstIp and parsed.tcpFlags == {SYN}: - for port in puncher.dstPorts: + for port in attempt.dstPorts: if parsed.tcpPortDst.int == port.int: seqNums.add(parsed.tcpSeqNumber) break await cb(seqNums) -proc captureAndResendAck(puncher: TcpSyniPuncher, captureFd: AsyncFD, +proc captureAndResendAck(attempt: ConnectAttempt, captureFd: AsyncFD, injectFd: AsyncFD) {.async.} = while true: let packet = await captureFd.recv(4000) @@ -87,47 +120,56 @@ proc captureAndResendAck(puncher: TcpSyniPuncher, captureFd: AsyncFD, break var parsed = parseEthernetPacket(packet) if parsed.protocol == tcp and - parsed.ipAddrSrc == puncher.srcIp and - parsed.tcpPortSrc.int == puncher.srcPort.int and - parsed.ipAddrDst == puncher.dstIp and + parsed.ipAddrSrc == attempt.srcIp and + parsed.tcpPortSrc.int == attempt.srcPort.int and + parsed.ipAddrDst == attempt.dstIp and parsed.tcpFlags == {ACK}: - for port in puncher.dstPorts: + for port in attempt.dstPorts: if parsed.tcpPortDst.int == port.int: parsed.ipTTL = 64 echo &"[{parsed.ipAddrSrc}:{parsed.tcpPortSrc.int} -> {parsed.ipAddrDst}:{parsed.tcpPortDst}, SEQ {parsed.tcpSeqNumber}] resending ACK with TTL {parsed.ipTTL}" await injectFd.injectTcpPacket(parsed) return -proc initPuncher*(srcPort: Port, dstIp: IpAddress, dstPorts: seq[Port], - seqNums: seq[uint32] = @[]): TcpSyniPuncher = - let localIp = getPrimaryIPAddr(dstIp) - # TODO: do real port prediction - var predictedDstPorts = newSeq[Port](1) - let basePort = min(dstPorts[1].uint16, - uint16.high - (predictedDstPorts.len - 1).uint16) - for i in 0 .. predictedDstPorts.len - 1: - predictedDstPorts[i] = Port(basePort + i.uint16) - result = TcpSyniPuncher(srcIp: localIp, srcPort: srcPort, dstIp: dstIp, - dstPorts: predictedDstPorts, seqNums: seqNums) +proc initPuncher*(): TcpSyniPuncher = TcpSyniPuncher() -proc addFirewallRules(puncher: TcpSyniPuncher) {.async.} = - for dstPort in puncher.dstPorts: - let rule = makeFirewallRule(puncher.srcIp, puncher.srcPort, - puncher.dstIp, dstPort) - try: - await iptablesInsert("INPUT", rule) - puncher.firewallRules.add(rule) - except OSError as e: - echo "cannot add firewall rule: ", e.msg - raise newException(PunchHoleError, e.msg) +proc findConnectAttempt(puncher: TcpSyniPuncher, srcIp: IpAddress, + srcPort: Port, dstIp: IpAddress, + dstPorts: seq[Port]): int = + for (index, attempt) in puncher.connectAttempts.pairs(): + if attempt.srcIp == srcIp and attempt.srcPort == srcPort and + attempt.dstIp == dstIp and + attempt.dstPorts.any(proc (p: Port): bool = p in dstPorts): + return index + +proc findAcceptAttempt(puncher: TcpSyniPuncher, srcIp: IpAddress, + srcPort: Port, dstIp: IpAddress, + dstPorts: seq[Port]): int = + for (index, attempt) in puncher.acceptAttempts.pairs(): + if attempt.srcIp == srcIp and attempt.srcPort == srcPort and + attempt.dstIp == dstIp and + attempt.dstPorts.any(proc (p: Port): bool = p in dstPorts): + return index + +proc findAcceptAttemptsByLocalAddr(puncher: TcpSyniPuncher, address: IpAddress, + port: Port): seq[AcceptAttempt] = + for attempt in puncher.acceptAttempts: + if attempt.srcIp == address and attempt.srcPort == port: + result.add(attempt) + +proc predictPortRange(dstPorts: seq[Port]): seq[Port] = + # TODO: do real port prediction + result = newSeq[Port](1) + let basePort = min(dstPorts[1].uint16, + uint16.high - (result.len - 1).uint16) + for i in 0 .. result.len - 1: + result[i] = Port(basePort + i.uint16) proc cleanup*(puncher: TcpSyniPuncher) {.async.} = - for rule in puncher.firewallRules: - try: - await iptablesDelete("INPUT", rule) - except OSError: - # At least we tried - discard + for attempt in puncher.connectAttempts: + await attempt.deleteFirewallRules() + for attempt in puncher.acceptAttempts: + await attempt.deleteFirewallRules() proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port, future: Future[AsyncSocket]) {.async.} = @@ -144,39 +186,45 @@ proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port, echo &"connection {srcIP}:{srcPort.int} -> {dstIp}:{dstPort.int} failed: ", e.msg discard -proc connectParallel(puncher: TcpSyniPuncher): Future[AsyncSocket] = - result = newFuture[AsyncSocket]("doConnect") - for dstPort in puncher.dstPorts: - asyncCheck doConnect(puncher.srcIp, puncher.srcPort, puncher.dstIp, dstPort, result) - -proc connect*(puncher: TcpSyniPuncher, +proc connect*(puncher: TcpSyniPuncher, srcPort: Port, dstIp: IpAddress, + dstPorts: seq[Port], progressCb: PunchProgressCb): Future[AsyncSocket] {.async.} = - let iface = fromIpAddress(puncher.srcIp) + let localIp = getPrimaryIPAddr(dstIp) + if puncher.findConnectAttempt(localIp, srcPort, dstIp, dstPorts) != -1: + raise newException(PunchHoleError, "hole punching for given parameters already active") + let attempt = ConnectAttempt(srcIp: localIp, srcPort: srcPort, dstIp: dstIp, + dstPorts: predictPortRange(dstPorts)) + puncher.connectAttempts.add(attempt) + await attempt.addFirewallRules() + let iface = fromIpAddress(attempt.srcIp) let captureSeqFd = setupEthernetCapturingSocket(iface) let captureAckFd = setupEthernetCapturingSocket(iface) let injectAckFd = setupTcpInjectingSocket() - asyncCheck puncher.captureSeqNumbers(captureSeqFd, progressCb) - asyncCheck puncher.captureAndResendAck(captureAckFd, injectAckFd) - await puncher.addFirewallRules() + asyncCheck attempt.captureSeqNumbers(captureSeqFd, progressCb) + asyncCheck attempt.captureAndResendAck(captureAckFd, injectAckFd) try: - let connectParallelFuture = puncher.connectParallel() - await connectParallelFuture or sleepAsync(Timeout) - await puncher.cleanup() - if connectParallelFuture.finished(): - result = connectParallelFuture.read() + let connectFuture = newFuture[AsyncSocket]("connect") + for dstPort in attempt.dstPorts: + asyncCheck doConnect(attempt.srcIp, attempt.srcPort, attempt.dstIp, + dstPort, connectfuture) + await connectFuture or sleepAsync(Timeout) + await attempt.deleteFirewallRules() + puncher.connectAttempts.del(puncher.connectAttempts.find(attempt)) + if connectFuture.finished(): + result = connectFuture.read() else: raise newException(PunchHoleError, "timeout") except OSError as e: raise newException(PunchHoleError, e.msg) -proc prepareAccept(puncher: TcpSyniPuncher) {.async.} = - for dstPort in puncher.dstPorts: +proc prepareAccept(attempt: AcceptAttempt) {.async.} = + for dstPort in attempt.dstPorts: try: let sock = newAsyncSocket() sock.setSockOpt(OptReuseAddr, true) sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2) - sock.bindAddr(puncher.srcPort, $(puncher.srcIp)) - let connectFuture = sock.connect($(puncher.dstIp), dstPort) + sock.bindAddr(attempt.srcPort, $(attempt.srcIp)) + let connectFuture = sock.connect($(attempt.dstIp), dstPort) await connectFuture or sleepAsync(Timeout) if connectFuture.finished(): echo "connected during accept phase" @@ -184,38 +232,68 @@ proc prepareAccept(puncher: TcpSyniPuncher) {.async.} = except OSError: discard -proc accept*(puncher: TcpSyniPuncher): Future[AsyncSocket] {.async.} = - await puncher.prepareAccept() - await puncher.addFirewallRules() - try: - let rawFd = setupTcpInjectingSocket() - for dstPort in puncher.dstPorts: - for seqNum in puncher.seqNums: - let ipPacket = IpPacket(protocol: tcp, - ipAddrSrc: puncher.dstIp, - ipAddrDst: puncher.srcIp, - ipTTL: 64, - tcpPortSrc: dstPort, - tcpPortDst: puncher.srcPort, - tcpSeqNumber: seqNum, - tcpAckNumber: 0, - tcpFlags: {SYN}, - tcpWindowSize: 1452 * 10) - echo &"[{ipPacket.ipAddrSrc}:{ipPacket.tcpPortSrc} -> {ipPacket.ipAddrDst}:{ipPacket.tcpPortDst}, SEQ {ipPacket.tcpSeqNumber}] injecting SYN" - asyncCheck rawFd.injectTcpPacket(ipPacket) - let sock = newAsyncSocket() - sock.setSockOpt(OptReuseAddr, true) - sock.bindAddr(puncher.srcPort, $(puncher.srcIp)) - sock.listen() - echo &"accepting connections from {puncher.dstIp}:{puncher.dstPorts[0].int}" +proc doAccept(puncher: TcpSyniPuncher, srcIp: IpAddress, + srcPort: Port) {.async.} = + let sock = newAsyncSocket() + sock.setSockOpt(OptReuseAddr, true) + sock.bindAddr(srcPort, $(srcIp)) + sock.listen() + while true: let acceptFuture = sock.accept() await acceptFuture or sleepAsync(Timeout) - await puncher.cleanup() if acceptFuture.finished(): - result = acceptFuture.read() + let peer = acceptFuture.read() + let (peerAddr, peerPort) = peer.getPeerAddr() + let peerIp = parseIpAddress(peerAddr) + let i = puncher.findAcceptAttempt(srcIp, srcPort, peerIp, @[peerPort]) + if i == -1: + echo "Accepted connection, but no attempt found. Discarding." + else: + let attempt = puncher.acceptAttempts[i] + attempt.future.complete(peer) + else: + let attempts = puncher.findAcceptAttemptsByLocalAddr(srcIp, srcPort) + # FIXME: should attempts have timestamps, so we can decide here which ones to delete? + if attempts.len() <= 1: + break + +proc accept*(puncher: TcpSyniPuncher, srcPort: Port, dstIp: IpAddress, + dstPorts: seq[Port], + seqNums: seq[uint32]): Future[AsyncSocket] {.async.} = + let localIp = getPrimaryIPAddr(dstIp) + let existingAttempts = puncher.findAcceptAttemptsByLocalAddr(localIp, srcPort) + if existingAttempts.len() == 0: + echo &"accepting connections from {dstIp}:{dstPorts[0].int}" + asyncCheck puncher.doAccept(localIp, srcPort) + else: + for a in existingAttempts: + if a.dstIp == dstIp and + a.dstPorts.any(proc (p: Port): bool = p in dstPorts): + raise newException(PunchHoleError, "hole punching for given parameters already active") + let attempt = AcceptAttempt(srcIp: localIp, srcPort: srcPort, dstIp: dstIp, + dstPorts: dstPorts, seqNums: seqNums, + future: newFuture[AsyncSocket]("accept")) + puncher.acceptAttempts.add(attempt) + await attempt.addFirewallRules() + await attempt.prepareAccept() + try: + let rawFd = setupTcpInjectingSocket() + for dstPort in attempt.dstPorts: + for seqNum in attempt.seqNums: + let ipPacket = IpPacket(protocol: tcp, ipAddrSrc: attempt.dstIp, + ipAddrDst: attempt.srcIp, ipTTL: 64, + tcpPortSrc: dstPort, + tcpPortDst: attempt.srcPort, + tcpSeqNumber: seqNum, tcpAckNumber: 0, + tcpFlags: {SYN}, tcpWindowSize: 1452 * 10) + echo &"[{ipPacket.ipAddrSrc}:{ipPacket.tcpPortSrc} -> {ipPacket.ipAddrDst}:{ipPacket.tcpPortDst}, SEQ {ipPacket.tcpSeqNumber}] injecting SYN" + asyncCheck rawFd.injectTcpPacket(ipPacket) + await attempt.future or sleepAsync(Timeout) + await attempt.deleteFirewallRules() + puncher.acceptAttempts.del(puncher.acceptAttempts.find(attempt)) + if attempt.future.finished(): + result = attempt.future.read() else: raise newException(PunchHoleError, "timeout") except OSError as e: - echo &"accepting connections from {puncher.dstIP}:{puncher.dstPorts[0].int} failed: ", e.msg - await puncher.cleanup() - raise newException(PunchHoleError, e.msg) + echo &"accepting connections from {dstIP}:{dstPorts[0].int} failed: ", e.msg