forked from stanford-oval/storm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_storm.py
145 lines (121 loc) · 5.63 KB
/
run_storm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import toml
import logging
import argparse
from pathlib import Path
from dataclasses import dataclass
from knowledge_storm.storm_wiki.engine import STORMWikiLMConfigs, STORMWikiRunner, STORMWikiRunnerArguments
from knowledge_storm.lm import OpenAIModel
from knowledge_storm.rm import YouRM
from knowledge_storm.utils import WebPageHelper
def setup_logging(config):
"""Set up logging configuration."""
log_level = getattr(logging, config['logging']['level'].upper())
log_file = 'logs/storm.log' # Use specific log file for STORM
# Create logs directory if it doesn't exist
log_dir = os.path.dirname(log_file)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
# Clear any existing handlers
root_logger.handlers = []
# File handler with detailed formatting
file_handler = logging.FileHandler(log_file, mode='w')
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
root_logger.addHandler(file_handler)
# Console handler with simpler formatting but showing all levels
console_handler = logging.StreamHandler()
console_handler.setLevel(log_level) # Show all levels in console
console_handler.setFormatter(
logging.Formatter('%(levelname)s: %(message)s')
)
root_logger.addHandler(console_handler)
logging.info("STORM Logging initialized. Level: %s, File: %s", log_level, log_file)
def load_config(config_path='config.toml'):
"""Load configuration from TOML file."""
with open(config_path, 'r') as f:
return toml.load(f)
def main():
try:
parser = argparse.ArgumentParser(description='Run STORM from command line')
parser.add_argument('topic', type=str, help='Topic to generate article about')
parser.add_argument('--output', type=str, default='output.md',
help='Output file path (default: output.md)')
parser.add_argument('--config', type=str, default='config.toml',
help='Path to config file (default: config.toml)')
args = parser.parse_args()
# Load configuration
config = load_config(args.config)
# Set up logging
setup_logging(config)
logger = logging.getLogger(__name__)
logger.info(f"Starting STORM with topic: {args.topic}")
# Set environment variables
os.environ["OPENAI_API_KEY"] = config['openai']['api_key']
os.environ["OPENAI_API_TYPE"] = config['openai']['api_type']
os.environ["YDC_API_KEY"] = config['you']['api_key']
# Initialize LM configurations
logger.debug("Initializing language models")
lm_configs = STORMWikiLMConfigs()
openai_kwargs = {
'api_key': config['openai']['api_key'],
'temperature': config['generation']['temperature'],
'top_p': config['generation']['top_p'],
}
# Configure models for different components
conv_simulator_lm = OpenAIModel(model=config['models']['conv_simulator'], max_tokens=500, **openai_kwargs)
question_asker_lm = OpenAIModel(model=config['models']['question_asker'], max_tokens=500, **openai_kwargs)
outline_gen_lm = OpenAIModel(model=config['models']['outline_gen'], max_tokens=400, **openai_kwargs)
article_gen_lm = OpenAIModel(model=config['models']['article_gen'], max_tokens=700, **openai_kwargs)
article_polish_lm = OpenAIModel(model=config['models']['article_polish'], max_tokens=4000, **openai_kwargs)
lm_configs.set_conv_simulator_lm(conv_simulator_lm)
lm_configs.set_question_asker_lm(question_asker_lm)
lm_configs.set_outline_gen_lm(outline_gen_lm)
lm_configs.set_article_gen_lm(article_gen_lm)
lm_configs.set_article_polish_lm(article_polish_lm)
# Initialize retriever with configuration
logger.debug("Initializing retriever")
retriever = YouRM(
ydc_api_key=config['you']['api_key'],
k=config['retrieval']['search_top_k']
)
# Create runner arguments
logger.debug("Setting up runner arguments")
runner_args = STORMWikiRunnerArguments(
output_dir="output",
max_conv_turn=config['retrieval']['max_conv_turn'],
max_search_queries_per_turn=config['retrieval']['max_search_queries'],
max_thread_num=config['retrieval']['max_thread_num'],
search_top_k=config['retrieval']['search_top_k']
)
# Initialize STORM with all configurations
logger.debug("Initializing STORM runner")
storm = STORMWikiRunner(
args=runner_args,
lm_configs=lm_configs,
rm=retriever
)
# Generate article
logger.info(f"Generating article about: {args.topic}")
result = storm.run(topic=args.topic)
if result is None:
logger.error("Article generation failed - no article was produced")
return
# Save to file
output_dir = "output"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_file = os.path.join(output_dir, args.output)
logger.info(f"Saving article to: {output_file}")
with open(output_file, 'w', encoding='utf-8') as f:
f.write(result.to_string())
logger.info("Article generation completed successfully")
except Exception as e:
logger.error(f"Error during article generation: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
main()