Skip to content
Snippets Groups Projects
Commit 1375a91e authored by gruczela's avatar gruczela
Browse files

Update 1D_wave.py

parent 0e78dbc0
No related branches found
No related tags found
1 merge request!14Update 1D_wave.py
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import matplotlib.pylab as plt
import numpy as np
import time
import os
from numba import jit, njit, prange
# In[2]:
def set_pos():
xmin = 0.0
xmax = 10.0
......@@ -25,10 +15,6 @@ def set_pos():
nx, gamma, x, dx = set_pos()
# In[3]:
def set_time():
tmin = 0.0
tmax = 10.0
......@@ -39,10 +25,6 @@ def set_time():
nt, tgrid, dt = set_time()
# In[4]:
def velocity():
v = np.zeros(nx)
dvdt = np.zeros(nx)
......@@ -50,10 +32,6 @@ def velocity():
v, dvdt = velocity()
# In[5]:
y = np.zeros(nx)
y = np.exp(-(np.array(x)-5.0)**2)
......@@ -73,10 +51,6 @@ user_time, sys_time, _, _, _ = os.times()
print("User time: {:.3f} seconds".format(user_time))
print("System time: {:.3f} seconds".format(sys_time))
# In[6]:
plt.plot(x, y, label = "Position")
plt.plot(x, v, label = "Velocity")
plt.plot(x, dvdt, label = "Acceleration")
......@@ -85,10 +59,6 @@ plt.xlabel('Position')
plt.ylabel('Amplitude')
plt.legend()
# In[7]:
@njit(parallel=True)
def fast_pos():
xmin = 0.0
......@@ -100,10 +70,6 @@ def fast_pos():
return nx, gamma, x, dx
nx, gamma, x, dx = fast_pos()
# In[8]:
@njit(parallel=True)
def fast_time():
tmin = 0.0
......@@ -115,10 +81,6 @@ def fast_time():
nt, tgrid, dt = fast_time()
# In[9]:
@njit(parallel=True)
def fast_velocity():
v = np.zeros(nx)
......@@ -127,10 +89,6 @@ def fast_velocity():
v, dvdt = fast_velocity()
# In[10]:
@njit(parallel=True)
def fast_numerical_sol():
nx, gamma, x, dx = fast_pos()
......@@ -146,10 +104,6 @@ def fast_numerical_sol():
v += dvdt*dt
return y, v, dvdt
# In[11]:
start_2 = time.time()
fast_numerical_sol()
end_2 = time.time()
......@@ -162,10 +116,6 @@ end_3 = time.time()
time_3 = end_3 - start_3
print("Seconds elapsed (after compilation) = %s" % (time_3))
# In[12]:
x = 'Without Using Numba', 'Using Numba (after compilation)'
times = time_1, time_3
y = times
......@@ -173,4 +123,3 @@ plt.bar(x, y, width = 0.5)
plt.ylabel('Seconds Elapsed')
plt.title('Timing Results- 1D Wave')
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment