minor bug + better defaults in test()
Browse files
cli.py
CHANGED
|
@@ -223,7 +223,7 @@ def train(*, dataset='mnist', folder='mnist', resume=False, model='convae', walk
|
|
| 223 |
nb_updates += 1
|
| 224 |
|
| 225 |
|
| 226 |
-
def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=
|
| 227 |
if not os.path.exists(folder):
|
| 228 |
os.makedirs(folder, exist_ok=True)
|
| 229 |
dataset = load_dataset(dataset, split='train')
|
|
@@ -235,6 +235,7 @@ def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=100, nb_gene
|
|
| 235 |
model_path = os.path.join(folder, "model.th")
|
| 236 |
ae = torch.load(model_path, map_location="cpu")
|
| 237 |
ae = ae.to(device)
|
|
|
|
| 238 |
def enc(X):
|
| 239 |
batch_size = 64
|
| 240 |
h_list = []
|
|
@@ -267,12 +268,12 @@ def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=100, nb_gene
|
|
| 267 |
np.savez('{}/generated.npz'.format(folder), X=g.numpy())
|
| 268 |
g_subset = g[:, 0:100]
|
| 269 |
gr = grid_of_images_default(g_subset.reshape((g_subset.shape[0]*g_subset.shape[1], h, w, 1)).numpy(), shape=(g_subset.shape[0], g_subset.shape[1]))
|
| 270 |
-
imsave('{}/gen_full_iters.png'.format(folder), gr)
|
| 271 |
|
| 272 |
g = g[-1] # last iter
|
| 273 |
print(g.shape)
|
| 274 |
gr = grid_of_images_default(g.numpy())
|
| 275 |
-
imsave('{}/gen_full.png'.format(folder), gr)
|
| 276 |
|
| 277 |
if tsne:
|
| 278 |
from sklearn.manifold import TSNE
|
|
@@ -300,13 +301,13 @@ def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=100, nb_gene
|
|
| 300 |
print('fit tsne...')
|
| 301 |
ah = sne.fit_transform(ah)
|
| 302 |
print('grid embedding...')
|
| 303 |
-
|
| 304 |
asmall = np.concatenate((a[0:450], a[nb:nb + 450]), axis=0)
|
| 305 |
ahsmall = np.concatenate((ah[0:450], ah[nb:nb + 450]), axis=0)
|
| 306 |
rows = grid_embedding(ahsmall)
|
| 307 |
asmall = asmall[rows]
|
| 308 |
gr = grid_of_images_default(asmall)
|
| 309 |
-
imsave('{}/sne_grid.png'.format(folder), gr)
|
| 310 |
|
| 311 |
fig = plt.figure(figsize=(10, 10))
|
| 312 |
plot_dataset(ah, labels)
|
|
|
|
| 223 |
nb_updates += 1
|
| 224 |
|
| 225 |
|
| 226 |
+
def test(*, dataset='mnist', folder='out', model_path=None, nb_iter=25, nb_generate=100, nb_active=160, tsne=False):
|
| 227 |
if not os.path.exists(folder):
|
| 228 |
os.makedirs(folder, exist_ok=True)
|
| 229 |
dataset = load_dataset(dataset, split='train')
|
|
|
|
| 235 |
model_path = os.path.join(folder, "model.th")
|
| 236 |
ae = torch.load(model_path, map_location="cpu")
|
| 237 |
ae = ae.to(device)
|
| 238 |
+
ae.nb_active = nb_active # for fc_sparse.th only
|
| 239 |
def enc(X):
|
| 240 |
batch_size = 64
|
| 241 |
h_list = []
|
|
|
|
| 268 |
np.savez('{}/generated.npz'.format(folder), X=g.numpy())
|
| 269 |
g_subset = g[:, 0:100]
|
| 270 |
gr = grid_of_images_default(g_subset.reshape((g_subset.shape[0]*g_subset.shape[1], h, w, 1)).numpy(), shape=(g_subset.shape[0], g_subset.shape[1]))
|
| 271 |
+
imsave('{}/gen_full_iters.png'.format(folder), (gr*255).astype("uint8") )
|
| 272 |
|
| 273 |
g = g[-1] # last iter
|
| 274 |
print(g.shape)
|
| 275 |
gr = grid_of_images_default(g.numpy())
|
| 276 |
+
imsave('{}/gen_full.png'.format(folder), (gr*255).astype("uint8") )
|
| 277 |
|
| 278 |
if tsne:
|
| 279 |
from sklearn.manifold import TSNE
|
|
|
|
| 301 |
print('fit tsne...')
|
| 302 |
ah = sne.fit_transform(ah)
|
| 303 |
print('grid embedding...')
|
| 304 |
+
assert nb_generate >= 450
|
| 305 |
asmall = np.concatenate((a[0:450], a[nb:nb + 450]), axis=0)
|
| 306 |
ahsmall = np.concatenate((ah[0:450], ah[nb:nb + 450]), axis=0)
|
| 307 |
rows = grid_embedding(ahsmall)
|
| 308 |
asmall = asmall[rows]
|
| 309 |
gr = grid_of_images_default(asmall)
|
| 310 |
+
imsave('{}/sne_grid.png'.format(folder), (gr*255).astype("uint8") )
|
| 311 |
|
| 312 |
fig = plt.figure(figsize=(10, 10))
|
| 313 |
plot_dataset(ah, labels)
|
viz.py
CHANGED
|
@@ -116,8 +116,8 @@ def grid_of_images(M, border=0, bordercolor=[0.0, 0.0, 0.0], shape=None, normali
|
|
| 116 |
height, width, color = M[0].shape
|
| 117 |
assert color == 3, 'Nb of color channels are {}'.format(color)
|
| 118 |
if shape is None:
|
| 119 |
-
n0 = np.
|
| 120 |
-
n1 = np.
|
| 121 |
else:
|
| 122 |
n0 = shape[0]
|
| 123 |
n1 = shape[1]
|
|
|
|
| 116 |
height, width, color = M[0].shape
|
| 117 |
assert color == 3, 'Nb of color channels are {}'.format(color)
|
| 118 |
if shape is None:
|
| 119 |
+
n0 = np.int32(np.ceil(np.sqrt(numimages)))
|
| 120 |
+
n1 = np.int32(np.ceil(np.sqrt(numimages)))
|
| 121 |
else:
|
| 122 |
n0 = shape[0]
|
| 123 |
n1 = shape[1]
|