aboutsummaryrefslogtreecommitdiffhomepage
path: root/libslang/src/test/arrmult.sl
diff options
context:
space:
mode:
Diffstat (limited to 'libslang/src/test/arrmult.sl')
-rw-r--r--libslang/src/test/arrmult.sl163
1 files changed, 163 insertions, 0 deletions
diff --git a/libslang/src/test/arrmult.sl b/libslang/src/test/arrmult.sl
new file mode 100644
index 0000000..cd0cca8
--- /dev/null
+++ b/libslang/src/test/arrmult.sl
@@ -0,0 +1,163 @@
+_debug_info = 1; () = evalfile ("inc.sl");
+
+print ("Testing Matrix Multiplications ...");
+#ifexists Double_Type
+
+static define dot_prod (a, b)
+{
+ (a # b)[0]; % transpose not needed for 1-d arrays
+}
+
+static define sum (a)
+{
+ variable ones = Double_Type [length (a)] + 1;
+ dot_prod (a, ones);
+}
+
+
+if (1+2+3+4+5 != sum([1,2,3,4,5]))
+ failed ("sum");
+
+#ifexists Complex_Type
+if (1+2i != sum ([1,2i]))
+ failed ("sum complex");
+#endif
+
+define mult (a, b)
+{
+ variable dims_a, dims_b;
+ variable nr_a, nr_b, nc_a, nc_b;
+ variable i, j;
+ variable c;
+
+ (dims_a,,) = array_info (a);
+ (dims_b,,) = array_info (b);
+ nr_a = dims_a[0];
+ nc_a = dims_a[1];
+ nr_b = dims_b[0];
+ nc_b = dims_b[1];
+
+ c = _typeof ([a[0,0]]#[b[0,0]])[nr_a, nc_b];
+
+ for (i = 0; i < nr_a; i++)
+ {
+ for (j = 0; j < nc_b; j++)
+ c[i,j] = dot_prod (a[i,*], b[*,j]);
+ }
+ return c;
+}
+
+static define arr_cmp (a, b)
+{
+ variable i = length (where (b != a));
+ if (i == 0)
+ return 0;
+
+ i = where (b != a);
+ a = a[i];
+ b = b[i];
+ reshape (a, [length(a)]);
+ reshape (b, [length(b)]);
+ vmessage ("%S != %S\n", a[0], b[0]);
+ return 1;
+}
+
+static define test (a, b)
+{
+ if (0 != arr_cmp (mult (a,b), a#b))
+ failed ("%S # %S", a, b);
+}
+
+variable A, B;
+
+#ifexists Complex_Type
+A = [1+2i];
+B = [3+4i];
+reshape (A, [1, 1]);
+reshape (B, [1, 1]);
+test (A,B);
+#endif
+
+% Test intgers
+A = _reshape ([[1, 2, 3], [4, 5, 6]], [2,3]);
+B = _reshape ([[7,8,9],[1,2,4]], [2,3]);
+B = transpose (B);
+
+test (A, B);
+
+B *= 1f;
+test (A, B);
+
+B *= 1.0;
+test (A,B);
+
+A *= 1f;
+test (A,B);
+
+#ifexists Complex_Type
+B += 2i;
+test (A,B);
+
+A += 3i;
+test (A,B);
+
+B = Real(B);
+test (A,B);
+
+% Now try an empty array
+
+if (Complex_Type != _typeof (Complex_Type[0,0,0] # Complex_Type[0]))
+ failed ("[]#[]");
+#endif
+% And finally, do a 3-d array:
+
+A = _reshape ([1:2*3*4], [2,3,4]);
+B = _reshape ([1:4*5*6], [4,5,6]);
+static variable C = A#B;
+
+% C should be a [2,3,5,6] matrix. Let's check via brute force
+
+static define multiply_3d (a, b, c)
+{
+ variable i, j, k, l, m;
+ variable dims_a, dims_b;
+
+ (dims_a,,) = array_info(a);
+ (dims_b,,) = array_info(b);
+
+ _for (0, dims_a[0]-1, 1)
+ {
+ i = ();
+ _for (0, dims_a[1]-1, 1)
+ {
+ j = ();
+ _for (0, dims_b[1]-1, 1)
+ {
+ l = ();
+ _for (0, dims_b[2]-1, 1)
+ {
+ m = ();
+
+ variable sum = 0;
+ _for (0, dims_b[0]-1, 1)
+ {
+ k = ();
+ sum += a[i,j,k] * b[k, l, m];
+ }
+ if (sum != c[i,j,l,m])
+ failed ("multiply_3d");
+ }
+ }
+ }
+ }
+}
+
+multiply_3d (A, B, C);
+
+
+print ("Ok\n");
+#else
+print ("Not available\n");
+#endif
+exit (0);
+