爆款云主机2核4G限时秒杀,88元/年起!
查看详情

活动

天翼云最新优惠活动,涵盖免费试用,产品折扣等,助您降本增效!
热门活动
  • 618智算钜惠季 爆款云主机2核4G限时秒杀,88元/年起!
  • 免费体验DeepSeek,上天翼云息壤 NEW 新老用户均可免费体验2500万Tokens,限时两周
  • 云上钜惠 HOT 爆款云主机全场特惠,更有万元锦鲤券等你来领!
  • 算力套餐 HOT 让算力触手可及
  • 天翼云脑AOne NEW 连接、保护、办公,All-in-One!
  • 中小企业应用上云专场 产品组合下单即享折上9折起,助力企业快速上云
  • 息壤高校钜惠活动 NEW 天翼云息壤杯高校AI大赛,数款产品享受线上订购超值特惠
  • 天翼云电脑专场 HOT 移动办公新选择,爆款4核8G畅享1年3.5折起,快来抢购!
  • 天翼云奖励推广计划 加入成为云推官,推荐新用户注册下单得现金奖励
免费活动
  • 免费试用中心 HOT 多款云产品免费试用,快来开启云上之旅
  • 天翼云用户体验官 NEW 您的洞察,重塑科技边界

智算服务

打造统一的产品能力,实现算网调度、训练推理、技术架构、资源管理一体化智算服务
智算云(DeepSeek专区)
科研助手
  • 算力商城
  • 应用商城
  • 开发机
  • 并行计算
算力互联调度平台
  • 应用市场
  • 算力市场
  • 算力调度推荐
一站式智算服务平台
  • 模型广场
  • 体验中心
  • 服务接入
智算一体机
  • 智算一体机
大模型
  • DeepSeek-R1-昇腾版(671B)
  • DeepSeek-R1-英伟达版(671B)
  • DeepSeek-V3-昇腾版(671B)
  • DeepSeek-R1-Distill-Llama-70B
  • DeepSeek-R1-Distill-Qwen-32B
  • Qwen2-72B-Instruct
  • StableDiffusion-V2.1
  • TeleChat-12B

应用商城

天翼云精选行业优秀合作伙伴及千余款商品,提供一站式云上应用服务
进入甄选商城进入云市场创新解决方案
办公协同
  • WPS云文档
  • 安全邮箱
  • EMM手机管家
  • 智能商业平台
财务管理
  • 工资条
  • 税务风控云
企业应用
  • 翼信息化运维服务
  • 翼视频云归档解决方案
工业能源
  • 智慧工厂_生产流程管理解决方案
  • 智慧工地
建站工具
  • SSL证书
  • 新域名服务
网络工具
  • 翼云加速
灾备迁移
  • 云管家2.0
  • 翼备份
资源管理
  • 全栈混合云敏捷版(软件)
  • 全栈混合云敏捷版(一体机)
行业应用
  • 翼电子教室
  • 翼智慧显示一体化解决方案

合作伙伴

天翼云携手合作伙伴,共创云上生态,合作共赢
天翼云生态合作中心
  • 天翼云生态合作中心
天翼云渠道合作伙伴
  • 天翼云代理渠道合作伙伴
天翼云服务合作伙伴
  • 天翼云集成商交付能力认证
天翼云应用合作伙伴
  • 天翼云云市场合作伙伴
  • 天翼云甄选商城合作伙伴
天翼云技术合作伙伴
  • 天翼云OpenAPI中心
  • 天翼云EasyCoding平台
天翼云培训认证
  • 天翼云学堂
  • 天翼云市场商学院
天翼云合作计划
  • 云汇计划
天翼云东升计划
  • 适配中心
  • 东升计划
  • 适配互认证

开发者

开发者相关功能入口汇聚
技术社区
  • 专栏文章
  • 互动问答
  • 技术视频
资源与工具
  • OpenAPI中心
开放能力
  • EasyCoding敏捷开发平台
培训与认证
  • 天翼云学堂
  • 天翼云认证
魔乐社区
  • 魔乐社区

支持与服务

为您提供全方位支持与服务,全流程技术保障,助您轻松上云,安全无忧
文档与工具
  • 文档中心
  • 新手上云
  • 自助服务
  • OpenAPI中心
定价
  • 价格计算器
  • 定价策略
基础服务
  • 售前咨询
  • 在线支持
  • 在线支持
  • 工单服务
  • 建议与反馈
  • 用户体验官
  • 服务保障
  • 客户公告
  • 会员中心
增值服务
  • 红心服务
  • 首保服务
  • 客户支持计划
  • 专家技术服务
  • 备案管家

了解天翼云

天翼云秉承央企使命,致力于成为数字经济主力军,投身科技强国伟大事业,为用户提供安全、普惠云服务
品牌介绍
  • 关于天翼云
  • 智算云
  • 天翼云4.0
  • 新闻资讯
  • 天翼云APP
基础设施
  • 全球基础设施
  • 信任中心
最佳实践
  • 精选案例
  • 超级探访
  • 云杂志
  • 分析师和白皮书
  • 天翼云·创新直播间
市场活动
  • 2025智能云生态大会
  • 2024智算云生态大会
  • 2023云生态大会
  • 2022云生态大会
  • 天翼云中国行
天翼云
  • 活动
  • 智算服务
  • 产品
  • 解决方案
  • 应用商城
  • 合作伙伴
  • 开发者
  • 支持与服务
  • 了解天翼云
      • 文档
      • 控制中心
      • 备案
      • 管理中心

      深度学习实战之超分辨率算法(tensorflow)——ESPCN

      首页 知识中心 大数据 文章详情页

      深度学习实战之超分辨率算法(tensorflow)——ESPCN

      2024-11-22 08:11:42 阅读次数:26

      算法

      espcn原理算法请参考上一篇论文,这里主要给实现。

      数据集如下:尺寸相等即可

      深度学习实战之超分辨率算法(tensorflow)——ESPCN

      • 针对数据集,生成样本代码
      • preeate_data.py
      import imageio
      from scipy import misc, ndimage
      import numpy as np
      import imghdr
      import shutil
      import os
      import json

      mat = np.array(
      [[ 65.481, 128.553, 24.966 ],
      [-37.797, -74.203, 112.0 ],
      [ 112.0, -93.786, -18.214]])
      mat_inv = np.linalg.inv(mat)
      offset = np.array([16, 128, 128])

      def rgb2ycbcr(rgb_img):
      ycbcr_img = np.zeros(rgb_img.shape, dtype=np.uint8)
      for x in range(rgb_img.shape[0]):
      for y in range(rgb_img.shape[1]):
      ycbcr_img[x, y, :] = np.round(np.dot(mat, rgb_img[x, y, :] * 1.0 / 255) + offset)
      return ycbcr_img

      def ycbcr2rgb(ycbcr_img):
      rgb_img = np.zeros(ycbcr_img.shape, dtype=np.uint8)
      for x in range(ycbcr_img.shape[0]):
      for y in range(ycbcr_img.shape[1]):
      [r, g, b] = ycbcr_img[x,y,:]
      rgb_img[x, y, :] = np.maximum(0, np.minimum(255, np.round(np.dot(mat_inv, ycbcr_img[x, y, :] - offset) * 255.0)))
      return rgb_img

      def my_anti_shuffle(input_image, ratio):
      shape = input_image.shape
      ori_height = int(shape[0])
      ori_width = int(shape[1])
      ori_channels = int(shape[2])
      if ori_height % ratio != 0 or ori_width % ratio != 0:
      print("Error! Height and width must be divided by ratio!")
      return
      height = ori_height // ratio
      width = ori_width // ratio
      channels = ori_channels * ratio * ratio
      anti_shuffle = np.zeros((height, width, channels), dtype=np.uint8)
      for c in range(0, ori_channels):
      for x in range(0, ratio):
      for y in range(0, ratio):
      anti_shuffle[:,:,c * ratio * ratio + x * ratio + y] = input_image[x::ratio, y::ratio, c]
      return anti_shuffle

      def shuffle(input_image, ratio):
      shape = input_image.shape
      height = int(shape[0]) * ratio
      width = int(shape[1]) * ratio
      channels = int(shape[2]) // ratio // ratio
      shuffled = np.zeros((height, width, channels), dtype=np.uint8)
      for i in range(0, height):
      for j in range(0, width):
      for k in range(0, channels):
      shuffled[i,j,k] = input_image[i // ratio, j // ratio, k * ratio * ratio + (i % ratio) * ratio + (j % ratio)]
      return shuffled

      def prepare_images(params):
      ratio, training_num, lr_stride, lr_size = params['ratio'], params['training_num'], params['lr_stride'], params['lr_size']
      hr_stride = lr_stride * ratio
      hr_size = lr_size * ratio

      # first clear old images and create new directories
      for ele in ['training', 'validation', 'test']:
      new_dir = params[ele + '_image_dir'].format(ratio)
      if os.path.isdir(new_dir):
      shutil.rmtree(new_dir)
      for sub_dir in ['/hr', 'lr']:
      os.makedirs(new_dir + sub_dir)

      image_num = 0
      folder = params['training_image_dir'].format(ratio)
      for root, dirnames, filenames in os.walk(params['image_dir']):
      for filename in filenames:
      path = os.path.join(root, filename)
      if imghdr.what(path) != 'jpeg':
      continue

      hr_image = imageio.imread(path)
      height = hr_image.shape[0]
      new_height = height - height % ratio
      width = hr_image.shape[1]
      new_width = width - width % ratio
      hr_image = hr_image[0:new_height,0:new_width]
      blurred = ndimage.gaussian_filter(hr_image, sigma=(1, 1, 0))
      lr_image = blurred[::ratio,::ratio,:]

      height = hr_image.shape[0]
      width = hr_image.shape[1]
      vertical_number = height / hr_stride - 1
      horizontal_number = width / hr_stride - 1
      image_num = image_num + 1
      if image_num % 10 == 0:
      print ("Finished image: {}".format(image_num))
      if image_num > training_num and image_num <= training_num + params['validation_num']:
      folder = params['validation_image_dir'].format(ratio)
      elif image_num > training_num + params['validation_num']:
      folder = params['test_image_dir'].format(ratio)
      #misc.imsave(folder + 'hr_full/' + filename[0:-4] + '.png', hr_image)
      #misc.imsave(folder + 'lr_full/' + filename[0:-4] + '.png', lr_image)
      for x in range(0, int(horizontal_number)):
      for y in range(0, int(vertical_number)):
      hr_sub_image = hr_image[y * hr_stride : y * hr_stride + hr_size, x * hr_stride : x * hr_stride + hr_size]
      lr_sub_image = lr_image[y * lr_stride : y * lr_stride + lr_size, x * lr_stride : x * lr_stride + lr_size]
      imageio.imwrite("{}hr/{}_{}_{}.png".format(folder, filename[0:-4], y, x), hr_sub_image)
      imageio.imwrite("{}lr/{}_{}_{}.png".format(folder, filename[0:-4], y, x), lr_sub_image)
      if image_num >= training_num + params['validation_num'] + params['test_num']:
      break
      else:
      continue
      break

      def prepare_data(params):
      ratio = params['ratio']
      params['hr_stride'] = params['lr_stride'] * ratio
      params['hr_size'] = params['lr_size'] * ratio

      for ele in ['training', 'validation', 'test']:
      new_dir = params[ele + '_dir'].format(ratio)
      if os.path.isdir(new_dir):
      shutil.rmtree(new_dir)
      os.makedirs(new_dir)

      ratio, lr_size, edge = params['ratio'], params['lr_size'], params['edge']
      image_dirs = [d.format(ratio) for d in [params['training_image_dir'], params['validation_image_dir'], params['test_image_dir']]]
      data_dirs = [d.format(ratio) for d in [params['training_dir'], params['validation_dir'], params['test_dir']]]
      hr_start_idx = ratio * edge // 2
      hr_end_idx = hr_start_idx + (lr_size - edge) * ratio
      sub_hr_size = (lr_size - edge) * ratio
      for dir_idx, image_dir in enumerate(image_dirs):
      data_dir = data_dirs[dir_idx]
      print ("Creating {}".format(data_dir))
      for root, dirnames, filenames in os.walk(image_dir + "/lr"):
      for filename in filenames:
      lr_path = os.path.join(root, filename)
      hr_path = image_dir + "/hr/" + filename
      lr_image = imageio.imread(lr_path)
      hr_image = imageio.imread(hr_path)
      # convert to Ycbcr color space
      lr_image_y = rgb2ycbcr(lr_image)
      hr_image_y = rgb2ycbcr(hr_image)
      lr_data = lr_image_y.reshape((lr_size * lr_size * 3))
      sub_hr_image_y = hr_image_y[int(hr_start_idx):int(hr_end_idx):1,int(hr_start_idx):int(hr_end_idx):1]
      hr_data = my_anti_shuffle(sub_hr_image_y, ratio).reshape(sub_hr_size * sub_hr_size * 3)
      data = np.concatenate([lr_data, hr_data])
      data.astype('uint8').tofile(data_dir + "/" + filename[0:-4])

      def remove_images(params):
      # Don't need old image folders
      for ele in ['training', 'validation', 'test']:
      rm_dir = params[ele + '_image_dir'].format(params['ratio'])
      if os.path.isdir(rm_dir):
      shutil.rmtree(rm_dir)


      if __name__ == '__main__':
      with open("./params.json", 'r') as f:
      params = json.load(f)

      print("Preparing images with scaling ratio: {}".format(params['ratio']))
      print ("If you want a different ratio change 'ratio' in params.json")
      print ("Splitting images (1/3)")
      prepare_images(params)

      print ("Preparing data, this may take a while (2/3)")
      prepare_data(params)

      print ("Cleaning up split images (3/3)")
      remove_images(params)
      print("Done, you can now train the model!")
      • generate.py
      import argparse
      from PIL import Image
      import imageio
      import tensorflow as tf
      from scipy import ndimage
      from scipy import misc
      import numpy as np
      from prepare_data import *
      from psnr import psnr
      import json
      import pdb

      from espcn import ESPCN

      def get_arguments():
      parser = argparse.ArgumentParser(description='EspcnNet generation script')
      parser.add_argument('--checkpoint', type=str,
      help='Which model checkpoint to generate from',default="logdir_2x/train")
      parser.add_argument('--lr_image', type=str,
      help='The low-resolution image waiting for processed.',default="images/butterfly_GT.jpg")
      parser.add_argument('--hr_image', type=str,
      help='The high-resolution image which is used to calculate PSNR.')
      parser.add_argument('--out_path', type=str,
      help='The output path for the super-resolution image',default="result/butterfly_HR")
      return parser.parse_args()

      def check_params(args, params):
      if len(params['filters_size']) - len(params['channels']) != 1:
      print("The length of 'filters_size' must be greater then the length of 'channels' by 1.")
      return False
      return True

      def generate():
      args = get_arguments()

      with open("./params.json", 'r') as f:
      params = json.load(f)

      if check_params(args, params) == False:
      return

      sess = tf.Session()

      net = ESPCN(filters_size=params['filters_size'],
      channels=params['channels'],
      ratio=params['ratio'],
      batch_size=1,
      lr_size=params['lr_size'],
      edge=params['edge'])

      loss, images, labels = net.build_model()

      lr_image = tf.placeholder(tf.uint8)
      lr_image_data = imageio.imread(args.lr_image)
      lr_image_ycbcr_data = rgb2ycbcr(lr_image_data)
      lr_image_y_data = lr_image_ycbcr_data[:, :, 0:1]
      lr_image_cb_data = lr_image_ycbcr_data[:, :, 1:2]
      lr_image_cr_data = lr_image_ycbcr_data[:, :, 2:3]
      lr_image_batch = np.zeros((1,) + lr_image_y_data.shape)
      lr_image_batch[0] = lr_image_y_data

      sr_image = net.generate(lr_image)

      saver = tf.train.Saver()
      try:
      model_loaded = net.load(sess, saver, args.checkpoint)
      except:
      raise Exception("Failed to load model, does the ratio in params.json match the ratio you trained your checkpoint with?")

      if model_loaded:
      print("[*] Checkpoint load success!")
      else:
      print("[*] Checkpoint load failed/no checkpoint found")
      return

      sr_image_y_data = sess.run(sr_image, feed_dict={lr_image: lr_image_batch})

      sr_image_y_data = shuffle(sr_image_y_data[0], params['ratio'])
      sr_image_ycbcr_data =np.array(Image.fromarray(lr_image_ycbcr_data).resize(params['ratio'] * np.array(lr_image_data.shape[0:2]),Image.BICUBIC))


      edge = params['edge'] * params['ratio'] / 2

      sr_image_ycbcr_data = np.concatenate((sr_image_y_data, sr_image_ycbcr_data[int(edge):int(-edge),int(edge):int(-edge),1:3]), axis=2)
      sr_image_data = ycbcr2rgb(sr_image_ycbcr_data)

      imageio.imwrite(args.out_path + '.png', sr_image_data)

      if args.hr_image != None:
      hr_image_data = misc.imread(args.hr_image)
      model_psnr = psnr(hr_image_data, sr_image_data, edge)
      print('PSNR of the model: {:.2f}dB'.format(model_psnr))

      sr_image_bicubic_data = misc.imresize(lr_image_data,
      params['ratio'] * np.array(lr_image_data.shape[0:2]),
      'bicubic')
      misc.imsave(args.out_path + '_bicubic.png', sr_image_bicubic_data)
      bicubic_psnr = psnr(hr_image_data, sr_image_bicubic_data, 0)
      print('PSNR of Bicubic: {:.2f}dB'.format(bicubic_psnr))


      if __name__ == '__main__':
      generate()

      train.py
      ```python
      from __future__ import print_function
      import argparse
      from datetime import datetime
      import os
      import sys
      import time
      import json
      import time

      import tensorflow as tf
      from reader import create_inputs
      from espcn import ESPCN

      import pdb


      try:
      xrange
      except Exception as e:
      xrange = range
      # 批次
      BATCH_SIZE = 32
      # epochs
      NUM_EPOCHS = 100
      # learning rate
      LEARNING_RATE = 0.0001
      # logdir
      LOGDIR_ROOT = './logdir_{}x'

      def get_arguments():

      parser = argparse.ArgumentParser(description='EspcnNet example network')
      # 权重
      parser.add_argument('--checkpoint', type=str,
      help='Which model checkpoint to load from', default=None)
      # batch_size
      parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,
      help='How many image files to process at once.')
      # epochs
      parser.add_argument('--epochs', type=int, default=NUM_EPOCHS,
      help='Number of epochs.')
      # 学习率
      parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE,
      help='Learning rate for training.')
      # logdir_root
      parser.add_argument('--logdir_root', type=str, default=LOGDIR_ROOT,
      help='Root directory to place the logging '
      'output and generated model. These are stored '
      'under the dated subdirectory of --logdir_root. '
      'Cannot use with --logdir.')
      # 返回参数
      return parser.parse_args()

      def check_params(args, params):
      if len(params['filters_size']) - len(params['channels']) != 1:
      print("The length of 'filters_size' must be greater then the length of 'channels' by 1.")
      return False
      return True

      def train():

      args = get_arguments()
      # load json
      with open("./params.json", 'r') as f:
      params = json.load(f)
      # 存在
      if check_params(args, params) == False:
      return

      logdir_root = args.logdir_root # ./logdir
      if logdir_root == LOGDIR_ROOT:
      logdir_root = logdir_root.format(params['ratio']) # ./logdir_{RATIO}x
      logdir = os.path.join(logdir_root, 'train') # ./logdir_{RATIO}x/train

      # Load training data as np arrays
      # 加载数据
      lr_images, hr_labels = create_inputs(params)
      # 网络模型
      net = ESPCN(filters_size=params['filters_size'],
      channels=params['channels'],
      ratio=params['ratio'],
      batch_size=args.batch_size,
      lr_size=params['lr_size'],
      edge=params['edge'])

      loss, images, labels = net.build_model()
      optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
      trainable = tf.trainable_variables()
      optim = optimizer.minimize(loss, var_list=trainable)

      # set up logging for tensorboard
      writer = tf.summary.FileWriter(logdir)
      writer.add_graph(tf.get_default_graph())
      summaries = tf.summary.merge_all()

      # set up session
      sess = tf.Session()

      # saver for storing/restoring checkpoints of the model
      saver = tf.train.Saver()

      init = tf.initialize_all_variables()
      sess.run(init)

      if net.load(sess, saver, logdir):
      print("[*] Checkpoint load success!")
      else:
      print("[*] Checkpoint load failed/no checkpoint found")

      try:
      steps, start_average, end_average = 0, 0, 0
      start_time = time.time()
      for ep in xrange(1, args.epochs + 1):
      batch_idxs = len(lr_images) // args.batch_size
      batch_average = 0
      for idx in xrange(0, batch_idxs):
      # On the fly batch generation instead of Queue to optimize GPU usage
      batch_images = lr_images[idx * args.batch_size : (idx + 1) * args.batch_size]
      batch_labels = hr_labels[idx * args.batch_size : (idx + 1) * args.batch_size]

      steps += 1
      summary, loss_value, _ = sess.run([summaries, loss, optim], feed_dict={images: batch_images, labels: batch_labels})
      writer.add_summary(summary, steps)
      batch_average += loss_value

      # Compare loss of first 20% and last 20%
      batch_average = float(batch_average) / batch_idxs
      if ep < (args.epochs * 0.2):
      start_average += batch_average
      elif ep >= (args.epochs * 0.8):
      end_average += batch_average

      duration = time.time() - start_time
      print('Epoch: {}, step: {:d}, loss: {:.9f}, ({:.3f} sec/epoch)'.format(ep, steps, batch_average, duration))
      start_time = time.time()
      net.save(sess, saver, logdir, steps)
      except KeyboardInterrupt:
      print()
      finally:
      start_average = float(start_average) / (args.epochs * 0.2)
      end_average = float(end_average) / (args.epochs * 0.2)
      print("Start Average: [%.6f], End Average: [%.6f], Improved: [%.2f%%]" \
      % (start_average, end_average, 100 - (100*end_average/start_average)))

      if __name__ == '__main__':
      train()

      model 实现tensorflow版本

      import tensorflow as tf
      import os
      import sys
      import pdb

      def create_variable(name, shape):
      '''Create a convolution filter variable with the specified name and shape,
      and initialize it using Xavier initialition.'''
      initializer = tf.contrib.layers.xavier_initializer_conv2d()
      variable = tf.Variable(initializer(shape=shape), name=name)
      return variable

      def create_bias_variable(name, shape):
      '''Create a bias variable with the specified name and shape and initialize
      it to zero.'''
      initializer = tf.constant_initializer(value=0.0, dtype=tf.float32)
      return tf.Variable(initializer(shape=shape), name)

      class ESPCN:
      def __init__(self, filters_size, channels, ratio, batch_size, lr_size, edge):
      self.filters_size = filters_size
      self.channels = channels
      self.ratio = ratio
      self.batch_size = batch_size
      self.lr_size = lr_size
      self.edge = edge
      self.variables = self.create_variables()

      def create_variables(self):
      var = dict()
      var['filters'] = list()
      # the input layer
      var['filters'].append(
      create_variable('filter',
      [self.filters_size[0],
      self.filters_size[0],
      1,
      self.channels[0]]))
      # the hidden layers
      for idx in range(1, len(self.filters_size) - 1):
      var['filters'].append(
      create_variable('filter',
      [self.filters_size[idx],
      self.filters_size[idx],
      self.channels[idx - 1],
      self.channels[idx]]))
      # the output layer
      var['filters'].append(
      create_variable('filter',
      [self.filters_size[-1],
      self.filters_size[-1],
      self.channels[-1],
      self.ratio**2]))

      var['biases'] = list()
      for channel in self.channels:
      var['biases'].append(create_bias_variable('bias', [channel]))
      var['biases'].append(create_bias_variable('bias', [float(self.ratio)**2]))


      image_shape = (self.batch_size, self.lr_size, self.lr_size, 3)
      var['images'] = tf.placeholder(tf.uint8, shape=image_shape, name='images')
      label_shape = (self.batch_size, self.lr_size - self.edge, self.lr_size - self.edge, 3 * self.ratio**2)
      var['labels'] = tf.placeholder(tf.uint8, shape=label_shape, name='labels')

      return var

      def build_model(self):
      images, labels = self.variables['images'], self.variables['labels']
      input_images, input_labels = self.preprocess([images, labels])
      output = self.create_network(input_images)
      reduced_loss = self.loss(output, input_labels)
      return reduced_loss, images, labels

      def save(self, sess, saver, logdir, step):
      # print('[*] Storing checkpoint to {} ...'.format(logdir), end="")
      sys.stdout.flush()

      if not os.path.exists(logdir):
      os.makedirs(logdir)

      checkpoint = os.path.join(logdir, "model.ckpt")
      saver.save(sess, checkpoint, global_step=step)
      # print('[*] Done saving checkpoint.')

      def load(self, sess, saver, logdir):
      print("[*] Reading checkpoints...")
      ckpt = tf.train.get_checkpoint_state(logdir)

      if ckpt and ckpt.model_checkpoint_path:
      ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
      saver.restore(sess, os.path.join(logdir, ckpt_name))
      return True
      else:
      return False

      def preprocess(self, input_data):
      # cast to float32 and normalize the data
      input_list = list()
      for ele in input_data:
      if ele is None:
      continue
      ele = tf.cast(ele, tf.float32) / 255.0
      input_list.append(ele)

      input_images, input_labels = input_list[0][:,:,:,0:1], None
      # Generate doesn't use input_labels
      ratioSquare = self.ratio * self.ratio
      if input_data[1] is not None:
      input_labels = input_list[1][:,:,:,0:ratioSquare]
      return input_images, input_labels

      def create_network(self, input_labels):
      '''The default structure of the network is:

      input (3 channels) ---> 5 * 5 conv (64 channels) ---> 3 * 3 conv (32 channels) ---> 3 * 3 conv (3*r^2 channels)

      Where `conv` is 2d convolutions with a non-linear activation (tanh) at the output.
      '''
      current_layer = input_labels

      for idx in range(len(self.filters_size)):
      conv = tf.nn.conv2d(current_layer, self.variables['filters'][idx], [1, 1, 1, 1], padding='VALID')
      with_bias = tf.nn.bias_add(conv, self.variables['biases'][idx])
      if idx == len(self.filters_size) - 1:
      current_layer = with_bias
      else:
      current_layer = tf.nn.tanh(with_bias)
      return current_layer

      def loss(self, output, input_labels):
      residual = output - input_labels
      loss = tf.square(residual)
      reduced_loss = tf.reduce_mean(loss)
      tf.summary.scalar('loss', reduced_loss)
      return reduced_loss

      def generate(self, lr_image):
      lr_image = self.preprocess([lr_image, None])[0]
      sr_image = self.create_network(lr_image)
      sr_image = sr_image * 255.0
      sr_image = tf.cast(sr_image, 32)
      sr_image = tf.maximum(sr_image, 0)
      sr_image = tf.minimum(sr_image, 255)
      sr_image = tf.cast(sr_image, tf.uint8)
      return sr_image
      • 读取文件
      import tensorflow as tf
      import numpy as np
      import os
      import pdb

      def create_inputs(params):
      """
      Loads prepared training files and appends them as np arrays to a list.
      This approach is better because a FIFOQueue with a reader can't utilize
      the GPU while this approach can.
      """
      sess = tf.Session()

      lr_images, hr_labels = [], []
      training_dir = params['training_dir'].format(params['ratio'])

      # Raise exception if user has not ran prepare_data.py yet
      if not os.path.isdir(training_dir):
      raise Exception("You must first run prepare_data.py before you can train")

      lr_shape = (params['lr_size'], params['lr_size'], 3)
      hr_shape = output_shape = (params['lr_size'] - params['edge'], params['lr_size'] - params['edge'], 3 * params['ratio']**2)
      for file in os.listdir(training_dir):
      train_file = open("{}/{}".format(training_dir, file), "rb")
      train_data = np.fromfile(train_file, dtype=np.uint8)

      lr_image = train_data[:17 * 17 * 3].reshape(lr_shape)
      lr_images.append(lr_image)

      hr_label = train_data[17 * 17 * 3:].reshape(hr_shape)
      hr_labels.append(hr_label)

      return lr_images, hr_labels

      psnr计算

      import numpy as np
      import math

      def psnr(hr_image, sr_image, hr_edge):
      #assume RGB image
      hr_image_data = np.array(hr_image)
      if hr_edge > 0:
      hr_image_data = hr_image_data[hr_edge:-hr_edge, hr_edge:-hr_edge].astype('float32')

      sr_image_data = np.array(sr_image).astype('float32')

      diff = sr_image_data - hr_image_data
      diff = diff.flatten('C')
      rmse = math.sqrt( np.mean(diff ** 2.) )
      return 20*math.log10(255.0/rmse)

      训练过程有个BUG:bias is not unsupportd,但是也能学习。

      深度学习实战之超分辨率算法(tensorflow)——ESPCN

      版权声明:本文内容来自第三方投稿或授权转载,原文地址:https://blog.51cto.com/u_13859040/5814435,作者:qq5b42bed9cc7e9,版权归原作者所有。本网站转在其作品的目的在于传递更多信息,不拥有版权,亦不承担相应法律责任。如因作品内容、版权等问题需要同本网站联系,请发邮件至ctyunbbs@chinatelecom.cn沟通。

      上一篇:【Flask项目2】python对象分页数据序列化基类(5)

      下一篇:【django】新闻模块——新闻数据表设计和抽象模型类的用法【22】

      相关文章

      2025-05-19 09:04:14

      《剑指Offer》搜索算法题篇——更易理解的思路~

      《剑指Offer》搜索算法题篇——更易理解的思路~

      2025-05-19 09:04:14
      算法
      2025-05-19 09:04:14

      复杂度的OJ练习

      复杂度的OJ练习

      2025-05-19 09:04:14
      代码 , 复杂度 , 思路 , 数组 , 算法
      2025-05-19 09:04:14

      背包问题——“0-1背包”,“完全背包”(这样讲,还能不会?)

      背包问题——“0-1背包”,“完全背包”(这样讲,还能不会?)

      2025-05-19 09:04:14
      动态规划 , 算法
      2025-05-16 09:15:17

      多源BFS问题(2)_飞地的数量

      多源BFS问题(2)_飞地的数量

      2025-05-16 09:15:17
      bfs , grid , 单元格 , 算法
      2025-05-16 09:15:17

      BFS解决最短路问题(4)_为高尔夫比赛砍树

      BFS解决最短路问题(4)_为高尔夫比赛砍树

      2025-05-16 09:15:17
      BFS , lt , 复杂度 , 算法
      2025-05-16 09:15:17

      递归,搜索,回溯算法(3)之穷举,暴搜,深搜,回溯,剪枝

      递归,搜索,回溯算法(3)之穷举,暴搜,深搜,回溯,剪枝

      2025-05-16 09:15:17
      回溯 , 子集 , 数组 , 算法 , 递归
      2025-05-16 09:15:17

      多源BFS问题(4)_地图分析

      多源BFS问题(4)_地图分析

      2025-05-16 09:15:17
      单元格 , 算法 , 网格 , 距离
      2025-05-16 09:15:10

      BFS解决FloodFill算法(3)_岛屿的最大面积

      BFS解决FloodFill算法(3)_岛屿的最大面积

      2025-05-16 09:15:10
      grid , 复杂度 , 算法
      2025-05-14 10:33:31

      【数据结构】第一章——绪论(2)

      【数据结构】第一章——绪论(2)

      2025-05-14 10:33:31
      函数 , 实现 , 打印 , 理解 , 算法 , 输入 , 输出
      2025-05-14 10:33:31

      【数据结构】详细介绍串的简单模式匹配——朴素模式匹配算法

      【数据结构】详细介绍串的简单模式匹配——朴素模式匹配算法

      2025-05-14 10:33:31
      下标 , 元素 , 匹配 , 子串 , 模式匹配 , 算法
      查看更多
      推荐标签

      作者介绍

      天翼云小翼
      天翼云用户

      文章

      33561

      阅读量

      5250096

      查看更多

      最新文章

      《剑指Offer》搜索算法题篇——更易理解的思路~

      2025-05-19 09:04:14

      背包问题——“0-1背包”,“完全背包”(这样讲,还能不会?)

      2025-05-19 09:04:14

      多源BFS问题(2)_飞地的数量

      2025-05-16 09:15:17

      BFS解决最短路问题(4)_为高尔夫比赛砍树

      2025-05-16 09:15:17

      递归,搜索,回溯算法(3)之穷举,暴搜,深搜,回溯,剪枝

      2025-05-16 09:15:17

      多源BFS问题(4)_地图分析

      2025-05-16 09:15:17

      查看更多

      热门文章

      Lc70_爬楼梯

      2024-06-27 09:20:52

      利用函数求出一个数组最大三个数的乘积

      2023-02-13 08:10:07

      冒泡排序法解析

      2024-07-01 01:30:59

      猜字母问题

      2023-02-24 08:30:41

      1791. 找出星型图的中心节点

      2023-02-13 07:55:59

      经典算法——二分查找

      2023-05-11 06:06:36

      查看更多

      热门标签

      算法 leetcode python 数据 java 数组 节点 大数据 i++ 链表 golang c++ 排序 django 数据类型
      查看更多

      相关产品

      弹性云主机

      随时自助获取、弹性伸缩的云服务器资源

      天翼云电脑(公众版)

      便捷、安全、高效的云电脑服务

      对象存储

      高品质、低成本的云上存储服务

      云硬盘

      为云上计算资源提供持久性块存储

      查看更多

      随机文章

      给你两个整数 m 和 n 。

      给定一个整型数组 arr,数组中的每个值都为正数,表示完成一幅画作需要的时间,再 给定 一个整数 num,表示画匠的数量,每个画匠只能画连在一起的画作。

      ChatGPT 强化学习 Proximal Policy Optimization 近似策略优化算法

      [leetcode] 91. Decode Ways

      文心一言 VS 讯飞星火 VS chatgpt (397)-- 算法导论25.2 2题

      返回新列表

      • 7*24小时售后
      • 无忧退款
      • 免费备案
      • 专家服务
      售前咨询热线
      400-810-9889转1
      关注天翼云
      • 旗舰店
      • 天翼云APP
      • 天翼云微信公众号
      服务与支持
      • 备案中心
      • 售前咨询
      • 智能客服
      • 自助服务
      • 工单管理
      • 客户公告
      • 涉诈举报
      账户管理
      • 管理中心
      • 订单管理
      • 余额管理
      • 发票管理
      • 充值汇款
      • 续费管理
      快速入口
      • 天翼云旗舰店
      • 文档中心
      • 最新活动
      • 免费试用
      • 信任中心
      • 天翼云学堂
      云网生态
      • 甄选商城
      • 渠道合作
      • 云市场合作
      了解天翼云
      • 关于天翼云
      • 天翼云APP
      • 服务案例
      • 新闻资讯
      • 联系我们
      热门产品
      • 云电脑
      • 弹性云主机
      • 云电脑政企版
      • 天翼云手机
      • 云数据库
      • 对象存储
      • 云硬盘
      • Web应用防火墙
      • 服务器安全卫士
      • CDN加速
      热门推荐
      • 云服务备份
      • 边缘安全加速平台
      • 全站加速
      • 安全加速
      • 云服务器
      • 云主机
      • 智能边缘云
      • 应用编排服务
      • 微服务引擎
      • 共享流量包
      更多推荐
      • web应用防火墙
      • 密钥管理
      • 等保咨询
      • 安全专区
      • 应用运维管理
      • 云日志服务
      • 文档数据库服务
      • 云搜索服务
      • 数据湖探索
      • 数据仓库服务
      友情链接
      • 中国电信集团
      • 189邮箱
      • 天翼企业云盘
      • 天翼云盘
      ©2025 天翼云科技有限公司版权所有 增值电信业务经营许可证A2.B1.B2-20090001
      公司地址:北京市东城区青龙胡同甲1号、3号2幢2层205-32室
      • 用户协议
      • 隐私政策
      • 个人信息保护
      • 法律声明
      备案 京公网安备11010802043424号 京ICP备 2021034386号