As the title says, I am writing a UDP server on raw sockets, and using SocketAsyncEventArgs because i would like to write something fast.
I am aware that UdpClient exists, and that there are easier methods to go about writing a server, but i would like to learn how to properly use SocketAsyncEventArgs and the socket.ReceiveFromAsync / socket.SendToAsync methods for their claimed 'enchanced throughput' and 'better scalability' (MSDN Docs for SocketAsyncEventArgs).
I have essentially followed the MSDN example, as I figured it was a decent starting point to learn how these methods work, and ran into a bit of a problem. The server works initially and can echo back received bytes, but will randomly 'fail' to receive bytes from the correct address. Instead of the correct localhost client address (e.g. 127.0.0.1:7007) the RemoteEndPoint will be populated by the UDP placeholder EndPoint {0.0.0.0:0} (for lack of a better term). Image showing the problem (console is in Polish, sorry. Just trust me that the SocketException message is "Desired Address is invalid in this context").
I have wholesale ripped chunks from the MSDN example at times, changing only the fields that were populated on the SocketAsyncEventArgs instance for the socket.ReceiveFromAsync call (in accordance with the MSDN docs -> socket.ReceiveFromAsync Docs), and the end result was still the same. Furthermore, this is an intermittent issue, not a constant one. There is no time that the server will consistently error out, from what I have noticed.
My thoughts so far are a state issue with UdpServer, some inconsistency on the UdpClient side, or a misuse of the TaskCompletionSource.
Edit 1:
I feel that i should address why im using SocketAsyncEventArgs. I completely understand that there are easier methods to go about sending and receiving data. The async/await socket extensions are a good way to go about this, and are the way i did it initially. I would want to benchmark async/await versus the older api, the SocketArgs, to see by how much the two methods differ.
Code for the UdpClient, UdpServer, and shared structures is included below. I can also try to provide more code on demand, if StackOverflow will let me.
Thank you for taking the time to help me.
Test Code
private static async Task TestNetworking()
{
EndPoint serverEndPoint = new IPEndPoint(IPAddress.Loopback, 12345);
await Task.Factory.StartNew(async () =>
{
SocketClient client = new UdpClient();
bool bound = client.Bind(new IPEndPoint(IPAddress.Any, 7007));
if (bound)
{
Console.WriteLine($"[Client] Bound client socket!");
}
if (bound && client.Connect(serverEndPoint))
{
Console.WriteLine($"[Client] Connected to {serverEndPoint}!");
byte[] message = Encoding.UTF8.GetBytes("Hello World!");
Stopwatch stopwatch = new Stopwatch();
const int packetsToSend = 1_000_000;
for (int i = 0; i < packetsToSend; i++)
{
try
{
stopwatch.Start();
int sentBytes = await client.SendAsync(serverEndPoint, message, SocketFlags.None);
//Console.WriteLine($"[Client] Sent {sentBytes} to {serverEndPoint}");
ReceiveResult result = await client.ReceiveAsync(serverEndPoint, SocketFlags.None);
//Console.WriteLine($"[{result.RemoteEndPoint} > Client] : {Encoding.UTF8.GetString(result.Contents)}");
serverEndPoint = result.RemoteEndPoint;
stopwatch.Stop();
}
catch (Exception ex)
{
Console.WriteLine(ex);
i--;
await Task.Delay(1);
}
}
double approxBandwidth = (packetsToSend * message.Length) / (1_000_000.0 * (stopwatch.ElapsedMilliseconds / 1000.0));
Console.WriteLine($"Sent {packetsToSend} packets of {message.Length} bytes in {stopwatch.ElapsedMilliseconds:N} milliseconds.");
Console.WriteLine($"Approximate bandwidth: {approxBandwidth} MBps");
}
}, TaskCreationOptions.LongRunning);
await Task.Factory.StartNew(async () =>
{
try
{
SocketServer server = new UdpServer();
bool bound = server.Bind(serverEndPoint);
if (bound)
{
//Console.WriteLine($"[Server] Bound server socket!");
//Console.WriteLine($"Starting server at {serverEndPoint}!");
await server.StartAsync();
}
}
catch (Exception ex)
{
Console.WriteLine(ex);
}
}).Result;
}
Shared Code
public readonly struct ReceiveResult
{
public const int PacketSize = 1024;
public readonly Memory<byte> Contents;
public readonly int ReceivedBytes;
public readonly EndPoint RemoteEndPoint;
public ReceiveResult(Memory<byte> contents, int receivedBytes, EndPoint remoteEndPoint)
{
Contents = contents;
ReceivedBytes = receivedBytes;
RemoteEndPoint = remoteEndPoint;
}
}
UDP Client
public class UdpClient : SocketClient
{
/*
public abstract class SocketClient
{
protected readonly Socket socket;
protected SocketClient(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
socket = new Socket(addressFamily, socketType, protocolType);
}
public bool Bind(in EndPoint localEndPoint)
{
try
{
socket.Bind(localEndPoint);
return true;
}
catch (Exception ex)
{
Console.WriteLine(ex);
return false;
}
}
public bool Connect(in EndPoint remoteEndPoint)
{
try
{
socket.Connect(remoteEndPoint);
return true;
}
catch (Exception ex)
{
Console.WriteLine(ex);
return false;
}
}
public abstract Task<ReceiveResult> ReceiveAsync(EndPoint remoteEndPoint, SocketFlags socketFlags);
public abstract Task<int> SendAsync(EndPoint remoteEndPoint, ArraySegment<byte> buffer, SocketFlags socketFlags);
}
*/
/// <inheritdoc />
public UdpClient() : base(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)
{
}
public override async Task<ReceiveResult> ReceiveAsync(EndPoint remoteEndPoint, SocketFlags socketFlags)
{
byte[] buffer = new byte[ReceiveResult.PacketSize];
SocketReceiveFromResult result =
await socket.ReceiveFromAsync(new ArraySegment<byte>(buffer), socketFlags, remoteEndPoint);
return new ReceiveResult(buffer, result.ReceivedBytes, result.RemoteEndPoint);
/*
SocketAsyncEventArgs args = new SocketAsyncEventArgs();
args.SetBuffer(new byte[ReceiveResult.PacketSize]);
args.SocketFlags = socketFlags;
args.RemoteEndPoint = remoteEndPoint;
SocketTask awaitable = new SocketTask(args);
while (ReceiveResult.PacketSize > args.BytesTransferred)
{
await socket.ReceiveFromAsync(awaitable);
}
return new ReceiveResult(args.MemoryBuffer, args.RemoteEndPoint);
*/
}
public override async Task<int> SendAsync(EndPoint remoteEndPoint, ArraySegment<byte> buffer, SocketFlags socketFlags)
{
return await socket.SendToAsync(buffer.ToArray(), socketFlags, remoteEndPoint);
/*
SocketAsyncEventArgs args = new SocketAsyncEventArgs();
args.SetBuffer(buffer);
args.SocketFlags = socketFlags;
args.RemoteEndPoint = remoteEndPoint;
SocketTask awaitable = new SocketTask(args);
while (buffer.Length > args.BytesTransferred)
{
await socket.SendToAsync(awaitable);
}
return args.BytesTransferred;
*/
}
}
UDP Server
public class UdpServer : SocketServer
{
/*
public abstract class SocketServer
{
protected readonly Socket socket;
protected SocketServer(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
{
socket = new Socket(addressFamily, socketType, protocolType);
}
public bool Bind(in EndPoint localEndPoint)
{
try
{
socket.Bind(localEndPoint);
return true;
}
catch (Exception ex)
{
Console.WriteLine(ex);
return false;
}
}
public abstract Task StartAsync();
}
*/
private const int MaxPooledObjects = 100;
private readonly ConcurrentDictionary<EndPoint, ConcurrentQueue<byte[]>> clients;
private readonly ArrayPool<byte> receiveBufferPool = ArrayPool<byte>.Create(ReceiveResult.PacketSize, MaxPooledObjects);
private readonly ObjectPool<SocketAsyncEventArgs> receiveSocketAsyncEventArgsPool =
new DefaultObjectPool<SocketAsyncEventArgs>(new DefaultPooledObjectPolicy<SocketAsyncEventArgs>(), MaxPooledObjects);
private readonly ObjectPool<SocketAsyncEventArgs> sendSocketAsyncEventArgsPool =
new DefaultObjectPool<SocketAsyncEventArgs>(new DefaultPooledObjectPolicy<SocketAsyncEventArgs>(), MaxPooledObjects);
private void HandleIOCompleted(object? sender, SocketAsyncEventArgs eventArgs)
{
eventArgs.Completed -= HandleIOCompleted;
bool closed = false;
/*
Original (local) methods in ReceiveAsync and SendAsync,
these were assigned to eventArgs.Completed instead of HandleIOCompleted
=======================================================================
void ReceiveCompletedHandler(object? sender, SocketAsyncEventArgs eventArgs)
{
AsyncReadToken asyncReadToken = (AsyncReadToken)eventArgs.UserToken;
eventArgs.Completed -= ReceiveCompletedHandler;
if (eventArgs.SocketError != SocketError.Success)
{
asyncReadToken.CompletionSource.TrySetException(new SocketException((int)eventArgs.SocketError));
}
else
{
eventArgs.MemoryBuffer.CopyTo(asyncReadToken.OutputBuffer);
asyncReadToken.CompletionSource.TrySetResult(
new ReceiveResult(asyncReadToken.OutputBuffer, eventArgs.BytesTransferred, eventArgs.RemoteEndPoint));
}
receiveBufferPool.Return(asyncReadToken.RentedBuffer);
receiveSocketAsyncEventArgsPool.Return(eventArgs);
}
void SendCompletedHandler(object? sender, SocketAsyncEventArgs eventArgs)
{
AsyncWriteToken asyncWriteToken = (AsyncWriteToken)eventArgs.UserToken;
eventArgs.Completed -= SendCompletedHandler;
if (eventArgs.SocketError != SocketError.Success)
{
asyncWriteToken.CompletionSource.TrySetException(new SocketException((int)eventArgs.SocketError));
}
else
{
asyncWriteToken.CompletionSource.TrySetResult(eventArgs.BytesTransferred);
}
sendSocketAsyncEventArgsPool.Return(eventArgs);
}
*/
switch (eventArgs.LastOperation)
{
case SocketAsyncOperation.SendTo:
AsyncWriteToken asyncWriteToken = (AsyncWriteToken)eventArgs.UserToken;
if (eventArgs.SocketError != SocketError.Success)
{
asyncWriteToken.CompletionSource.TrySetException(new SocketException((int)eventArgs.SocketError));
}
else
{
asyncWriteToken.CompletionSource.TrySetResult(eventArgs.BytesTransferred);
}
sendSocketAsyncEventArgsPool.Return(eventArgs);
break;
case SocketAsyncOperation.ReceiveFrom:
AsyncReadToken asyncReadToken = (AsyncReadToken)eventArgs.UserToken;
if (eventArgs.SocketError != SocketError.Success)
{
asyncReadToken.CompletionSource.TrySetException(new SocketException((int)eventArgs.SocketError));
}
else
{
eventArgs.MemoryBuffer.CopyTo(asyncReadToken.OutputBuffer);
asyncReadToken.CompletionSource.TrySetResult(
new ReceiveResult(asyncReadToken.OutputBuffer, eventArgs.BytesTransferred, eventArgs.RemoteEndPoint));
}
receiveBufferPool.Return(asyncReadToken.RentedBuffer);
receiveSocketAsyncEventArgsPool.Return(eventArgs);
break;
case SocketAsyncOperation.Disconnect:
closed = true;
break;
case SocketAsyncOperation.Accept:
case SocketAsyncOperation.Connect:
case SocketAsyncOperation.None:
break;
}
if (closed)
{
// handle the client closing the connection on tcp servers at some point
}
}
private Task<ReceiveResult> ReceiveAsync(EndPoint remoteEndPoint, SocketFlags socketFlags, Memory<byte> outputBuffer)
{
TaskCompletionSource<ReceiveResult> tcs = new TaskCompletionSource<ReceiveResult>();
byte[] buffer = receiveBufferPool.Rent(ReceiveResult.PacketSize);
Memory<byte> memoryBuffer = new Memory<byte>(buffer);
SocketAsyncEventArgs args = receiveSocketAsyncEventArgsPool.Get();
args.SetBuffer(memoryBuffer);
args.SocketFlags = socketFlags;
args.RemoteEndPoint = remoteEndPoint;
args.UserToken = new AsyncReadToken(buffer, outputBuffer, tcs);
args.Completed += HandleIOCompleted;
if (socket.ReceiveFromAsync(args)) return tcs.Task;
byte[] bufferCopy = new byte[ReceiveResult.PacketSize];
args.MemoryBuffer.CopyTo(bufferCopy);
ReceiveResult result = new ReceiveResult(bufferCopy, args.BytesTransferred, args.RemoteEndPoint);
receiveBufferPool.Return(buffer);
receiveSocketAsyncEventArgsPool.Return(args);
return Task.FromResult(result);
}
private Task<int> SendAsync(EndPoint remoteEndPoint, Memory<byte> buffer, SocketFlags socketFlags)
{
TaskCompletionSource<int> tcs = new TaskCompletionSource<int>();
SocketAsyncEventArgs args = sendSocketAsyncEventArgsPool.Get();
args.SetBuffer(buffer);
args.SocketFlags = socketFlags;
args.RemoteEndPoint = remoteEndPoint;
args.UserToken = new AsyncWriteToken(buffer, tcs);
args.Completed += HandleIOCompleted;
if (socket.SendToAsync(args)) return tcs.Task;
int result = args.BytesTransferred;
sendSocketAsyncEventArgsPool.Return(args);
return Task.FromResult(result);
}
private readonly struct AsyncReadToken
{
public readonly TaskCompletionSource<ReceiveResult> CompletionSource;
public readonly Memory<byte> OutputBuffer;
public readonly byte[] RentedBuffer;
public AsyncReadToken(byte[] rentedBuffer, Memory<byte> outputBuffer, TaskCompletionSource<ReceiveResult> tcs)
{
RentedBuffer = rentedBuffer;
OutputBuffer = outputBuffer;
CompletionSource = tcs;
}
}
private readonly struct AsyncWriteToken
{
public readonly TaskCompletionSource<int> CompletionSource;
public readonly Memory<byte> OutputBuffer;
public AsyncWriteToken(Memory<byte> outputBuffer, TaskCompletionSource<int> tcs)
{
OutputBuffer = outputBuffer;
CompletionSource = tcs;
}
}
public UdpServer() : base(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp)
{
clients = new ConcurrentDictionary<EndPoint, ConcurrentQueue<byte[]>>();
}
/// <inheritdoc />
public override async Task StartAsync()
{
EndPoint nullEndPoint = new IPEndPoint(IPAddress.Any, 0);
byte[] receiveBuffer = new byte[ReceiveResult.PacketSize];
Memory<byte> receiveBufferMemory = new Memory<byte>(receiveBuffer);
while (true)
{
ReceiveResult result = await ReceiveAsync(nullEndPoint, SocketFlags.None, receiveBufferMemory);
Console.WriteLine($"[{result.RemoteEndPoint} > Server] : {Encoding.UTF8.GetString(result.Contents.Span)}");
int sentBytes = await SendAsync(result.RemoteEndPoint, result.Contents, SocketFlags.None);
Console.WriteLine($"[Server > {result.RemoteEndPoint}] Sent {sentBytes} bytes to {result.RemoteEndPoint}");
}
}
}