221 lines
8.3 KiB
Nim
221 lines
8.3 KiB
Nim
import asyncfutures, asyncdispatch, asyncnet, strformat
|
|
from net import IpAddress, Port, `$`, `==`, getPrimaryIPAddr, toSockAddr
|
|
from nativesockets import SockAddr, Sockaddr_storage, SockLen, setSockOptInt
|
|
import asyncutils
|
|
import ip_packet
|
|
import network_interface
|
|
import raw_socket
|
|
|
|
var IPPROTO_IP {.importc: "IPPROTO_IP", header: "<netinet/in.h>".}: cint
|
|
var IP_TTL {.importc: "IP_TTL", header: "<netinet/in.h>".}: cint
|
|
|
|
const Timeout = 3000
|
|
|
|
type
|
|
TcpSyniPuncher* = ref object
|
|
srcIp: IpAddress
|
|
srcPort: Port
|
|
dstIp: IpAddress
|
|
dstPorts: seq[Port]
|
|
seqNums: seq[uint32]
|
|
firewallRules: seq[string]
|
|
|
|
PunchProgressCb* = proc (seqNums: seq[uint32]) {.async.}
|
|
|
|
PunchHoleError* = object of ValueError
|
|
|
|
proc makeFirewallRule(srcIp: IpAddress, srcPort: Port,
|
|
dstIp: IpAddress, dstPort: Port): string =
|
|
result = fmt"""-w \
|
|
-d {srcIp} \
|
|
-p icmp \
|
|
--icmp-type time-exceeded \
|
|
-m conntrack \
|
|
--ctstate RELATED \
|
|
--ctproto tcp \
|
|
--ctorigsrc {srcIp} \
|
|
--ctorigsrcport {srcPort.int} \
|
|
--ctorigdst {dstIp} \
|
|
--ctorigdstport {dstPort.int} \
|
|
-j DROP"""
|
|
|
|
proc iptablesInsert(chain: string, rule: string) {.async.} =
|
|
let firewall_cmd = fmt"iptables -I {chain} {rule}"
|
|
discard await asyncExecCmd(firewall_cmd)
|
|
|
|
proc iptablesDelete(chain: string, rule: string) {.async.} =
|
|
let firewall_cmd = fmt"iptables -D {chain} {rule}"
|
|
discard await asyncExecCmd(firewall_cmd)
|
|
|
|
proc injectTcpPacket(rawFd: AsyncFD, ipPacket: IpPacket) {.async.} =
|
|
assert(ipPacket.protocol == tcp)
|
|
try:
|
|
let packet = serialize(ipPacket)
|
|
var sockaddr: Sockaddr_storage
|
|
var sockaddrLen: SockLen
|
|
toSockAddr(ipPacket.ipAddrDst, ipPacket.tcpPortDst, sockaddr, sockaddrLen)
|
|
await rawFd.sendTo(packet.cstring, packet.len,
|
|
cast[ptr SockAddr](addr sockaddr), sockaddrLen)
|
|
except OSError as e:
|
|
raise newException(PunchHoleError, e.msg)
|
|
|
|
proc captureSeqNumbers(puncher: TcpSyniPuncher, rawFd: AsyncFD,
|
|
cb: PunchProgressCb) {.async.} =
|
|
# FIXME: timeout?
|
|
var seqNums = newSeq[uint32]()
|
|
while seqNums.len < puncher.dstPorts.len:
|
|
let packet = await rawFd.recv(4000)
|
|
if packet == "":
|
|
break
|
|
let parsed = parseEthernetPacket(packet)
|
|
if parsed.protocol == tcp and
|
|
parsed.ipAddrSrc == puncher.srcIp and
|
|
parsed.tcpPortSrc.int == puncher.srcPort.int and
|
|
parsed.ipAddrDst == puncher.dstIp and
|
|
parsed.tcpFlags == {SYN}:
|
|
for port in puncher.dstPorts:
|
|
if parsed.tcpPortDst.int == port.int:
|
|
seqNums.add(parsed.tcpSeqNumber)
|
|
break
|
|
await cb(seqNums)
|
|
|
|
proc captureAndResendAck(puncher: TcpSyniPuncher, captureFd: AsyncFD,
|
|
injectFd: AsyncFD) {.async.} =
|
|
while true:
|
|
let packet = await captureFd.recv(4000)
|
|
if packet == "":
|
|
break
|
|
var parsed = parseEthernetPacket(packet)
|
|
if parsed.protocol == tcp and
|
|
parsed.ipAddrSrc == puncher.srcIp and
|
|
parsed.tcpPortSrc.int == puncher.srcPort.int and
|
|
parsed.ipAddrDst == puncher.dstIp and
|
|
parsed.tcpFlags == {ACK}:
|
|
for port in puncher.dstPorts:
|
|
if parsed.tcpPortDst.int == port.int:
|
|
parsed.ipTTL = 64
|
|
echo &"[{parsed.ipAddrSrc}:{parsed.tcpPortSrc.int} -> {parsed.ipAddrDst}:{parsed.tcpPortDst}, SEQ {parsed.tcpSeqNumber}] resending ACK with TTL {parsed.ipTTL}"
|
|
await injectFd.injectTcpPacket(parsed)
|
|
return
|
|
|
|
proc initPuncher*(srcPort: Port, dstIp: IpAddress, dstPorts: seq[Port],
|
|
seqNums: seq[uint32] = @[]): TcpSyniPuncher =
|
|
let localIp = getPrimaryIPAddr(dstIp)
|
|
# TODO: do real port prediction
|
|
var predictedDstPorts = newSeq[Port](1)
|
|
let basePort = min(dstPorts[1].uint16,
|
|
uint16.high - (predictedDstPorts.len - 1).uint16)
|
|
for i in 0 .. predictedDstPorts.len - 1:
|
|
predictedDstPorts[i] = Port(basePort + i.uint16)
|
|
result = TcpSyniPuncher(srcIp: localIp, srcPort: srcPort, dstIp: dstIp,
|
|
dstPorts: predictedDstPorts, seqNums: seqNums)
|
|
|
|
proc addFirewallRules(puncher: TcpSyniPuncher) {.async.} =
|
|
for dstPort in puncher.dstPorts:
|
|
let rule = makeFirewallRule(puncher.srcIp, puncher.srcPort,
|
|
puncher.dstIp, dstPort)
|
|
try:
|
|
await iptablesInsert("INPUT", rule)
|
|
puncher.firewallRules.add(rule)
|
|
except OSError as e:
|
|
echo "cannot add firewall rule: ", e.msg
|
|
raise newException(PunchHoleError, e.msg)
|
|
|
|
proc cleanup*(puncher: TcpSyniPuncher) {.async.} =
|
|
for rule in puncher.firewallRules:
|
|
try:
|
|
await iptablesDelete("INPUT", rule)
|
|
except OSError:
|
|
# At least we tried
|
|
discard
|
|
|
|
proc doConnect(srcIp: IpAddress, srcPort: Port, dstIp: IpAddress, dstPort: Port,
|
|
future: Future[AsyncSocket]) {.async.} =
|
|
let sock = newAsyncSocket()
|
|
sock.setSockOpt(OptReuseAddr, true)
|
|
sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2)
|
|
sock.bindAddr(srcPort, $srcIp)
|
|
try:
|
|
await sock.connect($dstIp, dstPort)
|
|
sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 64)
|
|
future.complete(sock)
|
|
except OSError as e:
|
|
echo &"connection {srcIP}:{srcPort.int} -> {dstIp}:{dstPort.int} failed: ", e.msg
|
|
discard
|
|
|
|
proc connectParallel(puncher: TcpSyniPuncher): Future[AsyncSocket] =
|
|
result = newFuture[AsyncSocket]("doConnect")
|
|
for dstPort in puncher.dstPorts:
|
|
asyncCheck doConnect(puncher.srcIp, puncher.srcPort, puncher.dstIp, dstPort, result)
|
|
|
|
proc connect*(puncher: TcpSyniPuncher,
|
|
progressCb: PunchProgressCb): Future[AsyncSocket] {.async.} =
|
|
let iface = fromIpAddress(puncher.srcIp)
|
|
let captureSeqFd = setupEthernetCapturingSocket(iface)
|
|
let captureAckFd = setupEthernetCapturingSocket(iface)
|
|
let injectAckFd = setupTcpInjectingSocket()
|
|
asyncCheck puncher.captureSeqNumbers(captureSeqFd, progressCb)
|
|
asyncCheck puncher.captureAndResendAck(captureAckFd, injectAckFd)
|
|
await puncher.addFirewallRules()
|
|
try:
|
|
let connectParallelFuture = puncher.connectParallel()
|
|
await connectParallelFuture or sleepAsync(Timeout)
|
|
await puncher.cleanup()
|
|
if connectParallelFuture.finished():
|
|
result = connectParallelFuture.read()
|
|
else:
|
|
raise newException(PunchHoleError, "timeout")
|
|
except OSError as e:
|
|
raise newException(PunchHoleError, e.msg)
|
|
|
|
proc prepareAccept(puncher: TcpSyniPuncher) {.async.} =
|
|
for dstPort in puncher.dstPorts:
|
|
try:
|
|
let sock = newAsyncSocket()
|
|
sock.setSockOpt(OptReuseAddr, true)
|
|
sock.getFd.setSockOptInt(IPPROTO_IP, IP_TTL, 2)
|
|
sock.bindAddr(puncher.srcPort, $(puncher.srcIp))
|
|
let connectFuture = sock.connect($(puncher.dstIp), dstPort)
|
|
await connectFuture or sleepAsync(Timeout)
|
|
if connectFuture.finished():
|
|
echo "connected during accept phase"
|
|
sock.close()
|
|
except OSError:
|
|
discard
|
|
|
|
proc accept*(puncher: TcpSyniPuncher): Future[AsyncSocket] {.async.} =
|
|
await puncher.prepareAccept()
|
|
await puncher.addFirewallRules()
|
|
try:
|
|
let rawFd = setupTcpInjectingSocket()
|
|
for dstPort in puncher.dstPorts:
|
|
for seqNum in puncher.seqNums:
|
|
let ipPacket = IpPacket(protocol: tcp,
|
|
ipAddrSrc: puncher.dstIp,
|
|
ipAddrDst: puncher.srcIp,
|
|
ipTTL: 64,
|
|
tcpPortSrc: dstPort,
|
|
tcpPortDst: puncher.srcPort,
|
|
tcpSeqNumber: seqNum,
|
|
tcpAckNumber: 0,
|
|
tcpFlags: {SYN},
|
|
tcpWindowSize: 1452 * 10)
|
|
echo &"[{ipPacket.ipAddrSrc}:{ipPacket.tcpPortSrc} -> {ipPacket.ipAddrDst}:{ipPacket.tcpPortDst}, SEQ {ipPacket.tcpSeqNumber}] injecting SYN"
|
|
asyncCheck rawFd.injectTcpPacket(ipPacket)
|
|
let sock = newAsyncSocket()
|
|
sock.setSockOpt(OptReuseAddr, true)
|
|
sock.bindAddr(puncher.srcPort, $(puncher.srcIp))
|
|
sock.listen()
|
|
echo &"accepting connections from {puncher.dstIp}:{puncher.dstPorts[0].int}"
|
|
let acceptFuture = sock.accept()
|
|
await acceptFuture or sleepAsync(Timeout)
|
|
await puncher.cleanup()
|
|
if acceptFuture.finished():
|
|
result = acceptFuture.read()
|
|
else:
|
|
raise newException(PunchHoleError, "timeout")
|
|
except OSError as e:
|
|
echo &"accepting connections from {puncher.dstIP}:{puncher.dstPorts[0].int} failed: ", e.msg
|
|
await puncher.cleanup()
|
|
raise newException(PunchHoleError, e.msg)
|