jaeyoungk commited on
Commit
b6a60d6
1 Parent(s): 2300a2f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +31 -30
README.md CHANGED
@@ -35,25 +35,35 @@ This is the model card of a 🤗 transformers model that has been pushed on the
35
 
36
  ## Uses
37
 
38
- use under gen function to parse output from the LLM
39
- Loaded function of LLM is as the same as other LLM
40
-
41
  import re
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def gen(x):
44
  system_prompt = f"""
45
  Make a trading decision based on the following data.
46
  Please respond with a JSON object in the following format:
47
  {{"investment_decision": string, "summary_reason": string, "short_memory_index": number, "middle_memory_index": number, "long_memory_index": number, "reflection_memory_index": number}}
48
  investment_decision must always be one of {{buy, sell, hold}}
49
- Print the memory index value to 4 decimal places. If it exceeds, round up.
50
  """
51
 
52
  # Tokenizing the input and generating the output
53
 
54
  inputs = tokenizer(
55
  [
56
- f"system{system_prompt}user{x}"
57
  ], return_tensors = "pt").to("cuda")
58
 
59
 
@@ -66,36 +76,27 @@ def gen(x):
66
 
67
  full_text = tokenizer.decode(gened[0])
68
 
69
- # Define possible start phrases
70
- possible_start_phrases = ["{\"investment_decision\": \"buy\"", "{\"investment_decision\": \"sell\"", "{\"investment_decision\": \"hold\""]
71
- start_idx = -1
72
-
73
- # Find the index for the start phrase
74
- for phrase in possible_start_phrases:
75
- start_idx = full_text.find(phrase)
76
- if start_idx != -1:
77
- break
78
-
79
- if start_idx == -1:
80
- return "No valid investment decision found in the output."
81
-
82
- # Find the index for the end phrase
83
- end_phrase = "\"reflection_memory_index\":"
84
- end_idx = full_text.find(end_phrase, start_idx)
85
-
86
- if end_idx == -1:
87
- return "No valid reflection_memory_index found in the output."
88
-
89
- # Find the end of the reflection_memory_index value
90
- end_idx = full_text.find('}', end_idx)
91
- if end_idx == -1:
92
- return "No closing bracket found in the output."
93
 
94
  # Extract the text between start_idx and end_idx
95
- extracted_text = full_text[start_idx:end_idx+1].strip()
96
 
97
  return extracted_text
98
 
 
 
99
 
100
 
101
  ### Direct Use
 
35
 
36
  ## Uses
37
 
38
+ import torch
39
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
40
  import re
41
 
42
+ model_id = "jaeyoungk/albatross" # safetensors 컨버팅된 레포
43
+ bnb_config = BitsAndBytesConfig(
44
+ load_in_4bit=True,
45
+ bnb_4bit_use_double_quant=True,
46
+ bnb_4bit_quant_type="nf4",
47
+ bnb_4bit_compute_dtype=torch.bfloat16
48
+ )
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct')
51
+ model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map='auto')
52
+
53
+
54
  def gen(x):
55
  system_prompt = f"""
56
  Make a trading decision based on the following data.
57
  Please respond with a JSON object in the following format:
58
  {{"investment_decision": string, "summary_reason": string, "short_memory_index": number, "middle_memory_index": number, "long_memory_index": number, "reflection_memory_index": number}}
59
  investment_decision must always be one of {{buy, sell, hold}}
 
60
  """
61
 
62
  # Tokenizing the input and generating the output
63
 
64
  inputs = tokenizer(
65
  [
66
+ f"<|start_header_id|>system<|end_header_id|>{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>{x}<|end_header_id|>"
67
  ], return_tensors = "pt").to("cuda")
68
 
69
 
 
76
 
77
  full_text = tokenizer.decode(gened[0])
78
 
79
+ # Finding the second occurrence of 'user<|end_header_id|'
80
+ start_phrase = "user<|end_header_id|>"
81
+ first_occurrence = full_text.find(start_phrase)
82
+ second_occurrence = full_text.find(start_phrase, first_occurrence + len(start_phrase))
83
+
84
+ if second_occurrence == -1:
85
+ # If the second occurrence is not found, fallback to using the first occurrence
86
+ start_idx = first_occurrence + len(start_phrase)
87
+ else:
88
+ start_idx = second_occurrence + len(start_phrase)
89
+
90
+ # Find the index of the next special token after the start index
91
+ end_idx = full_text.find('\\<|eot_id|', start_idx)
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # Extract the text between start_idx and end_idx
94
+ extracted_text = full_text[start_idx:end_idx].strip()
95
 
96
  return extracted_text
97
 
98
+ gen('input your text here')
99
+
100
 
101
 
102
  ### Direct Use