Skip to content

Commit

Permalink
Feature switch to turn off COM support in Windows (#50662)
Browse files Browse the repository at this point in the history
* Incorporating FB

* non-windows build break fix

* FB

* fb

* test fixes + added new guards to missed ones

* test fixed based on FB
  • Loading branch information
LakshanF committed Apr 27, 2021
1 parent 9a31832 commit 1c9e200
Show file tree
Hide file tree
Showing 12 changed files with 400 additions and 16 deletions.
Expand Up @@ -91,10 +91,16 @@ public struct ComActivationContext
public string AssemblyName;
public string TypeName;

[RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
[CLSCompliant(false)]
public static unsafe ComActivationContext Create(ref ComActivationContextInternal cxtInt)
{
#if FEATURE_COMINTEROP_UNMANAGED_ACTIVATION
if (!Marshal.IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

return new ComActivationContext()
{
ClassId = cxtInt.ClassId,
Expand Down Expand Up @@ -122,9 +128,15 @@ public static class ComActivator
/// Entry point for unmanaged COM activation API from managed code
/// </summary>
/// <param name="cxt">Reference to a <see cref="ComActivationContext"/> instance</param>
[RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
public static object GetClassFactoryForType(ComActivationContext cxt)
{
#if FEATURE_COMINTEROP_UNMANAGED_ACTIVATION
if (!Marshal.IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

if (cxt.InterfaceId != typeof(IClassFactory).GUID
&& cxt.InterfaceId != typeof(IClassFactory2).GUID)
{
Expand Down Expand Up @@ -154,9 +166,15 @@ public static object GetClassFactoryForType(ComActivationContext cxt)
/// </summary>
/// <param name="cxt">Reference to a <see cref="ComActivationContext"/> instance</param>
/// <param name="register">true if called for register or false to indicate unregister</param>
[RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
public static void ClassRegistrationScenarioForType(ComActivationContext cxt, bool register)
{
#if FEATURE_COMINTEROP_UNMANAGED_ACTIVATION
if (!Marshal.IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

// Retrieve the attribute type to use to determine if a function is the requested user defined
// registration function.
string attributeName = register ? "ComRegisterFunctionAttribute" : "ComUnregisterFunctionAttribute";
Expand Down Expand Up @@ -246,11 +264,17 @@ public static void ClassRegistrationScenarioForType(ComActivationContext cxt, bo
/// Internal entry point for unmanaged COM activation API from native code
/// </summary>
/// <param name="pCxtInt">Pointer to a <see cref="ComActivationContextInternal"/> instance</param>
[RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
[CLSCompliant(false)]
[UnmanagedCallersOnly]
public static unsafe int GetClassFactoryForTypeInternal(ComActivationContextInternal* pCxtInt)
{
#if FEATURE_COMINTEROP_UNMANAGED_ACTIVATION
if (!Marshal.IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

ref ComActivationContextInternal cxtInt = ref *pCxtInt;

if (IsLoggingEnabled())
Expand Down Expand Up @@ -287,11 +311,17 @@ public static unsafe int GetClassFactoryForTypeInternal(ComActivationContextInte
/// Internal entry point for registering a managed COM server API from native code
/// </summary>
/// <param name="pCxtInt">Pointer to a <see cref="ComActivationContextInternal"/> instance</param>
[RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
[CLSCompliant(false)]
[UnmanagedCallersOnly]
public static unsafe int RegisterClassForTypeInternal(ComActivationContextInternal* pCxtInt)
{
#if FEATURE_COMINTEROP_UNMANAGED_ACTIVATION
if (!Marshal.IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

ref ComActivationContextInternal cxtInt = ref *pCxtInt;

if (IsLoggingEnabled())
Expand Down Expand Up @@ -331,11 +361,17 @@ public static unsafe int RegisterClassForTypeInternal(ComActivationContextIntern
/// <summary>
/// Internal entry point for unregistering a managed COM server API from native code
/// </summary>
[RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
[CLSCompliant(false)]
[UnmanagedCallersOnly]
public static unsafe int UnregisterClassForTypeInternal(ComActivationContextInternal* pCxtInt)
{
#if FEATURE_COMINTEROP_UNMANAGED_ACTIVATION
if (!Marshal.IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

ref ComActivationContextInternal cxtInt = ref *pCxtInt;

if (IsLoggingEnabled())
Expand Down
Expand Up @@ -195,6 +195,9 @@ private static void PrelinkCore(MethodInfo m)
[DllImport(RuntimeHelpers.QCall, CharSet = CharSet.Unicode)]
private static extern void InternalPrelink(RuntimeMethodHandleInternal m);

[DllImport(RuntimeHelpers.QCall)]
private static extern bool IsComSupportedInternal();

[MethodImpl(MethodImplOptions.InternalCall)]
public static extern /* struct _EXCEPTION_POINTERS* */ IntPtr GetExceptionPointers();

Expand Down Expand Up @@ -233,6 +236,10 @@ private static object PtrToStructureHelper(IntPtr ptr, Type structureType)
[MethodImpl(MethodImplOptions.InternalCall)]
internal static extern bool IsPinnable(object? obj);

internal static bool IsComSupported { get; } = InitializeIsComSupported();

private static bool InitializeIsComSupported() => IsComSupportedInternal();

#if TARGET_WINDOWS
/// <summary>
/// Returns the HInstance for this module. Returns -1 if the module doesn't have
Expand Down Expand Up @@ -289,6 +296,11 @@ public static string GetTypeInfoName(ITypeInfo typeInfo)
// on Marshal for more consistent API surface.
internal static Type? GetTypeFromCLSID(Guid clsid, string? server, bool throwOnError)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

// Note: "throwOnError" is a vacuous parameter. Any errors due to the CLSID not being registered or the server not being found will happen
// on the Activator.CreateInstance() call. GetTypeFromCLSID() merely wraps the data in a Type object without any validation.

Expand Down Expand Up @@ -429,12 +441,27 @@ public static object GetUniqueObjectForIUnknown(IntPtr unknown)
public static extern object GetTypedObjectForIUnknown(IntPtr /* IUnknown* */ pUnk, Type t);

[SupportedOSPlatform("windows")]
public static IntPtr CreateAggregatedObject(IntPtr pOuter, object o)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

return CreateAggregatedObjectNative(pOuter, o);
}

[MethodImpl(MethodImplOptions.InternalCall)]
public static extern IntPtr CreateAggregatedObject(IntPtr pOuter, object o);
private static extern IntPtr CreateAggregatedObjectNative(IntPtr pOuter, object o);

[SupportedOSPlatform("windows")]
public static IntPtr CreateAggregatedObject<T>(IntPtr pOuter, T o) where T : notnull
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

return CreateAggregatedObject(pOuter, (object)o);
}

Expand All @@ -457,6 +484,11 @@ public static object GetUniqueObjectForIUnknown(IntPtr unknown)
[SupportedOSPlatform("windows")]
public static int ReleaseComObject(object o)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

if (o is null)
{
// Match .NET Framework behaviour.
Expand All @@ -480,6 +512,11 @@ public static int ReleaseComObject(object o)
[SupportedOSPlatform("windows")]
public static int FinalReleaseComObject(object o)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

if (o is null)
{
throw new ArgumentNullException(nameof(o));
Expand All @@ -499,6 +536,11 @@ public static int FinalReleaseComObject(object o)
[SupportedOSPlatform("windows")]
public static object? GetComObjectData(object obj, object key)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

if (obj is null)
{
throw new ArgumentNullException(nameof(obj));
Expand All @@ -525,6 +567,11 @@ public static int FinalReleaseComObject(object o)
[SupportedOSPlatform("windows")]
public static bool SetComObjectData(object obj, object key, object? data)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

if (obj is null)
{
throw new ArgumentNullException(nameof(obj));
Expand All @@ -550,6 +597,11 @@ public static bool SetComObjectData(object obj, object key, object? data)
[return: NotNullIfNotNull("o")]
public static object? CreateWrapperOfType(object? o, Type t)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

if (t is null)
{
throw new ArgumentNullException(nameof(t));
Expand Down Expand Up @@ -600,6 +652,11 @@ public static bool SetComObjectData(object obj, object key, object? data)
[SupportedOSPlatform("windows")]
public static TWrapper CreateWrapperOfType<T, TWrapper>(T? o)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

return (TWrapper)CreateWrapperOfType(o, typeof(TWrapper))!;
}

Expand All @@ -613,32 +670,77 @@ public static bool SetComObjectData(object obj, object key, object? data)
public static extern bool IsTypeVisibleFromCom(Type t);

[SupportedOSPlatform("windows")]
public static void GetNativeVariantForObject(object? obj, /* VARIANT * */ IntPtr pDstNativeVariant)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

GetNativeVariantForObjectNative(obj, pDstNativeVariant);
}

[MethodImpl(MethodImplOptions.InternalCall)]
public static extern void GetNativeVariantForObject(object? obj, /* VARIANT * */ IntPtr pDstNativeVariant);
private static extern void GetNativeVariantForObjectNative(object? obj, /* VARIANT * */ IntPtr pDstNativeVariant);

[SupportedOSPlatform("windows")]
public static void GetNativeVariantForObject<T>(T? obj, IntPtr pDstNativeVariant)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

GetNativeVariantForObject((object?)obj, pDstNativeVariant);
}

[SupportedOSPlatform("windows")]
public static object? GetObjectForNativeVariant(/* VARIANT * */ IntPtr pSrcNativeVariant)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

return GetObjectForNativeVariantNative(pSrcNativeVariant);
}

[MethodImpl(MethodImplOptions.InternalCall)]
public static extern object? GetObjectForNativeVariant(/* VARIANT * */ IntPtr pSrcNativeVariant);
private static extern object? GetObjectForNativeVariantNative(/* VARIANT * */ IntPtr pSrcNativeVariant);

[SupportedOSPlatform("windows")]
public static T? GetObjectForNativeVariant<T>(IntPtr pSrcNativeVariant)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

return (T?)GetObjectForNativeVariant(pSrcNativeVariant);
}

[SupportedOSPlatform("windows")]
public static object?[] GetObjectsForNativeVariants(/* VARIANT * */ IntPtr aSrcNativeVariant, int cVars)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

return GetObjectsForNativeVariantsNative(aSrcNativeVariant, cVars);
}

[MethodImpl(MethodImplOptions.InternalCall)]
public static extern object?[] GetObjectsForNativeVariants(/* VARIANT * */ IntPtr aSrcNativeVariant, int cVars);
private static extern object?[] GetObjectsForNativeVariantsNative(/* VARIANT * */ IntPtr aSrcNativeVariant, int cVars);

[SupportedOSPlatform("windows")]
public static T[] GetObjectsForNativeVariants<T>(IntPtr aSrcNativeVariant, int cVars)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

object?[] objects = GetObjectsForNativeVariants(aSrcNativeVariant, cVars);

T[] result = new T[objects.Length];
Expand All @@ -665,6 +767,11 @@ public static T[] GetObjectsForNativeVariants<T>(IntPtr aSrcNativeVariant, int c
[SupportedOSPlatform("windows")]
public static object BindToMoniker(string monikerName)
{
if (!IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

CreateBindCtx(0, out IBindCtx bindctx);

MkParseDisplayName(bindctx, monikerName, out _, out IMoniker pmoniker);
Expand Down
Expand Up @@ -3680,6 +3680,11 @@ internal static object CreateEnum(RuntimeType enumType, long value)
Type[] aArgsTypes,
Type retType)
{
if (!Marshal.IsComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

Debug.Assert(
aArgs.Length == aArgsIsByRef.Length
&& aArgs.Length == aArgsTypes.Length
Expand Down
9 changes: 5 additions & 4 deletions src/coreclr/vm/ecalllist.h
Expand Up @@ -758,6 +758,7 @@ FCFuncStart(gInteropMarshalFuncs)
FCFuncElement("OffsetOfHelper", MarshalNative::OffsetOfHelper)

QCFuncElement("InternalPrelink", MarshalNative::Prelink)
QCFuncElement("IsComSupportedInternal", MarshalNative::IsComSupported)
FCFuncElement("GetExceptionForHRInternal", MarshalNative::GetExceptionForHR)
FCFuncElement("GetDelegateForFunctionPointerInternal", MarshalNative::GetDelegateForFunctionPointerInternal)
FCFuncElement("GetFunctionPointerForDelegateInternal", MarshalNative::GetFunctionPointerForDelegateInternal)
Expand All @@ -767,14 +768,14 @@ FCFuncStart(gInteropMarshalFuncs)
FCFuncElement("IsComObject", MarshalNative::IsComObject)
FCFuncElement("GetObjectForIUnknownNative", MarshalNative::GetObjectForIUnknownNative)
FCFuncElement("GetUniqueObjectForIUnknownNative", MarshalNative::GetUniqueObjectForIUnknownNative)
FCFuncElement("GetNativeVariantForObject", MarshalNative::GetNativeVariantForObject)
FCFuncElement("GetObjectForNativeVariant", MarshalNative::GetObjectForNativeVariant)
FCFuncElement("GetNativeVariantForObjectNative", MarshalNative::GetNativeVariantForObjectNative)
FCFuncElement("GetObjectForNativeVariantNative", MarshalNative::GetObjectForNativeVariantNative)
FCFuncElement("InternalFinalReleaseComObject", MarshalNative::FinalReleaseComObject)
FCFuncElement("IsTypeVisibleFromCom", MarshalNative::IsTypeVisibleFromCom)
FCFuncElement("CreateAggregatedObject", MarshalNative::CreateAggregatedObject)
FCFuncElement("CreateAggregatedObjectNative", MarshalNative::CreateAggregatedObjectNative)
FCFuncElement("AreComObjectsAvailableForCleanup", MarshalNative::AreComObjectsAvailableForCleanup)
FCFuncElement("InternalCreateWrapperOfType", MarshalNative::InternalCreateWrapperOfType)
FCFuncElement("GetObjectsForNativeVariants", MarshalNative::GetObjectsForNativeVariants)
FCFuncElement("GetObjectsForNativeVariantsNative", MarshalNative::GetObjectsForNativeVariantsNative)
FCFuncElement("GetStartComSlot", MarshalNative::GetStartComSlot)
FCFuncElement("GetEndComSlot", MarshalNative::GetEndComSlot)
FCFuncElement("GetIUnknownForObjectNative", MarshalNative::GetIUnknownForObjectNative)
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/vm/eeconfig.cpp
Expand Up @@ -170,6 +170,7 @@ HRESULT EEConfig::Init()
#ifdef FEATURE_COMINTEROP
bLogCCWRefCountChange = false;
pszLogCCWRefCountChange = NULL;
m_fBuiltInCOMInteropSupported = true;
#endif // FEATURE_COMINTEROP

#ifdef _DEBUG
Expand Down Expand Up @@ -682,6 +683,7 @@ HRESULT EEConfig::sync()
bLogCCWRefCountChange = true;

fEnableRCWCleanupOnSTAShutdown = (CLRConfig::GetConfigValue(CLRConfig::INTERNAL_EnableRCWCleanupOnSTAShutdown) != 0);
m_fBuiltInCOMInteropSupported = Configuration::GetKnobBooleanValue(W("System.Runtime.InteropServices.Marshal.IsComSupported"), true);
#endif // FEATURE_COMINTEROP

#ifdef _DEBUG
Expand Down

0 comments on commit 1c9e200

Please sign in to comment.