forked from dusty-nv/jetson-containers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
args.py
171 lines (139 loc) · 10.7 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#!/usr/bin/env python3
import os
import sys
import argparse
import logging
from .log import LogFormatter
from .prompts import DefaultChatPrompts, DefaultCompletionPrompts
class ArgParser(argparse.ArgumentParser):
"""
Adds selectable extra args that are commonly used by this project
"""
Defaults = ['model', 'chat', 'generation', 'log']
Audio = ['audio_input', 'audio_output']
Video = ['video_input', 'video_output']
Riva = ['asr', 'tts']
def __init__(self, extras=Defaults, **kwargs):
super().__init__(formatter_class=argparse.ArgumentDefaultsHelpFormatter, **kwargs)
# LLM
if 'model' in extras:
self.add_argument("--model", type=str, default=None, #required=True,
help="path to the model, or repository on HuggingFace Hub")
self.add_argument("--quant", type=str, default=None,
help="path to the quantized weights (AWQ uses this)")
self.add_argument("--api", type=str, default=None, choices=['auto_gptq', 'awq', 'hf', 'mlc'],
help="specify the API to use (otherwise inferred)")
self.add_argument("--vision-model", type=str, default=None,
help="for VLMs, manually select the CLIP vision model to use (e.g. openai/clip-vit-large-patch14-336 for higher-res)")
if 'chat' in extras or 'prompt' in extras:
self.add_argument("--prompt", action='append', nargs='*',
help="add a prompt (can be prompt text or path to .txt, .json, or image file)")
self.add_argument("--save-mermaid", type=str, default=None, help="save mermaid diagram of the pipeline to this file")
if 'chat' in extras:
self.add_argument("--system-prompt", type=str, default=None, help="override the system prompt instruction")
self.add_argument("--chat-template", type=str, default=None, #choices=list(ChatTemplates.keys()),
help="manually select the chat template ('llama-2', 'llava-v1', 'vicuna-v1')")
if 'generation' in extras:
self.add_argument("--max-new-tokens", type=int, default=128,
help="the maximum number of new tokens to generate, in addition to the prompt")
self.add_argument("--min-new-tokens", type=int, default=-1,
help="force the model to generate a minimum number of output tokens")
self.add_argument("--do-sample", action="store_true",
help="enable output token sampling using temperature and top_p")
self.add_argument("--temperature", type=float, default=0.7,
help="token sampling temperature setting when --do-sample is used")
self.add_argument("--top-p", type=float, default=0.95,
help="token sampling top_p setting when --do-sample is used")
self.add_argument("--repetition-penalty", type=float, default=1.0,
help="the parameter for repetition penalty. 1.0 means no penalty")
# VIDEO
if 'video_input' in extras:
self.add_argument("--video-input", type=str, default=None, help="video camera device name, stream URL, file/dir path")
self.add_argument("--video-input-width", type=int, default=None, help="manually set the resolution of the video input")
self.add_argument("--video-input-height", type=int, default=None, help="manually set the resolution of the video input")
self.add_argument("--video-input-codec", type=str, default=None, choices=['h264', 'h265', 'vp8', 'vp9', 'mjpeg'], help="manually set the input video codec to use")
self.add_argument("--video-input-framerate", type=int, default=None, help="set the desired framerate of input video")
self.add_argument("--video-input-save", type=str, default=None, help="path to video file to save the incoming video feed to")
if 'video_output' in extras:
self.add_argument("--video-output", type=str, default=None, help="display, stream URL, file/dir path")
self.add_argument("--video-output-codec", type=str, default=None, choices=['h264', 'h265', 'vp8', 'vp9', 'mjpeg'], help="set the output video codec to use")
self.add_argument("--video-output-bitrate", type=int, default=None, help="set the output bitrate to use")
self.add_argument("--video-output-save", type=str, default=None, help="path to video file to save the outgoing video stream to")
# AUDIO
if 'audio_input' not in extras and 'asr' in extras:
extras += ['audio_input']
if 'audio_input' in extras:
self.add_argument("--audio-input-device", type=int, default=None, help="audio input device/microphone to use for ASR")
self.add_argument("--audio-input-channels", type=int, default=1, help="The number of input audio channels to use")
if 'audio_output' in extras:
self.add_argument("--audio-output-device", type=int, default=None, help="audio output interface device index (PortAudio)")
self.add_argument("--audio-output-file", type=str, default=None, help="save audio output to wav file using the given path")
self.add_argument("--audio-output-channels", type=int, default=1, help="the number of output audio channels to use")
if 'audio_input' in extras or 'audio_output' in extras:
self.add_argument("--list-audio-devices", action="store_true", help="List output audio devices indices.")
if any(x in extras for x in ('audio_input', 'audio_output', 'asr', 'tts')):
self.add_argument("--sample-rate-hz", default=48000, help="the audio sample rate in Hz")
# ASR/TTS
if 'asr' in extras or 'tts' in extras:
self.add_argument("--riva-server", default="localhost:50051", help="URI to the Riva GRPC server endpoint.")
self.add_argument("--language-code", default="en-US", help="Language code of the ASR/TTS to be used.")
if 'tts' in extras:
self.add_argument("--voice", type=str, default="English-US.Female-1", help="Voice model name to use for TTS")
self.add_argument("--voice-rate", type=str, default="default", help="TTS SSML voice speaker rate (between 25-250%%)")
self.add_argument("--voice-pitch", type=str, default="default", help="TTS SSML voice pitch shift")
self.add_argument("--voice-volume", type=str, default="default", help="TTS SSML voice volume attribute")
#self.add_argument("--voice-min-words", type=int, default=4, help="the minimum number of words the TTS should wait to speak")
if 'asr' in extras:
self.add_argument("--asr-confidence", type=float, default=-2.5, help="minimum ASR confidence (only applies to 'final' transcripts)")
self.add_argument("--asr-silence", type=float, default=-1.0, help="audio with RMS equal to or below this amount will be considered silent (negative will disable silence detection)")
self.add_argument("--asr-chunk", type=int, default=1600, help="the number of audio samples to buffer as input to ASR")
self.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding.")
self.add_argument("--boosted-lm-score", type=float, default=4.0, help="Value by which to boost words when decoding.")
self.add_argument("--profanity-filter", action='store_true', help="enable profanity filtering in ASR transcripts")
self.add_argument("--inverse-text-normalization", action='store_true', help="apply Inverse Text Normalization to convert numbers to digits/ect")
self.add_argument("--no-automatic-punctuation", dest='automatic_punctuation', action='store_false', help="disable punctuation in the ASR transcripts")
# WEBSERVER
if 'web' in extras:
self.add_argument("--web-host", type=str, default='0.0.0.0', help="network interface to bind to for hosting the webserver")
self.add_argument("--web-port", type=int, default=8050, help="port used for webserver HTTP/HTTPS")
self.add_argument("--ws-port", type=int, default=49000, help="port used for websocket communication")
self.add_argument("--ssl-key", default=os.getenv('SSL_KEY'), type=str, help="path to PEM-encoded SSL/TLS key file for enabling HTTPS")
self.add_argument("--ssl-cert", default=os.getenv('SSL_CERT'), type=str, help="path to PEM-encoded SSL/TLS cert file for enabling HTTPS")
self.add_argument("--upload-dir", type=str, default='/tmp/uploads', help="the path to save files uploaded from the client")
self.add_argument("--web-trace", action="store_true", help="output websocket message logs when --log-level=debug")
# LOGGING
if 'log' in extras:
self.add_argument("--log-level", type=str, default='info', choices=['debug', 'info', 'warning', 'error', 'critical'], help="the logging level to stdout")
self.add_argument("--debug", "--verbose", action="store_true", help="set the logging level to debug/verbose mode")
def parse_args(self, **kwargs):
"""
Override for parse_args() that does some additional configuration
"""
args = super().parse_args(**kwargs)
if hasattr(args, 'prompt'):
args.prompt = ArgParser.parse_prompt_args(args.prompt)
if hasattr(args, 'log_level'):
if args.debug:
args.log_level = "debug"
LogFormatter.config(level=args.log_level)
if hasattr(args, 'list_audio_devices') and args.list_audio_devices:
import riva.client
riva.client.audio_io.list_output_devices()
sys.exit(0)
logging.debug(f"{args}")
return args
@staticmethod
def parse_prompt_args(prompts, chat=True):
"""
Parse prompt command-line argument and return list of prompts
It's assumed that the argparse argument was created like this:
`parser.add_argument('--prompt', action='append', nargs='*')`
If the prompt text is 'default', then default chat prompts will
be assigned if chat=True (otherwise default completion prompts)
"""
if prompts is None:
return None
prompts = [x[0] for x in prompts]
if prompts[0].lower() == 'default' or prompts[0].lower() == 'defaults':
prompts = DefaultChatPrompts if chat else DefaultCompletionPrompts
return prompts