diff --git a/tcpip.go b/tcpip.go index 335fda6..75b6336 100644 --- a/tcpip.go +++ b/tcpip.go @@ -108,9 +108,6 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go // TODO: log parse failure return false, []byte{} } - if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { - return false, []byte("port forwarding is disabled") - } addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) ln, err := net.Listen("tcp", addr) if err != nil { @@ -119,6 +116,10 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go } _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) destPort, _ := strconv.Atoi(destPortStr) + if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, uint32(destPort)) { + ln.Close() + return false, []byte("port forwarding is disabled") + } h.Lock() h.forwards[addr] = ln h.Unlock()