takarajordan commited on
Commit
1a61279
β€’
1 Parent(s): 49edea6

Upload 2 files

Browse files
Files changed (2) hide show
  1. decoder.py +29 -0
  2. train.py +337 -0
decoder.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+
5
+ # Import your model class and extraction functions
6
+ from train import SteganographyNet, extract_message, get_device
7
+
8
+ # Import safetensors if available
9
+ try:
10
+ from safetensors.torch import load_file as load_safetensors
11
+ except ImportError:
12
+ print("safetensors not installed. Run: pip install safetensors")
13
+
14
+ # Load the saved model
15
+ device = get_device()
16
+ model = SteganographyNet(message_length=1024).to(device) # message_length doesn't matter for extraction
17
+
18
+ # Load model weights based on file extension
19
+ model_path = 'model.safetensors' # or 'stego_model_3.safetensors'
20
+ if model_path.endswith('.safetensors'):
21
+ model.load_state_dict(load_safetensors(model_path))
22
+ else:
23
+ model.load_state_dict(torch.load(model_path))
24
+
25
+ model.eval()
26
+
27
+ # Test extraction
28
+ extracted_message = extract_message(model, 'decode_me_3.png')
29
+ print(f"Extracted message: {extracted_message}")
train.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+ import torch.backends.mps
8
+ from math import exp
9
+ import torch.nn.functional as F
10
+
11
+ class SteganographyNet(nn.Module):
12
+ def __init__(self, message_length):
13
+ super(SteganographyNet, self).__init__()
14
+ self.message_length = message_length
15
+
16
+ # Modified encoder with skip connection
17
+ self.encoder_initial = nn.Sequential(
18
+ nn.Conv2d(4, 64, 3, padding=1),
19
+ nn.GroupNorm(8, 64),
20
+ nn.SiLU(),
21
+ )
22
+
23
+ self.encoder_backbone = nn.Sequential(
24
+ nn.Conv2d(64, 128, 3, padding=1),
25
+ nn.GroupNorm(16, 128),
26
+ nn.SiLU(),
27
+ SEBlock(128),
28
+ nn.Conv2d(128, 128, 3, padding=2, dilation=2),
29
+ nn.GroupNorm(16, 128),
30
+ nn.SiLU(),
31
+ ResidualBlock(128),
32
+ nn.Conv2d(128, 64, 1),
33
+ nn.GroupNorm(8, 64),
34
+ nn.SiLU(),
35
+ )
36
+
37
+ self.encoder_final = nn.Sequential(
38
+ nn.Conv2d(64, 3, 3, padding=1),
39
+ nn.Sigmoid()
40
+ )
41
+
42
+ # Add decoder
43
+ self.decoder = nn.Sequential(
44
+ # Initial feature extraction
45
+ nn.Conv2d(3, 64, 3, padding=1),
46
+ nn.GroupNorm(8, 64),
47
+ nn.SiLU(),
48
+
49
+ # Feature processing
50
+ nn.Conv2d(64, 128, 3, padding=1),
51
+ nn.GroupNorm(16, 128),
52
+ nn.SiLU(),
53
+ SEBlock(128),
54
+
55
+ ResidualBlock(128),
56
+
57
+ nn.Conv2d(128, 64, 3, padding=1),
58
+ nn.GroupNorm(8, 64),
59
+ nn.SiLU(),
60
+
61
+ # Final message extraction
62
+ nn.Conv2d(64, 1, 3, padding=1),
63
+ nn.Sigmoid()
64
+ )
65
+
66
+ def encode(self, x):
67
+ # Extract original image
68
+ original_img = x[:, :3, :, :]
69
+
70
+ # Process through encoder
71
+ initial = self.encoder_initial(x)
72
+ processed = self.encoder_backbone(initial)
73
+ output = self.encoder_final(processed)
74
+
75
+ # Add skip connection from input image
76
+ return 0.9 * original_img + 0.1 * output
77
+
78
+ def forward(self, x):
79
+ # This can be used for end-to-end training
80
+ encoded = self.encode(x)
81
+ decoded = self.decoder(encoded)
82
+ return encoded, decoded
83
+
84
+ # Add these new blocks
85
+ class SEBlock(nn.Module):
86
+ def __init__(self, channels, reduction=16):
87
+ super(SEBlock, self).__init__()
88
+ self.squeeze = nn.AdaptiveAvgPool2d(1)
89
+ self.excitation = nn.Sequential(
90
+ nn.Linear(channels, channels // reduction, bias=False),
91
+ nn.SiLU(),
92
+ nn.Linear(channels // reduction, channels, bias=False),
93
+ nn.Sigmoid()
94
+ )
95
+
96
+ def forward(self, x):
97
+ b, c, _, _ = x.size()
98
+ y = self.squeeze(x).view(b, c)
99
+ y = self.excitation(y).view(b, c, 1, 1)
100
+ return x * y.expand_as(x)
101
+
102
+ class ResidualBlock(nn.Module):
103
+ def __init__(self, channels):
104
+ super(ResidualBlock, self).__init__()
105
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
106
+ self.gn1 = nn.GroupNorm(8, channels)
107
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
108
+ self.gn2 = nn.GroupNorm(8, channels)
109
+ self.silu = nn.SiLU()
110
+
111
+ def forward(self, x):
112
+ residual = x
113
+ out = self.silu(self.gn1(self.conv1(x)))
114
+ out = self.gn2(self.conv2(out))
115
+ out += residual
116
+ return self.silu(out)
117
+
118
+ class SSIM(nn.Module):
119
+ def __init__(self, window_size=11, size_average=True, channel=3):
120
+ super(SSIM, self).__init__()
121
+ self.window_size = window_size
122
+ self.size_average = size_average
123
+ self.channel = channel
124
+ self.window = self.create_window(window_size, channel)
125
+
126
+ def gaussian(self, window_size, sigma):
127
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
128
+ return gauss/gauss.sum()
129
+
130
+ def create_window(self, window_size, channel):
131
+ _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
132
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
133
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
134
+ return window
135
+
136
+ def ssim(self, img1, img2, window, size_average=True):
137
+ mu1 = F.conv2d(img1, window, padding=self.window_size//2, groups=self.channel)
138
+ mu2 = F.conv2d(img2, window, padding=self.window_size//2, groups=self.channel)
139
+
140
+ mu1_sq = mu1.pow(2)
141
+ mu2_sq = mu2.pow(2)
142
+ mu1_mu2 = mu1 * mu2
143
+
144
+ sigma1_sq = F.conv2d(img1*img1, window, padding=self.window_size//2, groups=self.channel) - mu1_sq
145
+ sigma2_sq = F.conv2d(img2*img2, window, padding=self.window_size//2, groups=self.channel) - mu2_sq
146
+ sigma12 = F.conv2d(img1*img2, window, padding=self.window_size//2, groups=self.channel) - mu1_mu2
147
+
148
+ C1 = 0.01**2
149
+ C2 = 0.03**2
150
+
151
+ ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
152
+
153
+ if size_average:
154
+ return ssim_map.mean()
155
+ else:
156
+ return ssim_map.mean(1).mean(1).mean(1)
157
+
158
+ def forward(self, img1, img2):
159
+ # Make sure window is on the same device as input
160
+ window = self.window.to(img1.device)
161
+ return self.ssim(img1, img2, window, self.size_average)
162
+
163
+ def get_device():
164
+ if torch.backends.mps.is_available():
165
+ return torch.device("mps")
166
+ elif torch.cuda.is_available():
167
+ return torch.device("cuda")
168
+ else:
169
+ return torch.device("cpu")
170
+
171
+ def text_to_binary_tensor(text, height, width):
172
+ """Convert text to binary tensor"""
173
+ # Convert text to UTF-8 bytes, then to binary
174
+ binary = ''.join(format(byte, '08b') for byte in text.encode('utf-8'))
175
+ # Pad binary string to fill image
176
+ binary = binary + '0' * (height * width - len(binary))
177
+ binary_array = np.array([int(b) for b in binary]).reshape(1, height, width)
178
+ return torch.FloatTensor(binary_array)
179
+
180
+ def binary_tensor_to_text(tensor):
181
+ """Convert binary tensor back to text"""
182
+ # Threshold the tensor values to get clear 0s and 1s
183
+ binary = ''.join([str(int(round(float(b)))) for b in tensor.flatten()])
184
+
185
+ # Process in 8-bit chunks
186
+ message = ''
187
+ for i in range(0, len(binary) - 7, 8): # Changed to ensure we don't go past the end
188
+ byte = binary[i:i+8]
189
+ try:
190
+ char = chr(int(byte, 2))
191
+ if ord(char) == 0: # Stop at null terminator
192
+ break
193
+ message += char
194
+ except ValueError:
195
+ continue # Skip invalid bytes
196
+
197
+ return message
198
+
199
+ def embed_message(model, image_path, message, output_path):
200
+ """Embed a message into an image using the trained model"""
201
+ device = get_device()
202
+ # Load and preprocess image (now using 512x512)
203
+ transform = transforms.Compose([
204
+ transforms.Resize((512, 512)),
205
+ transforms.ToTensor()
206
+ ])
207
+ img = transform(Image.open(image_path)).unsqueeze(0).to(device)
208
+
209
+ # Prepare message (now using 512x512)
210
+ msg_tensor = text_to_binary_tensor(message, 512, 512).to(device)
211
+ msg_tensor = msg_tensor.unsqueeze(0)
212
+
213
+ # Concatenate image and message
214
+ x = torch.cat([img, msg_tensor], dim=1)
215
+
216
+ # Generate stego image
217
+ model.eval()
218
+ with torch.no_grad():
219
+ stego_img = model.encode(x)
220
+
221
+ # Save image
222
+ stego_img = stego_img.squeeze(0).cpu()
223
+ transforms.ToPILImage()(stego_img).save(output_path, 'PNG')
224
+ return True
225
+
226
+ def extract_message(model, image_path):
227
+ """Extract hidden message from image using the trained model"""
228
+ device = get_device()
229
+ transform = transforms.Compose([
230
+ transforms.Resize((512, 512)),
231
+ transforms.ToTensor()
232
+ ])
233
+ stego_img = transform(Image.open(image_path)).unsqueeze(0).to(device)
234
+
235
+ # Extract message
236
+ model.eval()
237
+ with torch.no_grad():
238
+ msg_tensor = model.decoder(stego_img)
239
+
240
+ # Threshold the values more aggressively
241
+ msg_tensor = (msg_tensor > 0.5).float()
242
+
243
+ # Convert to text with better error handling
244
+ try:
245
+ # Convert binary tensor to bytes
246
+ binary = msg_tensor.cpu().numpy().flatten()
247
+ binary_str = ''.join(['1' if b > 0.5 else '0' for b in binary])
248
+
249
+ # Process in chunks until we hit invalid UTF-8 or null terminator
250
+ bytes_data = bytearray()
251
+ for i in range(0, len(binary_str) - 7, 8):
252
+ byte = binary_str[i:i+8]
253
+ byte_val = int(byte, 2)
254
+ if byte_val == 0: # Stop at null terminator
255
+ break
256
+ bytes_data.append(byte_val)
257
+
258
+ # Decode with explicit UTF-8 handling
259
+ message = bytes_data.decode('utf-8', errors='ignore')
260
+
261
+ # Clean up any trailing null characters
262
+ message = message.split('\x00')[0]
263
+
264
+ except Exception as e:
265
+ print(f"Error during message extraction: {e}")
266
+ message = ""
267
+
268
+ return message
269
+
270
+ def train_model(image_path, message, epochs=600):
271
+ """Train the steganography model"""
272
+ device = get_device()
273
+ model = SteganographyNet(len(message) * 8).to(device)
274
+
275
+ # Use modern optimizer with weight decay
276
+ optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
277
+
278
+ # Use cosine annealing scheduler
279
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=1e-6)
280
+
281
+ # Use modern loss combination
282
+ mse_loss = nn.MSELoss()
283
+ ssim_loss = SSIM().to(device) # Structural Similarity Loss
284
+
285
+ # Prepare data (now using 512x512)
286
+ transform = transforms.Compose([
287
+ transforms.Resize((512, 512)),
288
+ transforms.ToTensor()
289
+ ])
290
+ img = transform(Image.open(image_path)).unsqueeze(0).to(device)
291
+ msg_tensor = text_to_binary_tensor(message, 512, 512).to(device)
292
+ msg_tensor = msg_tensor.unsqueeze(0)
293
+
294
+ # Training loop
295
+ for epoch in range(epochs):
296
+ # Forward pass
297
+ x = torch.cat([img, msg_tensor], dim=1)
298
+ stego_img = model.encode(x)
299
+ recovered_msg = model.decoder(stego_img)
300
+
301
+ # Calculate losses with perceptual components
302
+ image_loss = 0.95 * mse_loss(stego_img, img) + 0.05 * (1 - ssim_loss(stego_img, img))
303
+ message_loss = mse_loss(recovered_msg, msg_tensor)
304
+ # Adjust alpha to prioritize image quality
305
+ alpha = min(epoch / (epochs * 0.4), 0.3) # Cap at 0.3 instead of 1.0
306
+ total_loss = (1 - alpha) * image_loss + (alpha * 5) * message_loss # Reduced message weight from 10 to 5
307
+
308
+ # Backward pass
309
+ optimizer.zero_grad()
310
+ total_loss.backward()
311
+ optimizer.step()
312
+ scheduler.step()
313
+
314
+ if (epoch + 1) % 100 == 0:
315
+ print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss.item():.4f}')
316
+
317
+ return model
318
+
319
+ # Example usage
320
+ if __name__ == "__main__":
321
+ input_image = "steno_2(1).jpg"
322
+ output_image = "decode_me_3.png"
323
+ secret_message = "γ€Œη™½γη«‹γ‘, 道を瀺し, ε£°γͺしに, ζ—…δΊΊε°Žγ, 私はθͺ°οΌŸγ‚Έγƒ§γƒΌγƒ€γƒ³γ«η­”γˆγ‚’ι€γ£γ¦γγ γ‘γ„γ€‚γ€"
324
+
325
+ # Train model
326
+ model = train_model(input_image, secret_message)
327
+
328
+ # Save model weights
329
+ torch.save(model.state_dict(), 'stego_model_3.pth')
330
+
331
+ # Embed message
332
+ embed_message(model, input_image, secret_message, output_image)
333
+ print("Message embedded successfully!")
334
+
335
+ # Extract message
336
+ extracted_message = extract_message(model, output_image)
337
+ print(f"Extracted message: {extracted_message}")