4
votes

I am trying to implement a TCP forwarder in C#. Specifically the application:

  1. Listens to a TCP port and wait for a client,
  2. When a client is connected, connects to a remote host,
  3. Waits for incoming data on the both connections and exchange the data between two endpoints (acts as an proxy),
  4. Closes one connection when the other one get closed by an endpoint.

I have adapted Simple TCP Forwader (by Garcia) to forward a range of ports so that

TCPForwarder.exe 10.1.1.1 192.168.1.100 1000 1100 2000

will forward any packets received on 10.1.1.1 on port 1000-1100 to remote host 192.168.1.100 port 2000-2100. I have employed this to expose a FTP server that is behind a NAT.

By running the above command, the client is able to connect to the FTP server and the following pattern in outputted to the console which is expected (refer to code):

0 StartReceive: BeginReceive
1 StartReceive: BeginReceive
1 OnDataReceive: EndReceive
1 OnDataReceive: BeginReceive
1 OnDataReceive: EndReceive
1 OnDataReceive: Close (0 read)
0 OnDataReceive: EndReceive
0 OnDataReceive: Close (exception)

But after successfully connecting for several times (pressing F5 in Filezilla), no further response is received from TCPForwarder (and FTP Server).

There seems to be two problems with my implementation that I cannot debug:

  1. In that situation, the BeginReceive in StartReceive method gets called, but no data is received from the FTP server. I don't think that could be FTP server issue (it's a ProFTPD server) as it is a well-known FTP server.

  2. Every time a connection is made and closed, the number of threads increases by 1. I don't think garbage collection would fix that. Number of threads is consistently increasing and forcing garabage collector to run doesn't decrease that either. I think there is some leak in my code that is also causing issue #1.

Edit:

  • Restarting the FTP server didn't fix the problem, so there is definitely a bug in the TCPForwarder.

  • Some issues pointed out by @jgauffin is fixed in the below code.

Here is the full code:

using System;
using System.Net;
using System.Net.Sockets;
using System.Collections.Generic;
using System.Threading;

namespace TCPForwarder
{
    class Program
    {
        private class State
        {
            public int ID { get; private set; } // for debugging purposes
            public Socket SourceSocket { get; private set; }
            public Socket DestinationSocket { get; private set; }
            public byte[] Buffer { get; private set; }
            public State(int id, Socket source, Socket destination)
            {
                ID = id;
                SourceSocket = source;
                DestinationSocket = destination;
                Buffer = new byte[8192];
            }
        }

        public class TcpForwarder
        {
            public void Start(IPEndPoint local, IPEndPoint remote)
            {
                Socket MainSocket;
                try
                {
                    MainSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
                    MainSocket.Bind(local);
                    MainSocket.Listen(10);
                }
                catch (Exception exp)
                {
                    Console.WriteLine("Error on listening to " + local.Port + ": " + exp.Message);
                    return;
                }

                while (true)
                {
                    // Accept a new client
                    var socketSrc = MainSocket.Accept();
                    var socketDest = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);

                    try
                    {
                        // Connect to the endpoint
                        socketDest.Connect(remote);
                    }
                    catch
                    {
                        socketSrc.Shutdown(SocketShutdown.Both);
                        socketSrc.Close();
                        Console.WriteLine("Exception in connecting to remote host");
                        continue;
                    }

                    // Wait for data sent from client and forward it to the endpoint
                    StartReceive(0, socketSrc, socketDest);

                    // Also, wait for data sent from endpoint and forward it to the client
                    StartReceive(1, socketDest, socketSrc);
                }
            }

            private static void StartReceive(int id, Socket src, Socket dest)
            {
                var state = new State(id, src, dest);

                Console.WriteLine("{0} StartReceive: BeginReceive", id);
                try
                {
                    src.BeginReceive(state.Buffer, 0, state.Buffer.Length, 0, OnDataReceive, state);
                }
                catch
                {
                    Console.WriteLine("{0} Exception in StartReceive: BeginReceive", id);
                }
            }

            private static void OnDataReceive(IAsyncResult result)
            {
                State state = null;
                try
                {
                    state = (State)result.AsyncState;

                    Console.WriteLine("{0} OnDataReceive: EndReceive", state.ID);
                    var bytesRead = state.SourceSocket.EndReceive(result);
                    if (bytesRead > 0)
                    {
                        state.DestinationSocket.Send(state.Buffer, bytesRead, SocketFlags.None);

                        Console.WriteLine("{0} OnDataReceive: BeginReceive", state.ID);
                        state.SourceSocket.BeginReceive(state.Buffer, 0, state.Buffer.Length, 0, OnDataReceive, state);
                    }
                    else
                    {
                        Console.WriteLine("{0} OnDataReceive: Close (0 read)", state.ID);
                        state.SourceSocket.Shutdown(SocketShutdown.Both);
                        state.DestinationSocket.Shutdown(SocketShutdown.Both);
                        state.DestinationSocket.Close();
                        state.SourceSocket.Close();
                    }
                }
                catch
                {
                    if (state!=null)
                    {
                        Console.WriteLine("{0} OnDataReceive: Close (exception)", state.ID);
                        state.SourceSocket.Shutdown(SocketShutdown.Both);
                        state.DestinationSocket.Shutdown(SocketShutdown.Both);
                        state.DestinationSocket.Close();
                        state.SourceSocket.Close();
                    }
                }
            }
        }

        static void Main(string[] args)
        {
            List<Socket> sockets = new List<Socket>();

            int srcPortStart = int.Parse(args[2]);
            int srcPortEnd = int.Parse(args[3]);
            int destPortStart = int.Parse(args[4]);

            List<Thread> threads = new List<Thread>();
            for (int i = 0; i < srcPortEnd - srcPortStart + 1; i++)
            {
                int srcPort = srcPortStart + i;
                int destPort = destPortStart + i;

                TcpForwarder tcpForwarder = new TcpForwarder();

                Thread t = new Thread(new ThreadStart(() => tcpForwarder.Start(
                    new IPEndPoint(IPAddress.Parse(args[0]), srcPort),
                    new IPEndPoint(IPAddress.Parse(args[1]), destPort))));
                t.Start();

                threads.Add(t);
            }

            foreach (var t in threads)
            {
                t.Join();
            }
            Console.WriteLine("All threads are closed");
        }
    }
}
1

1 Answers

1
votes

The first problem is that the code will continue on connection failure on the destination socket (the accept loop). Use a continue; in the try/catch. There is also no guarantee that the sockets are still up when you invoke the first BeginReceive. Those calls also need to be wrapped.

Always wrap callback methods in a try/catch as your application can fail otherwise (in this case OnDataRecieve).

Fix that and start to write out the exceptions. They will surely give you a hint about what went wrong.