14
votes

.NET objects are free-threaded by default. If marshaled to another thread via COM, they always get marshaled to themselves, regardless of whether the creator thread was STA or not, and regardless of their ThreadingModel registry value. I suspect, they aggregate the Free Threaded Marshaler (more details about COM threading could be found here).

I want to make my .NET COM object use the standard COM marshaller proxy when marshaled to another thread. The problem:

using System;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using System.Windows.Threading;

namespace ConsoleApplication
{
    class Program
    {
        static void Main(string[] args)
        {
            var apt1 = new WpfApartment();
            var apt2 = new WpfApartment();

            apt1.Invoke(() =>
            {
                var comObj = new ComObject();
                comObj.Test();

                IntPtr pStm;
                NativeMethods.CoMarshalInterThreadInterfaceInStream(NativeMethods.IID_IUnknown, comObj, out pStm);

                apt2.Invoke(() =>
                {
                    object unk;
                    NativeMethods.CoGetInterfaceAndReleaseStream(pStm, NativeMethods.IID_IUnknown, out unk);

                    Console.WriteLine(new { equal = Object.ReferenceEquals(comObj, unk) });

                    var marshaledComObj = (IComObject)unk;
                    marshaledComObj.Test();
                });
            });

            Console.ReadLine();
        }
    }

    // ComObject
    [ComVisible(true)]
    [Guid("00020400-0000-0000-C000-000000000046")] // IID_IDispatch
    [InterfaceType(ComInterfaceType.InterfaceIsIDispatch)]
    public interface IComObject
    {
        void Test();
    }

    [ComVisible(true)]
    [ClassInterface(ClassInterfaceType.None)]
    [ComDefaultInterface(typeof(IComObject))]
    public class ComObject : IComObject
    {
        // IComObject methods
        public void Test()
        {
            Console.WriteLine(new { Environment.CurrentManagedThreadId });
        }
    }


    // WpfApartment - a WPF Dispatcher Thread 
    internal class WpfApartment : IDisposable
    {
        Thread _thread; // the STA thread
        public System.Threading.Tasks.TaskScheduler TaskScheduler { get; private set; }

        public WpfApartment()
        {
            var tcs = new TaskCompletionSource<System.Threading.Tasks.TaskScheduler>();

            // start the STA thread with WPF Dispatcher
            _thread = new Thread(_ =>
            {
                NativeMethods.OleInitialize(IntPtr.Zero);
                try
                {
                    // post a callback to get the TaskScheduler
                    Dispatcher.CurrentDispatcher.InvokeAsync(
                        () => tcs.SetResult(System.Threading.Tasks.TaskScheduler.FromCurrentSynchronizationContext()),
                        DispatcherPriority.ApplicationIdle);

                    // run the WPF Dispatcher message loop
                    Dispatcher.Run();
                }
                finally
                {
                    NativeMethods.OleUninitialize();
                }
            });

            _thread.SetApartmentState(ApartmentState.STA);
            _thread.IsBackground = true;
            _thread.Start();
            this.TaskScheduler = tcs.Task.Result;
        }

        // shutdown the STA thread
        public void Dispose()
        {
            if (_thread != null && _thread.IsAlive)
            {
                InvokeAsync(() => System.Windows.Threading.Dispatcher.ExitAllFrames());
                _thread.Join();
                _thread = null;
            }
        }

        // Task.Factory.StartNew wrappers
        public Task InvokeAsync(Action action)
        {
            return Task.Factory.StartNew(action,
                CancellationToken.None, TaskCreationOptions.None, this.TaskScheduler);
        }

        public void Invoke(Action action)
        {
            InvokeAsync(action).Wait();
        }
    }

    public static class NativeMethods
    {
        public static readonly Guid IID_IUnknown = new Guid("00000000-0000-0000-C000-000000000046");
        public static readonly Guid IID_IDispatch = new Guid("00020400-0000-0000-C000-000000000046");

        [DllImport("ole32.dll", PreserveSig = false)]
        public static extern void CoMarshalInterThreadInterfaceInStream(
            [In, MarshalAs(UnmanagedType.LPStruct)] Guid riid,
            [MarshalAs(UnmanagedType.IUnknown)] object pUnk,
            out IntPtr ppStm);

        [DllImport("ole32.dll", PreserveSig = false)]
        public static extern void CoGetInterfaceAndReleaseStream(
            IntPtr pStm,
            [In, MarshalAs(UnmanagedType.LPStruct)] Guid riid,
            [MarshalAs(UnmanagedType.IUnknown)] out object ppv);

        [DllImport("ole32.dll", PreserveSig = false)]
        public static extern void OleInitialize(IntPtr pvReserved);

        [DllImport("ole32.dll", PreserveSig = true)]
        public static extern void OleUninitialize();
    }
}

Output:

{ CurrentManagedThreadId = 11 }
{ equal = True }
{ CurrentManagedThreadId = 12 }

Note I use CoMarshalInterThreadInterfaceInStream/CoGetInterfaceAndReleaseStream to marshal ComObject from one STA thread to another. I want both Test() calls to be invoked on the same original thread, e.g. 11, as it would have been the case with a typical STA COM object implemented in C++.

One possible solution is to disable IMarshal interface on the .NET COM object:

[ComVisible(true)]
[ClassInterface(ClassInterfaceType.None)]
[ComDefaultInterface(typeof(IComObject))]
public class ComObject : IComObject, ICustomQueryInterface
{
    // IComObject methods
    public void Test()
    {
        Console.WriteLine(new { Environment.CurrentManagedThreadId });
    }

    public static readonly Guid IID_IMarshal = new Guid("00000003-0000-0000-C000-000000000046");

    public CustomQueryInterfaceResult GetInterface(ref Guid iid, out IntPtr ppv)
    {
        ppv = IntPtr.Zero;
        if (iid == IID_IMarshal)
        {
            return CustomQueryInterfaceResult.Failed;
        }
        return CustomQueryInterfaceResult.NotHandled;
    }
}

Output (as desired):

{ CurrentManagedThreadId = 11 }
{ equal = False }
{ CurrentManagedThreadId = 11 }

This works, but it feels like an implementation-specific hack. Is there a more decent way to get this done, like some special interop attribute I might have overlooked? Note that in real life ComObject is used (and gets marshaled) by a legacy unmanaged application.

2
You always have the most interesting threading questions. I don't know the answer to this but I definitely will be watching this question to see what the answer will be.Scott Chamberlain
@ScottChamberlain, thanks :) Where COM is not involved, I'm happy with using techniques like await WpfApartment.InvokeAsync(()=>DoSomething()) for marshaling calls between STA threads.noseratio
It is a very artificial test, real COM servers of course don't behave this way. They have a proper ThreadModel, the CLR makes an effort to find it in your test program but of course is doomed to fail. Without any shot at finding a proxy, it punts and doesn't marshal the interface at all. There is little point in pursuing this, just don't bother using COM at all if you are not going to use a real COM server.Hans Passant
@HansPassant, the actual server does use ComRegisterFunctionAttribute to register the object as ThreadingModel=Apartment. Yet, as I've mentioned above, that doesn't change the marshaling behavior I described. The unmanaged client calls CoMarshalInterThreadInterfaceInStream on one STA thread, then CoGetInterfaceAndReleaseStream on another STA thread, and gets the same object, not a proxy. That's because any .NET CCW implements IMarshal and is free-threaded, regardless of its ThreadingModel in the registry. Don't take my word for that, try it for yourself.noseratio
@PauloMadeira, in this particular case, it's a trick which allows me to use OLE Automation IDispatch-style marshaling without doing RegAsm. COM uses IDispatch::Invoke, and .NET has just enough metadata to make those IDispatch::Invoke calls hard-typed.noseratio

2 Answers

7
votes

You can inherit from StandardOleMarshalObject or ServicedComponent for that effect:

Managed objects that are exposed to COM behave as if they had aggregated the free-threaded marshaler. In other words, they can be called from any COM apartment in a free-threaded manner. The only managed objects that do not exhibit this free-threaded behavior are those objects that derive from ServicedComponent or StandardOleMarshalObject.

5
votes

Paulo Madeira's excellent answer provides a great solution for when the managed class being exposed to COM can be derived from StandardOleMarshalObject.

It got me thinking though, how to deal with the cases when there is already a base class, like say System.Windows.Forms.Control, which doesn't have StandardOleMarshalObject in its inheritance chain?

It turns out, it's possible to aggregate the Standard COM Marshaler. Similar to the Free Threaded Marshaler's CoCreateFreeThreadedMarshaler, there is an API for that: CoGetStdMarshalEx. Here's how it can be done:

[ComVisible(true)]
[ClassInterface(ClassInterfaceType.None)]
[ComDefaultInterface(typeof(IComObject))]
public class ComObject : IComObject, ICustomQueryInterface
{
    IntPtr _unkMarshal;

    public ComObject()
    {
        NativeMethods.CoGetStdMarshalEx(this, NativeMethods.SMEXF_SERVER, out _unkMarshal);
    }

    ~ComObject()
    {
        if (_unkMarshal != IntPtr.Zero)
        {
            Marshal.Release(_unkMarshal);
            _unkMarshal = IntPtr.Zero;
        }
    }

    // IComObject methods
    public void Test()
    {
        Console.WriteLine(new { Environment.CurrentManagedThreadId });
    }

    // ICustomQueryInterface
    public CustomQueryInterfaceResult GetInterface(ref Guid iid, out IntPtr ppv)
    {
        ppv = IntPtr.Zero;
        if (iid == NativeMethods.IID_IMarshal)
        {
            if (Marshal.QueryInterface(_unkMarshal, ref NativeMethods.IID_IMarshal, out ppv) != 0)
                return CustomQueryInterfaceResult.Failed;
            return CustomQueryInterfaceResult.Handled;
        }
        return CustomQueryInterfaceResult.NotHandled;
    }

    static class NativeMethods
    {
        public static Guid IID_IMarshal = new Guid("00000003-0000-0000-C000-000000000046");

        public const UInt32 SMEXF_SERVER = 1;

        [DllImport("ole32.dll", PreserveSig = false)]
        public static extern void CoGetStdMarshalEx(
            [MarshalAs(UnmanagedType.IUnknown)] object pUnkOuter,
            UInt32 smexflags,
            out IntPtr ppUnkInner);
    }
}