Skip to content

Commit

Permalink
style: fix python lint
Browse files Browse the repository at this point in the history
  • Loading branch information
wlruys committed Nov 10, 2023
1 parent 6c7d015 commit d0ff5bc
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 337 deletions.
135 changes: 11 additions & 124 deletions src/python/parla/common/parray/coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,13 @@ class MemoryOperation:
EVICT = 2

# Flag
<<<<<<< HEAD
SWITCH_DEVICE_FLAG = (
101
) # if the flag is set, it means dst is not the current device
LOAD_SUBARRAY = (
102
) # if the flag is set, it means a subarray of src should be loaded
ENSURE_IS_COMPLETE = (
103
) # if the flag is set, check data will also check if the data is complete
=======
SWITCH_DEVICE_FLAG = 101 # if the flag is set, it means dst is not the current device

# if the flag is set, it means dst is not the current device
SWITCH_DEVICE_FLAG = 101
# if the flag is set, it means a subarray of src should be loaded
LOAD_SUBARRAY = 102
# if the flag is set, check data will also check if the data is complete
ENSURE_IS_COMPLETE = 103
>>>>>>> origin/dev

def __init__(self, inst: int = NOOP, dst: int = -1, src: int = -1, flag: int = []):
self.inst = inst
Expand All @@ -61,15 +51,10 @@ def error() -> MemoryOperation:
return MemoryOperation(MemoryOperation.ERROR)

@staticmethod
<<<<<<< HEAD
def load(
dst: int, src: int, on_different_device: bool = False, is_subarray: bool = False
) -> MemoryOperation:
"""load all data from src to dst
=======
def load(dst: int, src: int, on_different_device: bool = False, is_subarray: bool = False) -> MemoryOperation:
""" load all data from src to dst
>>>>>>> origin/dev
Need to switch device if `on_different_device` is true
This could known by checking flag = SWITCH_DEVICE_FLAG
Expand Down Expand Up @@ -120,33 +105,8 @@ def __init__(self, init_owner: int, num_gpu: int, cyparray_state: CyPArrayState)
# If copy is complete, value is state
# if not, value is a Dict{slices_hash: state}
self._local_states = {
<<<<<<< HEAD
n: self.INVALID for n in range(num_gpu)
} # init GPU status
self._local_states[CPU_INDEX] = self.INVALID # init CPU status
# If copy is complete, value is version
# if not, value is a Dict{slices_hash: version}
self._versions = {
n: -1 for n in range(num_gpu)
} # init copy version (-1 means no data)
self._versions[CPU_INDEX] = -1
# fields used to support fine grained data movement
self._is_complete = {
n: None for n in range(num_gpu)
} # does the device own a complete copy? None means neither
self._is_complete[CPU_INDEX] = None
self._local_states[init_owner] = self.MODIFIED # initial state is MODIFIED
self.owner = (
init_owner
) # the device that has the complete copy (take the role of main memory)
self._versions[init_owner] = 0 # the first version is 0
self._is_complete[init_owner] = True # the copy is complete
self._latest_version = 0 # the latest version in the system
=======
n: self.INVALID for n in range(num_gpu)} # init GPU status
# init CPU status
self._local_states[CPU_INDEX] = self.INVALID

Expand All @@ -171,7 +131,6 @@ def __init__(self, init_owner: int, num_gpu: int, cyparray_state: CyPArrayState)
self._is_complete[init_owner] = True
# the latest version in the system
self._latest_version = 0
>>>>>>> origin/dev

# held the lock when updating states
self._lock = threading.Lock()
Expand Down Expand Up @@ -200,18 +159,13 @@ def _owner_is_latest(self) -> bool:
"""True if owner's has latest version"""
return self._versions[self.owner] == self._latest_version

<<<<<<< HEAD
def _write_back_to(
self,
device_id: int,
new_state: int,
on_different_device: bool = False,
this_device_id: int = None,
) -> List[MemoryOperation]:
=======
def _write_back_to(self, device_id: int, new_state: int, on_different_device: bool = False,
this_device_id: int = None) -> List[MemoryOperation]:
>>>>>>> origin/dev
"""
Generate the list of write back MemoryOperation.
Which make `device_id` has the latest version with a complete copy.
Expand Down Expand Up @@ -289,20 +243,14 @@ def _write_back_to(self, device_id: int, new_state: int, on_different_device: bo
evict_list.remove(this_device_id)

if current_version < latest_complete_version:
target = [latest_complete_copy_id] + \
list(target) # complete copy first
target = [latest_complete_copy_id] + list(target) # complete copy first

# update latest version
self._versions[device_id] = self._latest_version
<<<<<<< HEAD
return [
MemoryOperation.load(device_id, t, on_different_device=on_different_device)
for t in target
] + [MemoryOperation.evict(t) for t in evict_list]
=======
return [MemoryOperation.load(device_id, t, on_different_device=on_different_device) for t in target] \
+ [MemoryOperation.evict(t) for t in evict_list]
>>>>>>> origin/dev

def read(self, device_id: int, slices_hash: int = None) -> List[MemoryOperation]:
"""Tell the protocol that this device read from the copy.
Expand All @@ -320,14 +268,8 @@ def read(self, device_id: int, slices_hash: int = None) -> List[MemoryOperation]
operations = []

if slices_hash is not None: # move a subarray
<<<<<<< HEAD
if (
self._is_complete[device_id] is True
): # use existing complete data at this device
=======
# use existing complete data at this device
if self._is_complete[device_id] is True:
>>>>>>> origin/dev
device_local_state = self._local_states[device_id]
else:
if not isinstance(self._local_states[device_id], dict):
Expand All @@ -344,14 +286,9 @@ def read(self, device_id: int, slices_hash: int = None) -> List[MemoryOperation]
# writeback this subarrays and then copy complete data from owner

# write back to owner
<<<<<<< HEAD
operations.extend(
self._write_back_to(self.owner, self.SHARED, on_different_device=True)
)
=======
operations.extend(self._write_back_to(
self.owner, self.SHARED, on_different_device=True))
>>>>>>> origin/dev

# evict previous subarries at device_id
operations.append(MemoryOperation.evict(device_id))
Expand All @@ -361,14 +298,8 @@ def read(self, device_id: int, slices_hash: int = None) -> List[MemoryOperation]

self._is_complete[device_id] = True
self._versions[device_id] = self._versions[self.owner]
<<<<<<< HEAD
self._local_states[
self.owner
] = self.SHARED # owner is updated, so it is in SHARED states
=======
# owner is updated, so it is in SHARED states
self._local_states[self.owner] = self.SHARED
>>>>>>> origin/dev
self._local_states[device_id] = self.SHARED
self._cyparray_state.set_valid_on_device(self.owner, True)
self._cyparray_state.set_valid_on_device(device_id, True)
Expand Down Expand Up @@ -398,8 +329,7 @@ def read(self, device_id: int, slices_hash: int = None) -> List[MemoryOperation]
else:
if device_local_state == self.INVALID:
if self._is_complete[device_id]:
operations.extend(
self._write_back_to(device_id, self.SHARED))
operations.extend(self._write_back_to(device_id, self.SHARED))

# change owner
if self._owner_is_latest():
Expand All @@ -413,16 +343,11 @@ def read(self, device_id: int, slices_hash: int = None) -> List[MemoryOperation]
self.owner = device_id
self._versions[device_id] = self._latest_version
else: # since we assume all array are disjoint, so could load directly
<<<<<<< HEAD
operations.append(
MemoryOperation.load(
dst=device_id, src=self.owner, is_subarray=True
)
)
=======
operations.append(MemoryOperation.load(
dst=device_id, src=self.owner, is_subarray=True))
>>>>>>> origin/dev

self._versions[device_id][slices_hash] = self._versions[self.owner]
else:
Expand Down Expand Up @@ -454,14 +379,8 @@ def write(self, device_id: int, slices_hash: int = None) -> List[MemoryOperation
operations = []

if slices_hash is not None: # move a subarray
<<<<<<< HEAD
if (
self._is_complete[device_id] is True
): # use existing complete data at this device
=======
# use existing complete data at this device
if self._is_complete[device_id] is True:
>>>>>>> origin/dev
device_local_state = self._local_states[device_id]
else:
if not isinstance(self._local_states[device_id], dict):
Expand All @@ -478,14 +397,9 @@ def write(self, device_id: int, slices_hash: int = None) -> List[MemoryOperation
# writeback this subarrays and then copy complete data from owner

# write back to owner
<<<<<<< HEAD
operations.extend(
self._write_back_to(self.owner, self.MODIFIED, on_different_device=True)
)
=======
operations.extend(self._write_back_to(self.owner, self.MODIFIED,
on_different_device=True))
>>>>>>> origin/dev

# copy from owner
operations.append(MemoryOperation.load(device_id, self.owner))
Expand All @@ -509,8 +423,7 @@ def write(self, device_id: int, slices_hash: int = None) -> List[MemoryOperation

if device_id == self.owner:
if device_local_state != self.MODIFIED:
operations.extend(self._write_back_to(
device_id, self.MODIFIED))
operations.extend(self._write_back_to(device_id, self.MODIFIED))

self._latest_version += 1
self._versions[device_id] = self._latest_version
Expand All @@ -519,39 +432,27 @@ def write(self, device_id: int, slices_hash: int = None) -> List[MemoryOperation
else:
if device_local_state == self.INVALID:
if self._is_complete[device_id]:
operations.extend(self._write_back_to(
device_id, self.MODIFIED))
operations.extend(self._write_back_to(device_id, self.MODIFIED))

self._latest_version += 1
self._versions[device_id] = self._latest_version

# change owner
self.owner = device_id
else: # since we assume all subarrays are disjoint, could load directly
<<<<<<< HEAD
operations.append(
MemoryOperation.load(
dst=device_id, src=self.owner, is_subarray=True
)
)
=======
operations.append(MemoryOperation.load(
dst=device_id, src=self.owner, is_subarray=True))
>>>>>>> origin/dev

self._versions[device_id][slices_hash] = (
self._versions[self.owner] + 1
)
if self._owner_is_latest():
self._latest_version += 1
<<<<<<< HEAD
self._local_states[
self.owner
] = self.INVALID # invalidate overlapping copy
=======
# invalidate overlapping copy
self._local_states[self.owner] = self.INVALID
>>>>>>> origin/dev
self._cyparray_state.set_valid_on_device(self.owner, False)
elif device_local_state == self.SHARED:
if self._is_complete[device_id]:
Expand All @@ -562,8 +463,7 @@ def write(self, device_id: int, slices_hash: int = None) -> List[MemoryOperation
self.owner = device_id

# evict others
operations.extend(self._write_back_to(
device_id, self.MODIFIED))
operations.extend(self._write_back_to(device_id, self.MODIFIED))
else:
self._latest_version += 1
self._versions[device_id][slices_hash] = self._latest_version
Expand All @@ -574,21 +474,13 @@ def write(self, device_id: int, slices_hash: int = None) -> List[MemoryOperation
if not isinstance(state, dict):
if id != device_id:
self._local_states[id] = self.INVALID
self._cyparray_state.set_valid_on_device(
id, False)
<<<<<<< HEAD
if (
id != self.owner
): # owner's buffer will be kept (so won't lost the last complete copy)
=======
self._cyparray_state.set_valid_on_device(id, False)

# owner's buffer will be kept (so won't lost the last complete copy)
if id != self.owner:
>>>>>>> origin/dev
self._versions[id] = -1
self._is_complete[id] = None
operations.append(
MemoryOperation.evict(id))
operations.append(MemoryOperation.evict(id))
if len(operations) == 0:
operations.append(MemoryOperation.noop())
else:
Expand Down Expand Up @@ -646,7 +538,6 @@ def evict(
if evict_last_copy:
if keep_one_copy: # write back to CPU
if device_id != CPU_INDEX:
<<<<<<< HEAD
operations.extend(
self._write_back_to(
CPU_INDEX,
Expand All @@ -655,10 +546,6 @@ def evict(
this_device_id=device_id,
)
)
=======
operations.extend(self._write_back_to(
CPU_INDEX, self.MODIFIED, on_different_device=True, this_device_id=device_id))
>>>>>>> origin/dev
# special case, since `this_device_id` is set, _write_back will not evict this devic
# need to do it manually
operations.append(MemoryOperation.evict(device_id))
Expand Down
Loading

0 comments on commit d0ff5bc

Please sign in to comment.