mirror of
https://github.com/denshooter/gpu_colorization.git
synced 2026-01-21 12:32:57 +01:00
Adding new files
This commit is contained in:
45
Encoder.py
Normal file
45
Encoder.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import tensorflow as tf
|
||||
|
||||
class Encoder(tf.keras.Model): # <-- Needed to make parameters trainable and to be callable
|
||||
def __init__(self):
|
||||
super(Encoder, self).__init__()
|
||||
self.layer_list = [
|
||||
# input (243,243)
|
||||
|
||||
tf.keras.layers.Conv2D(75, kernel_size=(3, 3), strides=2, padding='same'),
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.Activation(tf.nn.tanh),
|
||||
# -> (81, 81, 32)
|
||||
|
||||
tf.keras.layers.Conv2D(90, kernel_size=(3, 3), strides=2, padding='same'),
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.Activation(tf.nn.tanh),
|
||||
# -> (27, 27, 64)
|
||||
|
||||
tf.keras.layers.Conv2D(105, kernel_size=(3, 3), strides=2, padding='same'),
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.Activation(tf.nn.tanh),
|
||||
|
||||
|
||||
# bottleneck
|
||||
tf.keras.layers.Conv2D(3, kernel_size=(1, 1), strides=1, padding='same'),
|
||||
tf.keras.layers.BatchNormalization(),
|
||||
tf.keras.layers.Activation(tf.nn.tanh),
|
||||
|
||||
]
|
||||
|
||||
@tf.function
|
||||
def call(self, x, training):
|
||||
#print("encoder:")
|
||||
for layer in self.layer_list:
|
||||
#print(x.shape)
|
||||
if isinstance(layer, tf.keras.layers.BatchNormalization):
|
||||
x = layer(x,training)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
#print(x.shape)
|
||||
#print("-------------")
|
||||
#exit()
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user