You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

check_gm.py 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Check basic properties of gram matrices.
  5. Created on Wed Sep 19 15:32:29 2018
  6. @author: ljia
  7. """
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. from numpy.linalg import eig
  11. # read gram matrices from file.
  12. results_dir = '../results/marginalizedkernel/myria'
  13. ds_name = 'ENZYMES'
  14. gmfile = np.load(results_dir + '/' + ds_name + '.gm.npz')
  15. #print('gm time: ', gmfile['gmtime'])
  16. # a list to store gram matrices for all param_grid_precomputed
  17. gram_matrices = gmfile['gms']
  18. # param_list_pre_revised = gmfile['params'] # list to store param grids precomputed ignoring the useless ones
  19. #y = gmfile['y'].tolist()
  20. #x = gram_matrices[0]
  21. for idx, x in enumerate(gram_matrices):
  22. print()
  23. print(idx)
  24. plt.imshow(x)
  25. plt.colorbar()
  26. plt.savefig('../check_gm/' + ds_name + '.gm.eps', format='eps', dpi=300)
  27. # print(np.transpose(x))
  28. print('if symmetric: ', np.array_equal(x, np.transpose(x)))
  29. print('diag: ', np.diag(x))
  30. print('sum diag < 0.1: ', np.sum(np.diag(x) < 0.1))
  31. print('min, max diag: ', min(np.diag(x)), max(np.diag(x)))
  32. print('min, max matrix: ', np.min(x), np.max(x))
  33. for i in range(len(x)):
  34. for j in range(len(x)):
  35. if x[i][j] > 1 + 1e-9:
  36. print(i, j)
  37. raise Exception('value bigger than 1 with index', i, j)
  38. print('mean x: ', np.mean(np.mean(x)))
  39. [lamnda, v] = eig(x)
  40. print('min, max lambda: ', min(lamnda), max(lamnda))
  41. if -1e-10 > min(lamnda):
  42. raise Exception('wrong eigen values.')

A Python package for graph kernels, graph edit distances and graph pre-image problem.