import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from mpl_toolkits import mplot3d
from scipy.interpolate import griddata
from scipy.optimize import curve_fit
m=200
n=300
x=np.linspace(-6,6,m)
y=np.linspace(-8,8,n)
x2,y2=np.meshgrid(x,y)
x3=np.reshape(x2,(1,-1))
y3=np.reshape(y2,(1,-1))
xy=np.vstack((x3,y3))
def pfun(t,m1,m2,s):
return np.exp(-((t[0]-m1)**2+(t[1]-m2)**2)/(2*s**2))
z=pfun(xy,1,2,3)
zr=z+0.2*np.random.normal(size=z.shape)
popt,pcov=curve_fit(pfun,xy,zr)
#print(popt)
zn=pfun(xy,*popt)
zn2=np.reshape(zn,x2.shape)
plt.rc('font',size=16)
ax=plt.axes(projection='3d')
ax.plot_surface(x2,y2,zn2,cmap='gist_rainbow')
plt.show()
'''
a = np.array([[1, 2, 3, 4, 5, 6],
[7, 8, 9, 10, 11, 12],
[13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24]])
b=np.reshape(a,(6,-1))
print(b)
'''
我还掌握了numpy库中的reshape函数,收获很大
最后结果: