| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Functions for downloading and reading MNIST data.""" |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| import gzip |
| import os |
| import tensorflow.python.platform |
| import numpy |
| from six.moves import urllib |
| from six.moves import xrange |
| import tensorflow as tf |
| SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' |
| |
| def maybe_download(filename, work_directory): |
| """Download the data from Yann's website, unless it's already here.""" |
| if not os.path.exists(work_directory): |
| os.mkdir(work_directory) |
| filepath = os.path.join(work_directory, filename) |
| if not os.path.exists(filepath): |
| filepath, _ = urllib.request.urlretrieve( |
| SOURCE_URL + filename, filepath) |
| statinfo = os.stat(filepath) |
| print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') |
| return filepath |
| |
| def _read32(bytestream): |
| dt = numpy.dtype(numpy.uint32).newbyteorder('>') |
| return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] |
| |
| def extract_images(filename): |
| """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" |
| print('Extracting', filename) |
| with gzip.open(filename) as bytestream: |
| magic = _read32(bytestream) |
| if magic != 2051: |
| raise ValueError( |
| 'Invalid magic number %d in MNIST image file: %s' % |
| (magic, filename)) |
| num_images = _read32(bytestream) |
| rows = _read32(bytestream) |
| cols = _read32(bytestream) |
| buf = bytestream.read(rows * cols * num_images) |
| data = numpy.frombuffer(buf, dtype=numpy.uint8) |
| data = data.reshape(num_images, rows, cols, 1) |
| return data |
| |
| def dense_to_one_hot(labels_dense, num_classes=10): |
| """Convert class labels from scalars to one-hot vectors.""" |
| num_labels = labels_dense.shape[0] |
| index_offset = numpy.arange(num_labels) * num_classes |
| labels_one_hot = numpy.zeros((num_labels, num_classes)) |
| labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 |
| return labels_one_hot |
| |
| def extract_labels(filename, one_hot=False): |
| """Extract the labels into a 1D uint8 numpy array [index].""" |
| print('Extracting', filename) |
| with gzip.open(filename) as bytestream: |
| magic = _read32(bytestream) |
| if magic != 2049: |
| raise ValueError( |
| 'Invalid magic number %d in MNIST label file: %s' % |
| (magic, filename)) |
| num_items = _read32(bytestream) |
| buf = bytestream.read(num_items) |
| labels = numpy.frombuffer(buf, dtype=numpy.uint8) |
| if one_hot: |
| return dense_to_one_hot(labels) |
| return labels |
| |
| class DataSet(object): |
| def __init__(self, images, labels, fake_data=False, one_hot=False, |
| dtype=tf.float32): |
| """Construct a DataSet. |
| one_hot arg is used only if fake_data is true. `dtype` can be either |
| `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into |
| `[0, 1]`. |
| """ |
| dtype = tf.as_dtype(dtype).base_dtype |
| if dtype not in (tf.uint8, tf.float32): |
| raise TypeError('Invalid image dtype %r, expected uint8 or float32' % |
| dtype) |
| if fake_data: |
| self._num_examples = 10000 |
| self.one_hot = one_hot |
| else: |
| assert images.shape[0] == labels.shape[0], ( |
| 'images.shape: %s labels.shape: %s' % (images.shape, |
| labels.shape)) |
| self._num_examples = images.shape[0] |
| |
| |
| assert images.shape[3] == 1 |
| images = images.reshape(images.shape[0], |
| images.shape[1] * images.shape[2]) |
| if dtype == tf.float32: |
| |
| images = images.astype(numpy.float32) |
| images = numpy.multiply(images, 1.0 / 255.0) |
| self._images = images |
| self._labels = labels |
| self._epochs_completed = 0 |
| self._index_in_epoch = 0 |
| |
| @property |
| def images(self): |
| return self._images |
| |
| @property |
| def labels(self): |
| return self._labels |
| |
| @property |
| def num_examples(self): |
| return self._num_examples |
| |
| @property |
| def epochs_completed(self): |
| return self._epochs_completed |
| |
| def next_batch(self, batch_size, fake_data=False): |
| """Return the next `batch_size` examples from this data set.""" |
| if fake_data: |
| fake_image = [1] * 784 |
| if self.one_hot: |
| fake_label = [1] + [0] * 9 |
| else: |
| fake_label = 0 |
| return [fake_image for _ in xrange(batch_size)], [ |
| fake_label for _ in xrange(batch_size)] |
| start = self._index_in_epoch |
| self._index_in_epoch += batch_size |
| if self._index_in_epoch > self._num_examples: |
| |
| self._epochs_completed += 1 |
| |
| perm = numpy.arange(self._num_examples) |
| numpy.random.shuffle(perm) |
| self._images = self._images[perm] |
| self._labels = self._labels[perm] |
| |
| start = 0 |
| self._index_in_epoch = batch_size |
| assert batch_size <= self._num_examples |
| end = self._index_in_epoch |
| return self._images[start:end], self._labels[start:end] |
| |
| def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32): |
| class DataSets(object): |
| pass |
| data_sets = DataSets() |
| if fake_data: |
| def fake(): |
| return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) |
| data_sets.train = fake() |
| data_sets.validation = fake() |
| data_sets.test = fake() |
| return data_sets |
| TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' |
| TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' |
| TEST_IMAGES = 't10k-images-idx3-ubyte.gz' |
| TEST_LABELS = 't10k-labels-idx1-ubyte.gz' |
| VALIDATION_SIZE = 5000 |
| local_file = maybe_download(TRAIN_IMAGES, train_dir) |
| train_images = extract_images(local_file) |
| local_file = maybe_download(TRAIN_LABELS, train_dir) |
| train_labels = extract_labels(local_file, one_hot=one_hot) |
| local_file = maybe_download(TEST_IMAGES, train_dir) |
| test_images = extract_images(local_file) |
| local_file = maybe_download(TEST_LABELS, train_dir) |
| test_labels = extract_labels(local_file, one_hot=one_hot) |
| validation_images = train_images[:VALIDATION_SIZE] |
| validation_labels = train_labels[:VALIDATION_SIZE] |
| train_images = train_images[VALIDATION_SIZE:] |
| train_labels = train_labels[VALIDATION_SIZE:] |
| data_sets.train = DataSet(train_images, train_labels, dtype=dtype) |
| data_sets.validation = DataSet(validation_images, validation_labels, |
| dtype=dtype) |
| data_sets.test = DataSet(test_images, test_labels, dtype=dtype) |
| return data_sets |