File size: 5,873 Bytes
10d6a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# I didn't write this code

from litellm import completion, acompletion
from litellm.utils import CustomStreamWrapper, ModelResponse
import os

class LLM:
    '''
    Creates primary Class instance for interacting with various LLM model APIs.
    Primary APIs supported are OpenAI and Anthropic.
    '''
    # non-exhaustive list of supported models 
    # these models are known to work
    valid_models = {'openai': [
                        "gpt-4-turbo-preview",
                        "gpt-4-0125-preview",
                        "gpt-4-1106-preview",
                        "gpt-3.5-turbo",
                        "gpt-3.5-turbo-1106",
                        "gpt-3.5-turbo-0125",
                        ],
                    'anthropic': [ 'claude-3-haiku-20240307', 
                                   'claude-3-sonnet-2024022',
                                   'claude-3-opus-20240229'
                                   ],
                    'cohere': ['command-r',
                               'command-r-plus'
                               ]
                    }

    def __init__(self, 
                 model_name: str='gpt-3.5-turbo-0125',
                 api_key: str=None,
                 api_version: str=None,
                 api_base: str=None
                 ):
        
        self.model_name = model_name
        if not api_key:
            try:
                self._api_key = os.environ['OPENAI_API_KEY']
            except KeyError:
                raise ValueError('Default api_key expects OPENAI_API_KEY environment variable. Check that you have this variable or pass in another api_key.')
        else:
            self._api_key = api_key
        self.api_version = api_version
        self.api_base = api_base

  
    def chat_completion(self, 
                        system_message: str,
                        user_message: str='',
                        temperature: int=0, 
                        max_tokens: int=500,
                        stream: bool=False,
                        raw_response: bool=False,
                        **kwargs
                        ) -> str | CustomStreamWrapper | ModelResponse:
        '''
        Generative text completion method.

        Args:
        -----
        system_message: str
            The system message to be sent to the model.
        user_message: str
            The user message to be sent to the model.
        temperature: int
            The temperature parameter for the model.
        max_tokens: int
            The maximum tokens to be generated.
        stream: bool
            Whether to stream the response.
        raw_response: bool
            If True, returns the raw model response.
        '''
        #reformat roles for claude models
        initial_role = 'user' if self.model_name.startswith('claude') else 'system'
        secondary_role = 'assistant' if self.model_name.startswith('claude') else 'user'
        
        #handle temperature for claude models
        if self.model_name.startswith('claude'):
            temperature = temperature/2

        messages =  [
            {'role': initial_role, 'content': system_message},
            {'role': secondary_role, 'content': user_message}
                    ]
        
        response = completion(model=self.model_name,
                              messages=messages,
                              temperature=temperature,
                              max_tokens=max_tokens,
                              stream=stream,
                              api_key=self._api_key,
                              api_base=self.api_base,
                              api_version=self.api_version,
                              **kwargs)
        
        if raw_response or stream:
            return response
        return response.choices[0].message.content
    
    async def achat_completion(self, 
                               system_message: str,
                               user_message: str=None,
                               temperature: int=0, 
                               max_tokens: int=500,
                               stream: bool=False,
                               raw_response: bool=False,
                               **kwargs
                               ) -> str | CustomStreamWrapper | ModelResponse:
        '''
        Asynchronous generative text completion method.

        Args:
        -----
        system_message: str
            The system message to be sent to the model.
        user_message: str
            The user message to be sent to the model.
        temperature: int
            The temperature parameter for the model.
        max_tokens: int
            The maximum tokens to be generated.
        stream: bool
            Whether to stream the response.
        raw_response: bool
            If True, returns the raw model response.
        '''
        initial_role = 'user' if self.model_name.startswith('claude') else 'system'
        if self.model_name.startswith('claude'):
            temperature = temperature/2
        messages =  [
            {'role': initial_role, 'content': system_message},
            {'role': 'user', 'content': user_message}
                    ]
        response = await acompletion(model=self.model_name,
                                     messages=messages,
                                     temperature=temperature,
                                     max_tokens=max_tokens,
                                     stream=stream,
                                     api_key=self._api_key,
                                     api_base=self.api_base,
                                     api_version=self.api_version,
                                     **kwargs)
        if raw_response or stream:
            return response
        return response.choices[0].message.content