change puncher interface to allow accepting multiple connections
This commit is contained in:
parent
dce5115c5c
commit
84cb8611ef
25
punchd.nim
25
punchd.nim
|
@ -13,7 +13,7 @@ const PunchdSocket = "/tmp/punchd.socket"
|
|||
type
|
||||
Punchd = ref object
|
||||
unixSocket: AsyncSocket
|
||||
punchers: seq[TcpSyniPuncher]
|
||||
tcpSyniPuncher: TcpSyniPuncher
|
||||
|
||||
Sigint = object of CatchableError
|
||||
|
||||
|
@ -52,16 +52,13 @@ proc handleRequest(punchd: Punchd, line: string,
|
|||
$req.dstIp, req.dstPorts.join(","),
|
||||
seqNumbers.join(",")].join("|")
|
||||
await unixSock.send(&"progress|{id}|{content}\n")
|
||||
puncher = initPuncher(req.srcPorts[0], req.dstIp, req.dstPorts)
|
||||
punchd.punchers.add(puncher)
|
||||
sock = await puncher.connect(handleSeqNumbers)
|
||||
|
||||
sock = await punchd.tcpSyniPuncher.connect(req.srcPorts[0], req.dstIp,
|
||||
req.dstPorts, handleSeqNumbers)
|
||||
|
||||
of "tcp-syni-accept":
|
||||
let req = parseMessage[TcpSyniAccept](args[2])
|
||||
puncher = initPuncher(req.srcPorts[0], req.dstIp, req.dstPorts,
|
||||
req.seqNums)
|
||||
punchd.punchers.add(puncher)
|
||||
sock = await puncher.accept()
|
||||
sock = await punchd.tcpSyniPuncher.accept(req.srcPorts[0], req.dstIp,
|
||||
req.dstPorts, req.seqNums)
|
||||
|
||||
else:
|
||||
raise newException(ValueError, "invalid request")
|
||||
|
@ -91,19 +88,19 @@ proc handleUsers(punchd: Punchd) {.async.} =
|
|||
proc main() =
|
||||
setControlCHook(handleSigint)
|
||||
removeFile(PunchdSocket)
|
||||
let punchd = Punchd(unixSocket: newAsyncSocket(AF_UNIX, SOCK_STREAM, IPPROTO_IP))
|
||||
punchd.unixSocket.bindUnix(PunchdSocket)
|
||||
let unixSocket = newAsyncSocket(AF_UNIX, SOCK_STREAM, IPPROTO_IP)
|
||||
unixSocket.bindUnix(PunchdSocket)
|
||||
unixSocket.listen()
|
||||
setFilePermissions(PunchdSocket,
|
||||
{fpUserRead, fpUserWrite, fpGroupRead, fpGroupWrite,
|
||||
fpOthersRead, fpOthersWrite})
|
||||
punchd.unixSocket.listen()
|
||||
let punchd = Punchd(unixSocket: unixSocket, tcpSyniPuncher: initPuncher())
|
||||
asyncCheck handleUsers(punchd)
|
||||
try:
|
||||
runForever()
|
||||
|
||||
except Sigint:
|
||||
for puncher in punchd.punchers:
|
||||
waitFor puncher.cleanup()
|
||||
waitFor punchd.tcpSyniPuncher.cleanup()
|
||||
punchd.unixSocket.close()
|
||||
removeFile(PunchdSocket)
|
||||
quit(0)
|
||||
|
|
254
tcp_syni.nim
254
tcp_syni.nim
|
@ -1,6 +1,7 @@
|
|||
import asyncfutures, asyncdispatch, asyncnet, strformat
|
||||
from net import IpAddress, Port, `$`, `==`, getPrimaryIPAddr, toSockAddr
|
||||
from net import IpAddress, Port, `$`, `==`, getPrimaryIPAddr, toSockAddr, parseIpAddress
|
||||
from nativesockets import SockAddr, Sockaddr_storage, SockLen, setSockOptInt
|
||||
from sequtils import any
|
||||
import asyncutils
|
||||
import ip_packet
|
||||
import network_interface
|
||||
|
@ -12,13 +13,25 @@ var IP_TTL {.importc: "IP_TTL", header: "<netinet/in.h>".}: cint
|
|||
const Timeout = 3000
|
||||
|
||||
type
|
||||
TcpSyniPuncher* = ref object
|
||||
ConnectAttempt* = ref object
|
||||
srcIp: IpAddress
|
||||
srcPort: Port
|
||||
dstIp: IpAddress
|
||||
dstPorts: seq[Port]
|
||||
firewallRules: seq[string]
|
||||
|
||||
AcceptAttempt* = ref object
|
||||
srcIp: IpAddress
|
||||
srcPort: Port
|
||||
dstIp: IpAddress
|
||||
dstPorts: seq[Port]
|
||||
seqNums: seq[uint32]
|
||||
firewallRules: seq[string]
|
||||
future: Future[AsyncSocket]
|
||||
|
||||
TcpSyniPuncher* = ref object
|
||||
connectAttempts: seq[ConnectAttempt]
|
||||
acceptAttempts: seq[AcceptAttempt]
|
||||
|
||||
PunchProgressCb* = proc (seqNums: seq[uint32]) {.async.}
|
||||
|
||||
|
@ -47,6 +60,26 @@ proc iptablesDelete(chain: string, rule: string) {.async.} =
|
|||
let firewall_cmd = fmt"iptables -D {chain} {rule}"
|
||||
discard await asyncExecCmd(firewall_cmd)
|
||||
|
||||
proc addFirewallRules[T](attempt: T) {.async.} =
|
||||
for dstPort in attempt.dstPorts:
|
||||
let rule = makeFirewallRule(attempt.srcIp, attempt.srcPort,
|
||||
attempt.dstIp, dstPort)
|
||||
try:
|
||||
await iptablesInsert("INPUT", rule)
|
||||
attempt.firewallRules.add(rule)
|
||||
except OSError as e:
|
||||
echo "cannot add firewall rule: ", e.msg
|
||||
raise newException(PunchHoleError, e.msg)
|
||||
|
||||
proc deleteFirewallRules[T](attempt: T) {.async.} =
|
||||
for rule in attempt.firewallRules:
|
||||
# FIXME: close sock?
|
||||
try:
|
||||
await iptablesDelete("INPUT", rule)
|
||||
except OSError:
|
||||
# At least we tried
|
||||
discard
|
||||
|
||||
proc injectTcpPacket(rawFd: AsyncFD, ipPacket: IpPacket) {.async.} =
|
||||
assert(ipPacket.protocol == tcp)
|
||||
try:
|
||||
|
@ -59,27 +92,27 @@ proc injectTcpPacket(rawFd: AsyncFD, ipPacket: IpPacket) {.async.} =
|
|||
except OSError as e:
|
||||
raise newException(PunchHoleError, e.msg)
|
||||
|
||||
proc captureSeqNumbers(puncher: TcpSyniPuncher, rawFd: AsyncFD,
|
||||
proc captureSeqNumbers(attempt: ConnectAttempt, rawFd: AsyncFD,
|
||||
cb: PunchProgressCb) {.async.} =
|
||||
# FIXME: timeout?
|
||||
var seqNums = newSeq[uint32]()
|
||||
while seqNums.len < puncher.dstPorts.len:
|
||||
while seqNums.len < attempt.dstPorts.len:
|
||||
let packet = await rawFd.recv(4000)
|
||||
if packet == "":
|
||||
break
|
||||
let parsed = parseEthernetPacket(packet)
|
||||
if parsed.protocol == tcp and
|
||||
parsed.ipAddrSrc == puncher.srcIp and
|
||||
parsed.tcpPortSrc.int == puncher.srcPort.int and
|
||||
parsed.ipAddrDst == puncher.dstIp and
|
||||
parsed.ipAddrSrc == attempt.srcIp and
|
||||
parsed.tcpPortSrc.int == attempt.srcPort.int and
|
||||
parsed.ipAddrDst == attempt.dstIp and
|
||||
parsed.tcpFlags == {SYN}:
|
||||
for port in puncher.dstPorts:
|
||||
for port in attempt.dstPorts:
|
||||
if parsed.tcpPortDst.int == port.int:
|
||||
seqNums.add(parsed.tcpSeqNumber)
|
||||
break
|
||||
await cb(seqNums)
|
||||
|
||||
proc captureAndResendAck(puncher: TcpSyniPuncher, captureFd: AsyncFD,
|
||||
proc captureAndResendAck(attempt: ConnectAttempt, captureFd: AsyncFD,
|
||||
injectFd: AsyncFD) {.async.} =
|
||||
while true:
|
||||
let packet = await captureFd.recv(4000)
|
||||
|
@ -87,47 +120,56 @@ proc captureAndResendAck(puncher: TcpSyniPuncher, captureFd: AsyncFD,
|
|||
break
|
||||
var parsed = parseEthernetPacket(packet)
|
||||
if parsed.protocol == tcp and
|
||||
parsed.ipAddrSrc == puncher.srcIp and
|
||||
parsed.tcpPortSrc.int == puncher.srcPort.int and
|
||||
parsed.ipAddrDst == puncher.dstIp and
|
||||
parsed.ipAddrSrc == attempt.srcIp and
|
||||
parsed.tcpPortSrc.int == attempt.srcPort.int and
|
||||
parsed.ipAddrDst == attempt.dstIp and
|
||||
parsed.tcpFlags == {ACK}:
|
||||
for port in puncher.dstPorts:
|
||||
for port in attempt.dstPorts:
|
||||
if parsed.tcpPortDst.int == port.int:
|
||||
parsed.ipTTL = 64
|
||||
echo &"[{parsed.ipAddrSrc}:{parsed.tcpPortSrc.int} -> {parsed.ipAddrDst}:{parsed.tcpPortDst}, SEQ {parsed.tcpSeqNumber}] resending ACK with TTL {parsed.ipTTL}"
|
||||
await injectFd.injectTcpPacket(parsed)
|
||||
return
|
||||
|
||||
proc initPuncher*(srcPort: Port, dstIp: IpAddress, dstPorts: seq[Port],
|
||||
seqNums: seq[uint32] = @[]): TcpSyniPuncher =
|
||||
let localIp = getPrimaryIPAddr(dstIp)
|
||||
# TODO: do real port prediction
|
||||
var predictedDstPorts = newSeq[Port](1)
|
||||
let basePort = min(dstPorts[1].uint16,
|
||||
uint16.high - (predictedDstPorts.len - 1).uint16)
|
||||
for i in 0 .. predictedDstPorts.len - 1:
|
||||
predictedDstPorts[i] = Port(basePort + i.uint16)
|
||||
result = TcpSyniPuncher(srcIp: localIp, srcPort: srcPort, dstIp: dstIp,
|
||||
dstPorts: predictedDstPorts, seqNums: seqNums)
|
||||
proc initPuncher*(): TcpSyniPuncher = TcpSyniPuncher()
|
||||
|
||||
proc addFirewallRules(puncher: TcpSyniPuncher) {.async.} =
|
||||
for dstPort in puncher.dstPorts:
|
||||
let rule = makeFirewallRule(puncher.srcIp, puncher.srcPort,
|
||||
puncher.dstIp, dstPort)
|
||||
try:
|
||||
await iptablesInsert("INPUT", rule)
|
||||
puncher.firewallRules.add(rule)
|
||||
except OSError as e:
|
||||
echo "cannot add firewall rule: ", e.msg
|
||||
raise newException(PunchHoleError, e.msg)
|
||||
proc findConnectAttempt(puncher: TcpSyniPuncher, srcIp: IpAddress,
|
||||
srcPort: Port, dstIp: IpAddress,
|
||||
dstPorts: seq[Port]): int =
|
||||
for (index, attempt) in puncher.connectAttempts.pairs():
|
||||
if attempt.srcIp == srcIp and attempt.srcPort == srcPort and
|
||||
attempt.dstIp == dstIp and
|
||||
attempt.dstPorts.any(proc (p: Port): bool = p in dstPorts):
|
||||
return index
|
||||
|
||||
proc findAcceptAttempt(puncher: TcpSyniPuncher, srcIp: IpAddress,
|
||||
srcPort: Port, dstIp: IpAddress,
|
||||
dstPorts: seq[Port]): int =
|
||||
for (index, attempt) in puncher.acceptAttempts.pairs():
|
||||
if attempt.srcIp == srcIp and attempt.srcPort == srcPort and
|
||||
attempt.dstIp == dstIp and
|
||||
attempt.dstPorts.any(proc (p: Port): bool = p in dstPorts):
|
||||
return index
|
||||
|
||||
proc findAcceptAttemptsByLocalAddr(puncher: TcpSyniPuncher, address: IpAddress,
|
||||
port: Port): seq[AcceptAttempt] =
|
||||
for attempt in puncher.acceptAttempts:
|
||||
if attempt.srcIp == address and attempt.srcPort == port:
|
||||
result.add(attempt)
|
||||
|
||||
proc predictPortRange(dstPorts: seq[Port]): seq[Port] =
|
||||
# TODO: do real port prediction
|
||||
result = newSeq[Port](1)
|
||||
let basePort = min(dstPorts[1].uint16,
|
||||
uint16.high - (result.len - 1).uint16)
|
||||
for i in 0 .. result.len - 1:
|
||||
result[i] = Port(basePort + i.uint16)
|
||||
|
||||
proc cleanup*(puncher: TcpSyniPuncher) {.async.} =
|
||||
for rule in puncher.firewallRules:
|
||||
try:
|
||||
await iptablesDelete("INPUT", rule)
|
||||
except OSError:
|
||||
# At least we tried
|
||||
discard
|
||||
for attempt in puncher.connectAttempts:
|
||||
await attempt.deleteFirewallRules()
|
||||
for attempt in puncher.acceptAttempts:
|
||||
await attempt.deleteFirewallRules()
|
||||
|
||||
proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port,
|
||||
future: Future[AsyncSocket]) {.async.} =
|
||||
|
@ -144,39 +186,45 @@ proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port,
|
|||
echo &"connection {srcIP}:{srcPort.int} -> {dstIp}:{dstPort.int} failed: ", e.msg
|
||||
discard
|
||||
|
||||
proc connectParallel(puncher: TcpSyniPuncher): Future[AsyncSocket] =
|
||||
result = newFuture[AsyncSocket]("doConnect")
|
||||
for dstPort in puncher.dstPorts:
|
||||
asyncCheck doConnect(puncher.srcIp, puncher.srcPort, puncher.dstIp, dstPort, result)
|
||||
|
||||
proc connect*(puncher: TcpSyniPuncher,
|
||||
proc connect*(puncher: TcpSyniPuncher, srcPort: Port, dstIp: IpAddress,
|
||||
dstPorts: seq[Port],
|
||||
progressCb: PunchProgressCb): Future[AsyncSocket] {.async.} =
|
||||
let iface = fromIpAddress(puncher.srcIp)
|
||||
let localIp = getPrimaryIPAddr(dstIp)
|
||||
if puncher.findConnectAttempt(localIp, srcPort, dstIp, dstPorts) != -1:
|
||||
raise newException(PunchHoleError, "hole punching for given parameters already active")
|
||||
let attempt = ConnectAttempt(srcIp: localIp, srcPort: srcPort, dstIp: dstIp,
|
||||
dstPorts: predictPortRange(dstPorts))
|
||||
puncher.connectAttempts.add(attempt)
|
||||
await attempt.addFirewallRules()
|
||||
let iface = fromIpAddress(attempt.srcIp)
|
||||
let captureSeqFd = setupEthernetCapturingSocket(iface)
|
||||
let captureAckFd = setupEthernetCapturingSocket(iface)
|
||||
let injectAckFd = setupTcpInjectingSocket()
|
||||
asyncCheck puncher.captureSeqNumbers(captureSeqFd, progressCb)
|
||||
asyncCheck puncher.captureAndResendAck(captureAckFd, injectAckFd)
|
||||
await puncher.addFirewallRules()
|
||||
asyncCheck attempt.captureSeqNumbers(captureSeqFd, progressCb)
|
||||
asyncCheck attempt.captureAndResendAck(captureAckFd, injectAckFd)
|
||||
try:
|
||||
let connectParallelFuture = puncher.connectParallel()
|
||||
await connectParallelFuture or sleepAsync(Timeout)
|
||||
await puncher.cleanup()
|
||||
if connectParallelFuture.finished():
|
||||
result = connectParallelFuture.read()
|
||||
let connectFuture = newFuture[AsyncSocket]("connect")
|
||||
for dstPort in attempt.dstPorts:
|
||||
asyncCheck doConnect(attempt.srcIp, attempt.srcPort, attempt.dstIp,
|
||||
dstPort, connectfuture)
|
||||
await connectFuture or sleepAsync(Timeout)
|
||||
await attempt.deleteFirewallRules()
|
||||
puncher.connectAttempts.del(puncher.connectAttempts.find(attempt))
|
||||
if connectFuture.finished():
|
||||
result = connectFuture.read()
|
||||
else:
|
||||
raise newException(PunchHoleError, "timeout")
|
||||
except OSError as e:
|
||||
raise newException(PunchHoleError, e.msg)
|
||||
|
||||
proc prepareAccept(puncher: TcpSyniPuncher) {.async.} =
|
||||
for dstPort in puncher.dstPorts:
|
||||
proc prepareAccept(attempt: AcceptAttempt) {.async.} =
|
||||
for dstPort in attempt.dstPorts:
|
||||
try:
|
||||
let sock = newAsyncSocket()
|
||||
sock.setSockOpt(OptReuseAddr, true)
|
||||
sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2)
|
||||
sock.bindAddr(puncher.srcPort, $(puncher.srcIp))
|
||||
let connectFuture = sock.connect($(puncher.dstIp), dstPort)
|
||||
sock.bindAddr(attempt.srcPort, $(attempt.srcIp))
|
||||
let connectFuture = sock.connect($(attempt.dstIp), dstPort)
|
||||
await connectFuture or sleepAsync(Timeout)
|
||||
if connectFuture.finished():
|
||||
echo "connected during accept phase"
|
||||
|
@ -184,38 +232,68 @@ proc prepareAccept(puncher: TcpSyniPuncher) {.async.} =
|
|||
except OSError:
|
||||
discard
|
||||
|
||||
proc accept*(puncher: TcpSyniPuncher): Future[AsyncSocket] {.async.} =
|
||||
await puncher.prepareAccept()
|
||||
await puncher.addFirewallRules()
|
||||
try:
|
||||
let rawFd = setupTcpInjectingSocket()
|
||||
for dstPort in puncher.dstPorts:
|
||||
for seqNum in puncher.seqNums:
|
||||
let ipPacket = IpPacket(protocol: tcp,
|
||||
ipAddrSrc: puncher.dstIp,
|
||||
ipAddrDst: puncher.srcIp,
|
||||
ipTTL: 64,
|
||||
tcpPortSrc: dstPort,
|
||||
tcpPortDst: puncher.srcPort,
|
||||
tcpSeqNumber: seqNum,
|
||||
tcpAckNumber: 0,
|
||||
tcpFlags: {SYN},
|
||||
tcpWindowSize: 1452 * 10)
|
||||
echo &"[{ipPacket.ipAddrSrc}:{ipPacket.tcpPortSrc} -> {ipPacket.ipAddrDst}:{ipPacket.tcpPortDst}, SEQ {ipPacket.tcpSeqNumber}] injecting SYN"
|
||||
asyncCheck rawFd.injectTcpPacket(ipPacket)
|
||||
let sock = newAsyncSocket()
|
||||
sock.setSockOpt(OptReuseAddr, true)
|
||||
sock.bindAddr(puncher.srcPort, $(puncher.srcIp))
|
||||
sock.listen()
|
||||
echo &"accepting connections from {puncher.dstIp}:{puncher.dstPorts[0].int}"
|
||||
proc doAccept(puncher: TcpSyniPuncher, srcIp: IpAddress,
|
||||
srcPort: Port) {.async.} =
|
||||
let sock = newAsyncSocket()
|
||||
sock.setSockOpt(OptReuseAddr, true)
|
||||
sock.bindAddr(srcPort, $(srcIp))
|
||||
sock.listen()
|
||||
while true:
|
||||
let acceptFuture = sock.accept()
|
||||
await acceptFuture or sleepAsync(Timeout)
|
||||
await puncher.cleanup()
|
||||
if acceptFuture.finished():
|
||||
result = acceptFuture.read()
|
||||
let peer = acceptFuture.read()
|
||||
let (peerAddr, peerPort) = peer.getPeerAddr()
|
||||
let peerIp = parseIpAddress(peerAddr)
|
||||
let i = puncher.findAcceptAttempt(srcIp, srcPort, peerIp, @[peerPort])
|
||||
if i == -1:
|
||||
echo "Accepted connection, but no attempt found. Discarding."
|
||||
else:
|
||||
let attempt = puncher.acceptAttempts[i]
|
||||
attempt.future.complete(peer)
|
||||
else:
|
||||
let attempts = puncher.findAcceptAttemptsByLocalAddr(srcIp, srcPort)
|
||||
# FIXME: should attempts have timestamps, so we can decide here which ones to delete?
|
||||
if attempts.len() <= 1:
|
||||
break
|
||||
|
||||
proc accept*(puncher: TcpSyniPuncher, srcPort: Port, dstIp: IpAddress,
|
||||
dstPorts: seq[Port],
|
||||
seqNums: seq[uint32]): Future[AsyncSocket] {.async.} =
|
||||
let localIp = getPrimaryIPAddr(dstIp)
|
||||
let existingAttempts = puncher.findAcceptAttemptsByLocalAddr(localIp, srcPort)
|
||||
if existingAttempts.len() == 0:
|
||||
echo &"accepting connections from {dstIp}:{dstPorts[0].int}"
|
||||
asyncCheck puncher.doAccept(localIp, srcPort)
|
||||
else:
|
||||
for a in existingAttempts:
|
||||
if a.dstIp == dstIp and
|
||||
a.dstPorts.any(proc (p: Port): bool = p in dstPorts):
|
||||
raise newException(PunchHoleError, "hole punching for given parameters already active")
|
||||
let attempt = AcceptAttempt(srcIp: localIp, srcPort: srcPort, dstIp: dstIp,
|
||||
dstPorts: dstPorts, seqNums: seqNums,
|
||||
future: newFuture[AsyncSocket]("accept"))
|
||||
puncher.acceptAttempts.add(attempt)
|
||||
await attempt.addFirewallRules()
|
||||
await attempt.prepareAccept()
|
||||
try:
|
||||
let rawFd = setupTcpInjectingSocket()
|
||||
for dstPort in attempt.dstPorts:
|
||||
for seqNum in attempt.seqNums:
|
||||
let ipPacket = IpPacket(protocol: tcp, ipAddrSrc: attempt.dstIp,
|
||||
ipAddrDst: attempt.srcIp, ipTTL: 64,
|
||||
tcpPortSrc: dstPort,
|
||||
tcpPortDst: attempt.srcPort,
|
||||
tcpSeqNumber: seqNum, tcpAckNumber: 0,
|
||||
tcpFlags: {SYN}, tcpWindowSize: 1452 * 10)
|
||||
echo &"[{ipPacket.ipAddrSrc}:{ipPacket.tcpPortSrc} -> {ipPacket.ipAddrDst}:{ipPacket.tcpPortDst}, SEQ {ipPacket.tcpSeqNumber}] injecting SYN"
|
||||
asyncCheck rawFd.injectTcpPacket(ipPacket)
|
||||
await attempt.future or sleepAsync(Timeout)
|
||||
await attempt.deleteFirewallRules()
|
||||
puncher.acceptAttempts.del(puncher.acceptAttempts.find(attempt))
|
||||
if attempt.future.finished():
|
||||
result = attempt.future.read()
|
||||
else:
|
||||
raise newException(PunchHoleError, "timeout")
|
||||
except OSError as e:
|
||||
echo &"accepting connections from {puncher.dstIP}:{puncher.dstPorts[0].int} failed: ", e.msg
|
||||
await puncher.cleanup()
|
||||
raise newException(PunchHoleError, e.msg)
|
||||
echo &"accepting connections from {dstIP}:{dstPorts[0].int} failed: ", e.msg
|
||||
|
|
Loading…
Reference in New Issue