diff --git a/puncher.nim b/puncher.nim index a57f75d..4afcc0f 100644 --- a/puncher.nim +++ b/puncher.nim @@ -6,21 +6,20 @@ from nativesockets import SockLen, getSockOptInt, setSockOptInt -from sequtils import any +from sequtils import any, map type Attempt* = object ## A hole punching attempt. - srcPorts*: seq[Port] + socks*: seq[AsyncSocket] dstIp*: IpAddress dstPorts*: seq[Port] future*: Future[(AsyncSocket, Port)] Puncher* = ref object - socks: seq[AsyncSocket] - srcPorts: seq[Port] - natProps: NatProperties - attempts: seq[Attempt] + socks*: seq[AsyncSocket] + natProps*: NatProperties + attempts*: seq[Attempt] PunchHoleError* = object of ValueError @@ -28,50 +27,73 @@ var IPPROTO_IP {.importc: "IPPROTO_IP", header: "".}: cint var IP_TTL {.importc: "IP_TTL", header: "".}: cint const Timeout = 3000 +const InitiatorMaxSockCount = 1000 +const ResponderMaxSockCount = 70 +const MaxSockCount = max(InitiatorMaxSockCount, ResponderMaxSockCount) + +proc srcPort(sock: AsyncSocket): Port = + result = sock.getLocalAddr[1] proc `==`(a, b: Attempt): bool = ## ``==`` for hole punching attempts. ## ## Two hole punching attempts are considered equal if their ``dstIp`` is - ## equal, their ``srcPorts`` overlap and their ``dstPorts`` overlap. - a.dstIp == b.dstIp and - a.srcPorts.any(proc (p: Port): bool = p in b.srcPorts) and - a.dstPorts.any(proc (p: Port): bool = p in b.dstPorts) + ## equal and their ``dstPorts`` overlap. + a.dstIp == b.dstIp and a.dstPorts.any(proc (p: Port): bool = p in b.dstPorts) -proc initPuncher*(socks: seq[AsyncSocket], probedSrcPorts: seq[Port]): Puncher = - assert(socks.len > 0) - var srcPorts = newSeq[Port](socks.len) - for i in 0 .. socks.len - 1: - let (_, srcPort) = socks[i].getLocalAddr() - srcPorts[i] = srcPort - let natProps = getNatProperties(srcPorts[0], probedSrcPorts) - Puncher(socks: socks, srcPorts: srcPorts, natProps: natProps) +proc initPuncher*(sock: AsyncSocket, probedSrcPorts: seq[Port]): Puncher = + # TODO: determine IP_TTL + let (_, primarySrcPort) = sock.getLocalAddr() + let natProps = getNatProperties(primarySrcPort, probedSrcPorts) + result = Puncher(socks: @[sock], natProps: natProps) + if result.natProps.natType == SymmetricRandom: + # our NAT is of the evil symmetric type with random port allocation. We are + # trying to help the other peer by allocating a lot of auxillary sockets + # for punching more holes + result.socks.setLen(MaxSockCount) + for i in 1 .. MaxSockCount - 1: + result.socks[i] = newAsyncSocket(sockType = SOCK_DGRAM, + protocol = IPPROTO_UDP, buffered = false) + result.socks[i].bindAddr(Port(0)) +proc primarySrcPort*(puncher: Puncher): Port = + puncher.socks[0].srcPort + +# TODO: lowTTL -> isInitiating, if isInitiating: punch with all auxSocks, else only use 70 proc punch(puncher: Puncher, peerIp: IpAddress, peerPort: Port, - peerProbedPorts: seq[Port], lowTTL: bool, msg: string): + peerProbedPorts: seq[Port], isInitiating: bool, msg: string): Future[Attempt] {.async.} = let punchFuture = newFuture[(AsyncSocket, Port)]("punch") - let natProps = getNatProperties(peerPort, peerProbedPorts) - let predictedDstPorts = predictPortRange(natProps) - result = Attempt(srcPorts: @[puncher.srcPorts[0]], dstIp: peerIp, - dstPorts: predictedDstPorts, future: punchFuture) - if puncher.natProps.natType == SymmetricRandom and puncher.srcPorts.len > 1: - # our NAT is of the evil symmetric type with random port allocation. We are + let peerNatProps = getNatProperties(peerPort, peerProbedPorts) + var sockCount = 1 + if puncher.natProps.natType == SymmetricRandom: + # Our NAT is of the evil symmetric type with random port allocation. We are # trying to help the other peer by punching more holes using all our # sockets. - result.srcPorts.add(puncher.srcPorts[1 .. ^1]) + if peerNatProps.natType == SymmetricRandom: + # If the other peer is behind a SymmetricRandom NAT too we give up. + raise newException(PunchHoleError, + "both peers behind symmetric NAT with random port allocation") + sockCount = if isInitiating: + InitiatorMaxSockCount + else: + ResponderMaxSockCount + let predictedDstPorts = predictPortRange(peerNatProps) + result = Attempt(dstIp: peerIp, dstPorts: predictedDstPorts, + future: punchFuture) if puncher.attempts.contains(result): raise newException(PunchHoleError, - "hole punching for given parameters already active") + "hole punching to given destination already active") puncher.attempts.add(result) - echo &"sending msg {msg} to {peerIp}, srcPorts: {result.srcPorts}, dstPorts: {result.dstPorts}" + let srcPorts = puncher.socks[0 .. sockCount - 1].map(srcPort) + echo &"sending msg {msg} to {peerIp}, srcPorts: {srcPorts}, dstPorts: {result.dstPorts}" var peerAddr: Sockaddr_storage var peerSockLen: SockLen try: var defaultTTL: int - for i in 0 .. result.srcPorts.len - 1: + for i in 0 .. sockCount - 1: let sock = puncher.socks[i] - if lowTTL: + if isInitiating: defaultTTL = sock.getFd.getSockOptInt(IPPROTO_IP, IP_TTL) sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2) for dstPort in result.dstPorts: @@ -79,7 +101,7 @@ proc punch(puncher: Puncher, peerIp: IpAddress, peerPort: Port, # TODO: replace asyncdispatch.sendTo with asyncnet.sendTo (Nim 1.4 required) await sendTo(sock.getFd.AsyncFD, msg.cstring, msg.len, cast[ptr SockAddr](addr peerAddr), peerSockLen) - if lowTTL: + if isInitiating: sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, defaultTTL) except OSError as e: raise newException(PunchHoleError, e.msg) @@ -107,8 +129,7 @@ proc handleMsg*(puncher: Puncher, msg: string, sock: AsyncSocket, # We received a SYN packet. We ignore it because we expected it to be # filtered by our NAT. return - let query = Attempt(srcPorts: puncher.srcPorts, dstIp: peerIp, - dstPorts: @[peerPort]) + let query = Attempt(dstIp: peerIp, dstPorts: @[peerPort]) let i = puncher.attempts.find(query) if i != -1: if msg == "ACK": diff --git a/quicp2p.nim b/quicp2p.nim index e519643..d5abce8 100644 --- a/quicp2p.nim +++ b/quicp2p.nim @@ -18,7 +18,6 @@ import quicly/defaults import quicly/recvstate import quicly/sendstate import quicly/streambuf -import random import server_connection import strformat import strutils @@ -48,8 +47,6 @@ type expectedPeerId: string QuicP2PContext = ref object - socks: seq[AsyncSocket] - puncher: Puncher streamOpen: quicly_stream_open_t nextCid: quicly_cid_plaintext_t signCertCb: ptls_openssl_sign_certificate_t @@ -82,10 +79,6 @@ proc relativeTimeout(ctx: QuicP2PContext): int32 = let delta = nextTimeout - now result = min(delta, int32.high).int32 -proc srcPort(ctx: QuicP2PContext): Port = - let (_, myPort) = ctx.socks[0].getLocalAddr() - myPort - proc peerId(ctx: QuicP2PContext): string = assert(ctx.tlsCtx.certificates.count == 2) let firstCertAddr = cast[ByteAddress](ctx.tlsCtx.certificates.list) @@ -214,8 +207,7 @@ proc verifyCerts(self: ptr ptls_verify_certificate_t, tls: ptr ptls_t, X509_STORE_free(store) X509_free(caCert) -proc initContext(socks: seq[AsyncSocket], certChainPath: string, - keyPath: string, +proc initContext(certChainPath: string, keyPath: string, streamOpenCb: typeof(quicly_stream_open_t.cb)): QuicP2PContext = var tlsCtx = ptls_context_t(randomBytes: ptls_openssl_random_bytes, @@ -223,8 +215,7 @@ proc initContext(socks: seq[AsyncSocket], certChainPath: string, keyExchanges: ptls_openssl_key_exchanges, cipherSuites: ptls_openssl_cipher_suites) quicly_amend_ptls_context(addr tlsCtx) - result = QuicP2PContext(socks: socks, - streamOpen: quicly_stream_open_t(cb: streamOpenCb), + result = QuicP2PContext(streamOpen: quicly_stream_open_t(cb: streamOpenCb), verifyCertsCb: ptls_verify_certificate_t(cb: verifyCerts), tlsCtx: tlsCtx, quiclyCtx: quicly_spec_context) result.quiclyCtx.tls = addr result.tlsCtx @@ -350,19 +341,23 @@ proc handleNotification(puncher: Puncher, notification: NotifyPeer) {.async.} = discard await attempt.finalize() proc runApp(ctx: QuicP2PContext, peerId: string) {.async.} = - let serverConn = await initServerConnection(rendezvousServers[0].hostname, + let primarySock = newAsyncSocket(sockType = SOCK_DGRAM, + protocol = IPPROTO_UDP, buffered = false) + primarySock.bindAddr(Port(0)) + let serverConn = await initServerConnection(primarySock, + rendezvousServers[0].hostname, rendezvousServers[0].port, - ctx.srcPort, rendezvousServers) - let puncher = initPuncher(ctx.socks, serverConn.probedSrcPorts) + rendezvousServers) + let puncher = initPuncher(primarySock, serverConn.probedSrcPorts) asyncCheck handleServerMessages(serverConn) - for sock in ctx.socks: + for sock in puncher.socks: asyncCheck receive(ctx, puncher, sock, peerId) if peerId.len == 0: # We are the responder let probedPorts = serverConn.probedSrcPorts.join(",") - let req = &"{ctx.peerId}|{serverConn.probedIp}|{ctx.srcPort}|{probedPorts}" + let req = &"{ctx.peerId}|{serverConn.probedIp}|{puncher.primarySrcPort}|{probedPorts}" discard await serverConn.sendRequest("register", req) while true: let (hasData, data) = await serverConn.peerNotifications.read() @@ -382,7 +377,7 @@ proc runApp(ctx: QuicP2PContext, peerId: string) {.async.} = let peerInfo = parseMessage[OkGetPeerInfo](serverResponse) let myProbedPorts = serverConn.probedSrcPorts.join(",") let peerProbedPorts = peerInfo.probedPorts.join(",") - let req = &"{ctx.peerId}|{peerId}|{serverConn.probedIp}|{ctx.srcPort}|{myProbedPorts}|{peerInfo.ip}|{peerInfo.localPort}|{peerProbedPorts}" + let req = &"{ctx.peerId}|{peerId}|{serverConn.probedIp}|{puncher.primarySrcPort}|{myProbedPorts}|{peerInfo.ip}|{peerInfo.localPort}|{peerProbedPorts}" let attempt = await puncher.initiate(peerInfo.ip, peerInfo.localPort, peerInfo.probedPorts) discard await serverConn.sendRequest("notify-peer", req) @@ -391,31 +386,15 @@ proc runApp(ctx: QuicP2PContext, peerId: string) {.async.} = proc main() = var ctx: QuicP2PContext - var socks = newSeq[AsyncSocket]() - randomize() - for i in 0 .. 70: - # FIXME: close socks - let sock = newAsyncSocket(sockType = SOCK_DGRAM, protocol = IPPROTO_UDP, - buffered = false) - let srcPort = rand(Port(1024) .. Port.high) - # FIXME: Once we start using UDP for endpoint probing (currently done in - # initServerConnction) we either have to - # - finish the probing before we bind the srcPort here - # - pass the primary socket to initServerConnection - sock.bindAddr(srcPort) - socks.add(sock) - case paramCount(): of 0: - ctx = initContext(socks, serverCertChainPath, serverKeyPath, - onServerStreamOpen) + ctx = initContext(serverCertChainPath, serverKeyPath, onServerStreamOpen) ctx.tlsCtx.require_client_authentication = 1 asyncCheck runApp(ctx, "") of 1: let peerId = paramStr(1) - ctx = initContext(socks, clientCertChainPath, clientKeyPath, - onClientStreamOpen) + ctx = initContext(clientCertChainPath, clientKeyPath, onClientStreamOpen) asyncCheck runApp(ctx, peerId) else: diff --git a/server_connection.nim b/server_connection.nim index 137812f..dafcac1 100644 --- a/server_connection.nim +++ b/server_connection.nim @@ -8,7 +8,6 @@ type outMessages: TableRef[string, Future[string]] peerNotifications*: FutureStream[string] probedIp*: IpAddress - srcPort*: Port probedSrcPorts*: seq[Port] ServerError* = object of ValueError @@ -32,39 +31,41 @@ type dstPort*: Port probedDstPorts*: seq[Port] -proc getEndpoint(srcPort: Port, serverHostname: string, serverPort: Port): +proc getEndpoint(sock: AsyncSocket, serverHostname: string, serverPort: Port): Future[OkGetEndpoint] {.async.} = - let sock = newAsyncSocket() + # TODO: use sock (UDP socket) for probing + let tcpSock = newAsyncSocket() + let (_, srcPort) = sock.getLocalAddr var failCount = 0 while true: try: - sock.bindAddr(srcPort) + tcpSock.bindAddr(srcPort) break except OSError as e: if failCount == 3: raise e failCount.inc await sleepAsync(100) - await sock.connect(serverHostname, serverPort) + await tcpSock.connect(serverHostname, serverPort) let id = rand(uint32) - await sock.send(&"get-endpoint|{id}\n") - let line = await sock.recvLine(maxLength = 400) + await tcpSock.send(&"get-endpoint|{id}\n") + let line = await tcpSock.recvLine(maxLength = 400) let args = line.parseArgs(3) assert(args[0] == "ok") assert(args[1] == $id) result = parseMessage[OkGetEndpoint](args[2]) - let emptyLine = await sock.recvLine(maxLength = 400) + let emptyLine = await tcpSock.recvLine(maxLength = 400) assert(emptyLine.len == 0) - sock.close() + tcpSock.close() -proc initServerConnection*(serverHostname: string, serverPort: Port, - srcPort: Port, probingServers: seq[Endpoint]): +proc initServerConnection*(sock: AsyncSocket, serverHostname: string, + serverPort: Port, probingServers: seq[Endpoint]): Future[ServerConnection] {.async.} = + let peerNotifications = newFutureStream[string]("initServerConnection") result = ServerConnection(outMessages: newTable[string, Future[string]](), - peerNotifications: newFutureStream[string]("initServerConnection"), - srcPort: srcPort) + peerNotifications: peerNotifications) for s in probingServers: - let endpoint = await getEndpoint(srcPort, s.hostname, s.port) + let endpoint = await getEndpoint(sock, s.hostname, s.port) # FIXME: what if we get get different IPs from different servers result.probedIp = endpoint.ip result.probedSrcPorts.add(endpoint.port)