OSDN Git Service

Add additional drm hal tests
authorJeff Tinker <jtinker@google.com>
Fri, 31 Mar 2017 22:32:12 +0000 (15:32 -0700)
committerJeff Tinker <jtinker@google.com>
Thu, 6 Apr 2017 06:21:13 +0000 (23:21 -0700)
Test: VTS tests passing

bug:34178477
Change-Id: Ie8fa5f2ad193b717b0564bb3046de8a64ccd8d85

drm/1.0/default/DrmPlugin.cpp
drm/1.0/vts/functional/Android.bp
drm/1.0/vts/functional/drm_hal_clearkey_test.cpp
drm/1.0/vts/functional/drm_hal_vendor_module_api.h
drm/1.0/vts/functional/drm_hal_vendor_test.cpp
drm/1.0/vts/functional/vendor_modules.cpp
drm/1.0/vts/functional/vendor_modules.h

index c7428a5..e9133ac 100644 (file)
@@ -327,19 +327,25 @@ namespace implementation {
 
     Return<void> DrmPlugin::sendEvent(EventType eventType,
             const hidl_vec<uint8_t>& sessionId, const hidl_vec<uint8_t>& data) {
-        mListener->sendEvent(eventType, sessionId, data);
+        if (mListener != nullptr) {
+            mListener->sendEvent(eventType, sessionId, data);
+        }
         return Void();
     }
 
     Return<void> DrmPlugin::sendExpirationUpdate(
             const hidl_vec<uint8_t>& sessionId, int64_t expiryTimeInMS) {
-        mListener->sendExpirationUpdate(sessionId, expiryTimeInMS);
+        if (mListener != nullptr) {
+            mListener->sendExpirationUpdate(sessionId, expiryTimeInMS);
+        }
         return Void();
     }
 
     Return<void> DrmPlugin::sendKeysChange(const hidl_vec<uint8_t>& sessionId,
             const hidl_vec<KeyStatus>& keyStatusList, bool hasNewUsableKey) {
-        mListener->sendKeysChange(sessionId, keyStatusList, hasNewUsableKey);
+        if (mListener != nullptr) {
+            mListener->sendKeysChange(sessionId, keyStatusList, hasNewUsableKey);
+        }
         return Void();
     }
 
index 36d7d1c..43ea372 100644 (file)
@@ -34,6 +34,8 @@ cc_test {
         "libhwbinder",
         "liblog",
         "libnativehelper",
+        "libssl",
+        "libcrypto",
         "libutils",
     ],
     static_libs: [
index 2296d2d..6910855 100644 (file)
@@ -27,6 +27,7 @@
 #include <hidl/HidlSupport.h>
 #include <hidlmemory/mapping.h>
 #include <memory>
+#include <openssl/aes.h>
 #include <random>
 
 #include "VtsHalHidlTargetTestBase.h"
@@ -125,6 +126,39 @@ TEST_F(DrmHalClearkeyFactoryTest, InvalidPluginNotSupported) {
 }
 
 /**
+ * Ensure the factory doesn't support an empty UUID
+ */
+TEST_F(DrmHalClearkeyFactoryTest, EmptyPluginUUIDNotSupported) {
+    hidl_array<uint8_t, 16> emptyUUID;
+    EXPECT_FALSE(drmFactory->isCryptoSchemeSupported(emptyUUID));
+    EXPECT_FALSE(cryptoFactory->isCryptoSchemeSupported(emptyUUID));
+}
+
+/**
+ * Ensure empty content type is not supported
+ */
+TEST_F(DrmHalClearkeyFactoryTest, EmptyContentTypeNotSupported) {
+    hidl_string empty;
+    EXPECT_FALSE(drmFactory->isContentTypeSupported(empty));
+}
+
+/**
+ * Ensure invalid content type is not supported
+ */
+TEST_F(DrmHalClearkeyFactoryTest, InvalidContentTypeNotSupported) {
+    hidl_string invalid("abcdabcd");
+    EXPECT_FALSE(drmFactory->isContentTypeSupported(invalid));
+}
+
+/**
+ * Ensure valid content type is supported
+ */
+TEST_F(DrmHalClearkeyFactoryTest, ValidContentTypeSupported) {
+    hidl_string cencType("cenc");
+    EXPECT_TRUE(drmFactory->isContentTypeSupported(cencType));
+}
+
+/**
  * Ensure clearkey drm plugin can be created
  */
 TEST_F(DrmHalClearkeyFactoryTest, CreateClearKeyDrmPlugin) {
@@ -418,6 +452,26 @@ TEST_F(DrmHalClearkeyPluginTest, ProvideKeyResponseEmptyResponse) {
 }
 
 /**
+ * Test that a removeKeys on an empty sessionID returns BAD_VALUE
+ */
+TEST_F(DrmHalClearkeyPluginTest, RemoveKeysEmptySessionId) {
+    SessionId sessionId;
+    Status status = drmPlugin->removeKeys(sessionId);
+    EXPECT_TRUE(status == Status::BAD_VALUE);
+}
+
+/**
+ * Remove keys is not supported for clearkey.
+ */
+TEST_F(DrmHalClearkeyPluginTest, RemoveKeysNewSession) {
+    SessionId sessionId = openSession();
+    Status status = drmPlugin->removeKeys(sessionId);
+    // Clearkey plugin doesn't support remove keys
+    EXPECT_EQ(Status::ERROR_DRM_CANNOT_HANDLE, status);
+    closeSession(sessionId);
+}
+
+/**
  * Test that the clearkey plugin doesn't support getting
  * secure stops.
  */
@@ -617,7 +671,7 @@ TEST_F(DrmHalClearkeyPluginTest, GenericEncryptNotSupported) {
     ;
     hidl_vec<uint8_t> keyId = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1};
     hidl_vec<uint8_t> input = {1, 2, 3, 4, 5};
-    hidl_vec<uint8_t> iv = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+    hidl_vec<uint8_t> iv = std::vector<uint8_t>(AES_BLOCK_SIZE, 0);
     auto res = drmPlugin->encrypt(session, keyId, input, iv,
                                   [&](Status status, const hidl_vec<uint8_t>&) {
                                       EXPECT_EQ(Status::ERROR_DRM_CANNOT_HANDLE,
@@ -629,10 +683,9 @@ TEST_F(DrmHalClearkeyPluginTest, GenericEncryptNotSupported) {
 
 TEST_F(DrmHalClearkeyPluginTest, GenericDecryptNotSupported) {
     SessionId session = openSession();
-    ;
     hidl_vec<uint8_t> keyId = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1};
     hidl_vec<uint8_t> input = {1, 2, 3, 4, 5};
-    hidl_vec<uint8_t> iv = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+    hidl_vec<uint8_t> iv = std::vector<uint8_t>(AES_BLOCK_SIZE, 0);
     auto res = drmPlugin->decrypt(session, keyId, input, iv,
                                   [&](Status status, const hidl_vec<uint8_t>&) {
                                       EXPECT_EQ(Status::ERROR_DRM_CANNOT_HANDLE,
@@ -763,6 +816,17 @@ TEST_F(DrmHalClearkeyPluginTest, SetMediaDrmSessionClosedSession) {
 }
 
 /**
+ * setMediaDrmSession with an empty session id: BAD_VALUE.  An
+ * empty session clears the previously set session and should
+ * return OK.
+ */
+TEST_F(DrmHalClearkeyPluginTest, SetMediaDrmSessionEmptySession) {
+    SessionId sessionId;
+    Status status = cryptoPlugin->setMediaDrmSession(sessionId);
+    EXPECT_EQ(Status::OK, status);
+}
+
+/**
  * Decrypt tests
  */
 
@@ -771,9 +835,15 @@ class DrmHalClearkeyDecryptTest : public DrmHalClearkeyPluginTest {
     void loadKeys(const SessionId& sessionId);
     void fillRandom(const sp<IMemory>& memory);
     hidl_array<uint8_t, 16> toHidlArray(const vector<uint8_t>& vec) {
-        EXPECT_EQ(vec.size(), 16u);
+        EXPECT_EQ(16u, vec.size());
         return hidl_array<uint8_t, 16>(&vec[0]);
     }
+    uint32_t decrypt(Mode mode, uint8_t* iv, const hidl_vec<SubSample>& subSamples,
+            const Pattern& pattern, Status status);
+    void aes_ctr_decrypt(uint8_t* dest, uint8_t* src, uint8_t* iv,
+            const hidl_vec<SubSample>& subSamples, const vector<uint8_t>& key);
+    void aes_cbc_decrypt(uint8_t* dest, uint8_t* src, uint8_t* iv,
+            const hidl_vec<SubSample>& subSamples, const vector<uint8_t>& key);
 };
 
 /**
@@ -847,36 +917,162 @@ void DrmHalClearkeyDecryptTest::fillRandom(const sp<IMemory>& memory) {
     }
 }
 
-/**
- * Positive decrypt test.  "Decrypt" a single clear
- * segment.  Verify data matches.
- */
-TEST_F(DrmHalClearkeyDecryptTest, ClearSegmentTest) {
-    const size_t kSegmentSize = 1024;
+uint32_t DrmHalClearkeyDecryptTest::decrypt(Mode mode,
+        uint8_t* iv, const hidl_vec<SubSample>& subSamples,
+        const Pattern& pattern, Status expectedStatus) {
     const size_t kSegmentIndex = 0;
     const vector<uint8_t> keyId = {0x60, 0x06, 0x1e, 0x01, 0x7e, 0x47,
                                    0x7e, 0x87, 0x7e, 0x57, 0xd0, 0x0d,
                                    0x1e, 0xd0, 0x0d, 0x1e};
-    uint8_t iv[16] = {0};
+    const vector<uint8_t> contentKey = {0x1a, 0x8a, 0x20, 0x95, 0xe4,
+                                        0xde, 0xb2, 0xd2, 0x9e, 0xc8,
+                                        0x16, 0xac, 0x7b, 0xae, 0x20, 0x82};
+    uint8_t localIv[AES_BLOCK_SIZE];
+    memcpy(localIv, iv, AES_BLOCK_SIZE);
+
+    size_t totalSize = 0;
+    for (size_t i = 0; i < subSamples.size(); i++) {
+        totalSize += subSamples[i].numBytesOfClearData;
+        totalSize += subSamples[i].numBytesOfEncryptedData;
+    }
 
+    // The first totalSize bytes of shared memory is the encrypted
+    // input, the second totalSize bytes is the decrypted output.
     sp<IMemory> sharedMemory =
-            getDecryptMemory(kSegmentSize * 2, kSegmentIndex);
+            getDecryptMemory(totalSize * 2, kSegmentIndex);
 
-    SharedBuffer sourceBuffer = {
-            .bufferId = kSegmentIndex, .offset = 0, .size = kSegmentSize};
+    const SharedBuffer sourceBuffer = {
+        .bufferId = kSegmentIndex, .offset = 0, .size = totalSize};
     fillRandom(sharedMemory);
 
-    DestinationBuffer destBuffer = {.type = BufferType::SHARED_MEMORY,
-                                    {.bufferId = kSegmentIndex,
-                                     .offset = kSegmentSize,
-                                     .size = kSegmentSize},
-                                    .secureMemory = nullptr};
+    const DestinationBuffer destBuffer = {.type = BufferType::SHARED_MEMORY,
+                                          {.bufferId = kSegmentIndex,
+                                           .offset = totalSize,
+                                           .size = totalSize},
+                                          .secureMemory = nullptr};
+    const uint64_t offset = 0;
+    const bool kNotSecure = false;
+    uint32_t bytesWritten = 0;
+    auto res = cryptoPlugin->decrypt(kNotSecure, toHidlArray(keyId), localIv, mode,
+            pattern, subSamples, sourceBuffer, offset, destBuffer,
+            [&](Status status, uint32_t count, string detailedError) {
+                EXPECT_EQ(expectedStatus, status) << "Unexpected decrypt status " <<
+                detailedError;
+                bytesWritten = count;
+            });
+    EXPECT_OK(res);
+
+    if (bytesWritten != totalSize) {
+        return bytesWritten;
+    }
+    uint8_t* base = static_cast<uint8_t*>(
+            static_cast<void*>(sharedMemory->getPointer()));
+
+    // generate reference vector
+    vector<uint8_t> reference(totalSize);
+
+    memcpy(localIv, iv, AES_BLOCK_SIZE);
+    switch (mode) {
+    case Mode::UNENCRYPTED:
+        memcpy(&reference[0], base, totalSize);
+        break;
+    case Mode::AES_CTR:
+        aes_ctr_decrypt(&reference[0], base, localIv, subSamples, contentKey);
+        break;
+    case Mode::AES_CBC:
+        aes_cbc_decrypt(&reference[0], base, localIv, subSamples, contentKey);
+        break;
+    case Mode::AES_CBC_CTS:
+        EXPECT_TRUE(false) << "AES_CBC_CTS mode not supported";
+        break;
+    }
 
-    Pattern noPattern = {0, 0};
-    vector<SubSample> subSamples = {{.numBytesOfClearData = kSegmentSize,
-                                     .numBytesOfEncryptedData = 0}};
-    uint64_t offset = 0;
+    // compare reference to decrypted data which is at base + total size
+    EXPECT_EQ(0, memcmp(static_cast<void *>(&reference[0]),
+                        static_cast<void*>(base + totalSize), totalSize))
+            << "decrypt data mismatch";
+    return totalSize;
+}
 
+/**
+ * Decrypt a list of clear+encrypted subsamples using the specified key
+ * in AES-CTR mode
+ */
+void DrmHalClearkeyDecryptTest::aes_ctr_decrypt(uint8_t* dest, uint8_t* src,
+        uint8_t* iv, const hidl_vec<SubSample>& subSamples,
+        const vector<uint8_t>& key) {
+    AES_KEY decryptionKey;
+    AES_set_encrypt_key(&key[0], 128, &decryptionKey);
+
+    size_t offset = 0;
+    unsigned int blockOffset = 0;
+    uint8_t previousEncryptedCounter[AES_BLOCK_SIZE];
+    memset(previousEncryptedCounter, 0, AES_BLOCK_SIZE);
+
+    for (size_t i = 0; i < subSamples.size(); i++) {
+        const SubSample& subSample = subSamples[i];
+
+        if (subSample.numBytesOfClearData > 0) {
+            memcpy(dest + offset, src + offset, subSample.numBytesOfClearData);
+            offset += subSample.numBytesOfClearData;
+        }
+
+        if (subSample.numBytesOfEncryptedData > 0) {
+            AES_ctr128_encrypt(src + offset, dest + offset,
+                    subSample.numBytesOfEncryptedData, &decryptionKey,
+                    iv, previousEncryptedCounter, &blockOffset);
+            offset += subSample.numBytesOfEncryptedData;
+        }
+    }
+}
+
+/**
+ * Decrypt a list of clear+encrypted subsamples using the specified key
+ * in AES-CBC mode
+ */
+void DrmHalClearkeyDecryptTest::aes_cbc_decrypt(uint8_t* dest, uint8_t* src,
+        uint8_t* iv, const hidl_vec<SubSample>& subSamples,
+        const vector<uint8_t>& key) {
+    AES_KEY decryptionKey;
+    AES_set_encrypt_key(&key[0], 128, &decryptionKey);
+
+    size_t offset = 0;
+    size_t num = 0;
+    size_t ecount_buf = 0;
+    for (size_t i = 0; i < subSamples.size(); i++) {
+        memcpy(dest + offset, src + offset, subSamples[i].numBytesOfClearData);
+        offset += subSamples[i].numBytesOfClearData;
+
+        AES_cbc_encrypt(src + offset, dest + offset, subSamples[i].numBytesOfEncryptedData,
+                &decryptionKey, iv, 0 /* decrypt */);
+        offset += subSamples[i].numBytesOfEncryptedData;
+    }
+}
+
+/**
+ * Test query key status
+ */
+TEST_F(DrmHalClearkeyDecryptTest, TestQueryKeyStatus) {
+    auto sessionId = openSession();
+    auto res = drmPlugin->queryKeyStatus(sessionId,
+            [&](Status status, KeyedVector /* info */) {
+                // clearkey doesn't support this method
+                EXPECT_EQ(Status::ERROR_DRM_CANNOT_HANDLE, status);
+            });
+    EXPECT_OK(res);
+}
+
+
+/**
+ * Positive decrypt test.  "Decrypt" a single clear segment
+ */
+TEST_F(DrmHalClearkeyDecryptTest, ClearSegmentTest) {
+    vector<uint8_t> iv(AES_BLOCK_SIZE, 0);
+    const Pattern noPattern = {0, 0};
+    const uint32_t kByteCount = 256;
+    const vector<SubSample> subSamples = {
+        {.numBytesOfClearData = kByteCount,
+         .numBytesOfEncryptedData = 0}};
     auto sessionId = openSession();
     loadKeys(sessionId);
 
@@ -884,21 +1080,57 @@ TEST_F(DrmHalClearkeyDecryptTest, ClearSegmentTest) {
     EXPECT_EQ(Status::OK, status);
 
     const bool kNotSecure = false;
-    auto res = cryptoPlugin->decrypt(
-            kNotSecure, toHidlArray(keyId), iv, Mode::UNENCRYPTED, noPattern,
-            subSamples, sourceBuffer, offset, destBuffer,
-            [&](Status status, uint32_t bytesWritten, string detailedError) {
-                EXPECT_EQ(Status::OK, status) << "Failure in decryption:"
-                                              << detailedError;
-                EXPECT_EQ(bytesWritten, kSegmentSize);
-            });
-    EXPECT_OK(res);
+    uint32_t byteCount = decrypt(Mode::UNENCRYPTED, &iv[0], subSamples,
+            noPattern, Status::OK);
+    EXPECT_EQ(kByteCount, byteCount);
 
-    uint8_t* base = static_cast<uint8_t*>(
-            static_cast<void*>(sharedMemory->getPointer()));
+    closeSession(sessionId);
+}
+
+/**
+ * Positive decrypt test.  Decrypt a single segment using AES_CTR.
+ * Verify data matches.
+ */
+TEST_F(DrmHalClearkeyDecryptTest, EncryptedAesCtrSegmentTest) {
+    vector<uint8_t> iv(AES_BLOCK_SIZE, 0);
+    const Pattern noPattern = {0, 0};
+    const uint32_t kClearBytes = 512;
+    const uint32_t kEncryptedBytes = 512;
+    const vector<SubSample> subSamples = {
+        {.numBytesOfClearData = kClearBytes,
+         .numBytesOfEncryptedData = kEncryptedBytes
+        }};
+    auto sessionId = openSession();
+    loadKeys(sessionId);
+
+    Status status = cryptoPlugin->setMediaDrmSession(sessionId);
+    EXPECT_EQ(Status::OK, status);
+
+    const bool kNotSecure = false;
+    uint32_t byteCount = decrypt(Mode::AES_CTR, &iv[0], subSamples,
+            noPattern, Status::OK);
+    EXPECT_EQ(kClearBytes + kEncryptedBytes, byteCount);
+
+    closeSession(sessionId);
+}
+/**
+ * Negative decrypt test. Decrypt without loading keys.
+ */
+TEST_F(DrmHalClearkeyDecryptTest, EncryptedAesCtrSegmentTestNoKeys) {
+    vector<uint8_t> iv(AES_BLOCK_SIZE, 0);
+    const Pattern noPattern = {0, 0};
+    const vector<SubSample> subSamples = {
+        {.numBytesOfClearData = 256,
+         .numBytesOfEncryptedData = 256}};
+    auto sessionId = openSession();
+
+    Status status = cryptoPlugin->setMediaDrmSession(sessionId);
+    EXPECT_EQ(Status::OK, status);
+
+    const bool kNotSecure = false;
+    uint32_t byteCount = decrypt(Mode::AES_CTR, &iv[0], subSamples,
+            noPattern, Status::ERROR_DRM_NO_LICENSE);
+    EXPECT_EQ(0u, byteCount);
 
-    EXPECT_EQ(0, memcmp(static_cast<void*>(base),
-                        static_cast<void*>(base + kSegmentSize), kSegmentSize))
-            << "decrypt data mismatch";
     closeSession(sessionId);
 }
index db19719..73e0cfe 100644 (file)
@@ -73,21 +73,21 @@ class DrmHalVTSVendorModule {
      * value with initial version 1. The API version indicates which subclass
      * version DrmHalVTSVendorModule this instance is.
      */
-    virtual uint32_t getAPIVersion() = 0;
+    virtual uint32_t getAPIVersion() const = 0;
 
     /**
      * Return the UUID for the DRM HAL implementation. Protection System
      * Specific
      * UUID (see http://dashif.org/identifiers/protection/)
      */
-    virtual std::vector<uint8_t> getUUID() = 0;
+    virtual std::vector<uint8_t> getUUID() const = 0;
 
     /**
      * Return the service name for the DRM HAL implementation. If the hal is a
      * legacy
      * drm plugin, i.e. not running as a HIDL service, return the empty string.
      */
-    virtual std::string getServiceName() = 0;
+    virtual std::string getServiceName() const = 0;
 
    private:
     DrmHalVTSVendorModule(const DrmHalVTSVendorModule&) = delete;
@@ -103,7 +103,7 @@ class DrmHalVTSVendorModule_V1 : public DrmHalVTSVendorModule {
     DrmHalVTSVendorModule_V1() {}
     virtual ~DrmHalVTSVendorModule_V1() {}
 
-    virtual uint32_t getAPIVersion() { return 1; }
+    virtual uint32_t getAPIVersion() const { return 1; }
 
     /**
      * Handle a provisioning request. This function will be called if the HAL
@@ -178,11 +178,10 @@ class DrmHalVTSVendorModule_V1 : public DrmHalVTSVendorModule {
             const std::vector<uint8_t> keyId;
 
             /**
-             * The key value is provided to generate expected values for
-             * validating
-             * decryption.  If isSecure is false, no key value is required.
+             * The clear content key is provided to generate expected values for
+             * validating decryption.
              */
-            const std::vector<uint8_t> keyValue;
+            const std::vector<uint8_t> clearContentKey;
         };
         std::vector<Key> keys;
     };
@@ -191,7 +190,8 @@ class DrmHalVTSVendorModule_V1 : public DrmHalVTSVendorModule {
      * Return a list of content configurations that can be exercised by the
      * VTS test.
      */
-    virtual std::vector<ContentConfiguration> getContentConfigurations() = 0;
+    virtual std::vector<ContentConfiguration>
+            getContentConfigurations() const = 0;
 
     /**
      * Handle a key request. This function will be called if the HAL
index bd78442..7448c42 100644 (file)
 #include <android/hardware/drm/1.0/ICryptoPlugin.h>
 #include <android/hardware/drm/1.0/IDrmFactory.h>
 #include <android/hardware/drm/1.0/IDrmPlugin.h>
+#include <android/hardware/drm/1.0/IDrmPluginListener.h>
 #include <android/hardware/drm/1.0/types.h>
 #include <android/hidl/allocator/1.0/IAllocator.h>
 #include <gtest/gtest.h>
 #include <hidlmemory/mapping.h>
 #include <memory>
+#include <openssl/aes.h>
 #include <random>
 
-#include "VtsHalHidlTargetTestBase.h"
 #include "drm_hal_vendor_module_api.h"
 #include "vendor_modules.h"
+#include "VtsHalHidlTargetTestBase.h"
 
 using ::android::hardware::drm::V1_0::BufferType;
 using ::android::hardware::drm::V1_0::DestinationBuffer;
+using ::android::hardware::drm::V1_0::EventType;
 using ::android::hardware::drm::V1_0::ICryptoFactory;
 using ::android::hardware::drm::V1_0::ICryptoPlugin;
 using ::android::hardware::drm::V1_0::IDrmFactory;
 using ::android::hardware::drm::V1_0::IDrmPlugin;
+using ::android::hardware::drm::V1_0::IDrmPluginListener;
 using ::android::hardware::drm::V1_0::KeyedVector;
-using ::android::hardware::drm::V1_0::KeyValue;
 using ::android::hardware::drm::V1_0::KeyRequestType;
+using ::android::hardware::drm::V1_0::KeyStatus;
+using ::android::hardware::drm::V1_0::KeyStatusType;
 using ::android::hardware::drm::V1_0::KeyType;
+using ::android::hardware::drm::V1_0::KeyValue;
 using ::android::hardware::drm::V1_0::Mode;
 using ::android::hardware::drm::V1_0::Pattern;
 using ::android::hardware::drm::V1_0::SecureStop;
@@ -56,6 +62,7 @@ using ::android::hardware::hidl_memory;
 using ::android::hardware::hidl_string;
 using ::android::hardware::hidl_vec;
 using ::android::hardware::Return;
+using ::android::hardware::Void;
 using ::android::hidl::allocator::V1_0::IAllocator;
 using ::android::hidl::memory::V1_0::IMemory;
 using ::android::sp;
@@ -67,6 +74,9 @@ using std::map;
 using std::mt19937;
 using std::vector;
 
+using ContentConfiguration = ::DrmHalVTSVendorModule_V1::ContentConfiguration;
+using Key = ::DrmHalVTSVendorModule_V1::ContentConfiguration::Key;
+
 #define ASSERT_OK(ret) ASSERT_TRUE(ret.isOk())
 #define EXPECT_OK(ret) EXPECT_TRUE(ret.isOk())
 
@@ -80,10 +90,9 @@ static drm_vts::VendorModules* gVendorModules = nullptr;
 class DrmHalVendorFactoryTest : public testing::TestWithParam<std::string> {
    public:
     DrmHalVendorFactoryTest()
-        : vendorModule(gVendorModules ? static_cast<DrmHalVTSVendorModule_V1*>(
-                                                gVendorModules->getVendorModule(
-                                                        GetParam()))
-                                      : nullptr) {}
+        : vendorModule(static_cast<DrmHalVTSVendorModule_V1*>(
+                        gVendorModules->getModule(GetParam()))),
+          contentConfigurations(vendorModule->getContentConfigurations()) {}
 
     virtual ~DrmHalVendorFactoryTest() {}
 
@@ -117,14 +126,27 @@ class DrmHalVendorFactoryTest : public testing::TestWithParam<std::string> {
     sp<IDrmFactory> drmFactory;
     sp<ICryptoFactory> cryptoFactory;
     unique_ptr<DrmHalVTSVendorModule_V1> vendorModule;
+    const vector<ContentConfiguration> contentConfigurations;
 };
 
-/**
- * Ensure the factory supports its scheme UUID
- */
-TEST_P(DrmHalVendorFactoryTest, VendorPluginSupported) {
-    EXPECT_TRUE(drmFactory->isCryptoSchemeSupported(getVendorUUID()));
-    EXPECT_TRUE(cryptoFactory->isCryptoSchemeSupported(getVendorUUID()));
+TEST_P(DrmHalVendorFactoryTest, ValidateConfigurations) {
+    const char* kVendorStr = "Vendor module ";
+    for (auto config : contentConfigurations) {
+        ASSERT_TRUE(config.name.size() > 0) << kVendorStr << "has no name";
+        ASSERT_TRUE(config.serverUrl.size() > 0) << kVendorStr
+                                                 << "has no serverUrl";
+        ASSERT_TRUE(config.initData.size() > 0) << kVendorStr
+                                                << "has no init data";
+        ASSERT_TRUE(config.mimeType.size() > 0) << kVendorStr
+                                                << "has no mime type";
+        ASSERT_TRUE(config.keys.size() >= 1) << kVendorStr << "has no keys";
+        for (auto key : config.keys) {
+            ASSERT_TRUE(key.keyId.size() > 0) << kVendorStr
+                                              << " has zero length keyId";
+            ASSERT_TRUE(key.keyId.size() > 0) << kVendorStr
+                                              << " has zero length key value";
+        }
+    }
 }
 
 /**
@@ -136,6 +158,48 @@ TEST_P(DrmHalVendorFactoryTest, InvalidPluginNotSupported) {
 }
 
 /**
+ * Ensure the factory doesn't support an empty UUID
+ */
+TEST_P(DrmHalVendorFactoryTest, EmptyPluginUUIDNotSupported) {
+    hidl_array<uint8_t, 16> emptyUUID;
+    EXPECT_FALSE(drmFactory->isCryptoSchemeSupported(emptyUUID));
+    EXPECT_FALSE(cryptoFactory->isCryptoSchemeSupported(emptyUUID));
+}
+
+/**
+ * Ensure the factory supports the scheme uuid in the config
+ */
+TEST_P(DrmHalVendorFactoryTest, EmptyPluginConfigUUIDSupported) {
+    EXPECT_TRUE(drmFactory->isCryptoSchemeSupported(getVendorUUID()));
+    EXPECT_TRUE(cryptoFactory->isCryptoSchemeSupported(getVendorUUID()));
+}
+
+/**
+ * Ensure empty content type is not supported
+ */
+TEST_P(DrmHalVendorFactoryTest, EmptyContentTypeNotSupported) {
+    hidl_string empty;
+    EXPECT_FALSE(drmFactory->isContentTypeSupported(empty));
+}
+
+/**
+ * Ensure invalid content type is not supported
+ */
+TEST_P(DrmHalVendorFactoryTest, InvalidContentTypeNotSupported) {
+    hidl_string invalid("abcdabcd");
+    EXPECT_FALSE(drmFactory->isContentTypeSupported(invalid));
+}
+
+/**
+ * Ensure valid content types in the configs are supported
+ */
+TEST_P(DrmHalVendorFactoryTest, ValidContentTypeSupported) {
+    for (auto config : contentConfigurations) {
+        EXPECT_TRUE(drmFactory->isContentTypeSupported(config.mimeType));
+    }
+}
+
+/**
  * Ensure vendor drm plugin can be created
  */
 TEST_P(DrmHalVendorFactoryTest, CreateVendorDrmPlugin) {
@@ -393,6 +457,26 @@ TEST_P(DrmHalVendorPluginTest, ProvideKeyResponseEmptyResponse) {
 }
 
 /**
+ * Test that a removeKeys on an empty sessionID returns BAD_VALUE
+ */
+TEST_P(DrmHalVendorPluginTest, RemoveKeysEmptySessionId) {
+    SessionId sessionId;
+    Status status = drmPlugin->removeKeys(sessionId);
+    EXPECT_TRUE(status == Status::BAD_VALUE);
+}
+
+/**
+ * Test that remove keys returns okay on an initialized session
+ * that has no keys.
+ */
+TEST_P(DrmHalVendorPluginTest, RemoveKeysNewSession) {
+    SessionId sessionId = openSession();
+    Status status = drmPlugin->removeKeys(sessionId);
+    EXPECT_TRUE(status == Status::OK);
+    closeSession(sessionId);
+}
+
+/**
  * Test that the plugin either doesn't support getting
  * secure stops, or has no secure stops available after
  * clearing them.
@@ -722,6 +806,175 @@ TEST_P(DrmHalVendorPluginTest, RequiresSecureDecoderInvalidMimeType) {
 }
 
 /**
+ * Verify that requiresSecureDecoderComponent returns true for secure
+ * configurations
+ */
+TEST_P(DrmHalVendorPluginTest, RequiresSecureDecoderConfig) {
+    const char* kVendorStr = "Vendor module ";
+    for (auto config : contentConfigurations) {
+        for (auto key : config.keys) {
+            if (key.isSecure) {
+                EXPECT_TRUE(cryptoPlugin->requiresSecureDecoderComponent(config.mimeType));
+                break;
+            }
+        }
+    }
+}
+
+/**
+ *  Event Handling tests
+ */
+
+class TestDrmPluginListener : public IDrmPluginListener {
+public:
+    TestDrmPluginListener() {reset();}
+    virtual ~TestDrmPluginListener() {}
+
+    virtual Return<void> sendEvent(EventType eventType, const hidl_vec<uint8_t>& sessionId,
+            const hidl_vec<uint8_t>& data) override {
+        eventType_ = eventType;
+        sessionId_ = sessionId;
+        data_ = data;
+        gotEvent_ = true;
+        return Void();
+    }
+
+    virtual Return<void> sendExpirationUpdate(const hidl_vec<uint8_t>& sessionId,
+            int64_t expiryTimeInMS) override {
+        sessionId_ = sessionId;
+        expiryTimeInMS_ = expiryTimeInMS;
+        gotExpirationUpdate_ = true;
+        return Void();
+    }
+
+    virtual Return<void> sendKeysChange(const hidl_vec<uint8_t>& sessionId,
+            const hidl_vec<KeyStatus>& keyStatusList, bool hasNewUsableKey) override {
+        sessionId_ = sessionId;
+        keyStatusList_ = keyStatusList;
+        hasNewUsableKey_ = hasNewUsableKey;
+        gotKeysChange_ = true;
+        return Void();
+    }
+
+    EventType getEventType() const {return eventType_;}
+    SessionId getSessionId() const {return sessionId_;}
+    vector<uint8_t> getData() const {return data_;}
+    int64_t getExpiryTimeInMS() const {return expiryTimeInMS_;}
+    hidl_vec<KeyStatus> getKeyStatusList() const {return keyStatusList_;}
+    bool hasNewUsableKey() {return hasNewUsableKey_;}
+    bool gotEvent() {return gotEvent_;}
+    bool gotExpirationUpdate() {return gotExpirationUpdate_;}
+    bool gotKeysChange() {return gotKeysChange_;}
+
+    void reset() {
+        gotEvent_ = gotExpirationUpdate_ = gotKeysChange_ = false;
+        eventType_ = EventType::PROVISION_REQUIRED;
+        sessionId_ = SessionId();
+        data_ = hidl_vec<uint8_t>();
+        expiryTimeInMS_ = 0;
+        keyStatusList_ = hidl_vec<KeyStatus>();
+        hasNewUsableKey_ = false;
+    }
+
+private:
+    bool gotEvent_;
+    bool gotExpirationUpdate_;
+    bool gotKeysChange_;
+
+    EventType eventType_;
+    SessionId sessionId_;
+    hidl_vec<uint8_t> data_;
+    int64_t expiryTimeInMS_;
+    hidl_vec<KeyStatus> keyStatusList_;
+    bool hasNewUsableKey_;
+};
+
+/**
+ * Simulate the plugin sending events. Make sure the listener
+ * gets them.
+ */
+TEST_P(DrmHalVendorPluginTest, ListenerEvents) {
+    sp<TestDrmPluginListener> listener = new TestDrmPluginListener();
+    drmPlugin->setListener(listener);
+    auto sessionId = openSession();
+    vector<uint8_t> data = {0, 1, 2};
+    EventType eventTypes[] = {EventType::PROVISION_REQUIRED,
+                              EventType::KEY_NEEDED,
+                              EventType::KEY_EXPIRED,
+                              EventType::VENDOR_DEFINED,
+                              EventType::SESSION_RECLAIMED};
+    for (auto eventType : eventTypes) {
+        listener->reset();
+        drmPlugin->sendEvent(eventType, sessionId, data);
+        while (!listener->gotEvent()) {usleep(100);}
+        EXPECT_EQ(eventType, listener->getEventType());
+        EXPECT_EQ(sessionId, listener->getSessionId());
+        EXPECT_EQ(data, listener->getData());
+    }
+    closeSession(sessionId);
+}
+
+/**
+ * Simulate the plugin sending expiration updates and make sure
+ * the listener gets them.
+ */
+TEST_P(DrmHalVendorPluginTest, ListenerExpirationUpdate) {
+    sp<TestDrmPluginListener> listener = new TestDrmPluginListener();
+    drmPlugin->setListener(listener);
+    auto sessionId = openSession();
+    drmPlugin->sendExpirationUpdate(sessionId, 100);
+    while (!listener->gotExpirationUpdate()) {usleep(100);}
+    EXPECT_EQ(sessionId, listener->getSessionId());
+    EXPECT_EQ(100, listener->getExpiryTimeInMS());
+    closeSession(sessionId);
+}
+
+/**
+ * Simulate the plugin sending keys change and make sure
+ * the listener gets them.
+ */
+TEST_P(DrmHalVendorPluginTest, ListenerKeysChange) {
+    sp<TestDrmPluginListener> listener = new TestDrmPluginListener();
+    drmPlugin->setListener(listener);
+    auto sessionId = openSession();
+    const hidl_vec<KeyStatus> keyStatusList = {
+        {{1}, KeyStatusType::USABLE},
+        {{2}, KeyStatusType::EXPIRED},
+        {{3}, KeyStatusType::OUTPUTNOTALLOWED},
+        {{4}, KeyStatusType::STATUSPENDING},
+        {{5}, KeyStatusType::INTERNALERROR},
+    };
+
+    drmPlugin->sendKeysChange(sessionId, keyStatusList, true);
+    while (!listener->gotKeysChange()) {usleep(100);}
+    EXPECT_EQ(sessionId, listener->getSessionId());
+    EXPECT_EQ(keyStatusList, listener->getKeyStatusList());
+    EXPECT_EQ(true, listener->hasNewUsableKey());
+}
+
+/**
+ * Negative listener tests. Call send methods with no
+ * listener set.
+ */
+TEST_P(DrmHalVendorPluginTest, NotListening) {
+    sp<TestDrmPluginListener> listener = new TestDrmPluginListener();
+    drmPlugin->setListener(listener);
+    drmPlugin->setListener(nullptr);
+
+    SessionId sessionId;
+    vector<uint8_t> data;
+    hidl_vec<KeyStatus> keyStatusList;
+    drmPlugin->sendEvent(EventType::PROVISION_REQUIRED, sessionId, data);
+    drmPlugin->sendExpirationUpdate(sessionId, 100);
+    drmPlugin->sendKeysChange(sessionId, keyStatusList, true);
+    usleep(1000); // can't wait for the event to be recieved, just wait a long time
+    EXPECT_EQ(false, listener->gotEvent());
+    EXPECT_EQ(false, listener->gotExpirationUpdate());
+    EXPECT_EQ(false, listener->gotKeysChange());
+}
+
+
+/**
  *  CryptoPlugin tests
  */
 
@@ -786,6 +1039,15 @@ TEST_P(DrmHalVendorPluginTest, SetMediaDrmSessionClosedSession) {
 }
 
 /**
+ * setMediaDrmSession with a empty session id: BAD_VALUE
+ */
+TEST_P(DrmHalVendorPluginTest, SetMediaDrmSessionEmptySession) {
+    SessionId sessionId;
+    Status status = cryptoPlugin->setMediaDrmSession(sessionId);
+    EXPECT_EQ(Status::BAD_VALUE, status);
+}
+
+/**
  * Decrypt tests
  */
 
@@ -796,14 +1058,23 @@ class DrmHalVendorDecryptTest : public DrmHalVendorPluginTest {
 
    protected:
     void loadKeys(const SessionId& sessionId,
-                  const DrmHalVTSVendorModule_V1::ContentConfiguration&
-                          configuration);
+                  const ContentConfiguration& configuration);
     void fillRandom(const sp<IMemory>& memory);
     KeyedVector toHidlKeyedVector(const map<string, string>& params);
     hidl_array<uint8_t, 16> toHidlArray(const vector<uint8_t>& vec) {
         EXPECT_EQ(vec.size(), 16u);
         return hidl_array<uint8_t, 16>(&vec[0]);
     }
+    hidl_vec<KeyValue> queryKeyStatus(SessionId sessionId);
+    void removeKeys(SessionId sessionId);
+    uint32_t decrypt(Mode mode, bool isSecure,
+            const hidl_array<uint8_t, 16>& keyId, uint8_t* iv,
+            const hidl_vec<SubSample>& subSamples, const Pattern& pattern,
+            const vector<uint8_t>& key, Status expectedStatus);
+    void aes_ctr_decrypt(uint8_t* dest, uint8_t* src, uint8_t* iv,
+            const hidl_vec<SubSample>& subSamples, const vector<uint8_t>& key);
+    void aes_cbc_decrypt(uint8_t* dest, uint8_t* src, uint8_t* iv,
+            const hidl_vec<SubSample>& subSamples, const vector<uint8_t>& key);
 };
 
 KeyedVector DrmHalVendorDecryptTest::toHidlKeyedVector(
@@ -823,9 +1094,8 @@ KeyedVector DrmHalVendorDecryptTest::toHidlKeyedVector(
  * These tests use predetermined key request/response to
  * avoid requiring a round trip to a license server.
  */
-void DrmHalVendorDecryptTest::loadKeys(
-        const SessionId& sessionId,
-        const DrmHalVTSVendorModule_V1::ContentConfiguration& configuration) {
+void DrmHalVendorDecryptTest::loadKeys(const SessionId& sessionId,
+        const ContentConfiguration& configuration) {
     hidl_vec<uint8_t> keyRequest;
     auto res = drmPlugin->getKeyRequest(
             sessionId, configuration.initData, configuration.mimeType,
@@ -874,111 +1144,326 @@ void DrmHalVendorDecryptTest::fillRandom(const sp<IMemory>& memory) {
     }
 }
 
-TEST_P(DrmHalVendorDecryptTest, ValidateConfigurations) {
-    vector<DrmHalVTSVendorModule_V1::ContentConfiguration> configurations =
-            vendorModule->getContentConfigurations();
-    const char* kVendorStr = "Vendor module ";
-    for (auto config : configurations) {
-        ASSERT_TRUE(config.name.size() > 0) << kVendorStr << "has no name";
-        ASSERT_TRUE(config.serverUrl.size() > 0) << kVendorStr
-                                                 << "has no serverUrl";
-        ASSERT_TRUE(config.initData.size() > 0) << kVendorStr
-                                                << "has no init data";
-        ASSERT_TRUE(config.mimeType.size() > 0) << kVendorStr
-                                                << "has no mime type";
-        ASSERT_TRUE(config.keys.size() >= 1) << kVendorStr << "has no keys";
-        for (auto key : config.keys) {
-            ASSERT_TRUE(key.keyId.size() > 0) << kVendorStr
-                                              << " has zero length keyId";
-            ASSERT_TRUE(key.keyId.size() > 0) << kVendorStr
-                                              << " has zero length key value";
+hidl_vec<KeyValue> DrmHalVendorDecryptTest::queryKeyStatus(SessionId sessionId) {
+    hidl_vec<KeyValue> keyStatus;
+    auto res = drmPlugin->queryKeyStatus(sessionId,
+            [&](Status status, KeyedVector info) {
+                EXPECT_EQ(Status::OK, status);
+                keyStatus = info;
+            });
+    EXPECT_OK(res);
+    return keyStatus;
+}
+
+void DrmHalVendorDecryptTest::removeKeys(SessionId sessionId) {
+    auto res = drmPlugin->removeKeys(sessionId);
+    EXPECT_OK(res);
+}
+
+uint32_t DrmHalVendorDecryptTest::decrypt(Mode mode, bool isSecure,
+        const hidl_array<uint8_t, 16>& keyId, uint8_t* iv,
+        const hidl_vec<SubSample>& subSamples, const Pattern& pattern,
+        const vector<uint8_t>& key, Status expectedStatus) {
+    const size_t kSegmentIndex = 0;
+
+    uint8_t localIv[AES_BLOCK_SIZE];
+    memcpy(localIv, iv, AES_BLOCK_SIZE);
+
+    size_t totalSize = 0;
+    for (size_t i = 0; i < subSamples.size(); i++) {
+        totalSize += subSamples[i].numBytesOfClearData;
+        totalSize += subSamples[i].numBytesOfEncryptedData;
+    }
+
+    // The first totalSize bytes of shared memory is the encrypted
+    // input, the second totalSize bytes is the decrypted output.
+    sp<IMemory> sharedMemory =
+            getDecryptMemory(totalSize * 2, kSegmentIndex);
+
+    SharedBuffer sourceBuffer = {
+            .bufferId = kSegmentIndex, .offset = 0, .size = totalSize};
+    fillRandom(sharedMemory);
+
+    DestinationBuffer destBuffer = {.type = BufferType::SHARED_MEMORY,
+                                    {.bufferId = kSegmentIndex,
+                                     .offset = totalSize,
+                                     .size = totalSize},
+                                    .secureMemory = nullptr};
+    uint64_t offset = 0;
+    uint32_t bytesWritten = 0;
+    auto res = cryptoPlugin->decrypt(isSecure, keyId, localIv, mode, pattern,
+            subSamples, sourceBuffer, offset, destBuffer,
+            [&](Status status, uint32_t count, string detailedError) {
+                EXPECT_EQ(expectedStatus, status) << "Unexpected decrypt status " <<
+                detailedError;
+                bytesWritten = count;
+            });
+    EXPECT_OK(res);
+
+    if (bytesWritten != totalSize) {
+        return bytesWritten;
+    }
+    uint8_t* base = static_cast<uint8_t*>(
+            static_cast<void*>(sharedMemory->getPointer()));
+
+    // generate reference vector
+    vector<uint8_t> reference(totalSize);
+
+    memcpy(localIv, iv, AES_BLOCK_SIZE);
+    switch (mode) {
+    case Mode::UNENCRYPTED:
+        memcpy(&reference[0], base, totalSize);
+        break;
+    case Mode::AES_CTR:
+        aes_ctr_decrypt(&reference[0], base, localIv, subSamples, key);
+        break;
+    case Mode::AES_CBC:
+        aes_cbc_decrypt(&reference[0], base, localIv, subSamples, key);
+        break;
+    case Mode::AES_CBC_CTS:
+        EXPECT_TRUE(false) << "AES_CBC_CTS mode not supported";
+        break;
+    }
+
+    // compare reference to decrypted data which is at base + total size
+    EXPECT_EQ(0, memcmp(static_cast<void*>(&reference[0]),
+                        static_cast<void*>(base + totalSize), totalSize))
+            << "decrypt data mismatch";
+    return totalSize;
+}
+
+/**
+ * Decrypt a list of clear+encrypted subsamples using the specified key
+ * in AES-CTR mode
+ */
+void DrmHalVendorDecryptTest::aes_ctr_decrypt(uint8_t* dest, uint8_t* src,
+        uint8_t* iv, const hidl_vec<SubSample>& subSamples,
+        const vector<uint8_t>& key) {
+
+    AES_KEY decryptionKey;
+    AES_set_encrypt_key(&key[0], 128, &decryptionKey);
+
+    size_t offset = 0;
+    unsigned blockOffset = 0;
+    uint8_t previousEncryptedCounter[AES_BLOCK_SIZE];
+    memset(previousEncryptedCounter, 0, AES_BLOCK_SIZE);
+
+    for (size_t i = 0; i < subSamples.size(); i++) {
+        const SubSample& subSample = subSamples[i];
+
+        if (subSample.numBytesOfClearData > 0) {
+            memcpy(dest + offset, src + offset, subSample.numBytesOfClearData);
+            offset += subSample.numBytesOfClearData;
+        }
+
+        if (subSample.numBytesOfEncryptedData > 0) {
+            AES_ctr128_encrypt(src + offset, dest + offset,
+                    subSample.numBytesOfEncryptedData, &decryptionKey,
+                    iv, previousEncryptedCounter, &blockOffset);
+            offset += subSample.numBytesOfEncryptedData;
         }
     }
 }
 
 /**
- * Positive decrypt test.  "Decrypt" a single clear
- * segment.  Verify data matches.
+ * Decrypt a list of clear+encrypted subsamples using the specified key
+ * in AES-CBC mode
  */
-TEST_P(DrmHalVendorDecryptTest, ClearSegmentTest) {
-    vector<DrmHalVTSVendorModule_V1::ContentConfiguration> configurations =
-            vendorModule->getContentConfigurations();
-    for (auto config : configurations) {
-        const size_t kSegmentSize = 1024;
-        const size_t kSegmentIndex = 0;
-        uint8_t iv[16] = {0};
-
-        sp<IMemory> sharedMemory =
-                getDecryptMemory(kSegmentSize * 2, kSegmentIndex);
-
-        SharedBuffer sourceBuffer = {
-                .bufferId = kSegmentIndex, .offset = 0, .size = kSegmentSize};
-        fillRandom(sharedMemory);
-
-        DestinationBuffer destBuffer = {.type = BufferType::SHARED_MEMORY,
-                                        {.bufferId = kSegmentIndex,
-                                         .offset = kSegmentSize,
-                                         .size = kSegmentSize},
-                                        .secureMemory = nullptr};
-
-        Pattern noPattern = {0, 0};
-        vector<SubSample> subSamples = {{.numBytesOfClearData = kSegmentSize,
-                                         .numBytesOfEncryptedData = 0}};
-        uint64_t offset = 0;
+void DrmHalVendorDecryptTest::aes_cbc_decrypt(uint8_t* dest, uint8_t* src,
+        uint8_t* iv, const hidl_vec<SubSample>& subSamples,
+        const vector<uint8_t>& key) {
+    AES_KEY decryptionKey;
+    AES_set_encrypt_key(&key[0], 128, &decryptionKey);
+
+    size_t offset = 0;
+    size_t num = 0;
+    size_t ecount_buf = 0;
+    for (size_t i = 0; i < subSamples.size(); i++) {
+        const SubSample& subSample = subSamples[i];
+
+        memcpy(dest + offset, src + offset, subSample.numBytesOfClearData);
+        offset += subSample.numBytesOfClearData;
+
+        AES_cbc_encrypt(src + offset, dest + offset, subSample.numBytesOfEncryptedData,
+                &decryptionKey, iv, 0 /* decrypt */);
+        offset += subSample.numBytesOfEncryptedData;
+    }
+}
 
+
+/**
+ * Test key status with empty session id, should return BAD_VALUE
+ */
+TEST_P(DrmHalVendorDecryptTest, QueryKeyStatusInvalidSession) {
+    SessionId sessionId;
+    auto res = drmPlugin->queryKeyStatus(sessionId,
+            [&](Status status, KeyedVector /* info */) {
+                EXPECT_EQ(Status::BAD_VALUE, status);
+            });
+    EXPECT_OK(res);
+}
+
+
+/**
+ * Test key status.  There should be no key status prior to loading keys
+ */
+TEST_P(DrmHalVendorDecryptTest, QueryKeyStatusWithNoKeys) {
+    auto sessionId = openSession();
+    auto keyStatus = queryKeyStatus(sessionId);
+    EXPECT_EQ(0u, keyStatus.size());
+    closeSession(sessionId);
+}
+
+
+/**
+ * Test key status.  There should be key status after loading keys.
+ */
+TEST_P(DrmHalVendorDecryptTest, QueryKeyStatus) {
+    for (auto config : contentConfigurations) {
         auto sessionId = openSession();
         loadKeys(sessionId, config);
+        auto keyStatus = queryKeyStatus(sessionId);
+        EXPECT_NE(0u, keyStatus.size());
+        closeSession(sessionId);
+    }
+}
 
-        Status status = cryptoPlugin->setMediaDrmSession(sessionId);
-        EXPECT_EQ(Status::OK, status);
+/**
+ * Positive decrypt test. "Decrypt" a single clear segment and verify.
+ */
+TEST_P(DrmHalVendorDecryptTest, ClearSegmentTest) {
+    for (auto config : contentConfigurations) {
+        for (auto key : config.keys) {
+            const size_t kSegmentSize = 1024;
+            vector<uint8_t> iv(AES_BLOCK_SIZE, 0);
+            const Pattern noPattern = {0, 0};
+            const vector<SubSample> subSamples = {{.numBytesOfClearData = kSegmentSize,
+                                                   .numBytesOfEncryptedData = 0}};
+            auto sessionId = openSession();
+            loadKeys(sessionId, config);
+
+            Status status = cryptoPlugin->setMediaDrmSession(sessionId);
+            EXPECT_EQ(Status::OK, status);
+
+            uint32_t byteCount = decrypt(Mode::UNENCRYPTED, key.isSecure, toHidlArray(key.keyId),
+                    &iv[0], subSamples, noPattern, key.clearContentKey, Status::OK);
+            EXPECT_EQ(kSegmentSize, byteCount);
+
+            closeSession(sessionId);
+        }
+    }
+}
 
-        const bool kNotSecure = false;
-        auto res = cryptoPlugin->decrypt(
-                kNotSecure, toHidlArray(config.keys[0].keyId), iv,
-                Mode::UNENCRYPTED, noPattern, subSamples, sourceBuffer, offset,
-                destBuffer, [&](Status status, uint32_t bytesWritten,
-                                string detailedError) {
-                    EXPECT_EQ(Status::OK, status) << "Failure in decryption "
-                                                     "for configuration "
-                                                  << config.name << ": "
-                                                  << detailedError;
-                    EXPECT_EQ(bytesWritten, kSegmentSize);
-                });
-        EXPECT_OK(res);
-        uint8_t* base = static_cast<uint8_t*>(
-                static_cast<void*>(sharedMemory->getPointer()));
+/**
+ * Positive decrypt test.  Decrypt a single segment using aes_ctr.
+ * Verify data matches.
+ */
+TEST_P(DrmHalVendorDecryptTest, EncryptedAesCtrSegmentTest) {
+    for (auto config : contentConfigurations) {
+        for (auto key : config.keys) {
+            const size_t kSegmentSize = 1024;
+            vector<uint8_t> iv(AES_BLOCK_SIZE, 0);
+            const Pattern noPattern = {0, 0};
+            const vector<SubSample> subSamples = {{.numBytesOfClearData = kSegmentSize,
+                                                   .numBytesOfEncryptedData = 0}};
+            auto sessionId = openSession();
+            loadKeys(sessionId, config);
+
+            Status status = cryptoPlugin->setMediaDrmSession(sessionId);
+            EXPECT_EQ(Status::OK, status);
+
+            uint32_t byteCount = decrypt(Mode::AES_CTR, key.isSecure, toHidlArray(key.keyId),
+                    &iv[0], subSamples, noPattern, key.clearContentKey, Status::OK);
+            EXPECT_EQ(kSegmentSize, byteCount);
+
+            closeSession(sessionId);
+        }
+    }
+}
 
-        EXPECT_EQ(0,
-                  memcmp(static_cast<void*>(base),
-                         static_cast<void*>(base + kSegmentSize), kSegmentSize))
-                << "decrypt data mismatch";
-        closeSession(sessionId);
+/**
+ * Negative decrypt test. Decrypt without loading keys.
+ */
+TEST_P(DrmHalVendorDecryptTest, EncryptedAesCtrSegmentTestNoKeys) {
+    for (auto config : contentConfigurations) {
+        for (auto key : config.keys) {
+            vector<uint8_t> iv(AES_BLOCK_SIZE, 0);
+            const Pattern noPattern = {0, 0};
+            const vector<SubSample> subSamples = {{.numBytesOfClearData = 256,
+                                                   .numBytesOfEncryptedData = 256}};
+            auto sessionId = openSession();
+
+            Status status = cryptoPlugin->setMediaDrmSession(sessionId);
+            EXPECT_EQ(Status::OK, status);
+
+            uint32_t byteCount = decrypt(Mode::AES_CTR, key.isSecure,
+                    toHidlArray(key.keyId), &iv[0], subSamples, noPattern,
+                    key.clearContentKey, Status::ERROR_DRM_NO_LICENSE);
+            EXPECT_EQ(0u, byteCount);
+
+            closeSession(sessionId);
+        }
     }
 }
 
 /**
+ * Test key removal.  Load keys then remove them and verify that
+ * decryption can't be performed.
+ */
+TEST_P(DrmHalVendorDecryptTest, AttemptDecryptWithKeysRemoved) {
+    for (auto config : contentConfigurations) {
+        for (auto key : config.keys) {
+            vector<uint8_t> iv(AES_BLOCK_SIZE, 0);
+            const Pattern noPattern = {0, 0};
+            const vector<SubSample> subSamples = {{.numBytesOfClearData = 256,
+                                                   .numBytesOfEncryptedData = 256}};
+            auto sessionId = openSession();
+
+            Status status = cryptoPlugin->setMediaDrmSession(sessionId);
+            EXPECT_EQ(Status::OK, status);
+
+            loadKeys(sessionId, config);
+            removeKeys(sessionId);
+
+            uint32_t byteCount = decrypt(Mode::AES_CTR, key.isSecure,
+                    toHidlArray(key.keyId), &iv[0], subSamples, noPattern,
+                    key.clearContentKey, Status::ERROR_DRM_DECRYPT);
+            EXPECT_EQ(0u, byteCount);
+
+            closeSession(sessionId);
+        }
+    }
+}
+
+
+/**
  * Instantiate the set of test cases for each vendor module
  */
 
 INSTANTIATE_TEST_CASE_P(
         DrmHalVendorFactoryTestCases, DrmHalVendorFactoryTest,
-        testing::ValuesIn(gVendorModules->getVendorModulePaths()));
+        testing::ValuesIn(gVendorModules->getPathList()));
 
 INSTANTIATE_TEST_CASE_P(
         DrmHalVendorPluginTestCases, DrmHalVendorPluginTest,
-        testing::ValuesIn(gVendorModules->getVendorModulePaths()));
+        testing::ValuesIn(gVendorModules->getPathList()));
 
 INSTANTIATE_TEST_CASE_P(
         DrmHalVendorDecryptTestCases, DrmHalVendorDecryptTest,
-        testing::ValuesIn(gVendorModules->getVendorModulePaths()));
+        testing::ValuesIn(gVendorModules->getPathList()));
 
 int main(int argc, char** argv) {
 #if defined(__LP64__)
-    const char *kModulePath = "/data/local/tmp/64/lib";
+    const charkModulePath = "/data/local/tmp/64/lib";
 #else
-    const char *kModulePath = "/data/local/tmp/32/lib";
+    const charkModulePath = "/data/local/tmp/32/lib";
 #endif
     gVendorModules = new drm_vts::VendorModules(kModulePath);
+    if (gVendorModules->getPathList().size() == 0) {
+        std::cerr << "No vendor modules found in " << kModulePath <<
+                ", exiting" << std::endl;
+        exit(-1);
+    }
     ::testing::InitGoogleTest(&argc, argv);
     return RUN_ALL_TESTS();
 }
index 34af6f8..bb232ae 100644 (file)
@@ -29,44 +29,37 @@ using std::vector;
 using std::unique_ptr;
 
 namespace drm_vts {
-vector<string> VendorModules::getVendorModulePaths() {
-    if (mModuleList.size() > 0) {
-        return mModuleList;
-    }
-
-    DIR* dir = opendir(mModulesPath.c_str());
+void VendorModules::scanModules(const std::string &directory) {
+    DIR* dir = opendir(directory.c_str());
     if (dir == NULL) {
-        ALOGE("Unable to open drm VTS vendor directory %s",
-              mModulesPath.c_str());
-        return mModuleList;
-    }
-
-    struct dirent* entry;
-    while ((entry = readdir(dir))) {
-        string fullpath = mModulesPath + "/" + entry->d_name;
-        if (endsWith(fullpath, ".so")) {
-            mModuleList.push_back(fullpath);
+        ALOGE("Unable to open drm VTS vendor directory %s", directory.c_str());
+    } else {
+        struct dirent* entry;
+        while ((entry = readdir(dir))) {
+            ALOGD("checking file %s", entry->d_name);
+            string fullpath = directory + "/" + entry->d_name;
+            if (endsWith(fullpath, ".so")) {
+                mPathList.push_back(fullpath);
+            }
         }
+        closedir(dir);
     }
-
-    closedir(dir);
-    return mModuleList;
 }
 
-DrmHalVTSVendorModule* VendorModules::getVendorModule(const string& path) {
-    unique_ptr<SharedLibrary>& library = mOpenLibraries[path];
-    if (!library) {
-        library = unique_ptr<SharedLibrary>(new SharedLibrary(path));
+DrmHalVTSVendorModule* VendorModules::getModule(const string& path) {
+    if (mOpenLibraries.find(path) == mOpenLibraries.end()) {
+        auto library = std::make_unique<SharedLibrary>(path);
         if (!library) {
             ALOGE("failed to map shared library %s", path.c_str());
             return NULL;
         }
+        mOpenLibraries[path] = std::move(library);
     }
+    const unique_ptr<SharedLibrary>& library = mOpenLibraries[path];
     void* symbol = library->lookup("vendorModuleFactory");
     if (symbol == NULL) {
         ALOGE("getVendorModule failed to lookup 'vendorModuleFactory' in %s: "
-              "%s",
-              path.c_str(), library->lastError());
+              "%s", path.c_str(), library->lastError());
         return NULL;
     }
     typedef DrmHalVTSVendorModule* (*ModuleFactory)();
index 5371a0d..ca538f6 100644 (file)
@@ -30,27 +30,33 @@ class VendorModules {
      * Initialize with a file system path where the shared libraries
      * are to be found.
      */
-    explicit VendorModules(const std::string& path) : mModulesPath(path) {}
+    explicit VendorModules(const std::string& dir) {
+        scanModules(dir);
+    }
     ~VendorModules() {}
 
     /**
-     * Return a list of paths to available vendor modules.
-     */
-    std::vector<std::string> getVendorModulePaths();
-
-    /**
      * Retrieve a DrmHalVTSVendorModule given its full path.  The
      * getAPIVersion method can be used to determine the versioned
      * subclass type.
      */
-    DrmHalVTSVendorModule* getVendorModule(const std::string& path);
+    DrmHalVTSVendorModule* getModule(const std::string& path);
+
+    /**
+     * Return the list of paths to available vendor modules.
+     */
+    std::vector<std::string> getPathList() const {return mPathList;}
 
    private:
-    std::string mModulesPath;
-    std::vector<std::string> mModuleList;
+    std::vector<std::string> mPathList;
     std::map<std::string, std::unique_ptr<SharedLibrary>> mOpenLibraries;
 
-    inline bool endsWith(const std::string& str, const std::string& suffix) {
+    /**
+     * Scan the list of paths to available vendor modules.
+     */
+    void scanModules(const std::string& dir);
+
+    inline bool endsWith(const std::string& str, const std::string& suffix) const {
         if (suffix.size() > str.size()) return false;
         return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
     }