From a30abc96249c8bb47c0067b0d6392c88940365fc Mon Sep 17 00:00:00 2001
From: Lucas Kim <lucas.kim@shopify.com>
Date: Wed, 29 Nov 2023 15:47:34 -0500
Subject: [PATCH] use foreign_key to reference parent association

---
 lib/identity_cache/parent_model_expiration.rb |  2 +-
 test/custom_primary_keys_test.rb              |  2 +-
 test/helpers/models.rb                        |  2 +-
 test/parent_model_expiration_test.rb          | 29 +++++++++++++++++++
 4 files changed, 32 insertions(+), 3 deletions(-)

diff --git a/lib/identity_cache/parent_model_expiration.rb b/lib/identity_cache/parent_model_expiration.rb
index 1167b9a5..1993a21c 100644
--- a/lib/identity_cache/parent_model_expiration.rb
+++ b/lib/identity_cache/parent_model_expiration.rb
@@ -66,7 +66,7 @@ def add_record_to_cache_expiry_set(parents_to_expire, record)
 
     def parents_to_expire_on_changes(parents_to_expire, association_name, cached_associations)
       parent_association = self.class.reflect_on_association(association_name)
-      foreign_key = parent_association.association_foreign_key
+      foreign_key = parent_association.foreign_key
 
       new_parent = send(association_name)
 
diff --git a/test/custom_primary_keys_test.rb b/test/custom_primary_keys_test.rb
index f60ec442..03ceffe6 100644
--- a/test/custom_primary_keys_test.rb
+++ b/test/custom_primary_keys_test.rb
@@ -5,7 +5,7 @@
 class CustomPrimaryKeysTest < IdentityCache::TestCase
   def setup
     super
-    CustomParentRecord.cache_has_many(:custom_child_record)
+    CustomParentRecord.cache_has_many(:custom_child_records)
     CustomChildRecord.cache_belongs_to(:custom_parent_record)
     @parent_record = CustomParentRecord.create!(parent_primary_key: 1)
     @child_record_1 = CustomChildRecord.create!(custom_parent_record: @parent_record, child_primary_key: 1)
diff --git a/test/helpers/models.rb b/test/helpers/models.rb
index 01cc8c0e..668c5028 100644
--- a/test/helpers/models.rb
+++ b/test/helpers/models.rb
@@ -87,7 +87,7 @@ class StiRecordTypeA < StiRecord
 
 class CustomParentRecord < ActiveRecord::Base
   include IdentityCache
-  has_many :custom_child_record, foreign_key: :parent_id
+  has_many :custom_child_records, foreign_key: :parent_id
   self.primary_key = "parent_primary_key"
 end
 
diff --git a/test/parent_model_expiration_test.rb b/test/parent_model_expiration_test.rb
index 4acb4508..d7bc0d69 100644
--- a/test/parent_model_expiration_test.rb
+++ b/test/parent_model_expiration_test.rb
@@ -33,4 +33,33 @@ def test_recursively_expire_parent_caches
     fetched_name = Item.fetch(item.id).fetch_associated_records.first.fetch_deeply_associated_records.first.name
     assert_equal("updated child", fetched_name)
   end
+
+  def test_custom_parent_foreign_key_expiry
+    define_cache_indexes = lambda do
+      CustomParentRecord.cache_has_many(:custom_child_records, embed: true)
+      CustomChildRecord.cache_belongs_to(:custom_parent_record)
+    end
+    define_cache_indexes.call
+    old_parent = CustomParentRecord.new(parent_primary_key: 1)
+    old_parent.save!
+    child = CustomChildRecord.new(child_primary_key: 10, parent_id: old_parent.id)
+    child.save!
+
+    # Warm the custom_child_records embedded cache on the old parent record
+    assert_equal(10, CustomParentRecord.fetch(1).fetch_custom_child_records.first.child_primary_key)
+
+    new_parent = CustomParentRecord.new(parent_primary_key: 2)
+    new_parent.save!
+    # Warm the custom_child_records embedded cache on the new parent record
+    assert_empty(CustomParentRecord.fetch(2).fetch_custom_child_records)
+
+    # Now invoke a db update, where the child switches parent
+    child.parent_id = new_parent.parent_primary_key
+    child.save!
+
+    # the old parent's custom_child_records embedded cache should be invalidated and empty
+    assert_empty(CustomParentRecord.fetch(1).fetch_custom_child_records)
+    # the new parent's custom_child_records embedded cache should be invalidated and filled with the new association
+    assert_equal(10, CustomParentRecord.fetch(2).fetch_custom_child_records.first.child_primary_key)
+  end
 end