Skip to content
This repository has been archived by the owner on Oct 5, 2023. It is now read-only.

Update play_dm.py to be more compatible with play.py #252

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions play_dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
import time
import argparse

from generator.gpt2.gpt2_generator import *
from generator.human_dm import *
Expand All @@ -11,6 +12,13 @@

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

parser = argparse.ArgumentParser("Play AIDungeon DM Mode")
parser.add_argument(
"--cpu",
action="store_true",
help="Force using CPU instead of GPU."
)


class AIPlayer:
def __init__(self, generator):
Expand All @@ -20,13 +28,33 @@ def get_action(self, prompt):
return self.generator.generate_raw(prompt)


def play_dm():
def play_dm(args):
"""
Entry/main function for starting AIDungeon DM Mode

Arguments:
args (namespace): Arguments returned by the
ArgumentParser
"""

console_print("Initializing AI Dungeon DM Mode")
generator = GPT2Generator(temperature=0.9)
generator = GPT2Generator(force_cpu=args.cpu, temperature=0.9)

story_manager = UnconstrainedStoryManager(HumanDM())
context, prompt = select_game()
(
setting_key,
character_key,
name,
character,
setting_description,
) = select_game()

if setting_key == "custom":
context, prompt = get_custom_prompt()
else:
context, prompt = get_curated_exposition(
setting_key, character_key, name, character, setting_description
)
console_print(context + prompt)
story_manager.start_new_story(prompt, context=context, upload_story=False)

Expand All @@ -48,4 +76,5 @@ def play_dm():


if __name__ == "__main__":
play_dm()
args = parser.parse_args()
play_dm(args)