store transport protocol in an Attempt; consider protocol when comparing attempts; puncher.getProtocol not needed anymore

This commit is contained in:
Christian Ulrich 2020-10-25 10:46:29 +01:00
parent 5c6050faaf
commit 5bdb69f214
No known key found for this signature in database
GPG Key ID: 8241BE099775A097
4 changed files with 28 additions and 28 deletions

View File

@ -31,10 +31,10 @@ proc sendToClient(unixSock: AsyncSocket, msg: string,
let unixFd = unixSock.getFd.AsyncFD let unixFd = unixSock.getFd.AsyncFD
await unixFd.asyncSendMsg(msg, cmsgs) await unixFd.asyncSendMsg(msg, cmsgs)
proc findAttemptsByLocalAddr(punchd: Punchd, srcIp: IpAddress, proc findAttemptsByLocalAddr(punchd: Punchd, protocol: Protocol,
srcPort: Port): seq[Attempt] = srcIp: IpAddress, srcPort: Port): seq[Attempt] =
proc matchesLocalAddr(a: Attempt): bool = proc matchesLocalAddr(a: Attempt): bool =
a.srcIp == srcIp and a.srcPort == srcPort a.protocol == protocol and a.srcIp == srcIp and a.srcPort == srcPort
punchd.attempts.filter(matchesLocalAddr) punchd.attempts.filter(matchesLocalAddr)
proc acceptConnections(punchd: Punchd, ip: IpAddress, port: Port, proc acceptConnections(punchd: Punchd, ip: IpAddress, port: Port,
@ -68,19 +68,20 @@ proc acceptConnections(punchd: Punchd, ip: IpAddress, port: Port,
else: else:
let acceptFuture = punchd.attempts[i].acceptFuture.get() let acceptFuture = punchd.attempts[i].acceptFuture.get()
acceptFuture.complete(peer) acceptFuture.complete(peer)
let localAddrMatches = punchd.findAttemptsByLocalAddr(ip, port) let localAddrMatches = punchd.findAttemptsByLocalAddr(protocol, ip, port)
if localAddrMatches.len() <= 1: if localAddrMatches.len() <= 1:
break break
sock.close() sock.close()
proc addAttempt(punchd: Punchd, attempt: Attempt, puncher: Puncher) = proc addAttempt(punchd: Punchd, attempt: Attempt) =
let localAddrMatches = punchd.findAttemptsByLocalAddr(attempt.srcIp, let localAddrMatches = punchd.findAttemptsByLocalAddr(attempt.protocol,
attempt.srcIp,
attempt.srcPort) attempt.srcPort)
punchd.attempts.add(attempt) punchd.attempts.add(attempt)
if localAddrMatches.len() == 0: if localAddrMatches.len() == 0:
if attempt.acceptFuture.isSome(): if attempt.acceptFuture.isSome():
asyncCheck punchd.acceptConnections(attempt.srcIp, attempt.srcPort, asyncCheck punchd.acceptConnections(attempt.srcIp, attempt.srcPort,
puncher.getProtocol()) attempt.protocol)
elif localAddrMatches.contains(attempt): elif localAddrMatches.contains(attempt):
raise newException(PunchHoleError, raise newException(PunchHoleError,
"hole punching for given parameters already active") "hole punching for given parameters already active")
@ -101,7 +102,7 @@ proc handleRequest(punchd: Punchd, line: string,
case args[0]: case args[0]:
of "initiate": of "initiate":
attempt = puncher.parseInitiateRequest(args[3]) attempt = puncher.parseInitiateRequest(args[3])
punchd.addAttempt(attempt, puncher) punchd.addAttempt(attempt)
proc progress(extraArgs: string) {.async.} = proc progress(extraArgs: string) {.async.} =
let msg = &"progress|{id}|{args[2]}|{args[3]}|{extraArgs}\n" let msg = &"progress|{id}|{args[2]}|{args[3]}|{extraArgs}\n"
await sendToClient(unixSock, msg) await sendToClient(unixSock, msg)
@ -110,7 +111,7 @@ proc handleRequest(punchd: Punchd, line: string,
of "respond": of "respond":
attempt = puncher.parseRespondRequest(args[3]) attempt = puncher.parseRespondRequest(args[3])
punchd.addAttempt(attempt, puncher) punchd.addAttempt(attempt)
sock = await puncher.respond(attempt) sock = await puncher.respond(attempt)
punchd.removeAttempt(attempt) punchd.removeAttempt(attempt)

View File

@ -16,6 +16,7 @@ type
## protocol can be obtained by calling ``getProcotol``. The puncher expects ## protocol can be obtained by calling ``getProcotol``. The puncher expects
## the caller to complete the future when a connections from ## the caller to complete the future when a connections from
## ``dstIp``:``dstPort`` has been accepted. ## ``dstIp``:``dstPort`` has been accepted.
protocol*: Protocol
srcIp*: IpAddress srcIp*: IpAddress
srcPort*: Port srcPort*: Port
dstIp*: IpAddress dstIp*: IpAddress
@ -40,10 +41,11 @@ const Timeout* = 3000
proc `==`*(a, b: Attempt): bool = proc `==`*(a, b: Attempt): bool =
## ``==`` for hole punching attempts. ## ``==`` for hole punching attempts.
## ##
## Two hole punching attempts are considered equal if their ``srcIp``, ## Two hole punching attempts are considered equal if their ``protocol`` is
## ``srcPort`` and ``dstIp`` are equal and their ``dstPorts`` overlap. ## the same, ``srcIp``, ``srcPort`` and ``dstIp`` are equal and their
a.srcIp == b.srcIp and a.srcPort == b.srcPort and a.dstIp == b.dstIp and ## ``dstPorts`` overlap.
a.dstPorts.any(proc (p: Port): bool = p in b.dstPorts) a.protocol == b.protocol and a.srcIp == b.srcIp and a.srcPort == b.srcPort and
a.dstIp == b.dstIp and a.dstPorts.any(proc (p: Port): bool = p in b.dstPorts)
method cleanup*(attempt: Attempt): Future[void] {.base, async.} = method cleanup*(attempt: Attempt): Future[void] {.base, async.} =
## Cleans up when an attempt finished (either successful or not). ## Cleans up when an attempt finished (either successful or not).
@ -51,10 +53,6 @@ method cleanup*(attempt: Attempt): Future[void] {.base, async.} =
## Does nothing. Override for custom attempt types. ## Does nothing. Override for custom attempt types.
discard discard
method getProtocol*(puncher: Puncher): Protocol {.base.} =
## Returns the transport protocol the puncher employs.
raise newException(CatchableError, "Method without implementation override")
method parseInitiateRequest*(puncher: Puncher, args: string): Attempt {.base.} = method parseInitiateRequest*(puncher: Puncher, args: string): Attempt {.base.} =
## Creates a new hole punching attempt by parsing arguments of an ``initiate`` ## Creates a new hole punching attempt by parsing arguments of an ``initiate``
## request. ## request.

View File

@ -56,23 +56,22 @@ proc initTcpNutssPuncher*(): TcpNutssPuncher =
randomize() randomize()
TcpNutssPuncher() TcpNutssPuncher()
method getProtocol*(puncher: TcpNutssPuncher): Protocol =
IPPROTO_TCP
method parseInitiateRequest*(puncher: TcpNutssPuncher, args: string): Attempt = method parseInitiateRequest*(puncher: TcpNutssPuncher, args: string): Attempt =
let parsed = parseMessage[InitiateRequest](args) let parsed = parseMessage[InitiateRequest](args)
let localIp = getPrimaryIPAddr(parsed.dstIp) let localIp = getPrimaryIPAddr(parsed.dstIp)
let predictedDstPorts = predictPortRange(parsed.dstPorts) let predictedDstPorts = predictPortRange(parsed.dstPorts)
let acceptFuture = newFuture[AsyncSocket]("parseInitiateRequest") let acceptFuture = newFuture[AsyncSocket]("parseInitiateRequest")
Attempt(srcIp: localIp, srcPort: parsed.srcPorts[0], dstIp: parsed.dstIp, Attempt(protocol: IPPROTO_TCP, srcIp: localIp, srcPort: parsed.srcPorts[0],
dstPorts: predictedDstPorts, acceptFuture: some(acceptFuture)) dstIp: parsed.dstIp, dstPorts: predictedDstPorts,
acceptFuture: some(acceptFuture))
method parseRespondRequest*(puncher: TcpNutssPuncher, args: string): Attempt = method parseRespondRequest*(puncher: TcpNutssPuncher, args: string): Attempt =
let parsed = parseMessage[RespondRequest](args) let parsed = parseMessage[RespondRequest](args)
let localIp = getPrimaryIPAddr(parsed.dstIp) let localIp = getPrimaryIPAddr(parsed.dstIp)
let predictedDstPorts = predictPortRange(parsed.dstPorts) let predictedDstPorts = predictPortRange(parsed.dstPorts)
Attempt(srcIp: localIp, srcPort: parsed.srcPorts[0], dstIp: parsed.dstIp, Attempt(protocol: IPPROTO_TCP, srcIp: localIp, srcPort: parsed.srcPorts[0],
dstPorts: predictedDstPorts, acceptFuture: none(Future[AsyncSocket])) dstIp: parsed.dstIp, dstPorts: predictedDstPorts,
acceptFuture: none(Future[AsyncSocket]))
method initiate*(puncher: TcpNutssPuncher, attempt: Attempt, method initiate*(puncher: TcpNutssPuncher, attempt: Attempt,
progress: PunchProgressCb): Future[AsyncSocket] {.async.} = progress: PunchProgressCb): Future[AsyncSocket] {.async.} =

View File

@ -137,8 +137,9 @@ method parseInitiateRequest*(puncher: TcpSyniPuncher, args: string): Attempt =
let parsed = parseMessage[InitiateRequest](args) let parsed = parseMessage[InitiateRequest](args)
let localIp = getPrimaryIPAddr(parsed.dstIp) let localIp = getPrimaryIPAddr(parsed.dstIp)
let predictedDstPorts = predictPortRange(parsed.dstPorts) let predictedDstPorts = predictPortRange(parsed.dstPorts)
TcpSyniInitiateAttempt(srcIp: localIp, srcPort: parsed.srcPorts[0], TcpSyniInitiateAttempt(protocol: IPPROTO_TCP, srcIp: localIp,
dstIp: parsed.dstIp, dstPorts: predictedDstPorts, srcPort: parsed.srcPorts[0], dstIp: parsed.dstIp,
dstPorts: predictedDstPorts,
acceptFuture: none(Future[AsyncSocket])) acceptFuture: none(Future[AsyncSocket]))
method parseRespondRequest*(puncher: TcpSyniPuncher, args: string): Attempt = method parseRespondRequest*(puncher: TcpSyniPuncher, args: string): Attempt =
@ -146,8 +147,9 @@ method parseRespondRequest*(puncher: TcpSyniPuncher, args: string): Attempt =
let localIp = getPrimaryIPAddr(parsed.dstIp) let localIp = getPrimaryIPAddr(parsed.dstIp)
let predictedDstPorts = predictPortRange(parsed.dstPorts) let predictedDstPorts = predictPortRange(parsed.dstPorts)
let acceptFuture = newFuture[AsyncSocket]("parseRespondRequest") let acceptFuture = newFuture[AsyncSocket]("parseRespondRequest")
TcpSyniRespondAttempt(srcIp: localIp, srcPort: parsed.srcPorts[0], TcpSyniRespondAttempt(protocol: IPPROTO_TCP, srcIp: localIp,
dstIp: parsed.dstIp, dstPorts: predictedDstPorts, srcPort: parsed.srcPorts[0], dstIp: parsed.dstIp,
dstPorts: predictedDstPorts,
acceptFuture: some(acceptFuture), acceptFuture: some(acceptFuture),
seqNums: parsed.seqNums) seqNums: parsed.seqNums)