#if !defined(RT_COM_H)
#define RT_COM_H
// com.h
//
// Description:
//  This file provides some helper classes that are useful when
//  dealing with COM and DirectX 8 when using the exception mechanism
//  provided in <rt/hr.h>.
//
// Provides:
//  com_runtime
//      Initialize/Uninitialize the COM runtime in the class'
//      constructor/destructor.
//
//  com_attach_ptr
//      Extends ATL CComPtr<> to provide a constructor that does an
//      Attach() instead of an AddRef().
//
//  com_qi_ptr
//      Extends ATL CComQIPtr<> to throw when QueryInterface fails.
//
//  com_ptr
//      Extends ATL CComPtr<> to provide a constructor that takes
//      the CLSID and calls CoCreateInstance, throwing an exception
//      on failure.
//
// Example:
//      bool com_initialized = false;
//      try
//      {
//          THR(CoInitialize(NULL));   // or CoInitializeEx(NULL, flags)
//          com_initialized = true;
//          // ... other things that may throw
//      }
//      catch (...)
//      {
//          if (com_initialized)
//          {
//              CoUninitialize();
//          }
//          throw;
//      }
//
//      becomes
//
//      com_runtime com;    // or com_runtime com(flags)
//      // ... other things that may throw
// -----
//      CComPtr<IDirect3D8> p;
//      p.Attach(Direct3DCreate8(D3D_SDK_VERSION));
//      if (!p)
//      {
//          THR(E_NOINTERFACE);
//      }
//
//      becomes
//
//      rt::com_attach_ptr<IDirect3D8> p(Direct3DCreate8(D3D_SDK_VERSION));
// -----
//
// Copyright (C) 2000-2001, Rich Thomson, all rights reserved.
//

#include <objbase.h>
#include <atlbase.h>
#include <rt/hr.h>

namespace rt {
  ////////////////////////////////////////////////////////////
  // com_runtime
  //
  // Acquire the COM runtime as a constructed resource.
  //
  class com_runtime
  {
  public:
    com_runtime()
    {
      THRMT(::CoInitialize(NULL), "CoInitialize");
    }
#if defined(_WIN32_DCOM)
    com_runtime(DWORD flags)
    {
      THRMT(::CoInitializeEx(NULL, flags), "CoInitializeEx");
    }
#endif
    ~com_runtime()
    {
      ::CoUninitialize();
    }
  };
  
  ////////////////////////////////////////////////////////////
  // com_attach_ptr
  //
  // Attach to a COM interface pointer via CComPtr<> in our c'tor
  //
  template <typename T>
  class com_attach_ptr : public CComPtr<T>
  {
  public:
    com_attach_ptr(T *p) :
      CComPtr<T>()
    {
      if (p) {
        CComPtr<T>::Attach(p);
      }
    }
    ~com_attach_ptr() {}
  };

  ////////////////////////////////////////////////////////////
  // com_qi_ptr
  //
  // Like CComQIPtr<> except that it throws an exception when the
  // QI fails.
  //
  template <typename T>
  class com_qi_ptr : public CComQIPtr<T>
  {
  public:
    com_qi_ptr(IUnknown *p) :
      CComQIPtr<T>(p)
    {
      if (CComQIPtr<T>::operator!()) {
        THR(E_NOINTERFACE);
      }
    }
    ~com_qi_ptr() {}
  };

  ////////////////////////////////////////////////////////////
  // com_ptr
  //
  // Like CComPtr<>, but provides a constructor that calls
  // CComPtr::CoCreateInstance with the supplied CLSID.
  //
  template <typename T>
  class com_ptr : public CComPtr<T>
  {
  public:
    com_ptr() : CComPtr<T>() {}
    com_ptr(T *ptr) : CComPtr<T>(ptr) {}
    com_ptr(const CComPtr<T> &ptr) : CComPtr<T>(ptr) {}
    com_ptr(const com_ptr<T> &ptr) :
      CComPtr<T>(static_cast<CComPtr<T> &>(ptr))
    {
    }
    com_ptr(const CLSID &clsid, const TCHAR *message = NULL) : CComPtr<T>()
    {
        THRM(CComPtr<T>::CoCreateInstance(clsid),
                message ? message : _T("CoCreateInstance"));
    }
    ~com_ptr() {}
  };
};

#endif
